Skip to content

Commit f02c12b

Browse files
rascanifacebook-github-bot
authored andcommitted
Add null pointer check for evalues (pytorch#16162)
Summary: The to* functions could de-reference a NULL pointer. Adding a check. Reviewed By: mergennachin Differential Revision: D83742246
1 parent c9f6df1 commit f02c12b

File tree

2 files changed

+196
-3
lines changed

2 files changed

+196
-3
lines changed

runtime/core/evalue.h

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,22 @@ class BoxedEvalueList {
6363
* unwrapped vals.
6464
*/
6565
BoxedEvalueList(EValue** wrapped_vals, T* unwrapped_vals, int size)
66-
: wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {}
66+
: wrapped_vals_(checkWrappedVals(wrapped_vals, size), size),
67+
unwrapped_vals_(checkUnwrappedVals(unwrapped_vals)) {}
68+
69+
private:
70+
static EValue** checkWrappedVals(EValue** wrapped_vals, int size) {
71+
ET_CHECK_MSG(wrapped_vals != nullptr, "wrapped_vals cannot be null");
72+
ET_CHECK_MSG(size >= 0, "size cannot be negative");
73+
return wrapped_vals;
74+
}
75+
76+
static T* checkUnwrappedVals(T* unwrapped_vals) {
77+
ET_CHECK_MSG(unwrapped_vals != nullptr, "unwrapped_vals cannot be null");
78+
return unwrapped_vals;
79+
}
80+
81+
public:
6782
/*
6883
* Constructs and returns the list of T specified by the EValue pointers
6984
*/
@@ -280,6 +295,7 @@ struct EValue {
280295

281296
/****** String Type ******/
282297
/*implicit*/ EValue(executorch::aten::ArrayRef<char>* s) : tag(Tag::String) {
298+
ET_CHECK_MSG(s != nullptr, "ArrayRef<char> pointer cannot be null");
283299
payload.copyable_union.as_string_ptr = s;
284300
}
285301

@@ -289,13 +305,18 @@ struct EValue {
289305

290306
std::string_view toString() const {
291307
ET_CHECK_MSG(isString(), "EValue is not a String.");
308+
ET_CHECK_MSG(
309+
payload.copyable_union.as_string_ptr != nullptr,
310+
"EValue string pointer is null.");
292311
return std::string_view(
293312
payload.copyable_union.as_string_ptr->data(),
294313
payload.copyable_union.as_string_ptr->size());
295314
}
296315

297316
/****** Int List Type ******/
298317
/*implicit*/ EValue(BoxedEvalueList<int64_t>* i) : tag(Tag::ListInt) {
318+
ET_CHECK_MSG(
319+
i != nullptr, "BoxedEvalueList<int64_t> pointer cannot be null");
299320
payload.copyable_union.as_int_list_ptr = i;
300321
}
301322

@@ -305,12 +326,16 @@ struct EValue {
305326

306327
executorch::aten::ArrayRef<int64_t> toIntList() const {
307328
ET_CHECK_MSG(isIntList(), "EValue is not an Int List.");
329+
ET_CHECK_MSG(
330+
payload.copyable_union.as_int_list_ptr != nullptr,
331+
"EValue int list pointer is null.");
308332
return (payload.copyable_union.as_int_list_ptr)->get();
309333
}
310334

311335
/****** Bool List Type ******/
312336
/*implicit*/ EValue(executorch::aten::ArrayRef<bool>* b)
313337
: tag(Tag::ListBool) {
338+
ET_CHECK_MSG(b != nullptr, "ArrayRef<bool> pointer cannot be null");
314339
payload.copyable_union.as_bool_list_ptr = b;
315340
}
316341

@@ -320,12 +345,16 @@ struct EValue {
320345

321346
executorch::aten::ArrayRef<bool> toBoolList() const {
322347
ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List.");
348+
ET_CHECK_MSG(
349+
payload.copyable_union.as_bool_list_ptr != nullptr,
350+
"EValue bool list pointer is null.");
323351
return *(payload.copyable_union.as_bool_list_ptr);
324352
}
325353

326354
/****** Double List Type ******/
327355
/*implicit*/ EValue(executorch::aten::ArrayRef<double>* d)
328356
: tag(Tag::ListDouble) {
357+
ET_CHECK_MSG(d != nullptr, "ArrayRef<double> pointer cannot be null");
329358
payload.copyable_union.as_double_list_ptr = d;
330359
}
331360

@@ -335,12 +364,17 @@ struct EValue {
335364

336365
executorch::aten::ArrayRef<double> toDoubleList() const {
337366
ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List.");
367+
ET_CHECK_MSG(
368+
payload.copyable_union.as_double_list_ptr != nullptr,
369+
"EValue double list pointer is null.");
338370
return *(payload.copyable_union.as_double_list_ptr);
339371
}
340372

341373
/****** Tensor List Type ******/
342374
/*implicit*/ EValue(BoxedEvalueList<executorch::aten::Tensor>* t)
343375
: tag(Tag::ListTensor) {
376+
ET_CHECK_MSG(
377+
t != nullptr, "BoxedEvalueList<Tensor> pointer cannot be null");
344378
payload.copyable_union.as_tensor_list_ptr = t;
345379
}
346380

@@ -350,13 +384,19 @@ struct EValue {
350384

351385
executorch::aten::ArrayRef<executorch::aten::Tensor> toTensorList() const {
352386
ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List.");
387+
ET_CHECK_MSG(
388+
payload.copyable_union.as_tensor_list_ptr != nullptr,
389+
"EValue tensor list pointer is null.");
353390
return payload.copyable_union.as_tensor_list_ptr->get();
354391
}
355392

356393
/****** List Optional Tensor Type ******/
357394
/*implicit*/ EValue(
358395
BoxedEvalueList<std::optional<executorch::aten::Tensor>>* t)
359396
: tag(Tag::ListOptionalTensor) {
397+
ET_CHECK_MSG(
398+
t != nullptr,
399+
"BoxedEvalueList<optional<Tensor>> pointer cannot be null");
360400
payload.copyable_union.as_list_optional_tensor_ptr = t;
361401
}
362402

@@ -366,6 +406,11 @@ struct EValue {
366406

367407
executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>
368408
toListOptionalTensor() const {
409+
ET_CHECK_MSG(
410+
isListOptionalTensor(), "EValue is not a List Optional Tensor.");
411+
ET_CHECK_MSG(
412+
payload.copyable_union.as_list_optional_tensor_ptr != nullptr,
413+
"EValue list optional tensor pointer is null.");
369414
return payload.copyable_union.as_list_optional_tensor_ptr->get();
370415
}
371416

@@ -445,11 +490,15 @@ struct EValue {
445490
// minor performance bump for a code maintainability hit
446491
if (isTensor()) {
447492
payload.as_tensor.~Tensor();
448-
} else if (isTensorList()) {
493+
} else if (
494+
isTensorList() &&
495+
payload.copyable_union.as_tensor_list_ptr != nullptr) {
449496
for (auto& tensor : toTensorList()) {
450497
tensor.~Tensor();
451498
}
452-
} else if (isListOptionalTensor()) {
499+
} else if (
500+
isListOptionalTensor() &&
501+
payload.copyable_union.as_list_optional_tensor_ptr != nullptr) {
453502
for (auto& optional_tensor : toListOptionalTensor()) {
454503
optional_tensor.~optional();
455504
}

runtime/core/test/evalue_test.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,147 @@ TEST_F(EValueTest, ConstructFromNullPtrAborts) {
281281

282282
ET_EXPECT_DEATH({ EValue evalue(null_ptr); }, "");
283283
}
284+
285+
TEST_F(EValueTest, StringConstructorNullCheck) {
286+
executorch::aten::ArrayRef<char>* null_string_ptr = nullptr;
287+
ET_EXPECT_DEATH(
288+
{ EValue evalue(null_string_ptr); },
289+
"ArrayRef<char> pointer cannot be null");
290+
}
291+
292+
TEST_F(EValueTest, BoolListConstructorNullCheck) {
293+
executorch::aten::ArrayRef<bool>* null_bool_list_ptr = nullptr;
294+
ET_EXPECT_DEATH(
295+
{ EValue evalue(null_bool_list_ptr); },
296+
"ArrayRef<bool> pointer cannot be null");
297+
}
298+
299+
TEST_F(EValueTest, DoubleListConstructorNullCheck) {
300+
executorch::aten::ArrayRef<double>* null_double_list_ptr = nullptr;
301+
ET_EXPECT_DEATH(
302+
{ EValue evalue(null_double_list_ptr); },
303+
"ArrayRef<double> pointer cannot be null");
304+
}
305+
306+
TEST_F(EValueTest, IntListConstructorNullCheck) {
307+
BoxedEvalueList<int64_t>* null_int_list_ptr = nullptr;
308+
ET_EXPECT_DEATH(
309+
{ EValue evalue(null_int_list_ptr); },
310+
"BoxedEvalueList<int64_t> pointer cannot be null");
311+
}
312+
313+
TEST_F(EValueTest, TensorListConstructorNullCheck) {
314+
BoxedEvalueList<executorch::aten::Tensor>* null_tensor_list_ptr = nullptr;
315+
ET_EXPECT_DEATH(
316+
{ EValue evalue(null_tensor_list_ptr); },
317+
"BoxedEvalueList<Tensor> pointer cannot be null");
318+
}
319+
320+
TEST_F(EValueTest, OptionalTensorListConstructorNullCheck) {
321+
BoxedEvalueList<std::optional<executorch::aten::Tensor>>*
322+
null_optional_tensor_list_ptr = nullptr;
323+
ET_EXPECT_DEATH(
324+
{ EValue evalue(null_optional_tensor_list_ptr); },
325+
"BoxedEvalueList<optional<Tensor>> pointer cannot be null");
326+
}
327+
328+
TEST_F(EValueTest, BoxedEvalueListConstructorNullChecks) {
329+
int64_t storage[3] = {0, 0, 0};
330+
EValue values[3] = {
331+
EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)};
332+
EValue* values_p[3] = {&values[0], &values[1], &values[2]};
333+
334+
// Test null wrapped_vals
335+
ET_EXPECT_DEATH(
336+
{ BoxedEvalueList<int64_t> list(nullptr, storage, 3); },
337+
"wrapped_vals cannot be null");
338+
339+
// Test null unwrapped_vals
340+
ET_EXPECT_DEATH(
341+
{ BoxedEvalueList<int64_t> list(values_p, nullptr, 3); },
342+
"unwrapped_vals cannot be null");
343+
344+
// Test negative size
345+
ET_EXPECT_DEATH(
346+
{ BoxedEvalueList<int64_t> list(values_p, storage, -1); },
347+
"size cannot be negative");
348+
}
349+
350+
TEST_F(EValueTest, toListOptionalTensorTypeCheck) {
351+
// Create an EValue that's not a ListOptionalTensor
352+
EValue e((int64_t)42);
353+
EXPECT_TRUE(e.isInt());
354+
EXPECT_FALSE(e.isListOptionalTensor());
355+
356+
// Should fail type check
357+
ET_EXPECT_DEATH(
358+
{ e.toListOptionalTensor(); }, "EValue is not a List Optional Tensor");
359+
}
360+
361+
TEST_F(EValueTest, toStringNullPointerCheck) {
362+
// Create an EValue with String tag but null pointer
363+
EValue e;
364+
e.tag = Tag::String;
365+
e.payload.copyable_union.as_string_ptr = nullptr;
366+
367+
// Should pass isString() check but fail null pointer check
368+
EXPECT_TRUE(e.isString());
369+
ET_EXPECT_DEATH({ e.toString(); }, "EValue string pointer is null");
370+
}
371+
372+
TEST_F(EValueTest, toIntListNullPointerCheck) {
373+
// Create an EValue with ListInt tag but null pointer
374+
EValue e;
375+
e.tag = Tag::ListInt;
376+
e.payload.copyable_union.as_int_list_ptr = nullptr;
377+
378+
// Should pass isIntList() check but fail null pointer check
379+
EXPECT_TRUE(e.isIntList());
380+
ET_EXPECT_DEATH({ e.toIntList(); }, "EValue int list pointer is null");
381+
}
382+
383+
TEST_F(EValueTest, toBoolListNullPointerCheck) {
384+
// Create an EValue with ListBool tag but null pointer
385+
EValue e;
386+
e.tag = Tag::ListBool;
387+
e.payload.copyable_union.as_bool_list_ptr = nullptr;
388+
389+
// Should pass isBoolList() check but fail null pointer check
390+
EXPECT_TRUE(e.isBoolList());
391+
ET_EXPECT_DEATH({ e.toBoolList(); }, "EValue bool list pointer is null");
392+
}
393+
394+
TEST_F(EValueTest, toDoubleListNullPointerCheck) {
395+
// Create an EValue with ListDouble tag but null pointer
396+
EValue e;
397+
e.tag = Tag::ListDouble;
398+
e.payload.copyable_union.as_double_list_ptr = nullptr;
399+
400+
// Should pass isDoubleList() check but fail null pointer check
401+
EXPECT_TRUE(e.isDoubleList());
402+
ET_EXPECT_DEATH({ e.toDoubleList(); }, "EValue double list pointer is null");
403+
}
404+
405+
TEST_F(EValueTest, toTensorListNullPointerCheck) {
406+
// Create an EValue with ListTensor tag but null pointer
407+
EValue e;
408+
e.tag = Tag::ListTensor;
409+
e.payload.copyable_union.as_tensor_list_ptr = nullptr;
410+
411+
// Should pass isTensorList() check but fail null pointer check
412+
EXPECT_TRUE(e.isTensorList());
413+
ET_EXPECT_DEATH({ e.toTensorList(); }, "EValue tensor list pointer is null");
414+
}
415+
416+
TEST_F(EValueTest, toListOptionalTensorNullPointerCheck) {
417+
// Create an EValue with ListOptionalTensor tag but null pointer
418+
EValue e;
419+
e.tag = Tag::ListOptionalTensor;
420+
e.payload.copyable_union.as_list_optional_tensor_ptr = nullptr;
421+
422+
// Should pass isListOptionalTensor() check but fail null pointer check
423+
EXPECT_TRUE(e.isListOptionalTensor());
424+
ET_EXPECT_DEATH(
425+
{ e.toListOptionalTensor(); },
426+
"EValue list optional tensor pointer is null");
427+
}

0 commit comments

Comments
 (0)