Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/MatrixProductStates/MatrixProductStates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module MatrixProductStates

using TensorTrains
import TensorTrains: _reshape1, accumulate_L, accumulate_R, accumulate_M,
sample_noalloc,
sample_noalloc, normalize_eachmatrix!,
normalize!, _merge_tensors, _split_tensor, LeftOrRight, Left, Right,
precompute_left_environments, precompute_right_environments, update_environments!,
_two_site_dmrg_generic!, two_site_dmrg!
Expand All @@ -16,9 +16,10 @@ import Optim

export MPS
export rand_mps
export nparams, evaluate
export grad_normalization_canonical, grad_normalization_two_site_canonical,
loglikelihood, grad_loglikelihood, grad_loglikelihood_two_site,
two_site_dmrg!
two_site_dmrg!, empirical_distribution_mps

include("mps.jl")
include("derivatives.jl")
Expand Down
14 changes: 8 additions & 6 deletions src/MatrixProductStates/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,20 @@ function grad_loglikelihood_two_site(p::MPS, k::Integer, X;
prodA_left = [precompute_left_environments(p.ψ, x) for x in X],
prodA_right = [precompute_right_environments(p.ψ, x) for x in X],
Aᵏᵏ⁺¹ =_merge_tensors(p[k], p[k+1]),
weights = ones(length(X)))
weights = ones(length(X))/length(X))

Zprime, Z = grad_normalization_two_site_canonical(p, k; Aᵏᵏ⁺¹)
ll = -log(Z) * mean(weights)
ll = -log(Z) * sum(weights)
T = length(X)
gA = - Zprime / Z * mean(weights)
gA = - Zprime / Z * sum(weights)

# TODO: this operation is in principle parallelizable
for (n,x) in enumerate(X)
gr, val = grad_evaluate_two_site(p.ψ, k, x;
Ax_left = prodA_left[n][k-1], Ax_right = prodA_right[n][k+2], Aᵏᵏ⁺¹
)
@inbounds gA[:,:,x[k]...,x[k+1]...] .+= 2/T * gr / val * weights[n]
ll += 1/T * log(abs2(val)) * weights[n]
@inbounds gA[:,:,x[k]...,x[k+1]...] .+= 2 * gr / val * weights[n]
ll += log(abs2(val)) * weights[n]
end
return gA, ll
end
Expand All @@ -90,7 +90,9 @@ end

Fit a MPS to data `X` using a MPS ansatz and the 2site-DMRG-like gradient descent.
"""
function TensorTrains.two_site_dmrg!(p::MPS, X, nsweeps; weights = ones(length(X)), kw...)
function TensorTrains.two_site_dmrg!(p::MPS, X, nsweeps;
weights = ones(length(X))/length(X), kw...)

function func(p, k, data; _kw...)
grad, val = grad_loglikelihood_two_site(p, k, data...; weights, _kw...)
# must return function to be *minimized*, so the *negative* log-likelihood
Expand Down
56 changes: 50 additions & 6 deletions src/MatrixProductStates/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end

@forward MPS.ψ TensorTrains.bond_dims, Base.iterate, Base.firstindex, Base.lastindex,
Base.setindex!, Base.getindex, check_bond_dims, Base.length, Base.eachindex, Base.eltype,
TensorTrains.nparams,
TensorTrains.nparams, TensorTrains.normalize_eachmatrix!,
TensorTrains.precompute_left_environments, TensorTrains.precompute_right_environments

Base.:(==)(A::T, B::T) where {T<:MPS} = isequal(A.ψ, B.ψ)
Expand Down Expand Up @@ -316,11 +316,55 @@ end
# TODO: maybe since (p.ψ.z) is both at the numerator and denominator, ignore it to avoid cancellations with the subtraction?

"""
StatsBase.loglikelihood(p::MPS, X)
StatsBase.loglikelihood(p::MPS, X; weights)

Compute the loglikelihood of the data `X` under the MPS distribution `p`.
Compute the average loglikelihood of the data `X` under the MPS distribution `p`.
Optionally re-weight the log-probability of each datapoint.
"""
function StatsBase.loglikelihood(p::MPS, X)
logz = log(normalization(p))
return mean(log(evaluate(p, x)) for x in X) - logz
function loglikelihood(p::MPS, X; weights=ones(length(X))/length(X))
logz = log(normalization(p)) * sum(weights)
return sum(log(evaluate(p, x)) * w for (x,w) in zip(X,weights)) - logz
end

"""
empirical_distribution_mps(X; qs, weights)

Return a MPS encoding the empirical probability distribution of dataset ``X=\\{x^{(1)}, \\ldots, x^{(M)}\\}``.
The resulting MPS evaluates to ``p(x)=\\frac{1}{M}\\sum_{\\mu \\in 1:M}\\delta(x, x^{(\\mu)})``.

## Optional arguments
- `qs`: the number of states for each variable. Inferred from the data if not provided
- `weights`: allows to re-weight the probability of each sample. It should sum to one (not checked)
"""
function empirical_distribution_mps(X; qs=maximum(maximum.(X)),
weights=ones(length(X))/length(X))

@assert all(>=(0), weights)
M = length(X)
L = length(X[1])
@assert all(x -> length(x)==L, X)

tensors = map(1:L) do i
if i == 1
A = zeros(1, M, qs...)
for n in 1:M
A[1,n,X[n][i]...] =sqrt(weights[n])
end
A
elseif i == L
A = zeros(M, 1, qs...)
for n in 1:M
A[n,1,X[n][i]...] = 1
end
A
else
A = zeros(M, M, qs...)
for n in 1:M
A[n,n,X[n][i]...] = 1
end
A
end
end

return MPS(tensors)
end
11 changes: 11 additions & 0 deletions test/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,14 @@ end
@test loglikelihood(q, X) > ll
end
end

@testset "Empirical distribution" begin
X = [[[rand(1:q) for q in (2, 3)] for _ in 1:50] for _ in 1:100]
unique!(X)
weights = rand(length(X))
weights ./= sum(weights)
p = empirical_distribution_mps(X; weights)
for (x, w) in zip(X, weights)
@test evaluate(p, x) ≈ w
end
end