Skip to content

Commit 14990d0

Browse files
committed
Generalize implementation to support Base.oneunit.
1 parent 7055806 commit 14990d0

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

src/host/construction.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@ function Base.fill!(A::AnyGPUArray{T}, x) where T
1111
end
1212

1313

14-
## uniform scaling
14+
## identity matrices
1515

16-
function uniformscaling_kernel(ctx::AbstractKernelContext, res::AbstractArray{T}, stride, s::UniformScaling) where T
16+
function identity_kernel(ctx::AbstractKernelContext, res::AbstractArray{T}, stride, val) where T
1717
i = linear_index(ctx)
1818
i > stride && return
1919
ilin = (stride * (i - 1)) + i
20-
@inbounds res[ilin] = s.λ
20+
@inbounds res[ilin] = val
2121
return
2222
end
2323

2424
function (T::Type{<: AnyGPUArray{U}})(s::UniformScaling, dims::Dims{2}) where {U}
2525
res = similar(T, dims)
2626
fill!(res, zero(U))
27-
gpu_call(uniformscaling_kernel, res, size(res, 1), s; total_threads=minimum(dims))
27+
gpu_call(identity_kernel, res, size(res, 1), s.λ; total_threads=minimum(dims))
2828
res
2929
end
3030

@@ -34,16 +34,22 @@ end
3434

3535
function Base.copyto!(A::AbstractGPUMatrix{T}, s::UniformScaling) where T
3636
fill!(A, zero(T))
37-
gpu_call(uniformscaling_kernel, A, size(A, 1), s; total_threads=minimum(size(A)))
37+
gpu_call(identity_kernel, A, size(A, 1), s.λ; total_threads=minimum(size(A)))
3838
A
3939
end
4040

41-
function Base.one(x::AbstractGPUMatrix)
42-
size(x,1)==size(x,2) ||
43-
throw(DimensionMismatch("multiplicative identity defined only for square matrices"))
44-
typeof(x)(I, size(x))
41+
function _one(unit::T, x::AbstractGPUMatrix) where {T}
42+
m,n = size(x)
43+
m==n || throw(DimensionMismatch("multiplicative identity defined only for square matrices"))
44+
I = similar(x, T)
45+
fill!(I, zero(T))
46+
gpu_call(identity_kernel, I, m, unit; total_threads=m)
47+
I
4548
end
4649

50+
Base.one(x::AbstractGPUMatrix{T}) where {T} = _one(one(T), x)
51+
Base.oneunit(x::AbstractGPUMatrix{T}) where {T} = _one(oneunit(T), x)
52+
4753

4854
## collect & convert
4955

test/testsuite/construction.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ end
121121
A = one(AT(rand(T, 2, 2)))
122122
@test A isa AT{T,2}
123123
@test Array(A) == one(rand(T, 2, 2))
124+
125+
A = oneunit(AT(rand(T, 2, 2)))
126+
@test A isa AT{T,2}
127+
@test Array(A) == oneunit(rand(T, 2, 2))
124128
end
125129
end
126130

@@ -154,7 +158,6 @@ end
154158

155159
@testsuite "construct/uniformscaling" AT->begin
156160
for T in supported_eltypes()
157-
158161
x = Matrix{T}(I, 4, 2)
159162

160163
x1 = AT{T, 2}(I, 4, 2)

0 commit comments

Comments
 (0)