mutex.rs 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. mod guard;
  2. use crate::WaitList;
  3. use core::{
  4. cell::UnsafeCell,
  5. pin::pin,
  6. sync::atomic::{AtomicBool, Ordering},
  7. };
  8. pub use guard::MutexGuard;
  9. #[derive(Debug, Default)]
  10. pub struct Mutex<T>
  11. where
  12. T: ?Sized,
  13. {
  14. locked: AtomicBool,
  15. wait_list: WaitList,
  16. value: UnsafeCell<T>,
  17. }
  18. impl<T> Mutex<T> {
  19. pub const fn new(value: T) -> Self {
  20. Self {
  21. locked: AtomicBool::new(false),
  22. wait_list: WaitList::new(),
  23. value: UnsafeCell::new(value),
  24. }
  25. }
  26. }
  27. impl<T> Mutex<T>
  28. where
  29. T: ?Sized,
  30. {
  31. /// # Safety
  32. /// This function is unsafe because the caller MUST ensure that we've got the
  33. /// exclusive access before calling this function.
  34. unsafe fn get_lock(&self) -> MutexGuard<'_, T> {
  35. MutexGuard {
  36. lock: self,
  37. // SAFETY: We are holding the lock, so we can safely access the value.
  38. value: unsafe { &mut *self.value.get() },
  39. }
  40. }
  41. pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
  42. self.locked
  43. .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
  44. .ok()
  45. .map(|_| unsafe { self.get_lock() })
  46. }
  47. fn try_lock_weak(&self) -> Option<MutexGuard<'_, T>> {
  48. self.locked
  49. .compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed)
  50. .ok()
  51. .map(|_| unsafe { self.get_lock() })
  52. }
  53. #[cold]
  54. async fn lock_slow_path(&self) -> MutexGuard<'_, T> {
  55. loop {
  56. let mut wait = pin!(self.wait_list.prepare_to_wait());
  57. wait.as_mut().add_to_wait_list();
  58. if let Some(guard) = self.try_lock_weak() {
  59. return guard;
  60. }
  61. wait.await;
  62. }
  63. }
  64. pub async fn lock(&self) -> MutexGuard<'_, T> {
  65. if let Some(guard) = self.try_lock() {
  66. // Quick path
  67. guard
  68. } else {
  69. self.lock_slow_path().await
  70. }
  71. }
  72. pub fn get_mut(&mut self) -> &mut T {
  73. // SAFETY: The exclusive access to the lock is guaranteed by the borrow checker.
  74. unsafe { &mut *self.value.get() }
  75. }
  76. }
  77. // SAFETY: As long as the value protected by the lock is able to be shared between threads,
  78. // we can send the lock between threads.
  79. unsafe impl<T> Send for Mutex<T> where T: ?Sized + Send {}
  80. // SAFETY: `RwLock` can provide exclusive access to the value it protects, so it is safe to
  81. // implement `Sync` for it as long as the protected value is `Send`.
  82. unsafe impl<T> Sync for Mutex<T> where T: ?Sized + Send {}