Skip to content

Commit 31bf934

Browse files
authored
[MLIR][NVVM] Add an explicit mask operand to elect.sync (#145509)
This patch adds a mask operand to elect.sync explicitly. When provided, this overrides the default value of 0xffffffff. Signed-off-by: Durgadoss R <[email protected]>
1 parent 77618a9 commit 31bf934

File tree

3 files changed

+28
-15
lines changed

3 files changed

+28
-15
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -965,19 +965,21 @@ def NVVM_ElectSyncOp : NVVM_Op<"elect.sync">
965965
let summary = "Elect one leader thread";
966966
let description = [{
967967
The `elect.sync` instruction elects one predicated active leader
968-
thread from among a set of threads specified in membermask.
969-
The membermask is set to `0xFFFFFFFF` for the current version
970-
of this Op. The predicate result is set to `True` for the
971-
leader thread, and `False` for all other threads.
968+
thread from among a set of threads specified in the `membermask`.
969+
When the `membermask` is not provided explicitly, a default value
970+
of `0xFFFFFFFF` is used. The predicate result is set to `True` for
971+
the leader thread, and `False` for all other threads.
972972

973973
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-elect-sync)
974974
}];
975975

976+
let arguments = (ins Optional<I32>:$membermask);
976977
let results = (outs I1:$pred);
977-
let assemblyFormat = "attr-dict `->` type(results)";
978+
let assemblyFormat = "($membermask^)? attr-dict `->` type(results)";
978979
string llvmBuilder = [{
979980
auto *resultTuple = createIntrinsicCall(builder,
980-
llvm::Intrinsic::nvvm_elect_sync, {builder.getInt32(0xFFFFFFFF)});
981+
llvm::Intrinsic::nvvm_elect_sync,
982+
{$membermask ? $membermask : builder.getInt32(0xFFFFFFFF)});
981983
// Extract the second value into $pred
982984
$pred = builder.CreateExtractValue(resultTuple, 1);
983985
}];
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: @test_nvvm_elect_sync
4+
llvm.func @test_nvvm_elect_sync() -> i1 {
5+
// CHECK: %[[RES:.*]] = call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1)
6+
// CHECK-NEXT: %[[PRED:.*]] = extractvalue { i32, i1 } %[[RES]], 1
7+
// CHECK-NEXT: ret i1 %[[PRED]]
8+
%0 = nvvm.elect.sync -> i1
9+
llvm.return %0 : i1
10+
}
11+
12+
// CHECK-LABEL: @test_nvvm_elect_sync_mask
13+
llvm.func @test_nvvm_elect_sync_mask(%mask : i32) -> i1 {
14+
// CHECK: %[[RES:.*]] = call { i32, i1 } @llvm.nvvm.elect.sync(i32 %0)
15+
// CHECK-NEXT: %[[PRED:.*]] = extractvalue { i32, i1 } %[[RES]], 1
16+
// CHECK-NEXT: ret i1 %[[PRED]]
17+
%0 = nvvm.elect.sync %mask -> i1
18+
llvm.return %0 : i1
19+
}
20+

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,15 +265,6 @@ llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
265265
llvm.return %3 : i32
266266
}
267267

268-
// CHECK-LABEL: @nvvm_elect_sync
269-
llvm.func @nvvm_elect_sync() -> i1 {
270-
// CHECK: %[[RES:.*]] = call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1)
271-
// CHECK-NEXT: %[[PRED:.*]] = extractvalue { i32, i1 } %[[RES]], 1
272-
// CHECK-NEXT: ret i1 %[[PRED]]
273-
%0 = nvvm.elect.sync -> i1
274-
llvm.return %0 : i1
275-
}
276-
277268
// CHECK-LABEL: @nvvm_mma_mn8n8k4_row_col_f32_f32
278269
llvm.func @nvvm_mma_mn8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
279270
%b0 : vector<2xf16>, %b1 : vector<2xf16>,

0 commit comments

Comments
 (0)