rcu.rs 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. use crate::prelude::*;
  2. use alloc::sync::Arc;
  3. use core::{
  4. ops::Deref,
  5. ptr::NonNull,
  6. sync::atomic::{AtomicPtr, Ordering},
  7. };
  8. use eonix_runtime::task::Task;
  9. use eonix_sync::{Mutex, RwLock, RwLockReadGuard};
  10. use pointers::BorrowedArc;
  11. pub struct RCUReadGuard<'data, T: 'data> {
  12. value: T,
  13. _guard: RwLockReadGuard<'data, ()>,
  14. _phantom: PhantomData<&'data T>,
  15. }
  16. static GLOBAL_RCU_SEM: RwLock<()> = RwLock::new(());
  17. impl<'data, T> RCUReadGuard<'data, BorrowedArc<'data, T>> {
  18. fn lock(value: BorrowedArc<'data, T>) -> Self {
  19. Self {
  20. value,
  21. _guard: Task::block_on(GLOBAL_RCU_SEM.read()),
  22. _phantom: PhantomData,
  23. }
  24. }
  25. pub fn borrow(&self) -> BorrowedArc<'data, T> {
  26. unsafe {
  27. BorrowedArc::from_raw(NonNull::new_unchecked(
  28. &raw const *self.value.borrow() as *mut T
  29. ))
  30. }
  31. }
  32. }
  33. impl<'data, T: 'data> Deref for RCUReadGuard<'data, T> {
  34. type Target = T;
  35. fn deref(&self) -> &Self::Target {
  36. &self.value
  37. }
  38. }
  39. pub async fn rcu_sync() {
  40. // Lock the global RCU semaphore to ensure that all readers are done.
  41. let _ = GLOBAL_RCU_SEM.write().await;
  42. }
  43. pub trait RCUNode<MySelf> {
  44. fn rcu_prev(&self) -> &AtomicPtr<MySelf>;
  45. fn rcu_next(&self) -> &AtomicPtr<MySelf>;
  46. }
  47. pub struct RCUList<T: RCUNode<T>> {
  48. head: AtomicPtr<T>,
  49. reader_lock: RwLock<()>,
  50. update_lock: Mutex<()>,
  51. }
  52. impl<T: RCUNode<T>> RCUList<T> {
  53. pub const fn new() -> Self {
  54. Self {
  55. head: AtomicPtr::new(core::ptr::null_mut()),
  56. reader_lock: RwLock::new(()),
  57. update_lock: Mutex::new(()),
  58. }
  59. }
  60. pub fn insert(&self, new_node: Arc<T>) {
  61. let _lck = self.update_lock.lock();
  62. let old_head = self.head.load(Ordering::Acquire);
  63. new_node
  64. .rcu_prev()
  65. .store(core::ptr::null_mut(), Ordering::Release);
  66. new_node.rcu_next().store(old_head, Ordering::Release);
  67. if let Some(old_head) = unsafe { old_head.as_ref() } {
  68. old_head
  69. .rcu_prev()
  70. .store(Arc::into_raw(new_node.clone()) as *mut _, Ordering::Release);
  71. }
  72. self.head
  73. .store(Arc::into_raw(new_node) as *mut _, Ordering::Release);
  74. }
  75. pub fn remove(&self, node: &Arc<T>) {
  76. let _lck = self.update_lock.lock();
  77. let prev = node.rcu_prev().load(Ordering::Acquire);
  78. let next = node.rcu_next().load(Ordering::Acquire);
  79. if let Some(next) = unsafe { next.as_ref() } {
  80. let me = next.rcu_prev().swap(prev, Ordering::AcqRel);
  81. debug_assert!(me == Arc::as_ptr(&node) as *mut _);
  82. unsafe { Arc::from_raw(me) };
  83. }
  84. {
  85. let prev_next =
  86. unsafe { prev.as_ref().map(|rcu| rcu.rcu_next()) }.unwrap_or(&self.head);
  87. let me = prev_next.swap(next, Ordering::AcqRel);
  88. debug_assert!(me == Arc::as_ptr(&node) as *mut _);
  89. unsafe { Arc::from_raw(me) };
  90. }
  91. let _lck = self.reader_lock.write();
  92. node.rcu_prev()
  93. .store(core::ptr::null_mut(), Ordering::Release);
  94. node.rcu_next()
  95. .store(core::ptr::null_mut(), Ordering::Release);
  96. }
  97. pub fn replace(&self, old_node: &Arc<T>, new_node: Arc<T>) {
  98. let _lck = self.update_lock.lock();
  99. let prev = old_node.rcu_prev().load(Ordering::Acquire);
  100. let next = old_node.rcu_next().load(Ordering::Acquire);
  101. new_node.rcu_prev().store(prev, Ordering::Release);
  102. new_node.rcu_next().store(next, Ordering::Release);
  103. {
  104. let prev_next =
  105. unsafe { prev.as_ref().map(|rcu| rcu.rcu_next()) }.unwrap_or(&self.head);
  106. let old = prev_next.swap(Arc::into_raw(new_node.clone()) as *mut _, Ordering::AcqRel);
  107. debug_assert!(old == Arc::as_ptr(&old_node) as *mut _);
  108. unsafe { Arc::from_raw(old) };
  109. }
  110. if let Some(next) = unsafe { next.as_ref() } {
  111. let old = next
  112. .rcu_prev()
  113. .swap(Arc::into_raw(new_node.clone()) as *mut _, Ordering::AcqRel);
  114. debug_assert!(old == Arc::as_ptr(&old_node) as *mut _);
  115. unsafe { Arc::from_raw(old) };
  116. }
  117. let _lck = self.reader_lock.write();
  118. old_node
  119. .rcu_prev()
  120. .store(core::ptr::null_mut(), Ordering::Release);
  121. old_node
  122. .rcu_next()
  123. .store(core::ptr::null_mut(), Ordering::Release);
  124. }
  125. pub fn iter(&self) -> RCUIterator<T> {
  126. let _lck = Task::block_on(self.reader_lock.read());
  127. RCUIterator {
  128. // SAFETY: We have a read lock, so the node is still alive.
  129. cur: NonNull::new(self.head.load(Ordering::SeqCst)),
  130. _lock: _lck,
  131. }
  132. }
  133. }
  134. pub struct RCUIterator<'lt, T: RCUNode<T>> {
  135. cur: Option<NonNull<T>>,
  136. _lock: RwLockReadGuard<'lt, ()>,
  137. }
  138. impl<'lt, T: RCUNode<T>> Iterator for RCUIterator<'lt, T> {
  139. type Item = BorrowedArc<'lt, T>;
  140. fn next(&mut self) -> Option<Self::Item> {
  141. match self.cur {
  142. None => None,
  143. Some(pointer) => {
  144. // SAFETY: We have a read lock, so the node is still alive.
  145. let reference = unsafe { pointer.as_ref() };
  146. self.cur = NonNull::new(reference.rcu_next().load(Ordering::SeqCst));
  147. Some(unsafe { BorrowedArc::from_raw(pointer) })
  148. }
  149. }
  150. }
  151. }
  152. pub struct RCUPointer<T>(AtomicPtr<T>);
  153. impl<T: core::fmt::Debug> core::fmt::Debug for RCUPointer<T> {
  154. fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
  155. match NonNull::new(self.0.load(Ordering::Acquire)) {
  156. Some(pointer) => {
  157. let borrowed = unsafe { BorrowedArc::from_raw(pointer) };
  158. f.write_str("RCUPointer of ")?;
  159. borrowed.fmt(f)
  160. }
  161. None => f.debug_tuple("NULL RCUPointer").finish(),
  162. }
  163. }
  164. }
  165. impl<T> RCUPointer<T> {
  166. pub const fn empty() -> Self {
  167. Self(AtomicPtr::new(core::ptr::null_mut()))
  168. }
  169. pub fn new(value: Arc<T>) -> Self {
  170. Self(AtomicPtr::new(Arc::into_raw(value) as *mut T))
  171. }
  172. pub fn load<'lt>(&self) -> Option<RCUReadGuard<'lt, BorrowedArc<'lt, T>>> {
  173. NonNull::new(self.0.load(Ordering::Acquire))
  174. .map(|p| RCUReadGuard::lock(unsafe { BorrowedArc::from_raw(p) }))
  175. }
  176. pub fn load_protected<'a, U: 'a>(
  177. &self,
  178. _guard: &RCUReadGuard<'a, U>,
  179. ) -> Option<BorrowedArc<'a, T>> {
  180. NonNull::new(self.0.load(Ordering::Acquire)).map(|p| unsafe { BorrowedArc::from_raw(p) })
  181. }
  182. /// # Safety
  183. /// Caller must ensure no writers are updating the pointer.
  184. pub unsafe fn load_locked<'lt>(&self) -> Option<BorrowedArc<'lt, T>> {
  185. NonNull::new(self.0.load(Ordering::Acquire)).map(|p| unsafe { BorrowedArc::from_raw(p) })
  186. }
  187. /// # Safety
  188. /// Caller must ensure that the actual pointer is freed after all readers are done.
  189. pub unsafe fn swap(&self, new: Option<Arc<T>>) -> Option<Arc<T>> {
  190. let new = new
  191. .map(|arc| Arc::into_raw(arc) as *mut T)
  192. .unwrap_or(core::ptr::null_mut());
  193. let old = self.0.swap(new, Ordering::AcqRel);
  194. if old.is_null() {
  195. None
  196. } else {
  197. Some(unsafe { Arc::from_raw(old) })
  198. }
  199. }
  200. /// Exchange the value of the pointers.
  201. ///
  202. /// # Safety
  203. /// Presence of readers is acceptable.
  204. /// But the caller must ensure that we are the only one **altering** the pointers.
  205. pub unsafe fn exchange(old: &Self, new: &Self) {
  206. let old_value = old.0.load(Ordering::Acquire);
  207. let new_value = new.0.swap(old_value, Ordering::AcqRel);
  208. old.0.store(new_value, Ordering::Release);
  209. }
  210. }
  211. impl<T> Drop for RCUPointer<T> {
  212. fn drop(&mut self) {
  213. // SAFETY: We call `rcu_sync()` to ensure that all readers are done.
  214. if let Some(arc) = unsafe { self.swap(None) } {
  215. // We only wait if there are other references.
  216. if Arc::strong_count(&arc) == 1 {
  217. Task::block_on(rcu_sync());
  218. }
  219. }
  220. }
  221. }