diff --git a/etherparse/Cargo.toml b/etherparse/Cargo.toml index c8179526..c9bcc003 100644 --- a/etherparse/Cargo.toml +++ b/etherparse/Cargo.toml @@ -21,7 +21,8 @@ rust-version = "1.83.0" [features] default = ["std"] -std = ["arrayvec/std"] +alloc = [] +std = ["alloc", "arrayvec/std"] [dependencies] arrayvec = { version = "0.7.2", default-features = false } diff --git a/etherparse/src/err/packet/build_slice_write_error.rs b/etherparse/src/err/packet/build_slice_write_error.rs new file mode 100644 index 00000000..1369cdea --- /dev/null +++ b/etherparse/src/err/packet/build_slice_write_error.rs @@ -0,0 +1,110 @@ +use crate::err::{ipv4_exts, ipv6_exts, SliceWriteSpaceError, ValueTooBigError}; + +/// Error while serializing a packet into a byte slice. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub enum BuildSliceWriteError { + /// Not enough space is available in the target slice. + /// Contains the minimum required length. + Space(usize), + + /// Error if the length of the payload is too + /// big to be representable by the length fields. + PayloadLen(ValueTooBigError), + + /// Error if the IPv4 extensions can not be serialized + /// because of internal consistency errors (i.e. a header + /// is never). + Ipv4Exts(ipv4_exts::ExtsWalkError), + + /// Error if the IPv6 extensions can not be serialized + /// because of internal consistency errors. + Ipv6Exts(ipv6_exts::ExtsWalkError), + + /// Error if ICMPv6 is packaged in an IPv4 packet (it is undefined + /// how to calculate the checksum). + Icmpv6InIpv4, + + /// Address size defined in the ARP header does not match the actual size. + ArpHeaderNotMatch, +} + +impl From> for BuildSliceWriteError { + fn from(value: ValueTooBigError) -> Self { + BuildSliceWriteError::PayloadLen(value) + } +} + +impl From for BuildSliceWriteError { + fn from(value: SliceWriteSpaceError) -> Self { + BuildSliceWriteError::Space(value.required_len) + } +} + +impl From for BuildSliceWriteError { + fn from(value: super::TransportChecksumError) -> Self { + match value { + super::TransportChecksumError::PayloadLen(err) => BuildSliceWriteError::PayloadLen(err), + super::TransportChecksumError::Icmpv6InIpv4 => BuildSliceWriteError::Icmpv6InIpv4, + } + } +} + +impl From> + for BuildSliceWriteError +{ + fn from(value: crate::WriteError) -> Self { + match value { + crate::WriteError::Io(err) => BuildSliceWriteError::Space(err.required_len), + crate::WriteError::Content(err) => BuildSliceWriteError::Ipv4Exts(err), + } + } +} + +impl From> + for BuildSliceWriteError +{ + fn from(value: crate::WriteError) -> Self { + match value { + crate::WriteError::Io(err) => BuildSliceWriteError::Space(err.required_len), + crate::WriteError::Content(err) => BuildSliceWriteError::Ipv6Exts(err), + } + } +} + +impl core::fmt::Display for BuildSliceWriteError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + use BuildSliceWriteError::*; + match self { + Space(required_len) => write!( + f, + "Not enough space to write packet to slice. Needed {} byte(s).", + required_len + ), + PayloadLen(err) => err.fmt(f), + Ipv4Exts(err) => err.fmt(f), + Ipv6Exts(err) => err.fmt(f), + ArpHeaderNotMatch => write!( + f, + "address size defined in the ARP header does not match the actual size" + ), + Icmpv6InIpv4 => write!( + f, + "Error: ICMPv6 can not be combined with an IPv4 headers (checksum can not be calculated)." + ), + } + } +} + +impl core::error::Error for BuildSliceWriteError { + fn source(&self) -> Option<&(dyn core::error::Error + 'static)> { + use BuildSliceWriteError::*; + match self { + Space(_) => None, + PayloadLen(err) => Some(err), + Ipv4Exts(err) => Some(err), + Ipv6Exts(err) => Some(err), + Icmpv6InIpv4 => None, + ArpHeaderNotMatch => None, + } + } +} diff --git a/etherparse/src/err/packet/build_vec_write_error.rs b/etherparse/src/err/packet/build_vec_write_error.rs new file mode 100644 index 00000000..6174d57f --- /dev/null +++ b/etherparse/src/err/packet/build_vec_write_error.rs @@ -0,0 +1,97 @@ +use crate::err::{ipv4_exts, ipv6_exts, ValueTooBigError}; +use core::convert::Infallible; + +/// Error while serializing a packet into a [`alloc::vec::Vec`]. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub enum BuildVecWriteError { + /// Error if the length of the payload is too + /// big to be representable by the length fields. + PayloadLen(ValueTooBigError), + + /// Error if the IPv4 extensions can not be serialized + /// because of internal consistency errors (i.e. a header + /// is never). + Ipv4Exts(ipv4_exts::ExtsWalkError), + + /// Error if the IPv6 extensions can not be serialized + /// because of internal consistency errors. + Ipv6Exts(ipv6_exts::ExtsWalkError), + + /// Error if ICMPv6 is packaged in an IPv4 packet (it is undefined + /// how to calculate the checksum). + Icmpv6InIpv4, + + /// Address size defined in the ARP header does not match the actual size. + ArpHeaderNotMatch, +} + +impl From for BuildVecWriteError { + fn from(value: Infallible) -> Self { + match value {} + } +} + +impl From> for BuildVecWriteError { + fn from(value: ValueTooBigError) -> Self { + BuildVecWriteError::PayloadLen(value) + } +} + +impl From for BuildVecWriteError { + fn from(value: super::TransportChecksumError) -> Self { + match value { + super::TransportChecksumError::PayloadLen(err) => BuildVecWriteError::PayloadLen(err), + super::TransportChecksumError::Icmpv6InIpv4 => BuildVecWriteError::Icmpv6InIpv4, + } + } +} + +impl From> for BuildVecWriteError { + fn from(value: crate::WriteError) -> Self { + match value { + crate::WriteError::Io(err) => match err {}, + crate::WriteError::Content(err) => BuildVecWriteError::Ipv4Exts(err), + } + } +} + +impl From> for BuildVecWriteError { + fn from(value: crate::WriteError) -> Self { + match value { + crate::WriteError::Io(err) => match err {}, + crate::WriteError::Content(err) => BuildVecWriteError::Ipv6Exts(err), + } + } +} + +impl core::fmt::Display for BuildVecWriteError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + use BuildVecWriteError::*; + match self { + PayloadLen(err) => err.fmt(f), + Ipv4Exts(err) => err.fmt(f), + Ipv6Exts(err) => err.fmt(f), + ArpHeaderNotMatch => write!( + f, + "address size defined in the ARP header does not match the actual size" + ), + Icmpv6InIpv4 => write!( + f, + "Error: ICMPv6 can not be combined with an IPv4 headers (checksum can not be calculated)." + ), + } + } +} + +impl core::error::Error for BuildVecWriteError { + fn source(&self) -> Option<&(dyn core::error::Error + 'static)> { + use BuildVecWriteError::*; + match self { + PayloadLen(err) => Some(err), + Ipv4Exts(err) => Some(err), + Ipv6Exts(err) => Some(err), + Icmpv6InIpv4 => None, + ArpHeaderNotMatch => None, + } + } +} diff --git a/etherparse/src/err/packet/build_write_error.rs b/etherparse/src/err/packet/build_write_error.rs index 455cb4a1..32d8ad75 100644 --- a/etherparse/src/err/packet/build_write_error.rs +++ b/etherparse/src/err/packet/build_write_error.rs @@ -105,6 +105,50 @@ impl core::error::Error for BuildWriteError { } } +#[cfg(feature = "std")] +impl From for BuildWriteError { + fn from(value: std::io::Error) -> Self { + BuildWriteError::Io(value) + } +} + +#[cfg(feature = "std")] +impl From> for BuildWriteError { + fn from(value: ValueTooBigError) -> Self { + BuildWriteError::PayloadLen(value) + } +} + +#[cfg(feature = "std")] +impl From for BuildWriteError { + fn from(value: super::TransportChecksumError) -> Self { + match value { + super::TransportChecksumError::PayloadLen(err) => BuildWriteError::PayloadLen(err), + super::TransportChecksumError::Icmpv6InIpv4 => BuildWriteError::Icmpv6InIpv4, + } + } +} + +#[cfg(feature = "std")] +impl From> for BuildWriteError { + fn from(value: crate::WriteError) -> Self { + match value { + crate::WriteError::Io(err) => BuildWriteError::Io(err), + crate::WriteError::Content(err) => BuildWriteError::Ipv4Exts(err), + } + } +} + +#[cfg(feature = "std")] +impl From> for BuildWriteError { + fn from(value: crate::WriteError) -> Self { + match value { + crate::WriteError::Io(err) => BuildWriteError::Io(err), + crate::WriteError::Content(err) => BuildWriteError::Ipv6Exts(err), + } + } +} + #[cfg(test)] mod tests { use super::{BuildWriteError::*, *}; diff --git a/etherparse/src/err/packet/mod.rs b/etherparse/src/err/packet/mod.rs index fd8bacbe..b9ffc14e 100644 --- a/etherparse/src/err/packet/mod.rs +++ b/etherparse/src/err/packet/mod.rs @@ -3,6 +3,14 @@ mod build_write_error; #[cfg(feature = "std")] pub use build_write_error::*; +#[cfg(feature = "alloc")] +mod build_vec_write_error; +#[cfg(feature = "alloc")] +pub use build_vec_write_error::*; + +mod build_slice_write_error; +pub use build_slice_write_error::*; + mod slice_error; pub use slice_error::*; diff --git a/etherparse/src/lib.rs b/etherparse/src/lib.rs index bd227581..04babfaf 100644 --- a/etherparse/src/lib.rs +++ b/etherparse/src/lib.rs @@ -303,7 +303,7 @@ // for docs.rs #![cfg_attr(docsrs, feature(doc_cfg))] -#[cfg(test)] +#[cfg(any(feature = "alloc", test))] extern crate alloc; #[cfg(test)] extern crate proptest; @@ -343,6 +343,9 @@ mod compositions_tests; mod helpers; pub(crate) use helpers::*; +mod writer; +pub(crate) use writer::*; + mod lax_packet_headers; pub use lax_packet_headers::*; @@ -358,9 +361,7 @@ pub(crate) use lax_sliced_packet_cursor::*; mod len_source; pub use len_source::*; -#[cfg(feature = "std")] mod packet_builder; -#[cfg(feature = "std")] pub use crate::packet_builder::*; mod packet_headers; diff --git a/etherparse/src/link/single_vlan_header_slice.rs b/etherparse/src/link/single_vlan_header_slice.rs index 2406cc6d..6bfbeb96 100644 --- a/etherparse/src/link/single_vlan_header_slice.rs +++ b/etherparse/src/link/single_vlan_header_slice.rs @@ -43,6 +43,7 @@ impl<'a> SingleVlanHeaderSlice<'a> { /// The caller must ensured that the given slice has the length of /// [`SingleVlanHeader::LEN`] #[inline] + #[cfg(feature = "std")] pub(crate) unsafe fn from_slice_unchecked(slice: &[u8]) -> SingleVlanHeaderSlice { SingleVlanHeaderSlice { slice } } diff --git a/etherparse/src/net/ipv4_exts.rs b/etherparse/src/net/ipv4_exts.rs index 1c9ff7e6..beb9a4b6 100644 --- a/etherparse/src/net/ipv4_exts.rs +++ b/etherparse/src/net/ipv4_exts.rs @@ -160,21 +160,19 @@ impl Ipv4Extensions { } /// Write the extensions to the writer. - #[cfg(feature = "std")] - #[cfg_attr(docsrs, doc(cfg(feature = "std")))] - pub fn write( + pub(crate) fn write_internal( &self, writer: &mut T, start_ip_number: IpNumber, - ) -> Result<(), err::ipv4_exts::HeaderWriteError> { - use err::ipv4_exts::{ExtsWalkError::*, HeaderWriteError::*}; + ) -> Result<(), WriteError> { + use err::ipv4_exts::ExtsWalkError::*; use ip_number::*; match self.auth { Some(ref header) => { if AUTH == start_ip_number { - header.write(writer).map_err(Io) + writer.write_all(&header.to_bytes()).map_err(WriteError::Io) } else { - Err(Content(ExtNotReferenced { + Err(WriteError::Content(ExtNotReferenced { missing_ext: IpNumber::AUTHENTICATION_HEADER, })) } @@ -183,6 +181,21 @@ impl Ipv4Extensions { } } + /// Write the extensions to the writer. + #[cfg(feature = "std")] + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] + pub fn write( + &self, + writer: &mut T, + start_ip_number: IpNumber, + ) -> Result<(), err::ipv4_exts::HeaderWriteError> { + self.write_internal(&mut IoWriter(writer), start_ip_number) + .map_err(|err| match err { + WriteError::Io(err) => err::ipv4_exts::HeaderWriteError::Io(err), + WriteError::Content(err) => err::ipv4_exts::HeaderWriteError::Content(err), + }) + } + ///Length of the all present headers in bytes. pub fn header_len(&self) -> usize { if let Some(ref header) = self.auth { diff --git a/etherparse/src/net/ipv6_exts.rs b/etherparse/src/net/ipv6_exts.rs index afbd104d..c814924f 100644 --- a/etherparse/src/net/ipv6_exts.rs +++ b/etherparse/src/net/ipv6_exts.rs @@ -639,17 +639,13 @@ impl Ipv6Extensions { /// /// It is required that all next header are correctly set in the headers /// and no other ipv6 header extensions follow this header. If this is not - /// the case an [`err::ipv6_exts::HeaderWriteError::Content`] error is - /// returned. - #[cfg(feature = "std")] - #[cfg_attr(docsrs, doc(cfg(feature = "std")))] - pub fn write( + /// the case a [`err::ipv6_exts::ExtsWalkError`] is returned. + pub(crate) fn write_internal( &self, writer: &mut T, first_header: IpNumber, - ) -> Result<(), err::ipv6_exts::HeaderWriteError> { + ) -> Result<(), WriteError> { use err::ipv6_exts::ExtsWalkError::*; - use err::ipv6_exts::HeaderWriteError::*; use ip_number::*; /// Struct flagging if a header needs to be written. @@ -681,7 +677,9 @@ impl Ipv6Extensions { // check if hop by hop header should be written first if IPV6_HOP_BY_HOP == next_header { let header = &self.hop_by_hop_options.as_ref().unwrap(); - header.write(writer).map_err(Io)?; + writer + .write_all(&header.to_bytes()) + .map_err(WriteError::Io)?; next_header = header.next_header; needs_write.hop_by_hop_options = false; } @@ -698,7 +696,7 @@ impl Ipv6Extensions { // by hop if it is not part of this extensions struct. if needs_write.hop_by_hop_options { // the hop by hop header is only allowed at the start - return Err(Content(HopByHopNotAtStart)); + return Err(WriteError::Content(HopByHopNotAtStart)); } else { break; } @@ -715,7 +713,9 @@ impl Ipv6Extensions { .final_destination_options .as_ref() .unwrap(); - header.write(writer).map_err(Io)?; + writer + .write_all(&header.to_bytes()) + .map_err(WriteError::Io)?; next_header = header.next_header; needs_write.final_destination_options = false; } else { @@ -723,7 +723,9 @@ impl Ipv6Extensions { } } else if needs_write.destination_options { let header = &self.destination_options.as_ref().unwrap(); - header.write(writer).map_err(Io)?; + writer + .write_all(&header.to_bytes()) + .map_err(WriteError::Io)?; next_header = header.next_header; needs_write.destination_options = false; } else { @@ -733,7 +735,9 @@ impl Ipv6Extensions { IPV6_ROUTE => { if needs_write.routing { let header = &self.routing.as_ref().unwrap().routing; - header.write(writer).map_err(Io)?; + writer + .write_all(&header.to_bytes()) + .map_err(WriteError::Io)?; next_header = header.next_header; needs_write.routing = false; // for destination options @@ -745,7 +749,9 @@ impl Ipv6Extensions { IPV6_FRAG => { if needs_write.fragment { let header = &self.fragment.as_ref().unwrap(); - header.write(writer).map_err(Io)?; + writer + .write_all(&header.to_bytes()) + .map_err(WriteError::Io)?; next_header = header.next_header; needs_write.fragment = false; } else { @@ -755,7 +761,9 @@ impl Ipv6Extensions { AUTH => { if needs_write.auth { let header = &self.auth.as_ref().unwrap(); - header.write(writer).map_err(Io)?; + writer + .write_all(&header.to_bytes()) + .map_err(WriteError::Io)?; next_header = header.next_header; needs_write.auth = false; } else { @@ -771,27 +779,27 @@ impl Ipv6Extensions { // check that all header have been written if needs_write.hop_by_hop_options { - Err(Content(ExtNotReferenced { + Err(WriteError::Content(ExtNotReferenced { missing_ext: IpNumber::IPV6_HEADER_HOP_BY_HOP, })) } else if needs_write.destination_options { - Err(Content(ExtNotReferenced { + Err(WriteError::Content(ExtNotReferenced { missing_ext: IpNumber::IPV6_DESTINATION_OPTIONS, })) } else if needs_write.routing { - Err(Content(ExtNotReferenced { + Err(WriteError::Content(ExtNotReferenced { missing_ext: IpNumber::IPV6_ROUTE_HEADER, })) } else if needs_write.fragment { - Err(Content(ExtNotReferenced { + Err(WriteError::Content(ExtNotReferenced { missing_ext: IpNumber::IPV6_FRAGMENTATION_HEADER, })) } else if needs_write.auth { - Err(Content(ExtNotReferenced { + Err(WriteError::Content(ExtNotReferenced { missing_ext: IpNumber::AUTHENTICATION_HEADER, })) } else if needs_write.final_destination_options { - Err(Content(ExtNotReferenced { + Err(WriteError::Content(ExtNotReferenced { missing_ext: IpNumber::IPV6_DESTINATION_OPTIONS, })) } else { @@ -799,6 +807,28 @@ impl Ipv6Extensions { } } + /// Writes the given headers to a writer based on the order defined in + /// the next_header fields of the headers and the first header_id + /// passed to this function. + /// + /// It is required that all next header are correctly set in the headers + /// and no other ipv6 header extensions follow this header. If this is not + /// the case an [`err::ipv6_exts::HeaderWriteError::Content`] error is + /// returned. + #[cfg(feature = "std")] + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] + pub fn write( + &self, + writer: &mut T, + first_header: IpNumber, + ) -> Result<(), err::ipv6_exts::HeaderWriteError> { + self.write_internal(&mut IoWriter(writer), first_header) + .map_err(|err| match err { + WriteError::Io(err) => err::ipv6_exts::HeaderWriteError::Io(err), + WriteError::Content(err) => err::ipv6_exts::HeaderWriteError::Content(err), + }) + } + /// Length of the all present headers in bytes. pub fn header_len(&self) -> usize { let mut result = 0; diff --git a/etherparse/src/packet_builder.rs b/etherparse/src/packet_builder.rs index 0036dfd1..9a926eed 100644 --- a/etherparse/src/packet_builder.rs +++ b/etherparse/src/packet_builder.rs @@ -1,8 +1,58 @@ +use crate::err::packet::BuildSliceWriteError; +#[cfg(feature = "alloc")] +use crate::err::packet::BuildVecWriteError; +#[cfg(feature = "std")] use crate::err::packet::BuildWriteError; use super::*; -use std::{io, marker}; +use core::marker; + +#[cfg(feature = "alloc")] +use alloc::vec::Vec; + +impl From for BuildSliceWriteError { + fn from(value: SliceCoreWriteError) -> Self { + BuildSliceWriteError::Space(value.required_len) + } +} + +impl From> + for BuildSliceWriteError +{ + fn from(value: crate::WriteError) -> Self { + match value { + crate::WriteError::Io(err) => err.into(), + crate::WriteError::Content(err) => BuildSliceWriteError::Ipv4Exts(err), + } + } +} + +impl From> + for BuildSliceWriteError +{ + fn from(value: crate::WriteError) -> Self { + match value { + crate::WriteError::Io(err) => err.into(), + crate::WriteError::Content(err) => BuildSliceWriteError::Ipv6Exts(err), + } + } +} + +fn final_write_to_slice( + builder: PacketBuilderStep, + buffer: &mut [u8], + payload: &[u8], +) -> Result { + let required = final_size(&builder, payload.len()); + let slice = buffer + .get_mut(..required) + .ok_or(BuildSliceWriteError::Space(required))?; + + let mut writer = SliceCoreWrite::new(slice); + final_write_with_net::<_, _, BuildSliceWriteError>(builder, &mut writer, payload)?; + Ok(required) +} /// Helper for building packets. /// @@ -100,7 +150,6 @@ use std::{io, marker}; /// * [`PacketBuilderStep::write`] /// * [`PacketBuilderStep::size`] /// -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub struct PacketBuilder {} impl PacketBuilder { @@ -375,13 +424,11 @@ struct PacketImpl { } ///An unfinished packet that is build with the packet builder -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub struct PacketBuilderStep { state: PacketImpl, _marker: marker::PhantomData, } -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] impl PacketBuilderStep { /// Add an IPv4 header /// @@ -730,7 +777,6 @@ impl PacketBuilderStep { } } -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] impl PacketBuilderStep { /// Add an ip header (length, protocol/next_header & checksum fields will be overwritten based on the rest of the packet). /// @@ -914,7 +960,6 @@ impl PacketBuilderStep { } } -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] impl PacketBuilderStep { ///Add an ip header (length, protocol/next_header & checksum fields will be overwritten based on the rest of the packet). /// @@ -1098,7 +1143,6 @@ impl PacketBuilderStep { } } -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] impl PacketBuilderStep { /// Adds an ICMPv4 header of the given [`Icmpv4Type`] to the packet. /// @@ -1636,7 +1680,8 @@ impl PacketBuilderStep { /// `last_next_header_ip_number` will be set in the last extension header /// or if no extension header exists the ip header as the "next header" or /// "protocol number". - pub fn write( + #[cfg(feature = "std")] + pub fn write( mut self, writer: &mut T, last_next_header_ip_number: IpNumber, @@ -1651,7 +1696,50 @@ impl PacketBuilderStep { } _ => {} } - final_write_with_net(self, writer, payload) + final_write_with_net::<_, _, BuildWriteError>(self, &mut IoWriter(writer), payload) + } + + /// Write all the headers and the payload to a [`Vec`]. + #[cfg(feature = "alloc")] + #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] + pub fn write_to_vec( + mut self, + buffer: &mut Vec, + last_next_header_ip_number: IpNumber, + payload: &[u8], + ) -> Result<(), BuildVecWriteError> { + match &mut (self.state.net_header) { + Some(NetHeaders::Ipv4(ref mut ip, ref mut exts)) => { + ip.protocol = exts.set_next_headers(last_next_header_ip_number); + } + Some(NetHeaders::Ipv6(ref mut ip, ref mut exts)) => { + ip.next_header = exts.set_next_headers(last_next_header_ip_number); + } + _ => {} + } + final_write_with_net::<_, _, BuildVecWriteError>(self, &mut VecWriter(buffer), payload) + } + + /// Write all the headers and the payload to a byte slice. + /// + /// Returns the number of bytes written. + pub fn write_to_slice( + mut self, + buffer: &mut [u8], + last_next_header_ip_number: IpNumber, + payload: &[u8], + ) -> Result { + match &mut (self.state.net_header) { + Some(NetHeaders::Ipv4(ref mut ip, ref mut exts)) => { + ip.protocol = exts.set_next_headers(last_next_header_ip_number); + } + Some(NetHeaders::Ipv6(ref mut ip, ref mut exts)) => { + ip.next_header = exts.set_next_headers(last_next_header_ip_number); + } + _ => {} + } + + final_write_to_slice(self, buffer, payload) } ///Returns the size of the packet when it is serialized @@ -1660,15 +1748,37 @@ impl PacketBuilderStep { } } -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] impl PacketBuilderStep { /// Write all the headers and the payload. - pub fn write( + #[cfg(feature = "std")] + pub fn write( self, writer: &mut T, payload: &[u8], ) -> Result<(), BuildWriteError> { - final_write_with_net(self, writer, payload) + final_write_with_net::<_, _, BuildWriteError>(self, &mut IoWriter(writer), payload) + } + + /// Write all the headers and the payload to a [`Vec`]. + #[cfg(feature = "alloc")] + #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] + pub fn write_to_vec( + self, + buffer: &mut Vec, + payload: &[u8], + ) -> Result<(), BuildVecWriteError> { + final_write_with_net::<_, _, BuildVecWriteError>(self, &mut VecWriter(buffer), payload) + } + + /// Write all the headers and the payload to a byte slice. + /// + /// Returns the number of bytes written. + pub fn write_to_slice( + self, + buffer: &mut [u8], + payload: &[u8], + ) -> Result { + final_write_to_slice(self, buffer, payload) } /// Returns the size of the packet when it is serialized @@ -1677,15 +1787,37 @@ impl PacketBuilderStep { } } -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] impl PacketBuilderStep { ///Write all the headers and the payload. - pub fn write( + #[cfg(feature = "std")] + pub fn write( self, writer: &mut T, payload: &[u8], ) -> Result<(), BuildWriteError> { - final_write_with_net(self, writer, payload) + final_write_with_net::<_, _, BuildWriteError>(self, &mut IoWriter(writer), payload) + } + + /// Write all the headers and the payload to a [`Vec`]. + #[cfg(feature = "alloc")] + #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] + pub fn write_to_vec( + self, + buffer: &mut Vec, + payload: &[u8], + ) -> Result<(), BuildVecWriteError> { + final_write_with_net::<_, _, BuildVecWriteError>(self, &mut VecWriter(buffer), payload) + } + + /// Write all the headers and the payload to a byte slice. + /// + /// Returns the number of bytes written. + pub fn write_to_slice( + self, + buffer: &mut [u8], + payload: &[u8], + ) -> Result { + final_write_to_slice(self, buffer, payload) } ///Returns the size of the packet when it is serialized @@ -1694,15 +1826,37 @@ impl PacketBuilderStep { } } -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] impl PacketBuilderStep { ///Write all the headers and the payload. - pub fn write( + #[cfg(feature = "std")] + pub fn write( self, writer: &mut T, payload: &[u8], ) -> Result<(), BuildWriteError> { - final_write_with_net(self, writer, payload) + final_write_with_net::<_, _, BuildWriteError>(self, &mut IoWriter(writer), payload) + } + + /// Write all the headers and the payload to a [`Vec`]. + #[cfg(feature = "alloc")] + #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] + pub fn write_to_vec( + self, + buffer: &mut Vec, + payload: &[u8], + ) -> Result<(), BuildVecWriteError> { + final_write_with_net::<_, _, BuildVecWriteError>(self, &mut VecWriter(buffer), payload) + } + + /// Write all the headers and the payload to a byte slice. + /// + /// Returns the number of bytes written. + pub fn write_to_slice( + self, + buffer: &mut [u8], + payload: &[u8], + ) -> Result { + final_write_to_slice(self, buffer, payload) } ///Returns the size of the packet when it is serialized @@ -1711,7 +1865,6 @@ impl PacketBuilderStep { } } -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] impl PacketBuilderStep { ///Set ns flag (ECN-nonce - concealment protection; experimental: see RFC 3540) pub fn ns(mut self) -> PacketBuilderStep { @@ -1858,12 +2011,35 @@ impl PacketBuilderStep { } ///Write all the headers and the payload. - pub fn write( + #[cfg(feature = "std")] + pub fn write( self, writer: &mut T, payload: &[u8], ) -> Result<(), BuildWriteError> { - final_write_with_net(self, writer, payload) + final_write_with_net::<_, _, BuildWriteError>(self, &mut IoWriter(writer), payload) + } + + /// Write all the headers and the payload to a [`Vec`]. + #[cfg(feature = "alloc")] + #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] + pub fn write_to_vec( + self, + buffer: &mut Vec, + payload: &[u8], + ) -> Result<(), BuildVecWriteError> { + final_write_with_net::<_, _, BuildVecWriteError>(self, &mut VecWriter(buffer), payload) + } + + /// Write all the headers and the payload to a byte slice. + /// + /// Returns the number of bytes written. + pub fn write_to_slice( + self, + buffer: &mut [u8], + payload: &[u8], + ) -> Result { + final_write_to_slice(self, buffer, payload) } ///Returns the size of the packet when it is serialized @@ -1872,25 +2048,47 @@ impl PacketBuilderStep { } } -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] impl PacketBuilderStep { - pub fn write(self, writer: &mut T) -> Result<(), BuildWriteError> { - final_write_with_net(self, writer, &[])?; + #[cfg(feature = "std")] + pub fn write(self, writer: &mut T) -> Result<(), BuildWriteError> { + final_write_with_net::<_, _, BuildWriteError>(self, &mut IoWriter(writer), &[])?; Ok(()) } + /// Write all the headers to a [`Vec`]. + #[cfg(feature = "alloc")] + #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))] + pub fn write_to_vec(self, buffer: &mut Vec) -> Result<(), BuildVecWriteError> { + final_write_with_net::<_, _, BuildVecWriteError>(self, &mut VecWriter(buffer), &[])?; + Ok(()) + } + + /// Write all the headers to a byte slice. + /// + /// Returns the number of bytes written. + pub fn write_to_slice(self, buffer: &mut [u8]) -> Result { + final_write_to_slice(self, buffer, &[]) + } + pub fn size(&self) -> usize { final_size(self, 0) } } /// Write all the headers and the payload. -fn final_write_with_net( +fn final_write_with_net( builder: PacketBuilderStep, - writer: &mut T, + writer: &mut W, payload: &[u8], -) -> Result<(), BuildWriteError> { - use BuildWriteError::*; +) -> Result<(), E> +where + W: CoreWrite + Sized, + E: From + + From> + + From> + + From> + + From, +{ use NetHeaders::*; // unpack builder (makes things easier with the borrow checker) @@ -1921,7 +2119,7 @@ fn final_write_with_net( None => net_ether_type, } }; - eth.write(writer).map_err(Io)?; + writer.write_all(ð.to_bytes()).map_err(E::from)?; } LinkHeader::LinuxSll(mut linux_sll) => { // Assumes that next layers are ether based. If more types of @@ -1929,7 +2127,7 @@ fn final_write_with_net( debug_assert_eq!(linux_sll.arp_hrd_type, ArpHardwareId::ETHERNET); linux_sll.protocol_type.change_value(net_ether_type.into()); - linux_sll.write(writer).map_err(Io)?; + writer.write_all(&linux_sll.to_bytes()).map_err(E::from)?; } } } @@ -1941,15 +2139,15 @@ fn final_write_with_net( //set ether types value.ether_type = net_ether_type; //serialize - value.write(writer).map_err(Io)?; + writer.write_all(&value.to_bytes()).map_err(E::from)?; } Some(Double(mut value)) => { //set ether types value.outer.ether_type = ether_type::VLAN_TAGGED_FRAME; value.inner.ether_type = net_ether_type; //serialize - value.outer.write(writer).map_err(Io)?; - value.inner.write(writer).map_err(Io)?; + writer.write_all(&value.outer.to_bytes()).map_err(E::from)?; + writer.write_all(&value.inner.to_bytes()).map_err(E::from)?; } None => {} } @@ -1976,7 +2174,7 @@ fn final_write_with_net( + transport.as_ref().map(|v| v.header_len()).unwrap_or(0) + payload.len(), ) - .map_err(PayloadLen)?; + .map_err(E::from)?; if let Some(transport) = &transport { ip.protocol = ip_exts.set_next_headers(match &transport { @@ -1988,24 +2186,15 @@ fn final_write_with_net( } // write ip header & extensions - ip.write(writer).map_err(Io)?; - ip_exts.write(writer, ip.protocol).map_err(|err| { - use err::ipv4_exts::HeaderWriteError as I; - match err { - I::Io(err) => Io(err), - I::Content(err) => Ipv4Exts(err), - } - })?; + ip.header_checksum = ip.calc_header_checksum(); + writer.write_all(&ip.to_bytes()).map_err(E::from)?; + ip_exts + .write_internal(writer, ip.protocol) + .map_err(E::from)?; // update the transport layer checksum if let Some(t) = &mut transport { - t.update_checksum_ipv4(&ip, payload).map_err(|err| { - use err::packet::TransportChecksumError as I; - match err { - I::PayloadLen(err) => PayloadLen(err), - I::Icmpv6InIpv4 => Icmpv6InIpv4, - } - })?; + t.update_checksum_ipv4(&ip, payload).map_err(E::from)?; } } Some(NetHeaders::Ipv6(mut ip, mut ip_exts)) => { @@ -2015,7 +2204,7 @@ fn final_write_with_net( + transport.as_ref().map(|v| v.header_len()).unwrap_or(0) + payload.len(), ) - .map_err(PayloadLen)?; + .map_err(E::from)?; if let Some(transport) = &transport { ip.next_header = ip_exts.set_next_headers(match &transport { @@ -2027,37 +2216,41 @@ fn final_write_with_net( } // write ip header & extensions - ip.write(writer).map_err(Io)?; - ip_exts.write(writer, ip.next_header).map_err(|err| { - use err::ipv6_exts::HeaderWriteError as I; - match err { - I::Io(err) => Io(err), - I::Content(err) => Ipv6Exts(err), - } - })?; + writer.write_all(&ip.to_bytes()).map_err(E::from)?; + ip_exts + .write_internal(writer, ip.next_header) + .map_err(E::from)?; // update the transport layer checksum if let Some(t) = &mut transport { - t.update_checksum_ipv6(&ip, payload).map_err(PayloadLen)?; + t.update_checksum_ipv6(&ip, payload).map_err(E::from)?; } } Some(NetHeaders::Arp(arp)) => { - writer.write_all(&arp.to_bytes()).map_err(Io)?; + writer.write_all(&arp.to_bytes()).map_err(E::from)?; } None => {} } // write transport header if let Some(transport) = transport { - transport.write(writer).map_err(Io)?; + match transport { + TransportHeader::Icmpv4(value) => { + writer.write_all(&value.to_bytes()).map_err(E::from)? + } + TransportHeader::Icmpv6(value) => { + writer.write_all(&value.to_bytes()).map_err(E::from)? + } + TransportHeader::Udp(value) => writer.write_all(&value.to_bytes()).map_err(E::from)?, + TransportHeader::Tcp(value) => writer.write_all(&value.to_bytes()).map_err(E::from)?, + } } // and finally the payload - writer.write_all(payload).map_err(Io)?; + writer.write_all(payload).map_err(E::from)?; Ok(()) } - ///Returns the size of the packet when it is serialized fn final_size(builder: &PacketBuilderStep, payload_size: usize) -> usize { use crate::NetHeaders::*; @@ -2111,7 +2304,7 @@ mod white_box_tests { #[should_panic] fn final_write_panic_missing_ip() { let mut writer = Vec::new(); - final_write_with_net( + final_write_with_net::<_, _, BuildVecWriteError>( PacketBuilderStep:: { state: PacketImpl { link_header: None, @@ -2121,7 +2314,7 @@ mod white_box_tests { }, _marker: marker::PhantomData:: {}, }, - &mut writer, + &mut VecWriter(&mut writer), &[], ) .unwrap(); @@ -2296,6 +2489,80 @@ mod test { assert_eq!(actual_payload, in_payload); } + #[test] + fn write_to_vec_empty() { + let payload = [1, 2, 3, 4]; + let mut written = Vec::new(); + let builder = PacketBuilder::ipv4([13, 14, 15, 16], [17, 18, 19, 20], 21).udp(22, 23); + let expected_len = builder.size(payload.len()); + + builder.write_to_vec(&mut written, &payload).unwrap(); + + assert_eq!(written.len(), expected_len); + + let mut expected = Vec::new(); + PacketBuilder::ipv4([13, 14, 15, 16], [17, 18, 19, 20], 21) + .udp(22, 23) + .write(&mut expected, &payload) + .unwrap(); + assert_eq!(written, expected); + } + + #[test] + fn write_to_vec_with_existing_content() { + let payload = [1, 2, 3, 4]; + let prefix = vec![0xaa, 0xbb, 0xcc]; + let mut written = prefix.clone(); + + PacketBuilder::ipv4([13, 14, 15, 16], [17, 18, 19, 20], 21) + .udp(22, 23) + .write_to_vec(&mut written, &payload) + .unwrap(); + + assert_eq!(&written[..prefix.len()], prefix.as_slice()); + + let mut expected_packet = Vec::new(); + PacketBuilder::ipv4([13, 14, 15, 16], [17, 18, 19, 20], 21) + .udp(22, 23) + .write(&mut expected_packet, &payload) + .unwrap(); + assert_eq!(&written[prefix.len()..], expected_packet.as_slice()); + } + + #[test] + fn write_to_slice_success() { + let payload = [1, 2, 3, 4]; + + let mut expected = Vec::new(); + PacketBuilder::ipv4([13, 14, 15, 16], [17, 18, 19, 20], 21) + .udp(22, 23) + .write(&mut expected, &payload) + .unwrap(); + let required = expected.len(); + + let mut buffer = vec![0x55; required + 4]; + let written = PacketBuilder::ipv4([13, 14, 15, 16], [17, 18, 19, 20], 21) + .udp(22, 23) + .write_to_slice(&mut buffer, &payload) + .unwrap(); + + assert_eq!(written, required); + assert_eq!(&buffer[..written], expected.as_slice()); + assert_eq!(&buffer[written..], &[0x55, 0x55, 0x55, 0x55]); + } + + #[test] + fn write_to_slice_too_small() { + let payload = [1, 2, 3, 4]; + let builder = PacketBuilder::ipv4([13, 14, 15, 16], [17, 18, 19, 20], 21).udp(22, 23); + let required = builder.size(payload.len()); + let mut buffer = vec![0u8; required - 1]; + + let actual = builder.write_to_slice(&mut buffer, &payload).unwrap_err(); + let expected = BuildSliceWriteError::Space(required); + assert_eq!(actual, expected); + } + #[test] fn linuxsll_ipv4_udp() { //generate diff --git a/etherparse/src/writer.rs b/etherparse/src/writer.rs new file mode 100644 index 00000000..0732db32 --- /dev/null +++ b/etherparse/src/writer.rs @@ -0,0 +1,85 @@ +#[cfg(feature = "alloc")] +use alloc::vec::Vec; +#[cfg(feature = "alloc")] +use core::convert::Infallible; + +/// Internal writer abstraction used to share serialization code between +/// `std` and `no_std` code paths. +pub(crate) trait CoreWrite { + type Error; + + fn write_all(&mut self, slice: &[u8]) -> Result<(), Self::Error>; +} + +/// Internal generic write error that separates transport errors (`Io`) from +/// semantic/content errors (`Content`). +pub(crate) enum WriteError { + Io(IO), + Content(Content), +} + +#[cfg(feature = "std")] +pub(crate) struct IoWriter<'a, T: std::io::Write + ?Sized>(pub(crate) &'a mut T); + +#[cfg(feature = "alloc")] +pub(crate) struct VecWriter<'a>(pub(crate) &'a mut Vec); + +pub(crate) struct SliceCoreWrite<'a> { + buf: &'a mut [u8], + pos: usize, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub(crate) struct SliceCoreWriteError { + pub(crate) required_len: usize, + pub(crate) len: usize, +} + +impl<'a> SliceCoreWrite<'a> { + #[inline] + pub(crate) fn new(buf: &'a mut [u8]) -> Self { + SliceCoreWrite { buf, pos: 0 } + } +} + +impl CoreWrite for SliceCoreWrite<'_> { + type Error = SliceCoreWriteError; + + #[inline] + fn write_all(&mut self, slice: &[u8]) -> Result<(), Self::Error> { + let buf_len = self.buf.len(); + + let required_len = self.pos.saturating_add(slice.len()); + self.buf + .get_mut(self.pos..) + .and_then(|tail| tail.get_mut(..slice.len())) + .ok_or(SliceCoreWriteError { + required_len, + len: buf_len, + })? + .copy_from_slice(slice); + self.pos = required_len; + Ok(()) + } +} + +#[cfg(feature = "std")] +impl CoreWrite for IoWriter<'_, T> { + type Error = std::io::Error; + + #[inline] + fn write_all(&mut self, slice: &[u8]) -> Result<(), Self::Error> { + std::io::Write::write_all(self.0, slice) + } +} + +#[cfg(feature = "alloc")] +impl CoreWrite for VecWriter<'_> { + type Error = Infallible; + + #[inline] + fn write_all(&mut self, slice: &[u8]) -> Result<(), Self::Error> { + self.0.extend_from_slice(slice); + Ok(()) + } +}