Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 185 additions & 2 deletions utils/src/rw_task_lock.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use std::future::Future;
use std::mem::replace;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::Mutex;

use thiserror::Error;
use tokio::sync::{RwLock, RwLockReadGuard};
use tokio::task::{JoinError, JoinHandle};

type BoxedFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;

#[derive(Debug, Error)]
#[non_exhaustive]
pub enum RwTaskLockError {
Expand All @@ -16,7 +20,21 @@ pub enum RwTaskLockError {
CalledAfterError,
}

/// Internal state for `RwTaskLock`.
///
/// The `Delayed` variant wraps the future in `std::sync::Mutex` to satisfy `Sync` bounds.
/// This is necessary because:
/// - `tokio::sync::RwLock<T>` requires `T: Send + Sync` to be `Sync` (since multiple readers can hold `&T`
/// simultaneously).
/// - `BoxedFuture` (`Pin<Box<dyn Future<...> + Send>>`) is `Send` but NOT `Sync`.
/// - `std::sync::Mutex<T>` is `Sync` when `T: Send` (exclusive access only).
/// - We're gauranteed that we won't have contention, and move the value out of the Mutex immediately to use, so it's
/// just a quick trick to satisfy the `Sync` bounds.
///
/// We use `RwLock` (not `Mutex`) for the outer lock because once the value is `Ready`,
/// multiple callers can hold read guards simultaneously to access the cached result.
enum RwTaskLockState<T, E> {
Delayed(Mutex<BoxedFuture<T, E>>),
Pending(JoinHandle<Result<T, E>>),
Ready(T),
Error,
Expand Down Expand Up @@ -93,6 +111,34 @@ where
}
}

/// From a future yielding Result<T, E> that will not start until `run_delayed()` is called
/// or `read()` is called.
pub fn from_task_delayed<Fut>(fut: Fut) -> Self
where
Fut: Future<Output = Result<T, E>> + Send + 'static,
{
Self {
state: RwLock::new(RwTaskLockState::Delayed(Mutex::new(Box::pin(fut)))),
}
}

/// Start the delayed future as a background task, moving it to the Pending state.
pub fn run_delayed(&self) {
let state = self.state.try_write();
if let Ok(mut state_guard) = state {
match replace(&mut *state_guard, RwTaskLockState::Error) {
RwTaskLockState::Delayed(fut_mutex) => {
let fut = fut_mutex.into_inner().unwrap();
let task = tokio::spawn(fut);
*state_guard = RwTaskLockState::Pending(task);
},
other => {
*state_guard = other;
},
}
}
}

/// Awaitable read: yields a custom read guard or error.
pub async fn read(&self) -> Result<RwTaskLockReadGuard<'_, T, E>, E> {
// Fast path
Expand All @@ -104,6 +150,7 @@ where
},
RwTaskLockState::Error => return Err(E::from(RwTaskLockError::CalledAfterError)),
RwTaskLockState::Pending(_) => {},
RwTaskLockState::Delayed(_) => {},
}
}
// Acquire write lock to initialize if necessary
Expand All @@ -127,6 +174,18 @@ where
},
};
},
RwTaskLockState::Delayed(fut_mutex) => {
let mut fut = fut_mutex.into_inner().unwrap();
match fut.as_mut().await {
Ok(v) => {
*state = RwTaskLockState::Ready(v);
},
Err(e) => {
*state = RwTaskLockState::Error;
return Err(e);
},
};
},
};

Ok(RwTaskLockReadGuard {
Expand Down Expand Up @@ -206,6 +265,16 @@ where
*state_lg = Pending(tokio::spawn(updater(v)));
Ok(())
},
Delayed(fut_mutex) => {
// Execute the delayed future, then chain the updater.
let fut = fut_mutex.into_inner().unwrap();
let new_task = tokio::spawn(async move {
let current = fut.await?;
updater(current).await
});
*state_lg = Pending(new_task);
Ok(())
},
Error => {
// Can't update if in error.
*state_lg = Error;
Expand All @@ -218,6 +287,9 @@ where
#[cfg(test)]
mod tests {

use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};

use super::*;

#[tokio::test]
Expand Down Expand Up @@ -252,7 +324,6 @@ mod tests {

#[tokio::test]
async fn test_concurrent_read() {
use std::sync::Arc;
let lock = Arc::new(RwTaskLock::from_task(async {
tokio::time::sleep(std::time::Duration::from_millis(30)).await;
Ok::<_, RwTaskLockError>("concurrent".to_string())
Expand Down Expand Up @@ -282,7 +353,6 @@ mod tests {

#[tokio::test]
async fn test_update_chained_pending() {
use std::sync::Arc;
let lock = Arc::new(RwTaskLock::from_task(async {
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
Ok::<_, RwTaskLockError>(5)
Expand Down Expand Up @@ -321,4 +391,117 @@ mod tests {
let guard = lock.read().await.unwrap();
assert_eq!(*guard, 22);
}

#[tokio::test]
async fn test_delayed_read_executes_future() {
let flag = Arc::new(AtomicBool::new(false));
let flag_ = flag.clone();
let lock = RwTaskLock::from_task_delayed(async move {
flag_.store(true, Ordering::Relaxed);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
Ok::<_, RwTaskLockError>(42)
});
assert!(!flag.load(Ordering::Relaxed));
let guard = lock.read().await.unwrap();
assert!(flag.load(Ordering::Relaxed));
assert_eq!(*guard, 42);
let guard2 = lock.read().await.unwrap();
assert_eq!(*guard2, 42);
}

#[tokio::test]
async fn test_delayed_run_delayed_spawns_task() {
let flag = Arc::new(AtomicBool::new(false));
let flag_ = flag.clone();
let lock = RwTaskLock::from_task_delayed(async move {
flag_.store(true, Ordering::Relaxed);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
Ok::<_, RwTaskLockError>(100)
});
assert!(!flag.load(Ordering::Relaxed));
lock.run_delayed();
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
assert!(flag.load(Ordering::Relaxed));
let guard = lock.read().await.unwrap();
assert_eq!(*guard, 100);
}

#[tokio::test]
async fn test_delayed_error() {
let flag = Arc::new(AtomicBool::new(false));
let flag_ = flag.clone();
let lock = RwTaskLock::<u8, RwTaskLockError>::from_task_delayed(async move {
flag_.store(true, Ordering::Relaxed);
Err(RwTaskLockError::CalledAfterError)
});
assert!(!flag.load(Ordering::Relaxed));
let result = lock.read().await;
assert!(flag.load(Ordering::Relaxed));
assert!(matches!(result, Err(RwTaskLockError::CalledAfterError)));
let result2 = lock.read().await;
assert!(matches!(result2, Err(RwTaskLockError::CalledAfterError)));
}

#[tokio::test]
async fn test_delayed_concurrent_read() {
let flag = Arc::new(AtomicBool::new(false));
let flag_ = flag.clone();
let lock = Arc::new(RwTaskLock::from_task_delayed(async move {
flag_.store(true, Ordering::Relaxed);
tokio::time::sleep(std::time::Duration::from_millis(30)).await;
Ok::<_, RwTaskLockError>("delayed_concurrent".to_string())
}));
assert!(!flag.load(Ordering::Relaxed));
let lock1 = lock.clone();
let lock2 = lock.clone();
let (a, b) = tokio::join!(lock1.read(), lock2.read());
assert!(flag.load(Ordering::Relaxed));
assert_eq!(*a.unwrap(), "delayed_concurrent");
assert_eq!(*b.unwrap(), "delayed_concurrent");
}

#[tokio::test]
async fn test_delayed_update() {
let flag = Arc::new(AtomicBool::new(false));
let flag_ = flag.clone();
let lock = RwTaskLock::from_task_delayed(async move {
flag_.store(true, Ordering::Relaxed);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
Ok::<_, RwTaskLockError>(5)
});
assert!(!flag.load(Ordering::Relaxed));
lock.update(|v| async move { Ok::<_, RwTaskLockError>(v * 2) }).await.unwrap();
let guard = lock.read().await.unwrap();
assert!(flag.load(Ordering::Relaxed));
assert_eq!(*guard, 10);
}

#[tokio::test]
async fn test_delayed_run_delayed_then_read() {
let flag = Arc::new(AtomicBool::new(false));
let flag_ = flag.clone();
let lock = RwTaskLock::from_task_delayed(async move {
flag_.store(true, Ordering::Relaxed);
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
Ok::<_, RwTaskLockError>(200)
});
assert!(!flag.load(Ordering::Relaxed));
lock.run_delayed();
tokio::time::sleep(std::time::Duration::from_millis(25)).await;
assert!(flag.load(Ordering::Relaxed));
let guard = lock.read().await.unwrap();
assert_eq!(*guard, 200);
}

#[tokio::test]
async fn test_delayed_does_not_run_without_trigger() {
let flag = Arc::new(AtomicBool::new(false));
let flag_ = flag.clone();
let _lock = RwTaskLock::from_task_delayed(async move {
flag_.store(true, Ordering::Relaxed);
Ok::<_, RwTaskLockError>(999)
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(!flag.load(Ordering::Relaxed));
}
}
Loading