Skip to content

Commit 5842144

Browse files
dakenfguschmue
andauthored
[js/web] JSEP Gemm for opset 13 (microsoft#16936)
### Description Added JSEP Gemm registration for opset 13. It was falling back to CPU provider as CPU has it for 13 --------- Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent edac3ef commit 5842144

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

js/web/docs/webgpu-operators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Do not modify directly.*
3838
| Floor | ai.onnx(6-12,13+) | |
3939
| Gather | ai.onnx(1-10,11-12,13+) | |
4040
| Gelu | com.microsoft(1+) | |
41-
| Gemm | ai.onnx(7-8,9-10,11+) | |
41+
| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
4242
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
4343
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
4444
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |

onnxruntime/core/providers/js/js_execution_provider.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn
231231
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose);
232232
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
233233
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
234-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Gemm);
234+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
235+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm);
235236
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul);
236237
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul);
237238

@@ -464,7 +465,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
464465
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose)>,
465466
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm)>,
466467
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm)>,
467-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Gemm)>,
468+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm)>,
469+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm)>,
468470
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul)>,
469471
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul)>,
470472

onnxruntime/core/providers/js/operators/gemm.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,15 @@ namespace js {
1212
ONNX_OPERATOR_TYPED_KERNEL_EX( \
1313
Gemm, \
1414
kOnnxDomain, \
15-
11, \
15+
13, \
16+
T, \
17+
kJsExecutionProvider, \
18+
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
19+
Gemm<T>); \
20+
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
21+
Gemm, \
22+
kOnnxDomain, \
23+
11, 12, \
1624
T, \
1725
kJsExecutionProvider, \
1826
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \

0 commit comments

Comments
 (0)