rcu.rs 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. use crate::{kernel::task::block_on, prelude::*};
  2. use alloc::sync::Arc;
  3. use core::{
  4. ops::Deref,
  5. ptr::NonNull,
  6. sync::atomic::{AtomicPtr, Ordering},
  7. };
  8. use eonix_runtime::scheduler::RUNTIME;
  9. use eonix_sync::{Mutex, RwLock, RwLockReadGuard};
  10. use pointers::BorrowedArc;
  11. pub struct RCUReadGuard<'data, T: 'data> {
  12. value: T,
  13. _guard: RwLockReadGuard<'data, ()>,
  14. _phantom: PhantomData<&'data T>,
  15. }
  16. static GLOBAL_RCU_SEM: RwLock<()> = RwLock::new(());
  17. impl<'data, T> RCUReadGuard<'data, BorrowedArc<'data, T>> {
  18. fn lock(value: BorrowedArc<'data, T>) -> Self {
  19. Self {
  20. value,
  21. _guard: block_on(GLOBAL_RCU_SEM.read()),
  22. _phantom: PhantomData,
  23. }
  24. }
  25. pub fn borrow(&self) -> BorrowedArc<'data, T> {
  26. unsafe {
  27. BorrowedArc::from_raw(NonNull::new_unchecked(
  28. &raw const *self.value.borrow() as *mut T
  29. ))
  30. }
  31. }
  32. }
  33. impl<'data, T: 'data> Deref for RCUReadGuard<'data, T> {
  34. type Target = T;
  35. fn deref(&self) -> &Self::Target {
  36. &self.value
  37. }
  38. }
  39. pub async fn rcu_sync() {
  40. // Lock the global RCU semaphore to ensure that all readers are done.
  41. let _ = GLOBAL_RCU_SEM.write().await;
  42. }
  43. pub fn call_rcu(func: impl FnOnce() + Send + 'static) {
  44. RUNTIME.spawn(async move {
  45. // Wait for all readers to finish.
  46. rcu_sync().await;
  47. func();
  48. });
  49. }
  50. pub trait RCUNode<MySelf> {
  51. fn rcu_prev(&self) -> &AtomicPtr<MySelf>;
  52. fn rcu_next(&self) -> &AtomicPtr<MySelf>;
  53. }
  54. pub struct RCUList<T: RCUNode<T>> {
  55. head: AtomicPtr<T>,
  56. reader_lock: RwLock<()>,
  57. update_lock: Mutex<()>,
  58. }
  59. impl<T: RCUNode<T>> RCUList<T> {
  60. pub const fn new() -> Self {
  61. Self {
  62. head: AtomicPtr::new(core::ptr::null_mut()),
  63. reader_lock: RwLock::new(()),
  64. update_lock: Mutex::new(()),
  65. }
  66. }
  67. pub fn insert(&self, new_node: Arc<T>) {
  68. let _lck = self.update_lock.lock();
  69. let old_head = self.head.load(Ordering::Acquire);
  70. new_node
  71. .rcu_prev()
  72. .store(core::ptr::null_mut(), Ordering::Release);
  73. new_node.rcu_next().store(old_head, Ordering::Release);
  74. if let Some(old_head) = unsafe { old_head.as_ref() } {
  75. old_head
  76. .rcu_prev()
  77. .store(Arc::into_raw(new_node.clone()) as *mut _, Ordering::Release);
  78. }
  79. self.head
  80. .store(Arc::into_raw(new_node) as *mut _, Ordering::Release);
  81. }
  82. pub fn remove(&self, node: &Arc<T>) {
  83. let _lck = self.update_lock.lock();
  84. let prev = node.rcu_prev().load(Ordering::Acquire);
  85. let next = node.rcu_next().load(Ordering::Acquire);
  86. if let Some(next) = unsafe { next.as_ref() } {
  87. let me = next.rcu_prev().swap(prev, Ordering::AcqRel);
  88. debug_assert!(me == Arc::as_ptr(&node) as *mut _);
  89. unsafe { Arc::from_raw(me) };
  90. }
  91. {
  92. let prev_next =
  93. unsafe { prev.as_ref().map(|rcu| rcu.rcu_next()) }.unwrap_or(&self.head);
  94. let me = prev_next.swap(next, Ordering::AcqRel);
  95. debug_assert!(me == Arc::as_ptr(&node) as *mut _);
  96. unsafe { Arc::from_raw(me) };
  97. }
  98. let _lck = self.reader_lock.write();
  99. node.rcu_prev()
  100. .store(core::ptr::null_mut(), Ordering::Release);
  101. node.rcu_next()
  102. .store(core::ptr::null_mut(), Ordering::Release);
  103. }
  104. pub fn replace(&self, old_node: &Arc<T>, new_node: Arc<T>) {
  105. let _lck = self.update_lock.lock();
  106. let prev = old_node.rcu_prev().load(Ordering::Acquire);
  107. let next = old_node.rcu_next().load(Ordering::Acquire);
  108. new_node.rcu_prev().store(prev, Ordering::Release);
  109. new_node.rcu_next().store(next, Ordering::Release);
  110. {
  111. let prev_next =
  112. unsafe { prev.as_ref().map(|rcu| rcu.rcu_next()) }.unwrap_or(&self.head);
  113. let old = prev_next.swap(Arc::into_raw(new_node.clone()) as *mut _, Ordering::AcqRel);
  114. debug_assert!(old == Arc::as_ptr(&old_node) as *mut _);
  115. unsafe { Arc::from_raw(old) };
  116. }
  117. if let Some(next) = unsafe { next.as_ref() } {
  118. let old = next
  119. .rcu_prev()
  120. .swap(Arc::into_raw(new_node.clone()) as *mut _, Ordering::AcqRel);
  121. debug_assert!(old == Arc::as_ptr(&old_node) as *mut _);
  122. unsafe { Arc::from_raw(old) };
  123. }
  124. let _lck = self.reader_lock.write();
  125. old_node
  126. .rcu_prev()
  127. .store(core::ptr::null_mut(), Ordering::Release);
  128. old_node
  129. .rcu_next()
  130. .store(core::ptr::null_mut(), Ordering::Release);
  131. }
  132. pub fn iter(&self) -> RCUIterator<T> {
  133. let _lck = block_on(self.reader_lock.read());
  134. RCUIterator {
  135. // SAFETY: We have a read lock, so the node is still alive.
  136. cur: NonNull::new(self.head.load(Ordering::SeqCst)),
  137. _lock: _lck,
  138. }
  139. }
  140. }
  141. pub struct RCUIterator<'lt, T: RCUNode<T>> {
  142. cur: Option<NonNull<T>>,
  143. _lock: RwLockReadGuard<'lt, ()>,
  144. }
  145. impl<'lt, T: RCUNode<T>> Iterator for RCUIterator<'lt, T> {
  146. type Item = BorrowedArc<'lt, T>;
  147. fn next(&mut self) -> Option<Self::Item> {
  148. match self.cur {
  149. None => None,
  150. Some(pointer) => {
  151. // SAFETY: We have a read lock, so the node is still alive.
  152. let reference = unsafe { pointer.as_ref() };
  153. self.cur = NonNull::new(reference.rcu_next().load(Ordering::SeqCst));
  154. Some(unsafe { BorrowedArc::from_raw(pointer) })
  155. }
  156. }
  157. }
  158. }
  159. pub struct RCUPointer<T>(AtomicPtr<T>)
  160. where
  161. T: Send + Sync + 'static;
  162. impl<T> core::fmt::Debug for RCUPointer<T>
  163. where
  164. T: core::fmt::Debug,
  165. T: Send + Sync + 'static,
  166. {
  167. fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
  168. match NonNull::new(self.0.load(Ordering::Acquire)) {
  169. Some(pointer) => {
  170. let borrowed = unsafe { BorrowedArc::from_raw(pointer) };
  171. f.write_str("RCUPointer of ")?;
  172. borrowed.fmt(f)
  173. }
  174. None => f.debug_tuple("NULL RCUPointer").finish(),
  175. }
  176. }
  177. }
  178. impl<T> RCUPointer<T>
  179. where
  180. T: Send + Sync + 'static,
  181. {
  182. pub const fn empty() -> Self {
  183. Self(AtomicPtr::new(core::ptr::null_mut()))
  184. }
  185. pub fn new(value: Arc<T>) -> Self {
  186. Self(AtomicPtr::new(Arc::into_raw(value) as *mut T))
  187. }
  188. pub fn load<'lt>(&self) -> Option<RCUReadGuard<'lt, BorrowedArc<'lt, T>>> {
  189. NonNull::new(self.0.load(Ordering::Acquire))
  190. .map(|p| RCUReadGuard::lock(unsafe { BorrowedArc::from_raw(p) }))
  191. }
  192. pub fn load_protected<'a, U: 'a>(
  193. &self,
  194. _guard: &RCUReadGuard<'a, U>,
  195. ) -> Option<BorrowedArc<'a, T>> {
  196. NonNull::new(self.0.load(Ordering::Acquire)).map(|p| unsafe { BorrowedArc::from_raw(p) })
  197. }
  198. /// # Safety
  199. /// Caller must ensure no writers are updating the pointer.
  200. pub unsafe fn load_locked<'lt>(&self) -> Option<BorrowedArc<'lt, T>> {
  201. NonNull::new(self.0.load(Ordering::Acquire)).map(|p| unsafe { BorrowedArc::from_raw(p) })
  202. }
  203. /// # Safety
  204. /// Caller must ensure that the actual pointer is freed after all readers are done.
  205. pub unsafe fn swap(&self, new: Option<Arc<T>>) -> Option<Arc<T>> {
  206. let new = new
  207. .map(|arc| Arc::into_raw(arc) as *mut T)
  208. .unwrap_or(core::ptr::null_mut());
  209. let old = self.0.swap(new, Ordering::AcqRel);
  210. if old.is_null() {
  211. None
  212. } else {
  213. Some(unsafe { Arc::from_raw(old) })
  214. }
  215. }
  216. /// Exchange the value of the pointers.
  217. ///
  218. /// # Safety
  219. /// Presence of readers is acceptable.
  220. /// But the caller must ensure that we are the only one **altering** the pointers.
  221. pub unsafe fn exchange(old: &Self, new: &Self) {
  222. let old_value = old.0.load(Ordering::Acquire);
  223. let new_value = new.0.swap(old_value, Ordering::AcqRel);
  224. old.0.store(new_value, Ordering::Release);
  225. }
  226. }
  227. impl<T> Drop for RCUPointer<T>
  228. where
  229. T: Send + Sync + 'static,
  230. {
  231. fn drop(&mut self) {
  232. // SAFETY: We call `rcu_sync()` to ensure that all readers are done.
  233. if let Some(arc) = unsafe { self.swap(None) } {
  234. // We only wait if there are other references.
  235. if Arc::strong_count(&arc) == 1 {
  236. call_rcu(move || drop(arc));
  237. }
  238. }
  239. }
  240. }