Skip to content

Commit f895279

Browse files
authored
fix the problems introduced in previous PRs in operator registration (onnx#1878)
* fix the problems introduced in previous PRs in operator registration * update the doc * fix the doc
1 parent f6f8065 commit f895279

File tree

5 files changed

+57
-67
lines changed

5 files changed

+57
-67
lines changed

docs/Changelog.md

+45
Original file line numberDiff line numberDiff line change
@@ -9486,6 +9486,51 @@ This version of the operator has been available since version 10 of the default
94869486
<dd>Constrain input and output types to float tensors.</dd>
94879487
</dl>
94889488

9489+
### <a name="Dropout-10"></a>**Dropout-10**</a>
9490+
9491+
Dropout takes one input floating tensor and produces two tensor outputs,
9492+
output (floating tensor) and mask (`Tensor<bool>`). Depending on whether it is
9493+
in test mode or not, the output Y will either be a random dropout, or a simple
9494+
copy of the input. Note that our implementation of Dropout does scaling in
9495+
the training phase, so during testing nothing needs to be done.
9496+
This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted.
9497+
9498+
#### Version
9499+
9500+
This version of the operator has been available since version 10 of the default ONNX operator set.
9501+
9502+
#### Attributes
9503+
9504+
<dl>
9505+
<dt><tt>ratio</tt> : float (default is 0.5)</dt>
9506+
<dd>The ratio of random dropout</dd>
9507+
</dl>
9508+
9509+
#### Inputs
9510+
9511+
<dl>
9512+
<dt><tt>data</tt> : T</dt>
9513+
<dd>The input data as Tensor.</dd>
9514+
</dl>
9515+
9516+
#### Outputs (1 - 2)
9517+
9518+
<dl>
9519+
<dt><tt>output</tt> : T</dt>
9520+
<dd>The output.</dd>
9521+
<dt><tt>mask</tt> (optional) : T1</dt>
9522+
<dd>The output mask.</dd>
9523+
</dl>
9524+
9525+
#### Type Constraints
9526+
9527+
<dl>
9528+
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double)</dt>
9529+
<dd>Constrain input and output types to float tensors.</dd>
9530+
<dt><tt>T1</tt> : tensor(bool)</dt>
9531+
<dd>Constrain output mask types to boolean tensors.</dd>
9532+
</dl>
9533+
94899534
### <a name="MaxPool-10"></a>**MaxPool-10**</a>
94909535

94919536
MaxPool consumes an input tensor X and applies max pooling across

docs/Operators.md

+8-6
Original file line numberDiff line numberDiff line change
@@ -3093,18 +3093,18 @@ expect(node, inputs=[x, y], outputs=[z],
30933093

30943094
### <a name="Dropout"></a><a name="dropout">**Dropout**</a>
30953095

3096-
Dropout takes one input data (Tensor<float>) and produces two Tensor outputs,
3097-
output (Tensor<float>) and mask (Tensor<bool>). Depending on whether it is in
3098-
test mode or not, the output Y will either be a random dropout, or a simple
3096+
Dropout takes one input floating tensor and produces two tensor outputs,
3097+
output (floating tensor) and mask (`Tensor<bool>`). Depending on whether it is
3098+
in test mode or not, the output Y will either be a random dropout, or a simple
30993099
copy of the input. Note that our implementation of Dropout does scaling in
31003100
the training phase, so during testing nothing needs to be done.
31013101
This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted.
31023102

31033103
#### Version
31043104

3105-
This version of the operator has been available since version 7 of the default ONNX operator set.
3105+
This version of the operator has been available since version 10 of the default ONNX operator set.
31063106

3107-
Other versions of this operator: <a href="Changelog.md#Dropout-1">Dropout-1</a>, <a href="Changelog.md#Dropout-6">Dropout-6</a>
3107+
Other versions of this operator: <a href="Changelog.md#Dropout-1">Dropout-1</a>, <a href="Changelog.md#Dropout-6">Dropout-6</a>, <a href="Changelog.md#Dropout-7">Dropout-7</a>
31083108

31093109
#### Attributes
31103110

@@ -3125,7 +3125,7 @@ Other versions of this operator: <a href="Changelog.md#Dropout-1">Dropout-1</a>,
31253125
<dl>
31263126
<dt><tt>output</tt> : T</dt>
31273127
<dd>The output.</dd>
3128-
<dt><tt>mask</tt> (optional) : T</dt>
3128+
<dt><tt>mask</tt> (optional) : T1</dt>
31293129
<dd>The output mask.</dd>
31303130
</dl>
31313131

@@ -3134,6 +3134,8 @@ Other versions of this operator: <a href="Changelog.md#Dropout-1">Dropout-1</a>,
31343134
<dl>
31353135
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double)</dt>
31363136
<dd>Constrain input and output types to float tensors.</dd>
3137+
<dt><tt>T1</tt> : tensor(bool)</dt>
3138+
<dd>Constrain output mask types to boolean tensors.</dd>
31373139
</dl>
31383140

31393141

onnx/defs/experiments/defs.cc

-59
Original file line numberDiff line numberDiff line change
@@ -149,63 +149,4 @@ ONNX_OPERATOR_SET_SCHEMA(
149149
"tensor(double)"},
150150
"Constrain output types to bool, int32, int64, float16, float, double tensors."));
151151

152-
static const char* ImageScaler_ver1_doc =
153-
R"DOC(Scale and bias the input image. Bias values are stored in
154-
the same ordering as the image pixel format.)DOC";
155-
156-
ONNX_OPERATOR_SET_SCHEMA(
157-
ImageScaler,
158-
1,
159-
OpSchema()
160-
.SetSupportLevel(SupportType::EXPERIMENTAL)
161-
.SetDoc(ImageScaler_ver1_doc)
162-
.Attr(
163-
"bias",
164-
"Bias applied to each channel, same size as C.",
165-
AttributeProto::FLOATS,
166-
OPTIONAL)
167-
.Attr(
168-
"scale",
169-
"The scale to apply.",
170-
AttributeProto::FLOAT,
171-
1.0f)
172-
.Input(0, "input", "Input tensor of shape [N,C,H,W]", "T")
173-
.Output(0, "output", "Result, has same shape and type as input", "T")
174-
.TypeConstraint(
175-
"T",
176-
{"tensor(float16)", "tensor(float)", "tensor(double)"},
177-
"Constrain input and output types to float tensors.")
178-
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
179-
180-
static const char* Crop_ver1_doc =
181-
R"DOC(Crop and image to the specified spatial dimensions. If scale is given,
182-
then optionally start the crop offset by the left/top border amounts.
183-
If scale is not provided, crop the borders as provided.)DOC";
184-
185-
ONNX_OPERATOR_SET_SCHEMA(
186-
Crop,
187-
1,
188-
OpSchema()
189-
.SetSupportLevel(SupportType::EXPERIMENTAL)
190-
.SetDoc(Crop_ver1_doc)
191-
.Attr(
192-
"border",
193-
"A 1-D values of (leftBorder, topBorder, rightBorder, bottomBorder).",
194-
AttributeProto::INTS,
195-
OPTIONAL)
196-
.Attr(
197-
"scale",
198-
"A 1-D values of (height, width).",
199-
AttributeProto::INTS,
200-
OPTIONAL)
201-
.Input(0, "input", "Input tensor of shape [N,C,H,W]", "T")
202-
.Output(
203-
0,
204-
"output",
205-
"Result, has same type as input, with H and W dimensions reduced.",
206-
"T")
207-
.TypeConstraint(
208-
"T",
209-
{"tensor(float16)", "tensor(float)", "tensor(double)"},
210-
"Constrain input and output types to float tensors."));
211152
} // namespace ONNX_NAMESPACE

onnx/defs/operator_sets.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 10, MaxPool);
530530
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 10, AveragePool);
531531
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 10, Slice);
532532
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 10, ThresholdedRelu);
533+
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 10, Dropout);
533534

534535
// Iterate over schema from ai.onnx version 10
535536
class OpSet_Onnx_ver10 {
@@ -544,11 +545,13 @@ class OpSet_Onnx_ver10 {
544545
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
545546
Onnx, 10, MaxPool)>());
546547
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
547-
Onnx, 10, AveragePool)>());
548+
Onnx, 10, AveragePool)>());
548549
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
549550
Onnx, 10, Slice)>());
550551
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
551552
Onnx, 10, ThresholdedRelu)>());
553+
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
554+
Onnx, 10, Dropout)>());
552555
}
553556
};
554557

onnx/defs/schema.h

-1
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,6 @@ class OpSchemaRegistry final : public ISchemaRegistry {
715715
fail_schema(err.str());
716716
}
717717

718-
719718
m[op_name][op_domain].insert(std::pair<int, OpSchema&&>(ver, std::move(op_schema)));
720719

721720
} catch (const std::exception& e) {

0 commit comments

Comments
 (0)