Skip to content

Commit 609476f

Browse files
separate 5-args constructor
1 parent 55e0273 commit 609476f

11 files changed

+116
-62
lines changed

src/TimedOperators.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ mutable struct TimedLinearOperator{T, OP <: AbstractLinearOperator{T}, F, Ft, Fc
1111
ctprod!::Fct
1212
end
1313

14-
TimedLinearOperator{T}(
14+
TimedLinearOperator(
1515
timer::TimerOutput,
1616
op::AbstractLinearOperator{T},
1717
prod!::F,
@@ -29,7 +29,7 @@ function TimedLinearOperator(op::AbstractLinearOperator{T}) where {T}
2929
prod!(res, x, α, β) = @timeit timer "prod" op.prod!(res, x, α, β)
3030
tprod!(res, x, α, β) = @timeit timer "tprod" op.tprod!(res, x, α, β)
3131
ctprod!(res, x, α, β) = @timeit timer "ctprod" op.ctprod!(res, x, α, β)
32-
TimedLinearOperator{T}(timer, op, prod!, tprod!, ctprod!)
32+
TimedLinearOperator(timer, op, prod!, tprod!, ctprod!)
3333
end
3434

3535
TimedLinearOperator(op::AdjointLinearOperator) = adjoint(TimedLinearOperator(op.parent))

src/abstract.jl

+39-14
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ export AbstractLinearOperator,
22
AbstractQuasiNewtonOperator,
33
AbstractDiagonalQuasiNewtonOperator,
44
LinearOperator,
5+
LinearOperator5,
56
LinearOperatorException,
67
hermitian,
78
ishermitian,
@@ -62,7 +63,8 @@ mutable struct LinearOperator{T, I <: Integer, F, Ft, Fct, S} <: AbstractLinearO
6263
allocated5::Bool # true for 5-args mul!, false for 3-args mul! until the vectors are allocated
6364
end
6465

65-
function LinearOperator{T}(
66+
function LinearOperator(
67+
::Type{T},
6668
nrow::I,
6769
ncol::I,
6870
symmetric::Bool,
@@ -76,11 +78,9 @@ function LinearOperator{T}(
7678
S::DataType = Vector{T},
7779
) where {T, I <: Integer, F, Ft, Fct}
7880
Mv5, Mtu5 = S(undef, 0), S(undef, 0)
79-
nargs = get_nargs(prod!)
80-
args5 = (nargs == 4)
81-
(args5 == false) || (nargs != 2) || throw(LinearOperatorException("Invalid number of arguments"))
82-
allocated5 = args5 ? true : false
83-
use_prod5! = args5 ? true : false
81+
args5 = false
82+
allocated5 = false
83+
use_prod5! = false
8484
return LinearOperator{T, I, F, Ft, Fct, S}(
8585
nrow,
8686
ncol,
@@ -100,21 +100,46 @@ function LinearOperator{T}(
100100
)
101101
end
102102

103-
LinearOperator{T}(
103+
function LinearOperator5(
104+
::Type{T},
104105
nrow::I,
105106
ncol::I,
106107
symmetric::Bool,
107108
hermitian::Bool,
108-
prod!,
109-
tprod!,
110-
ctprod!;
109+
prod!::F,
110+
tprod!::Ft,
111+
ctprod!::Fct,
112+
nprod::I,
113+
ntprod::I,
114+
nctprod::I;
111115
S::DataType = Vector{T},
112-
) where {T, I <: Integer} =
113-
LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, 0, 0, 0, S = S)
116+
) where {T, I <: Integer, F, Ft, Fct}
117+
Mv5, Mtu5 = S(undef, 0), S(undef, 0)
118+
args5 = true
119+
allocated5 = true
120+
use_prod5! = true
121+
return LinearOperator{T, I, F, Ft, Fct, S}(
122+
nrow,
123+
ncol,
124+
symmetric,
125+
hermitian,
126+
prod!,
127+
tprod!,
128+
ctprod!,
129+
nprod,
130+
ntprod,
131+
nctprod,
132+
args5,
133+
use_prod5!,
134+
Mv5,
135+
Mtu5,
136+
allocated5,
137+
)
138+
end
114139

115140
# create operator from other operators with +, *, vcat,...
116141
function CompositeLinearOperator(
117-
T::DataType,
142+
::Type{T},
118143
nrow::I,
119144
ncol::I,
120145
symmetric::Bool,
@@ -124,7 +149,7 @@ function CompositeLinearOperator(
124149
ctprod!::Fct,
125150
args5::Bool;
126151
S::DataType = Vector{T},
127-
) where {I <: Integer, F, Ft, Fct}
152+
) where {T, I <: Integer, F, Ft, Fct}
128153
Mv5, Mtu5 = S(undef, 0), S(undef, 0)
129154
allocated5 = true
130155
use_prod5! = true

src/constructors.jl

+47-18
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function LinearOperator(
1616
prod! = @closure (res, v, α, β) -> mul!(res, M, v, α, β)
1717
tprod! = @closure (res, u, α, β) -> mul!(res, transpose(M), u, α, β)
1818
ctprod! = @closure (res, w, α, β) -> mul!(res, adjoint(M), w, α, β)
19-
LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, S = S)
19+
LinearOperator5(T, nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, S = S)
2020
end
2121

2222
"""
@@ -58,6 +58,47 @@ end
5858
[tprod!=nothing, ctprod!=nothing],
5959
S = Vector{T}) where {T}
6060
61+
Construct a linear operator from functions where the type is specified as the first argument.
62+
Change `S` to use LinearOperators on GPU.
63+
```
64+
A = rand(2, 2)
65+
op = LinearOperator(Float64, 2, 2, false, false,
66+
(res, v) -> mul!(res, A, v),
67+
(res, w) -> mul!(res, A', w))
68+
```
69+
70+
Notice that the linear operator does not enforce the type, so using a wrong type can
71+
result in errors. For instance,
72+
```
73+
A = [im 1.0; 0.0 1.0] # Complex matrix
74+
op = LinearOperator5(Float64, 2, 2, false, false,
75+
(res, v) -> mul!(res, A, v),
76+
(res, u) -> mul!(res, transpose(A), u),
77+
(res, w) -> mul!(res, A', w))
78+
Matrix(op) # InexactError
79+
```
80+
The error is caused because `Matrix(op)` tries to create a Float64 matrix with the
81+
contents of the complex matrix `A`.
82+
"""
83+
function LinearOperator(
84+
::Type{T},
85+
nrow::I,
86+
ncol::I,
87+
symmetric::Bool,
88+
hermitian::Bool,
89+
prod!,
90+
tprod! = nothing,
91+
ctprod! = nothing;
92+
S = Vector{T},
93+
) where {T, I <: Integer}
94+
return LinearOperator(T, nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, 0, 0, 0, S = S)
95+
end
96+
97+
"""
98+
LinearOperator5(::Type{T}, nrow, ncol, symmetric, hermitian, prod!,
99+
[tprod!=nothing, ctprod!=nothing],
100+
S = Vector{T}) where {T}
101+
61102
Construct a linear operator from functions where the type is specified as the first argument.
62103
Change `S` to use LinearOperators on GPU.
63104
Notice that the linear operator does not enforce the type, so using a wrong type can
@@ -67,7 +108,7 @@ A = [im 1.0; 0.0 1.0] # Complex matrix
67108
function mulOp!(res, M, v, α, β)
68109
mul!(res, M, v, α, β)
69110
end
70-
op = LinearOperator(Float64, 2, 2, false, false,
111+
op = LinearOperator5(Float64, 2, 2, false, false,
71112
(res, v, α, β) -> mulOp!(res, A, v, α, β),
72113
(res, u, α, β) -> mulOp!(res, transpose(A), u, α, β),
73114
(res, w, α, β) -> mulOp!(res, A', w, α, β))
@@ -77,8 +118,6 @@ The error is caused because `Matrix(op)` tries to create a Float64 matrix with t
77118
contents of the complex matrix `A`.
78119
79120
Using `*` may generate a vector that contains `NaN` values.
80-
This can also happen if you use the 3-args `mul!` function with a preallocated vector such as
81-
`Vector{Float64}(undef, n)`.
82121
To fix this issue you will have to deal with the cases `β == 0` and `β != 0` separately:
83122
```
84123
d1 = [2.0; 3.0]
@@ -89,21 +128,11 @@ function mulSquareOpDiagonal!(res, d, v, α, β::T) where T
89128
res .= α .* d .* v .+ β .* res
90129
end
91130
end
92-
op = LinearOperator(Float64, 2, 2, true, true,
131+
op = LinearOperator5(Float64, 2, 2, true, true,
93132
(res, v, α, β) -> mulSquareOpDiagonal!(res, d, v, α, β))
94133
```
95-
96-
It is possible to create an operator with the 3-args `mul!`.
97-
In this case, using the 5-args `mul!` will generate storage vectors.
98-
99-
```
100-
A = rand(2, 2)
101-
op = LinearOperator(Float64, 2, 2, false, false,
102-
(res, v) -> mul!(res, A, v),
103-
(res, w) -> mul!(res, A', w))
104-
```
105134
"""
106-
function LinearOperator(
135+
function LinearOperator5(
107136
::Type{T},
108137
nrow::I,
109138
ncol::I,
@@ -114,5 +143,5 @@ function LinearOperator(
114143
ctprod! = nothing;
115144
S = Vector{T},
116145
) where {T, I <: Integer}
117-
return LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, S = S)
118-
end
146+
return LinearOperator5(T, nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, 0, 0, 0, S = S)
147+
end

src/kron.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function kron(A::AbstractLinearOperator, B::AbstractLinearOperator)
4141
symm = issymmetric(A) && issymmetric(B)
4242
herm = ishermitian(A) && ishermitian(B)
4343
nrow, ncol = m * p, n * q
44-
return LinearOperator{T}(nrow, ncol, symm, herm, prod!, tprod!, ctprod!)
44+
return LinearOperator5(T, nrow, ncol, symm, herm, prod!, tprod!, ctprod!)
4545
end
4646

4747
kron(A::AbstractMatrix, B::AbstractLinearOperator) = kron(LinearOperator(A), B)

src/lbfgs.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ mutable struct LBFGSOperator{T, I <: Integer, F, Ft, Fct} <: AbstractQuasiNewton
6464
nctprod::I
6565
end
6666

67-
LBFGSOperator{T}(
67+
LBFGSOperator(
6868
nrow::I,
6969
ncol::I,
7070
symmetric::Bool,
@@ -149,7 +149,7 @@ function InverseLBFGSOperator(T::DataType, n::I; kwargs...) where {I <: Integer}
149149
end
150150

151151
prod! = @closure (res, x, α, β) -> lbfgs_multiply(res, lbfgs_data, x, α, β)
152-
return LBFGSOperator{T}(n, n, true, true, prod!, prod!, prod!, true, lbfgs_data)
152+
return LBFGSOperator(n, n, true, true, prod!, prod!, prod!, true, lbfgs_data)
153153
end
154154

155155
InverseLBFGSOperator(n::Int; kwargs...) = InverseLBFGSOperator(Float64, n; kwargs...)
@@ -199,7 +199,7 @@ function LBFGSOperator(T::DataType, n::I; kwargs...) where {I <: Integer}
199199
end
200200

201201
prod! = @closure (res, x, α, β) -> lbfgs_multiply(res, lbfgs_data, x, α, β)
202-
return LBFGSOperator{T}(n, n, true, true, prod!, prod!, prod!, false, lbfgs_data)
202+
return LBFGSOperator(n, n, true, true, prod!, prod!, prod!, false, lbfgs_data)
203203
end
204204

205205
LBFGSOperator(n::I; kwargs...) where {I <: Integer} = LBFGSOperator(Float64, n; kwargs...)

src/linalg.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function opInverse(M::AbstractMatrix{T}; symm = false, herm = false) where {T}
2828
prod! = @closure (res, v, α, β) -> mulFact!(res, M, v, α, β)
2929
tprod! = @closure (res, u, α, β) -> mulFact!(res, transpose(M), u, α, β)
3030
ctprod! = @closure (res, w, α, β) -> mulFact!(res, adjoint(M), w, α, β)
31-
LinearOperator{T}(size(M, 2), size(M, 1), symm, herm, prod!, tprod!, ctprod!)
31+
LinearOperator5(T, size(M, 2), size(M, 1), symm, herm, prod!, tprod!, ctprod!)
3232
end
3333

3434
"""
@@ -53,7 +53,7 @@ function opCholesky(M::AbstractMatrix; check::Bool = false)
5353
tprod! = @closure (res, u, α, β) -> tmulFact!(res, LL, u, α, β) # M.' = conj(M)
5454
ctprod! = @closure (res, w, α, β) -> mulFact!(res, LL, w, α, β)
5555
S = eltype(LL)
56-
LinearOperator{S}(m, m, isreal(M), true, prod!, tprod!, ctprod!)
56+
LinearOperator5(S, m, m, isreal(M), true, prod!, tprod!, ctprod!)
5757
#TODO: use iterative refinement.
5858
end
5959

@@ -82,7 +82,7 @@ function opLDL(M::AbstractMatrix; check::Bool = false)
8282
tprod! = @closure (res, u, α, β) -> tmulFact!(res, LDL, u, α, β) # M.' = conj(M)
8383
ctprod! = @closure (res, w, α, β) -> mulFact!(res, LDL, w, α, β)
8484
S = eltype(LDL)
85-
return LinearOperator{S}(m, m, isreal(M), true, prod!, tprod!, ctprod!)
85+
return LinearOperator5(S, m, m, isreal(M), true, prod!, tprod!, ctprod!)
8686
#TODO: use iterative refinement.
8787
end
8888

@@ -97,7 +97,7 @@ function opLDL(M::Symmetric{T, SparseMatrixCSC{T, Int}}; check::Bool = false) wh
9797
tprod! = @closure (res, u) -> ldiv!(res, LDL, u) # M.' = conj(M)
9898
ctprod! = @closure (res, w) -> ldiv!(res, LDL, w)
9999
S = eltype(LDL)
100-
return LinearOperator{S}(m, m, isreal(M), true, prod!, tprod!, ctprod!)
100+
return LinearOperator(S, m, m, isreal(M), true, prod!, tprod!, ctprod!)
101101
end
102102

103103
function mulHouseholder!(res, h, v, α, β::T) where {T}
@@ -117,7 +117,7 @@ The result is `x -> (I - 2 h hᵀ) x`.
117117
function opHouseholder(h::AbstractVector{T}) where {T}
118118
n = length(h)
119119
prod! = @closure (res, v, α, β) -> mulHouseholder!(res, h, v, α, β) # tprod will be inferred
120-
LinearOperator{T}(n, n, isreal(h), true, prod!, nothing, prod!)
120+
LinearOperator5(T, n, n, isreal(h), true, prod!, nothing, prod!)
121121
end
122122

123123
function mulHermitian!(res, d, L, v, α, β::T) where {T}
@@ -139,7 +139,7 @@ function opHermitian(d::AbstractVector{S}, A::AbstractMatrix{T}) where {S, T}
139139
L = tril(A, -1)
140140
U = promote_type(S, T)
141141
prod! = @closure (res, v, α, β) -> mulHermitian!(res, d, L, v, α, β)
142-
LinearOperator{U}(m, m, isreal(A), true, prod!, nothing, nothing)
142+
LinearOperator5(U, m, m, isreal(A), true, prod!, nothing, nothing)
143143
end
144144

145145
"""

src/lsr1.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ mutable struct LSR1Operator{T, I <: Integer, F, Ft, Fct} <: AbstractQuasiNewtonO
5454
nctprod::I
5555
end
5656

57-
LSR1Operator{T}(
57+
LSR1Operator(
5858
nrow::I,
5959
ncol::I,
6060
symmetric::Bool,
@@ -114,7 +114,7 @@ function LSR1Operator(T::DataType, n::I; kwargs...) where {I <: Integer}
114114
end
115115

116116
prod! = @closure (res, x, α, β) -> lsr1_multiply(res, lsr1_data, x, α, β)
117-
return LSR1Operator{T}(n, n, true, true, prod!, nothing, nothing, false, lsr1_data)
117+
return LSR1Operator(n, n, true, true, prod!, nothing, nothing, false, lsr1_data)
118118
end
119119

120120
LSR1Operator(n::I; kwargs...) where {I <: Integer} = LSR1Operator(Float64, n; kwargs...)

src/special-operators.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Change `S` to use LinearOperators on GPU.
4848
"""
4949
function opEye(T::DataType, n::Int; S = Vector{T})
5050
prod! = @closure (res, v, α, β) -> mulOpEye!(res, v, α, β, n)
51-
LinearOperator{T}(n, n, true, true, prod!, prod!, prod!, S = S)
51+
LinearOperator5(T, n, n, true, true, prod!, prod!, prod!, S = S)
5252
end
5353

5454
opEye(n::Int) = opEye(Float64, n)
@@ -67,7 +67,7 @@ function opEye(T::DataType, nrow::I, ncol::I; S = Vector{T}) where {I <: Integer
6767
return opEye(T, nrow)
6868
end
6969
prod! = @closure (res, v, α, β) -> mulOpEye!(res, v, α, β, min(nrow, ncol))
70-
return LinearOperator{T}(nrow, ncol, false, false, prod!, prod!, prod!, S = S)
70+
return LinearOperator5(T, nrow, ncol, false, false, prod!, prod!, prod!, S = S)
7171
end
7272

7373
opEye(nrow::I, ncol::I) where {I <: Integer} = opEye(Float64, nrow, ncol)
@@ -90,7 +90,7 @@ Change `S` to use LinearOperators on GPU.
9090
"""
9191
function opOnes(T::DataType, nrow::I, ncol::I; S = Vector{T}) where {I <: Integer}
9292
prod! = @closure (res, v, α, β) -> mulOpOnes!(res, v, α, β)
93-
LinearOperator{T}(nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S)
93+
LinearOperator5(T, nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S)
9494
end
9595

9696
opOnes(nrow::I, ncol::I) where {I <: Integer} = opOnes(Float64, nrow, ncol)
@@ -113,7 +113,7 @@ Change `S` to use LinearOperators on GPU.
113113
"""
114114
function opZeros(T::DataType, nrow::I, ncol::I; S = Vector{T}) where {I <: Integer}
115115
prod! = @closure (res, v, α, β) -> mulOpZeros!(res, v, α, β)
116-
LinearOperator{T}(nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S)
116+
LinearOperator5(T, nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S)
117117
end
118118

119119
opZeros(nrow::I, ncol::I) where {I <: Integer} = opZeros(Float64, nrow, ncol)
@@ -134,7 +134,7 @@ Diagonal operator with the vector `d` on its main diagonal.
134134
function opDiagonal(d::AbstractVector{T}) where {T}
135135
prod! = @closure (res, v, α, β) -> mulSquareOpDiagonal!(res, d, v, α, β)
136136
ctprod! = @closure (res, w, α, β) -> mulSquareOpDiagonal!(res, conj.(d), w, α, β)
137-
LinearOperator{T}(length(d), length(d), true, isreal(d), prod!, prod!, ctprod!, S = typeof(d))
137+
LinearOperator5(T, length(d), length(d), true, isreal(d), prod!, prod!, ctprod!, S = typeof(d))
138138
end
139139

140140
function mulOpDiagonal!(res, d, v, α, β::T, n_min) where {T}
@@ -157,7 +157,7 @@ function opDiagonal(nrow::I, ncol::I, d::AbstractVector{T}) where {T, I <: Integ
157157
prod! = @closure (res, v, α, β) -> mulOpDiagonal!(res, d, v, α, β, n_min)
158158
tprod! = @closure (res, u, α, β) -> mulOpDiagonal!(res, d, u, α, β, n_min)
159159
ctprod! = @closure (res, w, α, β) -> mulOpDiagonal!(res, conj.(d), w, α, β, n_min)
160-
LinearOperator{T}(nrow, ncol, false, false, prod!, tprod!, ctprod!, S = typeof(d))
160+
LinearOperator5(T, nrow, ncol, false, false, prod!, tprod!, ctprod!, S = typeof(d))
161161
end
162162

163163
function mulRestrict!(res, I, v, α, β)
@@ -185,7 +185,7 @@ function opRestriction(Idx::LinearOperatorIndexType{I}, ncol::I) where {I <: Int
185185
nrow = length(Idx)
186186
prod! = @closure (res, v, α, β) -> mulRestrict!(res, Idx, v, α, β)
187187
tprod! = @closure (res, u, α, β) -> multRestrict!(res, Idx, u, α, β)
188-
return LinearOperator{I}(nrow, ncol, false, false, prod!, tprod!, tprod!)
188+
return LinearOperator5(I, nrow, ncol, false, false, prod!, tprod!, tprod!)
189189
end
190190

191191
opRestriction(::Colon, ncol::I) where {I <: Integer} = opEye(I, ncol)

0 commit comments

Comments
 (0)