Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2018,17 +2018,21 @@ static ORT_STATUS_PTR OrtGetValueImplSeqOfMap(const OrtValue* p_ml_value, int in
}
#endif

ORT_STATUS_PTR PopulateTensorWithData(Tensor& tensor, bool is_string, _In_ const void* data_elem, size_t num_elems,
size_t elem_size) {
ORT_STATUS_PTR PopulateTensorWithData(Tensor& tensor, _In_ const void* data_elem, size_t num_elems) {
auto len = narrow<size_t>(tensor.Shape().Size());
if (num_elems < len) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input array is too short");
}
if (!is_string) {
memcpy(tensor.MutableDataRaw(), data_elem, elem_size * num_elems);
if (!tensor.IsDataTypeString()) {
// Use the tensor's actual storage size in bytes rather than elem_size * num_elems.
// For packed sub-byte types (e.g., int4/uint4) multiple elements share a storage byte,
// so the naive product over-counts and would over-read the source / overflow the destination.
memcpy(tensor.MutableDataRaw(), data_elem, tensor.SizeInBytes());
} else {
const std::string* strings = reinterpret_cast<const std::string*>(data_elem);
auto str_span = gsl::make_span(strings, num_elems);
// Copy exactly the tensor's element count (len), not num_elems, to avoid writing past
// the destination when the source is larger than the tensor.
auto str_span = gsl::make_span(strings, len);
auto* dst = tensor.MutableData<std::string>();
std::copy(str_span.begin(), str_span.end(), dst);
}
Expand All @@ -2038,8 +2042,7 @@ ORT_STATUS_PTR PopulateTensorWithData(Tensor& tensor, bool is_string, _In_ const
ORT_STATUS_PTR CreateTensorAndPopulate(MLDataType element_type, const int64_t* shape, size_t shape_len,
const void* data, size_t num_elements, _Inout_ OrtAllocator* allocator, OrtValue& result) {
ORT_API_RETURN_IF_ERROR(CreateTensorImpl(element_type, shape, shape_len, allocator, result));
ORT_API_RETURN_IF_ERROR(PopulateTensorWithData(*result.GetMutable<Tensor>(), utils::IsDataTypeString(element_type),
data, num_elements, element_type->Size()));
ORT_API_RETURN_IF_ERROR(PopulateTensorWithData(*result.GetMutable<Tensor>(), data, num_elements));
return nullptr;
}

Expand Down
45 changes: 45 additions & 0 deletions onnxruntime/test/shared_lib/test_nontensor_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <functional>
#include <iostream>
#include <set>
#include <array>

#include "core/common/common.h"
#include "core/session/onnxruntime_cxx_api.h"
Expand Down Expand Up @@ -277,6 +278,50 @@ TEST(CApiTest, CreateGetSeqStringTensors) {
ASSERT_EQ(string_set, std::set<std::string>(std::begin(string_input_data), std::end(string_input_data)));
}

// Test - GetValue() on a sequence of packed sub-byte tensors
// (int4/uint4) must copy only the packed storage bytes.
TEST(CApiTest, CreateGetSeqSubByteTensors) {
auto default_allocator = std::make_unique<MockedOrtAllocator>();
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);

auto run_for_type = [&](ONNXTensorElementDataType elem_type, std::array<uint8_t, 4> packed) {
const std::vector<int64_t> dims{7}; // 7 4-bit elements -> 4 packed bytes
constexpr int N = 2;

std::vector<Ort::Value> in;
for (int i = 0; i < N; ++i) {
Ort::Value tensor = Ort::Value::CreateTensor(info, packed.data(), packed.size(),
dims.data(), dims.size(), elem_type);
in.push_back(std::move(tensor));
}

Ort::Value seq_ort = Ort::Value::CreateSequence(in);

for (int idx = 0; idx < N; ++idx) {
Ort::Value out = seq_ort.GetValue(idx, default_allocator.get());

auto type_info = out.GetTypeInfo();
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
ASSERT_EQ(tensor_info.GetElementType(), elem_type);
ASSERT_EQ(tensor_info.GetShape(), dims);

// Compare the packed bytes directly. GetTensorData<T>() does not support sub-byte
// types, so use the raw pointer and the packing-aware byte size.
const size_t out_bytes = out.GetTensorSizeInBytes();
ASSERT_EQ(out_bytes, packed.size());
const auto* ret = static_cast<const uint8_t*>(out.GetTensorRawData());
for (size_t i = 0; i < out_bytes; ++i) {
ASSERT_EQ(ret[i], packed[i]);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this compare bytes at the same address? I think the tensor is created in a way that its data refers directly to packed's bytes.

}
}
};

// {0, 1, 2, 3, -8, 7, 6, pad_0}
run_for_type(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, {0x10, 0x32, 0x78, 0x06});
// {0, 1, 2, 3, 4, 5, 15, pad_0}
run_for_type(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, {0x10, 0x32, 0x54, 0x0F});
}

TEST(CApiTest, TypeInfoSequence) {
// Creation
auto default_allocator = std::make_unique<MockedOrtAllocator>();
Expand Down
Loading