Skip to content

Commit a988d69

Browse files
committed
Implement p-norm.
1 parent 4de9fbb commit a988d69

File tree

4 files changed

+27
-8
lines changed

4 files changed

+27
-8
lines changed

Project.toml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,3 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
1414
AbstractFFTs = "0.4, 0.5, 1.0"
1515
Adapt = "2.0, 3.0"
1616
julia = "1.5"
17-
18-
[extras]
19-
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
20-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
21-
22-
[targets]
23-
test = ["Test", "FillArrays"]

src/host/linalg.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,16 @@ end
211211
# TODO: implementation without the memory copy
212212
LinearAlgebra.permutedims!(dest::AbstractGPUArray, src::AbstractGPUArray, perm) =
213213
permutedims!(dest, src, Tuple(perm))
214+
215+
216+
## norm
217+
218+
function LinearAlgebra.norm(v::AbstractGPUArray{T}, p::Real=2) where {T}
219+
if p == Inf
220+
maximum(abs.(v))
221+
elseif p == -Inf
222+
minimum(abs.(v))
223+
else
224+
mapreduce(x->abs(x)^p, +, v; init=zero(T))^(1/p)
225+
end
226+
end

test/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[deps]
2+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
4+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/testsuite/linalg.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testsuite "linear algebra" AT->begin
2-
@testset "adjoint and transpose" begin
2+
@testset "adjoint and trspose" begin
33
@test compare(adjoint, AT, rand(Float32, 32, 32))
44
@test compare(adjoint!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
55
@test compare(transpose, AT, rand(Float32, 32, 32))
@@ -114,4 +114,11 @@
114114
@test compare(rmul!, AT, rand(T, a), Ref(rand(T)))
115115
@test compare(lmul!, AT, Ref(rand(T)), rand(T, b))
116116
end
117+
118+
@testset "$p-norm($sz x $T)" for sz in [(2,), (2,2), (2,2,2)],
119+
p in Any[1, 2, 3, Inf, -Inf],
120+
T in supported_eltypes()
121+
range = T <: Integer ? (T(1):T(10)) : T # prevent integer overflow
122+
@test compare(norm, AT, rand(range, sz), Ref(p))
123+
end
117124
end

0 commit comments

Comments
 (0)