@@ -3312,6 +3312,46 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors `
33123312 return (res, ipiv, perm, info)
33133313end
33143314
3315+ @noinline function svd (
3316+ x:: TracedRArray{T,N} ,
3317+ :: Type{iT} = Int32;
3318+ full:: Bool = false ,
3319+ location= mlir_stacktrace (" svd" , @__FILE__ , @__LINE__ ),
3320+ ) where {T,iT,N}
3321+ @assert N >= 2
3322+
3323+ batch_sizes = size (x)[1 : (end - 2 )]
3324+ m, n = size (x)[(end - 1 ): end ]
3325+ r = min (m, n)
3326+
3327+ U_size = (batch_sizes... , m, full ? m : r)
3328+ S_size = (batch_sizes... , r)
3329+ Vt_size = (batch_sizes... , full ? n : r, n)
3330+ info_size = batch_sizes
3331+
3332+ svd_op = enzymexla. linalg_svd (
3333+ x. mlir_data;
3334+ U= mlir_type (TracedRArray{T,N}, U_size),
3335+ S= mlir_type (TracedRArray{Base. real (T),N - 1 }, S_size),
3336+ Vt= mlir_type (TracedRArray{T,N}, Vt_size),
3337+ info= mlir_type (TracedRArray{iT,N - 2 }, info_size),
3338+ full= full,
3339+ location,
3340+ )
3341+
3342+ U = TracedRArray {T,N} ((), MLIR. IR. result (svd_op, 1 ), U_size)
3343+ S = TracedRArray {Base.real(T),N - 1} ((), MLIR. IR. result (svd_op, 2 ), S_size)
3344+ Vt = TracedRArray {T,N} ((), MLIR. IR. result (svd_op, 3 ), Vt_size)
3345+
3346+ if N == 2
3347+ info = TracedRNumber {iT} ((), MLIR. IR. result (svd_op, 4 ))
3348+ else
3349+ info = TracedRArray {iT,N - 2} ((), MLIR. IR. result (svd_op, 4 ), info_size)
3350+ end
3351+
3352+ return U, S, Vt, info
3353+ end
3354+
33153355@noinline function reduce_window (
33163356 f:: F ,
33173357 inputs:: Vector{TracedRArray{T,N}} ,
0 commit comments