diff --git a/datafusion-cli/src/object_storage/instrumented.rs b/datafusion-cli/src/object_storage/instrumented.rs index c4b63b417fe4..a3352f9a2121 100644 --- a/datafusion-cli/src/object_storage/instrumented.rs +++ b/datafusion-cli/src/object_storage/instrumented.rs @@ -35,7 +35,7 @@ use datafusion::{ error::DataFusionError, execution::object_store::{DefaultObjectStoreRegistry, ObjectStoreRegistry}, }; -use futures::stream::BoxStream; +use futures::stream::{BoxStream, Stream}; use object_store::{ path::Path, GetOptions, GetRange, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, PutMultipartOptions, PutOptions, PutPayload, PutResult, Result, @@ -43,6 +43,58 @@ use object_store::{ use parking_lot::{Mutex, RwLock}; use url::Url; +/// A stream wrapper that measures the time until the first response(item or end of stream) is yielded +struct TimeToFirstItemStream { + inner: S, + start: Instant, + request_index: usize, + requests: Arc>>, + first_item_yielded: bool, +} + +impl TimeToFirstItemStream { + fn new( + inner: S, + start: Instant, + request_index: usize, + requests: Arc>>, + ) -> Self { + Self { + inner, + start, + request_index, + requests, + first_item_yielded: false, + } + } +} + +impl Stream for TimeToFirstItemStream +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let poll_result = std::pin::Pin::new(&mut self.inner).poll_next(cx); + + if !self.first_item_yielded && poll_result.is_ready() { + self.first_item_yielded = true; + let elapsed = self.start.elapsed(); + + let mut requests = self.requests.lock(); + if let Some(request) = requests.get_mut(self.request_index) { + request.duration = Some(elapsed); + } + } + + poll_result + } +} + /// The profiling mode to use for an [`InstrumentedObjectStore`] instance. Collecting profiling /// data will have a small negative impact on both CPU and memory usage. Default is `Disabled` #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] @@ -91,7 +143,7 @@ impl From for InstrumentedObjectStoreMode { pub struct InstrumentedObjectStore { inner: Arc, instrument_mode: AtomicU8, - requests: Mutex>, + requests: Arc>>, } impl InstrumentedObjectStore { @@ -100,7 +152,7 @@ impl InstrumentedObjectStore { Self { inner: object_store, instrument_mode, - requests: Mutex::new(Vec::new()), + requests: Arc::new(Mutex::new(Vec::new())), } } @@ -218,19 +270,31 @@ impl InstrumentedObjectStore { prefix: Option<&Path>, ) -> BoxStream<'static, Result> { let timestamp = Utc::now(); - let ret = self.inner.list(prefix); + let start = Instant::now(); + let inner_stream = self.inner.list(prefix); + + let request_index = { + let mut requests = self.requests.lock(); + requests.push(RequestDetails { + op: Operation::List, + path: prefix.cloned().unwrap_or_else(|| Path::from("")), + timestamp, + duration: None, + size: None, + range: None, + extra_display: None, + }); + requests.len() - 1 + }; - self.requests.lock().push(RequestDetails { - op: Operation::List, - path: prefix.cloned().unwrap_or_else(|| Path::from("")), - timestamp, - duration: None, // list returns a stream, so the duration isn't meaningful - size: None, - range: None, - extra_display: None, - }); + let wrapped_stream = TimeToFirstItemStream::new( + inner_stream, + start, + request_index, + Arc::clone(&self.requests), + ); - ret + Box::pin(wrapped_stream) } async fn instrumented_list_with_delimiter( @@ -758,6 +822,7 @@ impl ObjectStoreRegistry for InstrumentedObjectStoreRegistry { #[cfg(test)] mod tests { + use futures::StreamExt; use object_store::WriteMultipart; use super::*; @@ -896,13 +961,15 @@ mod tests { instrumented.set_instrument_mode(InstrumentedObjectStoreMode::Trace); assert!(instrumented.requests.lock().is_empty()); - let _ = instrumented.list(Some(&path)); + let mut stream = instrumented.list(Some(&path)); + // Consume at least one item from the stream to trigger duration measurement + let _ = stream.next().await; assert_eq!(instrumented.requests.lock().len(), 1); let request = instrumented.take_requests().pop().unwrap(); assert_eq!(request.op, Operation::List); assert_eq!(request.path, path); - assert!(request.duration.is_none()); + assert!(request.duration.is_some()); assert!(request.size.is_none()); assert!(request.range.is_none()); assert!(request.extra_display.is_none());