Skip to content

Commit a9be6b7

Browse files
authored
[webgpu] Implement Split operator (#23198)
Test: onnxruntime_test_all.exe --gtest_filter=SplitOperatorTest.* ### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 377165f commit a9be6b7

File tree

3 files changed

+229
-5
lines changed

3 files changed

+229
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/tensor/split.h"
5+
#include "core/providers/webgpu/shader_helper.h"
6+
#include "core/providers/webgpu/webgpu_supported_types.h"
7+
8+
namespace onnxruntime {
9+
namespace webgpu {
10+
11+
namespace {
12+
13+
// Helper function to calculate the output index based on the input index and the sizes of the splits.
14+
void CalculateOutputIndex(std::ostream& os, size_t output_count) {
15+
os << "fn calculate_output_index(index: u32) -> u32 {\n"
16+
<< " for (var i: u32 = 0u; i < " << output_count << "u; i += 1u ) {\n"
17+
<< " if (index < " << GetElementAt("uniforms.sizes_in_split_axis", "i", output_count) << ") {\n"
18+
<< " return i;\n"
19+
<< " }\n"
20+
<< " }\n"
21+
<< " return " << output_count << "u;\n"
22+
<< "}\n";
23+
}
24+
25+
// Helper function to write the buffer data for each output.
26+
void WriteBufferData(std::ostream& os, const ShaderVariableHelper& input,
27+
gsl::span<const ShaderVariableHelper*> outputs) {
28+
os << "fn write_buffer_data(output_number: u32, global_idx: u32, indices: output_0_indices_t) {\n";
29+
for (size_t i = 0; i < outputs.size(); ++i) {
30+
const auto buffer_write = outputs[i]->SetByIndices("indices", input.GetByOffset("global_idx"));
31+
if (outputs.size() == 1) {
32+
os << buffer_write;
33+
} else if (i == 0) {
34+
os << " if (output_number == 0u) {\n"
35+
<< " " << buffer_write << "\n";
36+
} else if (i == outputs.size() - 1) {
37+
os << " } else {\n"
38+
<< " " << buffer_write << "\n";
39+
} else {
40+
os << " } else if (output_number == " << i << "u) {\n"
41+
<< " " << buffer_write << "\n";
42+
}
43+
}
44+
os << " }\n"
45+
<< "}\n";
46+
}
47+
48+
} // namespace
49+
50+
Status SplitProgram::GenerateShaderCode(ShaderHelper& shader) const {
51+
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
52+
53+
size_t output_count = Outputs().size();
54+
std::vector<const ShaderVariableHelper*> outputs;
55+
outputs.reserve(output_count);
56+
for (size_t i = 0; i < output_count; ++i) {
57+
outputs.push_back(
58+
&shader.AddOutput("output_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias));
59+
}
60+
61+
// Add implementation of fn calculate_output_index.
62+
CalculateOutputIndex(shader.AdditionalImplementation(), output_count);
63+
// Add implementation of fn write_buffer_data.
64+
WriteBufferData(shader.AdditionalImplementation(), input, outputs);
65+
66+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size")
67+
<< " var indices = " << input.OffsetToIndices("global_idx") << ";\n"
68+
<< " var index = indices[" << axis_ << "];\n"
69+
<< " let output_number = calculate_output_index(index);\n"
70+
<< " if (output_number != 0u) {\n"
71+
<< " index -= uniforms.sizes_in_split_axis[output_number - 1u];\n"
72+
<< " indices[" << axis_ << "] = index;\n"
73+
<< " }\n"
74+
<< " write_buffer_data(output_number, global_idx, indices);\n";
75+
76+
return Status::OK();
77+
}
78+
79+
Status Split::ComputeInternal(ComputeContext& context) const {
80+
const Tensor* input = context.Input<Tensor>(0);
81+
auto& input_shape = input->Shape();
82+
auto num_outputs = context.OutputCount();
83+
84+
int64_t axis = axis_;
85+
std::vector<int64_t> split_sizes;
86+
87+
split_sizes.assign(split_sizes_.begin(), split_sizes_.end());
88+
// Compute split_sizes from the 'split' input tensor.
89+
if (split_sizes_.size() == 0 && context.InputCount() > 1) {
90+
const Tensor* split_tensor = context.Input<Tensor>(1);
91+
// Check if split_tensor is valid.
92+
if (split_tensor != nullptr) {
93+
ORT_ENFORCE(split_tensor->Shape().NumDimensions() == 1, "The split tensor must be a vector tensor.");
94+
// Get split_sizes from the input tensor.
95+
auto nDims = static_cast<size_t>(split_tensor->Shape()[0]);
96+
const auto* data = split_tensor->Data<int64_t>();
97+
split_sizes.assign(data, data + nDims);
98+
}
99+
}
100+
101+
// The variables below are not actually used in the current implementation.
102+
int before_dims = 0;
103+
int after_dims_including_split_axis = 0;
104+
int after_dims_excluding_split = 0;
105+
// This handles the case where the axis is negative. It also splits outputs evenly according to num_ouputs if
106+
// split_sizes is empty.
107+
ORT_RETURN_IF_ERROR(PrepareForCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis,
108+
after_dims_excluding_split, split_sizes));
109+
110+
SplitProgram program{gsl::narrow_cast<uint32_t>(axis)};
111+
program.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank});
112+
113+
auto output_dimensions = input_shape.AsShapeVector();
114+
for (int i = 0; i < num_outputs; ++i) {
115+
// Update the size of dimension for axis we're splitting on.
116+
auto split_size = narrow<int>(split_sizes[i]);
117+
output_dimensions[narrow<size_t>(axis)] = split_size;
118+
119+
Tensor* output = context.Output(i, TensorShape{output_dimensions});
120+
program.AddOutput({output, ProgramTensorMetadataDependency::Rank});
121+
}
122+
123+
uint32_t input_size = gsl::narrow<uint32_t>(input_shape.Size());
124+
// Early return if the input tensor is empty.
125+
if (input_size == 0) {
126+
return Status::OK();
127+
}
128+
129+
uint32_t previous_sum = 0;
130+
std::vector<uint32_t> sizes_in_split_axis;
131+
// sizes_in_split_axis are the cumulative sizes of the splits in the split axis.
132+
for (auto split_size : split_sizes) {
133+
previous_sum += gsl::narrow<uint32_t>(split_size);
134+
sizes_in_split_axis.push_back(previous_sum);
135+
}
136+
137+
program
138+
.SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
139+
.CacheHint(std::to_string(axis))
140+
.AddUniformVariables(
141+
{input_size, gsl::span<const uint32_t>(sizes_in_split_axis.data(), sizes_in_split_axis.size())});
142+
return context.RunProgram(program);
143+
}
144+
145+
#define WEBGPU_SPLIT_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \
146+
ONNX_OPERATOR_KERNEL_EX(OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \
147+
KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \
148+
KERNEL_CLASS);
149+
150+
#define WEBGPU_SPLIT_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \
151+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \
152+
KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \
153+
KERNEL_CLASS);
154+
155+
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 1, 1, Split_1, WebGpuSupportedNumberTypes())
156+
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 2, 10, Split_2_10, WebGpuSupportedNumberTypes())
157+
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 11, 12, Split_11_12, WebGpuSupportedNumberTypes())
158+
WEBGPU_SPLIT_VERSIONED_KERNEL(Split, 13, 17, Split_13_17, WebGpuSupportedNumberTypes())
159+
WEBGPU_SPLIT_KERNEL(Split, 18, Split_18, WebGpuSupportedNumberTypes());
160+
161+
} // namespace webgpu
162+
} // namespace onnxruntime
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/program.h"
7+
#include "core/providers/webgpu/webgpu_kernel.h"
8+
#include "core/providers/common.h"
9+
#include "core/providers/cpu/tensor/split.h"
10+
11+
namespace onnxruntime {
12+
namespace webgpu {
13+
14+
class SplitProgram final : public Program<SplitProgram> {
15+
public:
16+
SplitProgram(const uint32_t axis) : Program{"Split"}, axis_{axis} {}
17+
18+
Status GenerateShaderCode(ShaderHelper& sh) const override;
19+
20+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32},
21+
{"sizes_in_split_axis", ProgramUniformVariableDataType::Uint32});
22+
23+
private:
24+
uint32_t axis_;
25+
};
26+
27+
class Split : public WebGpuKernel, public SplitBase {
28+
public:
29+
Split(const OpKernelInfo& info, uint32_t opset) : WebGpuKernel(info), SplitBase(info, opset) {}
30+
31+
protected:
32+
Status ComputeInternal(ComputeContext& context) const override;
33+
};
34+
35+
class Split_1 final : public Split {
36+
public:
37+
Split_1(const OpKernelInfo& info) : Split(info, 1) {}
38+
};
39+
40+
class Split_2_10 final : public Split {
41+
public:
42+
Split_2_10(const OpKernelInfo& info) : Split(info, 2) {}
43+
};
44+
45+
class Split_11_12 final : public Split {
46+
public:
47+
Split_11_12(const OpKernelInfo& info) : Split(info, 11) {}
48+
};
49+
50+
class Split_13_17 final : public Split {
51+
public:
52+
Split_13_17(const OpKernelInfo& info) : Split(info, 13) {}
53+
};
54+
55+
class Split_18 final : public Split {
56+
public:
57+
Split_18(const OpKernelInfo& info) : Split(info, 18) {}
58+
};
59+
60+
} // namespace webgpu
61+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

+6-5
Original file line numberDiff line numberDiff line change
@@ -637,11 +637,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
637637
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Concat)>,
638638
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Concat)>,
639639

640-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 1, Split)>,
641-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Split)>,
642-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Split)>,
643-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Split)>,
644-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, Split)>,
640+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 1, Split)>,
641+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Split)>,
642+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Split)>,
643+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Split)>,
644+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, Split)>,
645+
645646
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 12, Expand)>,
646647
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Expand)>,
647648

0 commit comments

Comments
 (0)