Эх сурвалжийг харах

refactor: refactor better abstraction for context switch

shao 3 долоо хоног өмнө
parent
commit
f048367b02

+ 2 - 6
Cargo.lock

@@ -1,6 +1,6 @@
 # This file is automatically @generated by Cargo.
 # It is not intended for manual editing.
-version = 3
+version = 4
 
 [[package]]
 name = "aho-corasick"
@@ -15,7 +15,7 @@ dependencies = [
 name = "arch"
 version = "0.1.0"
 dependencies = [
- "x86_64",
+ "cfg-if",
 ]
 
 [[package]]
@@ -336,7 +336,3 @@ name = "windows_x86_64_msvc"
 version = "0.52.6"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
-
-[[package]]
-name = "x86_64"
-version = "0.1.0"

+ 1 - 1
Cargo.toml

@@ -14,7 +14,7 @@ lazy_static = { version = "1.5.0", features = ["spin_no_std"] }
 spin = "0.9.8"
 
 [features]
-default = []
+# default = ["debug_syscall"]
 debug_syscall = []
 
 [build-dependencies]

+ 1 - 1
arch/Cargo.toml

@@ -4,4 +4,4 @@ version = "0.1.0"
 edition = "2021"
 
 [dependencies]
-x86_64 = { path="./x86_64" }
+cfg-if = "1.0"

+ 9 - 90
arch/src/lib.rs

@@ -1,94 +1,13 @@
 #![no_std]
 
-pub mod vm {
-    pub fn invlpg(vaddr: usize) {
-        x86_64::vm::invlpg(vaddr)
-    }
-
-    pub fn invlpg_all() {
-        x86_64::vm::invlpg_all()
-    }
-
-    pub fn current_page_table() -> usize {
-        x86_64::vm::get_cr3()
-    }
-
-    pub fn switch_page_table(pfn: usize) {
-        x86_64::vm::set_cr3(pfn)
-    }
-}
-
-pub mod task {
-    #[inline(always)]
-    pub fn halt() {
-        x86_64::task::halt()
-    }
-
-    #[inline(always)]
-    pub fn pause() {
-        x86_64::task::pause()
-    }
-
-    #[inline(always)]
-    pub fn freeze() -> ! {
-        x86_64::task::freeze()
-    }
-
-    /// Switch to the `next` task. `IF` state is also switched.
-    ///
-    /// This function should only be used to switch between tasks that do not need SMP synchronization.
-    ///
-    /// # Arguments
-    /// * `current_task_sp` - Pointer to the stack pointer of the current task.
-    /// * `next_task_sp` - Pointer to the stack pointer of the next task.
-    #[inline(always)]
-    pub fn context_switch_light(current_task_sp: *mut usize, next_task_sp: *mut usize) {
-        x86_64::task::context_switch_light(current_task_sp, next_task_sp);
-    }
-}
-
-pub mod interrupt {
-    #[inline(always)]
-    pub fn enable() {
-        x86_64::interrupt::enable()
-    }
-
-    #[inline(always)]
-    pub fn disable() {
-        x86_64::interrupt::disable()
+cfg_if::cfg_if! {
+    if #[cfg(target_arch = "x86_64")] {
+        mod x86_64;
+        pub use self::x86_64::*;
+    } else if #[cfg(target_arch = "riscv64")] {
+        mod riscv;
+        pub use self::riscv::*;
+    } else if #[cfg(target_arch = "aarch64")]{
+        // TODO!!!
     }
 }
-
-pub mod io {
-    #[inline(always)]
-    pub fn inb(port: u16) -> u8 {
-        x86_64::io::inb(port)
-    }
-
-    #[inline(always)]
-    pub fn outb(port: u16, data: u8) {
-        x86_64::io::outb(port, data)
-    }
-
-    #[inline(always)]
-    pub fn inw(port: u16) -> u16 {
-        x86_64::io::inw(port)
-    }
-
-    #[inline(always)]
-    pub fn outw(port: u16, data: u16) {
-        x86_64::io::outw(port, data)
-    }
-
-    #[inline(always)]
-    pub fn inl(port: u16) -> u32 {
-        x86_64::io::inl(port)
-    }
-
-    #[inline(always)]
-    pub fn outl(port: u16, data: u32) {
-        x86_64::io::outl(port, data)
-    }
-}
-
-pub use x86_64;

+ 78 - 0
arch/src/x86_64/context.rs

@@ -0,0 +1,78 @@
+use core::arch::global_asm;
+
+#[repr(C)]
+#[derive(Debug, Default)]
+struct ContextSwitchFrame {
+    r15: u64,
+    r14: u64,
+    r13: u64,
+    r12: u64,
+    rbx: u64,
+    rbp: u64,
+    eflags: u64,
+    rip: u64,
+}
+
+/// Necessary hardware states of task for context switch
+pub struct TaskContext {
+    /// The kernel stack pointer
+    pub rsp: u64, 
+    // Extended states, i.e., FP/SIMD states to do!
+}
+
+impl TaskContext {
+    pub const fn new() -> Self {
+        Self {
+            rsp: 0,
+        }
+    }
+
+    pub fn init(&mut self, entry: usize, kstack_top: usize) {
+        unsafe {
+            let frame_ptr = (kstack_top as *mut ContextSwitchFrame).sub(1);
+            core::ptr::write(
+                frame_ptr,
+                ContextSwitchFrame {
+                    rip: entry as u64,
+                    eflags: 0x200,
+                    ..Default::default()
+                },
+            );
+            self.rsp = frame_ptr as u64;
+        }
+    }
+
+    #[inline(always)]
+    pub fn switch_to(&mut self, next_task: &mut Self) {
+        unsafe { __switch_to(&mut self.rsp, &mut next_task.rsp) }
+    }
+}
+
+global_asm!(
+    r#"
+    .global __switch_to
+__switch_to:
+    pushf
+    push    rbp
+    push    rbx
+    push    r12
+    push    r13
+    push    r14
+    push    r15
+    mov     [rdi], rsp
+
+    mov     rsp, [rsi]
+    pop     r15
+    pop     r14
+    pop     r13
+    pop     r12
+    pop     rbx
+    pop     rbp
+    popf
+    ret
+    "#
+);
+
+extern "C" {
+    fn __switch_to(current_context_sp: &mut u64, next_context_sp: &mut u64);
+}

+ 60 - 0
arch/src/x86_64/interrupt.rs

@@ -0,0 +1,60 @@
+use core::arch::asm;
+
+/// Saved registers when a trap (interrupt or exception) occurs.
+#[allow(missing_docs)]
+#[repr(C)]
+#[derive(Debug, Default, Clone, Copy)]
+pub struct InterruptContext {
+    pub rax: u64,
+    pub rbx: u64,
+    pub rcx: u64,
+    pub rdx: u64,
+    pub rdi: u64,
+    pub rsi: u64,
+    pub r8: u64,
+    pub r9: u64,
+    pub r10: u64,
+    pub r11: u64,
+    pub r12: u64,
+    pub r13: u64,
+    pub r14: u64,
+    pub r15: u64,
+    pub rbp: u64,
+    
+    pub int_no: u64,
+    pub error_code: u64,
+
+    // Pushed by CPU
+    pub rip: u64,
+    pub cs: u64,
+    pub eflags: u64,
+    pub rsp: u64,
+    pub ss: u64,
+}
+
+
+pub fn enable_irqs() {
+    unsafe {
+        asm!("sti");
+    }
+}
+
+pub fn disable_irqs() {
+    unsafe {
+        asm!("cli");
+    }
+}
+
+pub fn lidt(base: usize, limit: u16) {
+    let mut idt_descriptor = [0u16; 5];
+
+    idt_descriptor[0] = limit;
+    idt_descriptor[1] = base as u16;
+    idt_descriptor[2] = (base >> 16) as u16;
+    idt_descriptor[3] = (base >> 32) as u16;
+    idt_descriptor[4] = (base >> 48) as u16;
+
+    unsafe {
+        asm!("lidt ({})", in(reg) &idt_descriptor, options(att_syntax));
+    }
+}

+ 214 - 0
arch/src/x86_64/interrupt.s

@@ -0,0 +1,214 @@
+.text
+
+#define RAX     0x00
+#define RBX     0x08
+#define RCX     0x10
+#define RDX     0x18
+#define RDI     0x20
+#define RSI     0x28
+#define R8      0x30
+#define R9      0x38
+#define R10     0x40
+#define R11     0x48
+#define R12     0x50
+#define R13     0x58
+#define R14     0x60
+#define R15     0x68
+#define RBP     0x70
+#define INT_NO  0x78
+#define ERRCODE 0x80
+#define RIP     0x88
+#define CS      0x90
+#define FLAGS   0x98
+#define RSP     0xa0
+#define SS      0xa8
+
+.macro movcfi reg, offset
+	mov \reg, \offset(%rsp)
+	.cfi_rel_offset \reg, \offset
+.endm
+
+.macro movrst reg, offset
+	mov \offset(%rsp), \reg
+	.cfi_restore \reg
+.endm
+
+ISR_stub:
+	.cfi_startproc
+	.cfi_signal_frame
+	.cfi_def_cfa_offset 0x18
+	.cfi_offset %rsp, 0x10
+
+	sub $0x78, %rsp
+	.cfi_def_cfa_offset 0x90
+
+	movcfi %rax, RAX
+	movcfi %rbx, RBX
+	movcfi %rcx, RCX
+	movcfi %rdx, RDX
+	movcfi %rdi, RDI
+	movcfi %rsi, RSI
+	movcfi %r8,  R8
+	movcfi %r9,  R9
+	movcfi %r10, R10
+	movcfi %r11, R11
+	movcfi %r12, R12
+	movcfi %r13, R13
+	movcfi %r14, R14
+	movcfi %r15, R15
+	movcfi %rbp, RBP
+
+	mov INT_NO(%rsp), %rax
+	sub $ISR0, %rax
+	shr $3, %rax
+	mov %rax, INT_NO(%rsp)
+
+	mov %rsp, %rbx
+	.cfi_def_cfa_register %rbx
+
+	and $~0xf, %rsp
+	sub $512, %rsp
+	fxsave (%rsp)
+
+	mov %rbx, %rdi
+	mov %rsp, %rsi
+	call interrupt_handler
+
+ISR_stub_restore:
+	fxrstor (%rsp)
+	mov %rbx, %rsp
+	.cfi_def_cfa_register %rsp
+
+	movrst %rax, RAX
+	movrst %rbx, RBX
+	movrst %rcx, RCX
+	movrst %rdx, RDX
+	movrst %rdi, RDI
+	movrst %rsi, RSI
+	movrst %r8,  R8
+	movrst %r9,  R9
+	movrst %r10, R10
+	movrst %r11, R11
+	movrst %r12, R12
+	movrst %r13, R13
+	movrst %r14, R14
+	movrst %r15, R15
+	movrst %rbp, RBP
+
+	add $0x88, %rsp
+	.cfi_def_cfa_offset 0x08
+
+	iretq
+	.cfi_endproc
+
+.altmacro
+.macro build_isr_no_err name
+	.align 8
+	.globl ISR\name
+	.type  ISR\name @function
+	ISR\name:
+		.cfi_startproc
+		.cfi_signal_frame
+		.cfi_def_cfa_offset 0x08
+		.cfi_offset %rsp, 0x10
+
+		.cfi_same_value %rax
+		.cfi_same_value %rbx
+		.cfi_same_value %rcx
+		.cfi_same_value %rdx
+		.cfi_same_value %rdi
+		.cfi_same_value %rsi
+		.cfi_same_value %r8
+		.cfi_same_value %r9
+		.cfi_same_value %r10
+		.cfi_same_value %r11
+		.cfi_same_value %r12
+		.cfi_same_value %r13
+		.cfi_same_value %r14
+		.cfi_same_value %r15
+		.cfi_same_value %rbp
+
+		push %rbp # push placeholder for error code
+		.cfi_def_cfa_offset 0x10
+
+		call ISR_stub
+		.cfi_endproc
+.endm
+
+.altmacro
+.macro build_isr_err name
+	.align 8
+	.globl ISR\name
+	.type  ISR\name @function
+	ISR\name:
+		.cfi_startproc
+		.cfi_signal_frame
+		.cfi_def_cfa_offset 0x10
+		.cfi_offset %rsp, 0x10
+
+		.cfi_same_value %rax
+		.cfi_same_value %rbx
+		.cfi_same_value %rcx
+		.cfi_same_value %rdx
+		.cfi_same_value %rdi
+		.cfi_same_value %rsi
+		.cfi_same_value %r8
+		.cfi_same_value %r9
+		.cfi_same_value %r10
+		.cfi_same_value %r11
+		.cfi_same_value %r12
+		.cfi_same_value %r13
+		.cfi_same_value %r14
+		.cfi_same_value %r15
+		.cfi_same_value %rbp
+
+		call ISR_stub
+		.cfi_endproc
+.endm
+
+build_isr_no_err 0
+build_isr_no_err 1
+build_isr_no_err 2
+build_isr_no_err 3
+build_isr_no_err 4
+build_isr_no_err 5
+build_isr_no_err 6
+build_isr_no_err 7
+build_isr_err    8
+build_isr_no_err 9
+build_isr_err    10
+build_isr_err    11
+build_isr_err    12
+build_isr_err    13
+build_isr_err    14
+build_isr_no_err 15
+build_isr_no_err 16
+build_isr_err    17
+build_isr_no_err 18
+build_isr_no_err 19
+build_isr_no_err 20
+build_isr_err    21
+build_isr_no_err 22
+build_isr_no_err 23
+build_isr_no_err 24
+build_isr_no_err 25
+build_isr_no_err 26
+build_isr_no_err 27
+build_isr_no_err 28
+build_isr_err    29
+build_isr_err    30
+build_isr_no_err 31
+
+.set i, 32
+.rept 0x80+1
+	build_isr_no_err %i
+	.set i, i+1
+.endr
+
+.section .rodata
+
+.align 8
+.globl ISR_START_ADDR
+.type  ISR_START_ADDR @object
+ISR_START_ADDR:
+	.quad ISR0

+ 0 - 0
arch/x86_64/src/io.rs → arch/src/x86_64/io.rs


+ 91 - 0
arch/src/x86_64/mod.rs

@@ -0,0 +1,91 @@
+mod context;
+mod interrupt;
+mod io;
+
+pub use self::context::*;
+pub use self::interrupt::*;
+pub use self::io::*;
+
+use core::arch::asm;
+
+#[inline(always)]
+pub fn flush_tlb(vaddr: usize) {
+    unsafe {
+        asm!(
+            "invlpg ({})",
+            in(reg) vaddr,
+            options(att_syntax)
+        );
+    }
+}
+
+#[inline(always)]
+pub fn flush_tlb_all() {
+    unsafe {
+        asm!(
+            "mov %cr3, %rax",
+            "mov %rax, %cr3",
+            out("rax") _,
+            options(att_syntax)
+        );
+    }
+}
+
+#[inline(always)]
+pub fn get_root_page_table() -> usize {
+    let cr3: usize;
+    unsafe {
+        asm!(
+            "mov %cr3, {0}",
+            out(reg) cr3,
+            options(att_syntax)
+        );
+    }
+    cr3
+}
+
+#[inline(always)]
+pub fn set_root_page_table(pfn: usize) {
+    unsafe {
+        asm!(
+            "mov {0}, %cr3",
+            in(reg) pfn,
+            options(att_syntax)
+        );
+    }
+}
+
+#[inline(always)]
+pub fn get_page_fault_address() -> usize {
+    let cr2: usize;
+    unsafe {
+        asm!(
+            "mov %cr2, {}",
+            out(reg) cr2,
+            options(att_syntax)
+        );
+    }
+    cr2
+}
+
+#[inline(always)]
+pub fn halt() {
+    unsafe {
+        asm!("hlt", options(att_syntax, nostack));
+    }
+}
+
+#[inline(always)]
+pub fn pause() {
+    unsafe {
+        asm!("pause", options(att_syntax, nostack));
+    }
+}
+
+#[inline(always)]
+pub fn freeze() -> ! {
+    loop {
+        interrupt::disable_irqs();
+        halt();
+    }
+}

+ 0 - 6
arch/x86_64/Cargo.toml

@@ -1,6 +0,0 @@
-[package]
-name = "x86_64"
-version = "0.1.0"
-edition = "2021"
-
-[dependencies]

+ 0 - 27
arch/x86_64/src/interrupt.rs

@@ -1,27 +0,0 @@
-use core::arch::asm;
-
-pub fn enable() {
-    unsafe {
-        asm!("sti");
-    }
-}
-
-pub fn disable() {
-    unsafe {
-        asm!("cli");
-    }
-}
-
-pub fn lidt(base: usize, limit: u16) {
-    let mut idt_descriptor = [0u16; 5];
-
-    idt_descriptor[0] = limit;
-    idt_descriptor[1] = base as u16;
-    idt_descriptor[2] = (base >> 16) as u16;
-    idt_descriptor[3] = (base >> 32) as u16;
-    idt_descriptor[4] = (base >> 48) as u16;
-
-    unsafe {
-        asm!("lidt ({})", in(reg) &idt_descriptor, options(att_syntax));
-    }
-}

+ 0 - 69
arch/x86_64/src/lib.rs

@@ -1,69 +0,0 @@
-#![no_std]
-
-pub mod vm {
-    use core::arch::asm;
-
-    #[inline(always)]
-    pub fn invlpg(vaddr: usize) {
-        unsafe {
-            asm!(
-                "invlpg ({})",
-                in(reg) vaddr,
-                options(att_syntax)
-            );
-        }
-    }
-
-    #[inline(always)]
-    pub fn invlpg_all() {
-        unsafe {
-            asm!(
-                "mov %cr3, %rax",
-                "mov %rax, %cr3",
-                out("rax") _,
-                options(att_syntax)
-            );
-        }
-    }
-
-    #[inline(always)]
-    pub fn get_cr3() -> usize {
-        let cr3: usize;
-        unsafe {
-            asm!(
-                "mov %cr3, {0}",
-                out(reg) cr3,
-                options(att_syntax)
-            );
-        }
-        cr3
-    }
-
-    #[inline(always)]
-    pub fn set_cr3(pfn: usize) {
-        unsafe {
-            asm!(
-                "mov {0}, %cr3",
-                in(reg) pfn,
-                options(att_syntax)
-            );
-        }
-    }
-
-    #[inline(always)]
-    pub fn get_cr2() -> usize {
-        let cr2: usize;
-        unsafe {
-            asm!(
-                "mov %cr2, {}",
-                out(reg) cr2,
-                options(att_syntax)
-            );
-        }
-        cr2
-    }
-}
-
-pub mod interrupt;
-pub mod io;
-pub mod task;

+ 0 - 97
arch/x86_64/src/task.rs

@@ -1,97 +0,0 @@
-use core::arch::{asm, global_asm};
-
-use crate::interrupt;
-
-#[inline(always)]
-pub fn halt() {
-    unsafe {
-        asm!("hlt", options(att_syntax, nostack));
-    }
-}
-
-#[inline(always)]
-pub fn pause() {
-    unsafe {
-        asm!("pause", options(att_syntax, nostack));
-    }
-}
-
-#[inline(always)]
-pub fn freeze() -> ! {
-    loop {
-        interrupt::disable();
-        halt();
-    }
-}
-
-global_asm!(
-    r"
-    .macro movcfi reg, offset
-        mov \reg, \offset(%rsp)
-        .cfi_rel_offset \reg, \offset
-    .endm
-
-    .macro movrst reg, offset
-        mov \offset(%rsp), \reg
-        .cfi_restore \reg
-    .endm
-
-    .globl __context_switch_light
-    .type __context_switch_light @function
-    __context_switch_light:
-    .cfi_startproc
-
-        pushf
-    .cfi_def_cfa_offset 0x10
-
-        sub $0x38, %rsp  # extra 8 bytes to align to 16 bytes
-    .cfi_def_cfa_offset 0x48
-
-        movcfi %rbx, 0x08
-        movcfi %rbp, 0x10
-        movcfi %r12, 0x18
-        movcfi %r13, 0x20
-        movcfi %r14, 0x28
-        movcfi %r15, 0x30
-
-        push (%rdi)      # save sp of previous stack frame of current
-                         # acts as saving bp
-    .cfi_def_cfa_offset 0x50
-
-        mov %rsp, (%rdi) # save sp of current stack
-        mov (%rsi), %rsp # load sp of target stack
-
-        pop (%rsi)       # load sp of previous stack frame of target
-                         # acts as restoring previous bp
-    .cfi_def_cfa_offset 0x48
-
-        pop %rax         # align to 16 bytes
-    .cfi_def_cfa_offset 0x40
-
-        mov 0x28(%rsp), %r15
-        mov 0x20(%rsp), %r14
-        mov 0x18(%rsp), %r13
-        mov 0x10(%rsp), %r12
-        mov 0x08(%rsp), %rbp
-        mov 0x00(%rsp), %rbx
-
-        add $0x30, %rsp
-    .cfi_def_cfa_offset 0x10
-
-        popf
-    .cfi_def_cfa_offset 0x08
-
-        ret
-    .cfi_endproc
-    ",
-    options(att_syntax),
-);
-
-extern "C" {
-    fn __context_switch_light(current_task_sp: *mut usize, next_task_sp: *mut usize);
-}
-
-#[inline(always)]
-pub fn context_switch_light(current_task_sp: *mut usize, next_task_sp: *mut usize) {
-    unsafe { __context_switch_light(current_task_sp, next_task_sp) }
-}

+ 2 - 2
src/driver.rs

@@ -14,10 +14,10 @@ impl Port8 {
     }
 
     pub fn read(&self) -> u8 {
-        arch::io::inb(self.no)
+        arch::inb(self.no)
     }
 
     pub fn write(&self, data: u8) {
-        arch::io::outb(self.no, data)
+        arch::outb(self.no, data)
     }
 }

+ 2 - 2
src/driver/timer.rs

@@ -4,7 +4,7 @@ const COUNT: Port8 = Port8::new(0x40);
 const CONTROL: Port8 = Port8::new(0x43);
 
 pub fn init() {
-    arch::interrupt::disable();
+    arch::disable_irqs();
     // Set interval
     CONTROL.write(0x34);
 
@@ -12,5 +12,5 @@ pub fn init() {
     // 0x2e9a = 11930 = 100Hz
     COUNT.write(0x9a);
     COUNT.write(0x2e);
-    arch::interrupt::enable();
+    arch::enable_irqs();
 }

+ 6 - 5
src/kernel/interrupt.rs

@@ -2,9 +2,10 @@ use alloc::boxed::Box;
 use alloc::vec;
 use alloc::vec::Vec;
 
+use arch::InterruptContext;
 use lazy_static::lazy_static;
 
-use crate::bindings::root::{interrupt_stack, mmx_registers, EINVAL};
+use crate::bindings::root::{mmx_registers, EINVAL};
 use crate::{driver::Port8, prelude::*};
 
 use super::mem::handle_page_fault;
@@ -86,7 +87,7 @@ fn irq_handler(irqno: usize) {
     }
 }
 
-fn fault_handler(int_stack: &mut interrupt_stack) {
+fn fault_handler(int_stack: &mut InterruptContext) {
     match int_stack.int_no {
         // Invalid Op or Double Fault
         14 => handle_page_fault(int_stack),
@@ -97,7 +98,7 @@ fn fault_handler(int_stack: &mut interrupt_stack) {
 }
 
 #[no_mangle]
-pub extern "C" fn interrupt_handler(int_stack: *mut interrupt_stack, mmxregs: *mut mmx_registers) {
+pub extern "C" fn interrupt_handler(int_stack: *mut InterruptContext, mmxregs: *mut mmx_registers) {
     let int_stack = unsafe { &mut *int_stack };
     let mmxregs = unsafe { &mut *mmxregs };
 
@@ -105,7 +106,7 @@ pub extern "C" fn interrupt_handler(int_stack: *mut interrupt_stack, mmxregs: *m
         // Fault
         0..0x20 => fault_handler(int_stack),
         // Syscall
-        0x80 => handle_syscall32(int_stack.regs.rax as usize, int_stack, mmxregs),
+        0x80 => handle_syscall32(int_stack.rax as usize, int_stack, mmxregs),
         // IRQ
         no => irq_handler(no as usize - 0x20),
     }
@@ -124,7 +125,7 @@ where
 }
 
 pub fn init() -> KResult<()> {
-    arch::x86_64::interrupt::lidt(
+    arch::lidt(
         IDT.as_ptr() as usize,
         (size_of::<IDTEntry>() * 256 - 1) as u16,
     );

+ 8 - 8
src/kernel/mem/mm_list/page_fault.rs

@@ -1,8 +1,8 @@
+use arch::InterruptContext;
 use bindings::kernel::mem::paging::pfn_to_page;
 use bindings::{PA_A, PA_ANON, PA_COW, PA_MMAP, PA_P, PA_RW};
 use bitflags::bitflags;
 
-use crate::bindings::root::interrupt_stack;
 use crate::kernel::mem::paging::{Page, PageBuffer};
 use crate::kernel::mem::phys::{CachedPP, PhysPtr};
 use crate::kernel::mem::{Mapping, VRange};
@@ -34,7 +34,7 @@ struct FixEntry {
 impl MMList {
     fn handle_page_fault(
         &self,
-        int_stack: &mut interrupt_stack,
+        int_stack: &mut InterruptContext,
         addr: VAddr,
         error: PageFaultError,
     ) -> Result<(), Signal> {
@@ -157,8 +157,8 @@ extern "C" {
 /// Try to fix the page fault by jumping to the `error` address.
 ///
 /// Panic if we can't find the `ip` in the fix list.
-fn try_page_fault_fix(int_stack: &mut interrupt_stack, addr: VAddr) {
-    let ip = int_stack.v_rip as u64;
+fn try_page_fault_fix(int_stack: &mut InterruptContext, addr: VAddr) {
+    let ip = int_stack.rip as u64;
     // TODO: Use `op_type` to fix.
 
     // SAFETY: `FIX_START` and `FIX_END` are defined in the linker script in `.rodata` section.
@@ -171,7 +171,7 @@ fn try_page_fault_fix(int_stack: &mut interrupt_stack, addr: VAddr) {
 
     for entry in entries.iter() {
         if ip >= entry.start && ip < entry.start + entry.length {
-            int_stack.v_rip = entry.jump_address as usize;
+            int_stack.rip = entry.jump_address as u64;
             return;
         }
     }
@@ -186,9 +186,9 @@ fn kernel_page_fault_die(vaddr: VAddr, ip: usize) -> ! {
     )
 }
 
-pub fn handle_page_fault(int_stack: &mut interrupt_stack) {
+pub fn handle_page_fault(int_stack: &mut InterruptContext) {
     let error = PageFaultError::from_bits_truncate(int_stack.error_code);
-    let vaddr = VAddr(arch::x86_64::vm::get_cr2());
+    let vaddr = VAddr(arch::get_page_fault_address());
 
     let result = Thread::current()
         .process
@@ -199,7 +199,7 @@ pub fn handle_page_fault(int_stack: &mut interrupt_stack) {
         println_debug!(
             "Page fault on {:#x} in user space at {:#x}",
             vaddr.0,
-            int_stack.v_rip
+            int_stack.rip
         );
         ProcessList::kill_current(signal)
     }

+ 6 - 6
src/kernel/mem/page_table.rs

@@ -222,7 +222,7 @@ impl PageTable {
     }
 
     pub fn switch(&self) {
-        arch::vm::switch_page_table(self.page.as_phys())
+        arch::set_root_page_table(self.page.as_phys())
     }
 
     pub fn unmap(&self, area: &MMArea) {
@@ -230,7 +230,7 @@ impl PageTable {
         let use_invlpg = range.len() / 4096 < 4;
         let iter = self.iter_user(range).unwrap();
 
-        if self.page.as_phys() != arch::vm::current_page_table() {
+        if self.page.as_phys() != arch::get_root_page_table() {
             for pte in iter {
                 pte.take();
             }
@@ -242,19 +242,19 @@ impl PageTable {
                 pte.take();
 
                 let pfn = range.start().floor().0 + offset_pages * 4096;
-                arch::vm::invlpg(pfn);
+                arch::flush_tlb(pfn);
             }
         } else {
             for pte in iter {
                 pte.take();
             }
-            arch::vm::invlpg_all();
+            arch::flush_tlb_all();
         }
     }
 
     pub fn lazy_invalidate_tlb_all(&self) {
-        if self.page.as_phys() == arch::vm::current_page_table() {
-            arch::vm::invlpg_all();
+        if self.page.as_phys() == arch::get_root_page_table() {
+            arch::flush_tlb_all();
         }
     }
 

+ 25 - 22
src/kernel/syscall.rs

@@ -1,8 +1,11 @@
 use crate::{
-    bindings::root::{interrupt_stack, mmx_registers},
+    bindings::root::{mmx_registers},
     kernel::task::{ProcessList, Signal},
     println_warn,
 };
+use arch::InterruptContext;
+
+extern crate arch;
 
 mod file_rw;
 mod mm;
@@ -75,22 +78,22 @@ impl<'a, T: 'a> MapArgument<'a, *mut T> for MapArgumentImpl {
 
 macro_rules! arg_register {
     (0, $is:ident) => {
-        $is.regs.rbx
+        $is.rbx
     };
     (1, $is:ident) => {
-        $is.regs.rcx
+        $is.rcx
     };
     (2, $is:ident) => {
-        $is.regs.rdx
+        $is.rdx
     };
     (3, $is:ident) => {
-        $is.regs.rsi
+        $is.rsi
     };
     (4, $is:ident) => {
-        $is.regs.rdi
+        $is.rdi
     };
     (5, $is:ident) => {
-        $is.regs.rbp
+        $is.rbp
     };
 }
 
@@ -144,7 +147,7 @@ macro_rules! syscall32_call {
 
 macro_rules! define_syscall32 {
     ($name:ident, $handler:ident) => {
-        fn $name(_int_stack: &mut $crate::bindings::root::interrupt_stack,
+        fn $name(_int_stack: &mut $crate::kernel::syscall::arch::InterruptContext,
             _mmxregs: &mut $crate::bindings::root::mmx_registers) -> usize {
             use $crate::kernel::syscall::MapReturnValue;
 
@@ -156,7 +159,7 @@ macro_rules! define_syscall32 {
     };
     ($name:ident, $handler:ident, $($arg:ident: $argt:ty),*) => {
         fn $name(
-            int_stack: &mut $crate::bindings::root::interrupt_stack,
+            int_stack: &mut $crate::kernel::syscall::arch::InterruptContext,
             _mmxregs: &mut $crate::bindings::root::mmx_registers) -> usize {
             use $crate::kernel::syscall::syscall32_call;
 
@@ -180,13 +183,13 @@ use super::task::Thread;
 pub(self) use {arg_register, define_syscall32, format_expand, register_syscall, syscall32_call};
 
 pub(self) struct SyscallHandler {
-    handler: fn(&mut interrupt_stack, &mut mmx_registers) -> usize,
+    handler: fn(&mut InterruptContext, &mut mmx_registers) -> usize,
     name: &'static str,
 }
 
 pub(self) fn register_syscall_handler(
     no: usize,
-    handler: fn(&mut interrupt_stack, &mut mmx_registers) -> usize,
+    handler: fn(&mut InterruptContext, &mut mmx_registers) -> usize,
     name: &'static str,
 ) {
     // SAFETY: `SYSCALL_HANDLERS` is never modified after initialization.
@@ -210,7 +213,7 @@ const SYSCALL_HANDLERS_SIZE: usize = 404;
 static mut SYSCALL_HANDLERS: [Option<SyscallHandler>; SYSCALL_HANDLERS_SIZE] =
     [const { None }; SYSCALL_HANDLERS_SIZE];
 
-pub fn handle_syscall32(no: usize, int_stack: &mut interrupt_stack, mmxregs: &mut mmx_registers) {
+pub fn handle_syscall32(no: usize, int_stack: &mut InterruptContext, mmxregs: &mut mmx_registers) {
     // SAFETY: `SYSCALL_HANDLERS` are never modified after initialization.
     let syscall = unsafe { SYSCALL_HANDLERS.get(no) }.and_then(Option::as_ref);
 
@@ -220,19 +223,19 @@ pub fn handle_syscall32(no: usize, int_stack: &mut interrupt_stack, mmxregs: &mu
             ProcessList::kill_current(Signal::SIGSYS);
         }
         Some(handler) => {
-            arch::interrupt::enable();
+            arch::enable_irqs();
             let retval = (handler.handler)(int_stack, mmxregs);
 
             // SAFETY: `int_stack` is always valid.
-            int_stack.regs.rax = retval as u64;
-            int_stack.regs.r8 = 0;
-            int_stack.regs.r9 = 0;
-            int_stack.regs.r10 = 0;
-            int_stack.regs.r11 = 0;
-            int_stack.regs.r12 = 0;
-            int_stack.regs.r13 = 0;
-            int_stack.regs.r14 = 0;
-            int_stack.regs.r15 = 0;
+            int_stack.rax = retval as u64;
+            int_stack.r8 = 0;
+            int_stack.r9 = 0;
+            int_stack.r10 = 0;
+            int_stack.r11 = 0;
+            int_stack.r12 = 0;
+            int_stack.r13 = 0;
+            int_stack.r14 = 0;
+            int_stack.r15 = 0;
         }
     }
 

+ 15 - 60
src/kernel/syscall/procops.rs

@@ -1,8 +1,9 @@
-use core::arch::global_asm;
+use core::arch::{asm, global_asm, naked_asm};
 
 use alloc::borrow::ToOwned;
 use alloc::ffi::CString;
-use bindings::{interrupt_stack, mmx_registers, EINVAL, ENOENT, ENOTDIR, ESRCH};
+use arch::InterruptContext;
+use bindings::{mmx_registers, EINVAL, ENOENT, ENOTDIR, ESRCH};
 use bitflags::bitflags;
 
 use crate::elf::ParsedElf32;
@@ -105,14 +106,14 @@ fn do_execve(exec: &[u8], argv: Vec<CString>, envp: Vec<CString>) -> KResult<(VA
     }
 }
 
-fn sys_execve(int_stack: &mut interrupt_stack, _mmxregs: &mut mmx_registers) -> usize {
+fn sys_execve(int_stack: &mut InterruptContext, _mmxregs: &mut mmx_registers) -> usize {
     match (|| -> KResult<()> {
-        let exec = int_stack.regs.rbx as *const u8;
+        let exec = int_stack.rbx as *const u8;
         let exec = UserString::new(exec)?;
 
         // TODO!!!!!: copy from user
-        let mut argv: UserPointer<u32> = UserPointer::new_vaddr(int_stack.regs.rcx as _)?;
-        let mut envp: UserPointer<u32> = UserPointer::new_vaddr(int_stack.regs.rdx as _)?;
+        let mut argv: UserPointer<u32> = UserPointer::new_vaddr(int_stack.rcx as _)?;
+        let mut envp: UserPointer<u32> = UserPointer::new_vaddr(int_stack.rdx as _)?;
 
         let mut argv_vec = Vec::new();
         let mut envp_vec = Vec::new();
@@ -141,8 +142,8 @@ fn sys_execve(int_stack: &mut interrupt_stack, _mmxregs: &mut mmx_registers) ->
 
         let (ip, sp) = do_execve(exec.as_cstr().to_bytes(), argv_vec, envp_vec)?;
 
-        int_stack.v_rip = ip.0;
-        int_stack.rsp = sp.0;
+        int_stack.rip = ip.0 as u64;
+        int_stack.rsp = sp.0 as u64;
         Ok(())
     })() {
         Ok(_) => 0,
@@ -446,65 +447,19 @@ define_syscall32!(sys_rt_sigprocmask, do_rt_sigprocmask,
 define_syscall32!(sys_rt_sigaction, do_rt_sigaction,
     signum: u32, act: *const UserSignalAction, oldact: *mut UserSignalAction, sigsetsize: usize);
 
-extern "C" {
-    fn ISR_stub_restore();
-    fn new_process_return();
-}
-
-unsafe extern "C" fn real_new_process_return() {
-    // We don't land on the typical `Scheduler::schedule()` function, so we need to
-    // manually enable preemption.
-    preempt::enable();
-}
 
-global_asm!(
-    r"
-        .globl new_process_return
-        new_process_return:
-            call {0}
-            jmp {1}
-    ",
-    sym real_new_process_return,
-    sym ISR_stub_restore,
-    options(att_syntax),
-);
 
-fn sys_fork(int_stack: &mut interrupt_stack, mmxregs: &mut mmx_registers) -> usize {
+fn sys_fork(int_stack: &mut InterruptContext, mmxregs: &mut mmx_registers) -> usize {
     let new_thread = Thread::new_cloned(Thread::current());
-
-    // TODO: We should make the preparation of the kernel stack more abstract.
-    //       Currently, we can see that we are directly writing to the kernel stack,
-    //       which is platform dependent.
-    new_thread.prepare_kernel_stack(|kstack| {
-        let mut writer = kstack.get_writer();
-
-        // We make the child process return to `ISR_stub_restore`, pretending that we've
-        // just returned from a interrupt handler.
-        writer.entry = new_process_return;
-
-        let mut new_int_stack = int_stack.clone();
-
-        // Child's return value: 0
-        new_int_stack.regs.rax = 0;
-
-        writer.write(new_int_stack);
-
-        // In `ISR_stub_restore`, we will restore the mmx register context, followed by
-        // restoring the stack pointer by moving the value in `rbx` to `rsp`, which should
-        // point to the interrupt stack.
-        writer.rbx = writer.get_current_sp();
-
-        // Push the mmx register context to the stack.
-        writer.write(mmxregs.clone());
-
-        writer.finish();
-    });
-
+    let mut new_int_stack = int_stack.clone();
+    new_int_stack.rax = 0;
+    new_int_stack.eflags = 0x200;
+    new_thread.fork_init(new_int_stack);
     Scheduler::get().lock_irq().uwake(&new_thread);
     new_thread.process.pid as usize
 }
 
-fn sys_sigreturn(int_stack: &mut interrupt_stack, mmxregs: &mut mmx_registers) -> usize {
+fn sys_sigreturn(int_stack: &mut InterruptContext, mmxregs: &mut mmx_registers) -> usize {
     let result = Thread::current().signal_list.restore(int_stack, mmxregs);
     match result {
         Ok(ret) => ret,

+ 11 - 83
src/kernel/task/kstack.rs

@@ -1,91 +1,19 @@
+use arch::InterruptContext;
+
 use crate::kernel::mem::{
     paging::Page,
     phys::{CachedPP, PhysPtr},
 };
 
-use core::cell::UnsafeCell;
-
 pub struct KernelStack {
     pages: Page,
     bottom: usize,
-    sp: UnsafeCell<usize>,
-}
-
-pub struct KernelStackWriter<'lt> {
-    sp: &'lt mut usize,
-    prev_sp: usize,
-
-    pub entry: unsafe extern "C" fn(),
-    pub flags: usize,
-    pub r15: usize,
-    pub r14: usize,
-    pub r13: usize,
-    pub r12: usize,
-    pub rbp: usize,
-    pub rbx: usize,
 }
 
 unsafe extern "C" fn __not_assigned_entry() {
     panic!("__not_assigned_entry called");
 }
 
-impl<'lt> KernelStackWriter<'lt> {
-    fn new(sp: &'lt mut usize) -> Self {
-        let prev_sp = *sp;
-
-        Self {
-            sp,
-            entry: __not_assigned_entry,
-            flags: 0,
-            r15: 0,
-            r14: 0,
-            r13: 0,
-            r12: 0,
-            rbp: 0,
-            rbx: 0,
-            prev_sp,
-        }
-    }
-
-    /// `data` and current sp should have an alignment of 16 bytes.
-    /// Otherwise, extra padding is added.
-    pub fn write<T: Copy>(&mut self, data: T) {
-        *self.sp -= core::mem::size_of::<T>();
-        *self.sp &= !0xf; // Align to 16 bytes
-
-        // SAFETY: `sp` is always valid.
-        unsafe {
-            (*self.sp as *mut T).write(data);
-        }
-    }
-
-    pub fn get_current_sp(&self) -> usize {
-        *self.sp
-    }
-
-    fn push(&mut self, val: usize) {
-        *self.sp -= core::mem::size_of::<usize>();
-
-        // SAFETY: `sp` is always valid.
-        unsafe {
-            (*self.sp as *mut usize).write(val);
-        }
-    }
-
-    pub fn finish(mut self) {
-        self.push(self.entry as usize);
-        self.push(self.flags); // rflags
-        self.push(self.r15); // r15
-        self.push(self.r14); // r14
-        self.push(self.r13); // r13
-        self.push(self.r12); // r12
-        self.push(self.rbp); // rbp
-        self.push(self.rbx); // rbx
-        self.push(0); // 0 for alignment
-        self.push(self.prev_sp) // previous sp
-    }
-}
-
 impl KernelStack {
     /// Kernel stack page order
     /// 7 for `2^7 = 128 pages = 512 KiB`
@@ -98,7 +26,6 @@ impl KernelStack {
         Self {
             pages,
             bottom,
-            sp: UnsafeCell::new(bottom),
         }
     }
 
@@ -112,15 +39,16 @@ impl KernelStack {
         }
     }
 
-    pub fn get_writer(&mut self) -> KernelStackWriter {
-        KernelStackWriter::new(self.sp.get_mut())
+    pub fn get_stack_bottom(&self) -> usize {
+        self.bottom
     }
 
-    /// Get a pointer to `self.sp` so we can use it in `context_switch()`.
-    ///
-    /// # Safety
-    /// Save the pointer somewhere or pass it to a function that will use it is UB.
-    pub unsafe fn get_sp_ptr(&self) -> *mut usize {
-        self.sp.get()
+    pub fn init(&self, interrupt_context: InterruptContext) -> usize {
+        let mut sp = self.bottom - core::mem::size_of::<InterruptContext>();
+        sp &= !0xf;
+        unsafe {
+            (sp as *mut InterruptContext).write(interrupt_context);
+        }
+        sp
     }
 }

+ 9 - 9
src/kernel/task/scheduler.rs

@@ -59,14 +59,14 @@ impl Scheduler {
     }
 
     pub(super) fn set_idle(thread: Arc<Thread>) {
-        thread.prepare_kernel_stack(|kstack| {
-            let mut writer = kstack.get_writer();
-            writer.flags = 0x200;
-            writer.entry = idle_task;
-            writer.finish();
-        });
+        // thread.prepare_kernel_stack(|kstack| {
+        //     let mut writer = kstack.get_writer();
+        //     writer.flags = 0x200;
+        //     writer.entry = idle_task;
+        //     writer.finish();
+        // });
         // We don't wake the idle thread to prevent from accidentally being scheduled there.
-
+        thread.init(idle_task as usize);
         // TODO!!!: Set per cpu variable.
         unsafe { IDLE_TASK = Some(thread) };
     }
@@ -174,7 +174,7 @@ impl Scheduler {
 
 fn context_switch_light(from: &Arc<Thread>, to: &Arc<Thread>) {
     unsafe {
-        arch::task::context_switch_light(from.get_sp_ptr(), to.get_sp_ptr());
+        arch::TaskContext::switch_to(&mut *(from.get_context_mut_ptr()) ,&mut *(to.get_context_mut_ptr()));
     }
 }
 
@@ -203,7 +203,7 @@ extern "C" fn idle_task() {
         // No thread to run, halt the cpu and rerun the loop.
         if scheduler.ready.is_empty() {
             drop(scheduler);
-            arch::task::halt();
+            arch::halt();
             continue;
         }
 

+ 13 - 12
src/kernel/task/signal.rs

@@ -10,8 +10,9 @@ use crate::{
 };
 
 use alloc::collections::{binary_heap::BinaryHeap, btree_map::BTreeMap};
+use arch::InterruptContext;
 use bindings::{
-    interrupt_stack, mmx_registers, EFAULT, EINVAL, SA_RESTORER, SIGABRT, SIGBUS, SIGCHLD, SIGCONT,
+    mmx_registers, EFAULT, EINVAL, SA_RESTORER, SIGABRT, SIGBUS, SIGCHLD, SIGCONT,
     SIGFPE, SIGILL, SIGKILL, SIGQUIT, SIGSEGV, SIGSTOP, SIGSYS, SIGTRAP, SIGTSTP, SIGTTIN, SIGTTOU,
     SIGURG, SIGWINCH, SIGXCPU, SIGXFSZ,
 };
@@ -168,14 +169,14 @@ impl SignalAction {
         &self,
         signum: u32,
         old_mask: u64,
-        int_stack: &mut interrupt_stack,
+        int_stack: &mut InterruptContext,
         mmxregs: &mut mmx_registers,
     ) -> KResult<()> {
         if self.sa_flags & SA_RESTORER as usize == 0 {
             return Err(EINVAL);
         }
 
-        const CONTEXT_SIZE: usize = size_of::<interrupt_stack>()
+        const CONTEXT_SIZE: usize = size_of::<InterruptContext>()
             + size_of::<mmx_registers>()
             + size_of::<usize>() // old_mask
             + size_of::<u32>(); // `sa_handler` argument: `signum`
@@ -183,9 +184,9 @@ impl SignalAction {
         // Save current interrupt context to 128 bytes above current user stack
         // and align to 16 bytes. Then we push the return address of the restorer.
         // TODO!!!: Determine the size of the return address
-        let sp = ((int_stack.rsp - 128 - CONTEXT_SIZE) & !0xf) - size_of::<u32>();
+        let sp = ((int_stack.rsp as usize - 128 - CONTEXT_SIZE) & !0xf) - size_of::<u32>();
         let restorer_address: u32 = self.sa_restorer as u32;
-        let mut stack = UserBuffer::new(sp as *mut _, CONTEXT_SIZE + size_of::<u32>())?;
+        let mut stack = UserBuffer::new(sp as *mut u8, CONTEXT_SIZE + size_of::<u32>())?;
 
         stack.copy(&restorer_address)?.ok_or(EFAULT)?; // Restorer address
         stack.copy(&signum)?.ok_or(EFAULT)?; // Restorer address
@@ -193,8 +194,8 @@ impl SignalAction {
         stack.copy(mmxregs)?.ok_or(EFAULT)?; // MMX registers
         stack.copy(int_stack)?.ok_or(EFAULT)?; // Interrupt stack
 
-        int_stack.v_rip = self.sa_handler;
-        int_stack.rsp = sp;
+        int_stack.rip = self.sa_handler as u64;
+        int_stack.rsp = sp as u64;
         Ok(())
     }
 }
@@ -333,7 +334,7 @@ impl SignalList {
     /// # Safety
     /// This function might never return. Caller must make sure that local variables
     /// that own resources are dropped before calling this function.
-    pub fn handle(&self, int_stack: &mut interrupt_stack, mmxregs: &mut mmx_registers) {
+    pub fn handle(&self, int_stack: &mut InterruptContext, mmxregs: &mut mmx_registers) {
         loop {
             let signal = {
                 let signal = match self.inner.lock_irq().pop() {
@@ -396,18 +397,18 @@ impl SignalList {
     /// used to store the syscall return value to prevent the original value being clobbered.
     pub fn restore(
         &self,
-        int_stack: &mut interrupt_stack,
+        int_stack: &mut InterruptContext,
         mmxregs: &mut mmx_registers,
     ) -> KResult<usize> {
-        let old_mask_vaddr = int_stack.rsp;
+        let old_mask_vaddr = int_stack.rsp as usize;
         let old_mmxregs_vaddr = old_mask_vaddr + size_of::<usize>();
         let old_int_stack_vaddr = old_mmxregs_vaddr + size_of::<mmx_registers>();
 
         let old_mask = UserPointer::<u64>::new_vaddr(old_mask_vaddr)?.read()?;
         *mmxregs = UserPointer::<mmx_registers>::new_vaddr(old_mmxregs_vaddr)?.read()?;
-        *int_stack = UserPointer::<interrupt_stack>::new_vaddr(old_int_stack_vaddr)?.read()?;
+        *int_stack = UserPointer::<InterruptContext>::new_vaddr(old_int_stack_vaddr)?.read()?;
 
         self.inner.lock_irq().set_mask(old_mask);
-        Ok(int_stack.regs.rax as usize)
+        Ok(int_stack.rax as usize)
     }
 }

+ 58 - 20
src/kernel/task/thread.rs

@@ -1,6 +1,6 @@
 use core::{
-    arch::asm,
-    cell::RefCell,
+    arch::{asm, naked_asm},
+    cell::{RefCell, UnsafeCell},
     cmp,
     sync::atomic::{self, AtomicU32},
 };
@@ -29,10 +29,11 @@ use lazy_static::lazy_static;
 use crate::kernel::vfs::filearray::FileArray;
 
 use super::{
-    signal::{RaiseResult, Signal, SignalList},
-    KernelStack, Scheduler,
+    kstack, signal::{RaiseResult, Signal, SignalList}, KernelStack, Scheduler
 };
 
+use arch::{TaskContext, InterruptContext};
+
 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
 pub enum ThreadState {
     Preparing,
@@ -212,13 +213,16 @@ pub struct Thread {
     /// Thread state for scheduler use.
     pub state: Spin<ThreadState>,
 
+    /// Thread context
+    pub context: UnsafeCell<TaskContext>,
+
     /// Kernel stack
     /// Never access this directly.
     ///
     /// We can only touch kernel stack when the process is neither running nor sleeping.
     /// AKA, the process is in the ready queue and will return to `schedule` context.
     kstack: RefCell<KernelStack>,
-
+    
     inner: Spin<ThreadInner>,
 }
 
@@ -809,6 +813,7 @@ impl Thread {
             fs_context: FsContext::new_for_init(),
             signal_list: SignalList::new(),
             kstack: RefCell::new(KernelStack::new()),
+            context: UnsafeCell::new(TaskContext::new()),
             state: Spin::new(ThreadState::Preparing),
             inner: Spin::new(ThreadInner {
                 name,
@@ -838,6 +843,7 @@ impl Thread {
             fs_context: FsContext::new_cloned(&other.fs_context),
             signal_list,
             kstack: RefCell::new(KernelStack::new()),
+            context: UnsafeCell::new(TaskContext::new()),
             state: Spin::new(ThreadState::Preparing),
             inner: Spin::new(ThreadInner {
                 name: other_inner.name.clone(),
@@ -945,31 +951,34 @@ impl Thread {
         Ok(())
     }
 
-    /// This function is used to prepare the kernel stack for the thread in `Preparing` state.
-    ///
-    /// # Safety
-    /// Calling this function on a thread that is not in `Preparing` state will panic.
-    pub fn prepare_kernel_stack<F: FnOnce(&mut KernelStack)>(&self, func: F) {
+    pub fn fork_init(&self, interrupt_context: InterruptContext) {
         let mut state = self.state.lock();
-        assert!(matches!(*state, ThreadState::Preparing));
+        *state = ThreadState::USleep;
 
-        // SAFETY: We are in the preparing state with `state` locked.
-        func(&mut self.kstack.borrow_mut());
+        let sp = self.kstack.borrow().init(interrupt_context);
+        unsafe {
+            (&mut(*self.get_context_mut_ptr())).init(fork_return as usize, sp);
+        }
+    }
 
-        // Enter USleep state. Await for the thread to be scheduled manually.
+    pub fn init(&self, entry: usize) {
+        let mut state = self.state.lock();
         *state = ThreadState::USleep;
+        unsafe {
+            (&mut(*self.get_context_mut_ptr())).init(entry, self.get_kstack_bottom()); 
+        }
     }
 
     pub fn load_interrupt_stack(&self) {
         self.kstack.borrow().load_interrupt_stack();
     }
 
-    /// Get a pointer to `self.sp` so we can use it in `context_switch()`.
-    ///
-    /// # Safety
-    /// Save the pointer somewhere or pass it to a function that will use it is UB.
-    pub unsafe fn get_sp_ptr(&self) -> *mut usize {
-        self.kstack.borrow().get_sp_ptr()
+    pub fn get_kstack_bottom(&self) -> usize {
+        self.kstack.borrow().get_stack_bottom()
+    }
+
+    pub unsafe fn get_context_mut_ptr(&self) -> *mut TaskContext {
+        self.context.get() 
     }
 
     pub fn set_name(&self, name: Arc<[u8]>) {
@@ -981,6 +990,35 @@ impl Thread {
     }
 }
 
+#[naked]
+unsafe extern "C" fn fork_return() {
+    // We don't land on the typical `Scheduler::schedule()` function, so we need to
+    // manually enable preemption.
+    naked_asm! {
+        "
+        call {preempt_enable}
+        pop rax
+        pop rbx
+        pop rcx
+        pop rdx 
+        pop rdi
+        pop rsi
+        pop r8
+        pop r9
+        pop r10
+        pop r11
+        pop r12
+        pop r13
+        pop r14
+        pop r15
+        pop rbp
+        add rsp, 16
+        iretq
+        ",
+        preempt_enable = sym preempt::enable,
+    }
+}
+
 // TODO: Maybe we can find a better way instead of using `RefCell` for `KernelStack`?
 unsafe impl Sync for Thread {}
 

+ 16 - 17
src/lib.rs

@@ -5,6 +5,8 @@
 #![feature(arbitrary_self_types)]
 #![feature(get_mut_unchecked)]
 #![feature(macro_metavar_expr)]
+#![feature(naked_functions)]
+
 extern crate alloc;
 
 #[allow(warnings)]
@@ -56,7 +58,7 @@ fn panic(info: &core::panic::PanicInfo) -> ! {
     println_fatal!();
     println_fatal!("{}", info.message());
 
-    arch::task::freeze()
+    arch::freeze()
 }
 
 extern "C" {
@@ -124,30 +126,29 @@ pub extern "C" fn rust_kinit(early_kstack_pfn: usize) -> ! {
     // We need root dentry to be present in constructor of `FsContext`.
     // So call `init_vfs` first, then `init_multitasking`.
     init_multitasking();
-    Thread::current().prepare_kernel_stack(|kstack| {
-        let mut writer = kstack.get_writer();
-        writer.entry = to_init_process;
-        writer.flags = 0x200;
-        writer.rbp = 0;
-        writer.rbx = early_kstack_pfn; // `to_init_process` arg
-        writer.finish();
-    });
 
+    Thread::current().init(init_process as usize);
     // To satisfy the `Scheduler` "preempt count == 0" assertion.
     preempt::disable();
 
     Scheduler::get().lock().uwake(Thread::current());
 
-    arch::task::context_switch_light(
-        CachedPP::new(early_kstack_pfn).as_ptr(), // We will never come back
-        unsafe { Scheduler::idle_task().get_sp_ptr() },
-    );
-    arch::task::freeze()
+    let mut unuse_ctx = arch::TaskContext::new();
+    unuse_ctx.init(to_init_process as usize, early_kstack_pfn + 0xffffff0000000000);
+    unsafe {
+        arch::TaskContext::switch_to(
+            &mut unuse_ctx, // We will never come back
+            &mut *Scheduler::idle_task().get_context_mut_ptr()
+        );
+    }
+
+    arch::freeze()
 }
 
 /// We enter this function with `preempt count == 0`
 extern "C" fn init_process(early_kstack_pfn: usize) {
-    unsafe { Page::take_pfn(early_kstack_pfn, 9) };
+    // TODO!!! Should free pass eraly_kstack_pfn and free !!!
+    // unsafe { Page::take_pfn(early_kstack_pfn, 9) };
     preempt::enable();
 
     kernel::timer::init().unwrap();
@@ -208,8 +209,6 @@ extern "C" fn init_process(early_kstack_pfn: usize) {
 
     unsafe {
         asm!(
-            "mov %ax, %fs",
-            "mov %ax, %gs",
             "mov ${ds}, %rax",
             "mov %ax, %ds",
             "mov %ax, %es",