Prechádzať zdrojové kódy

wait_list: add `get_waker_function` used in `SignalList::signal_waker`

greatbridf 9 mesiacov pred
rodič
commit
3179e41a7c

+ 64 - 34
crates/eonix_sync/src/wait_list.rs

@@ -1,12 +1,12 @@
-mod prepare;
+mod wait_handle;
 mod wait_object;
 
 use crate::{LazyLock, Spin};
-use core::{fmt, sync::atomic::Ordering};
-use intrusive_collections::LinkedList;
-use wait_object::WaitObjectAdapter;
+use core::fmt;
+use intrusive_collections::{linked_list::CursorMut, LinkedList};
+use wait_object::{WaitObject, WaitObjectAdapter};
 
-pub use prepare::Prepare;
+pub use wait_handle::WaitHandle;
 
 pub struct WaitList {
     waiters: LazyLock<Spin<LinkedList<WaitObjectAdapter>>>,
@@ -26,20 +26,17 @@ impl WaitList {
     pub fn notify_one(&self) -> bool {
         let mut waiters = self.waiters.lock();
         let mut waiter = waiters.front_mut();
-        if let Some(waiter) = waiter.get() {
-            // SAFETY: `wait_object` is a valid reference to a `WaitObject` because we
-            //         won't drop the wait object until the waiting thread will be woken
-            //         up and make sure that it is not on the list.
-            waiter.woken_up.store(true, Ordering::Release);
 
-            if let Some(waker) = waiter.waker.lock().take() {
-                waker.wake();
+        if !waiter.is_null() {
+            unsafe {
+                // SAFETY: `waiter` is not null.
+                self.notify_waiter_unchecked(&mut waiter);
             }
-        }
 
-        // We need to remove the node from the list AFTER we've finished accessing it so
-        // the waiter knows when it is safe to release the wait object node.
-        waiter.remove().is_some()
+            true
+        } else {
+            false
+        }
     }
 
     pub fn notify_all(&self) -> usize {
@@ -48,31 +45,64 @@ impl WaitList {
         let mut count = 0;
 
         while !waiter.is_null() {
-            if let Some(waiter) = waiter.get() {
-                // SAFETY: `wait_object` is a valid reference to a `WaitObject` because we
-                //         won't drop the wait object until the waiting thread will be woken
-                //         up and make sure that it is not on the list.
-                waiter.woken_up.store(true, Ordering::Release);
-
-                if let Some(waker) = waiter.waker.lock().take() {
-                    waker.wake();
-                }
-            } else {
-                unreachable!("Invalid state.");
+            unsafe {
+                // SAFETY: `waiter` is not null.
+                self.notify_waiter_unchecked(&mut waiter);
             }
-
             count += 1;
-
-            // We need to remove the node from the list AFTER we've finished accessing it so
-            // the waiter knows when it is safe to release the wait object node.
-            waiter.remove();
         }
 
         count
     }
 
-    pub fn prepare_to_wait(&self) -> Prepare<'_> {
-        Prepare::new(self)
+    pub fn prepare_to_wait(&self) -> WaitHandle<'_> {
+        WaitHandle::new(self)
+    }
+}
+
+impl WaitList {
+    unsafe fn notify_waiter_unchecked(&self, waiter: &mut CursorMut<'_, WaitObjectAdapter>) {
+        let wait_object = unsafe {
+            // SAFETY: The caller guarantees that `waiter` should be `Some`.
+            //         `wait_object` is a valid reference to a `WaitObject` because we
+            //         won't drop the wait object until the waiting thread will be woken
+            //         up and make sure that it is not on the list.
+            waiter.get().unwrap_unchecked()
+        };
+
+        wait_object.set_woken_up();
+
+        if let Some(waker) = wait_object.take_waker() {
+            waker.wake();
+        }
+
+        // Acknowledge the wait object that we're done.
+        unsafe {
+            waiter.remove().unwrap_unchecked().clear_wait_list();
+        }
+    }
+
+    pub(self) fn notify_waiter(&self, wait_object: &WaitObject) {
+        let mut waiters = self.waiters.lock();
+        if !wait_object.on_list() {
+            return;
+        }
+
+        assert_eq!(
+            wait_object.wait_list(),
+            self,
+            "Wait object is not in the wait list."
+        );
+
+        let mut waiter = unsafe {
+            // SAFETY: `wait_object` is on the `waiters` list.
+            waiters.cursor_mut_from_ptr(wait_object)
+        };
+
+        unsafe {
+            // SAFETY: We got the cursor from a valid wait object, which can't be null.
+            self.notify_waiter_unchecked(&mut waiter);
+        }
     }
 }
 

+ 70 - 44
crates/eonix_sync/src/wait_list/prepare.rs → crates/eonix_sync/src/wait_list/wait_handle.rs

@@ -1,13 +1,13 @@
 use super::{wait_object::WaitObject, WaitList};
 use core::{
     cell::UnsafeCell,
+    hint::spin_loop,
     pin::Pin,
-    sync::atomic::Ordering,
     task::{Context, Poll, Waker},
 };
 use intrusive_collections::UnsafeRef;
 
-pub struct Prepare<'a> {
+pub struct WaitHandle<'a> {
     wait_list: &'a WaitList,
     wait_object: UnsafeCell<WaitObject>,
     state: State,
@@ -27,11 +27,14 @@ struct PrepareSplit<'a> {
     wait_object: Pin<&'a WaitObject>,
 }
 
-impl<'a> Prepare<'a> {
+// SAFETY: All access to `wait_object` is protected.
+unsafe impl Sync for WaitHandle<'_> {}
+
+impl<'a> WaitHandle<'a> {
     pub const fn new(wait_list: &'a WaitList) -> Self {
         Self {
             wait_list,
-            wait_object: UnsafeCell::new(WaitObject::new()),
+            wait_object: UnsafeCell::new(WaitObject::new(wait_list)),
             state: State::Init,
         }
     }
@@ -59,13 +62,19 @@ impl<'a> Prepare<'a> {
     }
 
     fn set_state(self: Pin<&mut Self>, state: State) {
-        // SAFETY: We only touch `state`, which is `Unpin`.
         unsafe {
+            // SAFETY: We only touch `state`, which is `Unpin`.
             let this = self.get_unchecked_mut();
             this.state = state;
         }
     }
 
+    fn wait_until_off_list(&self) {
+        while self.wait_object().on_list() {
+            spin_loop();
+        }
+    }
+
     /// # Returns
     /// Whether we've been woken up or not.
     fn do_add_to_wait_list(mut self: Pin<&mut Self>, waker: Option<&Waker>) -> bool {
@@ -91,8 +100,7 @@ impl<'a> Prepare<'a> {
                 waiters.push_back(wait_object_ref);
 
                 if let Some(waker) = waker.cloned() {
-                    let old_waker = wait_object.waker.lock().replace(waker);
-                    assert!(old_waker.is_none(), "Waker already set");
+                    wait_object.save_waker(waker);
                     *state = State::WakerSet;
                 } else {
                     *state = State::OnList;
@@ -103,22 +111,22 @@ impl<'a> Prepare<'a> {
             // We are already on the wait list, so we can just set the waker.
             State::OnList => {
                 // If we are already woken up, we can just return.
-                if wait_object.woken_up.load(Ordering::Acquire) {
+                if wait_object.woken_up() {
                     *state = State::WokenUp;
                     return true;
                 }
 
                 if let Some(waker) = waker {
                     // Lock the waker and check if it is already set.
-                    let mut waker_lock = wait_object.waker.lock();
-                    if wait_object.woken_up.load(Ordering::Acquire) {
+                    let waker_set = wait_object.save_waker_if_not_woken_up(&waker);
+
+                    if waker_set {
+                        *state = State::WakerSet;
+                    } else {
+                        // We are already woken up, so we can just return.
                         *state = State::WokenUp;
                         return true;
                     }
-
-                    let old_waker = waker_lock.replace(waker.clone());
-                    assert!(old_waker.is_none(), "Waker already set");
-                    *state = State::WakerSet;
                 }
 
                 return false;
@@ -130,32 +138,51 @@ impl<'a> Prepare<'a> {
     pub fn add_to_wait_list(self: Pin<&mut Self>) {
         self.do_add_to_wait_list(None);
     }
+
+    /// # Safety
+    /// The caller MUST guarantee that the last use of the returned function
+    /// is before `self` is dropped. Otherwise the value referred to in this
+    /// function will be dangling and will cause undefined behavior.
+    pub unsafe fn get_waker_function(self: Pin<&Self>) -> impl Fn() + Send + Sync + 'static {
+        let wait_list: &WaitList = unsafe {
+            // SAFETY: The caller guarantees that the last use of returned function
+            //         is before `self` is dropped.
+            &*(self.wait_list as *const _)
+        };
+
+        let wait_object = unsafe {
+            // SAFETY: The caller guarantees that the last use of returned function
+            //         is before `self` is dropped.
+            &*self.wait_object.get()
+        };
+
+        move || {
+            wait_list.notify_waiter(wait_object);
+        }
+    }
 }
 
-impl Future for Prepare<'_> {
+impl Future for WaitHandle<'_> {
     type Output = ();
 
     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
         match self.state {
             State::Init | State::OnList => {
                 if self.as_mut().do_add_to_wait_list(Some(cx.waker())) {
-                    // Make sure we're off the wait list.
-                    while self.wait_object().on_list() {}
+                    self.wait_until_off_list();
                     Poll::Ready(())
                 } else {
                     Poll::Pending
                 }
             }
             State::WakerSet => {
-                if !self.as_ref().wait_object().woken_up.load(Ordering::Acquire) {
+                if !self.as_ref().wait_object().woken_up() {
                     // If we read `woken_up == false`, we can guarantee that we have a spurious
                     // wakeup. In this case, we MUST be still on the wait list, so no more
                     // actions are required.
                     Poll::Pending
                 } else {
-                    // Make sure we're off the wait list.
-                    while self.wait_object().on_list() {}
-
+                    self.wait_until_off_list();
                     self.set_state(State::WokenUp);
                     Poll::Ready(())
                 }
@@ -165,30 +192,29 @@ impl Future for Prepare<'_> {
     }
 }
 
-impl Drop for Prepare<'_> {
+impl Drop for WaitHandle<'_> {
     fn drop(&mut self) {
-        match self.state {
-            State::Init | State::WokenUp => {}
-            State::OnList | State::WakerSet => {
-                let wait_object = self.wait_object();
-                if wait_object.woken_up.load(Ordering::Acquire) {
-                    // We've woken up by someone. It won't be long before they
-                    // remove us from the list. So spin until we are off the list.
-                    // And we're done.
-                    while wait_object.on_list() {}
-                } else {
-                    // Lock the list and try again.
-                    let mut waiters = self.wait_list.waiters.lock();
-
-                    if wait_object.on_list() {
-                        let mut cursor = unsafe {
-                            // SAFETY: The list is locked so no one could be polling nodes
-                            //         off while we are trying to remove it.
-                            waiters.cursor_mut_from_ptr(wait_object)
-                        };
-                        assert!(cursor.remove().is_some());
-                    }
-                }
+        if matches!(self.state, State::Init | State::WokenUp) {
+            return;
+        }
+
+        let wait_object = self.wait_object();
+        if wait_object.woken_up() {
+            // We've woken up by someone. It won't be long before they
+            // remove us from the list. So spin until we are off the list.
+            // And we're done.
+            self.wait_until_off_list();
+        } else {
+            // Lock the list and try again.
+            let mut waiters = self.wait_list.waiters.lock();
+
+            if wait_object.on_list() {
+                let mut cursor = unsafe {
+                    // SAFETY: The list is locked so no one could be polling nodes
+                    //         off while we are trying to remove it.
+                    waiters.cursor_mut_from_ptr(wait_object)
+                };
+                assert!(cursor.remove().is_some());
             }
         }
     }

+ 79 - 6
crates/eonix_sync/src/wait_list/wait_object.rs

@@ -1,5 +1,13 @@
+use super::WaitList;
 use crate::Spin;
-use core::{marker::PhantomPinned, pin::Pin, sync::atomic::AtomicBool, task::Waker};
+use core::{
+    cell::UnsafeCell,
+    marker::PhantomPinned,
+    pin::Pin,
+    ptr::null_mut,
+    sync::atomic::{AtomicBool, AtomicPtr, Ordering},
+    task::Waker,
+};
 use intrusive_collections::{intrusive_adapter, LinkedListAtomicLink, UnsafeRef};
 
 intrusive_adapter!(
@@ -8,23 +16,88 @@ intrusive_adapter!(
 );
 
 pub struct WaitObject {
-    pub(super) woken_up: AtomicBool,
-    pub(super) waker: Spin<Option<Waker>>,
+    woken_up: AtomicBool,
+    waker_lock: Spin<()>,
+    waker: UnsafeCell<Option<Waker>>,
+    wait_list: AtomicPtr<WaitList>,
     link: LinkedListAtomicLink,
     _pinned: PhantomPinned,
 }
 
+// SAFETY: `WaitObject` is `Sync` because we sync the `waker` access with a spinlock.
+unsafe impl Sync for WaitObject {}
+
 impl WaitObject {
-    pub const fn new() -> Self {
+    pub const fn new(wait_list: &WaitList) -> Self {
         Self {
             woken_up: AtomicBool::new(false),
-            waker: Spin::new(None),
+            waker_lock: Spin::new(()),
+            waker: UnsafeCell::new(None),
+            wait_list: AtomicPtr::new(wait_list as *const _ as *mut _),
             link: LinkedListAtomicLink::new(),
             _pinned: PhantomPinned,
         }
     }
 
+    pub fn save_waker(&self, waker: Waker) {
+        let _lock = self.waker_lock.lock();
+        unsafe {
+            // SAFETY: We're holding the waker lock.
+            let old_waker = (*self.waker.get()).replace(waker);
+            assert!(old_waker.is_none(), "Waker already set.");
+        }
+    }
+
+    /// Save the waker if the wait object was not woken up atomically.
+    ///
+    /// # Returns
+    /// Whether the waker was saved.
+    pub fn save_waker_if_not_woken_up(&self, waker: &Waker) -> bool {
+        let _lock = self.waker_lock.lock();
+        if self.woken_up() {
+            return false;
+        }
+
+        unsafe {
+            // SAFETY: We're holding the waker lock.
+            let old_waker = (*self.waker.get()).replace(waker.clone());
+            assert!(old_waker.is_none(), "Waker already set.");
+        }
+
+        true
+    }
+
+    pub fn take_waker(&self) -> Option<Waker> {
+        let _lock = self.waker_lock.lock();
+        unsafe {
+            // SAFETY: We're holding the waker lock.
+            self.waker.get().as_mut().unwrap().take()
+        }
+    }
+
+    /// Check whether someone had woken up the wait object.
+    ///
+    /// Does an `Acquire` operation.
+    pub fn woken_up(&self) -> bool {
+        self.woken_up.load(Ordering::Acquire)
+    }
+
+    /// Set the wait object as woken up.
+    ///
+    /// Does a `Release` operation.
+    pub fn set_woken_up(&self) {
+        self.woken_up.store(true, Ordering::Release);
+    }
+
+    pub fn wait_list(&self) -> *const WaitList {
+        self.wait_list.load(Ordering::Acquire)
+    }
+
+    pub fn clear_wait_list(&self) {
+        self.wait_list.store(null_mut(), Ordering::Release);
+    }
+
     pub fn on_list(&self) -> bool {
-        self.link.is_linked()
+        !self.wait_list.load(Ordering::Acquire).is_null()
     }
 }

+ 5 - 2
src/sync/condvar.rs

@@ -28,7 +28,7 @@ impl<const I: bool> CondVar<I> {
         self.wait_list.notify_all();
     }
 
-    /// Unlock the `guard`. Then wait until being waken up.
+    /// Unlock the `guard`. Then wait until being woken up.
     /// Return the relocked `guard`.
     pub async fn wait<G>(&self, guard: G) -> G
     where
@@ -38,7 +38,10 @@ impl<const I: bool> CondVar<I> {
         let mut wait_handle = pin!(self.wait_list.prepare_to_wait());
         wait_handle.as_mut().add_to_wait_list();
 
-        let interrupt_waker = pin!(|| {});
+        let interrupt_waker = pin!(unsafe {
+            // SAFETY: We won't use the waker after the wait_handle is dropped.
+            wait_handle.as_ref().get_waker_function()
+        });
 
         if I {
             // Prohibit the thread from being woken up by a signal.