Эх сурвалжийг харах

fix(lock): add lock_irq to avoid deadlocks

greatbridf 11 сар өмнө
parent
commit
bd0da59162

+ 2 - 1
CMakeLists.txt

@@ -38,6 +38,7 @@ set(BOOTLOADER_SOURCES src/boot.s
 set(KERNEL_MAIN_SOURCES src/fs/fat.cpp
                         src/kinit.cpp
                         src/kernel/async/waitlist.cc
+                        src/kernel/async/lock.cc
                         src/kernel/allocator.cc
                         src/kernel/interrupt.cpp
                         src/kernel/process.cpp
@@ -64,6 +65,7 @@ set(KERNEL_MAIN_SOURCES src/fs/fat.cpp
                         include/asm/sys.h
                         include/fs/fat.hpp
                         include/kernel/async/waitlist.hpp
+                        include/kernel/async/lock.hpp
                         include/kernel/tty.hpp
                         include/kernel/interrupt.h
                         include/kernel/irq.hpp
@@ -99,7 +101,6 @@ set(KERNEL_MAIN_SOURCES src/fs/fat.cpp
                         include/types/status.h
                         include/types/allocator.hpp
                         include/types/cplusplus.hpp
-                        include/types/lock.hpp
                         include/types/string.hpp
                         include/kernel/log.hpp
                         )

+ 63 - 0
include/kernel/async/lock.hpp

@@ -0,0 +1,63 @@
+#pragma once
+
+#include <stdint.h>
+
+namespace kernel::async {
+
+using spinlock_t = uint32_t volatile;
+using preempt_count_t = size_t;
+
+void preempt_disable();
+void preempt_enable();
+preempt_count_t preempt_count();
+
+void init_spinlock(spinlock_t& lock);
+
+void spin_lock(spinlock_t& lock);
+void spin_unlock(spinlock_t& lock);
+
+uint32_t spin_lock_irqsave(spinlock_t& lock);
+void spin_unlock_irqrestore(spinlock_t& lock, uint32_t state);
+
+class mutex {
+private:
+    spinlock_t m_lock;
+
+public:
+    constexpr mutex() : m_lock {0} { }
+    mutex(const mutex&) = delete;
+    ~mutex();
+
+    void lock();
+    void unlock();
+
+    uint32_t lock_irq();
+    void unlock_irq(uint32_t state);
+};
+
+class lock_guard {
+private:
+    mutex& m_mtx;
+
+public:
+    explicit inline lock_guard(mutex& mtx)
+        : m_mtx {mtx} { m_mtx.lock(); }
+    lock_guard(const lock_guard&) = delete;
+
+    inline ~lock_guard() { m_mtx.unlock(); }
+};
+
+class lock_guard_irq {
+private:
+    mutex& m_mtx;
+    uint32_t state;
+
+public:
+    explicit inline lock_guard_irq(mutex& mtx)
+        : m_mtx {mtx} { state = m_mtx.lock_irq(); }
+    lock_guard_irq(const lock_guard_irq&) = delete;
+
+    inline ~lock_guard_irq() { m_mtx.unlock_irq(state); }
+};
+
+} // namespace kernel::async

+ 3 - 4
include/kernel/async/waitlist.hpp

@@ -2,15 +2,14 @@
 
 #include <set>
 
-#include <types/lock.hpp>
-
 #include <kernel/task/forward.hpp>
+#include <kernel/async/lock.hpp>
 
 namespace kernel::async {
 
 class wait_list {
 private:
-    types::mutex m_mtx;
+    mutex m_mtx;
     std::set<task::thread*> m_subscribers;
 
     wait_list(const wait_list&) = delete;
@@ -19,7 +18,7 @@ public:
     explicit wait_list() = default;
 
     // @return whether the wait is interrupted
-    bool wait(types::mutex& lck);
+    bool wait(mutex& lck);
 
     void subscribe();
 

+ 1 - 2
include/kernel/process.hpp

@@ -21,7 +21,6 @@
 #include <types/path.hpp>
 #include <types/status.h>
 #include <types/types.h>
-#include <types/lock.hpp>
 
 #include <kernel/async/waitlist.hpp>
 #include <kernel/interrupt.h>
@@ -180,7 +179,7 @@ public:
     std::set<kernel::task::thread> thds;
     kernel::async::wait_list waitlist;
 
-    types::mutex mtx_waitprocs;
+    kernel::async::mutex mtx_waitprocs;
     std::list<wait_obj> waitprocs;
 
     process_attr attr {};

+ 0 - 2
include/kernel/task/readyqueue.hpp

@@ -2,8 +2,6 @@
 
 #include <list>
 
-#include <types/lock.hpp>
-
 #include <kernel/task/thread.hpp>
 
 namespace kernel::task::dispatcher {

+ 2 - 2
include/kernel/tty.hpp

@@ -7,9 +7,9 @@
 #include <types/allocator.hpp>
 #include <types/buffer.hpp>
 #include <types/cplusplus.hpp>
-#include <types/lock.hpp>
 
 #include <kernel/async/waitlist.hpp>
+#include <kernel/async/lock.hpp>
 
 class tty : public types::non_copyable {
 public:
@@ -55,7 +55,7 @@ public:
     termios termio;
 
 protected:
-    types::mutex mtx_buf;
+    kernel::async::mutex mtx_buf;
     types::buffer buf;
     kernel::async::wait_list waitlist;
 

+ 5 - 5
include/kernel/vfs/file.hpp

@@ -1,15 +1,15 @@
 #pragma once
 
-#include <kernel/async/waitlist.hpp>
-#include <kernel/vfs/dentry.hpp>
-
 #include <errno.h>
 #include <fcntl.h>
 #include <sys/types.h>
 
 #include <types/types.h>
 #include <types/buffer.hpp>
-#include <types/lock.hpp>
+
+#include <kernel/async/waitlist.hpp>
+#include <kernel/async/lock.hpp>
+#include <kernel/vfs/dentry.hpp>
 
 namespace fs {
 
@@ -22,7 +22,7 @@ private:
 private:
     types::buffer buf;
     kernel::async::wait_list waitlist;
-    types::mutex mtx;
+    kernel::async::mutex mtx;
     uint32_t flags;
 
 public:

+ 3 - 2
include/types/allocator.hpp

@@ -8,7 +8,8 @@
 #include <stdint.h>
 #include <types/cplusplus.hpp>
 #include <types/types.h>
-#include <types/lock.hpp>
+
+#include <kernel/async/lock.hpp>
 
 namespace kernel::kinit {
 
@@ -27,7 +28,7 @@ private:
     byte* p_start;
     byte* p_limit;
     byte* p_break;
-    types::mutex mtx;
+    kernel::async::mutex mtx;
 
     constexpr byte* brk(byte* addr)
     {

+ 0 - 67
include/types/lock.hpp

@@ -1,67 +0,0 @@
-#pragma once
-
-#include <stdint.h>
-
-inline void spin_lock(uint32_t volatile* lock_addr)
-{
-    asm volatile(
-        "%=:\n\t\
-         movl $1, %%eax\n\t\
-         xchgl %%eax, (%0)\n\t\
-         cmp $0, %%eax\n\t\
-         jne %=b\n\t\
-        "
-        :
-        : "r"(lock_addr)
-        : "eax", "memory");
-}
-
-inline void spin_unlock(uint32_t volatile* lock_addr)
-{
-    asm volatile(
-        "movl $0, %%eax\n\
-         xchgl %%eax, (%0)"
-        :
-        : "r"(lock_addr)
-        : "eax", "memory");
-}
-
-namespace types {
-
-struct mutex {
-    using mtx_t = volatile uint32_t;
-
-    mtx_t m_lock = 0;
-
-    inline void lock(void)
-    {
-        spin_lock(&m_lock);
-    }
-
-    inline void unlock(void)
-    {
-        spin_unlock(&m_lock);
-    }
-};
-
-class lock_guard {
-private:
-    mutex& m_mtx;
-
-public:
-    explicit lock_guard(mutex& mtx)
-        : m_mtx(mtx)
-    {
-        mtx.lock();
-    }
-
-    lock_guard(const lock_guard&) = delete;
-    lock_guard(lock_guard&&) = delete;
-
-    ~lock_guard()
-    {
-        m_mtx.unlock();
-    }
-};
-
-} // namespace types

+ 0 - 1
include/types/types.h

@@ -27,6 +27,5 @@
 #endif
 
 #ifdef __cplusplus
-#include <types/allocator.hpp>
 #include <types/cplusplus.hpp>
 #endif

+ 4 - 3
src/kernel/allocator.cc

@@ -1,5 +1,4 @@
 #include <types/allocator.hpp>
-#include <types/lock.hpp>
 
 #include <bit>
 #include <cstddef>
@@ -7,6 +6,8 @@
 #include <assert.h>
 #include <stdint.h>
 
+#include <kernel/async/lock.hpp>
+
 namespace types::memory {
 
 struct mem_blk_flags {
@@ -104,7 +105,7 @@ brk_memory_allocator::brk_memory_allocator(byte* start, size_type size)
 
 void* brk_memory_allocator::allocate(size_type size)
 {
-    types::lock_guard lck(mtx);
+    kernel::async::lock_guard_irq lck(mtx);
     // align to 8 bytes boundary
     size = (size + 7) & ~7;
 
@@ -138,7 +139,7 @@ void* brk_memory_allocator::allocate(size_type size)
 
 void brk_memory_allocator::deallocate(void* ptr)
 {
-    types::lock_guard lck(mtx);
+    kernel::async::lock_guard_irq lck(mtx);
     auto* blk = aspblk(aspbyte(ptr) - sizeof(mem_blk));
 
     blk->flags.is_free = 1;

+ 135 - 0
src/kernel/async/lock.cc

@@ -0,0 +1,135 @@
+#include <assert.h>
+#include <stdint.h>
+
+#include <kernel/async/lock.hpp>
+
+namespace kernel::async {
+
+static inline void _raw_spin_lock(spinlock_t* lock_addr)
+{
+    asm volatile(
+        "%=:\n\t\
+         movl $1, %%eax\n\t\
+         xchgl %%eax, (%0)\n\t\
+         cmp $0, %%eax\n\t\
+         jne %=b\n\t\
+        "
+        :
+        : "r"(lock_addr)
+        : "eax", "memory");
+}
+
+static inline void _raw_spin_unlock(spinlock_t* lock_addr)
+{
+    asm volatile(
+        "movl $0, %%eax\n\
+         xchgl %%eax, (%0)"
+        :
+        : "r"(lock_addr)
+        : "eax", "memory");
+}
+
+static inline uint32_t _save_interrupt_state()
+{
+    uint32_t retval;
+    asm volatile(
+        "pushfl\n\t"
+        "popl %0\n\t"
+        "cli"
+        : "=g"(retval)
+        :
+        :
+        );
+
+    return retval;
+}
+
+static inline void _restore_interrupt_state(uint32_t flags)
+{
+    asm volatile(
+        "pushl %0\n\t"
+        "popfl"
+        :
+        : "g"(flags)
+        :
+        );
+}
+
+// TODO: mark as _per_cpu
+static inline preempt_count_t& _preempt_count()
+{
+    static preempt_count_t _preempt_count;
+    assert(!(_preempt_count & 0x80000000));
+    return _preempt_count;
+}
+
+void preempt_disable()
+{
+    ++_preempt_count();
+}
+
+void preempt_enable()
+{
+    --_preempt_count();
+}
+
+preempt_count_t preempt_count()
+{
+    return _preempt_count();
+}
+
+void spin_lock(spinlock_t& lock)
+{
+    preempt_disable();
+    _raw_spin_lock(&lock);
+}
+
+void spin_unlock(spinlock_t& lock)
+{
+    _raw_spin_unlock(&lock);
+    preempt_enable();
+}
+
+uint32_t spin_lock_irqsave(spinlock_t& lock)
+{
+    auto state = _save_interrupt_state();
+    preempt_disable();
+
+    _raw_spin_lock(&lock);
+
+    return state;
+}
+
+void spin_unlock_irqrestore(spinlock_t& lock, uint32_t state)
+{
+    _raw_spin_unlock(&lock);
+    preempt_enable();
+    _restore_interrupt_state(state);
+}
+
+mutex::~mutex()
+{
+    assert(m_lock == 0);
+}
+
+void mutex::lock()
+{
+    spin_lock(m_lock);
+}
+
+void mutex::unlock()
+{
+    spin_unlock(m_lock);
+}
+
+uint32_t mutex::lock_irq()
+{
+    return spin_lock_irqsave(m_lock);
+}
+
+void mutex::unlock_irq(uint32_t state)
+{
+    spin_unlock_irqrestore(m_lock, state);
+}
+
+} // namespace kernel::async

+ 5 - 6
src/kernel/async/waitlist.cc

@@ -2,14 +2,13 @@
 
 #include <assert.h>
 
-#include <types/lock.hpp>
-
+#include <kernel/async/lock.hpp>
 #include <kernel/process.hpp>
 #include <kernel/task/thread.hpp>
 
 using namespace kernel::async;
 
-bool wait_list::wait(types::mutex& lock)
+bool wait_list::wait(mutex& lock)
 {
     this->subscribe();
 
@@ -26,7 +25,7 @@ bool wait_list::wait(types::mutex& lock)
 
 void wait_list::subscribe()
 {
-    types::lock_guard lck(m_mtx);
+    lock_guard lck(m_mtx);
 
     auto* thd = current_thread;
 
@@ -38,7 +37,7 @@ void wait_list::subscribe()
 
 void wait_list::notify_one()
 {
-    types::lock_guard lck(m_mtx);
+    lock_guard lck(m_mtx);
 
     if (m_subscribers.empty())
         return;
@@ -51,7 +50,7 @@ void wait_list::notify_one()
 
 void wait_list::notify_all()
 {
-    types::lock_guard lck(m_mtx);
+    lock_guard lck(m_mtx);
 
     if (m_subscribers.empty())
         return;

+ 7 - 5
src/kernel/interrupt.cpp

@@ -1,8 +1,14 @@
 #include <list>
 #include <vector>
 
-#include <asm/port_io.h>
 #include <assert.h>
+#include <stdint.h>
+#include <stdio.h>
+
+#include <types/size.h>
+#include <types/types.h>
+
+#include <asm/port_io.h>
 #include <kernel/hw/keyboard.h>
 #include <kernel/hw/serial.h>
 #include <kernel/hw/timer.h>
@@ -14,10 +20,6 @@
 #include <kernel/process.hpp>
 #include <kernel/vfs.hpp>
 #include <kernel/vga.hpp>
-#include <stdint.h>
-#include <stdio.h>
-#include <types/size.h>
-#include <types/types.h>
 
 struct IDT_entry {
     uint16_t offset_low;

+ 22 - 16
src/kernel/process.cpp

@@ -2,15 +2,22 @@
 #include <queue>
 #include <utility>
 
+#include <assert.h>
 #include <stdint.h>
 #include <stdio.h>
 #include <bits/alltypes.h>
 #include <sys/wait.h>
 
+#include <types/allocator.hpp>
+#include <types/bitmap.hpp>
+#include <types/cplusplus.hpp>
+#include <types/elf.hpp>
+#include <types/size.h>
+#include <types/status.h>
+#include <types/types.h>
+
 #include <asm/port_io.h>
 #include <asm/sys.h>
-#include <assert.h>
-#include <fs/fat.hpp>
 #include <kernel/interrupt.h>
 #include <kernel/log.hpp>
 #include <kernel/mem.h>
@@ -19,22 +26,18 @@
 #include <kernel/process.hpp>
 #include <kernel/signal.hpp>
 #include <kernel/vfs.hpp>
+#include <kernel/async/lock.hpp>
 #include <kernel/user/thread_local.hpp>
 #include <kernel/task/thread.hpp>
 #include <kernel/task/readyqueue.hpp>
 
-#include <types/allocator.hpp>
-#include <types/bitmap.hpp>
-#include <types/cplusplus.hpp>
-#include <types/elf.hpp>
-#include <types/lock.hpp>
-#include <types/size.h>
-#include <types/status.h>
-#include <types/types.h>
+using kernel::async::mutex;
+using kernel::async::lock_guard, kernel::async::lock_guard_irq;
 
 static void (*volatile kthreadd_new_thd_func)(void*);
 static void* volatile kthreadd_new_thd_data;
-static types::mutex kthreadd_mtx;
+
+static mutex kthreadd_mtx;
 
 namespace kernel {
 
@@ -240,7 +243,7 @@ void kernel_threadd_main(void)
             void* data = nullptr;
 
             if (1) {
-                types::lock_guard lck(kthreadd_mtx);
+                lock_guard lck(kthreadd_mtx);
 
                 if (kthreadd_new_thd_func) {
                     func = std::exchange(kthreadd_new_thd_func, nullptr);
@@ -357,10 +360,10 @@ void proclist::kill(pid_t pid, int exit_code)
 
     bool flag = false;
     if (1) {
-        types::lock_guard lck(init.mtx_waitprocs);
+        lock_guard_irq lck(init.mtx_waitprocs);
 
         if (1) {
-            types::lock_guard lck(proc.mtx_waitprocs);
+            lock_guard_irq lck(proc.mtx_waitprocs);
 
             for (const auto& item : proc.waitprocs) {
                 if (WIFSTOPPED(item.code) || WIFCONTINUED(item.code))
@@ -378,7 +381,7 @@ void proclist::kill(pid_t pid, int exit_code)
         init.waitlist.notify_all();
 
     if (1) {
-        types::lock_guard lck(parent.mtx_waitprocs);
+        lock_guard_irq lck(parent.mtx_waitprocs);
         parent.waitprocs.push_back({ pid, exit_code });
     }
 
@@ -495,7 +498,7 @@ void NORETURN _kernel_init(void)
 
 void k_new_thread(void (*func)(void*), void* data)
 {
-    types::lock_guard lck(kthreadd_mtx);
+    lock_guard lck(kthreadd_mtx);
     kthreadd_new_thd_func = func;
     kthreadd_new_thd_data = data;
 }
@@ -537,6 +540,9 @@ void NORETURN init_scheduler(void)
 extern "C" void asm_ctx_switch(uint32_t** curr_esp, uint32_t** next_esp);
 bool schedule()
 {
+    if (kernel::async::preempt_count() != 0)
+        return true;
+
     auto* next_thd = kernel::task::dispatcher::next();
     process* proc = nullptr;
     kernel::task::thread* curr_thd = nullptr;

+ 2 - 2
src/kernel/syscall.cpp

@@ -20,6 +20,7 @@
 #include <sys/utsname.h>
 #include <sys/wait.h>
 
+#include <kernel/async/lock.hpp>
 #include <kernel/user/thread_local.hpp>
 #include <kernel/task/readyqueue.hpp>
 #include <kernel/task/thread.hpp>
@@ -38,7 +39,6 @@
 #include <types/allocator.hpp>
 #include <types/elf.hpp>
 #include <types/path.hpp>
-#include <types/lock.hpp>
 #include <types/status.h>
 #include <types/string.hpp>
 #include <types/types.h>
@@ -226,7 +226,7 @@ int _syscall_waitpid(interrupt_stack* data)
         return -EINVAL;
 
     auto& cv = current_process->waitlist;
-    types::lock_guard lck(current_process->mtx_waitprocs);
+    kernel::async::lock_guard lck(current_process->mtx_waitprocs);
 
     auto& waitlist = current_process->waitprocs;
 

+ 6 - 6
src/kernel/task/readyqueue.cc

@@ -2,32 +2,32 @@
 
 #include <list>
 
-#include <types/lock.hpp>
-
+#include <kernel/async/lock.hpp>
 #include <kernel/task/thread.hpp>
 
 using namespace kernel::task;
+using kernel::async::mutex, kernel::async::lock_guard_irq;
 
-static types::mutex dispatcher_mtx;
+static mutex dispatcher_mtx;
 static std::list<thread*> dispatcher_thds;
 
 void dispatcher::enqueue(thread* thd)
 {
-    types::lock_guard lck(dispatcher_mtx);
+    lock_guard_irq lck(dispatcher_mtx);
 
     dispatcher_thds.push_back(thd);
 }
 
 void dispatcher::dequeue(thread* thd)
 {
-    types::lock_guard lck(dispatcher_mtx);
+    lock_guard_irq lck(dispatcher_mtx);
 
     dispatcher_thds.remove(thd);
 }
 
 thread* dispatcher::next()
 {
-    types::lock_guard lck(dispatcher_mtx);
+    lock_guard_irq lck(dispatcher_mtx);
 
     auto* retval = dispatcher_thds.front();
 

+ 4 - 4
src/kernel/task/thread.cc

@@ -2,11 +2,10 @@
 
 #include <queue>
 
-#include <types/lock.hpp>
-
 #include <kernel/log.hpp>
 #include <kernel/mm.hpp>
 #include <kernel/signal.hpp>
+#include <kernel/async/lock.hpp>
 #include <kernel/task/readyqueue.hpp>
 
 using namespace kernel::task;
@@ -37,12 +36,12 @@ bool thread::operator==(const thread& rhs) const
 }
 
 static std::priority_queue<std::byte*> s_kstacks;
+static kernel::async::mutex s_mtx_kstacks;
 
 thread::kernel_stack::kernel_stack()
 {
     static int allocated;
-    static types::mutex mtx;
-    types::lock_guard lck(mtx);
+    kernel::async::lock_guard_irq lck(s_mtx_kstacks);
 
     if (!s_kstacks.empty()) {
         stack_base = s_kstacks.top();
@@ -84,6 +83,7 @@ thread::kernel_stack::kernel_stack(kernel_stack&& other)
 
 thread::kernel_stack::~kernel_stack()
 {
+    kernel::async::lock_guard_irq lck(s_mtx_kstacks);
     s_kstacks.push(stack_base);
 }
 

+ 3 - 4
src/kernel/tty.cpp

@@ -4,8 +4,7 @@
 #include <stdio.h>
 #include <termios.h>
 
-#include <types/lock.hpp>
-
+#include <kernel/async/lock.hpp>
 #include <kernel/hw/serial.h>
 #include <kernel/process.hpp>
 #include <kernel/tty.hpp>
@@ -56,7 +55,7 @@ void tty::print(const char* str)
 
 int tty::poll()
 {
-    types::lock_guard lck(this->mtx_buf);
+    kernel::async::lock_guard lck(this->mtx_buf);
     if (this->buf.empty()) {
         bool interrupted = this->waitlist.wait(this->mtx_buf);
 
@@ -76,7 +75,7 @@ size_t tty::read(char* buf, size_t buf_size, size_t n)
         if (n == 0)
             break;
 
-        types::lock_guard lck(this->mtx_buf);
+        kernel::async::lock_guard lck(this->mtx_buf);
 
         if (this->buf.empty()) {
             bool interrupted = this->waitlist.wait(this->mtx_buf);

+ 4 - 4
src/kernel/vfs.cpp

@@ -781,7 +781,7 @@ fs::pipe::pipe(void)
 void fs::pipe::close_read(void)
 {
     if (1) {
-        types::lock_guard lck(mtx);
+        kernel::async::lock_guard lck(mtx);
         flags &= (~READABLE);
     }
     waitlist.notify_all();
@@ -790,7 +790,7 @@ void fs::pipe::close_read(void)
 void fs::pipe::close_write(void)
 {
     if (1) {
-        types::lock_guard lck(mtx);
+        kernel::async::lock_guard lck(mtx);
         flags &= (~WRITABLE);
     }
     waitlist.notify_all();
@@ -801,7 +801,7 @@ int fs::pipe::write(const char* buf, size_t n)
     // TODO: check privilege
     // TODO: check EPIPE
     if (1) {
-        types::lock_guard lck(mtx);
+        kernel::async::lock_guard lck(mtx);
 
         if (!is_readable()) {
             current_thread->send_signal(SIGPIPE);
@@ -831,7 +831,7 @@ int fs::pipe::read(char* buf, size_t n)
 {
     // TODO: check privilege
     if (1) {
-        types::lock_guard lck(mtx);
+        kernel::async::lock_guard lck(mtx);
 
         if (!is_writeable()) {
             size_t orig_n = n;