Prechádzať zdrojové kódy

hal: remove dependency of TaskContext for TrapContext

TaskContext is useless except for in the `trap_return()` implementation.
But actually, this is unnecessary and makes the code ugly and weird.

Rewrite the `trap_return()` with native code to remove dependency.

Signed-off-by: greatbridf <greatbridf@icloud.com>
greatbridf 1 týždeň pred
rodič
commit
b877fff904

+ 11 - 3
crates/eonix_hal/eonix_hal_traits/src/trap.rs

@@ -2,7 +2,6 @@ use core::marker::PhantomData;
 
 use eonix_mm::address::VAddr;
 
-use crate::context::RawTaskContext;
 use crate::fault::Fault;
 
 /// A raw trap context.
@@ -40,8 +39,6 @@ pub trait RawTrapContext: Copy {
 
 #[doc(notable_trait)]
 pub trait TrapReturn {
-    type TaskContext: RawTaskContext;
-
     /// Return to the context before the trap occurred.
     ///
     /// # Safety
@@ -50,6 +47,17 @@ pub trait TrapReturn {
     /// points to a valid stack frame and the program counter points to some
     /// valid instruction.
     unsafe fn trap_return(&mut self);
+
+    /// Switch to the context before the trap occurred.
+    /// This function will NOT capture traps and will never return.
+    ///
+    /// # Safety
+    /// This function is unsafe because the caller MUST ensure that the
+    /// context before the trap is valid, that is, that the stack pointer
+    /// points to a valid stack frame and the program counter points to some
+    /// valid instruction. Besides, the caller MUST ensure that all variables
+    /// in the current context are released.
+    unsafe fn trap_return_noreturn(&mut self) -> !;
 }
 
 pub trait IrqState {

+ 189 - 95
crates/eonix_hal/src/arch/riscv64/trap/captured.rs

@@ -1,61 +1,110 @@
-use crate::{arch::trap::Registers, context::TaskContext, trap::TrapContext};
-use core::{arch::naked_asm, mem::MaybeUninit};
-use eonix_hal_traits::context::RawTaskContext;
+use core::arch::naked_asm;
+use core::mem::MaybeUninit;
 
-static mut DIRTY_TASK_CONTEXT: MaybeUninit<TaskContext> = MaybeUninit::uninit();
+use crate::arch::trap::Registers;
+use crate::trap::TrapContext;
 
 // If captured trap context is present, we use it directly.
-// We need to restore the kernel tp from that TrapContext but sp is
-// fine since we will use TaskContext::switch.
+// We need to restore the callee saved registers from that TrapContext.
 #[unsafe(naked)]
 pub(super) unsafe extern "C" fn _captured_trap_entry() -> ! {
     naked_asm!(
-        "csrrw t0, sscratch, t0",
-        "sd    tp, {tp}(t0)",
-        "ld    tp, {ra}(t0)", // Load kernel tp from trap_ctx.ra
-        "sd    ra, {ra}(t0)",
-        "ld    ra, {sp}(t0)", // Load capturer task context from trap_ctx.sp
-        "sd    sp, {sp}(t0)",
-        "sd    gp, {gp}(t0)",
-        "sd    a0, {a0}(t0)",
-        "sd    a1, {a1}(t0)",
-        "sd    a2, {a2}(t0)",
-        "sd    a3, {a3}(t0)",
-        "sd    a4, {a4}(t0)",
-        "sd    t1, {t1}(t0)",
-        "sd    a5, {a5}(t0)",
-        "sd    a6, {a6}(t0)",
-        "sd    a7, {a7}(t0)",
-        "sd    t3, {t3}(t0)",
-        "sd    t4, {t4}(t0)",
-        "sd    t5, {t5}(t0)",
-        "sd    t2, {t2}(t0)",
-        "sd    t6, {t6}(t0)",
-        "sd    s0, {s0}(t0)",
-        "sd    s1, {s1}(t0)",
-        "sd    s2, {s2}(t0)",
-        "sd    s3, {s3}(t0)",
-        "sd    s4, {s4}(t0)",
-        "sd    s5, {s5}(t0)",
-        "sd    s6, {s6}(t0)",
-        "sd    s7, {s7}(t0)",
-        "sd    s8, {s8}(t0)",
-        "sd    s9, {s9}(t0)",
-        "sd    s10, {s10}(t0)",
-        "sd    s11, {s11}(t0)",
-        "csrr  t2, sstatus",
-        "csrr  t3, sepc",
-        "csrr  t4, scause",
-        "csrr  t5, stval",
-        "csrrw t6, sscratch, t0",
-        "sd    t6, {t0}(t0)",
-        "sd    t2, {sstatus}(t0)",
-        "sd    t3, {sepc}(t0)",
-        "sd    t4, {scause}(t0)",
-        "sd    t5, {stval}(t0)",
-        "la    a0, {dirty_task_context}",
-        "mv    a1, ra",
-        "j     {task_context_switch}",
+        "csrrw gp, sscratch, gp",
+
+        "sd    a0, {a0}(gp)",
+        "mv    a0, s0",
+        "ld    s0, {s0}(gp)",
+        "sd    a0, {s0}(gp)",
+
+        "sd    a1, {a1}(gp)",
+        "mv    a1, s1",
+        "ld    s1, {s1}(gp)",
+        "sd    a1, {s1}(gp)",
+
+        "sd    a2, {a2}(gp)",
+        "mv    a2, s2",
+        "ld    s2, {s2}(gp)",
+        "sd    a2, {s2}(gp)",
+
+        "sd    a3, {a3}(gp)",
+        "mv    a3, s3",
+        "ld    s3, {s3}(gp)",
+        "sd    a3, {s3}(gp)",
+
+        "sd    a4, {a4}(gp)",
+        "mv    a4, s4",
+        "ld    s4, {s4}(gp)",
+        "sd    a4, {s4}(gp)",
+
+        "sd    a5, {a5}(gp)",
+        "mv    a5, s5",
+        "ld    s5, {s5}(gp)",
+        "sd    a5, {s5}(gp)",
+
+        "sd    a6, {a6}(gp)",
+        "mv    a6, s6",
+        "ld    s6, {s6}(gp)",
+        "sd    a6, {s6}(gp)",
+
+        "sd    a7, {a7}(gp)",
+        "mv    a7, s7",
+        "ld    s7, {s7}(gp)",
+        "sd    a7, {s7}(gp)",
+
+        "sd    t0, {t0}(gp)",
+        "mv    t0, s8",
+        "ld    s8, {s8}(gp)",
+        "sd    t0, {s8}(gp)",
+
+        "sd    t1, {t1}(gp)",
+        "mv    t1, s9",
+        "ld    s9, {s9}(gp)",
+        "sd    t1, {s9}(gp)",
+
+        "sd    t2, {t2}(gp)",
+        "mv    t2, s10",
+        "ld    s10, {s10}(gp)",
+        "sd    t2, {s10}(gp)",
+
+        "sd    t3, {t3}(gp)",
+        "mv    t3, s11",
+        "ld    s11, {s11}(gp)",
+        "sd    t3, {s11}(gp)",
+
+        "sd    t4, {t4}(gp)",
+        "mv    t4, ra",
+        "ld    ra, {ra}(gp)",
+        "sd    t4, {ra}(gp)",
+
+        "sd    t5, {t5}(gp)",
+        "mv    t5, tp",
+        "ld    tp, {tp}(gp)",
+        "sd    t5, {tp}(gp)",
+
+        "sd    t6, {t6}(gp)",
+        "mv    t6, sp",
+        "ld    sp, {sp}(gp)",
+        "sd    t6, {sp}(gp)",
+
+        "ld    t5, {sstatus}(gp)",
+
+        "csrrw t0, sscratch, gp",
+        "sd    t0, {gp}(gp)",
+
+        "csrr  t1, sstatus",
+        "sd    t1, {sstatus}(gp)",
+
+        "csrr  t2, sepc",
+        "sd    t2, {sepc}(gp)",
+
+        "csrr  t3, scause",
+        "sd    t3, {scause}(gp)",
+
+        "csrr  t4, stval",
+        "sd    t4, {stval}(gp)",
+
+        "csrw  sstatus, t5",
+        "ret",
         ra = const Registers::OFFSET_RA,
         sp = const Registers::OFFSET_SP,
         gp = const Registers::OFFSET_GP,
@@ -91,54 +140,99 @@ pub(super) unsafe extern "C" fn _captured_trap_entry() -> ! {
         sepc = const TrapContext::OFFSET_SEPC,
         scause = const TrapContext::OFFSET_SCAUSE,
         stval = const TrapContext::OFFSET_STVAL,
-        dirty_task_context = sym DIRTY_TASK_CONTEXT,
-        task_context_switch = sym TaskContext::switch,
     );
 }
 
 #[unsafe(naked)]
-pub(super) unsafe extern "C" fn _captured_trap_return(ctx: &mut TrapContext) -> ! {
+pub(super) unsafe extern "C" fn _captured_trap_return(ctx: &mut TrapContext) {
     naked_asm!(
-        "csrr   t0,  sscratch",
-        "ld     t1,  {sstatus}(t0)",
-        "ld     t2,  {sepc}(t0)",
-        "csrw   sstatus, t1",
-        "csrw   sepc, t2",
-        "mv     t4,  tp",
-        "mv     t5,  sp",
-        "ld     tp,  {tp}(t0)",
-        "ld     ra,  {ra}(t0)",
-        "ld     sp,  {sp}(t0)",
-        "sd     t4,  {ra}(t0)", // Store kernel tp to trap_ctx.ra
-        "sd     t5,  {sp}(t0)", // Store capturer task context to trap_ctx.sp
-        "ld     gp,  {gp}(t0)",
-        "ld     a0,  {a0}(t0)",
-        "ld     a1,  {a1}(t0)",
-        "ld     a2,  {a2}(t0)",
-        "ld     a3,  {a3}(t0)",
-        "ld     a4,  {a4}(t0)",
-        "ld     t1,  {t1}(t0)",
-        "ld     a5,  {a5}(t0)",
-        "ld     a6,  {a6}(t0)",
-        "ld     a7,  {a7}(t0)",
-        "ld     t3,  {t3}(t0)",
-        "ld     t4,  {t4}(t0)",
-        "ld     t5,  {t5}(t0)",
-        "ld     t2,  {t2}(t0)",
-        "ld     t6,  {t6}(t0)",
-        "ld     s0,  {s0}(t0)",
-        "ld     s1,  {s1}(t0)",
-        "ld     s2,  {s2}(t0)",
-        "ld     s3,  {s3}(t0)",
-        "ld     s4,  {s4}(t0)",
-        "ld     s5,  {s5}(t0)",
-        "ld     s6,  {s6}(t0)",
-        "ld     s7,  {s7}(t0)",
-        "ld     s8,  {s8}(t0)",
-        "ld     s9,  {s9}(t0)",
-        "ld     s10, {s10}(t0)",
-        "ld     s11, {s11}(t0)",
-        "ld     t0,  {t0}(t0)",
+        "mv    t6, a0",
+
+        "mv    a0, s0",
+        "ld    s0, {s0}(t6)",
+        "sd    a0, {s0}(t6)",
+
+        "mv    a1, s1",
+        "ld    s1, {s1}(t6)",
+        "sd    a1, {s1}(t6)",
+
+        "mv    a2, s2",
+        "ld    s2, {s2}(t6)",
+        "sd    a2, {s2}(t6)",
+
+        "mv    a3, s3",
+        "ld    s3, {s3}(t6)",
+        "sd    a3, {s3}(t6)",
+
+        "mv    a4, s4",
+        "ld    s4, {s4}(t6)",
+        "sd    a4, {s4}(t6)",
+
+        "mv    a5, s5",
+        "ld    s5, {s5}(t6)",
+        "sd    a5, {s5}(t6)",
+
+        "mv    a6, s6",
+        "ld    s6, {s6}(t6)",
+        "sd    a6, {s6}(t6)",
+
+        "mv    a7, s7",
+        "ld    s7, {s7}(t6)",
+        "sd    a7, {s7}(t6)",
+
+        "mv    t0, s8",
+        "ld    s8, {s8}(t6)",
+        "sd    t0, {s8}(t6)",
+
+        "mv    t1, s9",
+        "ld    s9, {s9}(t6)",
+        "sd    t1, {s9}(t6)",
+
+        "mv     t2, s10",
+        "ld    s10, {s10}(t6)",
+        "sd     t2, {s10}(t6)",
+
+        "mv     t3, s11",
+        "ld    s11, {s11}(t6)",
+        "sd     t3, {s11}(t6)",
+
+        "mv     t4, ra",
+        "ld     ra, {ra}(t6)",
+        "sd     t4, {ra}(t6)",
+
+        "mv     t5, tp",
+        "ld     tp, {tp}(t6)",
+        "sd     t5, {tp}(t6)",
+
+        "mv     a0, sp",
+        "ld     sp, {sp}(t6)",
+        "sd     a0, {sp}(t6)",
+
+        "csrr   t4, sstatus",
+        "ld     t5, {sstatus}(t6)",
+        "sd     t4, {sstatus}(t6)",
+        "ld     t4, {sepc}(t6)",
+
+        "ld     gp, {gp}(t6)",
+        "ld     a0, {a0}(t6)",
+        "ld     a1, {a1}(t6)",
+        "ld     a2, {a2}(t6)",
+        "ld     a3, {a3}(t6)",
+        "ld     a4, {a4}(t6)",
+        "ld     a5, {a5}(t6)",
+        "ld     a6, {a6}(t6)",
+        "ld     a7, {a7}(t6)",
+
+        "csrw   sstatus, t5",
+        "csrw      sepc, t4",
+
+        "ld     t0, {t0}(t6)",
+        "ld     t1, {t1}(t6)",
+        "ld     t2, {t2}(t6)",
+        "ld     t3, {t3}(t6)",
+        "ld     t4, {t4}(t6)",
+        "ld     t5, {t5}(t6)",
+        "ld     t6, {t6}(t6)",
         "sret",
         ra = const Registers::OFFSET_RA,
         sp = const Registers::OFFSET_SP,

+ 19 - 22
crates/eonix_hal/src/arch/riscv64/trap/mod.rs

@@ -2,32 +2,25 @@ mod captured;
 mod default;
 mod trap_context;
 
-use super::config::platform::virt::*;
-use super::context::TaskContext;
-use captured::{_captured_trap_entry, _captured_trap_return};
 use core::arch::{global_asm, naked_asm};
 use core::mem::{offset_of, size_of};
 use core::num::NonZero;
 use core::ptr::NonNull;
+
+use captured::{_captured_trap_entry, _captured_trap_return};
 use default::_default_trap_entry;
-use eonix_hal_traits::{
-    context::RawTaskContext,
-    trap::{IrqState as IrqStateTrait, TrapReturn},
-};
+use eonix_hal_traits::context::RawTaskContext;
+use eonix_hal_traits::trap::{IrqState as IrqStateTrait, TrapReturn};
+use riscv::asm::sfence_vma_all;
 use riscv::register::sstatus::{self, Sstatus};
-use riscv::register::stvec::TrapMode;
+use riscv::register::stvec::{self, Stvec, TrapMode};
 use riscv::register::{scause, sepc, sscratch, stval};
-use riscv::{
-    asm::sfence_vma_all,
-    register::stvec::{self, Stvec},
-};
 use sbi::SbiError;
-
 pub use trap_context::*;
 
-impl TrapReturn for TrapContext {
-    type TaskContext = TaskContext;
+use super::config::platform::virt::*;
 
+impl TrapReturn for TrapContext {
     unsafe fn trap_return(&mut self) {
         let irq_states = disable_irqs_save();
 
@@ -42,14 +35,8 @@ impl TrapReturn for TrapContext {
         let old_trap_ctx = sscratch::read();
         sscratch::write(&raw mut *self as usize);
 
-        let mut from_ctx = TaskContext::new();
-        let mut to_ctx = TaskContext::new();
-        to_ctx.set_program_counter(_captured_trap_return as usize);
-        to_ctx.set_stack_pointer(&raw mut from_ctx as usize);
-        to_ctx.set_interrupt_enabled(false);
-
         unsafe {
-            TaskContext::switch(&mut from_ctx, &mut to_ctx);
+            _captured_trap_return(self);
         }
 
         sscratch::write(old_trap_ctx);
@@ -57,6 +44,16 @@ impl TrapReturn for TrapContext {
 
         irq_states.restore();
     }
+
+    unsafe fn trap_return_noreturn(&mut self) -> ! {
+        disable_irqs();
+
+        unsafe {
+            _captured_trap_return(self);
+        }
+
+        unreachable!("trap_return_noreturn should not return");
+    }
 }
 
 fn setup_trap_handler(trap_entry_addr: usize) {

+ 17 - 16
crates/eonix_hal/src/arch/riscv64/trap/trap_context.rs

@@ -16,37 +16,38 @@ use crate::processor::CPU;
 #[repr(C)]
 #[derive(Default, Clone, Copy)]
 pub struct Registers {
-    tp: u64,
+    s0: u64,
+    s1: u64,
+    s2: u64,
+    s3: u64,
+    s4: u64,
+    s5: u64,
+    s6: u64,
+    s7: u64,
+    s8: u64,
+    s9: u64,
+    s10: u64,
+    s11: u64,
     ra: u64,
+    tp: u64,
     sp: u64,
     gp: u64,
+
     a0: u64,
     a1: u64,
     a2: u64,
     a3: u64,
     a4: u64,
-    t1: u64,
     a5: u64,
     a6: u64,
     a7: u64,
+    t0: u64,
+    t1: u64,
+    t2: u64,
     t3: u64,
     t4: u64,
     t5: u64,
-    t2: u64,
     t6: u64,
-    s0: u64,
-    s1: u64,
-    s2: u64,
-    s3: u64,
-    s4: u64,
-    s5: u64,
-    s6: u64,
-    s7: u64,
-    s8: u64,
-    s9: u64,
-    s10: u64,
-    s11: u64,
-    t0: u64,
 }
 
 /// Saved CPU context when a trap (interrupt or exception) occurs on RISC-V 64.

+ 8 - 8
src/lib.rs

@@ -30,12 +30,10 @@ use alloc::ffi::CString;
 use core::hint::spin_loop;
 use core::sync::atomic::{AtomicBool, Ordering};
 
-use eonix_hal::context::TaskContext;
 use eonix_hal::processor::CPU;
 use eonix_hal::symbol_addr;
-use eonix_hal::traits::context::RawTaskContext;
-use eonix_hal::traits::trap::IrqState;
-use eonix_hal::trap::disable_irqs_save;
+use eonix_hal::traits::trap::{IrqState, RawTrapContext, TrapReturn};
+use eonix_hal::trap::{disable_irqs_save, TrapContext};
 use eonix_mm::address::PRange;
 use eonix_runtime::scheduler::RUNTIME;
 use kernel::mem::GlobalPageAlloc;
@@ -68,7 +66,7 @@ fn kernel_init(mut data: eonix_hal::bootstrap::BootStrapData) -> ! {
 
     drop(data);
 
-    let mut ctx = TaskContext::new();
+    let mut ctx = TrapContext::new();
     let stack_bottom = {
         let stack = KernelStack::new();
         let bottom = stack.get_bottom().addr().get();
@@ -77,11 +75,12 @@ fn kernel_init(mut data: eonix_hal::bootstrap::BootStrapData) -> ! {
         bottom
     };
     ctx.set_interrupt_enabled(true);
+    ctx.set_user_mode(false);
     ctx.set_program_counter(symbol_addr!(standard_main));
     ctx.set_stack_pointer(stack_bottom);
 
     unsafe {
-        TaskContext::switch_to_noreturn(&mut ctx);
+        ctx.trap_return_noreturn();
     }
 }
 
@@ -94,7 +93,7 @@ fn kernel_ap_main(_stack_range: PRange) -> ! {
 
     println_debug!("AP{} started", CPU::local().cpuid());
 
-    let mut ctx = TaskContext::new();
+    let mut ctx = TrapContext::new();
     let stack_bottom = {
         let stack = KernelStack::new();
         let bottom = stack.get_bottom().addr().get();
@@ -103,11 +102,12 @@ fn kernel_ap_main(_stack_range: PRange) -> ! {
         bottom
     };
     ctx.set_interrupt_enabled(true);
+    ctx.set_user_mode(false);
     ctx.set_program_counter(symbol_addr!(standard_main));
     ctx.set_stack_pointer(stack_bottom);
 
     unsafe {
-        TaskContext::switch_to_noreturn(&mut ctx);
+        ctx.trap_return_noreturn();
     }
 }