Skip to content

Commit f652ff1

Browse files
crepererumCopilot
andcommitted
feat: disallow compressed IPC data
- avoids ZIP bombs, i.e. data that consumes A LOT of memory on the host - avoids CPU-based denial-of-service attacks - assuming that arrays within the guest require the same amount of bytes than within the host is helpful, since it means that memory consumption stays somewhat in check Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 8c0ac44 commit f652ff1

File tree

5 files changed

+233
-17
lines changed

5 files changed

+233
-17
lines changed

Cargo.lock

Lines changed: 47 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

arrow2bytes/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ license.workspace = true
88
arrow.workspace = true
99

1010
[dev-dependencies]
11+
arrow = { workspace = true, features = ["ipc_compression"] }
1112
insta.workspace = true
1213

1314
[lints]
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
//! Scan IPC data for compressed data.
2+
//!
3+
//! This is a workaround until <https://github.com/apache/arrow-rs/issues/8917> is implemented.
4+
5+
use std::io::{Cursor, Read, Seek};
6+
7+
use arrow::{
8+
error::ArrowError,
9+
ipc::{BodyCompression, MessageHeader, root_as_message},
10+
};
11+
12+
/// Detect and fail if there's compressed data.
13+
pub(crate) fn detect_compressed_data(bytes: &[u8]) -> Result<(), ArrowError> {
14+
let mut reader = Cursor::new(bytes);
15+
16+
loop {
17+
let Some(meta_len) = read_meta_len(&mut reader)? else {
18+
break;
19+
};
20+
let mut meta = vec![0; meta_len];
21+
reader.read_exact(&mut meta)?;
22+
let msg = root_as_message(&meta).map_err(|err| {
23+
ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
24+
})?;
25+
26+
match msg.header_type() {
27+
MessageHeader::Schema => {
28+
// never compressed
29+
}
30+
MessageHeader::DictionaryBatch => {
31+
if let Some(batch) = msg.header_as_dictionary_batch()
32+
&& let Some(batch) = batch.data()
33+
&& let Some(compression) = batch.compression()
34+
{
35+
return Err(compression_err("dictionary batch", compression));
36+
}
37+
}
38+
MessageHeader::RecordBatch => {
39+
if let Some(batch) = msg.header_as_record_batch()
40+
&& let Some(compression) = batch.compression()
41+
{
42+
return Err(compression_err("record batch", compression));
43+
}
44+
}
45+
x => {
46+
return Err(ArrowError::ParseError(format!(
47+
"Unsupported message header type in IPC stream: '{x:?}'"
48+
)));
49+
}
50+
}
51+
52+
let body_len = msg.bodyLength();
53+
if body_len < 0 {
54+
return Err(ArrowError::ParseError(format!(
55+
"Invalid body length: {body_len}"
56+
)));
57+
}
58+
reader.seek_relative(body_len)?;
59+
}
60+
61+
Ok(())
62+
}
63+
64+
/// Read the metadata length for the next message from the underlying stream.
65+
///
66+
/// # Returns
67+
/// - `Ok(None)` if the reader signals the end of stream with EOF on
68+
/// the first read
69+
/// - `Err(_)` if the reader returns an error other than EOF on the first
70+
/// read, or if the metadata length is less than 0.
71+
/// - `Ok(Some(_))` with the length otherwise.
72+
fn read_meta_len(reader: &mut Cursor<&[u8]>) -> Result<Option<usize>, ArrowError> {
73+
const CONTINUATION_MARKER: [u8; 4] = [0xff; 4];
74+
let mut meta_len: [u8; 4] = [0; 4];
75+
match reader.read_exact(&mut meta_len) {
76+
Ok(_) => {}
77+
Err(e) => {
78+
return if e.kind() == std::io::ErrorKind::UnexpectedEof {
79+
// Handle EOF without the "0xFFFFFFFF 0x00000000"
80+
// valid according to:
81+
// https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format
82+
Ok(None)
83+
} else {
84+
Err(ArrowError::from(e))
85+
};
86+
}
87+
}
88+
89+
let meta_len = {
90+
// If a continuation marker is encountered, skip over it and read
91+
// the size from the next four bytes.
92+
if meta_len == CONTINUATION_MARKER {
93+
reader.read_exact(&mut meta_len)?;
94+
}
95+
96+
i32::from_le_bytes(meta_len)
97+
};
98+
99+
if meta_len == 0 {
100+
return Ok(None);
101+
}
102+
103+
let meta_len = usize::try_from(meta_len)
104+
.map_err(|_| ArrowError::ParseError(format!("Invalid metadata length: {meta_len}")))?;
105+
106+
Ok(Some(meta_len))
107+
}
108+
109+
/// Generate error for encountered compression.
110+
fn compression_err(what: &'static str, compression: BodyCompression<'_>) -> ArrowError {
111+
ArrowError::IpcError(format!(
112+
"IPC {what} is compressed using {}, but compressed data MUST NOT cross the security boundary. If you want to handle compressed data, please decompress it within the guest.",
113+
compression.codec().variant_name().unwrap_or("<unknown>")
114+
))
115+
}

arrow2bytes/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ use arrow::{
2222
#[cfg(test)]
2323
use insta as _;
2424

25+
mod compression_check;
26+
2527
/// Convert an [`Array`] to bytes.
2628
///
2729
/// This is done by encoding writing this as a [`RecordBatch`] with a single [`Field`].
@@ -48,6 +50,8 @@ pub fn array2bytes(array: ArrayRef) -> Vec<u8> {
4850
///
4951
/// See [`array2bytes`] for the reverse method and the format description.
5052
pub fn bytes2array(bytes: &[u8]) -> Result<ArrayRef, ArrowError> {
53+
compression_check::detect_compressed_data(bytes)?;
54+
5155
let cursor = Cursor::new(bytes);
5256
let mut reader = StreamReader::try_new(cursor, None)?;
5357
let Some(res) = reader.next() else {

arrow2bytes/tests/array.rs

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,18 @@ use arrow::{
88
ArrayRef, Int64Array, ListArray, RecordBatch, RecordBatchOptions, StringDictionaryBuilder,
99
},
1010
datatypes::{DataType, Field, Int32Type, Schema},
11-
ipc::writer::StreamWriter,
11+
error::ArrowError,
12+
ipc::{
13+
CompressionType,
14+
writer::{IpcWriteOptions, StreamWriter},
15+
},
1216
};
1317
use datafusion_udf_wasm_arrow2bytes::{array2bytes, bytes2array};
1418

1519
#[test]
1620
fn test_roundtrip() {
17-
roundtrip(Arc::new(Int64Array::from_iter([Some(1), None, Some(3)])));
18-
19-
let mut builder = StringDictionaryBuilder::<Int32Type>::new();
20-
builder.append("foo").unwrap();
21-
builder.append_null();
22-
builder.append("bar").unwrap();
23-
builder.append("foo").unwrap();
24-
roundtrip(Arc::new(builder.finish()));
21+
roundtrip(int64_array());
22+
roundtrip(string_dict_array());
2523
}
2624

2725
#[test]
@@ -89,14 +87,8 @@ fn test_err_multiple_columns() {
8987
Field::new("a", DataType::Int64, true),
9088
Field::new("b", DataType::Int64, true),
9189
]));
92-
let batch = RecordBatch::try_new(
93-
Arc::clone(&schema),
94-
vec![
95-
Arc::new(Int64Array::new_null(0)),
96-
Arc::new(Int64Array::new_null(0)),
97-
],
98-
)
99-
.unwrap();
90+
let batch =
91+
RecordBatch::try_new(Arc::clone(&schema), vec![int64_array(), int64_array()]).unwrap();
10092
let mut writer =
10193
StreamWriter::try_new(Vec::new(), &schema).expect("writing to buffer never fails");
10294
writer.write(&batch).unwrap();
@@ -136,9 +128,66 @@ fn test_deeply_nested() {
136128
);
137129
}
138130

131+
#[test]
132+
fn test_err_compression() {
133+
insta::assert_snapshot!(
134+
compression_err(int64_array(), CompressionType::LZ4_FRAME),
135+
@"Ipc error: IPC record batch is compressed using LZ4_FRAME, but compressed data MUST NOT cross the security boundary. If you want to handle compressed data, please decompress it within the guest.",
136+
);
137+
insta::assert_snapshot!(
138+
compression_err(int64_array(), CompressionType::ZSTD),
139+
@"Ipc error: IPC record batch is compressed using ZSTD, but compressed data MUST NOT cross the security boundary. If you want to handle compressed data, please decompress it within the guest.",
140+
);
141+
insta::assert_snapshot!(
142+
compression_err(string_dict_array(), CompressionType::LZ4_FRAME),
143+
@"Ipc error: IPC dictionary batch is compressed using LZ4_FRAME, but compressed data MUST NOT cross the security boundary. If you want to handle compressed data, please decompress it within the guest.",
144+
);
145+
insta::assert_snapshot!(
146+
compression_err(string_dict_array(), CompressionType::ZSTD),
147+
@"Ipc error: IPC dictionary batch is compressed using ZSTD, but compressed data MUST NOT cross the security boundary. If you want to handle compressed data, please decompress it within the guest.",
148+
);
149+
}
150+
139151
#[track_caller]
140152
fn roundtrip(array: ArrayRef) {
141153
let bytes = array2bytes(Arc::clone(&array));
142154
let array2 = bytes2array(&bytes).unwrap();
143155
assert_eq!(&array, &array2);
144156
}
157+
158+
/// Create a non-empty int64 array.
159+
fn int64_array() -> ArrayRef {
160+
Arc::new(Int64Array::from_iter([Some(1), None, Some(3)]))
161+
}
162+
163+
/// Create a non-empty dict-encoded string array.
164+
fn string_dict_array() -> ArrayRef {
165+
let mut builder = StringDictionaryBuilder::<Int32Type>::new();
166+
builder.append("foo").unwrap();
167+
builder.append_null();
168+
builder.append("bar").unwrap();
169+
builder.append("foo").unwrap();
170+
Arc::new(builder.finish())
171+
}
172+
173+
#[track_caller]
174+
fn compression_err(array: ArrayRef, compression: CompressionType) -> ArrowError {
175+
let schema = Arc::new(Schema::new(vec![Field::new(
176+
"a",
177+
array.data_type().clone(),
178+
true,
179+
)]));
180+
let batch = RecordBatch::try_new(Arc::clone(&schema), vec![array]).unwrap();
181+
let mut writer = StreamWriter::try_new_with_options(
182+
Vec::new(),
183+
&schema,
184+
IpcWriteOptions::default()
185+
.try_with_compression(Some(compression))
186+
.unwrap(),
187+
)
188+
.expect("writing to buffer never fails");
189+
writer.write(&batch).unwrap();
190+
let bytes = writer.into_inner().unwrap();
191+
192+
bytes2array(&bytes).unwrap_err()
193+
}

0 commit comments

Comments
 (0)