@@ -11,20 +11,20 @@ function Base.fill!(A::AnyGPUArray{T}, x) where T
11
11
end
12
12
13
13
14
- # # uniform scaling
14
+ # # identity matrices
15
15
16
- function uniformscaling_kernel (ctx:: AbstractKernelContext , res:: AbstractArray{T} , stride, s :: UniformScaling ) where T
16
+ function identity_kernel (ctx:: AbstractKernelContext , res:: AbstractArray{T} , stride, val ) where T
17
17
i = linear_index (ctx)
18
18
i > stride && return
19
19
ilin = (stride * (i - 1 )) + i
20
- @inbounds res[ilin] = s . λ
20
+ @inbounds res[ilin] = val
21
21
return
22
22
end
23
23
24
24
function (T:: Type{<: AnyGPUArray{U}} )(s:: UniformScaling , dims:: Dims{2} ) where {U}
25
25
res = similar (T, dims)
26
26
fill! (res, zero (U))
27
- gpu_call (uniformscaling_kernel , res, size (res, 1 ), s; total_threads= minimum (dims))
27
+ gpu_call (identity_kernel , res, size (res, 1 ), s. λ ; total_threads= minimum (dims))
28
28
res
29
29
end
30
30
34
34
35
35
function Base. copyto! (A:: AbstractGPUMatrix{T} , s:: UniformScaling ) where T
36
36
fill! (A, zero (T))
37
- gpu_call (uniformscaling_kernel , A, size (A, 1 ), s; total_threads= minimum (size (A)))
37
+ gpu_call (identity_kernel , A, size (A, 1 ), s. λ ; total_threads= minimum (size (A)))
38
38
A
39
39
end
40
40
41
- function Base. one (x:: AbstractGPUMatrix )
42
- size (x,1 )== size (x,2 ) ||
43
- throw (DimensionMismatch (" multiplicative identity defined only for square matrices" ))
44
- typeof (x)(I, size (x))
41
+ function _one (unit:: T , x:: AbstractGPUMatrix ) where {T}
42
+ m,n = size (x)
43
+ m== n || throw (DimensionMismatch (" multiplicative identity defined only for square matrices" ))
44
+ I = similar (x, T)
45
+ fill! (I, zero (T))
46
+ gpu_call (identity_kernel, I, m, unit; total_threads= m)
47
+ I
45
48
end
46
49
50
+ Base. one (x:: AbstractGPUMatrix{T} ) where {T} = _one (one (T), x)
51
+ Base. oneunit (x:: AbstractGPUMatrix{T} ) where {T} = _one (oneunit (T), x)
52
+
47
53
48
54
# # collect & convert
49
55
0 commit comments