Bladeren bron

feat(hal): impl trap handler for riscv64

Heinz 7 maanden geleden
bovenliggende
commit
89366051d1

+ 253 - 0
crates/eonix_hal/src/arch/riscv64/trap/mod.rs

@@ -0,0 +1,253 @@
+mod trap_context;
+
+use eonix_hal_traits::{context::RawTaskContext, trap::TrapReturn};
+pub use trap_context::*;
+
+use riscv::{
+    asm::sfence_vma_all,
+    register::{
+        sstatus::{self, Sstatus},
+        stvec::{self, Stvec}
+    }
+};
+use sbi::SbiError;
+use core::arch::{global_asm, naked_asm};
+
+use super::context::TaskContext;
+
+use super::config::platform::virt::*;
+
+//global_asm!(include_str!("trap.S"));
+
+use riscv::register::{scause, sepc, stval};
+
+//#[eonix_percpu::define_percpu]
+//static TRAP_HANDLER: unsafe extern "C" fn() = default_trap_handler;
+
+#[eonix_percpu::define_percpu]
+static TRAP_HANDLER: unsafe extern "C" fn() = default_trap_handler;
+
+#[eonix_percpu::define_percpu]
+static CAPTURER_CONTEXT: TaskContext = TaskContext::new();
+
+/// This value will never be used.
+static mut DIRTY_TRAP_CONTEXT: TaskContext = TaskContext::new();
+
+// TODO: is need to save kernel's callee saved registers?
+global_asm!(
+    r"
+    .altmacro
+    .macro SAVE_GP n
+        sd x\n, \n*8(sp)
+    .endm
+    .macro LOAD_GP n
+        ld x\n, \n*8(sp)
+    .endm
+
+    .section .text
+        .globl _raw_trap_entry
+        .globl return_to_user
+        .align 2
+
+    _raw_trap_entry:
+        # swap sp and sscratch(previously stored user TrapContext's address in return_to_user)
+        csrrw sp, sscratch, sp
+
+        sd x1, 1*8(sp)
+        .set n, 3
+        .rept 29
+            SAVE_GP %n
+            .set n, n+1
+        .endr
+
+        csrr t0, sstatus
+        csrr t1, sepc
+        csrr t2, scause
+        csrr t3, stval
+        sd t0, 32*8(sp)     # save sstatus into the TrapContext
+        sd t1, 33*8(sp)     # save sepc into the TrapContext
+        sd t2, 34*8(sp)     # save scause into the TrapContext
+        sd t3, 35*8(sp)     # save stval into the TrapContext
+
+        csrr t0, sscratch
+        sd t0, 2*8(sp)      # save user stack pointer into the TrapContext
+
+        addi t0, tp, {handler}
+        ld t1, 0(t0)
+        jr t1
+
+    _raw_trap_return:
+        # sscratch store the TrapContext's address
+        csrw sscratch, a0
+
+        mv sp, a0
+        # now sp points to TrapContext in kernel space
+
+        # restore sstatus and sepc
+        ld t0, 32*8(sp)
+        ld t1, 33*8(sp)
+        ld t2, 34*8(sp)
+        ld t3, 35*8(sp)
+        csrw sstatus, t0
+        csrw sepc, t1
+        csrw scause, t2
+        csrw stval, t3
+
+        # save x* expect x0 and sp
+        ld x1, 1*8(sp)
+        .set n, 3
+        .rept 29
+            LOAD_GP %n
+            .set n, n+1
+        .endr
+        ld sp, 2*8(sp)
+
+        sret
+    ",
+    handler = sym _percpu_inner_TRAP_HANDLER,
+
+);
+
+unsafe extern "C" {
+    fn _default_trap_handler(trap_context: &mut TrapContext);
+    fn _raw_trap_entry();
+    fn _raw_trap_return();
+}
+
+/// TODO:
+/// default_trap_handler
+/// captured_trap_handler
+/// _raw_trap_entry应该是做好了
+/// _raw_trap_return应该是做好了
+#[unsafe(naked)]
+unsafe extern "C" fn default_trap_handler() {
+    naked_asm!(
+        "mv t0, sp",
+        "andi sp, sp, -16",
+        "mv a0, t0",
+        "call {handle_trap}",
+
+        "mv sp, t0",
+
+        "j {trap_return}",
+
+        handle_trap = sym _default_trap_handler,
+        trap_return = sym _raw_trap_return,
+    );
+}
+
+#[unsafe(naked)]
+unsafe extern "C" fn captured_trap_handler() {
+    naked_asm!(
+        "addi sp, sp, -16",
+        "sd ra, 8(sp)",
+
+        "la a0, {from_context}",
+
+        "mv t0, tp",
+        "la t1, {to_context}",
+        "add a1, t0, t1",
+
+        "call {switch}",
+
+        "ld ra, 8(sp)",
+        "addi sp, sp, 16",
+        "ret",
+
+        from_context = sym DIRTY_TRAP_CONTEXT,
+        to_context = sym _percpu_inner_CAPTURER_CONTEXT,
+        switch = sym TaskContext::switch,
+    );
+}
+
+#[unsafe(naked)]
+unsafe extern "C" fn captured_trap_return(trap_context: usize) -> ! {
+    naked_asm!(
+        "la t0, {trap_return}",
+        "jalr zero, t0, 0",
+        trap_return = sym _raw_trap_return,
+    );
+}
+
+impl TrapReturn for TrapContext {
+    unsafe fn trap_return(&mut self) {
+        let irq_states = disable_irqs_save();
+        let old_handler = TRAP_HANDLER.swap(captured_trap_handler);
+
+        let mut to_ctx = TaskContext::new();
+        to_ctx.set_program_counter(captured_trap_return as _);
+        to_ctx.set_stack_pointer(&raw mut *self as usize);
+        to_ctx.set_interrupt_enabled(false);
+
+        unsafe {
+            TaskContext::switch(CAPTURER_CONTEXT.as_mut(), &mut to_ctx);
+        }
+
+        TRAP_HANDLER.set(old_handler);
+        irq_states.restore();
+    }
+}
+
+fn setup_trap_handler(trap_entry_addr: usize) {
+    unsafe {
+        stvec::write(Stvec::from_bits(trap_entry_addr));
+    }
+    sfence_vma_all();
+}
+
+pub fn setup_trap() {
+    setup_trap_handler(_raw_trap_entry as usize);
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub struct IrqState(usize);
+
+impl IrqState {
+    #[inline]
+    pub fn save() -> Self {
+        let sstatus_val = sstatus::read().bits();
+
+        unsafe {
+            sstatus::clear_sie();
+        }
+
+        IrqState(sstatus_val)
+    }
+
+    #[inline]
+    pub fn restore(self) {
+        let Self(state) = self;
+        unsafe {
+            sstatus::write(Sstatus::from_bits(state));
+        }
+    }
+
+    #[inline]
+    pub fn was_enabled(&self) -> bool {
+        (self.0 & (1 << 1)) != 0
+    }
+}
+
+#[inline]
+pub fn disable_irqs() {
+    unsafe {
+        sstatus::clear_sie();
+    }
+}
+
+#[inline]
+pub fn enable_irqs() {
+    unsafe {
+        sstatus::set_sie();
+    }
+}
+
+#[inline]
+pub fn disable_irqs_save() -> IrqState {
+    unsafe {
+        let original_sstatus_bits = sstatus::read().bits();
+        sstatus::clear_sie();
+
+        IrqState(original_sstatus_bits)
+    }
+}

+ 104 - 0
crates/eonix_hal/src/arch/riscv64/trap/trap.S

@@ -0,0 +1,104 @@
+// some old code
+
+.altmacro
+.macro SAVE_GP n
+    sd x\n, \n*8(sp)
+.endm
+.macro LOAD_GP n
+    ld x\n, \n*8(sp)
+.endm
+
+.section .text
+    .globl trap_from_user
+    .globl return_to_user
+    .align 2
+
+_raw_trap_entry:
+    # swap sp and sscratch(previously stored user TrapContext's address in return_to_user)
+    csrrw sp, sscratch, sp
+
+    sd x1, 1*8(sp)
+    .set n, 3
+    .rept 29
+        SAVE_GP %n
+        .set n, n+1
+    .endr
+
+    csrr t0, sstatus
+    csrr t1, sepc
+    csrr t2, scause
+    csrr t3, stval
+    sd t0, 32*8(sp)     # save sstatus into the TrapContext
+    sd t1, 33*8(sp)     # save sepc into the TrapContext
+    sd t2, 34*8(sp)     # save scause into the TrapContext
+    sd t3, 35*8(sp)     # save stval into the TrapContext
+
+    csrr t0, sscratch
+    sd t0, 2*8(sp)      # save user stack pointer into the TrapContext
+
+    ld ra, 37*8(sp)
+    ld s0, 38*8(sp)
+    ld s1, 39*8(sp)
+    ld s2, 40*8(sp)
+    ld s3, 41*8(sp)
+    ld s4, 42*8(sp)
+    ld s5, 43*8(sp)
+    ld s6, 44*8(sp)
+    ld s7, 45*8(sp)
+    ld s8, 46*8(sp)
+    ld s9, 47*8(sp)
+    ld s10, 48*8(sp)
+    ld s11, 49*8(sp)
+
+    ld fp, 50*8(sp)
+    ld tp, 51*8(sp)
+
+    ld sp, 36*8(sp)
+    ret
+
+# a0: pointer to TrapContext in user space (constant)
+_raw_trap_return:
+    # sscratch store the TrapContext's address
+    csrw sscratch, a0
+
+    # offset in TrapContext's order
+    sd sp, 36*8(a0)
+    sd ra, 37*8(a0)
+    sd s0, 38*8(a0)
+    sd s1, 39*8(a0) 
+    sd s2, 40*8(a0)
+    sd s3, 41*8(a0)
+    sd s4, 42*8(a0)
+    sd s5, 43*8(a0)
+    sd s6, 44*8(a0)
+    sd s7, 45*8(a0)
+    sd s8, 46*8(a0)
+    sd s9, 47*8(a0)
+    sd s10, 48*8(a0)
+    sd s11, 49*8(a0)
+    sd fp, 50*8(a0)
+    sd tp, 51*8(a0)
+
+    mv sp, a0
+    # now sp points to TrapContext in kernel space
+
+    # restore sstatus and sepc
+    ld t0, 32*8(sp)
+    ld t1, 33*8(sp)
+    ld t2, 34*8(sp)
+    ld t3, 35*8(sp)
+    csrw sstatus, t0
+    csrw sepc, t1
+    csrw scause, t2
+    csrw stval, t3
+
+    # save x* expect x0 and sp
+    ld x1, 1*8(sp)
+    .set n, 3
+    .rept 29
+        LOAD_GP %n
+        .set n, n+1
+    .endr
+    ld sp, 2*8(sp)
+
+    sret

+ 194 - 0
crates/eonix_hal/src/arch/riscv64/trap/trap_context.rs

@@ -0,0 +1,194 @@
+use core::arch::asm;
+use eonix_hal_traits::{fault::{Fault, PageFaultErrorCode}, trap::{RawTrapContext, TrapType}};
+use riscv::{
+    interrupt::{Exception, Interrupt, Trap}, register::{
+        scause, sie, sstatus::{self, Sstatus, SPP}, stval
+    }, ExceptionNumber, InterruptNumber
+};
+
+/// Floating-point registers context.
+#[repr(C)]
+#[derive(Debug, Clone, Copy, Default)]
+pub struct FpuRegisters {
+    pub f: [u64; 32],
+    pub fcsr: u32,
+}
+
+/// Saved CPU context when a trap (interrupt or exception) occurs on RISC-V 64.
+#[repr(C)]
+#[derive(Debug, Clone, Copy)]
+pub struct TrapContext {
+    pub x: [usize; 32],
+
+    // CSRs
+    pub sstatus: Sstatus, // sstatus CSR value. Contains privilege mode, interrupt enable, FPU state.
+    pub sepc: usize,    // sepc (Supervisor Exception Program Counter). Program counter at trap.
+    pub scause: usize,  // S-mode Trap Cause Register
+    pub stval: usize,
+    
+    //pub kernel_sp: usize,
+    //pub kernel_ra: usize,
+    //pub kernel_s: [usize; 12],
+    //pub kernel_fp: usize,
+    //pub kernel_tp: usize,
+}
+
+impl TrapContext {
+    fn syscall_no(&self) -> usize {
+        self.x[17]
+    }
+
+    fn syscall_args(&self) -> [usize; 6] {
+        [
+            self.x[10],
+            self.x[11],
+            self.x[12],
+            self.x[13],
+            self.x[14],
+            self.x[15],
+        ]
+    }
+}
+
+impl RawTrapContext for TrapContext {
+    /// TODO: temporarily all zero, may change in future
+    fn new() -> Self {
+        Self {
+            x: [0; 32],
+            sstatus: sstatus::read(),
+            sepc: 0,
+            scause: 0,
+            stval: 0,
+            //kernel_sp: 0,
+            //kernel_ra: 0,
+            //kernel_s: [0; 12],
+            //kernel_fp: 0,
+            //kernel_tp: 0
+        }
+    }
+
+    fn trap_type(&self) -> TrapType {
+        let scause = scause::Scause::from_bits(self.scause);
+        let cause = scause.cause();
+        match cause {
+            Trap::Interrupt(i) => {
+                match Interrupt::from_number(i).unwrap() {
+                    Interrupt::SupervisorTimer => TrapType::Timer,
+                    // TODO: need to read plic
+                    Interrupt::SupervisorExternal => TrapType::Irq(0),
+                    // soft interrupt
+                    _ => TrapType::Fault(Fault::Unknown(i)),
+                }
+            }
+            Trap::Exception(e) => {
+                match Exception::from_number(e).unwrap() {
+                    Exception::InstructionMisaligned |
+                    Exception::LoadMisaligned |
+                    Exception::InstructionFault |
+                    Exception::LoadFault |
+                    Exception::StoreFault |
+                    Exception::StoreMisaligned => {
+                        TrapType::Fault(Fault::BadAccess)
+                    },
+                    Exception::IllegalInstruction => {
+                        TrapType::Fault(Fault::InvalidOp)
+                    }
+                    Exception::UserEnvCall => {
+                        TrapType::Syscall { 
+                            no: self.syscall_no(),
+                            args: self.syscall_args()
+                        }
+                    },
+                    Exception::InstructionPageFault |
+                    Exception::LoadPageFault |
+                    Exception::StorePageFault => {
+                        let e = Exception::from_number(e).unwrap();
+                        TrapType::Fault(Fault::PageFault(self.get_page_fault_error_code(e)))
+                    },
+                    // breakpoint and supervisor env call
+                    _ => TrapType::Fault(Fault::Unknown(e)),
+                }
+            },
+        }
+    }
+
+    fn get_program_counter(&self) -> usize {
+        self.sepc
+    }
+
+    fn get_stack_pointer(&self) -> usize {
+        self.x[2]
+    }
+
+    fn set_program_counter(&mut self, pc: usize) {
+        self.sepc = pc;
+    }
+
+    fn set_stack_pointer(&mut self, sp: usize) {
+        self.x[2] = sp;
+    }
+
+    fn is_interrupt_enabled(&self) -> bool {
+        self.sstatus.sie()
+    }
+
+    /// TODO: may need more precise control
+    fn set_interrupt_enabled(&mut self, enabled: bool) {
+        if enabled {
+            self.sstatus.set_sie(enabled);
+            unsafe { 
+                sie::set_sext();
+                sie::set_ssoft();
+                sie::set_stimer();
+            };
+        } else {
+            self.sstatus.set_sie(enabled);
+            unsafe { 
+                sie::clear_sext();
+                sie::clear_ssoft();
+                sie::clear_stimer();
+            };
+        }
+    }
+
+    fn is_user_mode(&self) -> bool {
+        self.sstatus.spp() == SPP::User
+    }
+
+    fn set_user_mode(&mut self, user: bool) {
+        match user {
+            true => self.sstatus.set_spp(SPP::User),
+            false => self.sstatus.set_spp(SPP::Supervisor),
+        }
+    }
+
+    fn set_user_return_value(&mut self, retval: usize) {
+        self.sepc = retval;
+    }
+}
+
+impl TrapContext {
+    /// TODO: get PageFaultErrorCode also need check pagetable
+    fn get_page_fault_error_code(&self, exception_type: Exception) -> PageFaultErrorCode {
+        let scause_val = self.scause;
+        let mut error_code = PageFaultErrorCode::empty();
+
+        match exception_type {
+            Exception::InstructionPageFault => {
+                error_code |= PageFaultErrorCode::InstructionFetch;
+                error_code |= PageFaultErrorCode::Read;
+            }
+            Exception::LoadPageFault => {
+                error_code |= PageFaultErrorCode::Read;
+            }
+            Exception::StorePageFault => {
+                error_code |= PageFaultErrorCode::Write;
+            }
+            _ => {
+                unreachable!();
+            }
+        }
+        // TODO: here need check pagetable to confirm NonPresent and UserAccess
+        error_code
+    }
+}