diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index 4f8c3c068f1..3f6798f7309 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -95,6 +95,7 @@ TFLMRegistration Register_MIRROR_PAD(); TFLMRegistration Register_MUL(); TFLMRegistration Register_NEG(); TFLMRegistration Register_NOT_EQUAL(); +TFLMRegistration* Register_ONE_HOT(); TFLMRegistration Register_PACK(); TFLMRegistration Register_PAD(); TFLMRegistration Register_PADV2(); diff --git a/tensorflow/lite/micro/kernels/one_hot.cc b/tensorflow/lite/micro/kernels/one_hot.cc new file mode 100644 index 00000000000..a9a7eb2d183 --- /dev/null +++ b/tensorflow/lite/micro/kernels/one_hot.cc @@ -0,0 +1,243 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_common.h" + +namespace tflite { +namespace { + +constexpr int kIndicesTensor = 0; +constexpr int kDepthTensor = 1; +constexpr int kOnValueTensor = 2; +constexpr int kOffValueTensor = 3; +constexpr int kOutputTensor = 0; + +namespace { // Local Util functions +inline int NumElements(const TfLiteEvalTensor* t) { + int count = 1; + for (int i = 0; i < t->dims->size; ++i) { + count *= t->dims->data[i]; + } + return count; +} +} // namespace + +// Retrieves the input tensors (indices, depth, on_value, off_value) and the +// output tensor (output) from the TfLiteNode. +// Reads params->axis to compute the actual position (axis) where the depth +// dimension will be inserted. +// These values are created temporarily within the Prepare and Eval functions +// and are destroyed afterward → efficient use of stack memory. +struct OneHotContext { + OneHotContext(TfLiteContext* context, TfLiteNode* node) { + indices = tflite::micro::GetEvalInput(context, node, kIndicesTensor); + depth = tflite::micro::GetEvalInput(context, node, kDepthTensor); + on_value = tflite::micro::GetEvalInput(context, node, kOnValueTensor); + off_value = tflite::micro::GetEvalInput(context, node, kOffValueTensor); + output = tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + const auto* params = + reinterpret_cast(node->builtin_data); + const int indices_dims = indices->dims->size; + axis = (params->axis == -1) ? indices_dims : params->axis; + output_dims = indices_dims + 1; + dtype = on_value->type; + } + + const TfLiteEvalTensor* indices; + const TfLiteEvalTensor* depth; + const TfLiteEvalTensor* on_value; + const TfLiteEvalTensor* off_value; + TfLiteEvalTensor* output; + + int axis; + int output_dims; + TfLiteType dtype; +}; + +// Operation function +template +void OneHotComputeImpl(const OneHotContext& op_context) { + int prefix_dim_size = 1; + for (int i = 0; i < op_context.axis; ++i) { + prefix_dim_size *= op_context.indices->dims->data[i]; + } + if (prefix_dim_size == 0) { + return; + } + + const RuntimeShape indices_shape = + tflite::micro::GetTensorShape(op_context.indices); + const int suffix_dim_size = indices_shape.FlatSize() / prefix_dim_size; + + const int32_t* depth_ptr = + tflite::micro::GetTensorData(op_context.depth); + if (depth_ptr == nullptr) return; + const int depth = *depth_ptr; + + const T on_value = *tflite::micro::GetTensorData(op_context.on_value); + const T off_value = *tflite::micro::GetTensorData(op_context.off_value); + + T* output_data = tflite::micro::GetTensorData(op_context.output); + const TI* indices_data = tflite::micro::GetTensorData(op_context.indices); + + for (int i = 0; i < prefix_dim_size; ++i) { + for (int j = 0; j < depth; ++j) { + for (int k = 0; k < suffix_dim_size; ++k, ++output_data) { + *output_data = + static_cast(indices_data[i * suffix_dim_size + k]) == j + ? on_value + : off_value; + } + } + } +} + +template +void OneHotCompute(const OneHotContext& op_context) { + if (op_context.indices->type == kTfLiteInt64) { + OneHotComputeImpl(op_context); + } else { + OneHotComputeImpl(op_context); + } +} + +TfLiteStatus EnsureOutputDimsMatchExpected(TfLiteContext* context, + const OneHotContext& op_context) { + // read depth data + const int32_t* depth_ptr = + tflite::micro::GetTensorData(op_context.depth); + TF_LITE_ENSURE(context, depth_ptr != nullptr); + + const int depth_val = *depth_ptr; + TF_LITE_ENSURE(context, depth_val >= 0); + + // Output Tensor evaluation + TF_LITE_ENSURE(context, op_context.output != nullptr); + + TF_LITE_ENSURE(context, op_context.output->dims != nullptr); + + // TFLM assumes that the output tensor’s dims are already allocated + const int expected_dims_size = op_context.output_dims; + TF_LITE_ENSURE_EQ(context, op_context.output->dims->size, expected_dims_size); + + for (int i = 0; i < expected_dims_size; ++i) { + int expected_dim_i; + if (i < op_context.axis) { + expected_dim_i = op_context.indices->dims->data[i]; + } else if (i == op_context.axis) { + expected_dim_i = depth_val; + } else { + expected_dim_i = op_context.indices->dims->data[i - 1]; + } + + // If the size pre-allocated by the TFLM Memory Planner does not match the + // actual computed size, an error is raised. + TF_LITE_ENSURE_EQ(context, op_context.output->dims->data[i], + expected_dim_i); + } + + return kTfLiteOk; +} + +void* OneHotInit(TfLiteContext* context, const char* buffer, size_t length) { + (void)context; + (void)buffer; + (void)length; + // This kernel does not require persistent op data. + return nullptr; +} + +TfLiteStatus OneHotPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + OneHotContext op_context{context, node}; + TF_LITE_ENSURE(context, op_context.output != nullptr); + + switch (op_context.dtype) { + case kTfLiteFloat32: + case kTfLiteInt16: + case kTfLiteInt32: + case kTfLiteInt64: + case kTfLiteInt8: + case kTfLiteUInt8: + case kTfLiteBool: + op_context.output->type = op_context.dtype; + break; + default: + TF_LITE_KERNEL_LOG(context, "Unknown output data type: %s", + TfLiteTypeGetName(op_context.dtype)); + return kTfLiteError; + } + + TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 || + op_context.indices->type == kTfLiteInt64); + TF_LITE_ENSURE(context, op_context.axis >= 0 && + op_context.axis < op_context.output_dims); + TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1); + TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1); + TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1); + TF_LITE_ENSURE_TYPES_EQ(context, op_context.on_value->type, op_context.dtype); + TF_LITE_ENSURE_TYPES_EQ(context, op_context.off_value->type, + op_context.dtype); + + // Even if the depth tensor is not a constant, the test predefines the output + // shape, so here we only perform validation. + return EnsureOutputDimsMatchExpected(context, op_context); +} + +TfLiteStatus OneHotEval(TfLiteContext* context, TfLiteNode* node) { + OneHotContext op_context{context, node}; + + switch (op_context.output->type) { + case kTfLiteFloat32: + OneHotCompute(op_context); + break; + case kTfLiteInt32: + OneHotCompute(op_context); + break; + case kTfLiteInt64: + OneHotCompute(op_context); + break; + case kTfLiteInt8: + OneHotCompute(op_context); + break; + case kTfLiteUInt8: + OneHotCompute(op_context); + break; + case kTfLiteBool: + OneHotCompute(op_context); + break; + default: + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace + +const TFLMRegistration* Register_ONE_HOT() { + static TFLMRegistration r = + tflite::micro::RegisterOp(OneHotInit, OneHotPrepare, OneHotEval); + return &r; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/one_hot_test.cc b/tensorflow/lite/micro/one_hot_test.cc new file mode 100644 index 00000000000..c3a6cd97924 --- /dev/null +++ b/tensorflow/lite/micro/one_hot_test.cc @@ -0,0 +1,95 @@ +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/kernels/micro_ops.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +// Helper function for OneHot operation test +template +void TestOneHot(const int* indices_dims, const int32_t* indices_data, + const int* depth_dims, const int32_t* depth_data, + const int* on_dims, const T* on_data, const int* off_dims, + const T* off_data, const int* output_dims, + const T* expected_output_data, T* output_data, int axis = -1) { + // 1. Tensor Setting + TfLiteIntArray* in_dims = IntArrayFromInts(indices_dims); + TfLiteIntArray* d_dims = IntArrayFromInts(depth_dims); + TfLiteIntArray* on_val_dims = IntArrayFromInts(on_dims); + TfLiteIntArray* off_val_dims = IntArrayFromInts(off_dims); + TfLiteIntArray* out_dims = IntArrayFromInts(output_dims); + + const int output_dims_count = ElementCount(*out_dims); + + // 2. Create Input Tensor + constexpr int inputs_size = 4; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateTensor(indices_data, in_dims), CreateTensor(depth_data, d_dims), + CreateTensor(on_data, on_val_dims), CreateTensor(off_data, off_val_dims), + CreateTensor(output_data, out_dims), // Output Tensor + }; + + // 3. Parameter setting + TfLiteOneHotParams builtin_data = {axis}; + + // 4. KernelRunner execution + int inputs_array_data[] = {4, 0, 1, 2, 3}; // indices, depth, on, off + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 4}; // output tensor index + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + // tflite::Register_ONE_HOT) + const TFLMRegistration* registration = tflite::Register_ONE_HOT(); + micro::KernelRunner runner(*registration, tensors, tensors_size, inputs_array, + outputs_array, + reinterpret_cast(&builtin_data)); + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); + + // 5. Result evaluation + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +// UNIT TEST +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(OneHot_BasicInt32) { + // Indices: [0, 1, 2] + const int indices_dims[] = {1, 3}; + const int32_t indices_data[] = {0, 1, 2}; + + // Depth: 3 + const int depth_dims[] = {1, 1}; + const int32_t depth_data[] = {3}; + + // On: 1, Off: 0 + const int on_dims[] = {1, 1}; + const int32_t on_data[] = {1}; + const int off_dims[] = {1, 1}; + const int32_t off_data[] = {0}; + + // Output: [3, 3] -> Identity Matrix + const int output_dims[] = {2, 3, 3}; + const int32_t expected_output[] = {1, 0, 0, 0, 1, 0, 0, 0, 1}; + + int32_t output_data[9]; + + tflite::testing::TestOneHot(indices_dims, indices_data, depth_dims, + depth_data, on_dims, on_data, off_dims, off_data, + output_dims, expected_output, output_data); +} + +TF_LITE_MICRO_TESTS_END \ No newline at end of file diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 21f21a1ce05..eff82997264 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -361,7 +361,8 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/arena_allocator/single_arena_buffer_allo $(TENSORFLOW_ROOT)tensorflow/lite/micro/testing_helpers_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_planner/greedy_memory_planner_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_planner/linear_memory_planner_test.cc \ -$(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_planner/non_persistent_buffer_planner_shim_test.cc +$(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_planner/non_persistent_buffer_planner_shim_test.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/one_hot_test.cc MICROLITE_CC_KERNEL_SRCS := \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/activations.cc \ @@ -437,6 +438,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/mirror_pad.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/mul.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/mul_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/neg.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/one_hot.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/pack.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/pad.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/pad_common.cc \ @@ -480,7 +482,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/unidirectional_sequence_lstm.cc $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/unpack.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/var_handle.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/while.cc \ -$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/zeros_like.cc +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/zeros_like.cc \ MICROLITE_CC_SIGNAL_KERNEL_SRCS := \ $(TENSORFLOW_ROOT)signal/micro/kernels/delay.cc \