diff --git a/onnxruntime/contrib_ops/cpu/activations.cc b/onnxruntime/contrib_ops/cpu/activations.cc index 3e0533dd8b9e5..43c84319e94d5 100644 --- a/onnxruntime/contrib_ops/cpu/activations.cc +++ b/onnxruntime/contrib_ops/cpu/activations.cc @@ -4,6 +4,8 @@ #include "core/providers/cpu/activation/activations.h" #include "contrib_ops/cpu/activations.h" +#include "core/framework/allocator.h" + namespace onnxruntime { namespace contrib { @@ -26,13 +28,72 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), ThresholdedRelu); -ONNX_OPERATOR_KERNEL_EX( - QuickGelu, - kMSDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - QuickGelu); +// QuickGelu for MLFloat16 is computed in fp32 and converted back to fp16. This keeps the +// Swish/SiLU activation fused into a single kernel (instead of running as separate Sigmoid + Mul +// nodes), which is meaningfully faster on ARMv8.2-A CPUs, while remaining correct on CPUs without +// native fp16 support. +template <> +Status QuickGelu::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const MLFloat16* input_data = input->Data(); + Tensor* output = context->Output(0, input->Shape()); + MLFloat16* output_data = output->MutableData(); + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); + int64_t elem_count = input->Shape().Size(); + if (elem_count == 0) { + return Status::OK(); + } + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + const size_t count = onnxruntime::narrow(elem_count); + auto input_fp32 = IAllocator::MakeUniquePtr(allocator, count); + auto output_fp32 = IAllocator::MakeUniquePtr(allocator, count); + + MlasConvertHalfToFloatBufferInParallel(input_data, input_fp32.get(), count, tp); + + const float alpha = alpha_; + float* input_fp32_data = input_fp32.get(); + float* output_fp32_data = output_fp32.get(); + constexpr int64_t length_per_task = 4096; // this number comes from FastGelu. + int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; + concurrency::ThreadPool::TryBatchParallelFor( + tp, static_cast(task_count), + [&](ptrdiff_t task_idx) { + const auto start = task_idx * length_per_task; + const float* p_input = input_fp32_data + start; + float* p_output = output_fp32_data + start; + int64_t task_elems = std::min(length_per_task, elem_count - start); + + if (alpha == 1.0f) { + MlasComputeSilu(p_input, p_output, onnxruntime::narrow(task_elems)); + return; + } + + for (int64_t i = 0; i < task_elems; i++) { + p_output[i] = p_input[i] * alpha; + } + + MlasComputeLogistic(p_output, p_output, onnxruntime::narrow(task_elems)); + + MlasEltwiseMul(p_input, p_output, p_output, onnxruntime::narrow(task_elems)); + }, + 0); + + MlasConvertFloatToHalfBufferInParallel(output_fp32_data, output_data, count, tp); + + return Status::OK(); +} + +#define REGISTER_QUICKGELU_KERNEL(data_type) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + QuickGelu, kMSDomain, 1, data_type, kCpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + QuickGelu); + +REGISTER_QUICKGELU_KERNEL(float); +REGISTER_QUICKGELU_KERNEL(MLFloat16); } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 0749457f5a182..f9c3e6417a0b6 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -75,7 +75,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BiasG class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuickGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); // ******** Start: Quantization ******************* // @@ -374,7 +375,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to main backward compatibility diff --git a/onnxruntime/test/contrib_ops/activation_op_test.cc b/onnxruntime/test/contrib_ops/activation_op_test.cc index 061fffa572be2..881e35e6e5a49 100644 --- a/onnxruntime/test/contrib_ops/activation_op_test.cc +++ b/onnxruntime/test/contrib_ops/activation_op_test.cc @@ -152,5 +152,50 @@ TEST_F(ActivationOpTest, QuickGelu) { } } +TEST_F(ActivationOpTest, QuickGelu_fp16) { + // Use enough elements to cross the 4096-element chunk boundary used by the + // QuickGelu::Compute() specialization. 8205 = 2 * 4096 + 13 exercises + // the multi-task path as well as a final partial (tail) chunk. + constexpr int64_t element_count = 2 * 4096 + 13; + std::vector input_values; + input_values.reserve(element_count); + // Seed with corner values, then fill the remainder with a varied ramp. + const std::vector seed_values{-1.0f, 0.0f, 1.0f, 2.5f, -2.5f, 5.0f, -5.0f, 0.3f}; + input_values.insert(input_values.end(), seed_values.begin(), seed_values.end()); + for (int64_t i = static_cast(seed_values.size()); i < element_count; ++i) { + // Range roughly [-6, 6] to cover both saturation tails and the linear region. + input_values.push_back(static_cast(((i % 121) - 60)) * 0.1f); + } + std::vector dims{static_cast(input_values.size())}; + + auto quick_gelu = [](float x, float alpha) { + auto tmp = x * alpha; + auto y = 1.f / (1.f + std::exp(-std::abs(tmp))); // safe sigmoid + y = tmp >= 0 ? y : 1 - y; + return x * y; + }; + + for (float alpha : {1.702f, 1.0f, -1.702f}) { + std::vector input_fp16; + std::vector output_fp16; + input_fp16.reserve(input_values.size()); + output_fp16.reserve(input_values.size()); + for (float x : input_values) { + input_fp16.push_back(MLFloat16(x)); + output_fp16.push_back(MLFloat16(quick_gelu(x, alpha))); + } + + OpTester test("QuickGelu", 1, kMSDomain); + test.AddAttribute("alpha", alpha); + test.AddInput("X", dims, input_fp16); + test.AddOutput("Y", dims, output_fp16); + // Relax tolerance because the reference is computed in fp32. + test.SetOutputTolerance(0.005f); + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + } // namespace test } // namespace onnxruntime