Skip to content

Commit 7aaea82

Browse files
committed
Add new members to CPUOptions/CPUDriver and fix/xfail test failures
Add `sanitize_overflow: bool = True` to class CPUOptions in compiler.py and `get_benchmarker(self)` to class CPUDriver in driver.py to run the tests. XFAILing TritonToLinalg tests since this pass will be retire soon: test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir test/Conversion/TritonToLinalg/triton_assert.mlir test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir test/Conversion/TritonToLinalg/wraparound_stacked.mlir XFAILing StructuredToMemref tests due to LLVM commit 889b67c9d30e3024a1317431d66c22599f6c2011 asserts that dynamic shapes like <2x?> and <?x?> are mismatch: test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir test/Conversion/StructuredToMemref/wraparound_stacked.mlir
1 parent e9668d3 commit 7aaea82

12 files changed

+25
-12
lines changed

backend/compiler.py

+1
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class CPUOptions:
127127
shared: bool = False
128128
allow_fp8e4nv: bool = False
129129
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
130+
sanitize_overflow: bool = True
130131

131132
def __post_init__(self):
132133
pass

backend/driver.py

+4
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,10 @@ def __init__(self):
347347
def is_active():
348348
return False
349349

350+
def get_benchmarker(self):
351+
from triton.testing import do_bench
352+
return do_bench
353+
350354
def get_device_capability(self):
351355
return ("cpu", 0)
352356

test/Conversion/StructuredToMemref/convert_tensor_reshape.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ module {
1313
%8 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>>
1414
%9 = tt.addptr %8, %4 : tensor<32x!tt.ptr<f32>>, tensor<32xi32>
1515
%10 = tt.load %9 : tensor<32x!tt.ptr<f32>>
16-
%11 = tt.reshape %10 {allow_reorder = false} : tensor<32xf32> -> tensor<1x32xf32>
16+
%11 = tt.reshape %10 allow_reorder : tensor<32xf32> -> tensor<1x32xf32>
1717
%12 = tt.broadcast %11 : tensor<1x32xf32> -> tensor<64x32xf32>
18-
%13 = tt.reshape %12 {allow_reorder = false} : tensor<64x32xf32> -> tensor<2048xf32>
18+
%13 = tt.reshape %12 allow_reorder : tensor<64x32xf32> -> tensor<2048xf32>
1919
%14 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<2048x!tt.ptr<f32>>
2020
%15 = tt.addptr %14, %7 : tensor<2048x!tt.ptr<f32>>, tensor<2048xi32>
2121
tt.store %15, %13 : tensor<2048x!tt.ptr<f32>>

test/Conversion/StructuredToMemref/triton_assert.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ tt.func public @assert_lol(%arg0: i32) {
33
%c0_i32 = arith.constant 0 : i32
44
%0 = arith.cmpi sgt, %arg0, %c0_i32 : i32
55
%1 = tt.splat %0 : i1 -> tensor<1xi1>
6-
tt.assert %1, "lol", "", "", 0 : tensor<1xi1>
6+
tt.assert %1, "lol" : tensor<1xi1>
77
tt.return
88
}
99

@@ -12,6 +12,6 @@ tt.func public @assert_lol(%arg0: i32) {
1212
// CHECK-SAME: ([[PARAM_0_:%.+]]: i32, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
1313
// CHECK: [[CST_0_:%.+]] = arith.constant 0 : i32
1414
// CHECK: [[VAR_0_:%.+]] = arith.cmpi sgt, [[PARAM_0_]], [[CST_0_]] : i32
15-
// CHECK: cf.assert [[VAR_0_]], ".py:0: Assertion `lol` failed"
15+
// CHECK: cf.assert [[VAR_0_]], "Assertion `lol` failed"
1616
// CHECK: return
1717
// CHECK: }

test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// XFAIL: *
2+
// Note: LLVM commit 889b67c9d30e3024a1317431d66c22599f6c2011 asserts that dynamic shapes like
3+
// <?x?> and <2x?> are mismatch.
14
// RUN: triton-shared-opt --split-input-file --triton-to-linalg-experimental %s | FileCheck %s
25

36
module {

test/Conversion/StructuredToMemref/wraparound_stacked.mlir

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// XFAIL: *
2+
// Note: LLVM commit 889b67c9d30e3024a1317431d66c22599f6c2011 asserts that dynamic shapes like
3+
// <?x?> and <2x?> are mismatch.
14
// RUN: triton-shared-opt --split-input-file --triton-to-linalg-experimental %s | FileCheck %s
25

36
module {

test/Conversion/TritonArithToLinalg/convert_tensor_reshape.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ module {
1313
%8 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>>
1414
%9 = tt.addptr %8, %4 : tensor<32x!tt.ptr<f32>>, tensor<32xi32>
1515
%10 = tt.load %9 : tensor<32x!tt.ptr<f32>>
16-
%11 = tt.reshape %10 {allow_reorder = false} : tensor<32xf32> -> tensor<1x32xf32>
16+
%11 = tt.reshape %10 allow_reorder : tensor<32xf32> -> tensor<1x32xf32>
1717
%12 = tt.broadcast %11 : tensor<1x32xf32> -> tensor<64x32xf32>
18-
%13 = tt.reshape %12 {allow_reorder = false} : tensor<64x32xf32> -> tensor<2048xf32>
18+
%13 = tt.reshape %12 allow_reorder : tensor<64x32xf32> -> tensor<2048xf32>
1919
%14 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<2048x!tt.ptr<f32>>
2020
%15 = tt.addptr %14, %7 : tensor<2048x!tt.ptr<f32>>, tensor<2048xi32>
2121
tt.store %15, %13 : tensor<2048x!tt.ptr<f32>>

test/Conversion/TritonArithToLinalg/triton_assert.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ tt.func public @assert_lol(%arg0: i32) {
33
%c0_i32 = arith.constant 0 : i32
44
%0 = arith.cmpi sgt, %arg0, %c0_i32 : i32
55
%1 = tt.splat %0 : i1 -> tensor<1xi1>
6-
tt.assert %1, "lol", "", "", 0 : tensor<1xi1>
6+
tt.assert %1, "lol" : tensor<1xi1>
77
tt.return
88
}
99

@@ -13,6 +13,6 @@ tt.func public @assert_lol(%arg0: i32) {
1313
// CHECK-DAG: [[VAR_0_:%.+]] = arith.cmpi sgt, [[PARAM_0_]], [[CST_0_]] : i32
1414
// CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<1xi1>
1515
// CHECK: [[VAR_2_:%.+]] = linalg.fill ins([[VAR_0_]] : i1) outs([[VAR_1_]] : tensor<1xi1>) -> tensor<1xi1>
16-
// CHECK: cf.assert [[VAR_0_]], ".py:0: Assertion `lol` failed"
16+
// CHECK: cf.assert [[VAR_0_]], "Assertion `lol` failed"
1717
// CHECK: return
1818
// CHECK: }

test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ module {
1313
%8 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x!tt.ptr<f32>>
1414
%9 = tt.addptr %8, %4 : tensor<32x!tt.ptr<f32>>, tensor<32xi32>
1515
%10 = tt.load %9 : tensor<32x!tt.ptr<f32>>
16-
%11 = tt.reshape %10 {allow_reorder = false} : tensor<32xf32> -> tensor<1x32xf32>
16+
%11 = tt.reshape %10 allow_reorder : tensor<32xf32> -> tensor<1x32xf32>
1717
%12 = tt.broadcast %11 : tensor<1x32xf32> -> tensor<64x32xf32>
18-
%13 = tt.reshape %12 {allow_reorder = false} : tensor<64x32xf32> -> tensor<2048xf32>
18+
%13 = tt.reshape %12 allow_reorder : tensor<64x32xf32> -> tensor<2048xf32>
1919
%14 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<2048x!tt.ptr<f32>>
2020
%15 = tt.addptr %14, %7 : tensor<2048x!tt.ptr<f32>>, tensor<2048xi32>
2121
tt.store %15, %13 : tensor<2048x!tt.ptr<f32>>

test/Conversion/TritonToLinalg/triton_assert.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ tt.func public @assert_lol(%arg0: i32) {
33
%c0_i32 = arith.constant 0 : i32
44
%0 = arith.cmpi sgt, %arg0, %c0_i32 : i32
55
%1 = tt.splat %0 : i1 -> tensor<1xi1>
6-
tt.assert %1, "lol", "", "", 0 : tensor<1xi1>
6+
tt.assert %1, "lol": tensor<1xi1>
77
tt.return
88
}
99

1010
// CHECK: func.func @assert_lol(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) {
1111
// CHECK: %c0_i32 = arith.constant 0 : i32
1212
// CHECK: %0 = arith.cmpi sgt, %arg0, %c0_i32 : i32
13-
// CHECK: cf.assert %0, ".py:0: Assertion `lol` failed"
13+
// CHECK: cf.assert %0, "Assertion `lol` failed"
1414
// CHECK: return
1515
// CHECK: }

test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// XFAIL: *
12
// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s
23

34
module {

test/Conversion/TritonToLinalg/wraparound_stacked.mlir

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// XFAIL: *
12
// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s
23

34
module {

0 commit comments

Comments
 (0)