-
Notifications
You must be signed in to change notification settings - Fork 16
Add basic vector transforms equivalent to existing scalar ones, along with vector transform composition #142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,179 @@ | ||||||||||||||||||||
export VectorIdentity, VecTVExp, VecTVLogistic, VecTVShift, VecTVScale, VecTVNeg | ||||||||||||||||||||
####### | ||||||||||||||||||||
####### identity | ||||||||||||||||||||
####### | ||||||||||||||||||||
|
||||||||||||||||||||
""" | ||||||||||||||||||||
$(TYPEDEF) | ||||||||||||||||||||
|
||||||||||||||||||||
Identity ``x ↦ x``. | ||||||||||||||||||||
""" | ||||||||||||||||||||
struct VectorIdentity <: VectorTransform | ||||||||||||||||||||
d::Int | ||||||||||||||||||||
function VectorIdentity(d) | ||||||||||||||||||||
new(d) | ||||||||||||||||||||
end | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
dimension(t::VectorIdentity) = t.d | ||||||||||||||||||||
transform_with(flag::LogJacFlag, t::VectorIdentity, x::AbstractVector{T}, index::Int) where {T} = x, logjac_zero(flag, T), index + dimension(t) | ||||||||||||||||||||
inverse_eltype(t::VectorIdentity, T::Type) = eltype(T) | ||||||||||||||||||||
function inverse_at!(x::AbstractVector, index::Integer, t::VectorIdentity, y::AbstractVector) | ||||||||||||||||||||
newindex = index + dimension(t) | ||||||||||||||||||||
x[index:newindex-1] .= y | ||||||||||||||||||||
return newindex | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
####### | ||||||||||||||||||||
####### elementary vector transforms | ||||||||||||||||||||
####### | ||||||||||||||||||||
|
||||||||||||||||||||
""" | ||||||||||||||||||||
$(TYPEDEF) | ||||||||||||||||||||
|
||||||||||||||||||||
Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive reals. | ||||||||||||||||||||
""" | ||||||||||||||||||||
struct VecTVExp <: VectorTransform | ||||||||||||||||||||
d::Int | ||||||||||||||||||||
function VecTVExp(d) | ||||||||||||||||||||
new(d) | ||||||||||||||||||||
end | ||||||||||||||||||||
end | ||||||||||||||||||||
Comment on lines
+37
to
+42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, a more general version of this transform already exists. |
||||||||||||||||||||
|
||||||||||||||||||||
dimension(t::VecTVExp) = t.d | ||||||||||||||||||||
transform_with(flag::LogJacFlag, t::VecTVExp, x::AbstractVector{T}, index::Int) where {T} = exp.(x), flag isa LogJac ? abs(prod(x)) : logjac_zero(flag, T), index + dimension(t) | ||||||||||||||||||||
inverse_eltype(t::VecTVExp, T::Type) = eltype(T) | ||||||||||||||||||||
function inverse_at!(x::AbstractVector, index::Integer, t::VecTVExp, y::AbstractVector) | ||||||||||||||||||||
newindex = index + dimension(t) | ||||||||||||||||||||
x[index:newindex-1] .= log.(y) | ||||||||||||||||||||
return newindex | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
""" | ||||||||||||||||||||
$(TYPEDEF) | ||||||||||||||||||||
|
||||||||||||||||||||
Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1). | ||||||||||||||||||||
""" | ||||||||||||||||||||
struct VecTVLogistic <: VectorTransform | ||||||||||||||||||||
d::Int | ||||||||||||||||||||
function VecTVLogistic(d) | ||||||||||||||||||||
new(d) | ||||||||||||||||||||
end | ||||||||||||||||||||
end | ||||||||||||||||||||
Comment on lines
+58
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, also this already exists. |
||||||||||||||||||||
dimension(t::VecTVLogistic) = t.d | ||||||||||||||||||||
transform_with(flag::LogJacFlag, t::VecTVLogistic, x::AbstractVector{T}, index::Int) where {T} = logistic.(x), flag isa LogJac ? prod(logistic_logjac.(x)) : logjac_zero(flag, T), index + dimension(t) | ||||||||||||||||||||
inverse_eltype(t::VecTVLogistic, T::Type) = eltype(T) | ||||||||||||||||||||
function inverse_at!(x::AbstractVector, index::Integer, t::VecTVLogistic, y::AbstractVector) | ||||||||||||||||||||
newindex = index + dimension(t) | ||||||||||||||||||||
x[index:newindex-1] .= logit.(y) | ||||||||||||||||||||
return newindex | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
""" | ||||||||||||||||||||
$(TYPEDEF) | ||||||||||||||||||||
|
||||||||||||||||||||
Shift transformation `x ↦ x + shift`. | ||||||||||||||||||||
""" | ||||||||||||||||||||
struct VecTVShift{T<:Real} <: VectorTransform | ||||||||||||||||||||
Comment on lines
+73
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One could exploit |
||||||||||||||||||||
shift::AbstractVector | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type should be a type parameter |
||||||||||||||||||||
function VecTVShift(shift::AbstractVector{T}) where {T} | ||||||||||||||||||||
return new{T}(shift) | ||||||||||||||||||||
end | ||||||||||||||||||||
end | ||||||||||||||||||||
function VecTVShift(val::Real, dim::Integer) | ||||||||||||||||||||
return VecTVShift(repeat([val;], dim)) | ||||||||||||||||||||
end | ||||||||||||||||||||
Comment on lines
+84
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should just be
Suggested change
Ideally, though, I think users might want to use FillArrays here. To avoid additional dependencies and keep it more flexible, I'd remove this definition:
Suggested change
|
||||||||||||||||||||
|
||||||||||||||||||||
dimension(t::VecTVShift) = length(t.shift) | ||||||||||||||||||||
transform_with(flag::LogJacFlag, t::VecTVShift, x::AbstractVector{T}, index::Int) where {T} = x .+ t.shift, logjac_zero(flag, T), index + dimension(t) | ||||||||||||||||||||
inverse_eltype(t::VecTVShift, T::Type) = eltype(T) | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This must also take into account the element type of the shift. Moreover, currently the compiler won't specialize on
Suggested change
|
||||||||||||||||||||
function inverse_at!(x::AbstractVector, index::Integer, t::VecTVShift, y::AbstractVector) | ||||||||||||||||||||
newindex = index + dimension(t) | ||||||||||||||||||||
x[index:newindex-1] .= y .+ t.shift | ||||||||||||||||||||
return newindex | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
""" | ||||||||||||||||||||
$(TYPEDEF) | ||||||||||||||||||||
|
||||||||||||||||||||
Scale transformation `x ↦ scale * x`. | ||||||||||||||||||||
""" | ||||||||||||||||||||
struct VecTVScale{T<:Real} <: VectorTransform | ||||||||||||||||||||
scale::AbstractVector | ||||||||||||||||||||
function VecTVScale(scale::AbstractVector{T}) where {T} | ||||||||||||||||||||
return new{T}(scale) | ||||||||||||||||||||
end | ||||||||||||||||||||
end | ||||||||||||||||||||
function VecTVScale(val::Real, dim::Integer) | ||||||||||||||||||||
return VecTVScale(repeat([val;], dim)) | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
dimension(t::VecTVScale) = length(t.shift) | ||||||||||||||||||||
transform_with(flag::LogJacFlag, t::VecTVScale, x::AbstractVector{T}, index::Int) where {T} = x .* t.scale, flag isa LogJac ? log(abs(prod(x))) : logjac_zero(flag, T), index + dimension(t) | ||||||||||||||||||||
inverse_eltype(t::VecTVScale, T::Type) = eltype(T) | ||||||||||||||||||||
function inverse_at!(x::AbstractVector, index::Integer, t::VecTVScale, y::AbstractVector) | ||||||||||||||||||||
newindex = index + dimension(t) | ||||||||||||||||||||
x[index:newindex-1] .= y .* t.scale | ||||||||||||||||||||
return newindex | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
""" | ||||||||||||||||||||
$(TYPEDEF) | ||||||||||||||||||||
|
||||||||||||||||||||
Negative transformation `x ↦ -x`. | ||||||||||||||||||||
""" | ||||||||||||||||||||
struct VecTVNeg <: VectorTransform | ||||||||||||||||||||
d::Int | ||||||||||||||||||||
function VecTVNeg(d) | ||||||||||||||||||||
new(d) | ||||||||||||||||||||
end | ||||||||||||||||||||
end | ||||||||||||||||||||
dimension(t::VecTVNeg) = t.d | ||||||||||||||||||||
transform_with(flag::LogJacFlag, t::VecTVNeg, x::AbstractVector{T}, index::Int) where {T} = -x, logjac_zero(flag, T), index + dimension(t) | ||||||||||||||||||||
inverse_eltype(t::VecTVNeg, T::Type) = eltype(T) | ||||||||||||||||||||
function inverse_at!(x::AbstractVector, index::Integer, t::VecTVNeg, y::AbstractVector) | ||||||||||||||||||||
newindex = index + dimension(t) | ||||||||||||||||||||
x[index:newindex-1] .= -y | ||||||||||||||||||||
return newindex | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
### TODO composition of vector transforms | ||||||||||||||||||||
|
||||||||||||||||||||
####### | ||||||||||||||||||||
####### composite scalar transforms | ||||||||||||||||||||
####### | ||||||||||||||||||||
###""" | ||||||||||||||||||||
###$(TYPEDEF) | ||||||||||||||||||||
### | ||||||||||||||||||||
###A composite scalar transformation, i.e. a sequence of scalar transformations. | ||||||||||||||||||||
###""" | ||||||||||||||||||||
###struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform | ||||||||||||||||||||
### transforms::Ts | ||||||||||||||||||||
### function CompositeScalarTransform(transforms::Ts) where {Ts <: Tuple{ScalarTransform,Vararg{ScalarTransform}}} | ||||||||||||||||||||
### new{Ts}(transforms) | ||||||||||||||||||||
### end | ||||||||||||||||||||
###end | ||||||||||||||||||||
### | ||||||||||||||||||||
###transform(t::CompositeScalarTransform, x) = foldr(transform, t.transforms, init=x) | ||||||||||||||||||||
###function transform_and_logjac(ts::CompositeScalarTransform, x) | ||||||||||||||||||||
### foldr(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do t, (x, logjac) | ||||||||||||||||||||
### nx, nlogjac = transform_and_logjac(t, x) | ||||||||||||||||||||
### (nx, logjac + nlogjac) | ||||||||||||||||||||
### end | ||||||||||||||||||||
###end | ||||||||||||||||||||
### | ||||||||||||||||||||
###inverse(ts::CompositeScalarTransform, x) = foldl((y, t) -> inverse(t, y), ts.transforms, init=x) | ||||||||||||||||||||
###function inverse_and_logjac(ts::CompositeScalarTransform, x) | ||||||||||||||||||||
### foldl(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do (x, logjac), t | ||||||||||||||||||||
### nx, nlogjac = inverse_and_logjac(t, x) | ||||||||||||||||||||
### (nx, logjac + nlogjac) | ||||||||||||||||||||
### end | ||||||||||||||||||||
###end | ||||||||||||||||||||
### | ||||||||||||||||||||
###Base.:∘(t::ScalarTransform, s::ScalarTransform) = CompositeScalarTransform((t, s)) | ||||||||||||||||||||
###Base.:∘(t::ScalarTransform, ct::CompositeScalarTransform) = CompositeScalarTransform((t, ct.transforms...)) | ||||||||||||||||||||
###Base.:∘(ct::CompositeScalarTransform, t::ScalarTransform) = CompositeScalarTransform((ct.transforms..., t)) | ||||||||||||||||||||
###Base.:∘(ct1::CompositeScalarTransform, ct2::CompositeScalarTransform) = CompositeScalarTransform((ct1.transforms..., ct2.transforms...)) | ||||||||||||||||||||
###Base.:∘(t::ScalarTransform, tt::Vararg{ScalarTransform}) = foldl(∘, tt; init=t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This transform already exists in a more general form:
as(Real, dims...)
If there are performance improvements possible I think it would be better to implement them this existing array transformation.