Skip to content

Commit ff5c1a6

Browse files
committed
Implement p-norm.
1 parent 7f38f28 commit ff5c1a6

File tree

4 files changed

+27
-7
lines changed

4 files changed

+27
-7
lines changed

Project.toml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,3 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1515
AbstractFFTs = "0.4, 0.5, 1.0"
1616
Adapt = "2.0, 3.0"
1717
julia = "1.5"
18-
19-
[extras]
20-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
21-
22-
[targets]
23-
test = ["Test"]

src/host/linalg.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,16 @@ end
194194
# TODO: implementation without the memory copy
195195
LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) =
196196
permutedims!(dest, src, Tuple(perm))
197+
198+
199+
## norm
200+
201+
function LinearAlgebra.norm(v::AbstractGPUArray{T}, p::Real=2) where {T}
202+
if p == Inf
203+
maximum(abs.(v))
204+
elseif p == -Inf
205+
minimum(abs.(v))
206+
else
207+
mapreduce(x->abs(x)^p, +, v; init=zero(T))^(1/p)
208+
end
209+
end

test/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[deps]
2+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
4+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
6+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/testsuite/linalg.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testsuite "linalg" AT->begin
2-
@testset "adjoint and transpose" begin
2+
@testset "adjoint and trspose" begin
33
@test compare(adjoint, AT, rand(Float32, 32, 32))
44
@test compare(adjoint!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
55
@test compare(transpose, AT, rand(Float32, 32, 32))
@@ -121,4 +121,11 @@ end
121121
@test compare(rmul!, AT, rand(T, a), Ref(rand(T)))
122122
@test compare(lmul!, AT, Ref(rand(T)), rand(T, b))
123123
end
124+
125+
@testset "$p-norm($sz x $T)" for sz in [(2,), (2,2), (2,2,2)],
126+
p in Any[1, 2, 3, Inf, -Inf],
127+
T in supported_eltypes()
128+
range = T <: Integer ? (T(1):T(10)) : T # prevent integer overflow
129+
@test compare(norm, AT, rand(range, sz), Ref(p))
130+
end
124131
end

0 commit comments

Comments
 (0)