diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index aabce947365fb..b5ba28d7fc258 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3239,11 +3239,26 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso if (input_type == nullptr) input_type = &null_pointer; // Type error in input model/graph. + // Build a sorted, readable list of the types the operator permits for this input so the + // error message tells the user which types are valid instead of only the offending one. + std::vector permitted_type_strings; + permitted_type_strings.reserve(permitted_types.size()); + for (const auto* permitted_type : permitted_types) { + permitted_type_strings.push_back(permitted_type != nullptr ? *permitted_type : "(null)"); + } + std::sort(permitted_type_strings.begin(), permitted_type_strings.end()); + std::string expected_types_str; + for (size_t t = 0; t < permitted_type_strings.size(); ++t) { + if (t != 0) expected_types_str += ", "; + expected_types_str += permitted_type_strings[t]; + } + Status status(ONNXRUNTIME, INVALID_GRAPH, "This is an invalid model. " "Type Error: Type '" + *input_type + "' of input parameter (" + input_def->Name() + - ") of operator (" + op.Name() + ") in node (" + node_name + ") is invalid."); + ") of operator (" + op.Name() + ") in node (" + node_name + + ") is invalid. Expected one of the following types: " + expected_types_str + "."); return status; } diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 4716acc01ec42..3a73bbdd68700 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -1710,6 +1710,32 @@ TEST_F(GraphTest, VariadicOutput) { CheckTensorEltType(Z.TypeAsProto(), TensorProto_DataType_FLOAT); } +// When an operator receives an input whose type is not one of the operator's permitted +// types, the resolve error should list the expected types to aid debugging. See issue #4429. +TEST_F(GraphTest, TypeErrorMessageListsExpectedTypes) { + Model model("graph_type_error", false, *logger_); + auto& graph = model.MainGraph(); + + TypeProto int64_tensor; + int64_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64); + + auto& X = graph.GetOrCreateNodeArg("X", &int64_tensor); + auto& Y = graph.GetOrCreateNodeArg("Y", nullptr); + graph.SetInputs({&X}); + + // Sin only supports floating point types, so an int64 input is invalid. + graph.AddNode("node_1", "Sin", "node 1.", {&X}, {&Y}); + + auto status = graph.Resolve(); + ASSERT_FALSE(status.IsOK()); + const std::string msg = status.ErrorMessage(); + // The existing context is preserved. + EXPECT_NE(msg.find("of operator (Sin)"), std::string::npos) << msg; + // The improved message lists the operator's permitted types. + EXPECT_NE(msg.find("Expected one of the following types:"), std::string::npos) << msg; + EXPECT_NE(msg.find("tensor(float)"), std::string::npos) << msg; +} + // test that we prefer the graph input shape for a non-const initializer (initializer with matching graph input) TEST_F(GraphTest, NonConstInitializer) { Model model("graph_1", false, *logger_);