-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ComponentArrays extension #407
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with ComponentArrays
. Can @mhauru, @penelopeysm, or @sunxd3 take a look, please?
@@ -61,7 +61,7 @@ struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat},V<:DualValue} | |||
ℓπ::V # Cached neg potential energy for the current θ. | |||
ℓκ::V # Cached neg kinect energy for the current r. | |||
function PhasePoint(θ::T, r::T, ℓπ::V, ℓκ::V) where {T,V} | |||
@argcheck length(θ) == length(r) == length(ℓπ.gradient) == length(ℓπ.gradient) | |||
@argcheck length(θ) == length(r) == length(ℓπ.gradient) == length(ℓκ.gradient) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a separate change (which I had fixed in a recent PR but maybe it was reverted accidentally when fixing merge conflicts). IMO it deserves a separate PR + test.
function AdvancedHMC.∂H∂r(h::Hamiltonian{<:UnitEuclideanMetric,<:GaussianKinetic}, r::ComponentArray) | ||
copy(r) | ||
end | ||
function AdvancedHMC.∂H∂r(h::Hamiltonian{<:DiagEuclideanMetric,<:GaussianKinetic}, r::ComponentArray) | ||
out = similar(r) | ||
out .= h.metric.M⁻¹ .* r | ||
return out | ||
end | ||
function AdvancedHMC.∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::ComponentArray) | ||
out = similar(r) | ||
mul!(out, h.metric.M⁻¹, r) | ||
return out | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's nothing ComponentArray
specific in any of these methods. IMO this indicates that the we might just want to improve the existing generic definitions in the package.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might be able to use AbstractVector
(i.e. AbstractArray{T, 1} where T
) here
Replace: #345
Fix: #344