From e7b920ab19a4193c5ab39efac2dfe8b16c94ba7f Mon Sep 17 00:00:00 2001 From: di Date: Fri, 14 Nov 2025 01:09:21 +0800 Subject: [PATCH] basic structure; rewrites the driver loop --- cas_client/src/download_utils.rs | 116 +++++++---- cas_client/src/lib.rs | 1 + cas_client/src/memory_cache.rs | 42 ++++ cas_client/src/remote_client.rs | 334 +++++++++++++++++++++++++++++-- chunk_cache/src/memory.rs | 147 ++++++++++++++ 5 files changed, 586 insertions(+), 54 deletions(-) create mode 100644 cas_client/src/memory_cache.rs create mode 100644 chunk_cache/src/memory.rs diff --git a/cas_client/src/download_utils.rs b/cas_client/src/download_utils.rs index 9fddfc826..e24abdb0f 100644 --- a/cas_client/src/download_utils.rs +++ b/cas_client/src/download_utils.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::collections::HashMap; use std::io::Write; use std::sync::{Arc, Mutex, RwLock}; @@ -15,10 +16,12 @@ use http::header::RANGE; use merklehash::MerkleHash; use reqwest::Response; use reqwest_middleware::ClientWithMiddleware; +use tokio::io::AsyncWriteExt; use tracing::{debug, error, info, trace, warn}; use url::Url; use utils::singleflight::Group; +use crate::SequentialOutput; use crate::error::{CasClientError, Result}; use crate::http_client::Api; use crate::output_provider::SeekingOutputProvider; @@ -281,7 +284,7 @@ impl FetchTermDownload { } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub(crate) struct ChunkRangeWrite { pub chunk_range: ChunkRange, pub unpacked_length: u32, @@ -290,58 +293,95 @@ pub(crate) struct ChunkRangeWrite { pub writer_offset: u64, } +impl ChunkRangeWrite { + pub fn write_to_seek_writer(&self, download: &TermDownloadOutput, writer: &SeekingOutputProvider) -> Result { + let mut writer = writer.get_writer_at(self.writer_offset)?; + let data_sub_range_sliced = self.derive_write_bytes(&download); + writer.write_all(data_sub_range_sliced)?; + writer.flush()?; + Ok(self.take) + } + + pub async fn write_to_sequential_writer( + &self, + download: &TermDownloadOutput, + writer: &mut SequentialOutput, + ) -> Result { + let data_sub_range_sliced = self.derive_write_bytes(&download); + writer.write_all(data_sub_range_sliced).await?; + writer.flush().await?; + Ok(self.take) + } + + fn derive_write_bytes<'a>(&self, download: &'a TermDownloadOutput) -> &'a [u8] { + let TermDownloadOutput { + data, + chunk_byte_indices, + chunk_range, + } = download; + debug_assert_eq!(chunk_byte_indices.len(), (chunk_range.end - chunk_range.start + 1) as usize); + debug_assert_eq!(*chunk_byte_indices.last().expect("checked len is something") as usize, data.len()); + + debug_assert!(self.chunk_range.start >= chunk_range.start); + debug_assert!(self.chunk_range.end > chunk_range.start); + debug_assert!( + self.chunk_range.start < chunk_range.end, + "{} < {} ;;; write {:?} term {:?}", + self.chunk_range.start, + chunk_range.end, + self.chunk_range, + chunk_range + ); + debug_assert!(self.chunk_range.end <= chunk_range.end); + + let start_chunk_offset_index = self.chunk_range.start - chunk_range.start; + let end_chunk_offset_index = self.chunk_range.end - chunk_range.start; + let start_chunk_offset = chunk_byte_indices[start_chunk_offset_index as usize] as usize; + let end_chunk_offset = chunk_byte_indices[end_chunk_offset_index as usize] as usize; + let data_sub_range = &data[start_chunk_offset..end_chunk_offset]; + debug_assert_eq!(data_sub_range.len(), self.unpacked_length as usize); + + debug_assert!(data_sub_range.len() as u64 >= self.skip_bytes + self.take); + &data_sub_range[(self.skip_bytes as usize)..((self.skip_bytes + self.take) as usize)] + } +} + +#[derive(Debug)] +pub(crate) struct DownloadAndWriteAllSequential { + pub download: FetchTermDownload, + pub xorb_hash: MerkleHash, + pub writes: Vec, +} + +impl DownloadAndWriteAllSequential { + pub async fn run(self) -> Result)>> { + let download_result = self.download.run().await?; + Ok(TermDownloadResult { + payload: (download_result.payload, self.xorb_hash, self.writes), + duration: download_result.duration, + n_retries_on_403: download_result.n_retries_on_403, + }) + } +} + /// Helper object containing the structs needed when downloading and writing a term in parallel /// during reconstruction. #[derive(Debug)] -pub(crate) struct FetchTermDownloadOnceAndWriteEverywhereUsed { +pub(crate) struct DownloadAndWriteAllParallel { pub download: FetchTermDownload, - // pub write_offset: u64, // start position of the writer to write to pub output: SeekingOutputProvider, pub writes: Vec, } -impl FetchTermDownloadOnceAndWriteEverywhereUsed { +impl DownloadAndWriteAllParallel { /// Download the term and write it to the underlying storage, retry on 403 pub async fn run(self) -> Result> { let download_result = self.download.run().await?; - let TermDownloadOutput { - data, - chunk_byte_indices, - chunk_range, - } = download_result.payload; - debug_assert_eq!(chunk_byte_indices.len(), (chunk_range.end - chunk_range.start + 1) as usize); - debug_assert_eq!(*chunk_byte_indices.last().expect("checked len is something") as usize, data.len()); // write out the data let mut total_written = 0; for write in self.writes { - debug_assert!(write.chunk_range.start >= chunk_range.start); - debug_assert!(write.chunk_range.end > chunk_range.start); - debug_assert!( - write.chunk_range.start < chunk_range.end, - "{} < {} ;;; write {:?} term {:?}", - write.chunk_range.start, - chunk_range.end, - write.chunk_range, - chunk_range - ); - debug_assert!(write.chunk_range.end <= chunk_range.end); - - let start_chunk_offset_index = write.chunk_range.start - chunk_range.start; - let end_chunk_offset_index = write.chunk_range.end - chunk_range.start; - let start_chunk_offset = chunk_byte_indices[start_chunk_offset_index as usize] as usize; - let end_chunk_offset = chunk_byte_indices[end_chunk_offset_index as usize] as usize; - let data_sub_range = &data[start_chunk_offset..end_chunk_offset]; - debug_assert_eq!(data_sub_range.len(), write.unpacked_length as usize); - - debug_assert!(data_sub_range.len() as u64 >= write.skip_bytes + write.take); - let data_sub_range_sliced = - &data_sub_range[(write.skip_bytes as usize)..((write.skip_bytes + write.take) as usize)]; - - let mut writer = self.output.get_writer_at(write.writer_offset)?; - writer.write_all(data_sub_range_sliced)?; - writer.flush()?; - total_written += write.take; + total_written += write.write_to_seek_writer(&download_result.payload, &self.output)?; } Ok(TermDownloadResult { diff --git a/cas_client/src/lib.rs b/cas_client/src/lib.rs index 91b377254..d47cfe966 100644 --- a/cas_client/src/lib.rs +++ b/cas_client/src/lib.rs @@ -20,6 +20,7 @@ mod http_client; mod interface; #[cfg(not(target_family = "wasm"))] mod local_client; +mod memory_cache; #[cfg(not(target_family = "wasm"))] mod output_provider; pub mod remote_client; diff --git a/cas_client/src/memory_cache.rs b/cas_client/src/memory_cache.rs new file mode 100644 index 000000000..37e427d41 --- /dev/null +++ b/cas_client/src/memory_cache.rs @@ -0,0 +1,42 @@ +// A single-threaded cache with Bélády's optimal replacement policy, with optional disk back up. + +use std::collections::HashMap; +use std::fmt::Display; + +use bytes::Bytes; +use serde::Serialize; +use tempfile::TempDir; + +pub struct ClairvoyantHybridCache { + memory: HashMap, + disk: Option, +} + +impl ClairvoyantHybridCache {} + +struct DiskStorage { + _tempdir: TempDir, +} + +impl DiskStorage { + fn put(&self, k: K, v: &[u8]) -> std::io::Result<()> + where + K: Display, + { + todo!() + } + + fn get(&self, k: &K) -> std::io::Result + where + K: Display, + { + todo!() + } + + fn remove(&self, k: &K) + where + K: Display, + { + todo!() + } +} diff --git a/cas_client/src/remote_client.rs b/cas_client/src/remote_client.rs index 89f2e7543..240bb5f40 100644 --- a/cas_client/src/remote_client.rs +++ b/cas_client/src/remote_client.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::mem::take; use std::path::PathBuf; use std::sync::Arc; @@ -34,6 +34,7 @@ use xet_runtime::{GlobalSemaphoreHandle, XetRuntime, global_semaphore_handle}; use crate::download_utils::*; use crate::error::{CasClientError, Result}; use crate::http_client::{Api, ResponseErrorLogger, RetryConfig}; +use crate::memory_cache::MemoryCache; #[cfg(not(target_family = "wasm"))] use crate::output_provider::{SeekingOutputProvider, SequentialOutput}; use crate::retry_wrapper::RetryWrapper; @@ -135,7 +136,79 @@ pub(crate) async fn get_reconstruction_with_endpoint_and_client( #[cfg(not(target_family = "wasm"))] #[allow(clippy::too_many_arguments)] -pub(crate) async fn map_fetch_info_into_download_tasks( +pub(crate) async fn map_fetch_info_into_download_and_write_sequential_tasks( + segment: Arc, + terms: Vec, + offset_into_first_range: u64, + base_write_negative_offset: u64, + chunk_cache: Option>, + client: Arc, + range_download_single_flight: Arc>, +) -> Result> { + // the actual segment length. + // the file_range end may actually exceed the file total length for the last segment. + // in that case, the maximum length of this segment will be the total of all terms given + // minus the start offset + let seg_len = segment + .file_range + .length() + .min(terms.iter().fold(0, |acc, term| acc + term.unpacked_length as u64) - offset_into_first_range); + + let initial_writer_offset = segment.file_range.start - base_write_negative_offset; + let mut total_taken = 0; + + let mut fetch_info_term_map: HashMap<(MerkleHash, ChunkRange), DownloadAndWriteAllSequential> = HashMap::new(); + for (i, term) in terms.into_iter().enumerate() { + let (individual_fetch_info, _) = segment.find((term.hash, term.range)).await?; + + let skip_bytes = if i == 0 { offset_into_first_range } else { 0 }; + // amount to take is min of the whole term after skipped bytes or the remainder of the segment + let take = (term.unpacked_length as u64 - skip_bytes).min(seg_len - total_taken); + let write_term = ChunkRangeWrite { + // term details + chunk_range: term.range, + unpacked_length: term.unpacked_length, + + // write details + skip_bytes, + take, + writer_offset: initial_writer_offset + total_taken, + }; + + let task = fetch_info_term_map + .entry((term.hash.into(), individual_fetch_info.range)) + .or_insert_with(|| DownloadAndWriteAllSequential { + download: FetchTermDownload { + hash: term.hash.into(), + range: individual_fetch_info.range, + fetch_info: segment.clone(), + chunk_cache: chunk_cache.clone(), + client: client.clone(), + range_download_single_flight: range_download_single_flight.clone(), + }, + xorb_hash: term.hash.into(), + writes: vec![], + }); + task.writes.push(write_term); + + total_taken += take; + } + + let mut tasks: Vec = fetch_info_term_map.into_values().collect(); + + // Sort by the write position of the first term that will use (part) of the downloaded range, + // those with preceding write position get to download first, so that the download result + // is used immediately instead of being cached. + // For each task, the corresponding write positions are already sorted (so the first term is at + // index 0). This is because we append terms in sequential order to respective tasks. + tasks.sort_by_key(|t| t.writes[0].writer_offset); + + Ok(tasks) +} + +#[cfg(not(target_family = "wasm"))] +#[allow(clippy::too_many_arguments)] +pub(crate) async fn map_fetch_info_into_download_and_write_parallel_tasks( segment: Arc, terms: Vec, offset_into_first_range: u64, @@ -144,7 +217,7 @@ pub(crate) async fn map_fetch_info_into_download_tasks( client: Arc, range_download_single_flight: Arc>, output_provider: &SeekingOutputProvider, -) -> Result> { +) -> Result> { // the actual segment length. // the file_range end may actually exceed the file total length for the last segment. // in that case, the maximum length of this segment will be the total of all terms given @@ -157,8 +230,7 @@ pub(crate) async fn map_fetch_info_into_download_tasks( let initial_writer_offset = segment.file_range.start - base_write_negative_offset; let mut total_taken = 0; - let mut fetch_info_term_map: HashMap<(MerkleHash, ChunkRange), FetchTermDownloadOnceAndWriteEverywhereUsed> = - HashMap::new(); + let mut fetch_info_term_map: HashMap<(MerkleHash, ChunkRange), DownloadAndWriteAllParallel> = HashMap::new(); for (i, term) in terms.into_iter().enumerate() { let (individual_fetch_info, _) = segment.find((term.hash, term.range)).await?; @@ -178,7 +250,7 @@ pub(crate) async fn map_fetch_info_into_download_tasks( let task = fetch_info_term_map .entry((term.hash.into(), individual_fetch_info.range)) - .or_insert_with(|| FetchTermDownloadOnceAndWriteEverywhereUsed { + .or_insert_with(|| DownloadAndWriteAllParallel { download: FetchTermDownload { hash: term.hash.into(), range: individual_fetch_info.range, @@ -333,6 +405,237 @@ impl RemoteClient { Ok(response) } + // Segmented download such that the file reconstruction and fetch info is not queried in its entirety + // at the beginning of the download, but queried in segments. Range downloads are executed with + // a certain degree of parallelism, but writing out to storage is sequential. Ideal when the external + // storage uses HDDs. + #[instrument(skip_all, name = "RemoteClient::reconstruct_file_segmented", fields(file.hash = file_hash.hex() + ))] + async fn reconstruct_file_to_writer_segmented_sequential_write_2( + &self, + file_hash: &MerkleHash, + byte_range: Option, + mut writer: SequentialOutput, + progress_updater: Option>, + ) -> Result { + let call_id = FN_CALL_ID.fetch_add(1, Ordering::Relaxed); + info!( + call_id, + %file_hash, + ?byte_range, + "Starting reconstruct_file_to_writer_segmented", + ); + + // Use an unlimited queue size, as queue size is inherently bounded by degree of concurrency. + let mut task_queue: VecDeque> = VecDeque::new(); + let mut running_downloads: JoinSet< + Result<(TermDownloadResult<(TermDownloadOutput, MerkleHash, Vec)>, OwnedSemaphorePermit)>, + > = JoinSet::new(); + + // derive the actual range to reconstruct + let file_reconstruct_range = byte_range.unwrap_or_else(FileRange::full); + let base_write_negative_offset = file_reconstruct_range.start; + + // kick-start the download by enqueue the fetch info task. + task_queue.push_back(DownloadQueueItem::Metadata(FetchInfo::new( + *file_hash, + file_reconstruct_range, + self.endpoint.clone(), + self.authenticated_http_client_with_retry.clone(), + ))); + + // Start the queue processing logic + // + // If the queue item is `DownloadQueueItem::Metadata`, it fetches the file reconstruction info + // of the first segment, whose size is linear to `num_concurrent_range_gets`. Once fetched, term + // download tasks are enqueued and spawned with the degree of concurrency equal to `num_concurrent_range_gets`. + // After the above, a task that defines fetching the remainder of the file reconstruction info is enqueued, + // which will execute after the first of the above term download tasks finishes. + let term_download_client = self.http_client_with_retry.clone(); + let download_scheduler = DownloadSegmentLengthTuner::from_configurable_constants(); + + let download_concurrency_limiter = + XetRuntime::current().global_semaphore(*DOWNLOAD_CHUNK_RANGE_CONCURRENCY_LIMITER); + + info!(concurrency_limit = *NUM_CONCURRENT_RANGE_GETS, "Starting segmented download"); + + let mut data_available_writes: HashMap = HashMap::new(); + let coalesced_range_reuse_cache = if self.chunk_cache.is_some() { + None + } else { + Some(MemoryCache::default()) + }; + + async fn write_all_if_data_avaiable( + write_pos: &mut u64, + writer: &mut SequentialOutput, + data_available_writes: &mut HashMap, + coalesced_range_reuse_cache: &Option, + persistent_cache: &Option>, + ) -> Result { + let mut write_len = 0; + loop { + if let Some(write) = data_available_writes.remove(write_pos) { + if let Some(cache) = coalesced_range_reuse_cache { + } else if let Some(cache) = persistent_cache { + } + } + } + + Ok(write_len) + } + + async fn process_result( + (result, permit): ( + TermDownloadResult<(TermDownloadOutput, MerkleHash, Vec)>, + OwnedSemaphorePermit, + ), + write_pos: &mut u64, + writer: &mut SequentialOutput, + data_available_writes: &mut HashMap, + coalesced_range_reuse_cache: &Option, + persistent_cache: &Option>, + ) -> Result { + let mut write_len = 0; + + let (download, xorb_hash, writes) = result.payload; + + // write out if the term at this position uses part of the downloaded range + let mut i = 0; + while let Some(write) = writes.get(i) + && *write_pos == write.writer_offset + { + let len = write.write_to_sequential_writer(&download, writer).await?; + write_len += len; + *write_pos += len; + i += 1; + } + + // cache the data for later writes + if let Some(cache) = coalesced_range_reuse_cache + && i < writes.len() + { + let ttl: Vec = writes[i..].iter().map(|w| w.writer_offset).collect(); + for w in writes[i..].iter() { + data_available_writes.insert(w.writer_offset, *w); + } + cache + .put(&xorb_hash, download.chunk_range, download.chunk_byte_indices, download.data, ttl) + .await?; + } + + drop(permit); + write_len += write_all_if_data_avaiable( + write_pos, + writer, + data_available_writes, + coalesced_range_reuse_cache, + persistent_cache, + ) + .await?; + + Ok(write_len) + } + + let mut write_pos = 0; + while let Some(item) = task_queue.pop_front() { + // first try to join some tasks + while let Some(result) = running_downloads.try_join_next() { + let write_len = process_result( + result??, + &mut write_pos, + &mut writer, + &mut data_available_writes, + &coalesced_range_reuse_cache, + &self.chunk_cache, + ) + .await?; + if let Some(updater) = progress_updater.as_ref() { + updater.update(write_len).await; + } + } + + match item { + DownloadQueueItem::End => { + // everything processed + debug!(call_id, "download queue emptied"); + break; + }, + DownloadQueueItem::DownloadTask(term_download) => { + // acquire the permit before spawning the task, so that there's limited + // number of active downloads. + let permit = download_concurrency_limiter.clone().acquire_owned().await?; + debug!(call_id, "spawning 1 download task"); + running_downloads.spawn(async move { + let data = term_download.run().await?; + Ok((data, permit)) + }); + }, + DownloadQueueItem::Metadata(fetch_info) => { + // query for the file info of the first segment + let segment_size = download_scheduler.next_segment_size()?; + debug!(call_id, segment_size, "querying file info"); + let (segment, maybe_remainder) = fetch_info.take_segment(segment_size); + + let Some((offset_into_first_range, terms)) = segment.query().await? else { + // signal termination + task_queue.push_back(DownloadQueueItem::End); + continue; + }; + + let segment = Arc::new(segment); + + // define the term download tasks + let tasks = map_fetch_info_into_download_and_write_sequential_tasks( + segment.clone(), + terms, + offset_into_first_range, + base_write_negative_offset, + self.chunk_cache.clone(), + term_download_client.clone(), + self.range_download_single_flight.clone(), + ) + .await?; + + debug!(call_id, num_tasks = tasks.len(), "enqueueing download tasks"); + for task_def in tasks { + task_queue.push_back(DownloadQueueItem::DownloadTask(task_def)); + } + + // enqueue the remainder of file info fetch task + if let Some(remainder) = maybe_remainder { + task_queue.push_back(DownloadQueueItem::Metadata(remainder)); + } else { + task_queue.push_back(DownloadQueueItem::End); + } + }, + } + } + + while let Some(result) = running_downloads.join_next().await { + let write_len = process_result( + result??, + &mut write_pos, + &mut writer, + &mut data_available_writes, + &coalesced_range_reuse_cache, + &self.chunk_cache, + ) + .await?; + if let Some(updater) = progress_updater.as_ref() { + updater.update(write_len).await; + } + } + + info!( + call_id, + %file_hash, + ?byte_range, + "Completed reconstruct_file_to_writer_segmented" + ); + Ok(write_pos) + } + // Segmented download such that the file reconstruction and fetch info is not queried in its entirety // at the beginning of the download, but queried in segments. Range downloads are executed with // a certain degree of parallelism, but writing out to storage is sequential. Ideal when the external @@ -525,8 +828,7 @@ impl RemoteClient { ); // Use the unlimited queue, as queue size is inherently bounded by degree of concurrency. - let (task_tx, mut task_rx) = - mpsc::unbounded_channel::>(); + let mut task_queue: VecDeque> = VecDeque::new(); let mut running_downloads = JoinSet::>>::new(); // derive the actual range to reconstruct @@ -534,12 +836,12 @@ impl RemoteClient { let base_write_negative_offset = file_reconstruct_range.start; // kick-start the download by enqueue the fetch info task. - task_tx.send(DownloadQueueItem::Metadata(FetchInfo::new( + task_queue.push_back(DownloadQueueItem::Metadata(FetchInfo::new( *file_hash, file_reconstruct_range, self.endpoint.clone(), self.authenticated_http_client_with_retry.clone(), - )))?; + ))); // Start the queue processing logic // @@ -567,7 +869,7 @@ impl RemoteClient { }; let mut total_written = 0; - while let Some(item) = task_rx.recv().await { + while let Some(item) = task_queue.pop_front() { // first try to join some tasks while let Some(result) = running_downloads.try_join_next() { let write_len = process_result(result??, &mut total_written, &download_scheduler)?; @@ -602,14 +904,14 @@ impl RemoteClient { let Some((offset_into_first_range, terms)) = segment.query().await? else { // signal termination - task_tx.send(DownloadQueueItem::End)?; + task_queue.push_back(DownloadQueueItem::End); continue; }; let segment = Arc::new(segment); // define the term download tasks - let tasks = map_fetch_info_into_download_tasks( + let tasks = map_fetch_info_into_download_and_write_parallel_tasks( segment.clone(), terms, offset_into_first_range, @@ -623,14 +925,14 @@ impl RemoteClient { debug!(call_id, num_tasks = tasks.len(), "enqueueing download tasks"); for task_def in tasks { - task_tx.send(DownloadQueueItem::DownloadTask(task_def))?; + task_queue.push_back(DownloadQueueItem::DownloadTask(task_def)); } // enqueue the remainder of file info fetch task if let Some(remainder) = maybe_remainder { - task_tx.send(DownloadQueueItem::Metadata(remainder))?; + task_queue.push_back(DownloadQueueItem::Metadata(remainder)); } else { - task_tx.send(DownloadQueueItem::End)?; + task_queue.push_back(DownloadQueueItem::End); } }, } diff --git a/chunk_cache/src/memory.rs b/chunk_cache/src/memory.rs new file mode 100644 index 000000000..f3716aa5f --- /dev/null +++ b/chunk_cache/src/memory.rs @@ -0,0 +1,147 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use cas_types::ChunkRange; +use merklehash::MerkleHash; +use tokio::sync::RwLock; + +use crate::CacheRange; +use crate::error::ChunkCacheError; + +#[derive(Debug, Clone)] +struct MemoryCacheItem { + range: ChunkRange, + chunk_byte_indices: Vec, + data: Vec, + ttl: Vec, +} + +#[derive(Debug, Clone, Default)] +struct CacheState { + inner: HashMap>, + num_items: usize, + total_bytes: u64, +} + +impl CacheState { + fn find_match(&self, key: &MerkleHash, range: &ChunkRange) -> Option<&MemoryCacheItem> { + let items = self.inner.get(key)?; + + items + .iter() + .find(|&item| item.range.start <= range.start && range.end <= item.range.end) + .map(|v| v as _) + } +} + +/// MemoryCache is a ChunkCache implementor that stores data in memory +#[derive(Debug, Clone, Default)] +pub struct MemoryCache { + state: Arc>, +} + +impl MemoryCache { + pub async fn num_items(&self) -> usize { + self.state.read().await.num_items + } + + pub async fn total_bytes(&self) -> u64 { + self.state.read().await.total_bytes + } + + pub async fn get(&self, key: &MerkleHash) -> Result { + if range.start >= range.end { + return Err(ChunkCacheError::InvalidArguments); + } + + let state = self.state.read().await; + let Some(cache_item) = state.find_match(key, range) else { + return Ok(None); + }; + + // Extract the requested range from the cached item + let start_idx = (range.start - cache_item.range.start) as usize; + let end_idx = (range.end - cache_item.range.start) as usize; + + if end_idx >= cache_item.chunk_byte_indices.len() { + return Err(ChunkCacheError::BadRange); + } + + let start_byte = cache_item.chunk_byte_indices[start_idx] as usize; + let end_byte = cache_item.chunk_byte_indices[end_idx] as usize; + + if end_byte > cache_item.data.len() { + return Err(ChunkCacheError::BadRange); + } + + let data = cache_item.data[start_byte..end_byte].to_vec(); + let offsets: Vec = cache_item.chunk_byte_indices[start_idx..=end_idx] + .iter() + .map(|v| v - cache_item.chunk_byte_indices[start_idx]) + .collect(); + + Ok(Some(CacheRange { + offsets, + data, + range: *range, + })) + } + + pub async fn put( + &self, + key: &MerkleHash, + range: ChunkRange, + chunk_byte_indices: Vec, + data: Vec, + ttl: Vec, + ) -> Result<(), ChunkCacheError> { + // Validate inputs + if range.start >= range.end + || chunk_byte_indices.len() != (range.end - range.start + 1) as usize + || chunk_byte_indices.is_empty() + || chunk_byte_indices[0] != 0 + || *chunk_byte_indices.last().unwrap() as usize != data.len() + || !strictly_increasing(&chunk_byte_indices) + { + return Err(ChunkCacheError::InvalidArguments); + } + + let data_len = data.len() as u64; + + let mut state = self.state.write().await; + + // Check if we already have this exact range cached + if let Some(items) = state.inner.get(key) { + for item in items.iter() { + if item.range == range { + // Already cached + return Ok(()); + } + } + } + + // Add the new item + let cache_item = MemoryCacheItem { + range, + chunk_byte_indices, + data, + ttl, + }; + + state.total_bytes += data_len; + state.num_items += 1; + + state.inner.entry(key.clone()).or_insert_with(Vec::new).push(cache_item); + + Ok(()) + } +} + +fn strictly_increasing(chunk_byte_indices: &[u32]) -> bool { + for i in 1..chunk_byte_indices.len() { + if chunk_byte_indices[i - 1] >= chunk_byte_indices[i] { + return false; + } + } + true +}