Skip to content

Commit

Permalink
PAIR IFNERENCE
Browse files Browse the repository at this point in the history
  • Loading branch information
nifets committed May 18, 2023
1 parent 4cc76d8 commit 6bbd0d2
Show file tree
Hide file tree
Showing 11 changed files with 21,892 additions and 487 deletions.
850 changes: 607 additions & 243 deletions .ipynb_checkpoints/Pair Parameter Inference-checkpoint.ipynb

Large diffs are not rendered by default.

4,639 changes: 4,534 additions & 105 deletions .ipynb_checkpoints/Triple Parameter Inference-checkpoint.ipynb

Large diffs are not rendered by default.

275 changes: 275 additions & 0 deletions .ipynb_checkpoints/test-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"using TorusEvol \n",
"using TimerOutputs\n",
"using Distributions \n",
"\n",
"diff = WrappedDiffusion(-1.0, -0.8, 0.1, 0.1, 1.0, 1.0, 0.2)\n",
"jdiff = jumping(diff, 10.0)\n",
"t = 1.0\n",
"num = 100000\n",
"transd = transdist(diff, t, [-1.0, 1.0]);"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[0m\u001b[1m ────────────────────────────────────────────────────────────────────────────────\u001b[22m\n",
"\u001b[0m\u001b[1m \u001b[22m Time Allocations \n",
" ─────────────────────── ────────────────────────\n",
" Tot / % measured: 56.7s / 66.3% 19.9GiB / 99.4% \n",
"\n",
" Section ncalls time %tot avg alloc %tot avg\n",
" ────────────────────────────────────────────────────────────────────────────────\n",
" wd transdist logpdf 1 33.4s 88.7% 33.4s 17.7GiB 89.9% 17.7GiB\n",
" logpdf wn drift 9 31.3s 83.2% 3.48s 17.0GiB 86.3% 1.89GiB\n",
" logaddexp wrappe... 9 1.77s 4.7% 197ms 695MiB 3.4% 77.2MiB\n",
" wd statdist logpdf 1 4.22s 11.2% 4.22s 1.98GiB 10.0% 1.98GiB\n",
" wd transdist rand 1 27.7ms 0.1% 27.7ms 4.73MiB 0.0% 4.73MiB\n",
" rand 1 25.2ms 0.1% 25.2ms 2.29MiB 0.0% 2.29MiB\n",
" wd statdist rand 1 4.74ms 0.0% 4.74ms 1.53MiB 0.0% 1.53MiB\n",
"\u001b[0m\u001b[1m ────────────────────────────────────────────────────────────────────────────────\u001b[22m"
]
}
],
"source": [
"@timeit to \"wd statdist rand\" y = rand(statdist(diff), num)\n",
"@timeit to \"wd transdist rand\" x = rand(transd, num)\n",
"print(to)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[0m\u001b[1m ────────────────────────────────────────────────────────────────────────────────\u001b[22m\n",
"\u001b[0m\u001b[1m \u001b[22m Time Allocations \n",
" ─────────────────────── ────────────────────────\n",
" Tot / % measured: 3.90s / 100.0% 1.96GiB / 100.0% \n",
"\n",
" Section ncalls time %tot avg alloc %tot avg\n",
" ────────────────────────────────────────────────────────────────────────────────\n",
" wd transdist logpdf 1 3.38s 86.8% 3.38s 1.77GiB 90.4% 1.77GiB\n",
" logpdf wn drift 9 2.96s 75.9% 329ms 1.70GiB 86.9% 194MiB\n",
" logaddexp wrappe... 9 424ms 10.9% 47.1ms 68.6MiB 3.4% 7.62MiB\n",
" wd statdist logpdf 1 514ms 13.2% 514ms 193MiB 9.6% 193MiB\n",
"\u001b[0m\u001b[1m ────────────────────────────────────────────────────────────────────────────────\u001b[22m"
]
}
],
"source": [
"reset_timer!(to)\n",
"@timeit to \"wd statdist logpdf\" logpdf(statdist(diff), y)\n",
"@timeit to \"wd transdist logpdf\" logpdf(transd, x)\n",
"print(to)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"using Turing, Zygote, Distributions, ReverseDiff\n",
"using FillArrays\n",
"using StatsPlots\n",
"using Random\n",
"using LinearAlgebra\n",
"include(\"src/EvolutionSimulator.jl\")\n",
"\n",
"Turing.setadbackend(:reversediff);\n",
"Turing.setrdcache(true)\n",
"Random.seed!(3);\n",
"Threads.nthreads()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define Gaussian mixture model.\n",
"w = [0.5, 0.5]\n",
"μ = [-3.5, 0.5]\n",
"mixturemodel = MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w)\n",
"\n",
"# We draw the data points.\n",
"N = 60\n",
"x = rand(mixturemodel, N);\n",
"scatter(x[1, :], x[2, :]; legend=false, title=\"Synthetic Dataset\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@model function gaussian_mixture_model(x)\n",
" # Draw the parameters for each of the K=2 clusters from a standard normal distribution.\n",
" K = 2\n",
" μ ~ MvNormal(Zeros(K), I)\n",
"\n",
" # Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.\n",
" w ~ Dirichlet(K, 1.0)\n",
" # Alternatively, one could use a fixed set of weights.\n",
" # w = fill(1/K, K)\n",
"\n",
" # Construct categorical distribution of assignments.\n",
" distribution_assignments = Categorical(w)\n",
"\n",
" # Construct multivariate normal distributions of each cluster.\n",
" D, N = size(x)\n",
" distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]\n",
"\n",
" # Draw assignments for each datum and generate it from the multivariate normal distribution.\n",
" k = Vector{Int}(undef, N)\n",
" for i in 1:N\n",
" k[i] ~ distribution_assignments\n",
" x[:, i] ~ distribution_clusters[k[i]]\n",
" end\n",
"\n",
" return k\n",
"end\n",
"\n",
"model = gaussian_mixture_model(x);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sampler = Gibbs(PG(50, :k), HMC(0.05, 10, :μ, :w))\n",
"nsamples = 60\n",
"nchains = 5\n",
"chains = sample(model, sampler, MCMCThreads(), nsamples, nchains);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot(chains[[\"μ[1]\", \"μ[2]\"]]; colordim=:parameter, legend=true)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# prepare data\n",
"human = retrievepdb(\"1A3N\", dir=\"data/pdb\")\n",
"chainX, chainY = human[\"A\"], human[\"B\"]\n",
"X = vcat(reshape(sequence(chainX), 1, :), angles(chainX))\n",
"X[isnan.(X)] .= 0.0\n",
"Y = vcat(reshape(sequence(chainY), 1, :), angles(chainY))\n",
"Y[isnan.(Y)] .= 0.0\n",
"\n",
"l = 10\n",
"\n",
"X = X[:, 1:end]\n",
"Y = Y[:, 1:end]\n",
"n = size(X, 2)\n",
"m = size(Y, 2)\n",
"data = Array{Float64}(undef, 6, max(n, m))\n",
"data[1:3, 1:n] .= X\n",
"data[4:6, 1:m] .= Y\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"include(\"src/PairEvol.jl\")\n",
"using StatsPlots"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pair_align(chainX, chainY; t=0.1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = PairEvol(data, n, m)\n",
"#alg = HMC(0.05, 10, :t, :mean, :var, :γ, :r, :λ, :seq_length)\n",
"alg = Gibbs(HMC(0.05, 10, :t), MH(:mean, :var, :γ, :r, :λ, :seq_length))\n",
"nchains = 3\n",
"nsamples = 5\n",
"@time chains = sample(model, alg, MCMCThreads(), nsamples, nchains);\n",
"\n",
"# Index the chain with the persistence probabilities.\n",
"plot(chains[[\"t\", \"r\", \"γ\"]]; colordim=:parameter, legend=true)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.9.0-rc2",
"language": "julia",
"name": "julia-1.9"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
4,620 changes: 4,531 additions & 89 deletions Triple Parameter Inference.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 6bbd0d2

Please sign in to comment.