Skip to content

Commit c1cf16e

Browse files
Conv node bug, cached state was incoherent (microsoft#10041)
* Moved the init earlier to keep the cache coherent * Move setting of w_desc later, and zero shape check later to catch all cacheable changes. * Add comment
1 parent f4b2d3a commit c1cf16e

File tree

1 file changed

+6
-4
lines changed
  • onnxruntime/core/providers/cuda/nn

1 file changed

+6
-4
lines changed

onnxruntime/core/providers/cuda/nn/conv.cc

+6-4
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,6 @@ Status Conv<T>::UpdateState(OpKernelContext* context, bool bias_expected) const
176176
s_.slice_axes = slice_axes;
177177

178178
s_.Y = context->Output(0, TensorShape(s_.y_dims));
179-
if (s_.Y->Shape().Size() == 0) {
180-
return Status::OK();
181-
}
182179
if (post_slicing_required) {
183180
// Post slicing needed. Create and fill in the Conv results in an intermediate buffer.
184181
s_.memory_for_cudnn_conv_results = GetScratchBuffer<void>(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size);
@@ -206,9 +203,14 @@ Status Conv<T>::UpdateState(OpKernelContext* context, bool bias_expected) const
206203
dilations.push_back(1);
207204
}
208205

209-
if (w_dims_changed) {
206+
if (w_dims_changed)
210207
ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType<CudaT>()));
208+
209+
// We must delay returning early until here so that the weight dims have been cached properly
210+
if (s_.Y->Shape().Size() == 0) {
211+
return Status::OK();
211212
}
213+
212214
ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
213215
ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
214216
ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,

0 commit comments

Comments
 (0)