Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,21 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid
continue;
}

// Skip if unknown rank
auto shape = api_graph->GetValueInfo(node->Inputs()[0])->Shape();
if (!shape.has_value()) {
// The NCHW<->NHWC permutation depends only on rank. For Conv/ConvTranspose (and FusedConv, which is treated as Conv
// here) the data input and the weight share the same rank, so an unknown input[0] rank can be recovered from the
// weight at input[1].
std::optional<size_t> input_rank = api_graph->GetValueInfo(node->Inputs()[0])->ShapeRank();
if (!input_rank.has_value() && (op_type == "Conv" || op_type == "ConvTranspose")) {
input_rank = api_graph->GetValueInfo(node->Inputs()[1])->ShapeRank();
}

// Skip if rank is still unknown.
if (!input_rank.has_value()) {
continue;
}

// Convert to channels last
size_t rank = shape->size();
size_t rank = *input_rank;

bool has_channel_last_attr = node->GetAttributeInt("channels_last").has_value() ? true : false;
if (has_channel_last_attr) {
Expand Down
61 changes: 61 additions & 0 deletions onnxruntime/test/optimizer/transpose_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4664,6 +4664,67 @@ TEST(TransposeOptimizerTests, LayoutTransformDoesNotRetargetNhwcFusedConv) {
EXPECT_EQ(nhwc_fused_conv_count, 1);
}

// Helper function to test layout transformation with unknown input rank but known weight rank.
static void TestLayoutTransformWithUnknownInputRank(const std::string& op_type,
const std::vector<int64_t>& weight_shape) {
std::unordered_map<std::string, int> domain_to_version{{kOnnxDomain, 13}};
Model model("LayoutTransform_" + op_type + "_RecoverRankFromWeight", false, ModelMetaData(), PathString(),
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
DefaultLoggingManager().DefaultLogger());
Graph& graph = model.MainGraph();
ModelTestBuilder builder(graph);

// Create input with unknown shape (cleared).
auto* input_arg = builder.MakeInput<float>({1, 3, 7, 7}, -1.0f, 1.0f);
input_arg->ClearShape();

// Weight has known shape with rank 4.
auto* weight_arg = builder.MakeInitializer<float>(weight_shape, -1.0f, 1.0f);
auto* output_arg = builder.MakeOutput();

auto& node = builder.AddNode(op_type, {input_arg, weight_arg}, {output_arg});
node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
node.AddAttribute("strides", std::vector<int64_t>{1, 1});
node.AddAttribute("kernel_shape", std::vector<int64_t>{3, 3});

builder.SetGraphOutputs();
ASSERT_STATUS_OK(graph.Resolve());

std::string model_data;
model.ToProto().SerializeToString(&model_data);

SessionOptions so;
using InternalTestingEP = internal_testing_ep::InternalTestingExecutionProvider;
const std::unordered_set<std::string> empty_set;
auto internal_testing_ep = std::make_unique<InternalTestingEP>(empty_set, empty_set, DataLayout::NHWC);
internal_testing_ep->EnableStaticKernels().TakeAllNodes();

InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(internal_testing_ep)));
ASSERT_STATUS_OK(session.Load(model_data.data(), static_cast<int>(model_data.size())));
ASSERT_STATUS_OK(session.Initialize());

const auto& optimized_graph = session.GetGraph();
const auto op_to_count = CountOpsInGraph(optimized_graph);
const auto get_op_count = [&op_to_count](std::string_view op_type) {
const auto it = op_to_count.find(std::string{op_type});
return it == op_to_count.end() ? 0 : it->second;
};

// Transpose nodes should be inserted, proving that layout transformation proceeded after recovering rank from weight.
EXPECT_GT(get_op_count("Transpose"), 0) << "Layout transformation should insert Transpose nodes for NCHW->NHWC conversion";
}

// Verifies that layout transformation recovers Conv rank from weight when input rank is unknown.
TEST(TransposeOptimizerTests, LayoutTransformConvRecoverRankFromWeight) {
TestLayoutTransformWithUnknownInputRank("Conv", {8, 3, 3, 3});
}

// Verifies that layout transformation recovers ConvTranspose rank from weight when input rank is unknown.
TEST(TransposeOptimizerTests, LayoutTransformConvTransposeRecoverRankFromWeight) {
Comment thread
qjia7 marked this conversation as resolved.
TestLayoutTransformWithUnknownInputRank("ConvTranspose", {3, 8, 3, 3});
}

TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) {
Status status;
auto model_uri = ORT_TSTR("testdata/layout_transform_reshape.qdq.onnx");
Expand Down
Loading