Skip to content

Commit

Permalink
road towards triples!
Browse files Browse the repository at this point in the history
  • Loading branch information
nifets committed May 17, 2023
1 parent a142fd0 commit 18e75e7
Show file tree
Hide file tree
Showing 13 changed files with 1,782 additions and 1,098 deletions.
799 changes: 259 additions & 540 deletions .ipynb_checkpoints/Pair Parameter Inference-checkpoint.ipynb

Large diffs are not rendered by default.

384 changes: 384 additions & 0 deletions .ipynb_checkpoints/Triple Parameter Inference-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,384 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Prepare data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1.1 Load real data example"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mFile exists: 1A3N\n",
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mFile exists: 1MBN\n",
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mFile exists: 1A3N\n"
]
},
{
"ename": "LoadError",
"evalue": "UndefVarError: `LED` not defined",
"output_type": "error",
"traceback": [
"UndefVarError: `LED` not defined",
""
]
}
],
"source": [
"using TorusEvol\n",
"using Distributions\n",
"\n",
"# Underlying evolutionary process\n",
"t_Y = 0.5; t_Z=0.3; t_W=0.9\n",
"λ=0.03; μ=0.0308; r=0.4\n",
"τ = TKF92([t_Y+t_Z], λ, μ, r)\n",
"S = WAG_SubstitutionProcess()\n",
"μ_𝜙=-1.0; μ_𝜓=-0.8; σ_𝜙=0.8; σ_𝜓=0.8; α_𝜙=0.5; α_𝜓=1.0; α_cov=0.1; γ=0.2\n",
"Θ = JumpingWrappedDiffusion(μ_𝜙, μ_𝜓, σ_𝜙, σ_𝜓, α_𝜙, α_𝜓, α_cov, γ)\n",
"ξ = ProductProcess(S, Θ)\n",
"Γ = ChainJointDistribution(ξ, τ)\n",
"\n",
"chainY = from_pdb(\"1A3N\", \"A\"); Y = data(chainY)\n",
"chainZ = from_pdb(\"1MBN\", \"A\"); Z = data(chainZ)\n",
"chainW = from_pdb(\"1A3N\", \"B\"); W = data(chainW)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-1490.1530021439378-24.747921659483893"
]
},
{
"data": {
"text/plain": [
"3×162 Matrix{Integer}:\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1 … 1 1 1 1 1 0 0 0 0 0 0 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"τ_XYZ = TKF92([t_Y, t_Z], λ, μ, r; known_ancestor=false)\n",
"Γ = ChainJointDistribution(ξ, τ_XYZ)\n",
"α_YZ = get_α(τ_XYZ, (Y, Z))\n",
"lp = logpdfα!(α_YZ, Γ, (Y, Z)); print(lp)\n",
"\n",
"c = ConditionedAlignmentDistribution(τ_XYZ, α_YZ)\n",
"M_XYZ_data = rand(ConditionedAlignmentDistribution(τ_XYZ, α_YZ)); M_XYZ = Alignment(M_XYZ_data)\n",
"print(logpdf(c, M_XYZ_data))\n",
"data(M_XYZ)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"ename": "LoadError",
"evalue": "UndefVarError: `LengthEquilibriumDefinition` not defined",
"output_type": "error",
"traceback": [
"UndefVarError: `LengthEquilibriumDefinition` not defined",
"",
"Stacktrace:",
" [1] top-level scope",
" @ In[8]:2"
]
}
],
"source": [
"#X = hiddenchain_from_alignment(Y, Z, t_Y, t_Z, M_XYZ, ξ)\n",
"LengthEquilibriumDefinition(λ, μ, r)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Parameter Inference Bayesian Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.1 Set up priors for evolutionary processes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"using Turing, DynamicPPL\n",
"using LinearAlgebra\n",
"using LogExpFunctions \n",
"using Plots, StatsPlots\n",
"using Random\n",
"\n",
"import Base: length, eltype\n",
"import Distributions: _rand!, logpdf\n",
"\n",
"Turing.setprogress!(true)\n",
"\n",
"struct ScaledBeta <: ContinuousUnivariateDistribution \n",
" be::Beta \n",
" function ScaledBeta(α::Real, β::Real)\n",
" new(Beta(α, β))\n",
" end\n",
"end\n",
"Distributions.rand(rng::AbstractRNG, d::ScaledBeta) = rand(d.be)*2 - 1\n",
"Distributions.logpdf(d::ScaledBeta, x::Real) = logpdf(d.be, (x+1) / 2)\n",
"\n",
"\n",
"struct CompetingExponential <: ContinuousMultivariateDistribution \n",
" ex::Exponential\n",
" function CompetingExponential(rate::Real)\n",
" new(Exponential(rate))\n",
" end\n",
"end \n",
"Base.eltype(d::CompetingExponential) = Float64 \n",
"Base.length(d::CompetingExponential) = 2\n",
"\n",
"function Distributions._rand!(rng::AbstractRNG, d::CompetingExponential, x::AbstractVector{<:Real})\n",
" λ = rand(rng, d.ex)\n",
" μ = rand(rng, d.ex)\n",
" if λ > μ \n",
" tmp = λ; λ = μ; μ=tmp \n",
" end\n",
" x .= [λ, μ]\n",
" return x\n",
"end\n",
"\n",
"function Distributions._logpdf(d::CompetingExponential, x::AbstractArray)\n",
" if x[1] > x[2]\n",
" return -Inf\n",
" end\n",
" return log(2) + logpdf(d.ex, x[1]) + logpdf(d.ex, x[2])\n",
"end\n",
"\n",
"@model function tkf92_prior()\n",
" λμ ~ CompetingExponential(1.0)\n",
" λ = λμ[1]; μ = λμ[2]\n",
" r ~ Uniform(0.0, 1.0)\n",
"\n",
" # Require birth rate lower than death rate\n",
" if λ > μ || λ ≤ 0 || μ ≤ 0 || r ≤ 0 || r ≥ 1\n",
" μ = NaN; λ = NaN\n",
" end\n",
" return λ, μ, r\n",
"end;\n",
"\n",
"@model function jwndiff_prior()\n",
" μ ~ filldist(Uniform(-π, π), 2)\n",
" σ² ~ filldist(Gamma(π * 0.1), 2)\n",
" α ~ filldist(Gamma(π * 0.1), 2)\n",
" γ ~ Exponential(1.0) # jumping rate\n",
" α_corr ~ ScaledBeta(3, 3)\n",
" \n",
" # Require valid covariance matrices\n",
" if any(σ² .≤ 0) || any(α .≤ 0) || γ ≤ 0 \n",
" σ² .= NaN; α .= NaN; γ = NaN\n",
" end\n",
" α_cov = α_corr * sqrt(α[1] * α[2])\n",
" if α_cov^2 > α[1]*α[2]\n",
" α_cov = NaN\n",
" end\n",
" \n",
" return μ[1], μ[2], sqrt(σ²[1]), sqrt(σ²[2]), α[1], α[2], α_cov, γ\n",
"end;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.2 Set up sampler"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torus_proposal(v) = MixtureModel([WrappedNormal(v, I), WrappedNormal(v, 20*I)], [0.8, 0.2])\n",
"mv_rw_proposal(v::AbstractVector, cov) = MvNormal(v, cov)\n",
"rw_proposal(x, var) = Normal(x, var)\n",
"\n",
"\n",
"sampler = Gibbs(MH(:t => v -> rw_proposal(v, 0.2)),\n",
" MH(Symbol(\"Θ.μ\") => v -> torus_proposal(v)),\n",
" MH(Symbol(\"Θ.σ²\") => v -> mv_rw_proposal(v, 0.4*I)),\n",
" MH(Symbol(\"Θ.α\") => v -> mv_rw_proposal(v, 0.4*I)),\n",
" MH(Symbol(\"Θ.α_corr\") => x -> rw_proposal(x, 0.5)),\n",
" MH(Symbol(\"Θ.γ\") => x -> rw_proposal(x, 0.5)),\n",
" MH(Symbol(\"τ.λμ\") => v -> mv_rw_proposal(v, [0.4 0.1; 0.1 0.6])),\n",
" MH(Symbol(\"τ.r\") => x -> rw_proposal(x, 0.5))\n",
" );"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"using Memoization \n",
"\n",
"@memoize get_αs(pairs) = get_α.(Ref(TKF92([1.0], 0.2, 0.3, 0.4)), pairs)\n",
"@memoize get_Bs(pairs) = get_B.(pairs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.3 Prepare probabilistic model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"using TimerOutputs\n",
"\n",
"@model function pair_param_inference_simple(pairs)\n",
" # ____________________________________________________________________________________________________\n",
" # Step 1 - Sample prior parameters\n",
" \n",
" # Time parameter\n",
" t ~ Exponential(1.0) \n",
" # Alignment parameters\n",
" @submodel prefix=\"τ\" Λ = tkf92_prior()\n",
" # Dihedral parameters \n",
" @submodel prefix=\"Θ\" Ξ = jwndiff_prior()\n",
" # Check parameter validity \n",
" if t ≤ 0 || any(isnan.(Ξ)) || any(isnan.(Λ))\n",
" Turing.@addlogprob! -Inf; return\n",
" end\n",
" \n",
" # ____________________________________________________________________________________________________\n",
" # Step 2 - Construct processes \n",
" \n",
" # Substitution Process - no parameters for simplicity, use fully empirical model\n",
" S = WAG_SubstitutionProcess()\n",
" # Dihedral Process\n",
" Θ = JumpingWrappedDiffusion(Ξ...)\n",
" # Joint sequence-structure site level process with one regime\n",
" ξ = MixtureProductProcess([1.0], hcat([S, Θ]))\n",
" \n",
" # Alignment model\n",
" τ = TKF92([t], Λ...)\n",
" \n",
" # Chain level model\n",
" Γ = ChainJointDistribution(ξ, τ)\n",
" \n",
" # ____________________________________________________________________________________________________\n",
" # Step 3 - Observe each pair X, Y by proxy of their joint probability, marginalising over alignments\n",
" α = get_αs(pairs)\n",
" B = get_Bs(pairs)\n",
" for i ∈ eachindex(pairs)\n",
" X, Y = pairs[i]\n",
" # (X, Y) ~ ChainJointDistribution(ξ, τ)\n",
" fulljointlogpdf!(B[i], ξ, t, X, Y)\n",
" Turing.@addlogprob! logpdfαB!(α[i], B[i], Γ, (X, Y))\n",
" end\n",
" \n",
" return Γ\n",
"end;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2.4 Sample from the model and check results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"num_samples = 200\n",
"num_chains = 3\n",
"model = pair_param_inference_simple(simulated_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chain = sample(model, sampler, MCMCThreads(), num_samples, num_chains)\n",
"p = plot(chain, fontfamily=\"JuliaMono\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.9.0-rc2",
"language": "julia",
"name": "julia-1.9"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 18e75e7

Please sign in to comment.