Forráskód Böngészése

partial work: working trap

greatbridf 7 hónapja
szülő
commit
2d868ba813

+ 1 - 3
arch/src/riscv64/fpu.rs

@@ -1,8 +1,6 @@
 use core::arch::asm;
-
 use eonix_hal_traits::fpu::RawFpuState;
 
-
 #[repr(C)]
 #[derive(Debug, Clone, Copy, Default)]
 pub struct FpuState {
@@ -107,7 +105,7 @@ impl RawFpuState for FpuState {
             base = in(reg) base_ptr,
             fcsr_val = out(reg) _fcsr_val,
             fcsr_ptr = in(reg) fcsr_ptr,
-            options(nostack, nomem, preserves_flags));
+            options(nostack, preserves_flags));
         }
     }
 }

+ 1 - 5
arch/src/riscv64/mod.rs

@@ -1,9 +1,5 @@
 use core::arch::asm;
-
-use eonix_mm::{
-    address::{Addr, PAddr, VAddr},
-    paging::PFN,
-};
+use eonix_mm::{address::VAddr, paging::PFN};
 use riscv::{
     asm::{sfence_vma, sfence_vma_all},
     register::{satp, stval},

+ 9 - 15
crates/eonix_hal/src/arch/riscv64/bootstrap.rs

@@ -1,6 +1,7 @@
 use super::{
     config::{self, mm::*},
     console::write_str,
+    trap::TRAP_SCRATCH,
 };
 use crate::{
     arch::{
@@ -26,13 +27,7 @@ use eonix_mm::{
 };
 use eonix_percpu::PercpuArea;
 use fdt::Fdt;
-use riscv::{
-    asm::sfence_vma_all,
-    register::{
-        satp,
-        sstatus::{self, FS},
-    },
-};
+use riscv::{asm::sfence_vma_all, register::satp};
 use sbi::legacy::console_putchar;
 
 #[unsafe(link_section = ".bootstrap.stack")]
@@ -100,7 +95,6 @@ unsafe extern "C" fn _start(hart_id: usize, dtb_addr: usize) -> ! {
 /// TODO:
 /// 启动所有的cpu
 pub unsafe extern "C" fn riscv64_start(hart_id: usize, dtb_addr: PAddr) -> ! {
-    enable_fpu();
     let fdt = Fdt::from_ptr(ArchPhysAccess::as_ptr(dtb_addr).as_ptr())
         .expect("Failed to parse DTB from static memory.");
 
@@ -190,13 +184,6 @@ fn setup_kernel_page_table(alloc: impl PageAlloc) {
     sfence_vma_all();
 }
 
-fn enable_fpu() {
-    unsafe {
-        // FS (Floating-point Status) Initial (0b01)
-        sstatus::set_fs(FS::Initial);
-    }
-}
-
 /// set up tp register to percpu
 fn setup_cpu(alloc: impl PageAlloc, hart_id: usize) {
     let mut percpu_area = PercpuArea::new(|layout| {
@@ -227,6 +214,13 @@ fn setup_cpu(alloc: impl PageAlloc, hart_id: usize) {
     }
 
     percpu_area.register(cpu.cpuid());
+
+    unsafe {
+        // SAFETY: Interrupts are disabled.
+        TRAP_SCRATCH
+            .as_mut()
+            .set_kernel_tp(PercpuArea::get_for(cpu.cpuid()).unwrap().cast());
+    }
 }
 
 /// TODO

+ 13 - 66
crates/eonix_hal/src/arch/riscv64/context.rs

@@ -1,16 +1,15 @@
 use core::arch::naked_asm;
-
 use eonix_hal_traits::context::RawTaskContext;
 use riscv::register::sstatus::Sstatus;
 
 #[repr(C)]
-#[derive(Debug, Default)]
+#[derive(Debug)]
 pub struct TaskContext {
     // s0-11
     s: [u64; 12],
     sp: u64,
     ra: u64,
-    sstatus: u64,
+    sstatus: Sstatus,
 }
 
 impl RawTaskContext for TaskContext {
@@ -27,19 +26,11 @@ impl RawTaskContext for TaskContext {
     }
 
     fn is_interrupt_enabled(&self) -> bool {
-        let sstatus_val = Sstatus::from_bits(self.sstatus as usize);
-        sstatus_val.sie()
+        self.sstatus.sie()
     }
 
     fn set_interrupt_enabled(&mut self, is_enabled: bool) {
-        // sstatus: SIE bit is bit 1
-        const SSTATUS_SIE_BIT: u64 = 1 << 1; // 0x2
-
-        if is_enabled {
-            self.sstatus |= SSTATUS_SIE_BIT;
-        } else {
-            self.sstatus &= !SSTATUS_SIE_BIT;
-        }
+        self.sstatus.set_sie(is_enabled);
     }
 
     fn call(&mut self, func: unsafe extern "C" fn(usize) -> !, arg: usize) {
@@ -49,57 +40,8 @@ impl RawTaskContext for TaskContext {
         self.set_program_counter(Self::do_call as usize);
     }
 
-    unsafe extern "C" fn switch(from: &mut Self, to: &mut Self) {
-        unsafe { Self::__task_context_switch(from, to) }
-    }
-}
-
-impl TaskContext {
-    pub const fn new() -> Self {
-        Self {
-            s: [0; 12],
-            sp: 0,
-            ra: 0,
-            sstatus: 0,
-        }
-    }
-
-    pub fn ip(&mut self, ip: usize) {
-        self.ra = ip as u64;
-    }
-
-    pub fn entry_point(&mut self, entry: usize) {
-        self.ra = entry as u64;
-    }
-
-    pub fn sstatus(&mut self, status: usize) {
-        self.sstatus = status as u64;
-    }
-
-    pub fn interrupt(&mut self, is_enabled: bool) {
-        // sstatus: SIE bit is bit 1
-        const SSTATUS_SIE_BIT: u64 = 1 << 1; // 0x2
-
-        if is_enabled {
-            self.sstatus |= SSTATUS_SIE_BIT;
-        } else {
-            self.sstatus &= !SSTATUS_SIE_BIT;
-        }
-    }
-
-    /// SPIE bit
-    pub fn interrupt_on_return(&mut self, will_enable: bool) {
-        const SSTATUS_SPIE_BIT: u64 = 1 << 5;
-        if will_enable {
-            self.sstatus |= SSTATUS_SPIE_BIT;
-        } else {
-            self.sstatus &= !SSTATUS_SPIE_BIT;
-        }
-    }
-
     #[unsafe(naked)]
-    #[unsafe(no_mangle)]
-    pub unsafe extern "C" fn __task_context_switch(from: *mut Self, to: *const Self) {
+    unsafe extern "C" fn switch(from: &mut Self, to: &mut Self) {
         // Input arguments `from` and `to` will be in `a0` (x10) and `a1` (x11).
         naked_asm!(
             // Save current task's callee-saved registers to `from` context
@@ -139,10 +81,15 @@ impl TaskContext {
             "ret",
         );
     }
+}
 
-    pub fn switch(from: &mut Self, to: &mut Self) {
-        unsafe {
-            TaskContext::__task_context_switch(from, to);
+impl TaskContext {
+    pub const fn new() -> Self {
+        Self {
+            s: [0; 12],
+            sp: 0,
+            ra: 0,
+            sstatus: Sstatus::from_bits(1 << 13), // Set FS = Initial
         }
     }
 

+ 14 - 9
crates/eonix_hal/src/arch/riscv64/cpu.rs

@@ -1,9 +1,15 @@
-use super::{interrupt::InterruptControl, trap::setup_trap};
+use super::{
+    interrupt::InterruptControl,
+    trap::{setup_trap, TRAP_SCRATCH},
+};
 use crate::arch::fdt::{FdtExt, FDT};
-use core::pin::Pin;
+use core::{pin::Pin, ptr::NonNull};
 use eonix_preempt::PreemptGuard;
 use eonix_sync_base::LazyLock;
-use riscv::register::{mhartid, sscratch, sstatus};
+use riscv::register::{
+    medeleg::{self, Medeleg},
+    mhartid, sscratch, sstatus,
+};
 use riscv_peripheral::plic::PLIC;
 use sbi::PhysicalAddress;
 
@@ -50,11 +56,8 @@ impl CPU {
         let interrupt = self.map_unchecked_mut(|me| &mut me.interrupt);
         interrupt.init();
 
-        let mut current_sstatus = sstatus::read();
-        current_sstatus.set_spp(sstatus::SPP::Supervisor);
-        current_sstatus.set_sum(true);
-        current_sstatus.set_mxr(true);
-        sstatus::write(current_sstatus);
+        sstatus::set_sum();
+        sscratch::write(TRAP_SCRATCH.as_ptr() as usize);
     }
 
     /// Boot all other hart.
@@ -67,7 +70,9 @@ impl CPU {
     }
 
     pub unsafe fn load_interrupt_stack(self: Pin<&mut Self>, sp: u64) {
-        sscratch::write(sp as usize);
+        TRAP_SCRATCH
+            .as_mut()
+            .set_trap_context(NonNull::new(sp as *mut _).unwrap());
     }
 
     pub fn set_tls32(self: Pin<&mut Self>, _user_tls: &UserTLS) {

+ 178 - 118
crates/eonix_hal/src/arch/riscv64/trap/mod.rs

@@ -3,6 +3,9 @@ mod trap_context;
 use super::config::platform::virt::*;
 use super::context::TaskContext;
 use core::arch::{global_asm, naked_asm};
+use core::mem::{offset_of, size_of};
+use core::num::NonZero;
+use core::ptr::NonNull;
 use eonix_hal_traits::{context::RawTaskContext, trap::TrapReturn};
 use riscv::register::sie::Sie;
 use riscv::register::stvec::TrapMode;
@@ -18,11 +21,25 @@ use sbi::SbiError;
 
 pub use trap_context::*;
 
-#[eonix_percpu::define_percpu]
-static TRAP_HANDLER: unsafe extern "C" fn() = default_trap_handler;
+#[repr(C)]
+pub struct TrapScratch {
+    t1: u64,
+    t2: u64,
+    kernel_tp: Option<NonZero<u64>>,
+    trap_context: Option<NonNull<TrapContext>>,
+    handler: unsafe extern "C" fn(),
+    capturer_context: TaskContext,
+}
 
 #[eonix_percpu::define_percpu]
-static CAPTURER_CONTEXT: TaskContext = TaskContext::new();
+pub(crate) static TRAP_SCRATCH: TrapScratch = TrapScratch {
+    t1: 0,
+    t2: 0,
+    kernel_tp: None,
+    trap_context: None,
+    handler: default_trap_handler,
+    capturer_context: TaskContext::new(),
+};
 
 /// This value will never be used.
 static mut DIRTY_TRAP_CONTEXT: TaskContext = TaskContext::new();
@@ -30,109 +47,147 @@ static mut DIRTY_TRAP_CONTEXT: TaskContext = TaskContext::new();
 #[unsafe(naked)]
 unsafe extern "C" fn _raw_trap_entry() -> ! {
     naked_asm!(
-        "j {entry}",
-        entry = sym _raw_trap_entry,
+        "csrrw t0, sscratch, t0", // Swap t0 and sscratch
+        "sd    t1, 0(t0)",
+        "sd    t2, 8(t0)",
+        "csrr  t1, sstatus",
+        "andi  t1, t1, 0x10",
+        "beqz  t1, 2f",
+        // else SPP = 1, supervisor mode
+        "addi  t1, sp, -{trap_context_size}",
+        "mv    t2, tp",
+        "j     3f",
+        // SPP = 0, user mode
+        "2:",
+        "ld    t1, 24(t0)", // Load captured TrapContext address
+        "mv    t2, tp",
+        "ld    tp, 16(t0)", // Restore kernel tp
+        // t0: &mut TrapScratch, t1: &mut TrapContext, t2: tp before trap
+        "3:",
+        "sd    ra, {ra}(t1)",
+        "sd    sp, {sp}(t1)",
+        "sd    gp, {gp}(t1)",
+        "sd    t2, {tp}(t1)",
+        "ld    ra, 0(t0)",
+        "ld    t2, 8(t0)",
+        "sd    ra, {t1}(t1)",     // Save t1
+        "sd    t2, {t2}(t1)",     // Save t2
+        "ld    ra, 32(t0)",       // Load handler address
+        "csrrw t2, sscratch, t0", // Swap to and sscratch
+        "sd    t2, {t0}(t1)",
+        "sd    a0, {a0}(t1)",
+        "sd    a1, {a1}(t1)",
+        "sd    a2, {a2}(t1)",
+        "sd    a3, {a3}(t1)",
+        "sd    a4, {a4}(t1)",
+        "sd    a5, {a5}(t1)",
+        "sd    a6, {a6}(t1)",
+        "sd    a7, {a7}(t1)",
+        "sd    t3, {t3}(t1)",
+        "sd    t4, {t4}(t1)",
+        "sd    t5, {t5}(t1)",
+        "sd    t6, {t6}(t1)",
+        "csrr  t2, sstatus",
+        "csrr  t3, sepc",
+        "csrr  t4, scause",
+        "sd    t2, {sstatus}(t1)",
+        "sd    t3, {sepc}(t1)",
+        "sd    t4, {scause}(t1)",
+        "ret",
+        trap_context_size = const size_of::<TrapContext>(),
+        ra = const Registers::OFFSET_RA,
+        sp = const Registers::OFFSET_SP,
+        gp = const Registers::OFFSET_GP,
+        tp = const Registers::OFFSET_TP,
+        t1 = const Registers::OFFSET_T1,
+        t2 = const Registers::OFFSET_T2,
+        t0 = const Registers::OFFSET_T0,
+        a0 = const Registers::OFFSET_A0,
+        a1 = const Registers::OFFSET_A1,
+        a2 = const Registers::OFFSET_A2,
+        a3 = const Registers::OFFSET_A3,
+        a4 = const Registers::OFFSET_A4,
+        a5 = const Registers::OFFSET_A5,
+        a6 = const Registers::OFFSET_A6,
+        a7 = const Registers::OFFSET_A7,
+        t3 = const Registers::OFFSET_T3,
+        t4 = const Registers::OFFSET_T4,
+        t5 = const Registers::OFFSET_T5,
+        t6 = const Registers::OFFSET_T6,
+        sstatus = const TrapContext::OFFSET_SSTATUS,
+        sepc = const TrapContext::OFFSET_SEPC,
+        scause = const TrapContext::OFFSET_SCAUSE,
     );
 }
 
-// TODO: is need to save kernel's callee saved registers?
-global_asm!(
-    r"
-    .altmacro
-    .macro SAVE_GP n
-        sd x\n, \n*8(sp)
-    .endm
-    .macro LOAD_GP n
-        ld x\n, \n*8(sp)
-    .endm
-
-    .section .text
-        .globl __raw_trap_entry
-        .globl return_to_user
-        .align 2
-
-    __raw_trap_entry:
-        # swap sp and sscratch(previously stored user TrapContext's address in return_to_user)
-        csrrw sp, sscratch, sp
-
-        sd x1, 1*8(sp)
-        .set n, 3
-        .rept 29
-            SAVE_GP %n
-            .set n, n+1
-        .endr
-
-        csrr t0, sstatus
-        csrr t1, sepc
-        csrr t2, scause
-        csrr t3, stval
-        sd t0, 32*8(sp)     # save sstatus into the TrapContext
-        sd t1, 33*8(sp)     # save sepc into the TrapContext
-        sd t2, 34*8(sp)     # save scause into the TrapContext
-        sd t3, 35*8(sp)     # save stval into the TrapContext
-
-        csrr t0, sscratch
-        sd t0, 2*8(sp)      # save user stack pointer into the TrapContext
-
-        la t0, {handler}
-        ld t1, 0(t0)
-        jr t1
-
-    _raw_trap_return:
-        # sscratch store the TrapContext's address
-        csrw sscratch, a0
-
-        mv sp, a0
-        # now sp points to TrapContext in kernel space
-
-        # restore sstatus and sepc
-        ld t0, 32*8(sp)
-        ld t1, 33*8(sp)
-        ld t2, 34*8(sp)
-        ld t3, 35*8(sp)
-        csrw sstatus, t0
-        csrw sepc, t1
-        csrw scause, t2
-        csrw stval, t3
-
-        # save x* expect x0 and sp
-        ld x1, 1*8(sp)
-        .set n, 3
-        .rept 29
-            LOAD_GP %n
-            .set n, n+1
-        .endr
-        ld sp, 2*8(sp)
-
-        sret
-    ",
-    handler = sym _percpu_inner_TRAP_HANDLER,
-
-);
-
-unsafe extern "C" {
-    fn _default_trap_handler(trap_context: &mut TrapContext);
-    fn _raw_trap_return();
+#[unsafe(naked)]
+unsafe extern "C" fn _raw_trap_return(ctx: &mut TrapContext) -> ! {
+    naked_asm!(
+        "ld ra, {ra}(a0)",
+        "ld sp, {sp}(a0)",
+        "ld gp, {gp}(a0)",
+        "ld tp, {tp}(a0)",
+        "ld t1, {t1}(a0)",
+        "ld t2, {t2}(a0)",
+        "ld t0, {t0}(a0)",
+        "ld a1, {a1}(a0)",
+        "ld a2, {a2}(a0)",
+        "ld a3, {a3}(a0)",
+        "ld a4, {a4}(a0)",
+        "ld a5, {a5}(a0)",
+        "ld a6, {a6}(a0)",
+        "ld a7, {a7}(a0)",
+        "ld t3, {t3}(a0)",
+        "ld t4, {sepc}(a0)",    // Load sepc from TrapContext
+        "ld t5, {sstatus}(a0)", // Load sstatus from TrapContext
+        "csrw sepc, t4",        // Restore sepc
+        "csrw sstatus, t5",     // Restore sstatus
+        "ld t4, {t4}(a0)",
+        "ld t5, {t5}(a0)",
+        "ld t6, {t6}(a0)",
+        "ld a0, {a0}(a0)",
+        "sret",
+        ra = const Registers::OFFSET_RA,
+        sp = const Registers::OFFSET_SP,
+        gp = const Registers::OFFSET_GP,
+        tp = const Registers::OFFSET_TP,
+        t1 = const Registers::OFFSET_T1,
+        t2 = const Registers::OFFSET_T2,
+        t0 = const Registers::OFFSET_T0,
+        a0 = const Registers::OFFSET_A0,
+        a1 = const Registers::OFFSET_A1,
+        a2 = const Registers::OFFSET_A2,
+        a3 = const Registers::OFFSET_A3,
+        a4 = const Registers::OFFSET_A4,
+        a5 = const Registers::OFFSET_A5,
+        a6 = const Registers::OFFSET_A6,
+        a7 = const Registers::OFFSET_A7,
+        t3 = const Registers::OFFSET_T3,
+        t4 = const Registers::OFFSET_T4,
+        t5 = const Registers::OFFSET_T5,
+        t6 = const Registers::OFFSET_T6,
+        sstatus = const TrapContext::OFFSET_SSTATUS,
+        sepc = const TrapContext::OFFSET_SEPC,
+    );
 }
 
-/// TODO:
-/// default_trap_handler
-/// captured_trap_handler
-/// _raw_trap_entry应该是做好了
-/// _raw_trap_return应该是做好了
 #[unsafe(naked)]
 unsafe extern "C" fn default_trap_handler() {
-    naked_asm!(
-        "mv t0, sp",
-        "andi sp, sp, -16",
-        "mv a0, t0",
-        "call {handle_trap}",
-
-        "mv sp, t0",
+    unsafe extern "C" {
+        fn _default_trap_handler(trap_context: &mut TrapContext);
+    }
 
+    naked_asm!(
+        "andi sp, sp, -16", // Align stack pointer to 16 bytes
+        "addi sp, sp, -16",
+        "mv   a0, t1",      // TrapContext pointer in t1
+        "sd   a0, 0(sp)",   // Save TrapContext pointer
+        "",
+        "call {default_handler}",
+        "",
+        "ld   a0, 0(sp)",   // Restore TrapContext pointer
         "j {trap_return}",
-
-        handle_trap = sym _default_trap_handler,
+        default_handler = sym _default_trap_handler,
         trap_return = sym _raw_trap_return,
     );
 }
@@ -140,23 +195,11 @@ unsafe extern "C" fn default_trap_handler() {
 #[unsafe(naked)]
 unsafe extern "C" fn captured_trap_handler() {
     naked_asm!(
-        "addi sp, sp, -16",
-        "sd ra, 8(sp)",
-
         "la a0, {from_context}",
-
-        "mv t0, tp",
-        "la t1, {to_context}",
-        "add a1, t0, t1",
-
-        "call {switch}",
-
-        "ld ra, 8(sp)",
-        "addi sp, sp, 16",
-        "ret",
-
+        "addi a1, t0, {capturer_context_offset}",
+        "j {switch}",
         from_context = sym DIRTY_TRAP_CONTEXT,
-        to_context = sym _percpu_inner_CAPTURER_CONTEXT,
+        capturer_context_offset = const offset_of!(TrapScratch, capturer_context),
         switch = sym TaskContext::switch,
     );
 }
@@ -164,16 +207,33 @@ unsafe extern "C" fn captured_trap_handler() {
 #[unsafe(naked)]
 unsafe extern "C" fn captured_trap_return(trap_context: usize) -> ! {
     naked_asm!(
-        "la t0, {trap_return}",
-        "jalr zero, t0, 0",
-        trap_return = sym _raw_trap_return,
+        "mv a0, sp",
+        "j {raw_trap_return}",
+        raw_trap_return = sym _raw_trap_return,
     );
 }
 
+impl TrapScratch {
+    pub fn set_trap_context(&mut self, ctx: NonNull<TrapContext>) {
+        self.trap_context = Some(ctx);
+    }
+
+    pub fn clear_trap_context(&mut self) {
+        self.trap_context = None;
+    }
+
+    pub fn set_kernel_tp(&mut self, tp: NonNull<u8>) {
+        self.kernel_tp = Some(NonZero::new(tp.addr().get() as u64).unwrap());
+    }
+}
+
 impl TrapReturn for TrapContext {
     unsafe fn trap_return(&mut self) {
         let irq_states = disable_irqs_save();
-        let old_handler = TRAP_HANDLER.swap(captured_trap_handler);
+        let old_handler = {
+            let trap_scratch = TRAP_SCRATCH.as_mut();
+            core::mem::replace(&mut trap_scratch.handler, captured_trap_handler)
+        };
 
         let mut to_ctx = TaskContext::new();
         to_ctx.set_program_counter(captured_trap_return as _);
@@ -181,10 +241,10 @@ impl TrapReturn for TrapContext {
         to_ctx.set_interrupt_enabled(false);
 
         unsafe {
-            TaskContext::switch(CAPTURER_CONTEXT.as_mut(), &mut to_ctx);
+            TaskContext::switch(&mut TRAP_SCRATCH.as_mut().capturer_context, &mut to_ctx);
         }
 
-        TRAP_HANDLER.set(old_handler);
+        TRAP_SCRATCH.as_mut().handler = old_handler;
         irq_states.restore();
     }
 }

+ 0 - 104
crates/eonix_hal/src/arch/riscv64/trap/trap.S

@@ -1,104 +0,0 @@
-// some old code
-
-.altmacro
-.macro SAVE_GP n
-    sd x\n, \n*8(sp)
-.endm
-.macro LOAD_GP n
-    ld x\n, \n*8(sp)
-.endm
-
-.section .text
-    .globl trap_from_user
-    .globl return_to_user
-    .align 2
-
-_raw_trap_entry:
-    # swap sp and sscratch(previously stored user TrapContext's address in return_to_user)
-    csrrw sp, sscratch, sp
-
-    sd x1, 1*8(sp)
-    .set n, 3
-    .rept 29
-        SAVE_GP %n
-        .set n, n+1
-    .endr
-
-    csrr t0, sstatus
-    csrr t1, sepc
-    csrr t2, scause
-    csrr t3, stval
-    sd t0, 32*8(sp)     # save sstatus into the TrapContext
-    sd t1, 33*8(sp)     # save sepc into the TrapContext
-    sd t2, 34*8(sp)     # save scause into the TrapContext
-    sd t3, 35*8(sp)     # save stval into the TrapContext
-
-    csrr t0, sscratch
-    sd t0, 2*8(sp)      # save user stack pointer into the TrapContext
-
-    ld ra, 37*8(sp)
-    ld s0, 38*8(sp)
-    ld s1, 39*8(sp)
-    ld s2, 40*8(sp)
-    ld s3, 41*8(sp)
-    ld s4, 42*8(sp)
-    ld s5, 43*8(sp)
-    ld s6, 44*8(sp)
-    ld s7, 45*8(sp)
-    ld s8, 46*8(sp)
-    ld s9, 47*8(sp)
-    ld s10, 48*8(sp)
-    ld s11, 49*8(sp)
-
-    ld fp, 50*8(sp)
-    ld tp, 51*8(sp)
-
-    ld sp, 36*8(sp)
-    ret
-
-# a0: pointer to TrapContext in user space (constant)
-_raw_trap_return:
-    # sscratch store the TrapContext's address
-    csrw sscratch, a0
-
-    # offset in TrapContext's order
-    sd sp, 36*8(a0)
-    sd ra, 37*8(a0)
-    sd s0, 38*8(a0)
-    sd s1, 39*8(a0) 
-    sd s2, 40*8(a0)
-    sd s3, 41*8(a0)
-    sd s4, 42*8(a0)
-    sd s5, 43*8(a0)
-    sd s6, 44*8(a0)
-    sd s7, 45*8(a0)
-    sd s8, 46*8(a0)
-    sd s9, 47*8(a0)
-    sd s10, 48*8(a0)
-    sd s11, 49*8(a0)
-    sd fp, 50*8(a0)
-    sd tp, 51*8(a0)
-
-    mv sp, a0
-    # now sp points to TrapContext in kernel space
-
-    # restore sstatus and sepc
-    ld t0, 32*8(sp)
-    ld t1, 33*8(sp)
-    ld t2, 34*8(sp)
-    ld t3, 35*8(sp)
-    csrw sstatus, t0
-    csrw sepc, t1
-    csrw scause, t2
-    csrw stval, t3
-
-    # save x* expect x0 and sp
-    ld x1, 1*8(sp)
-    .set n, 3
-    .rept 29
-        LOAD_GP %n
-        .set n, n+1
-    .endr
-    ld sp, 2*8(sp)
-
-    sret

+ 74 - 33
crates/eonix_hal/src/arch/riscv64/trap/trap_context.rs

@@ -3,12 +3,12 @@ use eonix_hal_traits::{
     fault::{Fault, PageFaultErrorCode},
     trap::{RawTrapContext, TrapType},
 };
+use eonix_mm::address::VAddr;
 use riscv::{
     interrupt::{Exception, Interrupt, Trap},
     register::{
         scause::{self, Scause},
-        sie,
-        sstatus::{self, Sstatus, SPP},
+        sstatus::{self, Sstatus, FS, SPP},
         stval,
     },
     ExceptionNumber, InterruptNumber,
@@ -22,36 +22,92 @@ pub struct FpuRegisters {
     pub fcsr: u32,
 }
 
+#[repr(C)]
+#[derive(Default, Clone, Copy)]
+pub struct Registers {
+    ra: u64,
+    sp: u64,
+    gp: u64,
+    tp: u64,
+    t1: u64,
+    t2: u64,
+    t0: u64,
+    a0: u64,
+    a1: u64,
+    a2: u64,
+    a3: u64,
+    a4: u64,
+    a5: u64,
+    a6: u64,
+    a7: u64,
+    t3: u64,
+    t4: u64,
+    t5: u64,
+    t6: u64,
+}
+
 /// Saved CPU context when a trap (interrupt or exception) occurs on RISC-V 64.
 #[repr(C)]
-#[derive(Debug, Clone, Copy)]
+#[derive(Clone, Copy)]
 pub struct TrapContext {
-    pub x: [usize; 32],
+    regs: Registers,
+
+    sstatus: Sstatus,
+    sepc: usize,
+    scause: Scause,
+}
 
-    // CSRs
-    pub sstatus: Sstatus, // sstatus CSR value. Contains privilege mode, interrupt enable, FPU state.
-    pub sepc: usize,      // sepc (Supervisor Exception Program Counter). Program counter at trap.
-    pub scause: Scause,   // S-mode Trap Cause Register
+impl Registers {
+    pub const OFFSET_RA: usize = 0 * 8;
+    pub const OFFSET_SP: usize = 1 * 8;
+    pub const OFFSET_GP: usize = 2 * 8;
+    pub const OFFSET_TP: usize = 3 * 8;
+    pub const OFFSET_T1: usize = 4 * 8;
+    pub const OFFSET_T2: usize = 5 * 8;
+    pub const OFFSET_T0: usize = 6 * 8;
+    pub const OFFSET_A0: usize = 7 * 8;
+    pub const OFFSET_A1: usize = 8 * 8;
+    pub const OFFSET_A2: usize = 9 * 8;
+    pub const OFFSET_A3: usize = 10 * 8;
+    pub const OFFSET_A4: usize = 11 * 8;
+    pub const OFFSET_A5: usize = 12 * 8;
+    pub const OFFSET_A6: usize = 13 * 8;
+    pub const OFFSET_A7: usize = 14 * 8;
+    pub const OFFSET_T3: usize = 15 * 8;
+    pub const OFFSET_T4: usize = 16 * 8;
+    pub const OFFSET_T5: usize = 17 * 8;
+    pub const OFFSET_T6: usize = 18 * 8;
 }
 
 impl TrapContext {
+    pub const OFFSET_SSTATUS: usize = 19 * 8;
+    pub const OFFSET_SEPC: usize = 20 * 8;
+    pub const OFFSET_SCAUSE: usize = 21 * 8;
+
     fn syscall_no(&self) -> usize {
-        self.x[17]
+        self.regs.a7 as usize
     }
 
     fn syscall_args(&self) -> [usize; 6] {
         [
-            self.x[10], self.x[11], self.x[12], self.x[13], self.x[14], self.x[15],
+            self.regs.a0 as usize,
+            self.regs.a1 as usize,
+            self.regs.a2 as usize,
+            self.regs.a3 as usize,
+            self.regs.a4 as usize,
+            self.regs.a5 as usize,
         ]
     }
 }
 
 impl RawTrapContext for TrapContext {
-    /// TODO: temporarily all zero, may change in future
     fn new() -> Self {
+        let mut sstatus = Sstatus::from_bits(0);
+        sstatus.set_fs(FS::Initial);
+
         Self {
-            x: [0; 32],
-            sstatus: sstatus::read(),
+            regs: Registers::default(),
+            sstatus,
             sepc: 0,
             scause: Scause::from_bits(0),
         }
@@ -100,7 +156,7 @@ impl RawTrapContext for TrapContext {
     }
 
     fn get_stack_pointer(&self) -> usize {
-        self.x[2]
+        self.regs.sp as usize
     }
 
     fn set_program_counter(&mut self, pc: usize) {
@@ -108,30 +164,15 @@ impl RawTrapContext for TrapContext {
     }
 
     fn set_stack_pointer(&mut self, sp: usize) {
-        self.x[2] = sp;
+        self.regs.sp = sp as u64;
     }
 
     fn is_interrupt_enabled(&self) -> bool {
-        self.sstatus.sie()
+        self.sstatus.spie()
     }
 
-    /// TODO: may need more precise control
     fn set_interrupt_enabled(&mut self, enabled: bool) {
-        if enabled {
-            self.sstatus.set_sie(enabled);
-            unsafe {
-                sie::set_sext();
-                sie::set_ssoft();
-                sie::set_stimer();
-            };
-        } else {
-            self.sstatus.set_sie(enabled);
-            unsafe {
-                sie::clear_sext();
-                sie::clear_ssoft();
-                sie::clear_stimer();
-            };
-        }
+        self.sstatus.set_spie(enabled);
     }
 
     fn is_user_mode(&self) -> bool {
@@ -146,7 +187,7 @@ impl RawTrapContext for TrapContext {
     }
 
     fn set_user_return_value(&mut self, retval: usize) {
-        self.sepc = retval;
+        self.regs.a0 = retval as u64;
     }
 }
 

+ 30 - 36
crates/eonix_hal/src/arch/x86_64/context.rs

@@ -35,39 +35,6 @@ impl TaskContext {
         }
     }
 
-    #[unsafe(naked)]
-    unsafe extern "C" fn _switch(from: &mut Self, to: &mut Self) {
-        naked_asm!(
-            "pop %rax",
-            "pushf",
-            "pop %rcx",
-            "mov %r12, (%rdi)",
-            "mov %r13, 8(%rdi)",
-            "mov %r14, 16(%rdi)",
-            "mov %r15, 24(%rdi)",
-            "mov %rbx, 32(%rdi)",
-            "mov %rbp, 40(%rdi)",
-            "mov %rsp, 48(%rdi)",
-            "mov %rax, 56(%rdi)",
-            "mov %rcx, 64(%rdi)",
-            "",
-            "mov (%rsi), %r12",
-            "mov 8(%rsi), %r13",
-            "mov 16(%rsi), %r14",
-            "mov 24(%rsi), %r15",
-            "mov 32(%rsi), %rbx",
-            "mov 40(%rsi), %rbp",
-            "mov 48(%rsi), %rdi", // store next stack pointer
-            "mov 56(%rsi), %rax",
-            "mov 64(%rsi), %rcx",
-            "push %rcx",
-            "popf",
-            "xchg %rdi, %rsp", // switch to new stack
-            "jmp *%rax",
-            options(att_syntax),
-        );
-    }
-
     #[unsafe(naked)]
     unsafe extern "C" fn do_call() -> ! {
         naked_asm!(
@@ -111,9 +78,36 @@ impl RawTaskContext for TaskContext {
         self.rbp = 0; // NULL previous stack frame
     }
 
+    #[unsafe(naked)]
     unsafe extern "C" fn switch(from: &mut Self, to: &mut Self) {
-        unsafe {
-            Self::_switch(from, to);
-        }
+        naked_asm!(
+            "pop %rax",
+            "pushf",
+            "pop %rcx",
+            "mov %r12, (%rdi)",
+            "mov %r13, 8(%rdi)",
+            "mov %r14, 16(%rdi)",
+            "mov %r15, 24(%rdi)",
+            "mov %rbx, 32(%rdi)",
+            "mov %rbp, 40(%rdi)",
+            "mov %rsp, 48(%rdi)",
+            "mov %rax, 56(%rdi)",
+            "mov %rcx, 64(%rdi)",
+            "",
+            "mov (%rsi), %r12",
+            "mov 8(%rsi), %r13",
+            "mov 16(%rsi), %r14",
+            "mov 24(%rsi), %r15",
+            "mov 32(%rsi), %rbx",
+            "mov 40(%rsi), %rbp",
+            "mov 48(%rsi), %rdi", // store next stack pointer
+            "mov 56(%rsi), %rax",
+            "mov 64(%rsi), %rcx",
+            "push %rcx",
+            "popf",
+            "xchg %rdi, %rsp", // switch to new stack
+            "jmp *%rax",
+            options(att_syntax),
+        );
     }
 }

+ 4 - 1
crates/eonix_hal/src/arch/x86_64/cpu.rs

@@ -1,6 +1,8 @@
 use super::gdt::{GDTEntry, GDT};
 use super::interrupt::InterruptControl;
+use super::trap::TrapContext;
 use core::marker::PhantomPinned;
+use core::mem::size_of;
 use core::pin::Pin;
 use eonix_preempt::PreemptGuard;
 use eonix_sync_base::LazyLock;
@@ -114,7 +116,8 @@ impl CPU {
 
     pub unsafe fn load_interrupt_stack(self: Pin<&mut Self>, rsp: u64) {
         unsafe {
-            self.map_unchecked_mut(|me| &mut me.tss).set_rsp0(rsp);
+            self.map_unchecked_mut(|me| &mut me.tss)
+                .set_rsp0(rsp + size_of::<TrapContext>() as u64);
         }
     }
 

+ 1 - 4
src/kernel/pcie/driver.rs

@@ -2,10 +2,7 @@ use super::{
     device::{PCIDevice, PCIE_DEVICES},
     error::PciError,
 };
-use crate::{
-    kernel::constants::{EEXIST, ENOENT},
-    KResult,
-};
+use crate::{kernel::constants::EEXIST, KResult};
 use alloc::{
     collections::btree_map::{self, BTreeMap},
     sync::Arc,

+ 1 - 1
src/kernel/task/thread.rs

@@ -467,7 +467,7 @@ impl<F: Future> Contexted for ThreadRunnable<F> {
             // SAFETY:
             CPU::local()
                 .as_mut()
-                .load_interrupt_stack(trap_ctx_ptr.add(1).addr() as u64);
+                .load_interrupt_stack(trap_ctx_ptr as u64);
         }
     }