diff --git a/utils/src/rw_task_lock.rs b/utils/src/rw_task_lock.rs index aa522f23..139248f8 100644 --- a/utils/src/rw_task_lock.rs +++ b/utils/src/rw_task_lock.rs @@ -1,11 +1,15 @@ use std::future::Future; use std::mem::replace; use std::ops::Deref; +use std::pin::Pin; +use std::sync::Mutex; use thiserror::Error; use tokio::sync::{RwLock, RwLockReadGuard}; use tokio::task::{JoinError, JoinHandle}; +type BoxedFuture = Pin> + Send>>; + #[derive(Debug, Error)] #[non_exhaustive] pub enum RwTaskLockError { @@ -16,7 +20,21 @@ pub enum RwTaskLockError { CalledAfterError, } +/// Internal state for `RwTaskLock`. +/// +/// The `Delayed` variant wraps the future in `std::sync::Mutex` to satisfy `Sync` bounds. +/// This is necessary because: +/// - `tokio::sync::RwLock` requires `T: Send + Sync` to be `Sync` (since multiple readers can hold `&T` +/// simultaneously). +/// - `BoxedFuture` (`Pin + Send>>`) is `Send` but NOT `Sync`. +/// - `std::sync::Mutex` is `Sync` when `T: Send` (exclusive access only). +/// - We're gauranteed that we won't have contention, and move the value out of the Mutex immediately to use, so it's +/// just a quick trick to satisfy the `Sync` bounds. +/// +/// We use `RwLock` (not `Mutex`) for the outer lock because once the value is `Ready`, +/// multiple callers can hold read guards simultaneously to access the cached result. enum RwTaskLockState { + Delayed(Mutex>), Pending(JoinHandle>), Ready(T), Error, @@ -93,6 +111,34 @@ where } } + /// From a future yielding Result that will not start until `run_delayed()` is called + /// or `read()` is called. + pub fn from_task_delayed(fut: Fut) -> Self + where + Fut: Future> + Send + 'static, + { + Self { + state: RwLock::new(RwTaskLockState::Delayed(Mutex::new(Box::pin(fut)))), + } + } + + /// Start the delayed future as a background task, moving it to the Pending state. + pub fn run_delayed(&self) { + let state = self.state.try_write(); + if let Ok(mut state_guard) = state { + match replace(&mut *state_guard, RwTaskLockState::Error) { + RwTaskLockState::Delayed(fut_mutex) => { + let fut = fut_mutex.into_inner().unwrap(); + let task = tokio::spawn(fut); + *state_guard = RwTaskLockState::Pending(task); + }, + other => { + *state_guard = other; + }, + } + } + } + /// Awaitable read: yields a custom read guard or error. pub async fn read(&self) -> Result, E> { // Fast path @@ -104,6 +150,7 @@ where }, RwTaskLockState::Error => return Err(E::from(RwTaskLockError::CalledAfterError)), RwTaskLockState::Pending(_) => {}, + RwTaskLockState::Delayed(_) => {}, } } // Acquire write lock to initialize if necessary @@ -127,6 +174,18 @@ where }, }; }, + RwTaskLockState::Delayed(fut_mutex) => { + let mut fut = fut_mutex.into_inner().unwrap(); + match fut.as_mut().await { + Ok(v) => { + *state = RwTaskLockState::Ready(v); + }, + Err(e) => { + *state = RwTaskLockState::Error; + return Err(e); + }, + }; + }, }; Ok(RwTaskLockReadGuard { @@ -206,6 +265,16 @@ where *state_lg = Pending(tokio::spawn(updater(v))); Ok(()) }, + Delayed(fut_mutex) => { + // Execute the delayed future, then chain the updater. + let fut = fut_mutex.into_inner().unwrap(); + let new_task = tokio::spawn(async move { + let current = fut.await?; + updater(current).await + }); + *state_lg = Pending(new_task); + Ok(()) + }, Error => { // Can't update if in error. *state_lg = Error; @@ -218,6 +287,9 @@ where #[cfg(test)] mod tests { + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; + use super::*; #[tokio::test] @@ -252,7 +324,6 @@ mod tests { #[tokio::test] async fn test_concurrent_read() { - use std::sync::Arc; let lock = Arc::new(RwTaskLock::from_task(async { tokio::time::sleep(std::time::Duration::from_millis(30)).await; Ok::<_, RwTaskLockError>("concurrent".to_string()) @@ -282,7 +353,6 @@ mod tests { #[tokio::test] async fn test_update_chained_pending() { - use std::sync::Arc; let lock = Arc::new(RwTaskLock::from_task(async { tokio::time::sleep(std::time::Duration::from_millis(20)).await; Ok::<_, RwTaskLockError>(5) @@ -321,4 +391,117 @@ mod tests { let guard = lock.read().await.unwrap(); assert_eq!(*guard, 22); } + + #[tokio::test] + async fn test_delayed_read_executes_future() { + let flag = Arc::new(AtomicBool::new(false)); + let flag_ = flag.clone(); + let lock = RwTaskLock::from_task_delayed(async move { + flag_.store(true, Ordering::Relaxed); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + Ok::<_, RwTaskLockError>(42) + }); + assert!(!flag.load(Ordering::Relaxed)); + let guard = lock.read().await.unwrap(); + assert!(flag.load(Ordering::Relaxed)); + assert_eq!(*guard, 42); + let guard2 = lock.read().await.unwrap(); + assert_eq!(*guard2, 42); + } + + #[tokio::test] + async fn test_delayed_run_delayed_spawns_task() { + let flag = Arc::new(AtomicBool::new(false)); + let flag_ = flag.clone(); + let lock = RwTaskLock::from_task_delayed(async move { + flag_.store(true, Ordering::Relaxed); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + Ok::<_, RwTaskLockError>(100) + }); + assert!(!flag.load(Ordering::Relaxed)); + lock.run_delayed(); + tokio::time::sleep(std::time::Duration::from_millis(5)).await; + assert!(flag.load(Ordering::Relaxed)); + let guard = lock.read().await.unwrap(); + assert_eq!(*guard, 100); + } + + #[tokio::test] + async fn test_delayed_error() { + let flag = Arc::new(AtomicBool::new(false)); + let flag_ = flag.clone(); + let lock = RwTaskLock::::from_task_delayed(async move { + flag_.store(true, Ordering::Relaxed); + Err(RwTaskLockError::CalledAfterError) + }); + assert!(!flag.load(Ordering::Relaxed)); + let result = lock.read().await; + assert!(flag.load(Ordering::Relaxed)); + assert!(matches!(result, Err(RwTaskLockError::CalledAfterError))); + let result2 = lock.read().await; + assert!(matches!(result2, Err(RwTaskLockError::CalledAfterError))); + } + + #[tokio::test] + async fn test_delayed_concurrent_read() { + let flag = Arc::new(AtomicBool::new(false)); + let flag_ = flag.clone(); + let lock = Arc::new(RwTaskLock::from_task_delayed(async move { + flag_.store(true, Ordering::Relaxed); + tokio::time::sleep(std::time::Duration::from_millis(30)).await; + Ok::<_, RwTaskLockError>("delayed_concurrent".to_string()) + })); + assert!(!flag.load(Ordering::Relaxed)); + let lock1 = lock.clone(); + let lock2 = lock.clone(); + let (a, b) = tokio::join!(lock1.read(), lock2.read()); + assert!(flag.load(Ordering::Relaxed)); + assert_eq!(*a.unwrap(), "delayed_concurrent"); + assert_eq!(*b.unwrap(), "delayed_concurrent"); + } + + #[tokio::test] + async fn test_delayed_update() { + let flag = Arc::new(AtomicBool::new(false)); + let flag_ = flag.clone(); + let lock = RwTaskLock::from_task_delayed(async move { + flag_.store(true, Ordering::Relaxed); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + Ok::<_, RwTaskLockError>(5) + }); + assert!(!flag.load(Ordering::Relaxed)); + lock.update(|v| async move { Ok::<_, RwTaskLockError>(v * 2) }).await.unwrap(); + let guard = lock.read().await.unwrap(); + assert!(flag.load(Ordering::Relaxed)); + assert_eq!(*guard, 10); + } + + #[tokio::test] + async fn test_delayed_run_delayed_then_read() { + let flag = Arc::new(AtomicBool::new(false)); + let flag_ = flag.clone(); + let lock = RwTaskLock::from_task_delayed(async move { + flag_.store(true, Ordering::Relaxed); + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + Ok::<_, RwTaskLockError>(200) + }); + assert!(!flag.load(Ordering::Relaxed)); + lock.run_delayed(); + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + assert!(flag.load(Ordering::Relaxed)); + let guard = lock.read().await.unwrap(); + assert_eq!(*guard, 200); + } + + #[tokio::test] + async fn test_delayed_does_not_run_without_trigger() { + let flag = Arc::new(AtomicBool::new(false)); + let flag_ = flag.clone(); + let _lock = RwTaskLock::from_task_delayed(async move { + flag_.store(true, Ordering::Relaxed); + Ok::<_, RwTaskLockError>(999) + }); + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + assert!(!flag.load(Ordering::Relaxed)); + } }