Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit 563dbd6

Browse files
authored
lib: update operator resolve (#108)
1 parent 0432a35 commit 563dbd6

16 files changed

+466
-432
lines changed

lib/backend.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import {Graph} from './graph';
55
import {Operator} from './operators';
6+
import {OpSet} from './opset';
67
import {Session} from './session';
78

89
export interface InferenceHandler {
@@ -30,12 +31,11 @@ export interface SessionHandler {
3031
dispose(): void;
3132

3233
/**
33-
* Resolves the operator from the name; backend specific
34+
* Resolves the operator from the name and opset version; backend specific
3435
* @param node
35-
* @param domain
36-
* @param version
36+
* @param opsets
3737
*/
38-
resolve(node: Graph.Node, domain: string, version: number): Operator;
38+
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>): Operator;
3939
/**
4040
* This method let's the sessionHandler know that the graph initialization is complete
4141
* @param graph the completely initialized graph

lib/backends/cpu/op-resolve-rules.ts

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT license.
3+
4+
import {FLOAT_TYPES, NUMBER_TYPES} from '../../operators';
5+
import {OpSet} from '../../opset';
6+
7+
import {CpuArgMax} from './ops/argMax';
8+
import {CpuBatchNormalization} from './ops/batch-normalization';
9+
import {CpuBinaryOp} from './ops/binary-op';
10+
import {CpuConcat} from './ops/concat';
11+
import {CpuConv} from './ops/conv';
12+
import {CpuDropout} from './ops/dropout';
13+
import {CpuFlatten} from './ops/flatten';
14+
import {CpuGather} from './ops/gather';
15+
import {CpuGemm} from './ops/gemm';
16+
import {CpuImageScaler} from './ops/image-scaler';
17+
import {CpuInstanceNormalization} from './ops/instance-normalization';
18+
import {CpuLrn} from './ops/lrn';
19+
import {CpuMatMul} from './ops/matmul';
20+
import {CpuAveragePool, CpuGlobalAveragePool, CpuGlobalMaxPool, CpuMaxPool} from './ops/pool';
21+
import * as cpuReduce from './ops/reduce';
22+
import {CpuReshape} from './ops/reshape';
23+
import {CpuSlice} from './ops/slice';
24+
import {CpuSoftmax} from './ops/softmax';
25+
import {CpuSqueeze} from './ops/squeeze';
26+
import {CpuSum} from './ops/sum';
27+
import {CpuTile} from './ops/tile';
28+
import {CpuTranspose} from './ops/transpose';
29+
import * as unaryOps from './ops/unary-op';
30+
import {CpuUnsqueeze} from './ops/unsqueeze';
31+
32+
export const CPU_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
33+
['Abs', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.abs)],
34+
['Acos', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.acos)],
35+
['Add', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 + e2))],
36+
['And', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 && e2))],
37+
['ArgMax', '', '1+', () => new CpuArgMax()],
38+
['Asin', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.asin)],
39+
['Atan', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.atan)],
40+
['AveragePool', '', '7+', () => new CpuAveragePool()], // TODO: support new attributes for AveragePool-10
41+
['BatchNormalization', '', '7+', () => new CpuBatchNormalization()],
42+
['Ceil', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.ceil)],
43+
['Clip', '', '6+', () => new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.clip)],
44+
['Concat', '', '4+', () => new CpuConcat()],
45+
['Conv', '', '1+', () => new CpuConv()],
46+
['Cos', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.cos)],
47+
['Div', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 / e2))],
48+
['Dropout', '', '7+', () => new CpuDropout()],
49+
['Elu', '', '6+', () => new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.elu)],
50+
['Exp', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.exp)],
51+
['Flatten', '', '1+', () => new CpuFlatten()],
52+
['Floor', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.floor)],
53+
['Gather', '', '1+', () => new CpuGather()],
54+
['Gemm', '', '7+', () => new CpuGemm()],
55+
['GlobalAveragePool', '', '1+', () => new CpuGlobalAveragePool()],
56+
['GlobalMaxPool', '', '1+', () => new CpuGlobalMaxPool()],
57+
['ImageScaler', '', '1+', () => new CpuImageScaler()],
58+
['InstanceNormalization', '', '6+', () => new CpuInstanceNormalization()],
59+
['LeakyRelu', '', '6+', () => new unaryOps.CpuUnaryOp(FLOAT_TYPES, unaryOps.leakyRelu)],
60+
['Log', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.log)],
61+
['LRN', '', '1+', () => new CpuLrn()],
62+
['MatMul', '', '1+', () => new CpuMatMul()],
63+
['MaxPool', '', '1+', () => new CpuMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
64+
['Mul', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 * e2))],
65+
['Neg', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.neg)],
66+
['Or', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 || e2))],
67+
['PRelu', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 >= 0 ? e1 : e1 * e2))],
68+
['ReduceLogSum', '', '1+', () => new cpuReduce.CpuReduceLogSum()],
69+
['ReduceMax', '', '1+', () => new cpuReduce.CpuReduceMax()],
70+
['ReduceMean', '', '1+', () => new cpuReduce.CpuReduceMean()],
71+
['ReduceMin', '', '1+', () => new cpuReduce.CpuReduceMin()],
72+
['ReduceProd', '', '1+', () => new cpuReduce.CpuReduceProd()],
73+
['ReduceSum', '', '1+', () => new cpuReduce.CpuReduceSum()],
74+
['ReduceSumSquare', '', '1+', () => new cpuReduce.CpuReduceSumSquare()],
75+
['Relu', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.relu)],
76+
['Reshape', '', '5+', () => new CpuReshape()],
77+
['Sigmoid', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.sigmoid)],
78+
['Sin', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.sin)],
79+
['Slice', '', '1+', () => new CpuSlice()],
80+
['Softmax', '', '1+', () => new CpuSoftmax()],
81+
['Sqrt', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.sqrt)],
82+
['Squeeze', '', '1+', () => new CpuSqueeze()],
83+
['Sub', '', '7+', () => new CpuBinaryOp(NUMBER_TYPES, (e1, e2) => (e1 - e2))],
84+
['Sum', '', '6+', () => new CpuSum()], // TODO: support multidirectional broadcast for Sum-8
85+
['Tan', '', '7+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.tan)],
86+
['Tanh', '', '6+', () => new unaryOps.CpuUnaryOp(NUMBER_TYPES, unaryOps.tanh)],
87+
['Tile', '', '6+', () => new CpuTile()],
88+
['Transpose', '', '1+', () => new CpuTranspose()],
89+
['Unsqueeze', '', '1+', () => new CpuUnsqueeze()],
90+
['Xor', '', '7+', () => new CpuBinaryOp(['bool'], (e1, e2) => (e1 ^ e2))],
91+
];

lib/backends/cpu/ops-resolve.ts

-167
This file was deleted.

lib/backends/cpu/session-handler.ts

+6-5
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import {Backend, InferenceHandler, SessionHandler} from '../../backend';
55
import {Graph} from '../../graph';
66
import {Operator} from '../../operators';
7+
import {OpSet, resolveOperator} from '../../opset';
78
import {Session} from '../../session';
89

910
import {CpuInferenceHandler} from './inference-handler';
10-
import {resolve} from './ops-resolve';
11+
import {CPU_OP_RESOLVE_RULES} from './op-resolve-rules';
1112

1213
export class CpuSessionHandler implements SessionHandler {
1314
constructor(readonly backend: Backend, readonly context: Session.Context) {}
@@ -18,9 +19,9 @@ export class CpuSessionHandler implements SessionHandler {
1819

1920
dispose(): void {}
2021

21-
resolve(node: Graph.Node, domain: string, version: number): Operator {
22-
// We have kept the ops resolve logic separately to be leveraged by other components (if needed)
23-
// This is valid only if there is no statefulness associated with the op resolution logic (which is currently true)
24-
return resolve(node, domain, version);
22+
resolve(node: Graph.Node, opsets: ReadonlyArray<OpSet>): Operator {
23+
const op = resolveOperator(node, opsets, CPU_OP_RESOLVE_RULES);
24+
op.initialize(node.attributes);
25+
return op;
2526
}
2627
}

lib/backends/wasm/op-resolve-rules.ts

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT license.
3+
4+
import {OpSet} from '../../opset';
5+
6+
import {WasmBatchNormalization} from './ops/batch-normalization';
7+
import {WasmBinaryOp} from './ops/binary-op';
8+
import {WasmClip} from './ops/clip';
9+
import {WasmConv} from './ops/conv';
10+
import {WasmGemm} from './ops/gemm';
11+
import {WasmInstanceNormalization} from './ops/instance-normalization';
12+
import {WasmMatMul} from './ops/matmul';
13+
import {WasmAveragePool, WasmGlobalAveragePool, WasmGlobalMaxPool, WasmMaxPool} from './ops/pool';
14+
import {WasmSoftmax} from './ops/softmax';
15+
import {WasmSum} from './ops/sum';
16+
17+
export const WASM_OP_RESOLVE_RULES: ReadonlyArray<OpSet.ResolveRule> = [
18+
['Add', '', '7+', () => new WasmBinaryOp(['float32'], 'Add')],
19+
['And', '', '7+', () => new WasmBinaryOp(['bool'], 'And')],
20+
['AveragePool', '', '7+', () => new WasmAveragePool()], // TODO: support new attributes for AveragePool-10
21+
['BatchNormalization', '', '7+', () => new WasmBatchNormalization()],
22+
['Clip', '', '6+', () => new WasmClip()],
23+
['Conv', '', '1+', () => new WasmConv()],
24+
['Div', '', '7+', () => new WasmBinaryOp(['float32'], 'Div')],
25+
['Gemm', '', '7+', () => new WasmGemm()],
26+
['GlobalAveragePool', '', '1+', () => new WasmGlobalAveragePool()],
27+
['GlobalMaxPool', '', '1+', () => new WasmGlobalMaxPool()],
28+
['InstanceNormalization', '', '6+', () => new WasmInstanceNormalization()],
29+
['MatMul', '', '1+', () => new WasmMatMul()],
30+
['MaxPool', '', '1+', () => new WasmMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
31+
['Mul', '', '7+', () => new WasmBinaryOp(['float32'], 'Mul')],
32+
['Or', '', '7+', () => new WasmBinaryOp(['bool'], 'Or')],
33+
['PRelu', '', '7+', () => new WasmBinaryOp(['float32'], 'PRelu')],
34+
['Softmax', '', '1+', () => new WasmSoftmax()],
35+
['Sub', '', '7+', () => new WasmBinaryOp(['float32'], 'Sub')],
36+
['Sum', '', '6+', () => new WasmSum()], // TODO: support multidirectional broadcast for Sum-8
37+
['Xor', '', '7+', () => new WasmBinaryOp(['bool'], 'Xor')],
38+
];

0 commit comments

Comments
 (0)