Browse Source

task: fix infinite sleep in stackful tasks

The stackful tasks might be woken up before actually being put into
sleep by returning a Poll::Pending. Thus, infinite sleep will occur
since we are no longer on both the wait list and the ready queue.

The solution is to remember that we are woken up in stackful wakers and
check before putting us to sleep by wait_for_wakeups().

Also, implement Drop for RCUPointer by using call_rcu to drop the
underlying data. We must mark T: Send + Sync + 'static in order to send
the arc to the runtime...

Signed-off-by: greatbridf <greatbridf@icloud.com>
greatbridf 5 tháng trước cách đây
mục cha
commit
3fb4966118
2 tập tin đã thay đổi với 57 bổ sung12 xóa
  1. 39 3
      src/kernel/task.rs
  2. 18 9
      src/rcu.rs

+ 39 - 3
src/kernel/task.rs

@@ -10,7 +10,6 @@ mod signal;
 mod thread;
 
 pub use clone::{do_clone, CloneArgs, CloneFlags};
-use eonix_runtime::task::Task;
 pub use futex::{futex_wait, futex_wake, parse_futexop, FutexFlags, FutexOp, RobustListHead};
 pub use kernel_stack::KernelStack;
 pub use loader::ProgramLoader;
@@ -84,10 +83,14 @@ where
         interrupt::{default_fault_handler, default_irq_handler},
         timer::{should_reschedule, timer_interrupt},
     };
+    use alloc::sync::Arc;
+    use alloc::task::Wake;
     use core::cell::UnsafeCell;
     use core::future::Future;
     use core::pin::Pin;
     use core::ptr::NonNull;
+    use core::sync::atomic::AtomicBool;
+    use core::sync::atomic::Ordering;
     use core::task::Context;
     use core::task::Poll;
     use core::task::Waker;
@@ -97,6 +100,7 @@ where
     use eonix_hal::trap::TrapContext;
     use eonix_preempt::assert_preempt_enabled;
     use eonix_runtime::executor::Stack;
+    use eonix_runtime::task::Task;
     use thread::wait_for_wakeups;
 
     let stack = KernelStack::new();
@@ -105,18 +109,46 @@ where
     where
         F: Future,
     {
-        let waker = Waker::from(Task::current().clone());
+        struct WakeSaver {
+            task: Arc<Task>,
+            woken: AtomicBool,
+        }
+
+        impl Wake for WakeSaver {
+            fn wake_by_ref(self: &Arc<Self>) {
+                // SAFETY: If we read true below in the loop, we must have been
+                //         woken up and acquired our waker's work by the runtime.
+                self.woken.store(true, Ordering::Relaxed);
+                self.task.wake_by_ref();
+            }
+
+            fn wake(self: Arc<Self>) {
+                self.wake_by_ref();
+            }
+        }
+
+        let wake_saver = Arc::new(WakeSaver {
+            task: Task::current().clone(),
+            woken: AtomicBool::new(false),
+        });
+        let waker = Waker::from(wake_saver.clone());
         let mut cx = Context::from_waker(&waker);
 
         let output = loop {
             match future.as_mut().poll(&mut cx) {
                 Poll::Ready(output) => break output,
                 Poll::Pending => {
+                    assert_preempt_enabled!("Blocking in stackful futures is not allowed.");
+
                     if Task::current().is_ready() {
                         continue;
                     }
 
-                    assert_preempt_enabled!("Blocking in stackful futures is not allowed.");
+                    // SAFETY: The runtime must have ensured that we can see the
+                    //         work done by the waker.
+                    if wake_saver.woken.swap(false, Ordering::Relaxed) {
+                        continue;
+                    }
 
                     unsafe {
                         #[cfg(target_arch = "riscv64")]
@@ -129,6 +161,10 @@ where
             }
         };
 
+        drop(cx);
+        drop(waker);
+        drop(wake_saver);
+
         unsafe {
             output_ptr.write(Some(output));
         }

+ 18 - 9
src/rcu.rs

@@ -194,9 +194,15 @@ impl<'lt, T: RCUNode<T>> Iterator for RCUIterator<'lt, T> {
     }
 }
 
-pub struct RCUPointer<T>(AtomicPtr<T>);
-
-impl<T: core::fmt::Debug> core::fmt::Debug for RCUPointer<T> {
+pub struct RCUPointer<T>(AtomicPtr<T>)
+where
+    T: Send + Sync + 'static;
+
+impl<T> core::fmt::Debug for RCUPointer<T>
+where
+    T: core::fmt::Debug,
+    T: Send + Sync + 'static,
+{
     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
         match NonNull::new(self.0.load(Ordering::Acquire)) {
             Some(pointer) => {
@@ -209,7 +215,10 @@ impl<T: core::fmt::Debug> core::fmt::Debug for RCUPointer<T> {
     }
 }
 
-impl<T> RCUPointer<T> {
+impl<T> RCUPointer<T>
+where
+    T: Send + Sync + 'static,
+{
     pub const fn empty() -> Self {
         Self(AtomicPtr::new(core::ptr::null_mut()))
     }
@@ -266,16 +275,16 @@ impl<T> RCUPointer<T> {
     }
 }
 
-impl<T> Drop for RCUPointer<T> {
+impl<T> Drop for RCUPointer<T>
+where
+    T: Send + Sync + 'static,
+{
     fn drop(&mut self) {
         // SAFETY: We call `rcu_sync()` to ensure that all readers are done.
         if let Some(arc) = unsafe { self.swap(None) } {
             // We only wait if there are other references.
             if Arc::strong_count(&arc) == 1 {
-                call_rcu(move || {
-                    let _ = arc;
-                    todo!();
-                });
+                call_rcu(move || drop(arc));
             }
         }
     }