Skip to content

Commit b13fa6e

Browse files
committed
WIP: FS/GS-based TLS access abstraction
Signed-off-by: Joe Richey <[email protected]>
1 parent 8060f05 commit b13fa6e

File tree

2 files changed

+346
-0
lines changed

2 files changed

+346
-0
lines changed

src/instructions/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ pub mod random;
88
pub mod segmentation;
99
pub mod tables;
1010
pub mod tlb;
11+
#[cfg(feature = "inline_asm")]
12+
pub mod tls;
1113

1214
/// Halts the CPU until the next interrupt arrives.
1315
#[inline]

src/instructions/tls.rs

+344
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
#![allow(missing_docs)]
2+
3+
//! TODO: Document module?
4+
5+
use core::marker::PhantomData;
6+
use core::mem::{size_of, MaybeUninit};
7+
use core::ptr::NonNull;
8+
9+
use crate::VirtAddr;
10+
11+
use super::segmentation::{rdfsbase, rdgsbase, wrfsbase, wrgsbase};
12+
13+
/// TODO: Document
14+
pub trait Segment {
15+
unsafe fn get_base() -> VirtAddr;
16+
unsafe fn set_base(addr: VirtAddr);
17+
18+
unsafe fn read_u64(off: usize) -> u64;
19+
unsafe fn read_u32(off: usize) -> u32;
20+
unsafe fn read_u16(off: usize) -> u16;
21+
unsafe fn read_u8(off: usize) -> u8;
22+
unsafe fn write_u64(off: usize, val: u64);
23+
unsafe fn write_u32(off: usize, val: u32);
24+
unsafe fn write_u16(off: usize, val: u16);
25+
unsafe fn write_u8(off: usize, val: u8);
26+
27+
#[inline]
28+
unsafe fn read<T: Copy>(off: usize) -> T {
29+
let mut val: MaybeUninit<T> = MaybeUninit::uninit();
30+
read_ptr::<Self>(off, val.as_mut_ptr() as *mut u8, size_of::<T>());
31+
val.assume_init()
32+
}
33+
#[inline]
34+
unsafe fn write<T: Copy>(off: usize, val: T) {
35+
write_ptr::<Self>(off, &val as *const T as *const u8, size_of::<T>())
36+
}
37+
}
38+
39+
/// TODO: Document
40+
#[derive(Debug)]
41+
pub struct Wrapper<S, T>(PhantomData<(S, *mut T)>);
42+
unsafe impl<S, T> Send for Wrapper<S, T> {}
43+
unsafe impl<S, T> Sync for Wrapper<S, T> {}
44+
45+
impl<S: Segment, T> Wrapper<S, T> {
46+
pub const fn new() -> Self {
47+
Self(PhantomData)
48+
}
49+
pub unsafe fn init(&self, new: Option<NonNull<T>>) -> Option<NonNull<T>> {
50+
let old = S::get_base().as_mut_ptr();
51+
S::set_base(match new {
52+
None => VirtAddr::new(0),
53+
Some(p) => VirtAddr::from_ptr(p.as_ptr()),
54+
});
55+
NonNull::new(old)
56+
}
57+
58+
// Hidden helper functions to help with type deduction
59+
#[doc(hidden)]
60+
#[inline]
61+
pub const unsafe fn __uninit(&self) -> MaybeUninit<T> {
62+
MaybeUninit::uninit()
63+
}
64+
#[doc(hidden)]
65+
#[inline]
66+
pub unsafe fn __read<U: Copy>(&self, off: usize) -> U {
67+
S::read::<U>(off)
68+
}
69+
#[doc(hidden)]
70+
#[inline]
71+
pub unsafe fn __write<U: Copy>(&self, off: usize, _: *const U, val: U) {
72+
S::write::<U>(off, val)
73+
}
74+
}
75+
76+
// Hidden helper functions to help with type deduction
77+
#[doc(hidden)]
78+
#[inline]
79+
pub const unsafe fn __ptr_val_agree<U: Copy>(_: *const U, _: U) {}
80+
81+
/// TODO: Document
82+
#[macro_export]
83+
macro_rules! tls_read {
84+
($wrapper:path, $field:tt) => {{
85+
// TODO: Move offset into const when this is stable
86+
let u: MaybeUninit<_> = $wrapper.__uninit();
87+
let base: *const _ = u.as_ptr();
88+
let field: *const _ = ::core::ptr::addr_of!((*base).$field);
89+
let offset: isize = (field as *const u8).offset_from(base as *const u8);
90+
91+
let val = $wrapper.__read(offset as usize);
92+
__ptr_val_agree(field, val);
93+
val
94+
}};
95+
}
96+
97+
/// TODO: Document
98+
#[macro_export]
99+
macro_rules! tls_write {
100+
($wrapper:path, $field:tt, $val:expr) => {{
101+
let u: MaybeUninit<_> = $wrapper.__uninit();
102+
let base: *const _ = u.as_ptr();
103+
let field: *const _ = ::core::ptr::addr_of!((*base).$field);
104+
let offset: isize = (field as *const u8).offset_from(base as *const u8);
105+
106+
$wrapper.__write(offset as usize, field, $val);
107+
}};
108+
}
109+
110+
/// TODO: Document
111+
#[derive(Debug)]
112+
pub struct FS(());
113+
114+
impl Segment for FS {
115+
unsafe fn get_base() -> VirtAddr {
116+
// SAFETY: rdfsbase always returns a canonical address
117+
VirtAddr::new_unsafe(rdfsbase())
118+
}
119+
unsafe fn set_base(addr: VirtAddr) {
120+
wrfsbase(addr.as_u64())
121+
}
122+
unsafe fn read_u64(off: usize) -> u64 {
123+
let val: u64;
124+
asm!(
125+
"mov {}, qword ptr fs:[{}]",
126+
lateout(reg) val, in(reg) off,
127+
options(nostack, preserves_flags, pure, readonly),
128+
);
129+
val
130+
}
131+
unsafe fn read_u32(off: usize) -> u32 {
132+
let val: u32;
133+
asm!(
134+
"mov {:e}, dword ptr fs:[{}]",
135+
lateout(reg) val, in(reg) off,
136+
options(nostack, preserves_flags, pure, readonly),
137+
);
138+
val
139+
}
140+
unsafe fn read_u16(off: usize) -> u16 {
141+
let val: u32; // Avoid partial register issues
142+
asm!(
143+
"movzx {:e}, word ptr fs:[{}]",
144+
lateout(reg) val, in(reg) off,
145+
options(nostack, preserves_flags, pure, readonly),
146+
);
147+
val as u16
148+
}
149+
unsafe fn read_u8(off: usize) -> u8 {
150+
let val: u32; // Avoid partial register issues
151+
asm!(
152+
"movzx {:e}, byte ptr fs:[{}]",
153+
lateout(reg) val, in(reg) off,
154+
options(nostack, preserves_flags, pure, readonly),
155+
);
156+
val as u8
157+
}
158+
unsafe fn write_u64(off: usize, val: u64) {
159+
asm!(
160+
"mov qword ptr fs:[{}], {}",
161+
in(reg) off, in(reg) val,
162+
options(nostack, preserves_flags),
163+
);
164+
}
165+
unsafe fn write_u32(off: usize, val: u32) {
166+
asm!(
167+
"mov dword ptr fs:[{}], {:e}",
168+
in(reg) off, in(reg) val,
169+
options(nostack, preserves_flags),
170+
);
171+
}
172+
unsafe fn write_u16(off: usize, val: u16) {
173+
asm!(
174+
"mov word ptr fs:[{}], {:x}",
175+
in(reg) off, in(reg) val,
176+
options(nostack, preserves_flags),
177+
);
178+
}
179+
unsafe fn write_u8(off: usize, val: u8) {
180+
asm!(
181+
"mov byte ptr fs:[{}], {}",
182+
in(reg) off, in(reg_byte) val,
183+
options(nostack, preserves_flags),
184+
);
185+
}
186+
}
187+
188+
/// TODO: Document
189+
#[derive(Debug)]
190+
pub struct GS(());
191+
192+
impl Segment for GS {
193+
unsafe fn get_base() -> VirtAddr {
194+
// SAFETY: rdfsbase always returns a canonical address
195+
VirtAddr::new_unsafe(rdgsbase())
196+
}
197+
unsafe fn set_base(addr: VirtAddr) {
198+
wrgsbase(addr.as_u64())
199+
}
200+
unsafe fn read_u64(off: usize) -> u64 {
201+
let val: u64;
202+
asm!(
203+
"mov {}, qword ptr gs:[{}]",
204+
lateout(reg) val, in(reg) off,
205+
options(nostack, preserves_flags, pure, readonly),
206+
);
207+
val
208+
}
209+
unsafe fn read_u32(off: usize) -> u32 {
210+
let val: u32;
211+
asm!(
212+
"mov {:e}, dword ptr gs:[{}]",
213+
lateout(reg) val, in(reg) off,
214+
options(nostack, preserves_flags, pure, readonly),
215+
);
216+
val
217+
}
218+
unsafe fn read_u16(off: usize) -> u16 {
219+
let val: u32; // Avoid partial register issues
220+
asm!(
221+
"movzx {:e}, word ptr gs:[{}]",
222+
lateout(reg) val, in(reg) off,
223+
options(nostack, preserves_flags, pure, readonly),
224+
);
225+
val as u16
226+
}
227+
unsafe fn read_u8(off: usize) -> u8 {
228+
let val: u32; // Avoid partial register issues
229+
asm!(
230+
"movzx {:e}, byte ptr gs:[{}]",
231+
lateout(reg) val, in(reg) off,
232+
options(nostack, preserves_flags, pure, readonly),
233+
);
234+
val as u8
235+
}
236+
unsafe fn write_u64(off: usize, val: u64) {
237+
asm!(
238+
"mov qword ptr gs:[{}], {}",
239+
in(reg) off, in(reg) val,
240+
options(nostack, preserves_flags),
241+
);
242+
}
243+
unsafe fn write_u32(off: usize, val: u32) {
244+
asm!(
245+
"mov dword ptr gs:[{}], {:e}",
246+
in(reg) off, in(reg) val,
247+
options(nostack, preserves_flags),
248+
);
249+
}
250+
unsafe fn write_u16(off: usize, val: u16) {
251+
asm!(
252+
"mov word ptr gs:[{}], {:x}",
253+
in(reg) off, in(reg) val,
254+
options(nostack, preserves_flags),
255+
);
256+
}
257+
unsafe fn write_u8(off: usize, val: u8) {
258+
asm!(
259+
"mov byte ptr gs:[{}], {}",
260+
in(reg) off, in(reg_byte) val,
261+
options(nostack, preserves_flags),
262+
);
263+
}
264+
}
265+
266+
#[inline]
267+
unsafe fn read_ptr<S: Segment + ?Sized>(off: usize, p: *mut u8, size: usize) {
268+
if size >= 8 {
269+
(p as *mut u64).write_unaligned(S::read_u64(off));
270+
read_ptr::<S>(off + 8, p.offset(8), size - 8);
271+
} else if size == 4 {
272+
(p as *mut u32).write_unaligned(S::read_u32(off));
273+
} else if size == 2 {
274+
(p as *mut u16).write_unaligned(S::read_u16(off));
275+
} else if size == 1 {
276+
p.write(S::read_u8(off));
277+
} else if size > 0 {
278+
read_cold::<S>(off, p, size);
279+
}
280+
}
281+
282+
#[cold]
283+
unsafe fn read_cold<S: Segment + ?Sized>(off: usize, p: *mut u8, size: usize) {
284+
match size {
285+
7 => {
286+
(p as *mut u32).write_unaligned(S::read_u32(off));
287+
(p.offset(4) as *mut u16).write_unaligned(S::read_u16(off + 4));
288+
p.offset(6).write(S::read_u8(off + 6));
289+
}
290+
6 => {
291+
(p as *mut u32).write_unaligned(S::read_u32(off));
292+
(p.offset(4) as *mut u16).write_unaligned(S::read_u16(off + 4));
293+
}
294+
5 => {
295+
(p as *mut u32).write_unaligned(S::read_u32(off));
296+
p.offset(4).write(S::read_u8(off + 4));
297+
}
298+
3 => {
299+
(p as *mut u16).write_unaligned(S::read_u16(off));
300+
p.offset(2).write(S::read_u8(off + 2));
301+
}
302+
_ => core::hint::unreachable_unchecked(),
303+
}
304+
}
305+
306+
#[inline]
307+
unsafe fn write_ptr<S: Segment + ?Sized>(off: usize, p: *const u8, size: usize) {
308+
if size >= 8 {
309+
S::write_u64(off, (p as *const u64).read_unaligned());
310+
write_ptr::<S>(off + 8, p.offset(8), size - 8);
311+
} else if size == 4 {
312+
S::write_u32(off, (p as *const u32).read_unaligned());
313+
} else if size == 2 {
314+
S::write_u16(off, (p as *const u16).read_unaligned());
315+
} else if size == 1 {
316+
S::write_u8(off, p.read());
317+
} else if size > 0 {
318+
write_cold::<S>(off, p, size);
319+
}
320+
}
321+
322+
#[cold]
323+
unsafe fn write_cold<S: Segment + ?Sized>(off: usize, p: *const u8, size: usize) {
324+
match size {
325+
7 => {
326+
S::write_u32(off, (p as *const u32).read_unaligned());
327+
S::write_u16(off + 4, (p.offset(4) as *const u16).read_unaligned());
328+
S::write_u8(off + 6, p.offset(6).read_unaligned());
329+
}
330+
6 => {
331+
S::write_u32(off, (p as *const u32).read_unaligned());
332+
S::write_u16(off + 4, (p.offset(4) as *const u16).read_unaligned());
333+
}
334+
5 => {
335+
S::write_u32(off, (p as *const u32).read_unaligned());
336+
S::write_u8(off + 4, p.offset(4).read_unaligned());
337+
}
338+
3 => {
339+
S::write_u16(off, (p as *const u16).read_unaligned());
340+
S::write_u8(off + 2, p.offset(2).read_unaligned());
341+
}
342+
_ => core::hint::unreachable_unchecked(),
343+
}
344+
}

0 commit comments

Comments
 (0)