Skip to content

Commit cd934ef

Browse files
committed
feat: update to allow algorithms
1 parent 28c9143 commit cd934ef

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/Compiler.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,6 +1722,8 @@ function compile_mlir!(
17221722

17231723
blas_int_width = sizeof(BlasInt) * 8
17241724
lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \
1725+
blas_int_width=$blas_int_width},\
1726+
lower-enzymexla-lapack{backend=$backend \
17251727
blas_int_width=$blas_int_width}"
17261728

17271729
legalize_chlo_to_stablehlo =

src/Ops.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3316,6 +3316,7 @@ end
33163316
x::TracedRArray{T,N},
33173317
::Type{iT}=Int32;
33183318
full::Bool=false,
3319+
algorithm::String="DEFAULT",
33193320
location=mlir_stacktrace("svd", @__FILE__, @__LINE__),
33203321
) where {T,iT,N}
33213322
@assert N >= 2
@@ -3329,13 +3330,26 @@ end
33293330
Vt_size = (batch_sizes..., full ? n : r, n)
33303331
info_size = batch_sizes
33313332

3333+
if algorithm == "DEFAULT"
3334+
algint = 0
3335+
elseif algorithm == "QRIteration"
3336+
algint = 1
3337+
elseif algorithm == "DivideAndConquer"
3338+
algint = 2
3339+
elseif algorithm == "Jacobi"
3340+
algint = 3
3341+
else
3342+
error("Unsupported SVD algorithm: $algorithm")
3343+
end
3344+
33323345
svd_op = enzymexla.linalg_svd(
33333346
x.mlir_data;
33343347
U=mlir_type(TracedRArray{T,N}, U_size),
33353348
S=mlir_type(TracedRArray{Base.real(T),N - 1}, S_size),
33363349
Vt=mlir_type(TracedRArray{T,N}, Vt_size),
33373350
info=mlir_type(TracedRArray{iT,N - 2}, info_size),
33383351
full=full,
3352+
algorithm=MLIR.API.enzymexlaSVDAlgorithmAttrGet(MLIR.IR.context(), algint),
33393353
location,
33403354
)
33413355

0 commit comments

Comments
 (0)