Skip to content

Commit 5289e51

Browse files
committed
feat: svd op
1 parent 5ef447d commit 5289e51

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

src/Ops.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3312,6 +3312,46 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors `
33123312
return (res, ipiv, perm, info)
33133313
end
33143314

3315+
@noinline function svd(
3316+
x::TracedRArray{T,N},
3317+
::Type{iT}=Int32;
3318+
full::Bool=false,
3319+
location=mlir_stacktrace("svd", @__FILE__, @__LINE__),
3320+
) where {T,iT,N}
3321+
@assert N >= 2
3322+
3323+
batch_sizes = size(x)[1:(end - 2)]
3324+
m, n = size(x)[(end - 1):end]
3325+
r = min(m, n)
3326+
3327+
U_size = (batch_sizes..., m, full ? m : r)
3328+
S_size = (batch_sizes..., r)
3329+
Vt_size = (batch_sizes..., full ? n : r, n)
3330+
info_size = batch_sizes
3331+
3332+
svd_op = enzymexla.linalg_svd(
3333+
x.mlir_data;
3334+
U=mlir_type(TracedRArray{T,N}, U_size),
3335+
S=mlir_type(TracedRArray{Base.real(T),N - 1}, S_size),
3336+
Vt=mlir_type(TracedRArray{T,N}, Vt_size),
3337+
info=mlir_type(TracedRArray{iT,N - 2}, info_size),
3338+
full=full,
3339+
location,
3340+
)
3341+
3342+
U = TracedRArray{T,N}((), MLIR.IR.result(svd_op, 1), U_size)
3343+
S = TracedRArray{Base.real(T),N - 1}((), MLIR.IR.result(svd_op, 2), S_size)
3344+
Vt = TracedRArray{T,N}((), MLIR.IR.result(svd_op, 3), Vt_size)
3345+
3346+
if N == 2
3347+
info = TracedRNumber{iT}((), MLIR.IR.result(svd_op, 4))
3348+
else
3349+
info = TracedRArray{iT,N - 2}((), MLIR.IR.result(svd_op, 4), info_size)
3350+
end
3351+
3352+
return U, S, Vt, info
3353+
end
3354+
33153355
@noinline function reduce_window(
33163356
f::F,
33173357
inputs::Vector{TracedRArray{T,N}},

src/stdlibs/LinearAlgebra.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ function __init__()
2727
(BLAS.@blasfunc(dgetrf_), :enzymexla_lapack_dgetrf_),
2828
(BLAS.@blasfunc(cgetrf_), :enzymexla_lapack_cgetrf_),
2929
(BLAS.@blasfunc(zgetrf_), :enzymexla_lapack_zgetrf_),
30+
(BLAS.@blasfunc(sgesvd_), :enzymexla_lapack_sgesvd_),
31+
(BLAS.@blasfunc(dgesvd_), :enzymexla_lapack_dgesvd_),
32+
(BLAS.@blasfunc(cgesvd_), :enzymexla_lapack_cgesvd_),
33+
(BLAS.@blasfunc(zgesvd_), :enzymexla_lapack_zgesvd_),
3034
]
3135
sym = Libdl.dlsym(libblastrampoline_handle, cname)
3236
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(

0 commit comments

Comments
 (0)