Skip to content

Commit 046e816

Browse files
rascanifacebook-github-bot
authored andcommitted
Add null pointer check for evalues (pytorch#16162)
Summary: The to* functions could de-reference a NULL pointer. Adding a check. Reviewed By: larryliu0820, mergennachin Differential Revision: D83742246
1 parent c9f6df1 commit 046e816

File tree

2 files changed

+194
-3
lines changed

2 files changed

+194
-3
lines changed

runtime/core/evalue.h

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,26 @@ class BoxedEvalueList {
6363
* unwrapped vals.
6464
*/
6565
BoxedEvalueList(EValue** wrapped_vals, T* unwrapped_vals, int size)
66-
: wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {}
66+
: wrapped_vals_(checkWrappedVals(wrapped_vals, size), size),
67+
unwrapped_vals_(checkUnwrappedVals(unwrapped_vals)) {}
68+
6769
/*
6870
* Constructs and returns the list of T specified by the EValue pointers
6971
*/
7072
executorch::aten::ArrayRef<T> get() const;
7173

7274
private:
75+
static EValue** checkWrappedVals(EValue** wrapped_vals, int size) {
76+
ET_CHECK_MSG(wrapped_vals != nullptr, "wrapped_vals cannot be null");
77+
ET_CHECK_MSG(size >= 0, "size cannot be negative");
78+
return wrapped_vals;
79+
}
80+
81+
static T* checkUnwrappedVals(T* unwrapped_vals) {
82+
ET_CHECK_MSG(unwrapped_vals != nullptr, "unwrapped_vals cannot be null");
83+
return unwrapped_vals;
84+
}
85+
7386
// Source of truth for the list
7487
executorch::aten::ArrayRef<EValue*> wrapped_vals_;
7588
// Same size as wrapped_vals
@@ -280,6 +293,7 @@ struct EValue {
280293

281294
/****** String Type ******/
282295
/*implicit*/ EValue(executorch::aten::ArrayRef<char>* s) : tag(Tag::String) {
296+
ET_CHECK_MSG(s != nullptr, "ArrayRef<char> pointer cannot be null");
283297
payload.copyable_union.as_string_ptr = s;
284298
}
285299

@@ -289,13 +303,18 @@ struct EValue {
289303

290304
std::string_view toString() const {
291305
ET_CHECK_MSG(isString(), "EValue is not a String.");
306+
ET_CHECK_MSG(
307+
payload.copyable_union.as_string_ptr != nullptr,
308+
"EValue string pointer is null.");
292309
return std::string_view(
293310
payload.copyable_union.as_string_ptr->data(),
294311
payload.copyable_union.as_string_ptr->size());
295312
}
296313

297314
/****** Int List Type ******/
298315
/*implicit*/ EValue(BoxedEvalueList<int64_t>* i) : tag(Tag::ListInt) {
316+
ET_CHECK_MSG(
317+
i != nullptr, "BoxedEvalueList<int64_t> pointer cannot be null");
299318
payload.copyable_union.as_int_list_ptr = i;
300319
}
301320

@@ -305,12 +324,16 @@ struct EValue {
305324

306325
executorch::aten::ArrayRef<int64_t> toIntList() const {
307326
ET_CHECK_MSG(isIntList(), "EValue is not an Int List.");
327+
ET_CHECK_MSG(
328+
payload.copyable_union.as_int_list_ptr != nullptr,
329+
"EValue int list pointer is null.");
308330
return (payload.copyable_union.as_int_list_ptr)->get();
309331
}
310332

311333
/****** Bool List Type ******/
312334
/*implicit*/ EValue(executorch::aten::ArrayRef<bool>* b)
313335
: tag(Tag::ListBool) {
336+
ET_CHECK_MSG(b != nullptr, "ArrayRef<bool> pointer cannot be null");
314337
payload.copyable_union.as_bool_list_ptr = b;
315338
}
316339

@@ -320,12 +343,16 @@ struct EValue {
320343

321344
executorch::aten::ArrayRef<bool> toBoolList() const {
322345
ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List.");
346+
ET_CHECK_MSG(
347+
payload.copyable_union.as_bool_list_ptr != nullptr,
348+
"EValue bool list pointer is null.");
323349
return *(payload.copyable_union.as_bool_list_ptr);
324350
}
325351

326352
/****** Double List Type ******/
327353
/*implicit*/ EValue(executorch::aten::ArrayRef<double>* d)
328354
: tag(Tag::ListDouble) {
355+
ET_CHECK_MSG(d != nullptr, "ArrayRef<double> pointer cannot be null");
329356
payload.copyable_union.as_double_list_ptr = d;
330357
}
331358

@@ -335,12 +362,17 @@ struct EValue {
335362

336363
executorch::aten::ArrayRef<double> toDoubleList() const {
337364
ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List.");
365+
ET_CHECK_MSG(
366+
payload.copyable_union.as_double_list_ptr != nullptr,
367+
"EValue double list pointer is null.");
338368
return *(payload.copyable_union.as_double_list_ptr);
339369
}
340370

341371
/****** Tensor List Type ******/
342372
/*implicit*/ EValue(BoxedEvalueList<executorch::aten::Tensor>* t)
343373
: tag(Tag::ListTensor) {
374+
ET_CHECK_MSG(
375+
t != nullptr, "BoxedEvalueList<Tensor> pointer cannot be null");
344376
payload.copyable_union.as_tensor_list_ptr = t;
345377
}
346378

@@ -350,13 +382,19 @@ struct EValue {
350382

351383
executorch::aten::ArrayRef<executorch::aten::Tensor> toTensorList() const {
352384
ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List.");
385+
ET_CHECK_MSG(
386+
payload.copyable_union.as_tensor_list_ptr != nullptr,
387+
"EValue tensor list pointer is null.");
353388
return payload.copyable_union.as_tensor_list_ptr->get();
354389
}
355390

356391
/****** List Optional Tensor Type ******/
357392
/*implicit*/ EValue(
358393
BoxedEvalueList<std::optional<executorch::aten::Tensor>>* t)
359394
: tag(Tag::ListOptionalTensor) {
395+
ET_CHECK_MSG(
396+
t != nullptr,
397+
"BoxedEvalueList<optional<Tensor>> pointer cannot be null");
360398
payload.copyable_union.as_list_optional_tensor_ptr = t;
361399
}
362400

@@ -366,6 +404,11 @@ struct EValue {
366404

367405
executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>
368406
toListOptionalTensor() const {
407+
ET_CHECK_MSG(
408+
isListOptionalTensor(), "EValue is not a List Optional Tensor.");
409+
ET_CHECK_MSG(
410+
payload.copyable_union.as_list_optional_tensor_ptr != nullptr,
411+
"EValue list optional tensor pointer is null.");
369412
return payload.copyable_union.as_list_optional_tensor_ptr->get();
370413
}
371414

@@ -445,11 +488,15 @@ struct EValue {
445488
// minor performance bump for a code maintainability hit
446489
if (isTensor()) {
447490
payload.as_tensor.~Tensor();
448-
} else if (isTensorList()) {
491+
} else if (
492+
isTensorList() &&
493+
payload.copyable_union.as_tensor_list_ptr != nullptr) {
449494
for (auto& tensor : toTensorList()) {
450495
tensor.~Tensor();
451496
}
452-
} else if (isListOptionalTensor()) {
497+
} else if (
498+
isListOptionalTensor() &&
499+
payload.copyable_union.as_list_optional_tensor_ptr != nullptr) {
453500
for (auto& optional_tensor : toListOptionalTensor()) {
454501
optional_tensor.~optional();
455502
}

runtime/core/test/evalue_test.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,147 @@ TEST_F(EValueTest, ConstructFromNullPtrAborts) {
281281

282282
ET_EXPECT_DEATH({ EValue evalue(null_ptr); }, "");
283283
}
284+
285+
TEST_F(EValueTest, StringConstructorNullCheck) {
286+
executorch::aten::ArrayRef<char>* null_string_ptr = nullptr;
287+
ET_EXPECT_DEATH(
288+
{ EValue evalue(null_string_ptr); },
289+
"ArrayRef<char> pointer cannot be null");
290+
}
291+
292+
TEST_F(EValueTest, BoolListConstructorNullCheck) {
293+
executorch::aten::ArrayRef<bool>* null_bool_list_ptr = nullptr;
294+
ET_EXPECT_DEATH(
295+
{ EValue evalue(null_bool_list_ptr); },
296+
"ArrayRef<bool> pointer cannot be null");
297+
}
298+
299+
TEST_F(EValueTest, DoubleListConstructorNullCheck) {
300+
executorch::aten::ArrayRef<double>* null_double_list_ptr = nullptr;
301+
ET_EXPECT_DEATH(
302+
{ EValue evalue(null_double_list_ptr); },
303+
"ArrayRef<double> pointer cannot be null");
304+
}
305+
306+
TEST_F(EValueTest, IntListConstructorNullCheck) {
307+
BoxedEvalueList<int64_t>* null_int_list_ptr = nullptr;
308+
ET_EXPECT_DEATH(
309+
{ EValue evalue(null_int_list_ptr); },
310+
"BoxedEvalueList<int64_t> pointer cannot be null");
311+
}
312+
313+
TEST_F(EValueTest, TensorListConstructorNullCheck) {
314+
BoxedEvalueList<executorch::aten::Tensor>* null_tensor_list_ptr = nullptr;
315+
ET_EXPECT_DEATH(
316+
{ EValue evalue(null_tensor_list_ptr); },
317+
"BoxedEvalueList<Tensor> pointer cannot be null");
318+
}
319+
320+
TEST_F(EValueTest, OptionalTensorListConstructorNullCheck) {
321+
BoxedEvalueList<std::optional<executorch::aten::Tensor>>*
322+
null_optional_tensor_list_ptr = nullptr;
323+
ET_EXPECT_DEATH(
324+
{ EValue evalue(null_optional_tensor_list_ptr); },
325+
"BoxedEvalueList<optional<Tensor>> pointer cannot be null");
326+
}
327+
328+
TEST_F(EValueTest, BoxedEvalueListConstructorNullChecks) {
329+
int64_t storage[3] = {0, 0, 0};
330+
EValue values[3] = {
331+
EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)};
332+
EValue* values_p[3] = {&values[0], &values[1], &values[2]};
333+
334+
// Test null wrapped_vals
335+
ET_EXPECT_DEATH(
336+
{ BoxedEvalueList<int64_t> list(nullptr, storage, 3); },
337+
"wrapped_vals cannot be null");
338+
339+
// Test null unwrapped_vals
340+
ET_EXPECT_DEATH(
341+
{ BoxedEvalueList<int64_t> list(values_p, nullptr, 3); },
342+
"unwrapped_vals cannot be null");
343+
344+
// Test negative size
345+
ET_EXPECT_DEATH(
346+
{ BoxedEvalueList<int64_t> list(values_p, storage, -1); },
347+
"size cannot be negative");
348+
}
349+
350+
TEST_F(EValueTest, toListOptionalTensorTypeCheck) {
351+
// Create an EValue that's not a ListOptionalTensor
352+
EValue e((int64_t)42);
353+
EXPECT_TRUE(e.isInt());
354+
EXPECT_FALSE(e.isListOptionalTensor());
355+
356+
// Should fail type check
357+
ET_EXPECT_DEATH(
358+
{ e.toListOptionalTensor(); }, "EValue is not a List Optional Tensor");
359+
}
360+
361+
TEST_F(EValueTest, toStringNullPointerCheck) {
362+
// Create an EValue with String tag but null pointer
363+
EValue e;
364+
e.tag = Tag::String;
365+
e.payload.copyable_union.as_string_ptr = nullptr;
366+
367+
// Should pass isString() check but fail null pointer check
368+
EXPECT_TRUE(e.isString());
369+
ET_EXPECT_DEATH({ e.toString(); }, "EValue string pointer is null");
370+
}
371+
372+
TEST_F(EValueTest, toIntListNullPointerCheck) {
373+
// Create an EValue with ListInt tag but null pointer
374+
EValue e;
375+
e.tag = Tag::ListInt;
376+
e.payload.copyable_union.as_int_list_ptr = nullptr;
377+
378+
// Should pass isIntList() check but fail null pointer check
379+
EXPECT_TRUE(e.isIntList());
380+
ET_EXPECT_DEATH({ e.toIntList(); }, "EValue int list pointer is null");
381+
}
382+
383+
TEST_F(EValueTest, toBoolListNullPointerCheck) {
384+
// Create an EValue with ListBool tag but null pointer
385+
EValue e;
386+
e.tag = Tag::ListBool;
387+
e.payload.copyable_union.as_bool_list_ptr = nullptr;
388+
389+
// Should pass isBoolList() check but fail null pointer check
390+
EXPECT_TRUE(e.isBoolList());
391+
ET_EXPECT_DEATH({ e.toBoolList(); }, "EValue bool list pointer is null");
392+
}
393+
394+
TEST_F(EValueTest, toDoubleListNullPointerCheck) {
395+
// Create an EValue with ListDouble tag but null pointer
396+
EValue e;
397+
e.tag = Tag::ListDouble;
398+
e.payload.copyable_union.as_double_list_ptr = nullptr;
399+
400+
// Should pass isDoubleList() check but fail null pointer check
401+
EXPECT_TRUE(e.isDoubleList());
402+
ET_EXPECT_DEATH({ e.toDoubleList(); }, "EValue double list pointer is null");
403+
}
404+
405+
TEST_F(EValueTest, toTensorListNullPointerCheck) {
406+
// Create an EValue with ListTensor tag but null pointer
407+
EValue e;
408+
e.tag = Tag::ListTensor;
409+
e.payload.copyable_union.as_tensor_list_ptr = nullptr;
410+
411+
// Should pass isTensorList() check but fail null pointer check
412+
EXPECT_TRUE(e.isTensorList());
413+
ET_EXPECT_DEATH({ e.toTensorList(); }, "EValue tensor list pointer is null");
414+
}
415+
416+
TEST_F(EValueTest, toListOptionalTensorNullPointerCheck) {
417+
// Create an EValue with ListOptionalTensor tag but null pointer
418+
EValue e;
419+
e.tag = Tag::ListOptionalTensor;
420+
e.payload.copyable_union.as_list_optional_tensor_ptr = nullptr;
421+
422+
// Should pass isListOptionalTensor() check but fail null pointer check
423+
EXPECT_TRUE(e.isListOptionalTensor());
424+
ET_EXPECT_DEATH(
425+
{ e.toListOptionalTensor(); },
426+
"EValue list optional tensor pointer is null");
427+
}

0 commit comments

Comments
 (0)