Skip to content

Commit 08e8f3e

Browse files
authored
Classical orthogonal polynomials v0.3 (#27)
* ClassicalOrthogonalPolynomials v0.3 * REORG * Update Project.toml * add getindex for tests * increase coverage
1 parent 68e6281 commit 08e8f3e

File tree

6 files changed

+209
-159
lines changed

6 files changed

+209
-159
lines changed

Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "HarmonicOrthogonalPolynomials"
22
uuid = "e416a80e-9640-42f3-8df8-80a93ca01ea5"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.0.2"
4+
version = "0.0.3"
55

66
[deps]
77
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
@@ -18,13 +18,13 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1818
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1919

2020
[compat]
21-
BlockArrays = "0.14"
21+
BlockArrays = "0.14, 0.15"
2222
BlockBandedMatrices = "0.10"
23-
ClassicalOrthogonalPolynomials = "0.1, 0.2"
24-
ContinuumArrays = "0.5, 0.6"
23+
ClassicalOrthogonalPolynomials = "0.2, 0.3"
24+
ContinuumArrays = "0.6"
2525
DomainSets = "0.4"
2626
FastTransforms = "0.11, 0.12"
27-
InfiniteArrays = "0.9, 0.10"
27+
InfiniteArrays = "0.10"
2828
IntervalSets = "0.5"
2929
QuasiArrays = "0.4"
3030
SpecialFunctions = "0.10, 1"

src/HarmonicOrthogonalPolynomials.jl

Lines changed: 6 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -12,150 +12,11 @@ import BlockBandedMatrices: BlockRange1
1212
import FastTransforms: Plan, interlace
1313
import QuasiArrays: LazyQuasiMatrix, LazyQuasiArrayStyle
1414

15-
export SphericalHarmonic, UnitSphere, SphericalCoordinate, Block, associatedlegendre, RealSphericalHarmonic, sphericalharmonicy, Laplacian
15+
export SphericalHarmonic, UnitSphere, SphericalCoordinate, RadialCoordinate, Block, associatedlegendre, RealSphericalHarmonic, sphericalharmonicy, Laplacian
1616

1717
include("multivariateops.jl")
18-
19-
20-
###
21-
# SphereTrav
22-
###
23-
24-
25-
"""
26-
SphereTrav(A::AbstractMatrix)
27-
28-
is an anlogue of `DiagTrav` but for coefficients stored according to
29-
FastTransforms.jl spherical harmonics layout
30-
"""
31-
struct SphereTrav{T, AA<:AbstractMatrix{T}} <: AbstractBlockVector{T}
32-
matrix::AA
33-
function SphereTrav{T, AA}(matrix::AA) where {T,AA<:AbstractMatrix{T}}
34-
n,m = size(matrix)
35-
m == 2n-1 || throw(ArgumentError("size must match"))
36-
new{T,AA}(matrix)
37-
end
38-
end
39-
40-
SphereTrav{T}(matrix::AbstractMatrix{T}) where T = SphereTrav{T,typeof(matrix)}(matrix)
41-
SphereTrav(matrix::AbstractMatrix{T}) where T = SphereTrav{T}(matrix)
42-
43-
axes(A::SphereTrav) = (blockedrange(range(1; step=2, length=size(A.matrix,1))),)
44-
45-
function getindex(A::SphereTrav, K::Block{1})
46-
k = Int(K)
47-
m = size(A.matrix,1)
48-
st = stride(A.matrix,2)
49-
# nonnegative terms
50-
p = A.matrix[range(k; step=2*st-1, length=k)]
51-
k == 1 && return p
52-
# negative terms
53-
n = A.matrix[range(k+st-1; step=2*st-1, length=k-1)]
54-
[reverse!(n); p]
55-
end
56-
57-
getindex(A::SphereTrav, k::Int) = A[findblockindex(axes(A,1), k)]
58-
59-
"""
60-
RealSphereTrav(A::AbstractMatrix)
61-
62-
takes coefficients as provided by the spherical harmonics layout of FastTransforms.jl and
63-
makes them accessible sorted such that in each block the m=0 entries are always in first place,
64-
followed by alternating sin and cos terms of increasing |m|.
65-
"""
66-
struct RealSphereTrav{T, AA<:AbstractMatrix{T}} <: AbstractBlockVector{T}
67-
matrix::AA
68-
function RealSphereTrav{T, AA}(matrix::AA) where {T,AA<:AbstractMatrix{T}}
69-
n,m = size(matrix)
70-
m == 2n-1 || throw(ArgumentError("size must match"))
71-
new{T,AA}(matrix)
72-
end
73-
end
74-
75-
RealSphereTrav{T}(matrix::AbstractMatrix{T}) where T = RealSphereTrav{T,typeof(matrix)}(matrix)
76-
RealSphereTrav(matrix::AbstractMatrix{T}) where T = RealSphereTrav{T}(matrix)
77-
78-
axes(A::RealSphereTrav) = (blockedrange(range(1; step=2, length=size(A.matrix,1))),)
79-
80-
function getindex(A::RealSphereTrav, K::Block{1})
81-
k = Int(K)
82-
m = size(A.matrix,1)
83-
st = stride(A.matrix,2)
84-
# nonnegative terms
85-
p = A.matrix[range(k; step=2*st-1, length=k)]
86-
k == 1 && return p
87-
# negative terms
88-
n = A.matrix[range(k+st-1; step=2*st-1, length=k-1)]
89-
interlace(p,n)
90-
end
91-
92-
getindex(A::RealSphereTrav, k::Int) = A[findblockindex(axes(A,1), k)]
93-
94-
###
95-
# SphericalCoordinate
96-
###
97-
98-
abstract type AbstractSphericalCoordinate{T} <: StaticVector{3,T} end
99-
norm(::AbstractSphericalCoordinate{T}) where T = real(one(T))
100-
Base.in(::AbstractSphericalCoordinate, ::UnitSphere{T}) where T = true
101-
"""
102-
SphericalCoordinate(θ, φ)
103-
104-
represents a point in the unit sphere as a `StaticVector{3}` in
105-
spherical coordinates where the pole is `SphericalCoordinate(0,φ) == SVector(0,0,1)`
106-
and `SphericalCoordinate(π/2,0) == SVector(1,0,0)`.
107-
"""
108-
struct SphericalCoordinate{T} <: AbstractSphericalCoordinate{T}
109-
θ::T
110-
φ::T
111-
end
112-
113-
SphericalCoordinate(θ, φ) = SphericalCoordinate(promote(θ, φ)...)
114-
115-
"""
116-
ZSphericalCoordinate(φ, z)
117-
118-
represents a point in the unit sphere as a `StaticVector{3}` in
119-
where `z` is specified while the angle coordinate is given by spherical coordinates where the pole is `SVector(0,0,1)`.
120-
"""
121-
struct ZSphericalCoordinate{T} <: AbstractSphericalCoordinate{T}
122-
φ::T
123-
z::T
124-
function ZSphericalCoordinate{T}::T, z::T) where T
125-
-1  z  1 || throw(ArgumentError("z must be between -1 and 1"))
126-
new{T}(φ, z)
127-
end
128-
end
129-
ZSphericalCoordinate::T, z::V) where {T,V} = ZSphericalCoordinate{promote_type(T,V)}(φ,z)
130-
ZSphericalCoordinate(S::SphericalCoordinate) = ZSphericalCoordinate(S.φ, cos(S.θ))
131-
ZSphericalCoordinate{T}(S::SphericalCoordinate) where T = ZSphericalCoordinate{T}(S.φ, cos(S.θ))
132-
133-
SphericalCoordinate(S::ZSphericalCoordinate) = SphericalCoordinate(acos(S.z), S.φ)
134-
SphericalCoordinate{T}(S::ZSphericalCoordinate) where T = SphericalCoordinate{T}(acos(S.z), S.φ)
135-
136-
137-
function getindex(S::SphericalCoordinate, k::Int)
138-
k == 1 && return sin(S.θ) * cos(S.φ)
139-
k == 2 && return sin(S.θ) * sin(S.φ)
140-
k == 3 && return cos(S.θ)
141-
throw(BoundsError(S, k))
142-
end
143-
function getindex(S::ZSphericalCoordinate, k::Int)
144-
k == 1 && return sqrt(1-S.z^2) * cos(S.φ)
145-
k == 2 && return sqrt(1-S.z^2) * sin(S.φ)
146-
k == 3 && return S.z
147-
throw(BoundsError(S, k))
148-
end
149-
150-
convert(::Type{SVector{3,T}}, S::SphericalCoordinate) where T = SVector{3,T}(sin(S.θ)*cos(S.φ), sin(S.θ)*sin(S.φ), cos(S.θ))
151-
convert(::Type{SVector{3,T}}, S::ZSphericalCoordinate) where T = SVector{3,T}(sqrt(1-S.z^2)*cos(S.φ), sqrt(1-S.z^2)*sin(S.φ), S.z)
152-
convert(::Type{SVector{3}}, S::SphericalCoordinate) = SVector(sin(S.θ)*cos(S.φ), sin(S.θ)*sin(S.φ), cos(S.θ))
153-
convert(::Type{SVector{3}}, S::ZSphericalCoordinate) = SVector(sqrt(1-S.z^2)*cos(S.φ), sqrt(1-S.z^2)*sin(S.φ), S.z)
154-
155-
convert(::Type{SphericalCoordinate}, S::ZSphericalCoordinate) = SphericalCoordinate(S)
156-
convert(::Type{SphericalCoordinate{T}}, S::ZSphericalCoordinate) where T = SphericalCoordinate{T}(S)
157-
convert(::Type{ZSphericalCoordinate}, S::SphericalCoordinate) = ZSphericalCoordinate(S)
158-
convert(::Type{ZSphericalCoordinate{T}}, S::SphericalCoordinate) where T = ZSphericalCoordinate{T}(S)
18+
include("spheretrav.jl")
19+
include("coordinates.jl")
15920

16021

16122
checkpoints(::UnitSphere{T}) where T = [SphericalCoordinate{T}(0.1,0.2), SphericalCoordinate{T}(0.3,0.4)]
@@ -217,7 +78,10 @@ function getindex(S::RealSphericalHarmonic{T}, x::SphericalCoordinate, K::BlockI
21778
end
21879

21980
getindex(S::AbstractSphericalHarmonic, x::StaticVector{3}, K::BlockIndex{1}) = S[SphericalCoordinate(x), K]
81+
getindex(S::AbstractSphericalHarmonic, x::StaticVector{3}, K::Block{1}) = S[x, axes(S,2)[K]]
82+
getindex(S::AbstractSphericalHarmonic, x::StaticVector{3}, KR::BlockOneTo) = mortar([S[x, K] for K in KR])
22083
getindex(S::AbstractSphericalHarmonic, x::StaticVector{3}, k::Int) = S[x, findblockindex(axes(S,2), k)]
84+
getindex(S::AbstractSphericalHarmonic, x::StaticVector{3}, kr::AbstractUnitRange{Int}) = [S[x, k] for k in kr]
22185

22286
# @simplify *(Ac::QuasiAdjoint{<:Any,<:SphericalHarmonic}, B::SphericalHarmonic) =
22387

src/coordinates.jl

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
###
2+
# RadialCoordinate
3+
####
4+
5+
"""
6+
RadialCoordinate(r, θ)
7+
8+
represents the 2-vector [r*cos(θ),r*sin(θ)]
9+
"""
10+
struct RadialCoordinate{T} <: StaticVector{2,T}
11+
r::T
12+
θ::T
13+
RadialCoordinate{T}(r::T, θ::T) where T = new{T}(r, θ)
14+
end
15+
16+
RadialCoordinate{T}(r, θ) where T = RadialCoordinate{T}(convert(T,r), convert(T,θ))
17+
RadialCoordinate(r::T, θ::V) where {T<:Real,V<:Real} = RadialCoordinate{float(promote_type(T,V))}(r, θ)
18+
19+
function RadialCoordinate(xy::StaticVector{2})
20+
x,y = xy
21+
RadialCoordinate(norm(xy), atan(y,x))
22+
end
23+
24+
StaticArrays.SVector(rθ::RadialCoordinate) = SVector(rθ.r * cos(rθ.θ), rθ.r * sin(rθ.θ))
25+
getindex(R::RadialCoordinate, k::Int) = SVector(R)[k]
26+
27+
###
28+
# SphericalCoordinate
29+
###
30+
31+
abstract type AbstractSphericalCoordinate{T} <: StaticVector{3,T} end
32+
norm(::AbstractSphericalCoordinate{T}) where T = real(one(T))
33+
Base.in(::AbstractSphericalCoordinate, ::UnitSphere{T}) where T = true
34+
"""
35+
SphericalCoordinate(θ, φ)
36+
37+
represents a point in the unit sphere as a `StaticVector{3}` in
38+
spherical coordinates where the pole is `SphericalCoordinate(0,φ) == SVector(0,0,1)`
39+
and `SphericalCoordinate(π/2,0) == SVector(1,0,0)`.
40+
"""
41+
struct SphericalCoordinate{T} <: AbstractSphericalCoordinate{T}
42+
θ::T
43+
φ::T
44+
SphericalCoordinate{T}::T, φ::T) where T = new{T}(θ, φ)
45+
end
46+
47+
SphericalCoordinate{T}(θ, φ) where T = SphericalCoordinate{T}(convert(T,θ), convert(T,φ))
48+
SphericalCoordinate::V, φ::T) where {T<:Real,V<:Real} = SphericalCoordinate{float(promote_type(T,V))}(θ, φ)
49+
SphericalCoordinate(S::SphericalCoordinate) = S
50+
51+
"""
52+
ZSphericalCoordinate(φ, z)
53+
54+
represents a point in the unit sphere as a `StaticVector{3}` in
55+
where `z` is specified while the angle coordinate is given by spherical coordinates where the pole is `SVector(0,0,1)`.
56+
"""
57+
struct ZSphericalCoordinate{T} <: AbstractSphericalCoordinate{T}
58+
φ::T
59+
z::T
60+
function ZSphericalCoordinate{T}::T, z::T) where T
61+
-1  z  1 || throw(ArgumentError("z must be between -1 and 1"))
62+
new{T}(φ, z)
63+
end
64+
end
65+
ZSphericalCoordinate::T, z::V) where {T,V} = ZSphericalCoordinate{promote_type(T,V)}(φ,z)
66+
ZSphericalCoordinate(S::SphericalCoordinate) = ZSphericalCoordinate(S.φ, cos(S.θ))
67+
ZSphericalCoordinate{T}(S::SphericalCoordinate) where T = ZSphericalCoordinate{T}(S.φ, cos(S.θ))
68+
69+
SphericalCoordinate(S::ZSphericalCoordinate) = SphericalCoordinate(acos(S.z), S.φ)
70+
SphericalCoordinate{T}(S::ZSphericalCoordinate) where T = SphericalCoordinate{T}(acos(S.z), S.φ)
71+
72+
73+
function getindex(S::SphericalCoordinate, k::Int)
74+
k == 1 && return sin(S.θ) * cos(S.φ)
75+
k == 2 && return sin(S.θ) * sin(S.φ)
76+
k == 3 && return cos(S.θ)
77+
throw(BoundsError(S, k))
78+
end
79+
function getindex(S::ZSphericalCoordinate, k::Int)
80+
k == 1 && return sqrt(1-S.z^2) * cos(S.φ)
81+
k == 2 && return sqrt(1-S.z^2) * sin(S.φ)
82+
k == 3 && return S.z
83+
throw(BoundsError(S, k))
84+
end
85+
86+
convert(::Type{SVector{3,T}}, S::SphericalCoordinate) where T = SVector{3,T}(sin(S.θ)*cos(S.φ), sin(S.θ)*sin(S.φ), cos(S.θ))
87+
convert(::Type{SVector{3,T}}, S::ZSphericalCoordinate) where T = SVector{3,T}(sqrt(1-S.z^2)*cos(S.φ), sqrt(1-S.z^2)*sin(S.φ), S.z)
88+
convert(::Type{SVector{3}}, S::SphericalCoordinate) = SVector(sin(S.θ)*cos(S.φ), sin(S.θ)*sin(S.φ), cos(S.θ))
89+
convert(::Type{SVector{3}}, S::ZSphericalCoordinate) = SVector(sqrt(1-S.z^2)*cos(S.φ), sqrt(1-S.z^2)*sin(S.φ), S.z)
90+
91+
convert(::Type{SphericalCoordinate}, S::ZSphericalCoordinate) = SphericalCoordinate(S)
92+
convert(::Type{SphericalCoordinate{T}}, S::ZSphericalCoordinate) where T = SphericalCoordinate{T}(S)
93+
convert(::Type{ZSphericalCoordinate}, S::SphericalCoordinate) = ZSphericalCoordinate(S)
94+
convert(::Type{ZSphericalCoordinate{T}}, S::SphericalCoordinate) where T = ZSphericalCoordinate{T}(S)

src/multivariateops.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,12 @@ const BivariateOrthogonalPolynomial{T} = MultivariateOrthogonalPolynomial{2,T}
2424
const BlockOneTo = BlockRange{1,Tuple{OneTo{Int}}}
2525

2626

27-
getindex(P::MultivariateOrthogonalPolynomial{<:Any,D}, xy::SVector{D}, JR::BlockOneTo) where D =
28-
error("Overload")
29-
getindex(P::MultivariateOrthogonalPolynomial{D}, xy::SVector{D}, J::Block{1}) where D = P[xy, Block.(OneTo(Int(J)))][J]
30-
getindex(P::MultivariateOrthogonalPolynomial{D}, xy::SVector{D}, JR::BlockRange{1}) where D = P[xy, Block.(OneTo(Int(maximum(JR))))][JR]
31-
getindex(P::MultivariateOrthogonalPolynomial{D}, xy::SVector{D}, Jj::BlockIndex{1}) where D = P[xy, block(Jj)][blockindex(Jj)]
32-
getindex(P::MultivariateOrthogonalPolynomial{D}, xy::SVector{D}, j::Integer) where D = P[xy, findblockindex(axes(P,2), j)]
33-
getindex(P::MultivariateOrthogonalPolynomial{D}, xy::SVector{D}, jr::AbstractVector) where D = P[xy, Block.(OneTo(Int(findblock(axes(P,2), maximum(jr)))))][jr]
27+
getindex(P::MultivariateOrthogonalPolynomial{D}, xy::StaticVector{D}, JR::BlockOneTo) where D = error("Overload")
28+
getindex(P::MultivariateOrthogonalPolynomial{D}, xy::StaticVector{D}, J::Block{1}) where D = P[xy, Block.(OneTo(Int(J)))][J]
29+
getindex(P::MultivariateOrthogonalPolynomial{D}, xy::StaticVector{D}, JR::BlockRange{1}) where D = P[xy, Block.(OneTo(Int(maximum(JR))))][JR]
30+
getindex(P::MultivariateOrthogonalPolynomial{D}, xy::StaticVector{D}, Jj::BlockIndex{1}) where D = P[xy, block(Jj)][blockindex(Jj)]
31+
getindex(P::MultivariateOrthogonalPolynomial{D}, xy::StaticVector{D}, j::Integer) where D = P[xy, findblockindex(axes(P,2), j)]
32+
getindex(P::MultivariateOrthogonalPolynomial{D}, xy::StaticVector{D}, jr::AbstractVector{<:Integer}) where D = P[xy, Block.(OneTo(Int(findblock(axes(P,2), maximum(jr)))))][jr]
3433

3534
const FirstInclusion = BroadcastQuasiVector{<:Any, typeof(first), <:Tuple{Inclusion}}
3635
const LastInclusion = BroadcastQuasiVector{<:Any, typeof(last), <:Tuple{Inclusion}}

src/spheretrav.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
2+
###
3+
# SphereTrav
4+
###
5+
6+
7+
"""
8+
SphereTrav(A::AbstractMatrix)
9+
10+
is an anlogue of `DiagTrav` but for coefficients stored according to
11+
FastTransforms.jl spherical harmonics layout
12+
"""
13+
struct SphereTrav{T, AA<:AbstractMatrix{T}} <: AbstractBlockVector{T}
14+
matrix::AA
15+
function SphereTrav{T, AA}(matrix::AA) where {T,AA<:AbstractMatrix{T}}
16+
n,m = size(matrix)
17+
m == 2n-1 || throw(ArgumentError("size must match"))
18+
new{T,AA}(matrix)
19+
end
20+
end
21+
22+
SphereTrav{T}(matrix::AbstractMatrix{T}) where T = SphereTrav{T,typeof(matrix)}(matrix)
23+
SphereTrav(matrix::AbstractMatrix{T}) where T = SphereTrav{T}(matrix)
24+
25+
axes(A::SphereTrav) = (blockedrange(range(1; step=2, length=size(A.matrix,1))),)
26+
27+
function getindex(A::SphereTrav, K::Block{1})
28+
k = Int(K)
29+
m = size(A.matrix,1)
30+
st = stride(A.matrix,2)
31+
# nonnegative terms
32+
p = A.matrix[range(k; step=2*st-1, length=k)]
33+
k == 1 && return p
34+
# negative terms
35+
n = A.matrix[range(k+st-1; step=2*st-1, length=k-1)]
36+
[reverse!(n); p]
37+
end
38+
39+
getindex(A::SphereTrav, k::Int) = A[findblockindex(axes(A,1), k)]
40+
41+
"""
42+
RealSphereTrav(A::AbstractMatrix)
43+
44+
takes coefficients as provided by the spherical harmonics layout of FastTransforms.jl and
45+
makes them accessible sorted such that in each block the m=0 entries are always in first place,
46+
followed by alternating sin and cos terms of increasing |m|.
47+
"""
48+
struct RealSphereTrav{T, AA<:AbstractMatrix{T}} <: AbstractBlockVector{T}
49+
matrix::AA
50+
function RealSphereTrav{T, AA}(matrix::AA) where {T,AA<:AbstractMatrix{T}}
51+
n,m = size(matrix)
52+
m == 2n-1 || throw(ArgumentError("size must match"))
53+
new{T,AA}(matrix)
54+
end
55+
end
56+
57+
RealSphereTrav{T}(matrix::AbstractMatrix{T}) where T = RealSphereTrav{T,typeof(matrix)}(matrix)
58+
RealSphereTrav(matrix::AbstractMatrix{T}) where T = RealSphereTrav{T}(matrix)
59+
60+
axes(A::RealSphereTrav) = (blockedrange(range(1; step=2, length=size(A.matrix,1))),)
61+
62+
function getindex(A::RealSphereTrav, K::Block{1})
63+
k = Int(K)
64+
m = size(A.matrix,1)
65+
st = stride(A.matrix,2)
66+
# nonnegative terms
67+
p = A.matrix[range(k; step=2*st-1, length=k)]
68+
k == 1 && return p
69+
# negative terms
70+
n = A.matrix[range(k+st-1; step=2*st-1, length=k-1)]
71+
interlace(p,n)
72+
end
73+
74+
getindex(A::RealSphereTrav, k::Int) = A[findblockindex(axes(A,1), k)]

0 commit comments

Comments
 (0)