diff --git a/Cargo.toml b/Cargo.toml index ba4be9f..8212c1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,14 +5,17 @@ edition = "2021" description = "Bufferering types for embedded-io" readme = "README.md" repository = "https://github.com/rmja/buffered-io" -authors = ["Rasmus Melchior Jacobsen "] +authors = ["Rasmus Melchior Jacobsen ", "Tage Johansson"] license = "MIT / Apache-2.0" keywords = ["embedded", "buffer", "embedded-io", "read", "write"] exclude = [".github"] +[features] +async = ["dep:embedded-io-async"] + [dependencies] embedded-io = { version = "0.7" } -embedded-io-async = { version = "0.7" } +embedded-io-async = { version = "0.7", optional = true } [dev-dependencies] embedded-io-async = { version = "0.7", features = ["std"] } diff --git a/README.md b/README.md index 8185d2a..b12d3ef 100644 --- a/README.md +++ b/README.md @@ -9,16 +9,14 @@ The `buffered-io` crate implements buffering for the `embedded-io`/`embedded-io- ## Example ```rust -tokio_test::block_on(async { - use buffered_io::asynch::BufferedWrite; - use embedded_io_async::Write; - - let uart_tx = Vec::new(); // The underlying uart peripheral implementing Write to where buffered bytes are written - let mut write_buf = [0; 120]; - let mut buffering = BufferedWrite::new(uart_tx, &mut write_buf); - buffering.write(b"hello").await.unwrap(); // This write is buffered - buffering.write(b" ").await.unwrap(); // This write is also buffered - buffering.write(b"world").await.unwrap(); // This write is also buffered - buffering.flush().await.unwrap(); // The string "hello world" is written to uart in one write -}) -``` \ No newline at end of file +use buffered_io::BufferedWrite; +use embedded_io::Write; + +let uart_tx = Vec::new(); // The underlying uart peripheral implementing Write to where buffered bytes are written +let mut write_buf = [0; 120]; +let mut buffering = BufferedWrite::new(uart_tx, &mut write_buf); +buffering.write(b"hello").unwrap(); // This write is buffered +buffering.write(b" ").unwrap(); // This write is also buffered +buffering.write(b"world").unwrap(); // This write is also buffered +buffering.flush().unwrap(); // The string "hello world" is written to uart in one write +``` diff --git a/src/asynch/mod.rs b/src/asynch/mod.rs deleted file mode 100644 index 8e9ae58..0000000 --- a/src/asynch/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod read; -mod write; - -pub use read::BufferedRead; -pub use write::BufferedWrite; - -/// Unable to bypass the current buffered reader or writer because there are buffered bytes. -#[derive(Debug)] -pub struct BypassError; diff --git a/src/lib.rs b/src/lib.rs index e711e28..3ced8e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,12 @@ #![doc = include_str!("../README.md")] #![cfg_attr(not(test), no_std)] -pub mod asynch; + +mod read; +mod write; + +pub use read::BufferedRead; +pub use write::BufferedWrite; + +/// Unable to bypass the current buffered reader or writer because there are buffered bytes. +#[derive(Debug)] +pub struct BypassError; diff --git a/src/asynch/read.rs b/src/read.rs similarity index 74% rename from src/asynch/read.rs rename to src/read.rs index b2e8a8d..dbcc520 100644 --- a/src/asynch/read.rs +++ b/src/read.rs @@ -1,18 +1,21 @@ -use embedded_io_async::{BufRead, Read, Write}; +#[cfg(feature = "async")] +mod asynch; + +use embedded_io::{BufRead, ErrorType, Read, Write}; use super::BypassError; /// A buffered [`Read`] /// /// The BufferedRead will read into the provided buffer to avoid small reads to the inner reader. -pub struct BufferedRead<'buf, T: Read> { +pub struct BufferedRead<'buf, T> { inner: T, buf: &'buf mut [u8], offset: usize, available: usize, } -impl<'buf, T: Read> BufferedRead<'buf, T> { +impl<'buf, T> BufferedRead<'buf, T> { /// Create a new buffered reader pub fn new(inner: T, buf: &'buf mut [u8]) -> Self { Self { @@ -61,33 +64,33 @@ impl<'buf, T: Read> BufferedRead<'buf, T> { } } -impl embedded_io::ErrorType for BufferedRead<'_, T> { +impl ErrorType for BufferedRead<'_, T> { type Error = T::Error; } impl Write for BufferedRead<'_, T> { - async fn write(&mut self, buf: &[u8]) -> Result { - self.inner.write(buf).await + fn write(&mut self, buf: &[u8]) -> Result { + self.inner.write(buf) } - async fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> { - self.inner.write_all(buf).await + fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + self.inner.write_all(buf) } - async fn flush(&mut self) -> Result<(), Self::Error> { - self.inner.flush().await + fn flush(&mut self) -> Result<(), Self::Error> { + self.inner.flush() } } impl Read for BufferedRead<'_, T> { - async fn read(&mut self, buf: &mut [u8]) -> Result { + fn read(&mut self, buf: &mut [u8]) -> Result { if self.available == 0 { if buf.len() >= self.buf.len() { // Fast path - bypass local buffer - return self.inner.read(buf).await; + return self.inner.read(buf); } self.offset = 0; - self.available = self.inner.read(self.buf).await?; + self.available = self.inner.read(self.buf)?; } let len = usize::min(self.available, buf.len()); @@ -106,10 +109,10 @@ impl Read for BufferedRead<'_, T> { } impl BufRead for BufferedRead<'_, T> { - async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { + fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { if self.available == 0 { self.offset = 0; - self.available = self.inner.read(self.buf).await?; + self.available = self.inner.read(self.buf)?; } Ok(&self.buf[self.offset..self.offset + self.available]) @@ -124,71 +127,70 @@ impl BufRead for BufferedRead<'_, T> { #[cfg(test)] mod tests { - use super::*; + use embedded_io::{BufRead, Read}; + + use super::BufferedRead; - #[tokio::test] - async fn can_read_to_buffer() { + #[test] + fn can_read_to_buffer() { let inner = [1, 2, 3, 4, 5, 6, 7, 8]; let mut buf = [0; 8]; let mut buffered = BufferedRead::new(inner.as_slice(), &mut buf); let mut read_buf = [0; 2]; - assert_eq!(2, buffered.read(&mut read_buf).await.unwrap()); + assert_eq!(2, buffered.read(&mut read_buf).unwrap()); assert_eq!(2, buffered.offset); assert_eq!(6, buffered.available); assert_eq!(&[1, 2], read_buf.as_slice()); let mut read_buf = [0; 2]; - assert_eq!(2, buffered.read(&mut read_buf).await.unwrap()); + assert_eq!(2, buffered.read(&mut read_buf).unwrap()); assert_eq!(4, buffered.offset); assert_eq!(4, buffered.available); assert_eq!(&[3, 4], read_buf.as_slice()); let mut read_buf = [0; 8]; - assert_eq!(4, buffered.read(&mut read_buf).await.unwrap()); + assert_eq!(4, buffered.read(&mut read_buf).unwrap()); assert_eq!(4, buffered.offset); assert_eq!(0, buffered.available); assert_eq!(&[5, 6, 7, 8], &read_buf[..4]); } - #[tokio::test] - async fn bypass_on_large_buf() { + #[test] + fn bypass_on_large_buf() { let inner = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; let mut buf = [0; 8]; let mut buffered = BufferedRead::new(inner.as_slice(), &mut buf); let mut read_buf = [0; 10]; - assert_eq!(10, buffered.read(&mut read_buf).await.unwrap()); + assert_eq!(10, buffered.read(&mut read_buf).unwrap()); assert_eq!(0, buffered.offset); assert_eq!(0, buffered.available); assert_eq!(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], read_buf.as_slice()); } - #[tokio::test] - async fn can_buf_read() { + #[test] + fn can_buf_read() { let inner = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; let mut buf = [0; 8]; let mut buffered = BufferedRead::new(inner.as_slice(), &mut buf); assert_eq!(0, buffered.offset); assert_eq!(0, buffered.available); - assert_eq!( - &[1, 2, 3, 4, 5, 6, 7, 8], - buffered.fill_buf().await.unwrap() - ); + assert_eq!(&[1, 2, 3, 4, 5, 6, 7, 8], buffered.fill_buf().unwrap()); assert_eq!(0, buffered.offset); assert_eq!(8, buffered.available); buffered.consume(2); assert_eq!(2, buffered.offset); assert_eq!(6, buffered.available); - assert_eq!(&[3, 4, 5, 6, 7, 8], buffered.fill_buf().await.unwrap()); + assert_eq!(&[3, 4, 5, 6, 7, 8], buffered.fill_buf().unwrap()); buffered.consume(6); assert_eq!(8, buffered.offset); assert_eq!(0, buffered.available); - assert_eq!(&[9, 10], buffered.fill_buf().await.unwrap()); + assert_eq!(&[9, 10], buffered.fill_buf().unwrap()); assert_eq!(0, buffered.offset); assert_eq!(2, buffered.available); diff --git a/src/read/asynch.rs b/src/read/asynch.rs new file mode 100644 index 0000000..e04f6d4 --- /dev/null +++ b/src/read/asynch.rs @@ -0,0 +1,136 @@ +use embedded_io_async::{BufRead, Read, Write}; + +use super::BufferedRead; + +impl Write for BufferedRead<'_, T> { + async fn write(&mut self, buf: &[u8]) -> Result { + self.inner.write(buf).await + } + + async fn write_all(&mut self, buf: &[u8]) -> Result<(), Self::Error> { + self.inner.write_all(buf).await + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + self.inner.flush().await + } +} + +impl Read for BufferedRead<'_, T> { + async fn read(&mut self, buf: &mut [u8]) -> Result { + if self.available == 0 { + if buf.len() >= self.buf.len() { + // Fast path - bypass local buffer + return self.inner.read(buf).await; + } + self.offset = 0; + self.available = self.inner.read(self.buf).await?; + } + + let len = usize::min(self.available, buf.len()); + buf[..len].copy_from_slice(&self.buf[self.offset..self.offset + len]); + if len < self.available { + // There are still bytes left + self.offset += len; + self.available -= len; + } else { + // The buffer is drained + self.available = 0; + } + + Ok(len) + } +} + +impl BufRead for BufferedRead<'_, T> { + async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> { + if self.available == 0 { + self.offset = 0; + self.available = self.inner.read(self.buf).await?; + } + + Ok(&self.buf[self.offset..self.offset + self.available]) + } + + fn consume(&mut self, amt: usize) { + assert!(amt <= self.available); + self.offset += amt; + self.available -= amt; + } +} + +#[cfg(test)] +mod async_tests { + use super::*; + + #[tokio::test] + async fn can_read_to_buffer() { + let inner = [1, 2, 3, 4, 5, 6, 7, 8]; + let mut buf = [0; 8]; + let mut buffered = BufferedRead::new(inner.as_slice(), &mut buf); + + let mut read_buf = [0; 2]; + assert_eq!(2, buffered.read(&mut read_buf).await.unwrap()); + assert_eq!(2, buffered.offset); + assert_eq!(6, buffered.available); + assert_eq!(&[1, 2], read_buf.as_slice()); + + let mut read_buf = [0; 2]; + assert_eq!(2, buffered.read(&mut read_buf).await.unwrap()); + assert_eq!(4, buffered.offset); + assert_eq!(4, buffered.available); + assert_eq!(&[3, 4], read_buf.as_slice()); + + let mut read_buf = [0; 8]; + assert_eq!(4, buffered.read(&mut read_buf).await.unwrap()); + assert_eq!(4, buffered.offset); + assert_eq!(0, buffered.available); + assert_eq!(&[5, 6, 7, 8], &read_buf[..4]); + } + + #[tokio::test] + async fn bypass_on_large_buf() { + let inner = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + let mut buf = [0; 8]; + let mut buffered = BufferedRead::new(inner.as_slice(), &mut buf); + + let mut read_buf = [0; 10]; + assert_eq!(10, buffered.read(&mut read_buf).await.unwrap()); + assert_eq!(0, buffered.offset); + assert_eq!(0, buffered.available); + assert_eq!(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], read_buf.as_slice()); + } + + #[tokio::test] + async fn can_buf_read() { + let inner = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + let mut buf = [0; 8]; + let mut buffered = BufferedRead::new(inner.as_slice(), &mut buf); + assert_eq!(0, buffered.offset); + assert_eq!(0, buffered.available); + + assert_eq!( + &[1, 2, 3, 4, 5, 6, 7, 8], + buffered.fill_buf().await.unwrap() + ); + assert_eq!(0, buffered.offset); + assert_eq!(8, buffered.available); + + buffered.consume(2); + assert_eq!(2, buffered.offset); + assert_eq!(6, buffered.available); + assert_eq!(&[3, 4, 5, 6, 7, 8], buffered.fill_buf().await.unwrap()); + + buffered.consume(6); + assert_eq!(8, buffered.offset); + assert_eq!(0, buffered.available); + + assert_eq!(&[9, 10], buffered.fill_buf().await.unwrap()); + assert_eq!(0, buffered.offset); + assert_eq!(2, buffered.available); + + buffered.consume(2); + assert_eq!(2, buffered.offset); + assert_eq!(0, buffered.available); + } +} diff --git a/src/asynch/write.rs b/src/write.rs similarity index 76% rename from src/asynch/write.rs rename to src/write.rs index ef6fdf6..a8b6ec1 100644 --- a/src/asynch/write.rs +++ b/src/write.rs @@ -1,17 +1,20 @@ -use embedded_io_async::{Read, Write}; +#[cfg(feature = "async")] +mod asynch; + +use embedded_io::{ErrorType, Read, Write}; use super::BypassError; /// A buffered [`Write`] /// /// The BufferedWrite will write into the provided buffer to avoid small writes to the inner writer. -pub struct BufferedWrite<'buf, T: Write> { +pub struct BufferedWrite<'buf, T> { inner: T, buf: &'buf mut [u8], pos: usize, } -impl<'buf, T: Write> BufferedWrite<'buf, T> { +impl<'buf, T> BufferedWrite<'buf, T> { /// Create a new buffered writer pub fn new(inner: T, buf: &'buf mut [u8]) -> Self { Self { inner, buf, pos: 0 } @@ -68,31 +71,31 @@ impl<'buf, T: Write> BufferedWrite<'buf, T> { } } -impl embedded_io::ErrorType for BufferedWrite<'_, T> { +impl ErrorType for BufferedWrite<'_, T> { type Error = T::Error; } -impl Read for BufferedWrite<'_, T> { - async fn read(&mut self, buf: &mut [u8]) -> Result { - self.inner.read(buf).await +impl Read for BufferedWrite<'_, T> { + fn read(&mut self, buf: &mut [u8]) -> Result { + self.inner.read(buf) } - async fn read_exact( + fn read_exact( &mut self, buf: &mut [u8], ) -> Result<(), embedded_io::ReadExactError> { - self.inner.read_exact(buf).await + self.inner.read_exact(buf) } } impl Write for BufferedWrite<'_, T> { - async fn write(&mut self, buf: &[u8]) -> Result { + fn write(&mut self, buf: &[u8]) -> Result { if buf.is_empty() { return Ok(0); } if self.pos == 0 && buf.len() >= self.buf.len() { // Fast path - nothing in buffer and the buffer to write is large - return self.inner.write(buf).await; + return self.inner.write(buf); } let buffered = usize::min(buf.len(), self.buf.len() - self.pos); @@ -107,7 +110,7 @@ impl Write for BufferedWrite<'_, T> { self.pos = new_pos; } else { // The buffer is full - let written = self.inner.write(self.buf).await?; + let written = self.inner.write(self.buf)?; // We only assign self.pos _after_ we are sure that the write has completed successfully if written < new_pos { @@ -122,83 +125,83 @@ impl Write for BufferedWrite<'_, T> { Ok(buffered) } - async fn flush(&mut self) -> Result<(), Self::Error> { + fn flush(&mut self) -> Result<(), Self::Error> { if self.pos > 0 { - self.inner.write_all(&self.buf[..self.pos]).await?; + self.inner.write_all(&self.buf[..self.pos])?; self.pos = 0; } - self.inner.flush().await + self.inner.flush() } } #[cfg(test)] mod tests { - use embedded_io::{Error, ErrorKind, ErrorType}; + use embedded_io::{Error, ErrorKind, ErrorType, Write}; use super::*; - #[tokio::test] - async fn can_append_to_buffer() { + #[test] + fn can_append_to_buffer() { let mut inner = Vec::new(); let mut buf = [0; 8]; let mut buffered = BufferedWrite::new(&mut inner, &mut buf); - assert_eq!(2, buffered.write(&[1, 2]).await.unwrap()); + assert_eq!(2, buffered.write(&[1, 2]).unwrap()); assert_eq!(2, buffered.pos); assert_eq!(0, buffered.inner.len()); - assert_eq!(2, buffered.write(&[3, 4]).await.unwrap()); + assert_eq!(2, buffered.write(&[3, 4]).unwrap()); assert_eq!(4, buffered.pos); assert_eq!(0, buffered.inner.len()); - assert_eq!(4, buffered.write(&[5, 6, 7, 8]).await.unwrap()); + assert_eq!(4, buffered.write(&[5, 6, 7, 8]).unwrap()); assert_eq!(0, buffered.pos); assert_eq!(8, buffered.inner.len()); assert_eq!(&[1, 2, 3, 4, 5, 6, 7, 8], buffered.inner.as_slice()); } - #[tokio::test] - async fn bypass_large_write_when_empty() { + #[test] + fn bypass_large_write_when_empty() { let mut inner = Vec::new(); let mut buf = [0; 8]; let mut buffered = BufferedWrite::new(&mut inner, &mut buf); - assert_eq!(8, buffered.write(&[1, 2, 3, 4, 5, 6, 7, 8]).await.unwrap()); + assert_eq!(8, buffered.write(&[1, 2, 3, 4, 5, 6, 7, 8]).unwrap()); assert_eq!(0, buffered.pos); assert_eq!(8, buffered.inner.len()); } - #[tokio::test] - async fn large_write_when_not_empty() { + #[test] + fn large_write_when_not_empty() { let mut inner = Vec::new(); let mut buf = [0; 8]; let mut buffered = BufferedWrite::new(&mut inner, &mut buf); - assert_eq!(1, buffered.write(&[1]).await.unwrap()); + assert_eq!(1, buffered.write(&[1]).unwrap()); assert_eq!(1, buffered.pos); assert_eq!(0, buffered.inner.len()); - assert_eq!(7, buffered.write(&[2, 3, 4, 5, 6, 7, 8, 9]).await.unwrap()); + assert_eq!(7, buffered.write(&[2, 3, 4, 5, 6, 7, 8, 9]).unwrap()); assert_eq!(0, buffered.pos); assert_eq!(8, buffered.inner.len()); } - #[tokio::test] - async fn large_write_when_not_empty_can_handle_write_errors() { + #[test] + fn large_write_when_not_empty_can_handle_write_errors() { let mut inner = UnstableWrite::default(); inner.writeable.push(0); // Return error inner.writeable.push(8); // Write all bytes let mut buf = [0; 8]; let mut buffered = BufferedWrite::new(&mut inner, &mut buf); - assert_eq!(1, buffered.write(&[1]).await.unwrap()); + assert_eq!(1, buffered.write(&[1]).unwrap()); assert_eq!(1, buffered.pos); assert_eq!(0, buffered.inner.written.len()); - assert!(buffered.write(&[2, 3, 4, 5, 6, 7, 8]).await.is_err()); + assert!(buffered.write(&[2, 3, 4, 5, 6, 7, 8]).is_err()); - assert_eq!(7, buffered.write(&[2, 3, 4, 5, 6, 7, 8]).await.unwrap()); + assert_eq!(7, buffered.write(&[2, 3, 4, 5, 6, 7, 8]).unwrap()); assert_eq!(0, buffered.pos); assert_eq!(8, buffered.inner.written.len()); } @@ -232,7 +235,7 @@ mod tests { } impl Write for UnstableWrite { - async fn write(&mut self, buf: &[u8]) -> Result { + fn write(&mut self, buf: &[u8]) -> Result { let written = self.writeable[self.writes]; self.writes += 1; if written > 0 { @@ -243,22 +246,22 @@ mod tests { } } - async fn flush(&mut self) -> Result<(), Self::Error> { + fn flush(&mut self) -> Result<(), Self::Error> { Ok(()) } } - #[tokio::test] - async fn flush_clears_buffer() { + #[test] + fn flush_clears_buffer() { let mut inner = Vec::new(); let mut buf = [0; 8]; let mut buffered = BufferedWrite::new(&mut inner, &mut buf); - assert_eq!(2, buffered.write(&[1, 2]).await.unwrap()); + assert_eq!(2, buffered.write(&[1, 2]).unwrap()); assert_eq!(2, buffered.pos); assert_eq!(0, buffered.inner.len()); - buffered.flush().await.unwrap(); + buffered.flush().unwrap(); assert_eq!(0, buffered.pos); assert_eq!(2, buffered.inner.len()); } diff --git a/src/write/asynch.rs b/src/write/asynch.rs new file mode 100644 index 0000000..afa3afc --- /dev/null +++ b/src/write/asynch.rs @@ -0,0 +1,196 @@ +use embedded_io_async::{Read, Write}; + +use super::BufferedWrite; + +impl Read for BufferedWrite<'_, T> { + async fn read(&mut self, buf: &mut [u8]) -> Result { + self.inner.read(buf).await + } + + async fn read_exact( + &mut self, + buf: &mut [u8], + ) -> Result<(), embedded_io::ReadExactError> { + self.inner.read_exact(buf).await + } +} + +impl Write for BufferedWrite<'_, T> { + async fn write(&mut self, buf: &[u8]) -> Result { + if buf.is_empty() { + return Ok(0); + } + if self.pos == 0 && buf.len() >= self.buf.len() { + // Fast path - nothing in buffer and the buffer to write is large + return self.inner.write(buf).await; + } + + let buffered = usize::min(buf.len(), self.buf.len() - self.pos); + assert!(buffered > 0); + + let mut new_pos = self.pos; + self.buf[new_pos..new_pos + buffered].copy_from_slice(&buf[..buffered]); + new_pos += buffered; + + if new_pos < self.buf.len() { + // The buffer to write could fit in the buffer + self.pos = new_pos; + } else { + // The buffer is full + let written = self.inner.write(self.buf).await?; + + // We only assign self.pos _after_ we are sure that the write has completed successfully + if written < new_pos { + // We only partially wrote the inner buffer + self.buf.copy_within(written..new_pos, 0); + self.pos = new_pos - written; + } else { + self.pos = 0; + } + } + + Ok(buffered) + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + if self.pos > 0 { + self.inner.write_all(&self.buf[..self.pos]).await?; + self.pos = 0; + } + + self.inner.flush().await + } +} + +#[cfg(test)] +mod async_tests { + use embedded_io::{Error, ErrorKind, ErrorType}; + use embedded_io_async::Write; + + use super::BufferedWrite; + + #[tokio::test] + async fn can_append_to_buffer() { + let mut inner = Vec::new(); + let mut buf = [0; 8]; + let mut buffered = BufferedWrite::new(&mut inner, &mut buf); + + assert_eq!(2, buffered.write(&[1, 2]).await.unwrap()); + assert_eq!(2, buffered.pos); + assert_eq!(0, buffered.inner.len()); + + assert_eq!(2, buffered.write(&[3, 4]).await.unwrap()); + assert_eq!(4, buffered.pos); + assert_eq!(0, buffered.inner.len()); + + assert_eq!(4, buffered.write(&[5, 6, 7, 8]).await.unwrap()); + assert_eq!(0, buffered.pos); + assert_eq!(8, buffered.inner.len()); + assert_eq!(&[1, 2, 3, 4, 5, 6, 7, 8], buffered.inner.as_slice()); + } + + #[tokio::test] + async fn bypass_large_write_when_empty() { + let mut inner = Vec::new(); + let mut buf = [0; 8]; + let mut buffered = BufferedWrite::new(&mut inner, &mut buf); + + assert_eq!(8, buffered.write(&[1, 2, 3, 4, 5, 6, 7, 8]).await.unwrap()); + assert_eq!(0, buffered.pos); + assert_eq!(8, buffered.inner.len()); + } + + #[tokio::test] + async fn large_write_when_not_empty() { + let mut inner = Vec::new(); + let mut buf = [0; 8]; + let mut buffered = BufferedWrite::new(&mut inner, &mut buf); + + assert_eq!(1, buffered.write(&[1]).await.unwrap()); + assert_eq!(1, buffered.pos); + assert_eq!(0, buffered.inner.len()); + + assert_eq!(7, buffered.write(&[2, 3, 4, 5, 6, 7, 8, 9]).await.unwrap()); + assert_eq!(0, buffered.pos); + assert_eq!(8, buffered.inner.len()); + } + + #[tokio::test] + async fn large_write_when_not_empty_can_handle_write_errors() { + let mut inner = UnstableWrite::default(); + inner.writeable.push(0); // Return error + inner.writeable.push(8); // Write all bytes + let mut buf = [0; 8]; + let mut buffered = BufferedWrite::new(&mut inner, &mut buf); + + assert_eq!(1, buffered.write(&[1]).await.unwrap()); + assert_eq!(1, buffered.pos); + assert_eq!(0, buffered.inner.written.len()); + + assert!(buffered.write(&[2, 3, 4, 5, 6, 7, 8]).await.is_err()); + + assert_eq!(7, buffered.write(&[2, 3, 4, 5, 6, 7, 8]).await.unwrap()); + assert_eq!(0, buffered.pos); + assert_eq!(8, buffered.inner.written.len()); + } + + #[derive(Default)] + struct UnstableWrite { + written: Vec, + writes: usize, + writeable: Vec, + } + + #[derive(Debug)] + struct UnstableError; + + impl core::fmt::Display for UnstableError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "UnstableError") + } + } + + impl std::error::Error for UnstableError {} + + impl Error for UnstableError { + fn kind(&self) -> ErrorKind { + ErrorKind::Other + } + } + + impl ErrorType for UnstableWrite { + type Error = UnstableError; + } + + impl Write for UnstableWrite { + async fn write(&mut self, buf: &[u8]) -> Result { + let written = self.writeable[self.writes]; + self.writes += 1; + if written > 0 { + self.written.extend_from_slice(&buf[..written]); + Ok(written) + } else { + Err(UnstableError) + } + } + + async fn flush(&mut self) -> Result<(), Self::Error> { + Ok(()) + } + } + + #[tokio::test] + async fn flush_clears_buffer() { + let mut inner = Vec::new(); + let mut buf = [0; 8]; + let mut buffered = BufferedWrite::new(&mut inner, &mut buf); + + assert_eq!(2, buffered.write(&[1, 2]).await.unwrap()); + assert_eq!(2, buffered.pos); + assert_eq!(0, buffered.inner.len()); + + buffered.flush().await.unwrap(); + assert_eq!(0, buffered.pos); + assert_eq!(2, buffered.inner.len()); + } +}