Skip to content

Commit 5172041

Browse files
authored
Merge pull request #358 from JuliaGPU/tb/one
Add and test multiplicative identity based on uniform scaling.
2 parents 5c7e76a + 14990d0 commit 5172041

File tree

3 files changed

+152
-95
lines changed

3 files changed

+152
-95
lines changed

src/host/construction.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# constructors and conversions
1+
# convenience and indirect construction
22

33
function Base.fill!(A::AnyGPUArray{T}, x) where T
44
length(A) == 0 && return A
@@ -10,18 +10,21 @@ function Base.fill!(A::AnyGPUArray{T}, x) where T
1010
A
1111
end
1212

13-
function uniformscaling_kernel(ctx::AbstractKernelContext, res::AbstractArray{T}, stride, s::UniformScaling) where T
13+
14+
## identity matrices
15+
16+
function identity_kernel(ctx::AbstractKernelContext, res::AbstractArray{T}, stride, val) where T
1417
i = linear_index(ctx)
1518
i > stride && return
1619
ilin = (stride * (i - 1)) + i
17-
@inbounds res[ilin] = s.λ
20+
@inbounds res[ilin] = val
1821
return
1922
end
2023

2124
function (T::Type{<: AnyGPUArray{U}})(s::UniformScaling, dims::Dims{2}) where {U}
2225
res = similar(T, dims)
2326
fill!(res, zero(U))
24-
gpu_call(uniformscaling_kernel, res, size(res, 1), s; total_threads=minimum(dims))
27+
gpu_call(identity_kernel, res, size(res, 1), s.λ; total_threads=minimum(dims))
2528
res
2629
end
2730

@@ -31,10 +34,25 @@ end
3134

3235
function Base.copyto!(A::AbstractGPUMatrix{T}, s::UniformScaling) where T
3336
fill!(A, zero(T))
34-
gpu_call(uniformscaling_kernel, A, size(A, 1), s; total_threads=minimum(size(A)))
37+
gpu_call(identity_kernel, A, size(A, 1), s.λ; total_threads=minimum(size(A)))
3538
A
3639
end
3740

41+
function _one(unit::T, x::AbstractGPUMatrix) where {T}
42+
m,n = size(x)
43+
m==n || throw(DimensionMismatch("multiplicative identity defined only for square matrices"))
44+
I = similar(x, T)
45+
fill!(I, zero(T))
46+
gpu_call(identity_kernel, I, m, unit; total_threads=m)
47+
I
48+
end
49+
50+
Base.one(x::AbstractGPUMatrix{T}) where {T} = _one(one(T), x)
51+
Base.oneunit(x::AbstractGPUMatrix{T}) where {T} = _one(oneunit(T), x)
52+
53+
54+
## collect & convert
55+
3856
function indexstyle(x::T) where T
3957
style = try
4058
Base.IndexStyle(x)

test/testsuite/construction.jl

Lines changed: 120 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,134 @@
1-
@testsuite "constructors" AT->begin
2-
@testset "constructors + similar" begin
3-
for T in supported_eltypes()
4-
B = AT{T}(undef, 10)
5-
@test B isa AT{T,1}
6-
@test size(B) == (10,)
7-
@test eltype(B) == T
8-
9-
B = AT{T}(undef, 10, 10)
10-
@test B isa AT{T,2}
11-
@test size(B) == (10, 10)
12-
@test eltype(B) == T
13-
14-
B = AT{T}(undef, (10, 10))
15-
@test B isa AT{T,2}
16-
@test size(B) == (10, 10)
17-
@test eltype(B) == T
18-
19-
B = similar(B, Int32, 11, 15)
20-
@test B isa AT{Int32,2}
21-
@test size(B) == (11, 15)
22-
@test eltype(B) == Int32
1+
@testsuite "construct/direct" AT->begin
2+
for T in supported_eltypes()
3+
B = AT{T}(undef, 10)
4+
@test B isa AT{T,1}
5+
@test size(B) == (10,)
6+
@test eltype(B) == T
237

24-
B = similar(B, T)
25-
@test B isa AT{T,2}
26-
@test size(B) == (11, 15)
27-
@test eltype(B) == T
8+
B = AT{T}(undef, 10, 10)
9+
@test B isa AT{T,2}
10+
@test size(B) == (10, 10)
11+
@test eltype(B) == T
12+
13+
B = AT{T}(undef, (10, 10))
14+
@test B isa AT{T,2}
15+
@test size(B) == (10, 10)
16+
@test eltype(B) == T
17+
end
18+
19+
# compare against Array
20+
for typs in [(), (Int,), (Int,1), (Int,2), (Float32,), (Float32,1), (Float32,2)],
21+
args in [(), (1,), (1,2), ((1,),), ((1,2),),
22+
(undef,), (undef, 1,), (undef, 1,2), (undef, (1,),), (undef, (1,2),),
23+
(Int,), (Int, 1,), (Int, 1,2), (Int, (1,),), (Int, (1,2),),
24+
([1,2],), ([1 2],)]
25+
cpu = try
26+
Array{typs...}(args...)
27+
catch ex
28+
isa(ex, MethodError) || rethrow()
29+
nothing
30+
end
31+
32+
gpu = try
33+
AT{typs...}(args...)
34+
catch ex
35+
isa(ex, MethodError) || rethrow()
36+
cpu == nothing || rethrow()
37+
nothing
38+
end
39+
40+
if cpu == nothing
41+
@test gpu == nothing
42+
else
43+
@test typeof(cpu) == typeof(convert(Array, gpu))
44+
end
45+
end
46+
end
47+
48+
@testsuite "construct/similar" AT->begin
49+
for T in supported_eltypes()
50+
B = AT{T}(undef, 10)
51+
52+
B = similar(B, Int32, 11, 15)
53+
@test B isa AT{Int32,2}
54+
@test size(B) == (11, 15)
55+
@test eltype(B) == Int32
56+
57+
B = similar(B, T)
58+
@test B isa AT{T,2}
59+
@test size(B) == (11, 15)
60+
@test eltype(B) == T
61+
62+
B = similar(B, (5,))
63+
@test B isa AT{T,1}
64+
@test size(B) == (5,)
65+
@test eltype(B) == T
66+
67+
B = similar(B, 7)
68+
@test B isa AT{T,1}
69+
@test size(B) == (7,)
70+
@test eltype(B) == T
2871

29-
B = similar(B, (5,))
30-
@test B isa AT{T,1}
31-
@test size(B) == (5,)
32-
@test eltype(B) == T
72+
B = similar(AT{Int32}, (11, 15))
73+
@test B isa AT{Int32,2}
74+
@test size(B) == (11, 15)
75+
@test eltype(B) == Int32
3376

34-
B = similar(B, 7)
35-
@test B isa AT{T,1}
36-
@test size(B) == (7,)
37-
@test eltype(B) == T
77+
B = similar(AT{T}, (5,))
78+
@test B isa AT{T,1}
79+
@test size(B) == (5,)
80+
@test eltype(B) == T
81+
82+
B = similar(AT{T}, 7)
83+
@test B isa AT{T,1}
84+
@test size(B) == (7,)
85+
@test eltype(B) == T
86+
87+
B = similar(Broadcast.Broadcasted(*, (B, B)), T)
88+
@test B isa AT{T,1}
89+
@test size(B) == (7,)
90+
@test eltype(B) == T
3891

39-
B = similar(AT{Int32}, (11, 15))
92+
if VERSION >= v"1.5"
93+
B = similar(Broadcast.Broadcasted(*, (B, B)), Int32, (11, 15))
4094
@test B isa AT{Int32,2}
4195
@test size(B) == (11, 15)
4296
@test eltype(B) == Int32
43-
44-
B = similar(AT{T}, (5,))
45-
@test B isa AT{T,1}
46-
@test size(B) == (5,)
47-
@test eltype(B) == T
48-
49-
B = similar(AT{T}, 7)
50-
@test B isa AT{T,1}
51-
@test size(B) == (7,)
52-
@test eltype(B) == T
53-
54-
B = similar(Broadcast.Broadcasted(*, (B, B)), T)
55-
@test B isa AT{T,1}
56-
@test size(B) == (7,)
57-
@test eltype(B) == T
58-
59-
if VERSION >= v"1.5"
60-
B = similar(Broadcast.Broadcasted(*, (B, B)), Int32, (11, 15))
61-
@test B isa AT{Int32,2}
62-
@test size(B) == (11, 15)
63-
@test eltype(B) == Int32
64-
end
6597
end
6698
end
99+
end
67100

68-
@testset "comparison against Array" begin
69-
for typs in [(), (Int,), (Int,1), (Int,2), (Float32,), (Float32,1), (Float32,2)],
70-
args in [(), (1,), (1,2), ((1,),), ((1,2),),
71-
(undef,), (undef, 1,), (undef, 1,2), (undef, (1,),), (undef, (1,2),),
72-
(Int,), (Int, 1,), (Int, 1,2), (Int, (1,),), (Int, (1,2),),
73-
([1,2],), ([1 2],)]
74-
cpu = try
75-
Array{typs...}(args...)
76-
catch ex
77-
isa(ex, MethodError) || rethrow()
78-
nothing
79-
end
80-
81-
gpu = try
82-
AT{typs...}(args...)
83-
catch ex
84-
isa(ex, MethodError) || rethrow()
85-
cpu == nothing || rethrow()
86-
nothing
87-
end
88-
89-
if cpu == nothing
90-
@test gpu == nothing
91-
else
92-
@test typeof(cpu) == typeof(convert(Array, gpu))
93-
end
94-
end
101+
@testsuite "construct/convenience" AT->begin
102+
for T in supported_eltypes()
103+
A = AT(rand(T, 3))
104+
b = rand(T)
105+
fill!(A, b)
106+
@test A isa AT{T,1}
107+
@test Array(A) == fill(b, 3)
108+
109+
A = zero(AT(rand(T, 2)))
110+
@test A isa AT{T,1}
111+
@test Array(A) == zero(rand(T, 2))
112+
113+
A = zero(AT(rand(T, 2, 2)))
114+
@test A isa AT{T,2}
115+
@test Array(A) == zero(rand(T, 2, 2))
116+
117+
A = zero(AT(rand(T, 2, 2, 2)))
118+
@test A isa AT{T,3}
119+
@test Array(A) == zero(rand(T, 2, 2, 2))
120+
121+
A = one(AT(rand(T, 2, 2)))
122+
@test A isa AT{T,2}
123+
@test Array(A) == one(rand(T, 2, 2))
124+
125+
A = oneunit(AT(rand(T, 2, 2)))
126+
@test A isa AT{T,2}
127+
@test Array(A) == oneunit(rand(T, 2, 2))
95128
end
96129
end
97130

98-
@testsuite "conversions" AT->begin
131+
@testsuite "construct/conversions" AT->begin
99132
for T in supported_eltypes()
100133
Bc = round.(rand(10, 10) .* 10.0)
101134
B = AT{T}(Bc)
@@ -123,10 +156,8 @@ end
123156
end
124157
end
125158

126-
@testsuite "value constructors" AT->begin
159+
@testsuite "construct/uniformscaling" AT->begin
127160
for T in supported_eltypes()
128-
@test compare((a,b)->fill!(a, b), AT, rand(T, 3), rand(T))
129-
130161
x = Matrix{T}(I, 4, 2)
131162

132163
x1 = AT{T, 2}(I, 4, 2)
@@ -144,7 +175,7 @@ end
144175
end
145176
end
146177

147-
@testsuite "iterator constructors" AT->begin
178+
@testsuite "construct/iterator" AT->begin
148179
for T in supported_eltypes()
149180
@test Array(AT(Fill(T(0), (10,)))) == Array(fill!(similar(AT{T}, (10,)), T(0)))
150181
@test Array(AT(Fill(T(0), (10, 10)))) == Array(fill!(similar(AT{T}, (10, 10)), T(0)))

test/testsuite/math.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testsuite "math" AT->begin
1+
@testsuite "math/intrinsics" AT->begin
22
for ET in supported_eltypes()
33
# Skip complex numbers
44
ET in (Complex, ComplexF32, ComplexF64) && continue
@@ -16,3 +16,11 @@
1616
end
1717
end
1818
end
19+
20+
@testsuite "math/power" AT->begin
21+
for ET in supported_eltypes()
22+
for p in 0:5
23+
compare(x->x^p, AT, rand(ET, 2,2))
24+
end
25+
end
26+
end

0 commit comments

Comments
 (0)