Skip to content

Add basic vector transforms equivalent to existing scalar ones, along with vector transform composition #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/TransformVariables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ include("special_arrays.jl")
include("constant.jl")
include("aggregation.jl")
include("custom.jl")
include("vector.jl")

end # module
179 changes: 179 additions & 0 deletions src/vector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
export VectorIdentity, VecTVExp, VecTVLogistic, VecTVShift, VecTVScale, VecTVNeg
#######
####### identity
#######

"""
$(TYPEDEF)

Identity ``x ↦ x``.
"""
struct VectorIdentity <: VectorTransform
d::Int
function VectorIdentity(d)
new(d)
end
end
Comment on lines +11 to +16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This transform already exists in a more general form: as(Real, dims...)

If there are performance improvements possible I think it would be better to implement them this existing array transformation.


dimension(t::VectorIdentity) = t.d
transform_with(flag::LogJacFlag, t::VectorIdentity, x::AbstractVector{T}, index::Int) where {T} = x, logjac_zero(flag, T), index + dimension(t)
inverse_eltype(t::VectorIdentity, T::Type) = eltype(T)
function inverse_at!(x::AbstractVector, index::Integer, t::VectorIdentity, y::AbstractVector)
newindex = index + dimension(t)
x[index:newindex-1] .= y
return newindex
end


#######
####### elementary vector transforms
#######

"""
$(TYPEDEF)

Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive reals.
"""
struct VecTVExp <: VectorTransform
d::Int
function VecTVExp(d)
new(d)
end
end
Comment on lines +37 to +42
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, a more general version of this transform already exists.


dimension(t::VecTVExp) = t.d
transform_with(flag::LogJacFlag, t::VecTVExp, x::AbstractVector{T}, index::Int) where {T} = exp.(x), flag isa LogJac ? abs(prod(x)) : logjac_zero(flag, T), index + dimension(t)
inverse_eltype(t::VecTVExp, T::Type) = eltype(T)
function inverse_at!(x::AbstractVector, index::Integer, t::VecTVExp, y::AbstractVector)
newindex = index + dimension(t)
x[index:newindex-1] .= log.(y)
return newindex
end

"""
$(TYPEDEF)

Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1).
"""
struct VecTVLogistic <: VectorTransform
d::Int
function VecTVLogistic(d)
new(d)
end
end
Comment on lines +58 to +63
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, also this already exists.

dimension(t::VecTVLogistic) = t.d
transform_with(flag::LogJacFlag, t::VecTVLogistic, x::AbstractVector{T}, index::Int) where {T} = logistic.(x), flag isa LogJac ? prod(logistic_logjac.(x)) : logjac_zero(flag, T), index + dimension(t)
inverse_eltype(t::VecTVLogistic, T::Type) = eltype(T)
function inverse_at!(x::AbstractVector, index::Integer, t::VecTVLogistic, y::AbstractVector)
newindex = index + dimension(t)
x[index:newindex-1] .= logit.(y)
return newindex
end

"""
$(TYPEDEF)

Shift transformation `x ↦ x + shift`.
"""
struct VecTVShift{T<:Real} <: VectorTransform
Comment on lines +73 to +78
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One could exploit muladd when not decomposing shifting and scaling (which in my experience are usually both needed) in separate steps.

shift::AbstractVector
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type should be a type parameter T<:AbstractVector{<:Real}.

function VecTVShift(shift::AbstractVector{T}) where {T}
return new{T}(shift)
end
end
function VecTVShift(val::Real, dim::Integer)
return VecTVShift(repeat([val;], dim))
end
Comment on lines +84 to +86
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should just be

Suggested change
function VecTVShift(val::Real, dim::Integer)
return VecTVShift(repeat([val;], dim))
end
function VecTVShift(val::Real, dim::Integer)
return VecTVShift(fill(val, dim))
end

Ideally, though, I think users might want to use FillArrays here. To avoid additional dependencies and keep it more flexible, I'd remove this definition:

Suggested change
function VecTVShift(val::Real, dim::Integer)
return VecTVShift(repeat([val;], dim))
end


dimension(t::VecTVShift) = length(t.shift)
transform_with(flag::LogJacFlag, t::VecTVShift, x::AbstractVector{T}, index::Int) where {T} = x .+ t.shift, logjac_zero(flag, T), index + dimension(t)
inverse_eltype(t::VecTVShift, T::Type) = eltype(T)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This must also take into account the element type of the shift. Moreover, currently the compiler won't specialize on T.

Suggested change
inverse_eltype(t::VecTVShift, T::Type) = eltype(T)
inverse_eltype(t::VecTVShift, ::Type{<:AbstractVector{T}}) where {T<:Real} = promote_type(eltype(t.shift), T)

function inverse_at!(x::AbstractVector, index::Integer, t::VecTVShift, y::AbstractVector)
newindex = index + dimension(t)
x[index:newindex-1] .= y .+ t.shift
return newindex
end

"""
$(TYPEDEF)

Scale transformation `x ↦ scale * x`.
"""
struct VecTVScale{T<:Real} <: VectorTransform
scale::AbstractVector
function VecTVScale(scale::AbstractVector{T}) where {T}
return new{T}(scale)
end
end
function VecTVScale(val::Real, dim::Integer)
return VecTVScale(repeat([val;], dim))
end

dimension(t::VecTVScale) = length(t.shift)
transform_with(flag::LogJacFlag, t::VecTVScale, x::AbstractVector{T}, index::Int) where {T} = x .* t.scale, flag isa LogJac ? log(abs(prod(x))) : logjac_zero(flag, T), index + dimension(t)
inverse_eltype(t::VecTVScale, T::Type) = eltype(T)
function inverse_at!(x::AbstractVector, index::Integer, t::VecTVScale, y::AbstractVector)
newindex = index + dimension(t)
x[index:newindex-1] .= y .* t.scale
return newindex
end

"""
$(TYPEDEF)

Negative transformation `x ↦ -x`.
"""
struct VecTVNeg <: VectorTransform
d::Int
function VecTVNeg(d)
new(d)
end
end
dimension(t::VecTVNeg) = t.d
transform_with(flag::LogJacFlag, t::VecTVNeg, x::AbstractVector{T}, index::Int) where {T} = -x, logjac_zero(flag, T), index + dimension(t)
inverse_eltype(t::VecTVNeg, T::Type) = eltype(T)
function inverse_at!(x::AbstractVector, index::Integer, t::VecTVNeg, y::AbstractVector)
newindex = index + dimension(t)
x[index:newindex-1] .= -y
return newindex
end


### TODO composition of vector transforms

#######
####### composite scalar transforms
#######
###"""
###$(TYPEDEF)
###
###A composite scalar transformation, i.e. a sequence of scalar transformations.
###"""
###struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform
### transforms::Ts
### function CompositeScalarTransform(transforms::Ts) where {Ts <: Tuple{ScalarTransform,Vararg{ScalarTransform}}}
### new{Ts}(transforms)
### end
###end
###
###transform(t::CompositeScalarTransform, x) = foldr(transform, t.transforms, init=x)
###function transform_and_logjac(ts::CompositeScalarTransform, x)
### foldr(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do t, (x, logjac)
### nx, nlogjac = transform_and_logjac(t, x)
### (nx, logjac + nlogjac)
### end
###end
###
###inverse(ts::CompositeScalarTransform, x) = foldl((y, t) -> inverse(t, y), ts.transforms, init=x)
###function inverse_and_logjac(ts::CompositeScalarTransform, x)
### foldl(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do (x, logjac), t
### nx, nlogjac = inverse_and_logjac(t, x)
### (nx, logjac + nlogjac)
### end
###end
###
###Base.:∘(t::ScalarTransform, s::ScalarTransform) = CompositeScalarTransform((t, s))
###Base.:∘(t::ScalarTransform, ct::CompositeScalarTransform) = CompositeScalarTransform((t, ct.transforms...))
###Base.:∘(ct::CompositeScalarTransform, t::ScalarTransform) = CompositeScalarTransform((ct.transforms..., t))
###Base.:∘(ct1::CompositeScalarTransform, ct2::CompositeScalarTransform) = CompositeScalarTransform((ct1.transforms..., ct2.transforms...))
###Base.:∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = foldl(∘, tt; init=t)