-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
21,892 additions
and
487 deletions.
There are no files selected for viewing
850 changes: 607 additions & 243 deletions
850
.ipynb_checkpoints/Pair Parameter Inference-checkpoint.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
4,639 changes: 4,534 additions & 105 deletions
4,639
.ipynb_checkpoints/Triple Parameter Inference-checkpoint.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,275 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"using TorusEvol \n", | ||
"using TimerOutputs\n", | ||
"using Distributions \n", | ||
"\n", | ||
"diff = WrappedDiffusion(-1.0, -0.8, 0.1, 0.1, 1.0, 1.0, 0.2)\n", | ||
"jdiff = jumping(diff, 10.0)\n", | ||
"t = 1.0\n", | ||
"num = 100000\n", | ||
"transd = transdist(diff, t, [-1.0, 1.0]);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[0m\u001b[1m ────────────────────────────────────────────────────────────────────────────────\u001b[22m\n", | ||
"\u001b[0m\u001b[1m \u001b[22m Time Allocations \n", | ||
" ─────────────────────── ────────────────────────\n", | ||
" Tot / % measured: 56.7s / 66.3% 19.9GiB / 99.4% \n", | ||
"\n", | ||
" Section ncalls time %tot avg alloc %tot avg\n", | ||
" ────────────────────────────────────────────────────────────────────────────────\n", | ||
" wd transdist logpdf 1 33.4s 88.7% 33.4s 17.7GiB 89.9% 17.7GiB\n", | ||
" logpdf wn drift 9 31.3s 83.2% 3.48s 17.0GiB 86.3% 1.89GiB\n", | ||
" logaddexp wrappe... 9 1.77s 4.7% 197ms 695MiB 3.4% 77.2MiB\n", | ||
" wd statdist logpdf 1 4.22s 11.2% 4.22s 1.98GiB 10.0% 1.98GiB\n", | ||
" wd transdist rand 1 27.7ms 0.1% 27.7ms 4.73MiB 0.0% 4.73MiB\n", | ||
" rand 1 25.2ms 0.1% 25.2ms 2.29MiB 0.0% 2.29MiB\n", | ||
" wd statdist rand 1 4.74ms 0.0% 4.74ms 1.53MiB 0.0% 1.53MiB\n", | ||
"\u001b[0m\u001b[1m ────────────────────────────────────────────────────────────────────────────────\u001b[22m" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"@timeit to \"wd statdist rand\" y = rand(statdist(diff), num)\n", | ||
"@timeit to \"wd transdist rand\" x = rand(transd, num)\n", | ||
"print(to)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[0m\u001b[1m ────────────────────────────────────────────────────────────────────────────────\u001b[22m\n", | ||
"\u001b[0m\u001b[1m \u001b[22m Time Allocations \n", | ||
" ─────────────────────── ────────────────────────\n", | ||
" Tot / % measured: 3.90s / 100.0% 1.96GiB / 100.0% \n", | ||
"\n", | ||
" Section ncalls time %tot avg alloc %tot avg\n", | ||
" ────────────────────────────────────────────────────────────────────────────────\n", | ||
" wd transdist logpdf 1 3.38s 86.8% 3.38s 1.77GiB 90.4% 1.77GiB\n", | ||
" logpdf wn drift 9 2.96s 75.9% 329ms 1.70GiB 86.9% 194MiB\n", | ||
" logaddexp wrappe... 9 424ms 10.9% 47.1ms 68.6MiB 3.4% 7.62MiB\n", | ||
" wd statdist logpdf 1 514ms 13.2% 514ms 193MiB 9.6% 193MiB\n", | ||
"\u001b[0m\u001b[1m ────────────────────────────────────────────────────────────────────────────────\u001b[22m" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"reset_timer!(to)\n", | ||
"@timeit to \"wd statdist logpdf\" logpdf(statdist(diff), y)\n", | ||
"@timeit to \"wd transdist logpdf\" logpdf(transd, x)\n", | ||
"print(to)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"using Turing, Zygote, Distributions, ReverseDiff\n", | ||
"using FillArrays\n", | ||
"using StatsPlots\n", | ||
"using Random\n", | ||
"using LinearAlgebra\n", | ||
"include(\"src/EvolutionSimulator.jl\")\n", | ||
"\n", | ||
"Turing.setadbackend(:reversediff);\n", | ||
"Turing.setrdcache(true)\n", | ||
"Random.seed!(3);\n", | ||
"Threads.nthreads()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Define Gaussian mixture model.\n", | ||
"w = [0.5, 0.5]\n", | ||
"μ = [-3.5, 0.5]\n", | ||
"mixturemodel = MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w)\n", | ||
"\n", | ||
"# We draw the data points.\n", | ||
"N = 60\n", | ||
"x = rand(mixturemodel, N);\n", | ||
"scatter(x[1, :], x[2, :]; legend=false, title=\"Synthetic Dataset\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"@model function gaussian_mixture_model(x)\n", | ||
" # Draw the parameters for each of the K=2 clusters from a standard normal distribution.\n", | ||
" K = 2\n", | ||
" μ ~ MvNormal(Zeros(K), I)\n", | ||
"\n", | ||
" # Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.\n", | ||
" w ~ Dirichlet(K, 1.0)\n", | ||
" # Alternatively, one could use a fixed set of weights.\n", | ||
" # w = fill(1/K, K)\n", | ||
"\n", | ||
" # Construct categorical distribution of assignments.\n", | ||
" distribution_assignments = Categorical(w)\n", | ||
"\n", | ||
" # Construct multivariate normal distributions of each cluster.\n", | ||
" D, N = size(x)\n", | ||
" distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]\n", | ||
"\n", | ||
" # Draw assignments for each datum and generate it from the multivariate normal distribution.\n", | ||
" k = Vector{Int}(undef, N)\n", | ||
" for i in 1:N\n", | ||
" k[i] ~ distribution_assignments\n", | ||
" x[:, i] ~ distribution_clusters[k[i]]\n", | ||
" end\n", | ||
"\n", | ||
" return k\n", | ||
"end\n", | ||
"\n", | ||
"model = gaussian_mixture_model(x);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"sampler = Gibbs(PG(50, :k), HMC(0.05, 10, :μ, :w))\n", | ||
"nsamples = 60\n", | ||
"nchains = 5\n", | ||
"chains = sample(model, sampler, MCMCThreads(), nsamples, nchains);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"plot(chains[[\"μ[1]\", \"μ[2]\"]]; colordim=:parameter, legend=true)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# prepare data\n", | ||
"human = retrievepdb(\"1A3N\", dir=\"data/pdb\")\n", | ||
"chainX, chainY = human[\"A\"], human[\"B\"]\n", | ||
"X = vcat(reshape(sequence(chainX), 1, :), angles(chainX))\n", | ||
"X[isnan.(X)] .= 0.0\n", | ||
"Y = vcat(reshape(sequence(chainY), 1, :), angles(chainY))\n", | ||
"Y[isnan.(Y)] .= 0.0\n", | ||
"\n", | ||
"l = 10\n", | ||
"\n", | ||
"X = X[:, 1:end]\n", | ||
"Y = Y[:, 1:end]\n", | ||
"n = size(X, 2)\n", | ||
"m = size(Y, 2)\n", | ||
"data = Array{Float64}(undef, 6, max(n, m))\n", | ||
"data[1:3, 1:n] .= X\n", | ||
"data[4:6, 1:m] .= Y\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"include(\"src/PairEvol.jl\")\n", | ||
"using StatsPlots" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pair_align(chainX, chainY; t=0.1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = PairEvol(data, n, m)\n", | ||
"#alg = HMC(0.05, 10, :t, :mean, :var, :γ, :r, :λ, :seq_length)\n", | ||
"alg = Gibbs(HMC(0.05, 10, :t), MH(:mean, :var, :γ, :r, :λ, :seq_length))\n", | ||
"nchains = 3\n", | ||
"nsamples = 5\n", | ||
"@time chains = sample(model, alg, MCMCThreads(), nsamples, nchains);\n", | ||
"\n", | ||
"# Index the chain with the persistence probabilities.\n", | ||
"plot(chains[[\"t\", \"r\", \"γ\"]]; colordim=:parameter, legend=true)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Julia 1.9.0-rc2", | ||
"language": "julia", | ||
"name": "julia-1.9" | ||
}, | ||
"language_info": { | ||
"file_extension": ".jl", | ||
"mimetype": "application/julia", | ||
"name": "julia", | ||
"version": "1.9.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
Oops, something went wrong.