Skip to content

Commit 0bdc306

Browse files
rascanifacebook-github-bot
authored andcommitted
Add null pointer check for evalues (pytorch#16162)
Summary: The to* functions could de-reference a NULL pointer. Adding a check. Reviewed By: larryliu0820, mergennachin Differential Revision: D83742246
1 parent d39d64b commit 0bdc306

File tree

2 files changed

+174
-3
lines changed

2 files changed

+174
-3
lines changed

runtime/core/evalue.h

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,26 @@ class BoxedEvalueList {
6363
* unwrapped vals.
6464
*/
6565
BoxedEvalueList(EValue** wrapped_vals, T* unwrapped_vals, int size)
66-
: wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {}
66+
: wrapped_vals_(checkWrappedVals(wrapped_vals, size), size),
67+
unwrapped_vals_(checkUnwrappedVals(unwrapped_vals)) {}
68+
6769
/*
6870
* Constructs and returns the list of T specified by the EValue pointers
6971
*/
7072
executorch::aten::ArrayRef<T> get() const;
7173

7274
private:
75+
static EValue** checkWrappedVals(EValue** wrapped_vals, int size) {
76+
ET_CHECK_MSG(wrapped_vals != nullptr, "wrapped_vals cannot be null");
77+
ET_CHECK_MSG(size >= 0, "size cannot be negative");
78+
return wrapped_vals;
79+
}
80+
81+
static T* checkUnwrappedVals(T* unwrapped_vals) {
82+
ET_CHECK_MSG(unwrapped_vals != nullptr, "unwrapped_vals cannot be null");
83+
return unwrapped_vals;
84+
}
85+
7386
// Source of truth for the list
7487
executorch::aten::ArrayRef<EValue*> wrapped_vals_;
7588
// Same size as wrapped_vals
@@ -280,6 +293,7 @@ struct EValue {
280293

281294
/****** String Type ******/
282295
/*implicit*/ EValue(executorch::aten::ArrayRef<char>* s) : tag(Tag::String) {
296+
ET_CHECK_MSG(s != nullptr, "ArrayRef<char> pointer cannot be null");
283297
payload.copyable_union.as_string_ptr = s;
284298
}
285299

@@ -289,13 +303,18 @@ struct EValue {
289303

290304
std::string_view toString() const {
291305
ET_CHECK_MSG(isString(), "EValue is not a String.");
306+
ET_CHECK_MSG(
307+
payload.copyable_union.as_string_ptr != nullptr,
308+
"EValue string pointer is null.");
292309
return std::string_view(
293310
payload.copyable_union.as_string_ptr->data(),
294311
payload.copyable_union.as_string_ptr->size());
295312
}
296313

297314
/****** Int List Type ******/
298315
/*implicit*/ EValue(BoxedEvalueList<int64_t>* i) : tag(Tag::ListInt) {
316+
ET_CHECK_MSG(
317+
i != nullptr, "BoxedEvalueList<int64_t> pointer cannot be null");
299318
payload.copyable_union.as_int_list_ptr = i;
300319
}
301320

@@ -305,12 +324,16 @@ struct EValue {
305324

306325
executorch::aten::ArrayRef<int64_t> toIntList() const {
307326
ET_CHECK_MSG(isIntList(), "EValue is not an Int List.");
327+
ET_CHECK_MSG(
328+
payload.copyable_union.as_int_list_ptr != nullptr,
329+
"EValue int list pointer is null.");
308330
return (payload.copyable_union.as_int_list_ptr)->get();
309331
}
310332

311333
/****** Bool List Type ******/
312334
/*implicit*/ EValue(executorch::aten::ArrayRef<bool>* b)
313335
: tag(Tag::ListBool) {
336+
ET_CHECK_MSG(b != nullptr, "ArrayRef<bool> pointer cannot be null");
314337
payload.copyable_union.as_bool_list_ptr = b;
315338
}
316339

@@ -320,12 +343,16 @@ struct EValue {
320343

321344
executorch::aten::ArrayRef<bool> toBoolList() const {
322345
ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List.");
346+
ET_CHECK_MSG(
347+
payload.copyable_union.as_bool_list_ptr != nullptr,
348+
"EValue bool list pointer is null.");
323349
return *(payload.copyable_union.as_bool_list_ptr);
324350
}
325351

326352
/****** Double List Type ******/
327353
/*implicit*/ EValue(executorch::aten::ArrayRef<double>* d)
328354
: tag(Tag::ListDouble) {
355+
ET_CHECK_MSG(d != nullptr, "ArrayRef<double> pointer cannot be null");
329356
payload.copyable_union.as_double_list_ptr = d;
330357
}
331358

@@ -335,12 +362,17 @@ struct EValue {
335362

336363
executorch::aten::ArrayRef<double> toDoubleList() const {
337364
ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List.");
365+
ET_CHECK_MSG(
366+
payload.copyable_union.as_double_list_ptr != nullptr,
367+
"EValue double list pointer is null.");
338368
return *(payload.copyable_union.as_double_list_ptr);
339369
}
340370

341371
/****** Tensor List Type ******/
342372
/*implicit*/ EValue(BoxedEvalueList<executorch::aten::Tensor>* t)
343373
: tag(Tag::ListTensor) {
374+
ET_CHECK_MSG(
375+
t != nullptr, "BoxedEvalueList<Tensor> pointer cannot be null");
344376
payload.copyable_union.as_tensor_list_ptr = t;
345377
}
346378

@@ -350,13 +382,19 @@ struct EValue {
350382

351383
executorch::aten::ArrayRef<executorch::aten::Tensor> toTensorList() const {
352384
ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List.");
385+
ET_CHECK_MSG(
386+
payload.copyable_union.as_tensor_list_ptr != nullptr,
387+
"EValue tensor list pointer is null.");
353388
return payload.copyable_union.as_tensor_list_ptr->get();
354389
}
355390

356391
/****** List Optional Tensor Type ******/
357392
/*implicit*/ EValue(
358393
BoxedEvalueList<std::optional<executorch::aten::Tensor>>* t)
359394
: tag(Tag::ListOptionalTensor) {
395+
ET_CHECK_MSG(
396+
t != nullptr,
397+
"BoxedEvalueList<optional<Tensor>> pointer cannot be null");
360398
payload.copyable_union.as_list_optional_tensor_ptr = t;
361399
}
362400

@@ -366,6 +404,11 @@ struct EValue {
366404

367405
executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>
368406
toListOptionalTensor() const {
407+
ET_CHECK_MSG(
408+
isListOptionalTensor(), "EValue is not a List Optional Tensor.");
409+
ET_CHECK_MSG(
410+
payload.copyable_union.as_list_optional_tensor_ptr != nullptr,
411+
"EValue list optional tensor pointer is null.");
369412
return payload.copyable_union.as_list_optional_tensor_ptr->get();
370413
}
371414

@@ -445,11 +488,15 @@ struct EValue {
445488
// minor performance bump for a code maintainability hit
446489
if (isTensor()) {
447490
payload.as_tensor.~Tensor();
448-
} else if (isTensorList()) {
491+
} else if (
492+
isTensorList() &&
493+
payload.copyable_union.as_tensor_list_ptr != nullptr) {
449494
for (auto& tensor : toTensorList()) {
450495
tensor.~Tensor();
451496
}
452-
} else if (isListOptionalTensor()) {
497+
} else if (
498+
isListOptionalTensor() &&
499+
payload.copyable_union.as_list_optional_tensor_ptr != nullptr) {
453500
for (auto& optional_tensor : toListOptionalTensor()) {
454501
optional_tensor.~optional();
455502
}

runtime/core/test/evalue_test.cpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,127 @@ TEST_F(EValueTest, ConstructFromNullPtrAborts) {
281281

282282
ET_EXPECT_DEATH({ EValue evalue(null_ptr); }, "");
283283
}
284+
285+
TEST_F(EValueTest, StringConstructorNullCheck) {
286+
executorch::aten::ArrayRef<char>* null_string_ptr = nullptr;
287+
ET_EXPECT_DEATH({ EValue evalue(null_string_ptr); }, "");
288+
}
289+
290+
TEST_F(EValueTest, BoolListConstructorNullCheck) {
291+
executorch::aten::ArrayRef<bool>* null_bool_list_ptr = nullptr;
292+
ET_EXPECT_DEATH({ EValue evalue(null_bool_list_ptr); }, "");
293+
}
294+
295+
TEST_F(EValueTest, DoubleListConstructorNullCheck) {
296+
executorch::aten::ArrayRef<double>* null_double_list_ptr = nullptr;
297+
ET_EXPECT_DEATH({ EValue evalue(null_double_list_ptr); }, "");
298+
}
299+
300+
TEST_F(EValueTest, IntListConstructorNullCheck) {
301+
BoxedEvalueList<int64_t>* null_int_list_ptr = nullptr;
302+
ET_EXPECT_DEATH({ EValue evalue(null_int_list_ptr); }, "");
303+
}
304+
305+
TEST_F(EValueTest, TensorListConstructorNullCheck) {
306+
BoxedEvalueList<executorch::aten::Tensor>* null_tensor_list_ptr = nullptr;
307+
ET_EXPECT_DEATH({ EValue evalue(null_tensor_list_ptr); }, "");
308+
}
309+
310+
TEST_F(EValueTest, OptionalTensorListConstructorNullCheck) {
311+
BoxedEvalueList<std::optional<executorch::aten::Tensor>>*
312+
null_optional_tensor_list_ptr = nullptr;
313+
ET_EXPECT_DEATH({ EValue evalue(null_optional_tensor_list_ptr); }, "");
314+
}
315+
316+
TEST_F(EValueTest, BoxedEvalueListConstructorNullChecks) {
317+
int64_t storage[3] = {0, 0, 0};
318+
EValue values[3] = {
319+
EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)};
320+
EValue* values_p[3] = {&values[0], &values[1], &values[2]};
321+
322+
// Test null wrapped_vals
323+
ET_EXPECT_DEATH({ BoxedEvalueList<int64_t> list(nullptr, storage, 3); }, "");
324+
325+
// Test null unwrapped_vals
326+
ET_EXPECT_DEATH({ BoxedEvalueList<int64_t> list(values_p, nullptr, 3); }, "");
327+
328+
// Test negative size
329+
ET_EXPECT_DEATH(
330+
{ BoxedEvalueList<int64_t> list(values_p, storage, -1); }, "");
331+
}
332+
333+
TEST_F(EValueTest, toListOptionalTensorTypeCheck) {
334+
// Create an EValue that's not a ListOptionalTensor
335+
EValue e((int64_t)42);
336+
EXPECT_TRUE(e.isInt());
337+
EXPECT_FALSE(e.isListOptionalTensor());
338+
339+
// Should fail type check
340+
ET_EXPECT_DEATH({ e.toListOptionalTensor(); }, "");
341+
}
342+
343+
TEST_F(EValueTest, toStringNullPointerCheck) {
344+
// Create an EValue with String tag but null pointer
345+
EValue e;
346+
e.tag = Tag::String;
347+
e.payload.copyable_union.as_string_ptr = nullptr;
348+
349+
// Should pass isString() check but fail null pointer check
350+
EXPECT_TRUE(e.isString());
351+
ET_EXPECT_DEATH({ e.toString(); }, "");
352+
}
353+
354+
TEST_F(EValueTest, toIntListNullPointerCheck) {
355+
// Create an EValue with ListInt tag but null pointer
356+
EValue e;
357+
e.tag = Tag::ListInt;
358+
e.payload.copyable_union.as_int_list_ptr = nullptr;
359+
360+
// Should pass isIntList() check but fail null pointer check
361+
EXPECT_TRUE(e.isIntList());
362+
ET_EXPECT_DEATH({ e.toIntList(); }, "");
363+
}
364+
365+
TEST_F(EValueTest, toBoolListNullPointerCheck) {
366+
// Create an EValue with ListBool tag but null pointer
367+
EValue e;
368+
e.tag = Tag::ListBool;
369+
e.payload.copyable_union.as_bool_list_ptr = nullptr;
370+
371+
// Should pass isBoolList() check but fail null pointer check
372+
EXPECT_TRUE(e.isBoolList());
373+
ET_EXPECT_DEATH({ e.toBoolList(); }, "");
374+
}
375+
376+
TEST_F(EValueTest, toDoubleListNullPointerCheck) {
377+
// Create an EValue with ListDouble tag but null pointer
378+
EValue e;
379+
e.tag = Tag::ListDouble;
380+
e.payload.copyable_union.as_double_list_ptr = nullptr;
381+
382+
// Should pass isDoubleList() check but fail null pointer check
383+
EXPECT_TRUE(e.isDoubleList());
384+
ET_EXPECT_DEATH({ e.toDoubleList(); }, "");
385+
}
386+
387+
TEST_F(EValueTest, toTensorListNullPointerCheck) {
388+
// Create an EValue with ListTensor tag but null pointer
389+
EValue e;
390+
e.tag = Tag::ListTensor;
391+
e.payload.copyable_union.as_tensor_list_ptr = nullptr;
392+
393+
// Should pass isTensorList() check but fail null pointer check
394+
EXPECT_TRUE(e.isTensorList());
395+
ET_EXPECT_DEATH({ e.toTensorList(); }, "");
396+
}
397+
398+
TEST_F(EValueTest, toListOptionalTensorNullPointerCheck) {
399+
// Create an EValue with ListOptionalTensor tag but null pointer
400+
EValue e;
401+
e.tag = Tag::ListOptionalTensor;
402+
e.payload.copyable_union.as_list_optional_tensor_ptr = nullptr;
403+
404+
// Should pass isListOptionalTensor() check but fail null pointer check
405+
EXPECT_TRUE(e.isListOptionalTensor());
406+
ET_EXPECT_DEATH({ e.toListOptionalTensor(); }, "");
407+
}

0 commit comments

Comments
 (0)