rwlock.rs 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. mod guard;
  2. use crate::WaitList;
  3. use core::{
  4. cell::UnsafeCell,
  5. pin::pin,
  6. sync::atomic::{AtomicIsize, Ordering},
  7. };
  8. pub use guard::{RwLockReadGuard, RwLockWriteGuard};
  9. #[derive(Debug, Default)]
  10. pub struct RwLock<T>
  11. where
  12. T: ?Sized,
  13. {
  14. counter: AtomicIsize,
  15. read_wait: WaitList,
  16. write_wait: WaitList,
  17. value: UnsafeCell<T>,
  18. }
  19. impl<T> RwLock<T> {
  20. pub const fn new(value: T) -> Self {
  21. Self {
  22. counter: AtomicIsize::new(0),
  23. read_wait: WaitList::new(),
  24. write_wait: WaitList::new(),
  25. value: UnsafeCell::new(value),
  26. }
  27. }
  28. }
  29. impl<T> RwLock<T>
  30. where
  31. T: ?Sized,
  32. {
  33. /// # Safety
  34. /// This function is unsafe because the caller MUST ensure that we've got the
  35. /// write access before calling this function.
  36. unsafe fn write_lock(&self) -> RwLockWriteGuard<'_, T> {
  37. RwLockWriteGuard {
  38. lock: self,
  39. // SAFETY: We are holding the write lock, so we can safely access the value.
  40. value: unsafe { &mut *self.value.get() },
  41. }
  42. }
  43. /// # Safety
  44. /// This function is unsafe because the caller MUST ensure that we've got the
  45. /// read access before calling this function.
  46. unsafe fn read_lock(&self) -> RwLockReadGuard<'_, T> {
  47. RwLockReadGuard {
  48. lock: self,
  49. // SAFETY: We are holding the read lock, so we can safely access the value.
  50. value: unsafe { &*self.value.get() },
  51. }
  52. }
  53. /// # Safety
  54. /// This function is unsafe because the caller MUST ensure that we won't hold any
  55. /// references to the value after calling this function.
  56. pub(self) unsafe fn write_unlock(&self) {
  57. let old = self.counter.swap(0, Ordering::Release);
  58. debug_assert_eq!(
  59. old, -1,
  60. "RwLock::write_unlock(): erroneous counter value: {}",
  61. old
  62. );
  63. if !self.write_wait.notify_one() {
  64. self.read_wait.notify_all();
  65. }
  66. }
  67. /// # Safety
  68. /// This function is unsafe because the caller MUST ensure that we won't hold any
  69. /// references to the value after calling this function.
  70. pub(self) unsafe fn read_unlock(&self) {
  71. match self.counter.fetch_sub(1, Ordering::Release) {
  72. 2.. => {}
  73. 1 => {
  74. if !self.write_wait.notify_one() {
  75. self.read_wait.notify_all();
  76. }
  77. }
  78. val => unreachable!("RwLock::read_unlock(): erroneous counter value: {}", val),
  79. }
  80. }
  81. pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
  82. self.counter
  83. .compare_exchange(0, -1, Ordering::Acquire, Ordering::Relaxed)
  84. .ok()
  85. .map(|_| unsafe { self.write_lock() })
  86. }
  87. fn try_write_weak(&self) -> Option<RwLockWriteGuard<'_, T>> {
  88. self.counter
  89. .compare_exchange_weak(0, -1, Ordering::Acquire, Ordering::Relaxed)
  90. .ok()
  91. .map(|_| unsafe { self.write_lock() })
  92. }
  93. pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
  94. // We'll spin if we fail here anyway.
  95. if self.write_wait.has_waiters() {
  96. return None;
  97. }
  98. let counter = self.counter.load(Ordering::Relaxed);
  99. if counter >= 0 {
  100. self.counter
  101. .compare_exchange(counter, counter + 1, Ordering::Acquire, Ordering::Relaxed)
  102. .ok()
  103. .map(|_| unsafe { self.read_lock() })
  104. } else {
  105. None
  106. }
  107. }
  108. fn try_read_weak(&self) -> Option<RwLockReadGuard<'_, T>> {
  109. // TODO: If we check write waiters here, we would lose wakeups.
  110. // Try locking the wait lists to prevent this.
  111. let counter = self.counter.load(Ordering::Relaxed);
  112. if counter >= 0 {
  113. self.counter
  114. .compare_exchange_weak(counter, counter + 1, Ordering::Acquire, Ordering::Relaxed)
  115. .ok()
  116. .map(|_| unsafe { self.read_lock() })
  117. } else {
  118. None
  119. }
  120. }
  121. #[cold]
  122. async fn write_slow_path(&self) -> RwLockWriteGuard<'_, T> {
  123. loop {
  124. let mut wait = pin!(self.write_wait.prepare_to_wait());
  125. wait.as_mut().add_to_wait_list();
  126. if let Some(guard) = self.try_write_weak() {
  127. return guard;
  128. }
  129. wait.await;
  130. }
  131. }
  132. #[cold]
  133. async fn read_slow_path(&self) -> RwLockReadGuard<'_, T> {
  134. loop {
  135. let mut wait = pin!(self.read_wait.prepare_to_wait());
  136. wait.as_mut().add_to_wait_list();
  137. if let Some(guard) = self.try_read_weak() {
  138. return guard;
  139. }
  140. wait.await;
  141. }
  142. }
  143. pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
  144. if let Some(guard) = self.try_write() {
  145. // Quick path
  146. guard
  147. } else {
  148. self.write_slow_path().await
  149. }
  150. }
  151. pub async fn read(&self) -> RwLockReadGuard<'_, T> {
  152. if let Some(guard) = self.try_read() {
  153. // Quick path
  154. guard
  155. } else {
  156. self.read_slow_path().await
  157. }
  158. }
  159. pub fn get_mut(&mut self) -> &mut T {
  160. // SAFETY: The exclusive access to the lock is guaranteed by the borrow checker.
  161. unsafe { &mut *self.value.get() }
  162. }
  163. }
  164. // SAFETY: As long as the value protected by the lock is able to be shared between threads,
  165. // we can send the lock between threads.
  166. unsafe impl<T> Send for RwLock<T> where T: ?Sized + Send {}
  167. // SAFETY: `RwLock` can provide shared access to the value it protects, so it is safe to
  168. // implement `Sync` for it. However, this is only true if the value itself is `Sync`.
  169. unsafe impl<T> Sync for RwLock<T> where T: ?Sized + Send + Sync {}