Skip to content

Commit d13cece

Browse files
rascanifacebook-github-bot
authored andcommitted
Add null pointer check for evalues (pytorch#16162)
Summary: The to* functions could de-reference a NULL pointer. Adding a check. Reviewed By: larryliu0820, mergennachin Differential Revision: D83742246
1 parent d39d64b commit d13cece

File tree

2 files changed

+177
-3
lines changed

2 files changed

+177
-3
lines changed

runtime/core/evalue.h

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,26 @@ class BoxedEvalueList {
6363
* unwrapped vals.
6464
*/
6565
BoxedEvalueList(EValue** wrapped_vals, T* unwrapped_vals, int size)
66-
: wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {}
66+
: wrapped_vals_(checkWrappedVals(wrapped_vals, size), size),
67+
unwrapped_vals_(checkUnwrappedVals(unwrapped_vals)) {}
68+
6769
/*
6870
* Constructs and returns the list of T specified by the EValue pointers
6971
*/
7072
executorch::aten::ArrayRef<T> get() const;
7173

7274
private:
75+
static EValue** checkWrappedVals(EValue** wrapped_vals, int size) {
76+
ET_CHECK_MSG(wrapped_vals != nullptr, "wrapped_vals cannot be null");
77+
ET_CHECK_MSG(size >= 0, "size cannot be negative");
78+
return wrapped_vals;
79+
}
80+
81+
static T* checkUnwrappedVals(T* unwrapped_vals) {
82+
ET_CHECK_MSG(unwrapped_vals != nullptr, "unwrapped_vals cannot be null");
83+
return unwrapped_vals;
84+
}
85+
7386
// Source of truth for the list
7487
executorch::aten::ArrayRef<EValue*> wrapped_vals_;
7588
// Same size as wrapped_vals
@@ -280,6 +293,7 @@ struct EValue {
280293

281294
/****** String Type ******/
282295
/*implicit*/ EValue(executorch::aten::ArrayRef<char>* s) : tag(Tag::String) {
296+
ET_CHECK_MSG(s != nullptr, "ArrayRef<char> pointer cannot be null");
283297
payload.copyable_union.as_string_ptr = s;
284298
}
285299

@@ -289,13 +303,18 @@ struct EValue {
289303

290304
std::string_view toString() const {
291305
ET_CHECK_MSG(isString(), "EValue is not a String.");
306+
ET_CHECK_MSG(
307+
payload.copyable_union.as_string_ptr != nullptr,
308+
"EValue string pointer is null.");
292309
return std::string_view(
293310
payload.copyable_union.as_string_ptr->data(),
294311
payload.copyable_union.as_string_ptr->size());
295312
}
296313

297314
/****** Int List Type ******/
298315
/*implicit*/ EValue(BoxedEvalueList<int64_t>* i) : tag(Tag::ListInt) {
316+
ET_CHECK_MSG(
317+
i != nullptr, "BoxedEvalueList<int64_t> pointer cannot be null");
299318
payload.copyable_union.as_int_list_ptr = i;
300319
}
301320

@@ -305,12 +324,16 @@ struct EValue {
305324

306325
executorch::aten::ArrayRef<int64_t> toIntList() const {
307326
ET_CHECK_MSG(isIntList(), "EValue is not an Int List.");
327+
ET_CHECK_MSG(
328+
payload.copyable_union.as_int_list_ptr != nullptr,
329+
"EValue int list pointer is null.");
308330
return (payload.copyable_union.as_int_list_ptr)->get();
309331
}
310332

311333
/****** Bool List Type ******/
312334
/*implicit*/ EValue(executorch::aten::ArrayRef<bool>* b)
313335
: tag(Tag::ListBool) {
336+
ET_CHECK_MSG(b != nullptr, "ArrayRef<bool> pointer cannot be null");
314337
payload.copyable_union.as_bool_list_ptr = b;
315338
}
316339

@@ -320,12 +343,16 @@ struct EValue {
320343

321344
executorch::aten::ArrayRef<bool> toBoolList() const {
322345
ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List.");
346+
ET_CHECK_MSG(
347+
payload.copyable_union.as_bool_list_ptr != nullptr,
348+
"EValue bool list pointer is null.");
323349
return *(payload.copyable_union.as_bool_list_ptr);
324350
}
325351

326352
/****** Double List Type ******/
327353
/*implicit*/ EValue(executorch::aten::ArrayRef<double>* d)
328354
: tag(Tag::ListDouble) {
355+
ET_CHECK_MSG(d != nullptr, "ArrayRef<double> pointer cannot be null");
329356
payload.copyable_union.as_double_list_ptr = d;
330357
}
331358

@@ -335,12 +362,17 @@ struct EValue {
335362

336363
executorch::aten::ArrayRef<double> toDoubleList() const {
337364
ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List.");
365+
ET_CHECK_MSG(
366+
payload.copyable_union.as_double_list_ptr != nullptr,
367+
"EValue double list pointer is null.");
338368
return *(payload.copyable_union.as_double_list_ptr);
339369
}
340370

341371
/****** Tensor List Type ******/
342372
/*implicit*/ EValue(BoxedEvalueList<executorch::aten::Tensor>* t)
343373
: tag(Tag::ListTensor) {
374+
ET_CHECK_MSG(
375+
t != nullptr, "BoxedEvalueList<Tensor> pointer cannot be null");
344376
payload.copyable_union.as_tensor_list_ptr = t;
345377
}
346378

@@ -350,13 +382,19 @@ struct EValue {
350382

351383
executorch::aten::ArrayRef<executorch::aten::Tensor> toTensorList() const {
352384
ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List.");
385+
ET_CHECK_MSG(
386+
payload.copyable_union.as_tensor_list_ptr != nullptr,
387+
"EValue tensor list pointer is null.");
353388
return payload.copyable_union.as_tensor_list_ptr->get();
354389
}
355390

356391
/****** List Optional Tensor Type ******/
357392
/*implicit*/ EValue(
358393
BoxedEvalueList<std::optional<executorch::aten::Tensor>>* t)
359394
: tag(Tag::ListOptionalTensor) {
395+
ET_CHECK_MSG(
396+
t != nullptr,
397+
"BoxedEvalueList<optional<Tensor>> pointer cannot be null");
360398
payload.copyable_union.as_list_optional_tensor_ptr = t;
361399
}
362400

@@ -366,6 +404,11 @@ struct EValue {
366404

367405
executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>
368406
toListOptionalTensor() const {
407+
ET_CHECK_MSG(
408+
isListOptionalTensor(), "EValue is not a List Optional Tensor.");
409+
ET_CHECK_MSG(
410+
payload.copyable_union.as_list_optional_tensor_ptr != nullptr,
411+
"EValue list optional tensor pointer is null.");
369412
return payload.copyable_union.as_list_optional_tensor_ptr->get();
370413
}
371414

@@ -445,11 +488,15 @@ struct EValue {
445488
// minor performance bump for a code maintainability hit
446489
if (isTensor()) {
447490
payload.as_tensor.~Tensor();
448-
} else if (isTensorList()) {
491+
} else if (
492+
isTensorList() &&
493+
payload.copyable_union.as_tensor_list_ptr != nullptr) {
449494
for (auto& tensor : toTensorList()) {
450495
tensor.~Tensor();
451496
}
452-
} else if (isListOptionalTensor()) {
497+
} else if (
498+
isListOptionalTensor() &&
499+
payload.copyable_union.as_list_optional_tensor_ptr != nullptr) {
453500
for (auto& optional_tensor : toListOptionalTensor()) {
454501
optional_tensor.~optional();
455502
}

runtime/core/test/evalue_test.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,130 @@ TEST_F(EValueTest, ConstructFromNullPtrAborts) {
281281

282282
ET_EXPECT_DEATH({ EValue evalue(null_ptr); }, "");
283283
}
284+
285+
TEST_F(EValueTest, StringConstructorNullCheck) {
286+
executorch::aten::ArrayRef<char>* null_string_ptr = nullptr;
287+
ET_EXPECT_DEATH({ EValue evalue(null_string_ptr); }, "");
288+
}
289+
290+
TEST_F(EValueTest, BoolListConstructorNullCheck) {
291+
executorch::aten::ArrayRef<bool>* null_bool_list_ptr = nullptr;
292+
ET_EXPECT_DEATH({ EValue evalue(null_bool_list_ptr); }, "");
293+
}
294+
295+
TEST_F(EValueTest, DoubleListConstructorNullCheck) {
296+
executorch::aten::ArrayRef<double>* null_double_list_ptr = nullptr;
297+
ET_EXPECT_DEATH({ EValue evalue(null_double_list_ptr); }, "");
298+
}
299+
300+
TEST_F(EValueTest, IntListConstructorNullCheck) {
301+
BoxedEvalueList<int64_t>* null_int_list_ptr = nullptr;
302+
ET_EXPECT_DEATH({ EValue evalue(null_int_list_ptr); }, "");
303+
}
304+
305+
TEST_F(EValueTest, TensorListConstructorNullCheck) {
306+
BoxedEvalueList<executorch::aten::Tensor>* null_tensor_list_ptr = nullptr;
307+
ET_EXPECT_DEATH({ EValue evalue(null_tensor_list_ptr); }, "");
308+
}
309+
310+
TEST_F(EValueTest, OptionalTensorListConstructorNullCheck) {
311+
BoxedEvalueList<std::optional<executorch::aten::Tensor>>*
312+
null_optional_tensor_list_ptr = nullptr;
313+
ET_EXPECT_DEATH({ EValue evalue(null_optional_tensor_list_ptr); }, "");
314+
}
315+
316+
TEST_F(EValueTest, BoxedEvalueListConstructorNullChecks) {
317+
std::array<int64_t, 3> storage = {0, 0, 0};
318+
std::array<EValue, 3> values = {
319+
EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)};
320+
std::array<EValue*, 3> values_p = {&values[0], &values[1], &values[2]};
321+
322+
// Test null wrapped_vals
323+
ET_EXPECT_DEATH(
324+
{ BoxedEvalueList<int64_t> list(nullptr, storage.data(), 3); }, "");
325+
326+
// Test null unwrapped_vals
327+
ET_EXPECT_DEATH(
328+
{ BoxedEvalueList<int64_t> list(values_p.data(), nullptr, 3); }, "");
329+
330+
// Test negative size
331+
ET_EXPECT_DEATH(
332+
{ BoxedEvalueList<int64_t> list(values_p.data(), storage.data(), -1); },
333+
"");
334+
}
335+
336+
TEST_F(EValueTest, toListOptionalTensorTypeCheck) {
337+
// Create an EValue that's not a ListOptionalTensor
338+
EValue e((int64_t)42);
339+
EXPECT_TRUE(e.isInt());
340+
EXPECT_FALSE(e.isListOptionalTensor());
341+
342+
// Should fail type check
343+
ET_EXPECT_DEATH({ e.toListOptionalTensor(); }, "");
344+
}
345+
346+
TEST_F(EValueTest, toStringNullPointerCheck) {
347+
// Create an EValue with String tag but null pointer
348+
EValue e;
349+
e.tag = Tag::String;
350+
e.payload.copyable_union.as_string_ptr = nullptr;
351+
352+
// Should pass isString() check but fail null pointer check
353+
EXPECT_TRUE(e.isString());
354+
ET_EXPECT_DEATH({ e.toString(); }, "");
355+
}
356+
357+
TEST_F(EValueTest, toIntListNullPointerCheck) {
358+
// Create an EValue with ListInt tag but null pointer
359+
EValue e;
360+
e.tag = Tag::ListInt;
361+
e.payload.copyable_union.as_int_list_ptr = nullptr;
362+
363+
// Should pass isIntList() check but fail null pointer check
364+
EXPECT_TRUE(e.isIntList());
365+
ET_EXPECT_DEATH({ e.toIntList(); }, "");
366+
}
367+
368+
TEST_F(EValueTest, toBoolListNullPointerCheck) {
369+
// Create an EValue with ListBool tag but null pointer
370+
EValue e;
371+
e.tag = Tag::ListBool;
372+
e.payload.copyable_union.as_bool_list_ptr = nullptr;
373+
374+
// Should pass isBoolList() check but fail null pointer check
375+
EXPECT_TRUE(e.isBoolList());
376+
ET_EXPECT_DEATH({ e.toBoolList(); }, "");
377+
}
378+
379+
TEST_F(EValueTest, toDoubleListNullPointerCheck) {
380+
// Create an EValue with ListDouble tag but null pointer
381+
EValue e;
382+
e.tag = Tag::ListDouble;
383+
e.payload.copyable_union.as_double_list_ptr = nullptr;
384+
385+
// Should pass isDoubleList() check but fail null pointer check
386+
EXPECT_TRUE(e.isDoubleList());
387+
ET_EXPECT_DEATH({ e.toDoubleList(); }, "");
388+
}
389+
390+
TEST_F(EValueTest, toTensorListNullPointerCheck) {
391+
// Create an EValue with ListTensor tag but null pointer
392+
EValue e;
393+
e.tag = Tag::ListTensor;
394+
e.payload.copyable_union.as_tensor_list_ptr = nullptr;
395+
396+
// Should pass isTensorList() check but fail null pointer check
397+
EXPECT_TRUE(e.isTensorList());
398+
ET_EXPECT_DEATH({ e.toTensorList(); }, "");
399+
}
400+
401+
TEST_F(EValueTest, toListOptionalTensorNullPointerCheck) {
402+
// Create an EValue with ListOptionalTensor tag but null pointer
403+
EValue e;
404+
e.tag = Tag::ListOptionalTensor;
405+
e.payload.copyable_union.as_list_optional_tensor_ptr = nullptr;
406+
407+
// Should pass isListOptionalTensor() check but fail null pointer check
408+
EXPECT_TRUE(e.isListOptionalTensor());
409+
ET_EXPECT_DEATH({ e.toListOptionalTensor(); }, "");
410+
}

0 commit comments

Comments
 (0)