Skip to content

Commit 4cefd1b

Browse files
authored
Merge pull request #23097 from ziggoon/master
std.heap.PageAllocator updates to fix race condition and utilize NtAllocateVirtualMemory / NtFreeVirtualMemory instead of VirtualAlloc / VirtualFree
2 parents 1e0739f + 5b03e24 commit 4cefd1b

File tree

3 files changed

+116
-62
lines changed

3 files changed

+116
-62
lines changed

lib/std/heap/PageAllocator.zig

Lines changed: 65 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@ const maxInt = std.math.maxInt;
66
const assert = std.debug.assert;
77
const native_os = builtin.os.tag;
88
const windows = std.os.windows;
9+
const ntdll = windows.ntdll;
910
const posix = std.posix;
1011
const page_size_min = std.heap.page_size_min;
1112

13+
const SUCCESS = @import("../os/windows/ntstatus.zig").NTSTATUS.SUCCESS;
14+
const MEM_RESERVE_PLACEHOLDER = windows.MEM_RESERVE_PLACEHOLDER;
15+
const MEM_PRESERVE_PLACEHOLDER = windows.MEM_PRESERVE_PLACEHOLDER;
16+
1217
pub const vtable: Allocator.VTable = .{
1318
.alloc = alloc,
1419
.resize = resize,
@@ -22,51 +27,62 @@ pub fn map(n: usize, alignment: mem.Alignment) ?[*]u8 {
2227
const alignment_bytes = alignment.toByteUnits();
2328

2429
if (native_os == .windows) {
25-
// According to official documentation, VirtualAlloc aligns to page
26-
// boundary, however, empirically it reserves pages on a 64K boundary.
27-
// Since it is very likely the requested alignment will be honored,
28-
// this logic first tries a call with exactly the size requested,
29-
// before falling back to the loop below.
30-
// https://devblogs.microsoft.com/oldnewthing/?p=42223
31-
const addr = windows.VirtualAlloc(
32-
null,
33-
// VirtualAlloc will round the length to a multiple of page size.
34-
// "If the lpAddress parameter is NULL, this value is rounded up to
35-
// the next page boundary".
36-
n,
37-
windows.MEM_COMMIT | windows.MEM_RESERVE,
38-
windows.PAGE_READWRITE,
39-
) catch return null;
40-
41-
if (mem.isAligned(@intFromPtr(addr), alignment_bytes))
42-
return @ptrCast(addr);
43-
44-
// Fallback: reserve a range of memory large enough to find a
45-
// sufficiently aligned address, then free the entire range and
46-
// immediately allocate the desired subset. Another thread may have won
47-
// the race to map the target range, in which case a retry is needed.
48-
windows.VirtualFree(addr, 0, windows.MEM_RELEASE);
30+
var base_addr: ?*anyopaque = null;
31+
var size: windows.SIZE_T = n;
32+
33+
var status = ntdll.NtAllocateVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), 0, &size, windows.MEM_COMMIT | windows.MEM_RESERVE, windows.PAGE_READWRITE);
34+
35+
if (status == SUCCESS and mem.isAligned(@intFromPtr(base_addr), alignment_bytes)) {
36+
return @ptrCast(base_addr);
37+
}
38+
39+
if (status == SUCCESS) {
40+
var region_size: windows.SIZE_T = 0;
41+
_ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), &region_size, windows.MEM_RELEASE);
42+
}
4943

5044
const overalloc_len = n + alignment_bytes - page_size;
5145
const aligned_len = mem.alignForward(usize, n, page_size);
5246

53-
while (true) {
54-
const reserved_addr = windows.VirtualAlloc(
55-
null,
56-
overalloc_len,
57-
windows.MEM_RESERVE,
58-
windows.PAGE_NOACCESS,
59-
) catch return null;
60-
const aligned_addr = mem.alignForward(usize, @intFromPtr(reserved_addr), alignment_bytes);
61-
windows.VirtualFree(reserved_addr, 0, windows.MEM_RELEASE);
62-
const ptr = windows.VirtualAlloc(
63-
@ptrFromInt(aligned_addr),
64-
aligned_len,
65-
windows.MEM_COMMIT | windows.MEM_RESERVE,
66-
windows.PAGE_READWRITE,
67-
) catch continue;
68-
return @ptrCast(ptr);
47+
base_addr = null;
48+
size = overalloc_len;
49+
50+
status = ntdll.NtAllocateVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), 0, &size, windows.MEM_RESERVE | MEM_RESERVE_PLACEHOLDER, windows.PAGE_NOACCESS);
51+
52+
if (status != SUCCESS) return null;
53+
54+
const placeholder_addr = @intFromPtr(base_addr);
55+
const aligned_addr = mem.alignForward(usize, placeholder_addr, alignment_bytes);
56+
const prefix_size = aligned_addr - placeholder_addr;
57+
58+
if (prefix_size > 0) {
59+
var prefix_base = base_addr;
60+
var prefix_size_param: windows.SIZE_T = prefix_size;
61+
_ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&prefix_base), &prefix_size_param, windows.MEM_RELEASE | MEM_PRESERVE_PLACEHOLDER);
6962
}
63+
64+
const suffix_start = aligned_addr + aligned_len;
65+
const suffix_size = (placeholder_addr + overalloc_len) - suffix_start;
66+
if (suffix_size > 0) {
67+
var suffix_base = @as(?*anyopaque, @ptrFromInt(suffix_start));
68+
var suffix_size_param: windows.SIZE_T = suffix_size;
69+
_ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&suffix_base), &suffix_size_param, windows.MEM_RELEASE | MEM_PRESERVE_PLACEHOLDER);
70+
}
71+
72+
base_addr = @ptrFromInt(aligned_addr);
73+
size = aligned_len;
74+
75+
status = ntdll.NtAllocateVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), 0, &size, windows.MEM_COMMIT | MEM_PRESERVE_PLACEHOLDER, windows.PAGE_READWRITE);
76+
77+
if (status == SUCCESS) {
78+
return @ptrCast(base_addr);
79+
}
80+
81+
base_addr = @as(?*anyopaque, @ptrFromInt(aligned_addr));
82+
size = aligned_len;
83+
_ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), &size, windows.MEM_RELEASE);
84+
85+
return null;
7086
}
7187

7288
const aligned_len = mem.alignForward(usize, n, page_size);
@@ -104,26 +120,14 @@ fn alloc(context: *anyopaque, n: usize, alignment: mem.Alignment, ra: usize) ?[*
104120
return map(n, alignment);
105121
}
106122

107-
fn resize(
108-
context: *anyopaque,
109-
memory: []u8,
110-
alignment: mem.Alignment,
111-
new_len: usize,
112-
return_address: usize,
113-
) bool {
123+
fn resize(context: *anyopaque, memory: []u8, alignment: mem.Alignment, new_len: usize, return_address: usize) bool {
114124
_ = context;
115125
_ = alignment;
116126
_ = return_address;
117127
return realloc(memory, new_len, false) != null;
118128
}
119129

120-
fn remap(
121-
context: *anyopaque,
122-
memory: []u8,
123-
alignment: mem.Alignment,
124-
new_len: usize,
125-
return_address: usize,
126-
) ?[*]u8 {
130+
fn remap(context: *anyopaque, memory: []u8, alignment: mem.Alignment, new_len: usize, return_address: usize) ?[*]u8 {
127131
_ = context;
128132
_ = alignment;
129133
_ = return_address;
@@ -139,7 +143,9 @@ fn free(context: *anyopaque, memory: []u8, alignment: mem.Alignment, return_addr
139143

140144
pub fn unmap(memory: []align(page_size_min) u8) void {
141145
if (native_os == .windows) {
142-
windows.VirtualFree(memory.ptr, 0, windows.MEM_RELEASE);
146+
var base_addr: ?*anyopaque = memory.ptr;
147+
var region_size: windows.SIZE_T = 0;
148+
_ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), &region_size, windows.MEM_RELEASE);
143149
} else {
144150
const page_aligned_len = mem.alignForward(usize, memory.len, std.heap.pageSize());
145151
posix.munmap(memory.ptr[0..page_aligned_len]);
@@ -157,13 +163,10 @@ pub fn realloc(uncasted_memory: []u8, new_len: usize, may_move: bool) ?[*]u8 {
157163
const old_addr_end = base_addr + memory.len;
158164
const new_addr_end = mem.alignForward(usize, base_addr + new_len, page_size);
159165
if (old_addr_end > new_addr_end) {
160-
// For shrinking that is not releasing, we will only decommit
161-
// the pages not needed anymore.
162-
windows.VirtualFree(
163-
@ptrFromInt(new_addr_end),
164-
old_addr_end - new_addr_end,
165-
windows.MEM_DECOMMIT,
166-
);
166+
var decommit_addr: ?*anyopaque = @ptrFromInt(new_addr_end);
167+
var decommit_size: windows.SIZE_T = old_addr_end - new_addr_end;
168+
169+
_ = ntdll.NtAllocateVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&decommit_addr), 0, &decommit_size, windows.MEM_RESET, windows.PAGE_NOACCESS);
167170
}
168171
return memory.ptr;
169172
}

lib/std/os/windows.zig

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,6 +1758,38 @@ pub fn TerminateProcess(hProcess: HANDLE, uExitCode: UINT) TerminateProcessError
17581758
}
17591759
}
17601760

1761+
pub const NtAllocateVirtualMemoryError = error{
1762+
AccessDenied,
1763+
InvalidParameter,
1764+
NoMemory,
1765+
Unexpected,
1766+
};
1767+
1768+
pub fn NtAllocateVirtualMemory(hProcess: HANDLE, addr: ?*PVOID, zero_bits: ULONG_PTR, size: ?*SIZE_T, alloc_type: ULONG, protect: ULONG) NtAllocateVirtualMemoryError!void {
1769+
return switch (ntdll.NtAllocateVirtualMemory(hProcess, addr, zero_bits, size, alloc_type, protect)) {
1770+
.SUCCESS => return,
1771+
.ACCESS_DENIED => NtAllocateVirtualMemoryError.AccessDenied,
1772+
.INVALID_PARAMETER => NtAllocateVirtualMemoryError.InvalidParameter,
1773+
.NO_MEMORY => NtAllocateVirtualMemoryError.NoMemory,
1774+
else => |st| unexpectedStatus(st),
1775+
};
1776+
}
1777+
1778+
pub const NtFreeVirtualMemoryError = error{
1779+
AccessDenied,
1780+
InvalidParameter,
1781+
Unexpected,
1782+
};
1783+
1784+
pub fn NtFreeVirtualMemory(hProcess: HANDLE, addr: ?*PVOID, size: *SIZE_T, free_type: ULONG) NtFreeVirtualMemoryError!void {
1785+
return switch (ntdll.NtFreeVirtualMemory(hProcess, addr, size, free_type)) {
1786+
.SUCCESS => return,
1787+
.ACCESS_DENIED => NtFreeVirtualMemoryError.AccessDenied,
1788+
.INVALID_PARAMETER => NtFreeVirtualMemoryError.InvalidParameter,
1789+
else => NtFreeVirtualMemoryError.Unexpected,
1790+
};
1791+
}
1792+
17611793
pub const VirtualAllocError = error{Unexpected};
17621794

17631795
pub fn VirtualAlloc(addr: ?LPVOID, size: usize, alloc_type: DWORD, flProtect: DWORD) VirtualAllocError!LPVOID {
@@ -3539,6 +3571,8 @@ pub const MEM_LARGE_PAGES = 0x20000000;
35393571
pub const MEM_PHYSICAL = 0x400000;
35403572
pub const MEM_TOP_DOWN = 0x100000;
35413573
pub const MEM_WRITE_WATCH = 0x200000;
3574+
pub const MEM_RESERVE_PLACEHOLDER = 0x00040000;
3575+
pub const MEM_PRESERVE_PLACEHOLDER = 0x00000400;
35423576

35433577
// Protect values
35443578
pub const PAGE_EXECUTE = 0x10;

lib/std/os/windows/ntdll.zig

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ const BOOL = windows.BOOL;
55
const DWORD = windows.DWORD;
66
const DWORD64 = windows.DWORD64;
77
const ULONG = windows.ULONG;
8+
const ULONG_PTR = windows.ULONG_PTR;
89
const NTSTATUS = windows.NTSTATUS;
910
const WORD = windows.WORD;
1011
const HANDLE = windows.HANDLE;
@@ -358,3 +359,19 @@ pub extern "ntdll" fn NtCreateNamedPipeFile(
358359
OutboundQuota: ULONG,
359360
DefaultTimeout: *LARGE_INTEGER,
360361
) callconv(.winapi) NTSTATUS;
362+
363+
pub extern "ntdll" fn NtAllocateVirtualMemory(
364+
ProcessHandle: HANDLE,
365+
BaseAddress: ?*PVOID,
366+
ZeroBits: ULONG_PTR,
367+
RegionSize: ?*SIZE_T,
368+
AllocationType: ULONG,
369+
PageProtection: ULONG,
370+
) callconv(.winapi) NTSTATUS;
371+
372+
pub extern "ntdll" fn NtFreeVirtualMemory(
373+
ProcessHandle: HANDLE,
374+
BaseAddress: ?*PVOID,
375+
RegionSize: *SIZE_T,
376+
FreeType: ULONG,
377+
) callconv(.winapi) NTSTATUS;

0 commit comments

Comments
 (0)