-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
1,782 additions
and
1,098 deletions.
There are no files selected for viewing
799 changes: 259 additions & 540 deletions
799
.ipynb_checkpoints/Pair Parameter Inference-checkpoint.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
384 changes: 384 additions & 0 deletions
384
.ipynb_checkpoints/Triple Parameter Inference-checkpoint.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,384 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 1. Prepare data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### 1.1 Load real data example" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": { | ||
"scrolled": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mFile exists: 1A3N\n", | ||
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mFile exists: 1MBN\n", | ||
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mFile exists: 1A3N\n" | ||
] | ||
}, | ||
{ | ||
"ename": "LoadError", | ||
"evalue": "UndefVarError: `LED` not defined", | ||
"output_type": "error", | ||
"traceback": [ | ||
"UndefVarError: `LED` not defined", | ||
"" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"using TorusEvol\n", | ||
"using Distributions\n", | ||
"\n", | ||
"# Underlying evolutionary process\n", | ||
"t_Y = 0.5; t_Z=0.3; t_W=0.9\n", | ||
"λ=0.03; μ=0.0308; r=0.4\n", | ||
"τ = TKF92([t_Y+t_Z], λ, μ, r)\n", | ||
"S = WAG_SubstitutionProcess()\n", | ||
"μ_𝜙=-1.0; μ_𝜓=-0.8; σ_𝜙=0.8; σ_𝜓=0.8; α_𝜙=0.5; α_𝜓=1.0; α_cov=0.1; γ=0.2\n", | ||
"Θ = JumpingWrappedDiffusion(μ_𝜙, μ_𝜓, σ_𝜙, σ_𝜓, α_𝜙, α_𝜓, α_cov, γ)\n", | ||
"ξ = ProductProcess(S, Θ)\n", | ||
"Γ = ChainJointDistribution(ξ, τ)\n", | ||
"\n", | ||
"chainY = from_pdb(\"1A3N\", \"A\"); Y = data(chainY)\n", | ||
"chainZ = from_pdb(\"1MBN\", \"A\"); Z = data(chainZ)\n", | ||
"chainW = from_pdb(\"1A3N\", \"B\"); W = data(chainW)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"-1490.1530021439378-24.747921659483893" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"3×162 Matrix{Integer}:\n", | ||
" 1 1 1 1 1 1 1 1 1 1 1 1 1 … 1 1 1 1 1 0 0 0 0 0 0 1\n", | ||
" 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 1\n", | ||
" 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"τ_XYZ = TKF92([t_Y, t_Z], λ, μ, r; known_ancestor=false)\n", | ||
"Γ = ChainJointDistribution(ξ, τ_XYZ)\n", | ||
"α_YZ = get_α(τ_XYZ, (Y, Z))\n", | ||
"lp = logpdfα!(α_YZ, Γ, (Y, Z)); print(lp)\n", | ||
"\n", | ||
"c = ConditionedAlignmentDistribution(τ_XYZ, α_YZ)\n", | ||
"M_XYZ_data = rand(ConditionedAlignmentDistribution(τ_XYZ, α_YZ)); M_XYZ = Alignment(M_XYZ_data)\n", | ||
"print(logpdf(c, M_XYZ_data))\n", | ||
"data(M_XYZ)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "LoadError", | ||
"evalue": "UndefVarError: `LengthEquilibriumDefinition` not defined", | ||
"output_type": "error", | ||
"traceback": [ | ||
"UndefVarError: `LengthEquilibriumDefinition` not defined", | ||
"", | ||
"Stacktrace:", | ||
" [1] top-level scope", | ||
" @ In[8]:2" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"#X = hiddenchain_from_alignment(Y, Z, t_Y, t_Z, M_XYZ, ξ)\n", | ||
"LengthEquilibriumDefinition(λ, μ, r)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## 2. Parameter Inference Bayesian Model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### 2.1 Set up priors for evolutionary processes" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"using Turing, DynamicPPL\n", | ||
"using LinearAlgebra\n", | ||
"using LogExpFunctions \n", | ||
"using Plots, StatsPlots\n", | ||
"using Random\n", | ||
"\n", | ||
"import Base: length, eltype\n", | ||
"import Distributions: _rand!, logpdf\n", | ||
"\n", | ||
"Turing.setprogress!(true)\n", | ||
"\n", | ||
"struct ScaledBeta <: ContinuousUnivariateDistribution \n", | ||
" be::Beta \n", | ||
" function ScaledBeta(α::Real, β::Real)\n", | ||
" new(Beta(α, β))\n", | ||
" end\n", | ||
"end\n", | ||
"Distributions.rand(rng::AbstractRNG, d::ScaledBeta) = rand(d.be)*2 - 1\n", | ||
"Distributions.logpdf(d::ScaledBeta, x::Real) = logpdf(d.be, (x+1) / 2)\n", | ||
"\n", | ||
"\n", | ||
"struct CompetingExponential <: ContinuousMultivariateDistribution \n", | ||
" ex::Exponential\n", | ||
" function CompetingExponential(rate::Real)\n", | ||
" new(Exponential(rate))\n", | ||
" end\n", | ||
"end \n", | ||
"Base.eltype(d::CompetingExponential) = Float64 \n", | ||
"Base.length(d::CompetingExponential) = 2\n", | ||
"\n", | ||
"function Distributions._rand!(rng::AbstractRNG, d::CompetingExponential, x::AbstractVector{<:Real})\n", | ||
" λ = rand(rng, d.ex)\n", | ||
" μ = rand(rng, d.ex)\n", | ||
" if λ > μ \n", | ||
" tmp = λ; λ = μ; μ=tmp \n", | ||
" end\n", | ||
" x .= [λ, μ]\n", | ||
" return x\n", | ||
"end\n", | ||
"\n", | ||
"function Distributions._logpdf(d::CompetingExponential, x::AbstractArray)\n", | ||
" if x[1] > x[2]\n", | ||
" return -Inf\n", | ||
" end\n", | ||
" return log(2) + logpdf(d.ex, x[1]) + logpdf(d.ex, x[2])\n", | ||
"end\n", | ||
"\n", | ||
"@model function tkf92_prior()\n", | ||
" λμ ~ CompetingExponential(1.0)\n", | ||
" λ = λμ[1]; μ = λμ[2]\n", | ||
" r ~ Uniform(0.0, 1.0)\n", | ||
"\n", | ||
" # Require birth rate lower than death rate\n", | ||
" if λ > μ || λ ≤ 0 || μ ≤ 0 || r ≤ 0 || r ≥ 1\n", | ||
" μ = NaN; λ = NaN\n", | ||
" end\n", | ||
" return λ, μ, r\n", | ||
"end;\n", | ||
"\n", | ||
"@model function jwndiff_prior()\n", | ||
" μ ~ filldist(Uniform(-π, π), 2)\n", | ||
" σ² ~ filldist(Gamma(π * 0.1), 2)\n", | ||
" α ~ filldist(Gamma(π * 0.1), 2)\n", | ||
" γ ~ Exponential(1.0) # jumping rate\n", | ||
" α_corr ~ ScaledBeta(3, 3)\n", | ||
" \n", | ||
" # Require valid covariance matrices\n", | ||
" if any(σ² .≤ 0) || any(α .≤ 0) || γ ≤ 0 \n", | ||
" σ² .= NaN; α .= NaN; γ = NaN\n", | ||
" end\n", | ||
" α_cov = α_corr * sqrt(α[1] * α[2])\n", | ||
" if α_cov^2 > α[1]*α[2]\n", | ||
" α_cov = NaN\n", | ||
" end\n", | ||
" \n", | ||
" return μ[1], μ[2], sqrt(σ²[1]), sqrt(σ²[2]), α[1], α[2], α_cov, γ\n", | ||
"end;" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### 2.2 Set up sampler" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"torus_proposal(v) = MixtureModel([WrappedNormal(v, I), WrappedNormal(v, 20*I)], [0.8, 0.2])\n", | ||
"mv_rw_proposal(v::AbstractVector, cov) = MvNormal(v, cov)\n", | ||
"rw_proposal(x, var) = Normal(x, var)\n", | ||
"\n", | ||
"\n", | ||
"sampler = Gibbs(MH(:t => v -> rw_proposal(v, 0.2)),\n", | ||
" MH(Symbol(\"Θ.μ\") => v -> torus_proposal(v)),\n", | ||
" MH(Symbol(\"Θ.σ²\") => v -> mv_rw_proposal(v, 0.4*I)),\n", | ||
" MH(Symbol(\"Θ.α\") => v -> mv_rw_proposal(v, 0.4*I)),\n", | ||
" MH(Symbol(\"Θ.α_corr\") => x -> rw_proposal(x, 0.5)),\n", | ||
" MH(Symbol(\"Θ.γ\") => x -> rw_proposal(x, 0.5)),\n", | ||
" MH(Symbol(\"τ.λμ\") => v -> mv_rw_proposal(v, [0.4 0.1; 0.1 0.6])),\n", | ||
" MH(Symbol(\"τ.r\") => x -> rw_proposal(x, 0.5))\n", | ||
" );" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"using Memoization \n", | ||
"\n", | ||
"@memoize get_αs(pairs) = get_α.(Ref(TKF92([1.0], 0.2, 0.3, 0.4)), pairs)\n", | ||
"@memoize get_Bs(pairs) = get_B.(pairs)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### 2.3 Prepare probabilistic model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"using TimerOutputs\n", | ||
"\n", | ||
"@model function pair_param_inference_simple(pairs)\n", | ||
" # ____________________________________________________________________________________________________\n", | ||
" # Step 1 - Sample prior parameters\n", | ||
" \n", | ||
" # Time parameter\n", | ||
" t ~ Exponential(1.0) \n", | ||
" # Alignment parameters\n", | ||
" @submodel prefix=\"τ\" Λ = tkf92_prior()\n", | ||
" # Dihedral parameters \n", | ||
" @submodel prefix=\"Θ\" Ξ = jwndiff_prior()\n", | ||
" # Check parameter validity \n", | ||
" if t ≤ 0 || any(isnan.(Ξ)) || any(isnan.(Λ))\n", | ||
" Turing.@addlogprob! -Inf; return\n", | ||
" end\n", | ||
" \n", | ||
" # ____________________________________________________________________________________________________\n", | ||
" # Step 2 - Construct processes \n", | ||
" \n", | ||
" # Substitution Process - no parameters for simplicity, use fully empirical model\n", | ||
" S = WAG_SubstitutionProcess()\n", | ||
" # Dihedral Process\n", | ||
" Θ = JumpingWrappedDiffusion(Ξ...)\n", | ||
" # Joint sequence-structure site level process with one regime\n", | ||
" ξ = MixtureProductProcess([1.0], hcat([S, Θ]))\n", | ||
" \n", | ||
" # Alignment model\n", | ||
" τ = TKF92([t], Λ...)\n", | ||
" \n", | ||
" # Chain level model\n", | ||
" Γ = ChainJointDistribution(ξ, τ)\n", | ||
" \n", | ||
" # ____________________________________________________________________________________________________\n", | ||
" # Step 3 - Observe each pair X, Y by proxy of their joint probability, marginalising over alignments\n", | ||
" α = get_αs(pairs)\n", | ||
" B = get_Bs(pairs)\n", | ||
" for i ∈ eachindex(pairs)\n", | ||
" X, Y = pairs[i]\n", | ||
" # (X, Y) ~ ChainJointDistribution(ξ, τ)\n", | ||
" fulljointlogpdf!(B[i], ξ, t, X, Y)\n", | ||
" Turing.@addlogprob! logpdfαB!(α[i], B[i], Γ, (X, Y))\n", | ||
" end\n", | ||
" \n", | ||
" return Γ\n", | ||
"end;" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### 2.4 Sample from the model and check results" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"num_samples = 200\n", | ||
"num_chains = 3\n", | ||
"model = pair_param_inference_simple(simulated_data)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"chain = sample(model, sampler, MCMCThreads(), num_samples, num_chains)\n", | ||
"p = plot(chain, fontfamily=\"JuliaMono\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"chain" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Julia 1.9.0-rc2", | ||
"language": "julia", | ||
"name": "julia-1.9" | ||
}, | ||
"language_info": { | ||
"file_extension": ".jl", | ||
"mimetype": "application/julia", | ||
"name": "julia", | ||
"version": "1.9.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.