rcu.rs 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. use crate::prelude::*;
  2. use alloc::sync::Arc;
  3. use core::{
  4. ops::Deref,
  5. ptr::NonNull,
  6. sync::atomic::{AtomicPtr, Ordering},
  7. };
  8. use eonix_runtime::task::Task;
  9. use eonix_sync::{Mutex, RwLock, RwLockReadGuard};
  10. use pointers::BorrowedArc;
  11. pub struct RCUReadGuard<'data, T: 'data> {
  12. value: T,
  13. _guard: RwLockReadGuard<'data, ()>,
  14. _phantom: PhantomData<&'data T>,
  15. }
  16. static GLOBAL_RCU_SEM: RwLock<()> = RwLock::new(());
  17. impl<'data, T: 'data> RCUReadGuard<'data, T> {
  18. fn lock(value: T) -> Self {
  19. Self {
  20. value,
  21. _guard: Task::block_on(GLOBAL_RCU_SEM.read()),
  22. _phantom: PhantomData,
  23. }
  24. }
  25. }
  26. impl<'data, T: 'data> Deref for RCUReadGuard<'data, T> {
  27. type Target = T;
  28. fn deref(&self) -> &Self::Target {
  29. &self.value
  30. }
  31. }
  32. pub async fn rcu_sync() {
  33. // Lock the global RCU semaphore to ensure that all readers are done.
  34. let _ = GLOBAL_RCU_SEM.write().await;
  35. }
  36. pub trait RCUNode<MySelf> {
  37. fn rcu_prev(&self) -> &AtomicPtr<MySelf>;
  38. fn rcu_next(&self) -> &AtomicPtr<MySelf>;
  39. }
  40. pub struct RCUList<T: RCUNode<T>> {
  41. head: AtomicPtr<T>,
  42. reader_lock: RwLock<()>,
  43. update_lock: Mutex<()>,
  44. }
  45. impl<T: RCUNode<T>> RCUList<T> {
  46. pub const fn new() -> Self {
  47. Self {
  48. head: AtomicPtr::new(core::ptr::null_mut()),
  49. reader_lock: RwLock::new(()),
  50. update_lock: Mutex::new(()),
  51. }
  52. }
  53. pub fn insert(&self, new_node: Arc<T>) {
  54. let _lck = self.update_lock.lock();
  55. let old_head = self.head.load(Ordering::Acquire);
  56. new_node
  57. .rcu_prev()
  58. .store(core::ptr::null_mut(), Ordering::Release);
  59. new_node.rcu_next().store(old_head, Ordering::Release);
  60. if let Some(old_head) = unsafe { old_head.as_ref() } {
  61. old_head
  62. .rcu_prev()
  63. .store(Arc::into_raw(new_node.clone()) as *mut _, Ordering::Release);
  64. }
  65. self.head
  66. .store(Arc::into_raw(new_node) as *mut _, Ordering::Release);
  67. }
  68. pub fn remove(&self, node: &Arc<T>) {
  69. let _lck = self.update_lock.lock();
  70. let prev = node.rcu_prev().load(Ordering::Acquire);
  71. let next = node.rcu_next().load(Ordering::Acquire);
  72. if let Some(next) = unsafe { next.as_ref() } {
  73. let me = next.rcu_prev().swap(prev, Ordering::AcqRel);
  74. debug_assert!(me == Arc::as_ptr(&node) as *mut _);
  75. unsafe { Arc::from_raw(me) };
  76. }
  77. {
  78. let prev_next =
  79. unsafe { prev.as_ref().map(|rcu| rcu.rcu_next()) }.unwrap_or(&self.head);
  80. let me = prev_next.swap(next, Ordering::AcqRel);
  81. debug_assert!(me == Arc::as_ptr(&node) as *mut _);
  82. unsafe { Arc::from_raw(me) };
  83. }
  84. let _lck = self.reader_lock.write();
  85. node.rcu_prev()
  86. .store(core::ptr::null_mut(), Ordering::Release);
  87. node.rcu_next()
  88. .store(core::ptr::null_mut(), Ordering::Release);
  89. }
  90. pub fn replace(&self, old_node: &Arc<T>, new_node: Arc<T>) {
  91. let _lck = self.update_lock.lock();
  92. let prev = old_node.rcu_prev().load(Ordering::Acquire);
  93. let next = old_node.rcu_next().load(Ordering::Acquire);
  94. new_node.rcu_prev().store(prev, Ordering::Release);
  95. new_node.rcu_next().store(next, Ordering::Release);
  96. {
  97. let prev_next =
  98. unsafe { prev.as_ref().map(|rcu| rcu.rcu_next()) }.unwrap_or(&self.head);
  99. let old = prev_next.swap(Arc::into_raw(new_node.clone()) as *mut _, Ordering::AcqRel);
  100. debug_assert!(old == Arc::as_ptr(&old_node) as *mut _);
  101. unsafe { Arc::from_raw(old) };
  102. }
  103. if let Some(next) = unsafe { next.as_ref() } {
  104. let old = next
  105. .rcu_prev()
  106. .swap(Arc::into_raw(new_node.clone()) as *mut _, Ordering::AcqRel);
  107. debug_assert!(old == Arc::as_ptr(&old_node) as *mut _);
  108. unsafe { Arc::from_raw(old) };
  109. }
  110. let _lck = self.reader_lock.write();
  111. old_node
  112. .rcu_prev()
  113. .store(core::ptr::null_mut(), Ordering::Release);
  114. old_node
  115. .rcu_next()
  116. .store(core::ptr::null_mut(), Ordering::Release);
  117. }
  118. pub fn iter(&self) -> RCUIterator<T> {
  119. let _lck = Task::block_on(self.reader_lock.read());
  120. RCUIterator {
  121. // SAFETY: We have a read lock, so the node is still alive.
  122. cur: NonNull::new(self.head.load(Ordering::SeqCst)),
  123. _lock: _lck,
  124. }
  125. }
  126. }
  127. pub struct RCUIterator<'lt, T: RCUNode<T>> {
  128. cur: Option<NonNull<T>>,
  129. _lock: RwLockReadGuard<'lt, ()>,
  130. }
  131. impl<'lt, T: RCUNode<T>> Iterator for RCUIterator<'lt, T> {
  132. type Item = BorrowedArc<'lt, T>;
  133. fn next(&mut self) -> Option<Self::Item> {
  134. match self.cur {
  135. None => None,
  136. Some(pointer) => {
  137. // SAFETY: We have a read lock, so the node is still alive.
  138. let reference = unsafe { pointer.as_ref() };
  139. self.cur = NonNull::new(reference.rcu_next().load(Ordering::SeqCst));
  140. Some(unsafe { BorrowedArc::from_raw(pointer) })
  141. }
  142. }
  143. }
  144. }
  145. pub struct RCUPointer<T>(AtomicPtr<T>);
  146. impl<T: core::fmt::Debug> core::fmt::Debug for RCUPointer<T> {
  147. fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
  148. match NonNull::new(self.0.load(Ordering::Acquire)) {
  149. Some(pointer) => {
  150. let borrowed = unsafe { BorrowedArc::from_raw(pointer) };
  151. f.write_str("RCUPointer of ")?;
  152. borrowed.fmt(f)
  153. }
  154. None => f.debug_tuple("NULL RCUPointer").finish(),
  155. }
  156. }
  157. }
  158. impl<T> RCUPointer<T> {
  159. pub fn empty() -> Self {
  160. Self(AtomicPtr::new(core::ptr::null_mut()))
  161. }
  162. pub fn load<'lt>(&self) -> Option<RCUReadGuard<'lt, BorrowedArc<'lt, T>>> {
  163. NonNull::new(self.0.load(Ordering::Acquire))
  164. .map(|p| RCUReadGuard::lock(unsafe { BorrowedArc::from_raw(p) }))
  165. }
  166. /// # Safety
  167. /// Caller must ensure no writers are updating the pointer.
  168. pub unsafe fn load_locked<'lt>(&self) -> Option<BorrowedArc<'lt, T>> {
  169. NonNull::new(self.0.load(Ordering::Acquire)).map(|p| unsafe { BorrowedArc::from_raw(p) })
  170. }
  171. /// # Safety
  172. /// Caller must ensure that the actual pointer is freed after all readers are done.
  173. pub unsafe fn swap(&self, new: Option<Arc<T>>) -> Option<Arc<T>> {
  174. let new = new
  175. .map(|arc| Arc::into_raw(arc) as *mut T)
  176. .unwrap_or(core::ptr::null_mut());
  177. let old = self.0.swap(new, Ordering::AcqRel);
  178. if old.is_null() {
  179. None
  180. } else {
  181. Some(unsafe { Arc::from_raw(old) })
  182. }
  183. }
  184. }
  185. impl<T> Drop for RCUPointer<T> {
  186. fn drop(&mut self) {
  187. // SAFETY: We call `rcu_sync()` to ensure that all readers are done.
  188. if let Some(arc) = unsafe { self.swap(None) } {
  189. // We only wait if there are other references.
  190. if Arc::strong_count(&arc) == 1 {
  191. Task::block_on(rcu_sync());
  192. }
  193. }
  194. }
  195. }