Skip to content

Commit e3aa1c2

Browse files
committed
use _continue_interrupt_flag to reduce binary size
1 parent 52d5185 commit e3aa1c2

File tree

2 files changed

+154
-103
lines changed

2 files changed

+154
-103
lines changed

riscv-rt/macros/src/lib.rs

Lines changed: 151 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -313,11 +313,122 @@ pub fn loop_global_asm(input: TokenStream) -> TokenStream {
313313
res.parse().unwrap()
314314
}
315315

316+
#[derive(Clone, Copy)]
316317
enum RiscvArch {
317318
Rv32,
318319
Rv64,
319320
}
320321

322+
const TRAP_SIZE: usize = 16;
323+
324+
#[rustfmt::skip]
325+
const TRAP_FRAME: [&str; TRAP_SIZE] = [
326+
"ra",
327+
"t0",
328+
"t1",
329+
"t2",
330+
"t3",
331+
"t4",
332+
"t5",
333+
"t6",
334+
"a0",
335+
"a1",
336+
"a2",
337+
"a3",
338+
"a4",
339+
"a5",
340+
"a6",
341+
"a7",
342+
];
343+
344+
fn store_trap<T: FnMut(&str) -> bool>(arch: RiscvArch, mut filter: T) -> String {
345+
let (width, store) = match arch {
346+
RiscvArch::Rv32 => (4, "sw"),
347+
RiscvArch::Rv64 => (8, "sd"),
348+
};
349+
let mut stores = Vec::new();
350+
for (i, reg) in TRAP_FRAME
351+
.iter()
352+
.enumerate()
353+
.filter(|(_, &reg)| filter(reg))
354+
{
355+
stores.push(format!("{store} {reg}, {i}*{width}(sp)"));
356+
}
357+
stores.join("\n")
358+
}
359+
360+
fn load_trap(arch: RiscvArch) -> String {
361+
let (width, load) = match arch {
362+
RiscvArch::Rv32 => (4, "lw"),
363+
RiscvArch::Rv64 => (8, "ld"),
364+
};
365+
let mut loads = Vec::new();
366+
for (i, reg) in TRAP_FRAME.iter().enumerate() {
367+
loads.push(format!("{load} {reg}, {i}*{width}(sp)"));
368+
}
369+
loads.join("\n")
370+
}
371+
372+
#[proc_macro]
373+
pub fn weak_start_trap_riscv32(_input: TokenStream) -> TokenStream {
374+
weak_start_trap(RiscvArch::Rv32)
375+
}
376+
377+
#[proc_macro]
378+
pub fn weak_start_trap_riscv64(_input: TokenStream) -> TokenStream {
379+
weak_start_trap(RiscvArch::Rv64)
380+
}
381+
382+
fn weak_start_trap(arch: RiscvArch) -> TokenStream {
383+
let width = match arch {
384+
RiscvArch::Rv32 => 4,
385+
RiscvArch::Rv64 => 8,
386+
};
387+
// ensure we do not break that sp is 16-byte aligned
388+
if (TRAP_SIZE * width) % 16 != 0 {
389+
return parse::Error::new(Span::call_site(), "Trap frame size must be 16-byte aligned")
390+
.to_compile_error()
391+
.into();
392+
}
393+
let store = store_trap(arch, |_| true);
394+
let load = load_trap(arch);
395+
396+
#[cfg(feature = "s-mode")]
397+
let ret = "sret";
398+
#[cfg(not(feature = "s-mode"))]
399+
let ret = "mret";
400+
401+
let instructions: proc_macro2::TokenStream = format!(
402+
"
403+
core::arch::global_asm!(
404+
\".section .trap, \\\"ax\\\"
405+
.align {width}
406+
.weak _start_trap
407+
_start_trap:
408+
addi sp, sp, - {TRAP_SIZE} * {width}
409+
{store}
410+
add a0, sp, zero
411+
jal ra, _start_trap_rust
412+
{load}
413+
addi sp, sp, {TRAP_SIZE} * {width}
414+
{ret}
415+
\");"
416+
)
417+
.parse()
418+
.unwrap();
419+
420+
#[cfg(feature = "v-trap")]
421+
let v_trap = v_trap::continue_interrupt_trap(arch);
422+
#[cfg(not(feature = "v-trap"))]
423+
let v_trap = proc_macro2::TokenStream::new();
424+
425+
quote!(
426+
#instructions
427+
#v_trap
428+
)
429+
.into()
430+
}
431+
321432
#[proc_macro_attribute]
322433
pub fn interrupt_riscv32(args: TokenStream, input: TokenStream) -> TokenStream {
323434
interrupt(args, input, RiscvArch::Rv32)
@@ -376,7 +487,7 @@ fn interrupt(args: TokenStream, input: TokenStream, _arch: RiscvArch) -> TokenSt
376487
#[cfg(not(feature = "v-trap"))]
377488
let start_trap = proc_macro2::TokenStream::new();
378489
#[cfg(feature = "v-trap")]
379-
let start_trap = v_trap::start_interrupt_trap_asm(ident, _arch);
490+
let start_trap = v_trap::start_interrupt_trap(ident, _arch);
380491

381492
quote!(
382493
#start_trap
@@ -390,45 +501,41 @@ fn interrupt(args: TokenStream, input: TokenStream, _arch: RiscvArch) -> TokenSt
390501
mod v_trap {
391502
use super::*;
392503

393-
const TRAP_SIZE: usize = 16;
394-
395-
#[rustfmt::skip]
396-
const TRAP_FRAME: [&str; TRAP_SIZE] = [
397-
"ra",
398-
"t0",
399-
"t1",
400-
"t2",
401-
"t3",
402-
"t4",
403-
"t5",
404-
"t6",
405-
"a0",
406-
"a1",
407-
"a2",
408-
"a3",
409-
"a4",
410-
"a5",
411-
"a6",
412-
"a7",
413-
];
414-
415-
pub(crate) fn start_interrupt_trap_asm(
504+
pub(crate) fn start_interrupt_trap(
416505
ident: &syn::Ident,
417506
arch: RiscvArch,
418507
) -> proc_macro2::TokenStream {
419-
let function = ident.to_string();
420-
let (width, store, load) = match arch {
421-
RiscvArch::Rv32 => (4, "sw", "lw"),
422-
RiscvArch::Rv64 => (8, "sd", "ld"),
508+
let interrupt = ident.to_string();
509+
let width = match arch {
510+
RiscvArch::Rv32 => 4,
511+
RiscvArch::Rv64 => 8,
423512
};
513+
let store = store_trap(arch, |r| r == "a0");
424514

425-
let (mut stores, mut loads) = (Vec::new(), Vec::new());
426-
for (i, r) in TRAP_FRAME.iter().enumerate() {
427-
stores.push(format!(" {store} {r}, {i}*{width}(sp)"));
428-
loads.push(format!(" {load} {r}, {i}*{width}(sp)"));
429-
}
430-
let store = stores.join("\n");
431-
let load = loads.join("\n");
515+
let instructions = format!(
516+
"
517+
core::arch::global_asm!(
518+
\".section .trap, \\\"ax\\\"
519+
.align {width}
520+
.global _start_{interrupt}_trap
521+
_start_{interrupt}_trap:
522+
addi sp, sp, -{TRAP_SIZE} * {width} // allocate space for trap frame
523+
{store} // store trap partially (only register a0)
524+
la a0, {interrupt} // load interrupt handler address into a0
525+
j _continue_interrupt_trap // jump to common part of interrupt trap
526+
\");"
527+
);
528+
529+
instructions.parse().unwrap()
530+
}
531+
532+
pub(crate) fn continue_interrupt_trap(arch: RiscvArch) -> proc_macro2::TokenStream {
533+
let width = match arch {
534+
RiscvArch::Rv32 => 4,
535+
RiscvArch::Rv64 => 8,
536+
};
537+
let store = store_trap(arch, |reg| reg != "a0");
538+
let load = load_trap(arch);
432539

433540
#[cfg(feature = "s-mode")]
434541
let ret = "sret";
@@ -439,16 +546,15 @@ mod v_trap {
439546
"
440547
core::arch::global_asm!(
441548
\".section .trap, \\\"ax\\\"
442-
.align {width}
443-
.global _start_{function}_trap
444-
_start_{function}_trap:
445-
addi sp, sp, - {TRAP_SIZE} * {width}
446-
{store}
447-
call {function}
448-
{load}
449-
addi sp, sp, {TRAP_SIZE} * {width}
450-
{ret}\"
451-
);"
549+
.align {width} // TODO is this necessary?
550+
.global _continue_interrupt_trap
551+
_continue_interrupt_trap:
552+
{store} // store trap partially (all registers except a0)
553+
jalr ra, a0, 0 // jump to corresponding interrupt handler (address stored in a0)
554+
{load} // restore trap frame
555+
addi sp, sp, {TRAP_SIZE} * {width} // deallocate space for trap frame
556+
{ret} // return from interrupt
557+
\");"
452558
);
453559

454560
instructions.parse().unwrap()

riscv-rt/src/asm.rs

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -277,65 +277,10 @@ _pre_init_trap:
277277
j _pre_init_trap",
278278
);
279279

280-
/// Trap entry point (_start_trap). It saves caller saved registers, calls
281-
/// _start_trap_rust, restores caller saved registers and then returns.
282-
///
283-
/// # Usage
284-
///
285-
/// The macro takes 5 arguments:
286-
/// - `$STORE`: the instruction used to store a register in the stack (e.g. `sd` for riscv64)
287-
/// - `$LOAD`: the instruction used to load a register from the stack (e.g. `ld` for riscv64)
288-
/// - `$BYTES`: the number of bytes used to store a register (e.g. 8 for riscv64)
289-
/// - `$TRAP_SIZE`: the number of registers to store in the stack (e.g. 32 for all the user registers)
290-
/// - list of tuples of the form `($REG, $LOCATION)`, where:
291-
/// - `$REG`: the register to store/load
292-
/// - `$LOCATION`: the location in the stack where to store/load the register
293-
#[rustfmt::skip]
294-
macro_rules! trap_handler {
295-
($STORE:ident, $LOAD:ident, $BYTES:literal, $TRAP_SIZE:literal, [$(($REG:ident, $LOCATION:literal)),*]) => {
296-
// ensure we do not break that sp is 16-byte aligned
297-
const _: () = assert!(($TRAP_SIZE * $BYTES) % 16 == 0);
298-
global_asm!(
299-
"
300-
.section .trap, \"ax\"
301-
.weak _start_trap
302-
_start_trap:",
303-
// save space for trap handler in stack
304-
concat!("addi sp, sp, -", stringify!($TRAP_SIZE * $BYTES)),
305-
// save registers in the desired order
306-
$(concat!(stringify!($STORE), " ", stringify!($REG), ", ", stringify!($LOCATION * $BYTES), "(sp)"),)*
307-
// call rust trap handler
308-
"add a0, sp, zero
309-
jal ra, _start_trap_rust",
310-
// restore registers in the desired order
311-
$(concat!(stringify!($LOAD), " ", stringify!($REG), ", ", stringify!($LOCATION * $BYTES), "(sp)"),)*
312-
// free stack
313-
concat!("addi sp, sp, ", stringify!($TRAP_SIZE * $BYTES)),
314-
);
315-
cfg_global_asm!(
316-
// return from trap
317-
#[cfg(feature = "s-mode")]
318-
"sret",
319-
#[cfg(not(feature = "s-mode"))]
320-
"mret",
321-
);
322-
};
323-
}
324-
325-
#[rustfmt::skip]
326280
#[cfg(riscv32)]
327-
trap_handler!(
328-
sw, lw, 4, 16,
329-
[(ra, 0), (t0, 1), (t1, 2), (t2, 3), (t3, 4), (t4, 5), (t5, 6), (t6, 7),
330-
(a0, 8), (a1, 9), (a2, 10), (a3, 11), (a4, 12), (a5, 13), (a6, 14), (a7, 15)]
331-
);
332-
#[rustfmt::skip]
281+
riscv_rt_macros::weak_start_trap_riscv32!();
333282
#[cfg(riscv64)]
334-
trap_handler!(
335-
sd, ld, 8, 16,
336-
[(ra, 0), (t0, 1), (t1, 2), (t2, 3), (t3, 4), (t4, 5), (t5, 6), (t6, 7),
337-
(a0, 8), (a1, 9), (a2, 10), (a3, 11), (a4, 12), (a5, 13), (a6, 14), (a7, 15)]
338-
);
283+
riscv_rt_macros::weak_start_trap_riscv64!();
339284

340285
#[cfg(feature = "v-trap")]
341286
cfg_global_asm!(
@@ -345,7 +290,7 @@ cfg_global_asm!(
345290
.type _vector_table, @function
346291
347292
.option push
348-
.balign 0x100 // TODO check if this is the correct alignment
293+
.balign 0x4 // TODO check if this is the correct alignment
349294
.option norelax
350295
.option norvc
351296

0 commit comments

Comments
 (0)