Skip to content

Commit d084e49

Browse files
Gautham GanapathyFrederik Mellbye
authored andcommitted
Fix to prevent preapply transform for inplace elementwise op
Summary: Do not perform preapply transform for reduce if the elementwise op is inplace. REF T68013 Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, alfiee, zigmasb Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, zigmasb Maniphest Tasks: T68013 Differential Revision: https://phabricator.sourcevertex.net/D75145
1 parent d907997 commit d084e49

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

tensorflow/compiler/plugin/poplar/driver/passes/elementwise_preapply.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ StatusOr<bool> TryHandle(HloInstruction* inst) {
8282
// elementwise.
8383
return false;
8484
}
85+
if (user->metadata().op_type() == "AssignAddVariableOp" ||
86+
user->metadata().op_type() == "AssignSubVariableOp") {
87+
return false;
88+
}
8589
for (const HloInstruction* operand : user->operands()) {
8690
// An elementwise op is uniform on inst if all of its operands are one of
8791
// inst, have size 1, or are a broadcast of a variable with size 1.

tensorflow/compiler/plugin/poplar/tests/elementwise_preapply_test.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,27 @@ static std::string hlo_binary_bool() {
788788
"constant({{true, false, true}, {false, true, false}})"}});
789789
}
790790

791+
static std::string hlo_assign_add_sub(const std::string& assign_op) {
792+
return absl::StrReplaceAll(R"(
793+
HloModule module
794+
795+
function_to_apply {
796+
p0 = f32[] parameter(0)
797+
p1 = f32[] parameter(1)
798+
ROOT output = f32[] $FUNC1(p0, p1)
799+
}
800+
801+
ENTRY f {
802+
reduce_param = f32[2, 3] constant({{1, 2, 3}, {4, 5, 6}})
803+
reduce_init = f32[] constant(5)
804+
reduce = f32[] reduce(reduce_param, reduce_init), dimensions={0, 1}, to_apply=function_to_apply
805+
second_elementwise_arg = f32[] parameter(0)
806+
ROOT output = f32[] $FUNC2(reduce, second_elementwise_arg), metadata={op_type="$ASSIGN_OP"}
807+
}
808+
)",
809+
{{"$ASSIGN_OP", assign_op}});
810+
}
811+
791812
static std::string hlo_unary_float() {
792813
return absl::StrReplaceAll(
793814
hlo_binary_float(),
@@ -848,6 +869,9 @@ INSTANTIATE_TEST_SUITE_P(
848869
{"minimum", "copy", hlo_unary_float()},
849870
{"add", "subtract", hlo_binary_float_swapped()},
850871
{"multiply", "divide", hlo_binary_float_swapped()},
872+
// tf.assign_add(), tf.assign_sub
873+
{"add", "add", hlo_assign_add_sub("AssignAddVariableOp")},
874+
{"add", "subtract", hlo_assign_add_sub("AssignSubVariableOp")},
851875
// scalar result (no broadcast for constant)
852876
{"minimum", "maximum", hlo_scalar_result()},
853877
{"add", "multiply", hlo_scalar_result()},

0 commit comments

Comments
 (0)