@@ -3316,6 +3316,7 @@ end
33163316 x:: TracedRArray{T,N} ,
33173317 :: Type{iT} = Int32;
33183318 full:: Bool = false ,
3319+ algorithm:: String = " DEFAULT" ,
33193320 location= mlir_stacktrace (" svd" , @__FILE__ , @__LINE__ ),
33203321) where {T,iT,N}
33213322 @assert N >= 2
@@ -3329,13 +3330,26 @@ end
33293330 Vt_size = (batch_sizes... , full ? n : r, n)
33303331 info_size = batch_sizes
33313332
3333+ if algorithm == " DEFAULT"
3334+ algint = 0
3335+ elseif algorithm == " QRIteration"
3336+ algint = 1
3337+ elseif algorithm == " DivideAndConquer"
3338+ algint = 2
3339+ elseif algorithm == " Jacobi"
3340+ algint = 3
3341+ else
3342+ error (" Unsupported SVD algorithm: $algorithm " )
3343+ end
3344+
33323345 svd_op = enzymexla. linalg_svd (
33333346 x. mlir_data;
33343347 U= mlir_type (TracedRArray{T,N}, U_size),
33353348 S= mlir_type (TracedRArray{Base. real (T),N - 1 }, S_size),
33363349 Vt= mlir_type (TracedRArray{T,N}, Vt_size),
33373350 info= mlir_type (TracedRArray{iT,N - 2 }, info_size),
33383351 full= full,
3352+ algorithm= MLIR. API. enzymexlaSVDAlgorithmAttrGet (MLIR. IR. context (), algint),
33393353 location,
33403354 )
33413355
0 commit comments