Skip to content

Commit 4c9c15f

Browse files
Apply add/remove has_side_effect in MPMD pipelines
PiperOrigin-RevId: 802989064
1 parent a95b171 commit 4c9c15f

13 files changed

+288
-1
lines changed

shardy/dialect/mpmd/transforms/common/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ cc_library(
3131
name = "passes",
3232
srcs = [
3333
"absorb_inferred_fragments.cc",
34+
"add_side_effect_to_avoid_cse.cc",
3435
"call_rewrites.cc",
3536
"copy_constants.cc",
3637
"fragment_dce.cc",
3738
"fragment_dedup.cc",
3839
"merge_fragments.cc",
3940
"merge_transfers.cc",
41+
"remove_side_effect_after_cse.cc",
4042
"remove_transfer_cycles.cc",
4143
"rule_based_merge.cc",
4244
"split_bwd_fragments.cc",
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/* Copyright 2025 The MPMD Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <memory>
17+
18+
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/IR/BuiltinAttributes.h"
20+
#include "shardy/dialect/mpmd/transforms/common/passes.h"
21+
#include "shardy/dialect/mpmd/transforms/common/utils.h"
22+
#include "stablehlo/dialect/StablehloOps.h"
23+
#include "mlir/Pass/Pass.h"
24+
25+
namespace mlir::mpmd {
26+
27+
namespace {
28+
29+
#define GEN_PASS_DEF_ADDSIDEEFFECTTOAVOIDCSEPASS
30+
#include "shardy/dialect/mpmd/transforms/common/passes.h.inc"
31+
32+
struct AddSideEffectToAvoidCSEPass
33+
: public impl::AddSideEffectToAvoidCSEPassBase<
34+
AddSideEffectToAvoidCSEPass> {
35+
using impl::AddSideEffectToAvoidCSEPassBase<
36+
AddSideEffectToAvoidCSEPass>::AddSideEffectToAvoidCSEPassBase;
37+
38+
void runOnOperation() override {
39+
getOperation().walk([](stablehlo::CustomCallOp customCallOp) {
40+
if (customCallOp->hasAttr(kMhloNoCseAttr)) {
41+
customCallOp.setHasSideEffect(true);
42+
}
43+
});
44+
}
45+
};
46+
47+
} // namespace
48+
49+
std::unique_ptr<Pass> createAddSideEffectToAvoidCSEPass() {
50+
return std::make_unique<AddSideEffectToAvoidCSEPass>();
51+
}
52+
53+
} // namespace mlir::mpmd

shardy/dialect/mpmd/transforms/common/passes.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,30 @@ limitations under the License.
1515

1616
include "mlir/Pass/PassBase.td"
1717

18+
def AddSideEffectToAvoidCSEPass :
19+
PassBase<"mpmd-add-side-effect-to-avoid-cse", "OperationPass<func::FuncOp>"> {
20+
let summary = "Adds a side effect attribute to custom_call ops with "
21+
"{mhlo.no_cse} to avoid CSE.";
22+
let description = [{
23+
For `stablehlo.custom_call` operations that have the `{mhlo.no_cse}`
24+
attribute, this pass adds an `{has_side_effect = true}` attribute.
25+
This prevents MLIR's CSE pass from eliminating these operations, because
26+
CSE skips operations with side effects.
27+
}];
28+
}
29+
30+
def RemoveSideEffectAfterCSEPass :
31+
PassBase<"mpmd-remove-side-effect-after-cse", "OperationPass<func::FuncOp>"> {
32+
let summary = "Removes side effect attribute from custom_call ops with "
33+
"{mhlo.no_cse}.";
34+
let description = [{
35+
For `stablehlo.custom_call` operations that have the `{mhlo.no_cse}`
36+
attribute, this pass removes the `{has_side_effect = true}` attribute if
37+
it exists. This is useful to run after CSE to remove the attribute that
38+
is no longer needed.
39+
}];
40+
}
41+
1842
// TODO: b/374694825 - This pass is not complete yet. In particular, we also
1943
// need to consider: (a) side-ways merging. We need to be careful with this as
2044
// it may have performance and jitting time implications. (b) relax the
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/* Copyright 2025 The MPMD Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <memory>
17+
#include <optional>
18+
19+
#include "mlir/Dialect/Func/IR/FuncOps.h"
20+
#include "mlir/IR/BuiltinAttributes.h"
21+
#include "shardy/dialect/mpmd/transforms/common/passes.h"
22+
#include "shardy/dialect/mpmd/transforms/common/utils.h"
23+
#include "stablehlo/dialect/StablehloOps.h"
24+
#include "mlir/Pass/Pass.h"
25+
26+
namespace mlir::mpmd {
27+
28+
namespace {
29+
30+
#define GEN_PASS_DEF_REMOVESIDEEFFECTAFTERCSEPASS
31+
#include "shardy/dialect/mpmd/transforms/common/passes.h.inc"
32+
33+
struct RemoveSideEffectAfterCSEPass
34+
: public impl::RemoveSideEffectAfterCSEPassBase<
35+
RemoveSideEffectAfterCSEPass> {
36+
using impl::RemoveSideEffectAfterCSEPassBase<
37+
RemoveSideEffectAfterCSEPass>::RemoveSideEffectAfterCSEPassBase;
38+
39+
void runOnOperation() override {
40+
getOperation().walk([&](stablehlo::CustomCallOp customCallOp) {
41+
if (customCallOp->hasAttr(kMhloNoCseAttr)) {
42+
customCallOp.setHasSideEffect(std::nullopt);
43+
}
44+
});
45+
}
46+
};
47+
48+
} // namespace
49+
50+
std::unique_ptr<Pass> createRemoveSideEffectAfterCSEPass() {
51+
return std::make_unique<RemoveSideEffectAfterCSEPass>();
52+
}
53+
54+
} // namespace mlir::mpmd
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: mpmd_opt %s -mpmd-add-side-effect-to-avoid-cse | FileCheck %s
2+
3+
// CHECK-LABEL: func @custom_call_with_no_cse_should_add_side_effect
4+
// CHECK-SAME: (%arg0: tensor<f32>) -> tensor<f32>
5+
func.func @custom_call_with_no_cse_should_add_side_effect(%arg0: tensor<f32>) -> tensor<f32> {
6+
// CHECK: %[[RES0:.*]] = stablehlo.custom_call @Sharding(%arg0)
7+
// CHECK-SAME: has_side_effect = true
8+
// CHECK-SAME: mhlo.no_cse
9+
// CHECK-SAME: : (tensor<f32>) -> tensor<f32>
10+
%0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.no_cse} : (tensor<f32>) -> tensor<f32>
11+
func.return %0 : tensor<f32>
12+
}
13+
14+
// CHECK-LABEL: func @custom_call_without_no_cse_should_not_add_side_effect
15+
// CHECK-SAME: (%arg0: tensor<f32>) -> tensor<f32>
16+
func.func @custom_call_without_no_cse_should_not_add_side_effect(%arg0: tensor<f32>) -> tensor<f32> {
17+
// CHECK-NOT: has_side_effect
18+
// CHECK: stablehlo.custom_call @Sharding(%arg0) : (tensor<f32>) -> tensor<f32>
19+
%0 = stablehlo.custom_call @Sharding(%arg0) : (tensor<f32>) -> tensor<f32>
20+
func.return %0 : tensor<f32>
21+
}
22+
23+
// CHECK-LABEL: func @other_op_with_no_cse_should_not_add_side_effect
24+
// CHECK-SAME: (%arg0: tensor<f32>) -> tensor<f32>
25+
func.func @other_op_with_no_cse_should_not_add_side_effect(%arg0: tensor<f32>) -> tensor<f32> {
26+
// CHECK-NOT: has_side_effect
27+
// CHECK: stablehlo.add %arg0, %arg0 {mhlo.no_cse} : tensor<f32>
28+
%0 = stablehlo.add %arg0, %arg0 {mhlo.no_cse} : tensor<f32>
29+
func.return %0 : tensor<f32>
30+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: mpmd_opt %s -mpmd-add-side-effect-to-avoid-cse -cse -mpmd-remove-side-effect-after-cse | FileCheck %s
2+
3+
// CHECK-LABEL: func @duplicate_custom_call_with_no_cse_should_be_csed
4+
// CHECK-SAME: (%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>)
5+
func.func @duplicate_custom_call_with_no_cse_should_be_csed(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
6+
// CHECK: %[[RES0:.*]] = stablehlo.custom_call @Sharding(%arg0)
7+
// CHECK-NOT: has_side_effect
8+
// CHECK-SAME: mhlo.no_cse
9+
// CHECK: %[[RES1:.*]] = stablehlo.custom_call @Sharding(%arg0)
10+
// CHECK-NOT: has_side_effect
11+
// CHECK-SAME: mhlo.no_cse
12+
// CHECK: return %[[RES0]], %[[RES1]]
13+
%0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.no_cse} : (tensor<f32>) -> tensor<f32>
14+
%1 = stablehlo.custom_call @Sharding(%arg0) {mhlo.no_cse} : (tensor<f32>) -> tensor<f32>
15+
func.return %0, %1 : tensor<f32>, tensor<f32>
16+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: mpmd_opt %s -mpmd-remove-side-effect-after-cse | FileCheck %s
2+
3+
// CHECK-LABEL: func @custom_call_with_no_cse_should_remove_side_effect
4+
// CHECK-SAME: (%arg0: tensor<f32>) -> tensor<f32>
5+
func.func @custom_call_with_no_cse_should_remove_side_effect(%arg0: tensor<f32>) -> tensor<f32> {
6+
// CHECK-NOT: has_side_effect = true
7+
// CHECK: %[[RES0:.*]] = stablehlo.custom_call @Sharding(%arg0)
8+
// CHECK-SAME: mhlo.no_cse
9+
// CHECK-SAME: : (tensor<f32>) -> tensor<f32>
10+
%0 = stablehlo.custom_call @Sharding(%arg0) {has_side_effect = true,mhlo.no_cse} : (tensor<f32>) -> tensor<f32>
11+
func.return %0 : tensor<f32>
12+
}
13+
14+
// CHECK-LABEL: func @custom_call_without_no_cse_should_do_nothing
15+
// CHECK-SAME: (%arg0: tensor<f32>) -> tensor<f32>
16+
func.func @custom_call_without_no_cse_should_do_nothing(%arg0: tensor<f32>) -> tensor<f32> {
17+
// CHECK: stablehlo.custom_call @Sharding(%arg0) {has_side_effect = true}
18+
%0 = stablehlo.custom_call @Sharding(%arg0) {has_side_effect = true}: (tensor<f32>) -> tensor<f32>
19+
func.return %0 : tensor<f32>
20+
}
21+
22+
// CHECK-LABEL: func @other_op_with_no_cse_should_do_nothing
23+
// CHECK-SAME: (%arg0: tensor<f32>) -> tensor<f32>
24+
func.func @other_op_with_no_cse_should_do_nothing(%arg0: tensor<f32>) -> tensor<f32> {
25+
// CHECK: stablehlo.add %arg0, %arg0 {has_side_effect = true, mhlo.no_cse} : tensor<f32>
26+
%0 = stablehlo.add %arg0, %arg0 {has_side_effect = true, mhlo.no_cse} : tensor<f32>
27+
func.return %0 : tensor<f32>
28+
}

shardy/dialect/mpmd/transforms/common/utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ limitations under the License.
3434

3535
namespace mlir::mpmd {
3636

37+
// The attribute to avoid CSE.
38+
inline constexpr StringRef kMhloNoCseAttr = "mhlo.no_cse";
39+
3740
// The name of the attribute that keeps track of how many times a loop has been
3841
// unrolled.
3942
constexpr StringRef kUnrollCounterAttrName = "unroll_counter";

shardy/dialect/mpmd/transforms/export/export_pipeline.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ void addExportPipeline(OpPassManager& pm, const ExportOptions& options) {
7979
// and by DCE'ing the fragment bodies.
8080
pm.addNestedPass<FuncOp>(createFragmentDcePass());
8181

82+
// Now all CSE is done, we can remove the side effect from custom calls that
83+
// have the no_cse attribute.
84+
pm.addNestedPass<FuncOp>(createRemoveSideEffectAfterCSEPass());
85+
8286
// Must be applied after the last -mpmd-fragment-dedup, as it may add
8387
// duplicated fragment results and after -canonicalize, as it may add
8488
// identity fragments, which would be canonicalized away.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: mpmd_opt %s -mpmd-import-pipeline='name-to-mesh-assignment=f1@m1,f2@m2' -mpmd-optimize-pipeline -mpmd-sharding-propagation-pipeline -mpmd-export-pipeline 2>&1 | FileCheck %s
2+
3+
#topology = #mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>>
4+
5+
// CHECK-LABEL: func.func @main
6+
func.func @main(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> attributes {
7+
"topology"=#topology} {
8+
// CHECK: %[[FRAGMENT_CALL:.*]] = mpmd.fragment_call<mesh="m1", origin=["f1"]> @p0_f1_fwd.main(%arg0)
9+
%1:2 = mpmd.named_computation<"f1"> (%arg0, %arg0) (%arg3: tensor<4x8xf32>, %arg4: tensor<4x8xf32>) {
10+
%2 = stablehlo.custom_call @sdy_testonly(%arg3) {mhlo.no_cse} : (tensor<4x8xf32>) -> tensor<4x8xf32>
11+
%3 = stablehlo.custom_call @sdy_testonly(%arg4) {mhlo.no_cse} : (tensor<4x8xf32>) -> tensor<4x8xf32>
12+
mpmd.return %2, %3 : tensor<4x8xf32>, tensor<4x8xf32>
13+
} : (tensor<4x8xf32>, tensor<4x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>)
14+
func.return %1#0 : tensor<4x8xf32>
15+
}
16+
// CHECK-LABEL: func.func @p0_f1_fwd.main
17+
// CHECK: %[[CUSTOM_CALL:.*]] = stablehlo.custom_call @sdy_testonly
18+
// CHECK-NOT: has_side_effect
19+
// CHECK-SAME: {mhlo.no_cse}
20+
// CHECK-NEXT: %[[CUSTOM_CALL_2:.*]] = stablehlo.custom_call @sdy_testonly
21+
// CHECK-NOT: has_side_effect
22+
// CHECK-SAME: {mhlo.no_cse}

0 commit comments

Comments
 (0)