Skip to content

Commit a3b8c34

Browse files
committed
Add getindex functionality
1 parent e05375f commit a3b8c34

File tree

4 files changed

+237
-0
lines changed

4 files changed

+237
-0
lines changed

src/LinearMaps.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ include("kronecker.jl") # Kronecker product of linear maps
262262
include("fillmap.jl") # linear maps representing constantly filled matrices
263263
include("conversion.jl") # conversion of linear maps to matrices
264264
include("show.jl") # show methods for LinearMap objects
265+
include("getindex.jl") # getindex functionality
265266

266267
"""
267268
LinearMap(A::LinearMap; kwargs...)::WrappedMap

src/getindex.jl

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# module GetIndex
2+
3+
# using ..LinearMaps: LinearMap, AdjointMap, TransposeMap, FillMap, LinearCombination,
4+
# ScaledMap, UniformScalingMap, WrappedMap
5+
6+
# required in Base.to_indices for [:]-indexing
7+
Base.eachindex(::IndexLinear, A::LinearMap) = (Base.@_inline_meta; Base.OneTo(length(A)))
8+
# Base.IndexStyle(::LinearMap) = IndexCartesian()
9+
# Base.IndexStyle(A::Union{WrappedMap,AdjointMap,TransposeMap,ScaledMap}) = IndexStyle(A.lmap)
10+
11+
function Base.checkbounds(A::LinearMap, i, j)
12+
Base.@_inline_meta
13+
Base.checkbounds_indices(Bool, axes(A), (i, j)) || throw(BoundsError(A, (i, j)))
14+
nothing
15+
end
16+
# Linear indexing is explicitly allowed when there is only one (non-cartesian) index
17+
function Base.checkbounds(A::LinearMap, i)
18+
Base.@_inline_meta
19+
Base.checkindex(Bool, Base.OneTo(length(A)), i) || throw(BoundsError(A, i))
20+
nothing
21+
end
22+
23+
# dispatch hierarchy
24+
# Base.getindex (includes bounds checking)
25+
# -> Base._getindex (conversion of linear indices to cartesian indices)
26+
# -> _unsafe_getindex
27+
# main entry point
28+
Base.@propagate_inbounds function Base.getindex(A::LinearMap, I...)
29+
# TODO: introduce some sort of switch?
30+
Base.@_inline_meta
31+
@boundscheck checkbounds(A, I...)
32+
_getindex(A, Base.to_indices(A, I)...)
33+
end
34+
# quick pass forward
35+
Base.@propagate_inbounds Base.getindex(A::ScaledMap, I...) = A.λ .* getindex(A.lmap, I...)
36+
Base.@propagate_inbounds Base.getindex(A::AdjointMap, i::Integer) =
37+
adjoint(A.lmap[i-1+first(axes(A.lmap)[1])])
38+
Base.@propagate_inbounds Base.getindex(A::AdjointMap, i::Integer, j::Integer) =
39+
adjoint(A.lmap[j, i])
40+
Base.@propagate_inbounds Base.getindex(A::TransposeMap, i::Integer) =
41+
transpose(A.lmap[i-1+first(axes(A.lmap)[1])])
42+
Base.@propagate_inbounds Base.getindex(A::TransposeMap, i::Integer, j::Integer) =
43+
transpose(A.lmap[j, i])
44+
Base.@propagate_inbounds Base.getindex(A::WrappedMap, I...) = A.lmap[I...]
45+
Base.@propagate_inbounds Base.getindex(A::WrappedMap, i::Integer) = A.lmap[i]
46+
Base.@propagate_inbounds Base.getindex(A::WrappedMap, i::Integer, j::Integer) = A.lmap[i,j]
47+
48+
# Base._getindex, IndexLinear
49+
# Base.@propagate_inbounds Base._getindex(::IndexLinear, A::LinearMap, i::Integer) = _unsafe_getindex(A, i)
50+
# Base.@propagate_inbounds function Base._getindex(::IndexLinear, A::LinearMap, i::Integer, j::Integer)
51+
# Base.@_inline_meta
52+
# # @boundscheck checkbounds(A, i, j)
53+
# return _unsafe_getindex(A, Base._sub2ind(axes(A), i, j))
54+
# end
55+
# Base._getindex, IndexCartesian
56+
Base.@propagate_inbounds function _getindex(A::LinearMap, i::Integer)
57+
Base.@_inline_meta
58+
@boundscheck checkbounds(A, i)
59+
i1, i2 = Base._ind2sub(axes(A), i)
60+
@inbounds r = _unsafe_getindex(A, i1, i2)
61+
return r
62+
end
63+
Base.@propagate_inbounds _getindex(A::LinearMap, i::Integer, j::Integer) =
64+
_unsafe_getindex(A, i, j)
65+
66+
########################
67+
# scalar indexing
68+
########################
69+
# fallback via colon-based method
70+
Base.@propagate_inbounds _unsafe_getindex(A::LinearMap, i::Integer, j::Integer) =
71+
(Base.@_inline_meta; _getindex(A, Base.Slice(axes(A)[1]), j)[i])
72+
# specialized methods
73+
_unsafe_getindex(A::FillMap, ::Integer, ::Integer) = A.λ
74+
Base.@propagate_inbounds _unsafe_getindex(A::LinearCombination, i::Integer, j::Integer) =
75+
sum(a -> getindex(A.maps[a], i, j), eachindex(A.maps))
76+
_unsafe_getindex(A::UniformScalingMap, i::Integer, j::Integer) =
77+
ifelse(i == j, A.λ, zero(eltype(A)))
78+
79+
########################
80+
# multidimensional slicing
81+
########################
82+
Base.@propagate_inbounds function _getindex(A::LinearMap, i::Integer, J::AbstractVector{<:Integer})
83+
try
84+
return (basevec(A, i)'A)[J]
85+
catch
86+
x = zeros(eltype(A), size(A, 2))
87+
y = similar(x, eltype(A), size(A, 1))
88+
r = similar(x, eltype(A), length(J))
89+
@inbounds for (ind, j) in enumerate(J)
90+
x[j] = one(eltype(A))
91+
_unsafe_mul!(y, A, x)
92+
r[ind] = y[i]
93+
x[j] = zero(eltype(A))
94+
end
95+
return r
96+
end
97+
end
98+
Base.@propagate_inbounds _getindex(A::LinearMap, I::AbstractVector{<:Integer}, j::Integer) =
99+
(Base.@_inline_meta; _getindex(A, Base.Slice(axes(A)[1]), j)[I])
100+
Base.@propagate_inbounds function _getindex(A::LinearMap, Is::Vararg{AbstractVector{<:Integer},2})
101+
shape = Base.index_shape(Is...)
102+
dest = zeros(eltype(A), shape)
103+
I, J = Is
104+
for (ind, ij) in zip(eachindex(dest), Iterators.product(I, J))
105+
i, j = ij
106+
dest[ind] = _unsafe_getindex(A, i, j)
107+
end
108+
return dest
109+
end
110+
Base.@propagate_inbounds function _getindex(A::LinearMap, I::AbstractVector{<:Integer})
111+
dest = Vector{eltype(A)}(undef, length(I))
112+
for i in eachindex(dest, I)
113+
dest[i] = _getindex(A, I[i])
114+
end
115+
return dest
116+
end
117+
_getindex(A::LinearMap, ::Base.Slice, ::Base.Slice) = Matrix(A)
118+
_getindex(A::LinearMap, ::Base.Slice) = vec(Matrix(A))
119+
function _getindex(A::LinearMap, i::Integer, J::Base.Slice)
120+
try
121+
return vec(basevec(A, i)'A)
122+
catch
123+
return vec(_getindex(A, i:i, J))
124+
end
125+
end
126+
_getindex(A::LinearMap, ::Base.Slice, j::Integer) = A*basevec(A, j)
127+
# Needs to be defined for custom LinearMap subtypes
128+
# Base.@propagate_inbounds function _unsafe_getindex(A::CustomMap, i::Union{Integer,AbstractVector{<:Integer}})
129+
function _getindex(A::LinearMap, I::AbstractVector{<:Integer}, ::Base.Slice)
130+
x = zeros(eltype(A), size(A, 2))
131+
y = similar(x, eltype(A), size(A, 1))
132+
r = similar(x, eltype(A), (length(I), size(A, 2)))
133+
@views @inbounds for j in axes(A)[2]
134+
x[j] = one(eltype(A))
135+
_unsafe_mul!(y, A, x)
136+
r[:,j] .= y[I]
137+
x[j] = zero(eltype(A))
138+
end
139+
return r
140+
end
141+
function _getindex(A::LinearMap, ::Base.Slice, J::AbstractVector{<:Integer})
142+
x = zeros(eltype(A), size(A, 2))
143+
y = similar(x, eltype(A), (size(A, 1), length(J)))
144+
@inbounds for (i, j) in enumerate(J)
145+
x[j] = one(eltype(A))
146+
_unsafe_mul!(selectdim(y, 2, i), A, x)
147+
x[j] = zero(eltype(A))
148+
end
149+
return y
150+
end
151+
152+
# helpers
153+
function basevec(A, i::Integer)
154+
x = zeros(eltype(A), size(A, 2))
155+
@inbounds x[i] = one(eltype(A))
156+
return x
157+
end
158+
159+
nogetindex_error() = error("indexing not allowed for LinearMaps; consider setting `LinearMaps.allowgetindex = true`")
160+
161+
# end # module

test/getindex.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using BenchmarkTools, LinearAlgebra, LinearMaps, Test
2+
# using LinearMaps.GetIndex
3+
4+
struct TwoMap <: LinearMaps.LinearMap{Float64} end
5+
Base.size(::TwoMap) = (5,5)
6+
Base.IndexStyle(::TwoMap) = IndexLinear()
7+
LinearMaps._unsafe_getindex(::TwoMap, i::Integer) = 2.0
8+
LinearMaps._unsafe_mul!(y::AbstractVector, ::TwoMap, x::AbstractVector) = fill!(y, 2.0*sum(x))
9+
10+
@testset "getindex" begin
11+
A = rand(3,3)
12+
L = LinearMap(A)
13+
@test all((L[i,j] == A[i,j] for i in 1:3, j in 1:3))
14+
@test all((L[i] == A[i] for i in 1:9))
15+
@test L[1,:] == A[1,:]
16+
@btime getindex($A, i) setup=(i = rand(1:9))
17+
@btime getindex($L, i) setup=(i = rand(1:9))
18+
@btime (getindex($A, i, j)) setup=(i = rand(1:3); j = rand(1:3))
19+
@btime (getindex($L, i, j)) setup=(i = rand(1:3); j = rand(1:3))
20+
21+
@testset "minifillmap" begin
22+
T = TwoMap()
23+
@test T[1,1] == 2.0
24+
@test T[:,1] == fill(2.0, 5)
25+
@test T[1,:] == fill(2.0, 5)
26+
@test T[2:3,:] == fill(2.0, 2, 5)
27+
@test T[:,2:3] == fill(2.0, 5, 2)
28+
@test T[2:3,3] == fill(2.0, 2)
29+
@test T[2,2:3] == fill(2.0, 2)
30+
@test_throws BoundsError T[6,1]
31+
@test_throws BoundsError T[1,6]
32+
@test_throws BoundsError T[2,1:6]
33+
@test_throws BoundsError T[1:6,2]
34+
@test_throws BoundsError T[0]
35+
@test_throws BoundsError T[26]
36+
37+
Base.adjoint(A::TwoMap) = A
38+
@test T[1,1] == 2.0
39+
@test T[:,1] == fill(2.0, 5)
40+
@test T[1,:] == fill(2.0, 5)
41+
@test T[2:3,:] == fill(2.0, 2, 5)
42+
@test T[:,2:3] == fill(2.0, 5, 2)
43+
@test T[2:3,3] == fill(2.0, 2)
44+
@test T[2,2:3] == fill(2.0, 2)
45+
@test_throws BoundsError T[6,1]
46+
@test_throws BoundsError T[1,6]
47+
@test_throws BoundsError T[2,1:6]
48+
@test_throws BoundsError T[1:6,2]
49+
@test_throws BoundsError T[0]
50+
@test_throws BoundsError T[26]
51+
end
52+
53+
@testset "function wrap around matrix" begin
54+
MA = rand(ComplexF64, 5, 5)
55+
FA = LinearMap{ComplexF64}((y, x) -> mul!(y, MA, x), (y, x) -> mul!(y, MA', x), 5, 5)
56+
for transform in (identity, transpose, adjoint), (A, F) in ((MA, FA), (3MA, 3FA))
57+
@test transform(F)[1,1] transform(A)[1,1]
58+
@test transform(F)[:] transform(A)[:]
59+
@test transform(F)[1,:] transform(A)[1,:]
60+
@test transform(F)[:,1] transform(A)[:,1]
61+
@test transform(F)[1:4,:] transform(A)[1:4,:]
62+
@test transform(F)[:,1:4] transform(A)[:,1:4]
63+
@test transform(F)[1,1:3] transform(A)[1,1:3]
64+
@test transform(F)[1:3,1] transform(A)[1:3,1]
65+
@test transform(F)[1:2,1:3] transform(A)[1:2,1:3]
66+
@test transform(F)[[2,1],1:3] transform(A)[[2,1],1:3]
67+
@test transform(F)[:,:] transform(A)
68+
@test transform(F)[7] transform(A)[7]
69+
@test_throws BoundsError transform(F)[0]
70+
@test_throws BoundsError transform(F)[26]
71+
end
72+
end
73+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,5 @@ include("fillmap.jl")
4040
if VERSION v"1.1"
4141
include("nontradaxes.jl")
4242
end
43+
44+
include("getindex.jl")

0 commit comments

Comments
 (0)