Procházet zdrojové kódy

riscv64, trap: rework to fix nested captured traps

The previous implementation has some bugs inside that will cause kernel
space nested traps to lose some required information:

- In kernel mode, trap contexts are saved above the current stack frame
  without exception, which is not what we want. We expect to read the
  trap data in the CAPTURED context.
- The capturer task context is not saved as well, which will mess up the
  nested traps completely.
- We are reading page fault virtual addresses in TrapContext::trap_type,
  which won't work since if the inner trap is captured, and the outer
  trap interleaves with the trap_type() call, we will lose the stval
  data in the inner trap.

The solution is to separate our "normal" trap handling procedure out of
captured trap handling procedure. We swap the stvec CSR when we set up
captured traps and restore it afterwards so the two approach don't have
to tell then apart in trap entries. Then, we can store the TrapContext
pointer in sscratch without having to distinguish between trap handling
types. In the way, we keep the procedure simple.

The register stval is saved together with other registers to be used in
page faults.

Signed-off-by: greatbridf <greatbridf@icloud.com>
greatbridf před 6 měsíci
rodič
revize
661a15940b

+ 0 - 8
crates/eonix_hal/src/arch/riscv64/bootstrap.rs

@@ -3,7 +3,6 @@ use super::{
     console::write_str,
     cpu::{CPUID, CPU_COUNT},
     time::set_next_timer,
-    trap::TRAP_SCRATCH,
 };
 use crate::{
     arch::{
@@ -234,13 +233,6 @@ fn setup_cpu(alloc: impl PageAlloc, hart_id: usize) {
     }
 
     percpu_area.register(cpu.cpuid());
-
-    unsafe {
-        // SAFETY: Interrupts are disabled.
-        TRAP_SCRATCH
-            .as_mut()
-            .set_kernel_tp(PercpuArea::get_for(cpu.cpuid()).unwrap().cast());
-    }
 }
 
 fn get_ap_start_addr() -> usize {

+ 10 - 3
crates/eonix_hal/src/arch/riscv64/cpu.rs

@@ -1,9 +1,13 @@
 use super::{
     interrupt::InterruptControl,
-    trap::{setup_trap, TRAP_SCRATCH},
+    trap::{setup_trap, TrapContext},
 };
 use crate::arch::fdt::{FdtExt, FDT};
-use core::{arch::asm, pin::Pin, ptr::NonNull, sync::atomic::AtomicUsize};
+use core::{
+    arch::asm, cell::UnsafeCell, mem::MaybeUninit, pin::Pin, ptr::NonNull,
+    sync::atomic::AtomicUsize,
+};
+use eonix_hal_traits::trap::RawTrapContext;
 use eonix_preempt::PreemptGuard;
 use eonix_sync_base::LazyLock;
 use riscv::register::{
@@ -17,6 +21,9 @@ pub static CPU_COUNT: AtomicUsize = AtomicUsize::new(0);
 #[eonix_percpu::define_percpu]
 pub static CPUID: usize = 0;
 
+#[eonix_percpu::define_percpu]
+static DEFAULT_TRAP_CONTEXT: MaybeUninit<TrapContext> = MaybeUninit::uninit();
+
 #[eonix_percpu::define_percpu]
 static LOCAL_CPU: LazyLock<CPU> = LazyLock::new(|| CPU::new(CPUID.get()));
 
@@ -56,7 +63,7 @@ impl CPU {
         interrupt.init();
 
         sstatus::set_sum();
-        sscratch::write(TRAP_SCRATCH.as_ptr() as usize);
+        sscratch::write(DEFAULT_TRAP_CONTEXT.as_ptr() as usize);
     }
 
     pub unsafe fn load_interrupt_stack(self: Pin<&mut Self>, sp: u64) {}

+ 177 - 0
crates/eonix_hal/src/arch/riscv64/trap/captured.rs

@@ -0,0 +1,177 @@
+use crate::{arch::trap::Registers, context::TaskContext, trap::TrapContext};
+use core::{arch::naked_asm, mem::MaybeUninit};
+use eonix_hal_traits::context::RawTaskContext;
+
+static mut DIRTY_TASK_CONTEXT: MaybeUninit<TaskContext> = MaybeUninit::uninit();
+
+// If captured trap context is present, we use it directly.
+// We need to restore the kernel tp from that TrapContext but sp is
+// fine since we will use TaskContext::switch.
+#[unsafe(naked)]
+pub(super) unsafe extern "C" fn _captured_trap_entry() -> ! {
+    naked_asm!(
+        "csrrw t0, sscratch, t0",
+        "sd    tp, {tp}(t0)",
+        "ld    tp, {ra}(t0)", // Load kernel tp from trap_ctx.ra
+        "sd    ra, {ra}(t0)",
+        "ld    ra, {sp}(t0)", // Load capturer task context from trap_ctx.sp
+        "sd    sp, {sp}(t0)",
+        "sd    gp, {gp}(t0)",
+        "sd    a0, {a0}(t0)",
+        "sd    a1, {a1}(t0)",
+        "sd    a2, {a2}(t0)",
+        "sd    a3, {a3}(t0)",
+        "sd    a4, {a4}(t0)",
+        "sd    t1, {t1}(t0)",
+        "sd    a5, {a5}(t0)",
+        "sd    a6, {a6}(t0)",
+        "sd    a7, {a7}(t0)",
+        "sd    t3, {t3}(t0)",
+        "sd    t4, {t4}(t0)",
+        "sd    t5, {t5}(t0)",
+        "sd    t2, {t2}(t0)",
+        "sd    t6, {t6}(t0)",
+        "sd    s0, {s0}(t0)",
+        "sd    s1, {s1}(t0)",
+        "sd    s2, {s2}(t0)",
+        "sd    s3, {s3}(t0)",
+        "sd    s4, {s4}(t0)",
+        "sd    s5, {s5}(t0)",
+        "sd    s6, {s6}(t0)",
+        "sd    s7, {s7}(t0)",
+        "sd    s8, {s8}(t0)",
+        "sd    s9, {s9}(t0)",
+        "sd    s10, {s10}(t0)",
+        "sd    s11, {s11}(t0)",
+        "csrr  t2, sstatus",
+        "csrr  t3, sepc",
+        "csrr  t4, scause",
+        "csrr  t5, stval",
+        "csrrw t6, sscratch, t0",
+        "sd    t6, {t0}(t0)",
+        "sd    t2, {sstatus}(t0)",
+        "sd    t3, {sepc}(t0)",
+        "sd    t4, {scause}(t0)",
+        "sd    t5, {stval}(t0)",
+        "la    a0, {dirty_task_context}",
+        "mv    a1, ra",
+        "j     {task_context_switch}",
+        ra = const Registers::OFFSET_RA,
+        sp = const Registers::OFFSET_SP,
+        gp = const Registers::OFFSET_GP,
+        tp = const Registers::OFFSET_TP,
+        t1 = const Registers::OFFSET_T1,
+        t2 = const Registers::OFFSET_T2,
+        t0 = const Registers::OFFSET_T0,
+        a0 = const Registers::OFFSET_A0,
+        a1 = const Registers::OFFSET_A1,
+        a2 = const Registers::OFFSET_A2,
+        a3 = const Registers::OFFSET_A3,
+        a4 = const Registers::OFFSET_A4,
+        a5 = const Registers::OFFSET_A5,
+        a6 = const Registers::OFFSET_A6,
+        a7 = const Registers::OFFSET_A7,
+        t3 = const Registers::OFFSET_T3,
+        t4 = const Registers::OFFSET_T4,
+        t5 = const Registers::OFFSET_T5,
+        t6 = const Registers::OFFSET_T6,
+        s0 = const Registers::OFFSET_S0,
+        s1 = const Registers::OFFSET_S1,
+        s2 = const Registers::OFFSET_S2,
+        s3 = const Registers::OFFSET_S3,
+        s4 = const Registers::OFFSET_S4,
+        s5 = const Registers::OFFSET_S5,
+        s6 = const Registers::OFFSET_S6,
+        s7 = const Registers::OFFSET_S7,
+        s8 = const Registers::OFFSET_S8,
+        s9 = const Registers::OFFSET_S9,
+        s10 = const Registers::OFFSET_S10,
+        s11 = const Registers::OFFSET_S11,
+        sstatus = const TrapContext::OFFSET_SSTATUS,
+        sepc = const TrapContext::OFFSET_SEPC,
+        scause = const TrapContext::OFFSET_SCAUSE,
+        stval = const TrapContext::OFFSET_STVAL,
+        dirty_task_context = sym DIRTY_TASK_CONTEXT,
+        task_context_switch = sym TaskContext::switch,
+    );
+}
+
+#[unsafe(naked)]
+pub(super) unsafe extern "C" fn _captured_trap_return(ctx: &mut TrapContext) -> ! {
+    naked_asm!(
+        "csrr   t0,  sscratch",
+        "ld     t1,  {sstatus}(t0)",
+        "ld     t2,  {sepc}(t0)",
+        "csrw   sstatus, t1",
+        "csrw   sepc, t2",
+        "mv     t4,  tp",
+        "mv     t5,  sp",
+        "ld     tp,  {tp}(t0)",
+        "ld     ra,  {ra}(t0)",
+        "ld     sp,  {sp}(t0)",
+        "sd     t4,  {ra}(t0)", // Store kernel tp to trap_ctx.ra
+        "sd     t5,  {sp}(t0)", // Store capturer task context to trap_ctx.sp
+        "ld     gp,  {gp}(t0)",
+        "ld     a0,  {a0}(t0)",
+        "ld     a1,  {a1}(t0)",
+        "ld     a2,  {a2}(t0)",
+        "ld     a3,  {a3}(t0)",
+        "ld     a4,  {a4}(t0)",
+        "ld     t1,  {t1}(t0)",
+        "ld     a5,  {a5}(t0)",
+        "ld     a6,  {a6}(t0)",
+        "ld     a7,  {a7}(t0)",
+        "ld     t3,  {t3}(t0)",
+        "ld     t4,  {t4}(t0)",
+        "ld     t5,  {t5}(t0)",
+        "ld     t2,  {t2}(t0)",
+        "ld     t6,  {t6}(t0)",
+        "ld     s0,  {s0}(t0)",
+        "ld     s1,  {s1}(t0)",
+        "ld     s2,  {s2}(t0)",
+        "ld     s3,  {s3}(t0)",
+        "ld     s4,  {s4}(t0)",
+        "ld     s5,  {s5}(t0)",
+        "ld     s6,  {s6}(t0)",
+        "ld     s7,  {s7}(t0)",
+        "ld     s8,  {s8}(t0)",
+        "ld     s9,  {s9}(t0)",
+        "ld     s10, {s10}(t0)",
+        "ld     s11, {s11}(t0)",
+        "ld     t0,  {t0}(t0)",
+        "sret",
+        ra = const Registers::OFFSET_RA,
+        sp = const Registers::OFFSET_SP,
+        gp = const Registers::OFFSET_GP,
+        tp = const Registers::OFFSET_TP,
+        t1 = const Registers::OFFSET_T1,
+        t2 = const Registers::OFFSET_T2,
+        t0 = const Registers::OFFSET_T0,
+        a0 = const Registers::OFFSET_A0,
+        a1 = const Registers::OFFSET_A1,
+        a2 = const Registers::OFFSET_A2,
+        a3 = const Registers::OFFSET_A3,
+        a4 = const Registers::OFFSET_A4,
+        a5 = const Registers::OFFSET_A5,
+        a6 = const Registers::OFFSET_A6,
+        a7 = const Registers::OFFSET_A7,
+        t3 = const Registers::OFFSET_T3,
+        t4 = const Registers::OFFSET_T4,
+        t5 = const Registers::OFFSET_T5,
+        t6 = const Registers::OFFSET_T6,
+        s0 = const Registers::OFFSET_S0,
+        s1 = const Registers::OFFSET_S1,
+        s2 = const Registers::OFFSET_S2,
+        s3 = const Registers::OFFSET_S3,
+        s4 = const Registers::OFFSET_S4,
+        s5 = const Registers::OFFSET_S5,
+        s6 = const Registers::OFFSET_S6,
+        s7 = const Registers::OFFSET_S7,
+        s8 = const Registers::OFFSET_S8,
+        s9 = const Registers::OFFSET_S9,
+        s10 = const Registers::OFFSET_S10,
+        s11 = const Registers::OFFSET_S11,
+        sstatus = const TrapContext::OFFSET_SSTATUS,
+        sepc = const TrapContext::OFFSET_SEPC,
+    );
+}

+ 134 - 0
crates/eonix_hal/src/arch/riscv64/trap/default.rs

@@ -0,0 +1,134 @@
+use super::Registers;
+use crate::trap::TrapContext;
+use core::arch::naked_asm;
+
+unsafe extern "C" {
+    fn _default_trap_handler(trap_context: &mut TrapContext);
+}
+
+#[unsafe(naked)]
+pub(super) unsafe extern "C" fn _default_trap_entry() -> ! {
+    naked_asm!(
+        "csrrw t0,      sscratch, t0",
+        "sd    tp,      {tp}(t0)",
+        "sd    ra,      {ra}(t0)",
+        "sd    sp,      {sp}(t0)",
+        "sd    gp,      {gp}(t0)",
+        "sd    a0,      {a0}(t0)",
+        "sd    a1,      {a1}(t0)",
+        "sd    a2,      {a2}(t0)",
+        "sd    a3,      {a3}(t0)",
+        "sd    a4,      {a4}(t0)",
+        "sd    t1,      {t1}(t0)",
+        "sd    a5,      {a5}(t0)",
+        "sd    a6,      {a6}(t0)",
+        "sd    a7,      {a7}(t0)",
+        "sd    t3,      {t3}(t0)",
+        "sd    t4,      {t4}(t0)",
+        "sd    t5,      {t5}(t0)",
+        "sd    t2,      {t2}(t0)",
+        "sd    t6,      {t6}(t0)",
+        "sd    s0,      {s0}(t0)",
+        "sd    s1,      {s1}(t0)",
+        "sd    s2,      {s2}(t0)",
+        "sd    s3,      {s3}(t0)",
+        "sd    s4,      {s4}(t0)",
+        "sd    s5,      {s5}(t0)",
+        "sd    s6,      {s6}(t0)",
+        "sd    s7,      {s7}(t0)",
+        "sd    s8,      {s8}(t0)",
+        "sd    s9,      {s9}(t0)",
+        "sd    s10,     {s10}(t0)",
+        "sd    s11,     {s11}(t0)",
+        "mv    a0,      t0",
+        "csrrw t0,      sscratch, t0",
+        "sd    t0,      {t0}(a0)",
+        "csrr  t0,      sepc",
+        "csrr  t1,      scause",
+        "csrr  t2,      sstatus",
+        "csrr  t3,      stval",
+        "sd    t0,      {sepc}(a0)",
+        "sd    t1,      {scause}(a0)",
+        "sd    t2,      {sstatus}(a0)",
+        "sd    t3,      {stval}(a0)",
+
+        "la    t0,      {default_trap_handler}",
+        "jalr  t0",
+
+        "csrr  t0,      sscratch",
+        "ld    t1,      {sepc}(t0)",
+        "ld    t2,      {sstatus}(t0)",
+        "ld    tp,      {tp}(t0)",
+        "ld    ra,      {ra}(t0)",
+        "ld    sp,      {sp}(t0)",
+        "ld    gp,      {gp}(t0)",
+        "ld    a0,      {a0}(t0)",
+        "ld    a1,      {a1}(t0)",
+        "ld    a2,      {a2}(t0)",
+        "ld    a3,      {a3}(t0)",
+        "ld    a4,      {a4}(t0)",
+
+        "csrw  sepc,    t1",
+        "csrw  sstatus, t2",
+
+        "ld    t1,      {t1}(t0)",
+        "ld    a5,      {a5}(t0)",
+        "ld    a6,      {a6}(t0)",
+        "ld    a7,      {a7}(t0)",
+        "ld    t3,      {t3}(t0)",
+        "ld    t4,      {t4}(t0)",
+        "ld    t5,      {t5}(t0)",
+        "ld    t2,      {t2}(t0)",
+        "ld    t6,      {t6}(t0)",
+        "ld    s0,      {s0}(t0)",
+        "ld    s1,      {s1}(t0)",
+        "ld    s2,      {s2}(t0)",
+        "ld    s3,      {s3}(t0)",
+        "ld    s4,      {s4}(t0)",
+        "ld    s5,      {s5}(t0)",
+        "ld    s6,      {s6}(t0)",
+        "ld    s7,      {s7}(t0)",
+        "ld    s8,      {s8}(t0)",
+        "ld    s9,      {s9}(t0)",
+        "ld    s10,     {s10}(t0)",
+        "ld    s11,     {s11}(t0)",
+        "ld    t0,      {t0}(t0)",
+        "sret",
+        tp = const Registers::OFFSET_TP,
+        ra = const Registers::OFFSET_RA,
+        sp = const Registers::OFFSET_SP,
+        gp = const Registers::OFFSET_GP,
+        t0 = const Registers::OFFSET_T0,
+        t1 = const Registers::OFFSET_T1,
+        t2 = const Registers::OFFSET_T2,
+        t3 = const Registers::OFFSET_T3,
+        t4 = const Registers::OFFSET_T4,
+        t5 = const Registers::OFFSET_T5,
+        t6 = const Registers::OFFSET_T6,
+        a0 = const Registers::OFFSET_A0,
+        a1 = const Registers::OFFSET_A1,
+        a2 = const Registers::OFFSET_A2,
+        a3 = const Registers::OFFSET_A3,
+        a4 = const Registers::OFFSET_A4,
+        a5 = const Registers::OFFSET_A5,
+        a6 = const Registers::OFFSET_A6,
+        a7 = const Registers::OFFSET_A7,
+        s0 = const Registers::OFFSET_S0,
+        s1 = const Registers::OFFSET_S1,
+        s2 = const Registers::OFFSET_S2,
+        s3 = const Registers::OFFSET_S3,
+        s4 = const Registers::OFFSET_S4,
+        s5 = const Registers::OFFSET_S5,
+        s6 = const Registers::OFFSET_S6,
+        s7 = const Registers::OFFSET_S7,
+        s8 = const Registers::OFFSET_S8,
+        s9 = const Registers::OFFSET_S9,
+        s10 = const Registers::OFFSET_S10,
+        s11 = const Registers::OFFSET_S11,
+        sepc = const TrapContext::OFFSET_SEPC,
+        scause = const TrapContext::OFFSET_SCAUSE,
+        sstatus = const TrapContext::OFFSET_SSTATUS,
+        stval = const TrapContext::OFFSET_STVAL,
+        default_trap_handler = sym _default_trap_handler,
+    );
+}

+ 21 - 270
crates/eonix_hal/src/arch/riscv64/trap/mod.rs

@@ -1,18 +1,22 @@
+mod captured;
+mod default;
 mod trap_context;
 
 use super::config::platform::virt::*;
 use super::context::TaskContext;
+use captured::{_captured_trap_entry, _captured_trap_return};
 use core::arch::{global_asm, naked_asm};
 use core::mem::{offset_of, size_of};
 use core::num::NonZero;
 use core::ptr::NonNull;
+use default::_default_trap_entry;
 use eonix_hal_traits::{
     context::RawTaskContext,
     trap::{IrqState as IrqStateTrait, TrapReturn},
 };
 use riscv::register::sstatus::{self, Sstatus};
 use riscv::register::stvec::TrapMode;
-use riscv::register::{scause, sepc, stval};
+use riscv::register::{scause, sepc, sscratch, stval};
 use riscv::{
     asm::sfence_vma_all,
     register::stvec::{self, Stvec},
@@ -21,288 +25,35 @@ use sbi::SbiError;
 
 pub use trap_context::*;
 
-#[repr(C)]
-pub struct TrapScratch {
-    t1: u64,
-    t2: u64,
-    kernel_tp: Option<NonZero<u64>>,
-    trap_context: Option<NonNull<TrapContext>>,
-    handler: unsafe extern "C" fn(),
-    capturer_context: TaskContext,
-}
-
-#[eonix_percpu::define_percpu]
-pub(crate) static TRAP_SCRATCH: TrapScratch = TrapScratch {
-    t1: 0,
-    t2: 0,
-    kernel_tp: None,
-    trap_context: None,
-    handler: default_trap_handler,
-    capturer_context: TaskContext::new(),
-};
-
-static mut DIRTY_TASK_CONTEXT: TaskContext = TaskContext::new();
-
-#[unsafe(naked)]
-unsafe extern "C" fn _raw_trap_entry() -> ! {
-    naked_asm!(
-        "csrrw t0, sscratch, t0", // Swap t0 and sscratch
-        "sd    t1, 0(t0)",
-        "sd    t2, 8(t0)",
-        "csrr  t1, sstatus",
-        "andi  t1, t1, 0x100",
-        "beqz  t1, 2f",
-        // else SPP = 1, supervisor mode
-        "addi  t1, sp, -{trap_context_size}",
-        "mv    t2, tp",
-        "sd    ra, {ra}(t1)",
-        "sd    sp, {sp}(t1)",
-        "mv    sp, t1",
-        "j     4f",
-        // SPP = 0, user mode
-        "2:",
-        "ld    t1, 24(t0)", // Load captured TrapContext address
-        "mv    t2, tp",
-        "ld    tp, 16(t0)", // Restore kernel tp
-        // t0: &mut TrapScratch, t1: &mut TrapContext, t2: tp before trap
-        "3:",
-        "sd    ra, {ra}(t1)",
-        "sd    sp, {sp}(t1)",
-        "4:",
-        "sd    gp, {gp}(t1)",
-        "sd    t2, {tp}(t1)",
-        "ld    ra, 0(t0)",
-        "ld    t2, 8(t0)",
-        "sd    ra, {t1}(t1)",     // Save t1
-        "sd    t2, {t2}(t1)",     // Save t2
-        "ld    ra, 32(t0)",       // Load handler address
-        "csrrw t2, sscratch, t0", // Swap t0 and sscratch
-        "sd    t2, {t0}(t1)",
-        "sd    a0, {a0}(t1)",
-        "sd    a1, {a1}(t1)",
-        "sd    a2, {a2}(t1)",
-        "sd    a3, {a3}(t1)",
-        "sd    a4, {a4}(t1)",
-        "sd    a5, {a5}(t1)",
-        "sd    a6, {a6}(t1)",
-        "sd    a7, {a7}(t1)",
-        "sd    t3, {t3}(t1)",
-        "sd    t4, {t4}(t1)",
-        "sd    t5, {t5}(t1)",
-        "sd    t6, {t6}(t1)",
-        "sd    s0, {s0}(t1)",
-        "sd    s1, {s1}(t1)",
-        "sd    s2, {s2}(t1)",
-        "sd    s3, {s3}(t1)",
-        "sd    s4, {s4}(t1)",
-        "sd    s5, {s5}(t1)",
-        "sd    s6, {s6}(t1)",
-        "sd    s7, {s7}(t1)",
-        "sd    s8, {s8}(t1)",
-        "sd    s9, {s9}(t1)",
-        "sd    s10, {s10}(t1)",
-        "sd    s11, {s11}(t1)",
-        "csrr  t2, sstatus",
-        "csrr  t3, sepc",
-        "csrr  t4, scause",
-        "sd    t2, {sstatus}(t1)",
-        "sd    t3, {sepc}(t1)",
-        "sd    t4, {scause}(t1)",
-        "ret",
-        trap_context_size = const size_of::<TrapContext>(),
-        ra = const Registers::OFFSET_RA,
-        sp = const Registers::OFFSET_SP,
-        gp = const Registers::OFFSET_GP,
-        tp = const Registers::OFFSET_TP,
-        t1 = const Registers::OFFSET_T1,
-        t2 = const Registers::OFFSET_T2,
-        t0 = const Registers::OFFSET_T0,
-        a0 = const Registers::OFFSET_A0,
-        a1 = const Registers::OFFSET_A1,
-        a2 = const Registers::OFFSET_A2,
-        a3 = const Registers::OFFSET_A3,
-        a4 = const Registers::OFFSET_A4,
-        a5 = const Registers::OFFSET_A5,
-        a6 = const Registers::OFFSET_A6,
-        a7 = const Registers::OFFSET_A7,
-        t3 = const Registers::OFFSET_T3,
-        t4 = const Registers::OFFSET_T4,
-        t5 = const Registers::OFFSET_T5,
-        t6 = const Registers::OFFSET_T6,
-        s0 = const Registers::OFFSET_S0,
-        s1 = const Registers::OFFSET_S1,
-        s2 = const Registers::OFFSET_S2,
-        s3 = const Registers::OFFSET_S3,
-        s4 = const Registers::OFFSET_S4,
-        s5 = const Registers::OFFSET_S5,
-        s6 = const Registers::OFFSET_S6,
-        s7 = const Registers::OFFSET_S7,
-        s8 = const Registers::OFFSET_S8,
-        s9 = const Registers::OFFSET_S9,
-        s10 = const Registers::OFFSET_S10,
-        s11 = const Registers::OFFSET_S11,
-        sstatus = const TrapContext::OFFSET_SSTATUS,
-        sepc = const TrapContext::OFFSET_SEPC,
-        scause = const TrapContext::OFFSET_SCAUSE,
-    );
-}
-
-#[unsafe(naked)]
-unsafe extern "C" fn _raw_trap_return(ctx: &mut TrapContext) -> ! {
-    naked_asm!(
-        "ld ra, {ra}(a0)",
-        "ld sp, {sp}(a0)",
-        "ld gp, {gp}(a0)",
-        "ld tp, {tp}(a0)",
-        "ld t1, {t1}(a0)",
-        "ld t2, {t2}(a0)",
-        "ld t0, {t0}(a0)",
-        "ld a1, {a1}(a0)",
-        "ld a2, {a2}(a0)",
-        "ld a3, {a3}(a0)",
-        "ld a4, {a4}(a0)",
-        "ld a5, {a5}(a0)",
-        "ld a6, {a6}(a0)",
-        "ld a7, {a7}(a0)",
-        "ld t3, {t3}(a0)",
-        "ld t4, {sepc}(a0)",    // Load sepc from TrapContext
-        "ld t5, {sstatus}(a0)", // Load sstatus from TrapContext
-        "ld s0, {s0}(a0)",
-        "ld s1, {s1}(a0)",
-        "ld s2, {s2}(a0)",
-        "ld s3, {s3}(a0)",
-        "ld s4, {s4}(a0)",
-        "ld s5, {s5}(a0)",
-        "ld s6, {s6}(a0)",
-        "ld s7, {s7}(a0)",
-        "ld s8, {s8}(a0)",
-        "ld s9, {s9}(a0)",
-        "ld s10, {s10}(a0)",
-        "ld s11, {s11}(a0)",
-        "csrw sepc, t4",        // Restore sepc
-        "csrw sstatus, t5",     // Restore sstatus
-        "ld t4, {t4}(a0)",
-        "ld t5, {t5}(a0)",
-        "ld t6, {t6}(a0)",
-        "ld a0, {a0}(a0)",
-        "sret",
-        ra = const Registers::OFFSET_RA,
-        sp = const Registers::OFFSET_SP,
-        gp = const Registers::OFFSET_GP,
-        tp = const Registers::OFFSET_TP,
-        t1 = const Registers::OFFSET_T1,
-        t2 = const Registers::OFFSET_T2,
-        t0 = const Registers::OFFSET_T0,
-        a0 = const Registers::OFFSET_A0,
-        a1 = const Registers::OFFSET_A1,
-        a2 = const Registers::OFFSET_A2,
-        a3 = const Registers::OFFSET_A3,
-        a4 = const Registers::OFFSET_A4,
-        a5 = const Registers::OFFSET_A5,
-        a6 = const Registers::OFFSET_A6,
-        a7 = const Registers::OFFSET_A7,
-        t3 = const Registers::OFFSET_T3,
-        t4 = const Registers::OFFSET_T4,
-        t5 = const Registers::OFFSET_T5,
-        t6 = const Registers::OFFSET_T6,
-        s0 = const Registers::OFFSET_S0,
-        s1 = const Registers::OFFSET_S1,
-        s2 = const Registers::OFFSET_S2,
-        s3 = const Registers::OFFSET_S3,
-        s4 = const Registers::OFFSET_S4,
-        s5 = const Registers::OFFSET_S5,
-        s6 = const Registers::OFFSET_S6,
-        s7 = const Registers::OFFSET_S7,
-        s8 = const Registers::OFFSET_S8,
-        s9 = const Registers::OFFSET_S9,
-        s10 = const Registers::OFFSET_S10,
-        s11 = const Registers::OFFSET_S11,
-        sstatus = const TrapContext::OFFSET_SSTATUS,
-        sepc = const TrapContext::OFFSET_SEPC,
-    );
-}
-
-#[unsafe(naked)]
-unsafe extern "C" fn default_trap_handler() {
-    unsafe extern "C" {
-        fn _default_trap_handler(trap_context: &mut TrapContext);
-    }
-
-    naked_asm!(
-        "andi sp, sp, -16", // Align stack pointer to 16 bytes
-        "addi sp, sp, -16",
-        "mv   a0, t1",      // TrapContext pointer in t1
-        "sd   a0, 0(sp)",   // Save TrapContext pointer
-        "",
-        "call {default_handler}",
-        "",
-        "ld   a0, 0(sp)",   // Restore TrapContext pointer
-        "j {trap_return}",
-        default_handler = sym _default_trap_handler,
-        trap_return = sym _raw_trap_return,
-    );
-}
-
-#[unsafe(naked)]
-unsafe extern "C" fn captured_trap_handler() {
-    naked_asm!(
-        "la   a0, {dirty_task_context}",
-        "addi a1, t0, {capturer_context_offset}",
-        "j {switch}",
-        dirty_task_context = sym DIRTY_TASK_CONTEXT,
-        capturer_context_offset = const offset_of!(TrapScratch, capturer_context),
-        switch = sym TaskContext::switch,
-    );
-}
-
-#[unsafe(naked)]
-unsafe extern "C" fn captured_trap_return(trap_context: usize) -> ! {
-    naked_asm!(
-        "mv a0, sp",
-        "j {raw_trap_return}",
-        raw_trap_return = sym _raw_trap_return,
-    );
-}
-
-impl TrapScratch {
-    pub fn set_trap_context(&mut self, ctx: NonNull<TrapContext>) {
-        self.trap_context = Some(ctx);
-    }
-
-    pub fn clear_trap_context(&mut self) {
-        self.trap_context = None;
-    }
-
-    pub fn set_kernel_tp(&mut self, tp: NonNull<u8>) {
-        self.kernel_tp = Some(NonZero::new(tp.addr().get() as u64).unwrap());
-    }
-}
-
 impl TrapReturn for TrapContext {
     type TaskContext = TaskContext;
 
     unsafe fn trap_return(&mut self) {
         let irq_states = disable_irqs_save();
 
-        let old_handler =
-            core::mem::replace(&mut TRAP_SCRATCH.as_mut().handler, captured_trap_handler);
+        let old_stvec = stvec::read();
+        stvec::write({
+            let mut stvec_val = Stvec::from_bits(0);
+            stvec_val.set_address(_captured_trap_entry as usize);
+            stvec_val.set_trap_mode(TrapMode::Direct);
+            stvec_val
+        });
 
-        let old_trap_context = core::mem::replace(
-            &mut TRAP_SCRATCH.as_mut().trap_context,
-            Some(NonNull::from(&mut *self)),
-        );
+        let old_trap_ctx = sscratch::read();
+        sscratch::write(&raw mut *self as usize);
 
+        let mut from_ctx = TaskContext::new();
         let mut to_ctx = TaskContext::new();
-        to_ctx.set_program_counter(captured_trap_return as usize);
-        to_ctx.set_stack_pointer(&raw mut *self as usize);
+        to_ctx.set_program_counter(_captured_trap_return as usize);
+        to_ctx.set_stack_pointer(&raw mut from_ctx as usize);
         to_ctx.set_interrupt_enabled(false);
 
         unsafe {
-            TaskContext::switch(&mut TRAP_SCRATCH.as_mut().capturer_context, &mut to_ctx);
+            TaskContext::switch(&mut from_ctx, &mut to_ctx);
         }
 
-        TRAP_SCRATCH.as_mut().handler = old_handler;
-        TRAP_SCRATCH.as_mut().trap_context = old_trap_context;
+        sscratch::write(old_trap_ctx);
+        stvec::write(old_stvec);
 
         irq_states.restore();
     }
@@ -319,7 +70,7 @@ fn setup_trap_handler(trap_entry_addr: usize) {
 }
 
 pub fn setup_trap() {
-    setup_trap_handler(_raw_trap_entry as usize);
+    setup_trap_handler(_default_trap_entry as usize);
 }
 
 #[derive(Debug, Clone, Copy, PartialEq, Eq)]

+ 47 - 50
crates/eonix_hal/src/arch/riscv64/trap/trap_context.rs

@@ -1,5 +1,5 @@
 use crate::{arch::time::set_next_timer, processor::CPU};
-use core::arch::asm;
+use core::{arch::asm, mem::offset_of};
 use eonix_hal_traits::{
     fault::{Fault, PageFaultErrorCode},
     trap::{RawTrapContext, TrapType},
@@ -18,24 +18,23 @@ use riscv::{
 #[repr(C)]
 #[derive(Default, Clone, Copy)]
 pub struct Registers {
+    tp: u64,
     ra: u64,
     sp: u64,
     gp: u64,
-    tp: u64,
-    t1: u64,
-    t2: u64,
-    t0: u64,
     a0: u64,
     a1: u64,
     a2: u64,
     a3: u64,
     a4: u64,
+    t1: u64,
     a5: u64,
     a6: u64,
     a7: u64,
     t3: u64,
     t4: u64,
     t5: u64,
+    t2: u64,
     t6: u64,
     s0: u64,
     s1: u64,
@@ -49,10 +48,11 @@ pub struct Registers {
     s9: u64,
     s10: u64,
     s11: u64,
+    t0: u64,
 }
 
 /// Saved CPU context when a trap (interrupt or exception) occurs on RISC-V 64.
-#[repr(C)]
+#[repr(C, align(16))]
 #[derive(Clone, Copy)]
 pub struct TrapContext {
     regs: Registers,
@@ -60,46 +60,48 @@ pub struct TrapContext {
     sstatus: Sstatus,
     sepc: usize,
     scause: Scause,
+    stval: usize,
 }
 
 impl Registers {
-    pub const OFFSET_RA: usize = 0 * 8;
-    pub const OFFSET_SP: usize = 1 * 8;
-    pub const OFFSET_GP: usize = 2 * 8;
-    pub const OFFSET_TP: usize = 3 * 8;
-    pub const OFFSET_T1: usize = 4 * 8;
-    pub const OFFSET_T2: usize = 5 * 8;
-    pub const OFFSET_T0: usize = 6 * 8;
-    pub const OFFSET_A0: usize = 7 * 8;
-    pub const OFFSET_A1: usize = 8 * 8;
-    pub const OFFSET_A2: usize = 9 * 8;
-    pub const OFFSET_A3: usize = 10 * 8;
-    pub const OFFSET_A4: usize = 11 * 8;
-    pub const OFFSET_A5: usize = 12 * 8;
-    pub const OFFSET_A6: usize = 13 * 8;
-    pub const OFFSET_A7: usize = 14 * 8;
-    pub const OFFSET_T3: usize = 15 * 8;
-    pub const OFFSET_T4: usize = 16 * 8;
-    pub const OFFSET_T5: usize = 17 * 8;
-    pub const OFFSET_T6: usize = 18 * 8;
-    pub const OFFSET_S0: usize = 19 * 8;
-    pub const OFFSET_S1: usize = 20 * 8;
-    pub const OFFSET_S2: usize = 21 * 8;
-    pub const OFFSET_S3: usize = 22 * 8;
-    pub const OFFSET_S4: usize = 23 * 8;
-    pub const OFFSET_S5: usize = 24 * 8;
-    pub const OFFSET_S6: usize = 25 * 8;
-    pub const OFFSET_S7: usize = 26 * 8;
-    pub const OFFSET_S8: usize = 27 * 8;
-    pub const OFFSET_S9: usize = 28 * 8;
-    pub const OFFSET_S10: usize = 29 * 8;
-    pub const OFFSET_S11: usize = 30 * 8;
+    pub const OFFSET_TP: usize = offset_of!(Registers, tp);
+    pub const OFFSET_SP: usize = offset_of!(Registers, sp);
+    pub const OFFSET_RA: usize = offset_of!(Registers, ra);
+    pub const OFFSET_GP: usize = offset_of!(Registers, gp);
+    pub const OFFSET_T1: usize = offset_of!(Registers, t1);
+    pub const OFFSET_T2: usize = offset_of!(Registers, t2);
+    pub const OFFSET_T0: usize = offset_of!(Registers, t0);
+    pub const OFFSET_A0: usize = offset_of!(Registers, a0);
+    pub const OFFSET_A1: usize = offset_of!(Registers, a1);
+    pub const OFFSET_A2: usize = offset_of!(Registers, a2);
+    pub const OFFSET_A3: usize = offset_of!(Registers, a3);
+    pub const OFFSET_A4: usize = offset_of!(Registers, a4);
+    pub const OFFSET_A5: usize = offset_of!(Registers, a5);
+    pub const OFFSET_A6: usize = offset_of!(Registers, a6);
+    pub const OFFSET_A7: usize = offset_of!(Registers, a7);
+    pub const OFFSET_T3: usize = offset_of!(Registers, t3);
+    pub const OFFSET_T4: usize = offset_of!(Registers, t4);
+    pub const OFFSET_T5: usize = offset_of!(Registers, t5);
+    pub const OFFSET_T6: usize = offset_of!(Registers, t6);
+    pub const OFFSET_S0: usize = offset_of!(Registers, s0);
+    pub const OFFSET_S1: usize = offset_of!(Registers, s1);
+    pub const OFFSET_S2: usize = offset_of!(Registers, s2);
+    pub const OFFSET_S3: usize = offset_of!(Registers, s3);
+    pub const OFFSET_S4: usize = offset_of!(Registers, s4);
+    pub const OFFSET_S5: usize = offset_of!(Registers, s5);
+    pub const OFFSET_S6: usize = offset_of!(Registers, s6);
+    pub const OFFSET_S7: usize = offset_of!(Registers, s7);
+    pub const OFFSET_S8: usize = offset_of!(Registers, s8);
+    pub const OFFSET_S9: usize = offset_of!(Registers, s9);
+    pub const OFFSET_S10: usize = offset_of!(Registers, s10);
+    pub const OFFSET_S11: usize = offset_of!(Registers, s11);
 }
 
 impl TrapContext {
-    pub const OFFSET_SSTATUS: usize = 31 * 8;
-    pub const OFFSET_SEPC: usize = 32 * 8;
-    pub const OFFSET_SCAUSE: usize = 33 * 8;
+    pub const OFFSET_SSTATUS: usize = offset_of!(TrapContext, sstatus);
+    pub const OFFSET_SEPC: usize = offset_of!(TrapContext, sepc);
+    pub const OFFSET_SCAUSE: usize = offset_of!(TrapContext, scause);
+    pub const OFFSET_STVAL: usize = offset_of!(TrapContext, stval);
 
     fn syscall_no(&self) -> usize {
         self.regs.a7 as usize
@@ -131,6 +133,7 @@ impl RawTrapContext for TrapContext {
             sstatus,
             sepc: 0,
             scause: Scause::from_bits(0),
+            stval: 0,
         }
     }
 
@@ -176,16 +179,10 @@ impl RawTrapContext for TrapContext {
                     },
                     exception @ (Exception::InstructionPageFault
                     | Exception::LoadPageFault
-                    | Exception::StorePageFault) => {
-                        #[inline(always)]
-                        fn get_page_fault_address() -> VAddr {
-                            VAddr::from(stval::read())
-                        }
-                        TrapType::Fault(Fault::PageFault {
-                            error_code: self.get_page_fault_error_code(exception),
-                            address: get_page_fault_address(),
-                        })
-                    }
+                    | Exception::StorePageFault) => TrapType::Fault(Fault::PageFault {
+                        error_code: self.get_page_fault_error_code(exception),
+                        address: VAddr::from(self.stval),
+                    }),
                     // breakpoint and supervisor env call
                     _ => TrapType::Fault(Fault::Unknown(e)),
                 }