Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorTrains"
uuid = "89893e69-996d-40b1-ba32-8ff5f34c0dd5"
authors = ["stecrotti <[email protected]>", "abraunst <[email protected]"]
version = "0.13.0"
authors = ["stecrotti <[email protected]>", "abraunst <[email protected]"]

[deps]
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Expand Down
11 changes: 6 additions & 5 deletions experiments/dmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ using Plots

d_original = 3
L = 5
q = rand_mps(ComplexF64, d_original, L, 4)
F = ComplexF64
q = rand_mps(F, d_original, L, 4)
normalize!(q)
nsamples = 5*10^3
nsamples = 10*10^3
X = [sample(q)[1] for _ in 1:nsamples]
nll = -loglikelihood(q, X)
println("Negative Log-Likelihood according to generating distribution q = $nll\n")
mq = marginals(q)

p = rand_mps(ComplexF64, 2, length(q), 4)
p = rand_mps(F, 2, length(q), 4)

function CB()
nlls = zeros(0)
Expand All @@ -35,9 +36,9 @@ function CB()
end

callback = CB()
nsweeps = 40
nsweeps = 20
two_site_dmrg!(p, X, nsweeps;
η=1e-4, ndesc=10, svd_trunc=TruncBond(d_original+2), callback)
η=1e-4, ndesc=10, svd_trunc=TruncBond(d_original), callback)

pl1 = plot(callback.nlls, xlabel="it", ylabel="NLL", label="")
hline!(pl1, [nll], ls=:dash, c=:gray, label="NLL according to generative model")
Expand Down
8 changes: 4 additions & 4 deletions src/MatrixProductStates/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function grad_normalization_canonical(p::MPS, k::Integer)
@tullio zz = Aᵏconj_[m,n,x] * Aᵏ_[m,n,x]
z2 = abs2(float(p.ψ.z))
z = zz / z2
gradz = 2 * Aᵏ / z2
gradz = 2 * conj(Aᵏ) / z2
return gradz, z
end

Expand All @@ -24,7 +24,7 @@ function grad_loglikelihood(p::MPS, k::Integer, X)
gA = - Zprime ./ Z
for x in X
gr, val = grad_evaluate(p.ψ, k, x)
gA[:,:,x[k]...] .+= 2/T * gr / val
gA[:,:,x[k]...] .+= 2 * conj(gr) / (T * val)
ll += 1/T * log(abs2(val))
end
return gA, ll
Expand Down Expand Up @@ -70,7 +70,7 @@ function grad_loglikelihood_two_site(p::MPS, k::Integer, X;
weights = ones(length(X)))

Zprime, Z = grad_normalization_two_site_canonical(p, k; Aᵏᵏ⁺¹)
ll = -log(Z) * mean(weights)
ll = - log(Z) * mean(weights)
T = length(X)
gA = - Zprime / Z * mean(weights)

Expand All @@ -79,7 +79,7 @@ function grad_loglikelihood_two_site(p::MPS, k::Integer, X;
gr, val = grad_evaluate_two_site(p.ψ, k, x;
Ax_left = prodA_left[n][k-1], Ax_right = prodA_right[n][k+2], Aᵏᵏ⁺¹
)
@inbounds gA[:,:,x[k]...,x[k+1]...] .+= 2/T * gr / val * weights[n]
@inbounds gA[:,:,x[k]...,x[k+1]...] .+= 2 * conj(gr) / (T * conj(val)) * weights[n]
ll += 1/T * log(abs2(val)) * weights[n]
end
return gA, ll
Expand Down
10 changes: 5 additions & 5 deletions src/tensor_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,19 @@ function orthogonalize_left!(C::TensorTrain{F}; svd_trunc=TruncThresh(1e-6),
return C
end

function orthogonalize_center!(C::TensorTrain, l::Integer; svd_trunc=TruncThresh(1e-6))
function orthogonalize_center!(C::TensorTrain, l::Integer; svd_trunc=TruncThresh(0.0))
orthogonalize_left!(C; svd_trunc, indices = 1:l-1)
orthogonalize_right!(C; svd_trunc, indices = l+1:length(C))
end

"""
orthogonalize_two_site_center!(C::TensorTrain, k::Integer; svd_trunc=TruncThresh(1e-6))
orthogonalize_two_site_center!(C::TensorTrain, k::Integer; svd_trunc=TruncThresh(0.0))

Orthogonalize the tensor train for a two-site DMRG update at positions k and k+1.
This puts sites 1:k-1 in left-canonical form, sites k+2:N in right-canonical form,
and leaves sites k and k+1 as the non-orthogonal center for merging.
"""
function orthogonalize_two_site_center!(C::TensorTrain, k::Integer; svd_trunc=TruncThresh(1e-6))
function orthogonalize_two_site_center!(C::TensorTrain, k::Integer; svd_trunc=TruncThresh(0.0))
@assert 1 <= k < length(C) "k must be between 1 and length(C)-1 for two-site update"
orthogonalize_left!(C; svd_trunc, indices = 1:k-1)
orthogonalize_right!(C; svd_trunc, indices = k+2:length(C))
Expand Down Expand Up @@ -246,7 +246,7 @@ function grad_evaluate(A::TensorTrain, l::Integer, X)
Ax_center = A[l][:,:,X[l]...]
z = float(A.z)
val = only(prodA_left * Ax_center * prodA_right) / z
gr = (prodA_right * prodA_left)' / z
gr = transpose(prodA_right * prodA_left) / z
return gr, val
end

Expand Down Expand Up @@ -281,7 +281,7 @@ function grad_evaluate_two_site(A::TensorTrain, k::Integer, X;
Ax_center = Aᵏᵏ⁺¹[:,:,X[k]...,X[k+1]...]
z = float(A.z)
val = only(Ax_left * Ax_center * Ax_right) / z
gr = (Ax_right * Ax_left)' / z
gr = transpose(Ax_right * Ax_left) / z
return gr, val
end

Expand Down
14 changes: 5 additions & 9 deletions test/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,23 +247,19 @@ end
@testset "Complex Derivatives" begin
F = ComplexF64
tensors = [rand(F, 1,5,2,2), rand(F, 5,4,2,2),
rand(F, 4,10,2,2), rand(F, 10,1,2,2)]
rand(F, 4,10,2,2), rand(F, 10,3,2,2), rand(F, 3, 1, 2, 2)]
ψ = TensorTrain(tensors)
p = MPS(ψ)

@testset "Gradient of loglikelihood - 2-site" begin
X = [sample(p)[1] for _ in 1:10]
X = [sample(rng, p)[1] for _ in 1:10]
for l in 1:length(p)-1
orthogonalize_two_site_center!(p, l)
p_cp = deepcopy(p)
A = _merge_tensors(p_cp[l], p_cp[l+1])
dlldA, ll = grad_loglikelihood_two_site(p_cp, l, X)
@test ll ≈ loglikelihood(p_cp, X)
η = 1e-3
lls = map(1:100) do _
orthogonalize_two_site_center!(p_cp, l)
η = 1e-4
lls = map(1:100) do it
A = _merge_tensors(p_cp[l], p_cp[l+1])
dlldA, ll = grad_loglikelihood_two_site(p_cp, l, X)
@test ll ≈ loglikelihood(p_cp, X)
p_cp[l], p_cp[l+1] = TensorTrains._split_tensor(A + η*dlldA)
ll
end
Expand Down