Skip to content

Commit 302bd57

Browse files
committed
fix: use correct return
1 parent 51e772c commit 302bd57

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

src/enzyme_ad/jax/Passes/LowerEnzymeXLALinalg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ struct LUFactorizationOpLowering
108108
/*output_operand_aliases=*/outputOperandAliases,
109109
/*xla_side_effect_free=*/rewriter.getUnitAttr());
110110

111-
stablehlo::ReturnOp::create(rewriter, op.getLoc(),
111+
func::ReturnOp::create(rewriter, op.getLoc(),
112112
ValueRange{jitCall.getResult(0),
113113
jitCall.getResult(1),
114114
jitCall.getResult(2)});

test/lit_tests/linalg/lu.mlir

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -75,29 +75,29 @@ module {
7575
// TPU-NEXT: return %0#0, %1, %2, %6 : tensor<64x64xf32>, tensor<64xi32>, tensor<64xi32>, tensor<i32>
7676
// TPU-NEXT: }
7777

78-
// module {
79-
// // CPU: enzymexla.jit_call @enzymexla_lapack_dgetrf_
80-
// // CPU: func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<i32>) {
81-
// func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<i32>) {
82-
// %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<64xi32>, tensor<i32>)
83-
// return %0#0, %0#1, %0#3 : tensor<64x64xf64>, tensor<64xi32>, tensor<i32>
84-
// }
85-
// }
78+
module {
79+
// CPU: enzymexla.jit_call @enzymexla_lapack_dgetrf_
80+
// CPU: func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<i32>) {
81+
func.func @main(%arg0: tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<i32>) {
82+
%0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xf64>) -> (tensor<64x64xf64>, tensor<64xi32>, tensor<64xi32>, tensor<i32>)
83+
return %0#0, %0#1, %0#3 : tensor<64x64xf64>, tensor<64xi32>, tensor<i32>
84+
}
85+
}
8686

87-
// module {
88-
// // CPU: enzymexla.jit_call @enzymexla_lapack_zgetrf_
89-
// // CPU: func.func @main(%arg0: tensor<64x64xcomplex<f64>>) -> (tensor<64x64xcomplex<f64>>, tensor<64xi32>, tensor<i32>) {
90-
// func.func @main(%arg0: tensor<64x64xcomplex<f64>>) -> (tensor<64x64xcomplex<f64>>, tensor<64xi32>, tensor<i32>) {
91-
// %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex<f64>>) -> (tensor<64x64xcomplex<f64>>, tensor<64xi32>, tensor<64xi32>, tensor<i32>)
92-
// return %0#0, %0#1, %0#3 : tensor<64x64xcomplex<f64>>, tensor<64xi32>, tensor<i32>
93-
// }
94-
// }
87+
module {
88+
// CPU: enzymexla.jit_call @enzymexla_lapack_zgetrf_
89+
// CPU: func.func @main(%arg0: tensor<64x64xcomplex<f64>>) -> (tensor<64x64xcomplex<f64>>, tensor<64xi32>, tensor<i32>) {
90+
func.func @main(%arg0: tensor<64x64xcomplex<f64>>) -> (tensor<64x64xcomplex<f64>>, tensor<64xi32>, tensor<i32>) {
91+
%0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex<f64>>) -> (tensor<64x64xcomplex<f64>>, tensor<64xi32>, tensor<64xi32>, tensor<i32>)
92+
return %0#0, %0#1, %0#3 : tensor<64x64xcomplex<f64>>, tensor<64xi32>, tensor<i32>
93+
}
94+
}
9595

96-
// module {
97-
// // CPU: enzymexla.jit_call @enzymexla_lapack_cgetrf_
98-
// // CPU: func.func @main(%arg0: tensor<64x64xcomplex<f32>>) -> (tensor<64x64xcomplex<f32>>, tensor<64xi32>, tensor<i32>) {
99-
// func.func @main(%arg0: tensor<64x64xcomplex<f32>>) -> (tensor<64x64xcomplex<f32>>, tensor<64xi32>, tensor<i32>) {
100-
// %0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex<f32>>) -> (tensor<64x64xcomplex<f32>>, tensor<64xi32>, tensor<64xi32>, tensor<i32>)
101-
// return %0#0, %0#1, %0#3 : tensor<64x64xcomplex<f32>>, tensor<64xi32>, tensor<i32>
102-
// }
103-
// }
96+
module {
97+
// CPU: enzymexla.jit_call @enzymexla_lapack_cgetrf_
98+
// CPU: func.func @main(%arg0: tensor<64x64xcomplex<f32>>) -> (tensor<64x64xcomplex<f32>>, tensor<64xi32>, tensor<i32>) {
99+
func.func @main(%arg0: tensor<64x64xcomplex<f32>>) -> (tensor<64x64xcomplex<f32>>, tensor<64xi32>, tensor<i32>) {
100+
%0:4 = enzymexla.linalg.lu %arg0 : (tensor<64x64xcomplex<f32>>) -> (tensor<64x64xcomplex<f32>>, tensor<64xi32>, tensor<64xi32>, tensor<i32>)
101+
return %0#0, %0#1, %0#3 : tensor<64x64xcomplex<f32>>, tensor<64xi32>, tensor<i32>
102+
}
103+
}

0 commit comments

Comments
 (0)