Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 68 additions & 7 deletions onnxruntime/contrib_ops/cpu/activations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "core/providers/cpu/activation/activations.h"
#include "contrib_ops/cpu/activations.h"

#include "core/framework/allocator.h"

namespace onnxruntime {
namespace contrib {

Expand All @@ -26,13 +28,72 @@
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
ThresholdedRelu<float>);

ONNX_OPERATOR_KERNEL_EX(
QuickGelu,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
QuickGelu<float>);
// QuickGelu for MLFloat16 is computed in fp32 and converted back to fp16. This keeps the
// Swish/SiLU activation fused into a single kernel (instead of running as separate Sigmoid + Mul
// nodes), which is meaningfully faster on ARMv8.2-A CPUs, while remaining correct on CPUs without
// native fp16 support.
template <>
Status QuickGelu<MLFloat16>::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const MLFloat16* input_data = input->Data<MLFloat16>();
Tensor* output = context->Output(0, input->Shape());
MLFloat16* output_data = output->MutableData<MLFloat16>();
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
int64_t elem_count = input->Shape().Size();
if (elem_count == 0) {
return Status::OK();
}

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

const size_t count = onnxruntime::narrow<size_t>(elem_count);
auto input_fp32 = IAllocator::MakeUniquePtr<float>(allocator, count);
auto output_fp32 = IAllocator::MakeUniquePtr<float>(allocator, count);

MlasConvertHalfToFloatBufferInParallel(input_data, input_fp32.get(), count, tp);

const float alpha = alpha_;
float* input_fp32_data = input_fp32.get();
float* output_fp32_data = output_fp32.get();
constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
concurrency::ThreadPool::TryBatchParallelFor(
tp, static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
const auto start = task_idx * length_per_task;
const float* p_input = input_fp32_data + start;
float* p_output = output_fp32_data + start;
int64_t task_elems = std::min(length_per_task, elem_count - start);

Check warning on line 67 in onnxruntime/contrib_ops/cpu/activations.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/cpu/activations.cc:67: Add #include <algorithm> for min [build/include_what_you_use] [4]

if (alpha == 1.0f) {
MlasComputeSilu(p_input, p_output, onnxruntime::narrow<size_t>(task_elems));
return;
}

for (int64_t i = 0; i < task_elems; i++) {
p_output[i] = p_input[i] * alpha;
}

MlasComputeLogistic(p_output, p_output, onnxruntime::narrow<size_t>(task_elems));

MlasEltwiseMul<float>(p_input, p_output, p_output, onnxruntime::narrow<size_t>(task_elems));
},
0);

MlasConvertFloatToHalfBufferInParallel(output_fp32_data, output_data, count, tp);

return Status::OK();
}

#define REGISTER_QUICKGELU_KERNEL(data_type) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
QuickGelu, kMSDomain, 1, data_type, kCpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
QuickGelu<data_type>);

REGISTER_QUICKGELU_KERNEL(float);
REGISTER_QUICKGELU_KERNEL(MLFloat16);

} // namespace contrib
} // namespace onnxruntime
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BiasG
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention);

// ******** Start: Quantization ******************* //
Expand Down Expand Up @@ -374,7 +375,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention)>,
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
// contrib ops to main backward compatibility
Expand Down
45 changes: 45 additions & 0 deletions onnxruntime/test/contrib_ops/activation_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,50 @@
}
}

TEST_F(ActivationOpTest, QuickGelu_fp16) {
// Use enough elements to cross the 4096-element chunk boundary used by the
// QuickGelu<MLFloat16>::Compute() specialization. 8205 = 2 * 4096 + 13 exercises
// the multi-task path as well as a final partial (tail) chunk.
constexpr int64_t element_count = 2 * 4096 + 13;
std::vector<float> input_values;

Check warning on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / build_x64_debug

declaration of 'input_values' hides class member

Check failure on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / build_x64_debug

the following warning is treated as an error

Check warning on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / build_x64_release_ep_generic_interface

declaration of 'input_values' hides class member

Check failure on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / build_x64_release_ep_generic_interface

the following warning is treated as an error

Check warning on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

declaration of 'input_values' hides class member

Check failure on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

the following warning is treated as an error

Check warning on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / build_x64_release

declaration of 'input_values' hides class member

Check failure on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / build_x64_release

the following warning is treated as an error

Check warning on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / build_x64_release_xnnpack

declaration of 'input_values' hides class member

Check failure on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / build_x64_release_xnnpack

the following warning is treated as an error

Check warning on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / build_x86_release

declaration of 'input_values' hides class member

Check warning on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / Windows GPU CUDA CI Pipeline

declaration of 'input_values' hides class member

Check failure on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / Windows GPU CUDA CI Pipeline

the following warning is treated as an error

Check warning on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / Windows GPU DML CI Pipeline

declaration of 'input_values' hides class member

Check failure on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / Windows GPU DML CI Pipeline

the following warning is treated as an error

Check failure on line 160 in onnxruntime/test/contrib_ops/activation_op_test.cc

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, static)

the following warning is treated as an error
input_values.reserve(element_count);
// Seed with corner values, then fill the remainder with a varied ramp.
const std::vector<float> seed_values{-1.0f, 0.0f, 1.0f, 2.5f, -2.5f, 5.0f, -5.0f, 0.3f};
input_values.insert(input_values.end(), seed_values.begin(), seed_values.end());
for (int64_t i = static_cast<int64_t>(seed_values.size()); i < element_count; ++i) {
// Range roughly [-6, 6] to cover both saturation tails and the linear region.
input_values.push_back(static_cast<float>(((i % 121) - 60)) * 0.1f);
}
std::vector<int64_t> dims{static_cast<int64_t>(input_values.size())};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this test cross the 4096-element chunk boundary used by the new QuickGelu<MLFloat16>::Compute() specialization? With only eight values, the test covers the math branches but would not catch a bug in task partitioning or the final partial chunk. A 4097-element case, or a small multiple plus remainder, would exercise the multi-task/tail path as well.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in e7fb2db. The fp16 test now uses 8205 = 2 * 4096 + 13 elements (seeded with the original corner values plus a varied ramp), so it crosses the 4096-element chunk boundary and exercises the multi-task path plus a final partial/tail chunk.


auto quick_gelu = [](float x, float alpha) {
auto tmp = x * alpha;
auto y = 1.f / (1.f + std::exp(-std::abs(tmp))); // safe sigmoid
y = tmp >= 0 ? y : 1 - y;
return x * y;
};

for (float alpha : {1.702f, 1.0f, -1.702f}) {
std::vector<MLFloat16> input_fp16;
std::vector<MLFloat16> output_fp16;
input_fp16.reserve(input_values.size());
output_fp16.reserve(input_values.size());
for (float x : input_values) {
input_fp16.push_back(MLFloat16(x));
output_fp16.push_back(MLFloat16(quick_gelu(x, alpha)));
}

OpTester test("QuickGelu", 1, kMSDomain);
test.AddAttribute("alpha", alpha);
test.AddInput<MLFloat16>("X", dims, input_fp16);
test.AddOutput<MLFloat16>("Y", dims, output_fp16);
// Relax tolerance because the reference is computed in fp32.
test.SetOutputTolerance(0.005f);
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

} // namespace test
} // namespace onnxruntime
Loading