Skip to content
Merged
3 changes: 3 additions & 0 deletions include/onnxruntime/core/graph/node_arg.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class NodeArg {
/** Sets the inferred shape scalar value */
void SetInferredShapeScalarValue(int64_t value) noexcept { inferred_scalar_value_ = value; }

/** Clears the inferred shape scalar value */
void ClearInferredShapeScalarValue() noexcept { inferred_scalar_value_.reset(); }

/** Gets a flag indicating whether this NodeArg exists or not.
Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */
bool Exists() const noexcept;
Expand Down
15 changes: 11 additions & 4 deletions onnxruntime/core/graph/data_propagation/add_op_data_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/graph/node_arg.h"
#include "core/graph/onnx_protobuf.h"
#include "core/providers/common.h"
#include "core/graph/data_propagation/data_propagation_value_utils.h"

namespace onnxruntime {

Expand All @@ -20,10 +21,16 @@ Status AddOpDataPropagation::infer() {
return Status::OK();
}

if (input_0->GetInferredShapeScalarValue().has_value() && input_1->GetInferredShapeScalarValue().has_value()) {
output_def_.SetInferredShapeScalarValue(
input_0->GetInferredShapeScalarValue().value() +
input_1->GetInferredShapeScalarValue().value());
int64_t lhs = 0;
int64_t rhs = 0;
bool lhs_is_rank1 = false;
bool rhs_is_rank1 = false;
if (TryGetSinglePropagatedShapeValue(*input_0, lhs, lhs_is_rank1) &&
TryGetSinglePropagatedShapeValue(*input_1, rhs, rhs_is_rank1)) {
// Single-element operands may be carried as a rank-0 scalar or a rank-1 [1] value. Per ONNX
// broadcasting, the result is rank-1 if either operand is rank-1, otherwise a scalar; keep
// the propagated value's rank consistent with that so downstream consumers see the right rank.
SetSinglePropagatedShapeValue(output_def_, lhs + rhs, lhs_is_rank1 || rhs_is_rank1);
}

return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class AddOpDataPropagation : public CustomDataPropagationBase {
public:
AddOpDataPropagation(const Node& node,
NodeArg& output_def,
std::function<Status(const std::string&, TensorShapeVector&)> func,
GetInitializedInputValuesFunc func,
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
const logging::Logger& logger) noexcept
: CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace onnxruntime {

std::unique_ptr<CustomDataPropagationBase> CreateCustomDataPropagation(const Node& node,
NodeArg& output_def,
std::function<Status(const std::string&, TensorShapeVector&)> func,
GetInitializedInputValuesFunc func,
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
const logging::Logger& logger) {
int dim_size = 0;
Expand Down
21 changes: 18 additions & 3 deletions onnxruntime/core/graph/data_propagation/custom_data_propagation.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@

namespace onnxruntime {

/**
* @brief Signature of the helper used to read a constant-initializer input's values during
* data propagation.
*
* @param input_name Name of the input to read.
* @param input_values Receives the initializer's flattened int64 values (empty if the input is
* not a constant initializer).
* @param num_dims Receives the rank (number of dimensions) of the initializer's shape, or is
* left unchanged when the input is not a constant initializer. A value of 0
* denotes a scalar (rank-0) initializer; this lets callers distinguish a 0-D
* scalar from a rank-1 single-element initializer (both have one element).
*/
using GetInitializedInputValuesFunc =
std::function<Status(const std::string& input_name, TensorShapeVector& input_values, int& num_dims)>;

Check warning on line 26 in onnxruntime/core/graph/data_propagation/custom_data_propagation.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/graph/data_propagation/custom_data_propagation.h:26: Add #include <string> for string [build/include_what_you_use] [4]

/**
* @class CustomDataPropagation
* Custom data propagation for the operator to help enhance shape inference.
Expand All @@ -27,7 +42,7 @@
protected:
CustomDataPropagationBase(const Node& node,
NodeArg& output_def,
std::function<Status(const std::string&, TensorShapeVector&)> func,
GetInitializedInputValuesFunc func,
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
const logging::Logger& logger) noexcept
: node_(node),
Expand All @@ -38,7 +53,7 @@

const Node& node_;
NodeArg& output_def_;
std::function<Status(const std::string&, TensorShapeVector&)> get_initialized_input_values_func_;
GetInitializedInputValuesFunc get_initialized_input_values_func_;
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation_;
const logging::Logger& logger_;
};
Expand Down Expand Up @@ -68,7 +83,7 @@
std::unique_ptr<CustomDataPropagationBase> CreateCustomDataPropagation(
const Node& node,
NodeArg& output_def,
std::function<Status(const std::string&, TensorShapeVector&)> func,
GetInitializedInputValuesFunc func,
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
const logging::Logger& logger);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <cstdint>

#include "core/graph/node_arg.h"

namespace onnxruntime {

// Data propagation carries a small "shape value" in one of two non-interchangeable channels
// on a NodeArg: a rank-0 scalar (inferred_scalar_value_) or a rank>=1 list of values
// (inferred_shape_values_). The helpers below let custom data-propagation ops read and write a
// single-element value while preserving its rank (rank-0 scalar vs rank-1 [1]), so a producer
// and its consumers cannot silently disagree on rank (e.g. Gather feeding Mul feeding TopK).

// Reads a single int64 shape value carried by a NodeArg's data propagation, accepting either a
// rank-0 scalar value or a rank-1 single-element value. On success, sets `value`, sets
// `is_rank1` to false for a scalar source or true for a rank-1 [1] source, and returns true.
// Returns false if the NodeArg carries no usable single-element shape value.
inline bool TryGetSinglePropagatedShapeValue(const NodeArg& input_def, int64_t& value, bool& is_rank1) {
if (input_def.GetInferredShapeScalarValue().has_value()) {
value = input_def.GetInferredShapeScalarValue().value();
is_rank1 = false;
return true;
}

const auto& inferred_values = input_def.GetInferredShapeValues();
if (inferred_values.has_value() &&
inferred_values->dim_size() == 1 &&
inferred_values->dim(0).has_dim_value()) {
value = inferred_values->dim(0).dim_value();
is_rank1 = true;
return true;
}

return false;
}
Comment thread
titaiwangms marked this conversation as resolved.

// Stores a single int64 shape value on `output_def`, as a rank-0 scalar when `is_rank1` is false
// or as a rank-1 single-element value when `is_rank1` is true. The rank-1 representation mirrors
// how Graph::getInputData() reconstructs a TensorProto (dims=[1]) from inferred_shape_values_.
// The setter is correct-by-construction: it populates exactly one channel and clears the other, so
// the scalar-first reader (TryGetSinglePropagatedShapeValue) and the values-first getInputData()
// can never disagree on rank even if `output_def` carried a stale value from another channel.
inline void SetSinglePropagatedShapeValue(NodeArg& output_def, int64_t value, bool is_rank1) {
if (!is_rank1) {
output_def.SetInferredShapeScalarValue(value);
// Keep exactly one channel populated: drop any stale values channel that getInputData() would
// otherwise prefer over this scalar.
output_def.GetMutableInferredShapeValues().reset();
return;
}

auto& inferred_values = output_def.GetMutableInferredShapeValues();
if (!inferred_values.has_value()) {
inferred_values.emplace();
}
inferred_values->clear_dim();
inferred_values->add_dim()->set_dim_value(value);
// Keep exactly one channel populated: drop any stale scalar that the scalar-first reader would
// otherwise return ahead of this rank-1 value.
output_def.ClearInferredShapeScalarValue();
}
Comment thread
titaiwangms marked this conversation as resolved.

} // namespace onnxruntime
16 changes: 12 additions & 4 deletions onnxruntime/core/graph/data_propagation/div_op_data_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/graph/node_arg.h"
#include "core/graph/onnx_protobuf.h"
#include "core/providers/common.h"
#include "core/graph/data_propagation/data_propagation_value_utils.h"

namespace onnxruntime {

Expand All @@ -20,10 +21,17 @@ Status DivOpDataPropagation::infer() {
return Status::OK();
}

if (input_0->GetInferredShapeScalarValue().has_value() && input_1->GetInferredShapeScalarValue().has_value()) {
output_def_.SetInferredShapeScalarValue(
input_0->GetInferredShapeScalarValue().value() /
input_1->GetInferredShapeScalarValue().value());
int64_t lhs = 0;
int64_t rhs = 0;
bool lhs_is_rank1 = false;
bool rhs_is_rank1 = false;
if (TryGetSinglePropagatedShapeValue(*input_0, lhs, lhs_is_rank1) &&
TryGetSinglePropagatedShapeValue(*input_1, rhs, rhs_is_rank1) &&
rhs != 0) {
// Single-element operands may be carried as a rank-0 scalar or a rank-1 [1] value. Per ONNX
// broadcasting, the result is rank-1 if either operand is rank-1, otherwise a scalar; keep
// the propagated value's rank consistent with that so downstream consumers see the right rank.
SetSinglePropagatedShapeValue(output_def_, lhs / rhs, lhs_is_rank1 || rhs_is_rank1);
}

return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class DivOpDataPropagation : public CustomDataPropagationBase {
public:
DivOpDataPropagation(const Node& node,
NodeArg& output_def,
std::function<Status(const std::string&, TensorShapeVector&)> func,
GetInitializedInputValuesFunc func,
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
const logging::Logger& logger) noexcept
: CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/graph/node_arg.h"
#include "core/graph/onnx_protobuf.h"
#include "core/providers/common.h"
#include "core/graph/data_propagation/data_propagation_value_utils.h"

namespace onnxruntime {

Expand Down Expand Up @@ -58,15 +59,41 @@ Status GatherOpDataPropagation::infer() {

ORT_TRY {
TensorShapeVector indices;
ORT_RETURN_IF_ERROR(get_initialized_input_values_func_(input_1->Name(), indices));
int indices_num_dims = -1;
ORT_RETURN_IF_ERROR(get_initialized_input_values_func_(input_1->Name(), indices, indices_num_dims));
if (indices.size() == 1) {
// Note: Index value is expected to be within bounds [-s, s-1] along axis of size s
auto index = static_cast<int32_t>(
HandleNegativeAxis(indices[0], tensor_shape_proto.dim_size()));

auto& dim = tensor_shape_proto.dim(index);
if (dim.has_dim_value()) {
output_def_.SetInferredShapeScalarValue(dim.dim_value());
// Gather output rank = data_rank - 1 + indices_rank. The "data" input here is the
// 1-D Shape output, so the output rank equals the indices rank. Route by that rank:
// * a 0-D scalar index -> a scalar output value,
// * a 1-D index -> a rank-1 [1] output value (so consumers that require a
// 1-D tensor, e.g. TopK's K input, still observe the correct rank),
// * a rank >= 2 index -> decline: leave the output value unset so the dimension
// stays symbolic, because the single-value channel cannot represent a rank >= 2
// Gather output and emitting a rank-1 value would fabricate a misleading rank.
//
// The index rank (indices_num_dims) is sourced canonically from the same constant
// initializer the index value came from. Reading that initializer is exactly what makes
// indices.size() == 1 reachable here, and it always reports a concrete rank
// (TensorShape::NumDimensions(), >= 0), so indices_num_dims is always >= 0 here and no
// unknown-rank (< 0) handling is needed.
const int effective_num_dims = indices_num_dims;

if (effective_num_dims == 0) {
// 0-D scalar index -> a scalar output value.
SetSinglePropagatedShapeValue(output_def_, dim.dim_value(), /*is_rank1=*/false);
} else if (effective_num_dims == 1) {
// 1-D index -> a rank-1 [1] output value.
SetSinglePropagatedShapeValue(output_def_, dim.dim_value(), /*is_rank1=*/true);
}
// rank >= 2 index: leave the output value unset so the dimension remains symbolic --
// the correct decline behavior, since the single-value channel cannot represent a
// rank >= 2 Gather output and there is no ONNX fallback once this custom propagator runs.
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class GatherOpDataPropagation : public CustomDataPropagationBase {
public:
GatherOpDataPropagation(const Node& node,
NodeArg& output_def,
std::function<Status(const std::string&, TensorShapeVector&)> func,
GetInitializedInputValuesFunc func,
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
const logging::Logger& logger) noexcept
: CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {}
Expand Down
15 changes: 11 additions & 4 deletions onnxruntime/core/graph/data_propagation/mul_op_data_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/graph/node_arg.h"
#include "core/graph/onnx_protobuf.h"
#include "core/providers/common.h"
#include "core/graph/data_propagation/data_propagation_value_utils.h"

namespace onnxruntime {

Expand All @@ -20,10 +21,16 @@ Status MulOpDataPropagation::infer() {
return Status::OK();
}

if (input_0->GetInferredShapeScalarValue().has_value() && input_1->GetInferredShapeScalarValue().has_value()) {
output_def_.SetInferredShapeScalarValue(
input_0->GetInferredShapeScalarValue().value() *
input_1->GetInferredShapeScalarValue().value());
int64_t lhs = 0;
int64_t rhs = 0;
bool lhs_is_rank1 = false;
bool rhs_is_rank1 = false;
if (TryGetSinglePropagatedShapeValue(*input_0, lhs, lhs_is_rank1) &&
TryGetSinglePropagatedShapeValue(*input_1, rhs, rhs_is_rank1)) {
// Single-element operands may be carried as a rank-0 scalar or a rank-1 [1] value. Per ONNX
// broadcasting, the result is rank-1 if either operand is rank-1, otherwise a scalar; keep
// the propagated value's rank consistent with that so downstream consumers see the right rank.
SetSinglePropagatedShapeValue(output_def_, lhs * rhs, lhs_is_rank1 || rhs_is_rank1);
}

return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class MulOpDataPropagation : public CustomDataPropagationBase {
public:
MulOpDataPropagation(const Node& node,
NodeArg& output_def,
std::function<Status(const std::string&, TensorShapeVector&)> func,
GetInitializedInputValuesFunc func,
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
const logging::Logger& logger) noexcept
: CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class SizeOpDataPropagation : public CustomDataPropagationBase {
public:
SizeOpDataPropagation(const Node& node,
NodeArg& output_def,
std::function<Status(const std::string&, TensorShapeVector&)> func,
GetInitializedInputValuesFunc func,
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
const logging::Logger& logger) noexcept
: CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ Status SqueezeOpDataPropagation::infer() {
if (node_.InputDefs().size() > 1) {
const auto* input_1 = node_.InputDefs()[1];
ORT_TRY {
ORT_RETURN_IF_ERROR(get_initialized_input_values_func_(input_1->Name(), axes));
[[maybe_unused]] int axes_num_dims = -1;
ORT_RETURN_IF_ERROR(get_initialized_input_values_func_(input_1->Name(), axes, axes_num_dims));
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class SqueezeOpDataPropagation : public CustomDataPropagationBase {
public:
SqueezeOpDataPropagation(const Node& node,
NodeArg& output_def,
std::function<Status(const std::string&, TensorShapeVector&)> func,
GetInitializedInputValuesFunc func,
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
const logging::Logger& logger) noexcept
: CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {}
Expand Down
15 changes: 11 additions & 4 deletions onnxruntime/core/graph/data_propagation/sub_op_data_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/graph/node_arg.h"
#include "core/graph/onnx_protobuf.h"
#include "core/providers/common.h"
#include "core/graph/data_propagation/data_propagation_value_utils.h"

namespace onnxruntime {

Expand All @@ -20,10 +21,16 @@ Status SubOpDataPropagation::infer() {
return Status::OK();
}

if (input_0->GetInferredShapeScalarValue().has_value() && input_1->GetInferredShapeScalarValue().has_value()) {
output_def_.SetInferredShapeScalarValue(
input_0->GetInferredShapeScalarValue().value() -
input_1->GetInferredShapeScalarValue().value());
int64_t lhs = 0;
int64_t rhs = 0;
bool lhs_is_rank1 = false;
bool rhs_is_rank1 = false;
if (TryGetSinglePropagatedShapeValue(*input_0, lhs, lhs_is_rank1) &&
TryGetSinglePropagatedShapeValue(*input_1, rhs, rhs_is_rank1)) {
// Single-element operands may be carried as a rank-0 scalar or a rank-1 [1] value. Per ONNX
// broadcasting, the result is rank-1 if either operand is rank-1, otherwise a scalar; keep
// the propagated value's rank consistent with that so downstream consumers see the right rank.
SetSinglePropagatedShapeValue(output_def_, lhs - rhs, lhs_is_rank1 || rhs_is_rank1);
}

return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SubOpDataPropagation : public CustomDataPropagationBase {
public:
SubOpDataPropagation(const Node& node,
NodeArg& output_def,
std::function<Status(const std::string&, TensorShapeVector&)> func,
GetInitializedInputValuesFunc func,
const ONNX_NAMESPACE::TypeProto& output_from_onnx_op_data_propagation,
const logging::Logger& logger) noexcept
: CustomDataPropagationBase(node, output_def, func, output_from_onnx_op_data_propagation, logger) {}
Expand Down
Loading
Loading