Skip to content

Commit

Permalink
ancestor sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
nifets committed May 19, 2023
1 parent 516d99d commit c7c888f
Show file tree
Hide file tree
Showing 14 changed files with 1,530 additions and 68 deletions.
308 changes: 308 additions & 0 deletions .ipynb_checkpoints/Simulated Data Experiments-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "54f56c5a",
"metadata": {},
"source": [
"## 1. Experiment set up"
]
},
{
"cell_type": "markdown",
"id": "5d67afe0",
"metadata": {},
"source": [
"### 1.1. Sample evolutionary process parameters"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "dc9cf9bb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39m[Turing]: progress logging is enabled globally\n",
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39m[AdvancedVI]: global PROGRESS is set as true\n",
"\u001b[36m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39m(2.0793844375569694, 0.005773570704397102, 0.009389661226149627, 0.769318191229367, 0.0004517633068183017, 0.3867126231221906, 0.0008207706544378668, 0.8074528459487708)\n",
"\u001b[36m\u001b[1m└ \u001b[22m\u001b[39m Ξ_1 = \"Ξ_1\"\n",
"\u001b[36m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39m(1.1344905145087365, -0.6151056108872228, 0.5045661989378393, 1.1942245564097882, 0.001300041441663701, 0.02682643626914298, 0.0005803159622610678, 1.3748720746700442)\n",
"\u001b[36m\u001b[1m└ \u001b[22m\u001b[39m Ξ_2 = \"Ξ_2\"\n",
"\u001b[36m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39m(0.029162685377525838, 0.061302174093350016, 0.5520709365048283)\n",
"\u001b[36m\u001b[1m└ \u001b[22m\u001b[39m Λ = \"Λ\"\n"
]
},
{
"data": {
"text/plain": [
"10-element Vector{Float64}:\n",
" 0.31426686000171133\n",
" 0.15546045409280682\n",
" 0.06330746782915772\n",
" 0.11425781020314718\n",
" 0.1902670150680404\n",
" 0.1413388813018777\n",
" 0.3871452754000275\n",
" 0.04522098943588153\n",
" 0.04587982654249431\n",
" 0.01649363662998924"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using TorusEvol \n",
"using Distributions\n",
"using Turing\n",
"using Random\n",
"\n",
"Random.seed!(10203)\n",
"Turing.setprogress!(true)\n",
"\n",
"# Number of descendants \n",
"D = 10\n",
"Y = Vector{ObservedChain}(undef, D) # descendants\n",
"M = Vector{Alignment}(undef, D) # alignments\n",
"\n",
"# Evolutionary regimes\n",
"E = 2\n",
"weights = rand(Dirichlet(E, 1.0))\n",
"\n",
"# Site level process\n",
"S = WAG_SubstitutionProcess()\n",
"\n",
"Ξ_1 = jwndiff_prior()(); Θ_1 = JumpingWrappedDiffusion(Ξ_1...); @info Ξ_1 \"Ξ_1\"\n",
"Ξ_2 = jwndiff_prior()(); Θ_2 = JumpingWrappedDiffusion(Ξ_2...); @info Ξ_2 \"Ξ_2\"\n",
"ξ = MixtureProductProcess(weights, [S S; Θ_1 Θ_2])\n",
"\n",
"# Alignment parameters\n",
"Λ = tkf92_prior()(); @info Λ \"Λ\"\n",
"\n",
"# Evolutionary distances \n",
"ts = rand(Exponential(0.1), D) \n",
"ts"
]
},
{
"cell_type": "markdown",
"id": "14f740f0",
"metadata": {},
"source": [
"### 1.2. Sample ancestor and aligned descendants"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "5bdfc105",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The log pdf of X and Y_1 is -83.51918048243054"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mSuperimposing based on a sequence alignment between 7 residues\n",
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mSuperimposing based on 7 atoms\n",
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mModel 1 with 2 chains (1,2), 14 residues, 54 atoms\n",
"\u001b[36m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39m\n",
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---#----##-#----#-#\n",
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-####--------#----#-#\n",
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---#----##-#----#-#\n",
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---##----##-----#-#\n",
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---#-#---#-#----#-#\n",
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m#----#----##-#----#-#\n",
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#------------#####-#\n",
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m--------####-#----##-\n",
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---#----##-#----#-#\n",
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---#----##-#----#-#\n",
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---#----##-#----#-#\n",
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m\n",
"\u001b[36m\u001b[1m└ \u001b[22m\u001b[39m\n"
]
}
],
"source": [
"pair_chain_dist = ChainJointDistribution(ξ, TKF92([ts[1]], Λ...))\n",
"(X, _) = rand(pair_chain_dist)\n",
"\n",
"for d ∈ 1:D\n",
" τ = TKF92([ts[d]], Λ...)\n",
" Y[d] = rand(ChainTransitionDistribution(ξ, τ, X))\n",
" Γ = ChainJointDistribution(ξ, τ)\n",
" α_XY = get_α(τ, (X, Y[d])); logpdfα!(α_XY, Γ, (X, Y[d]))\n",
" M[d] = Alignment(rand(ConditionedAlignmentDistribution(TKF92([ts[d]], Λ...), α_XY)), τ)\n",
" M[d] = Alignment(data(M[d]), [1, d+1])\n",
"end\n",
"\n",
"chainX = from_primary_dihedrals(Int.(data(X)[1]), data(X)[2])\n",
"chainY = from_primary_dihedrals(Int.(data(Y[1])[1]), data(Y[1])[2])\n",
"lp = logpdf(pair_chain_dist, (X, Y[1]))\n",
"print(\"The log pdf of X and Y_1 is $lp\")\n",
"render(chainX, chainY; aligned=true)\n",
"\n",
"M_full = combine(1, M[1], M[2])\n",
"for d ∈ 3:D \n",
" M_full = combine(1, M_full, M[d])\n",
"end\n",
"\n",
"@info M_full"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "310f058b",
"metadata": {},
"outputs": [],
"source": [
"@model function bisection_sampler(p, descs, t)\n",
" D = length(descs)\n",
" descs[1] ~ statdist(p)\n",
" \n",
" x = Vector{Real}(undef, length(p))\n",
" x ~ transdist(p, t, descs[1])\n",
" for d ∈ 2:D\n",
" descs[d] ~ transdist(p, t, x)\n",
" end\n",
"end\n",
"\n",
"function sample_anc_coords_wn(p, descs, t; burn_in=300)\n",
" sampler = MH(:x => v -> MixtureModel([WrappedNormal(v, I), WrappedNormal(v, 20*I)], [0.8, 0.2]))\n",
" n = size(descs[1], 2)\n",
" D = length(descs)\n",
" X = Matrix{Real}(undef, length(p), n)\n",
" for i ∈ 1:n\n",
" descs_i = [descs[d][:, i] for d ∈ 1:D]\n",
" model = bisection_sampler(p, descs_i, t)\n",
" chn = sample(model, sampler, burn_in+1)\n",
" X[1, i] = get(chn, :x).x[1][burn_in+1]\n",
" X[2, i] = get(chn, :x).x[2][burn_in+1]\n",
" end\n",
" return X\n",
"end\n",
"\n",
"function sample_anc_coords_sub(p, descs, t; burn_in=300)\n",
" sampler = PG(60, :x)\n",
" n = size(descs[1], 2)\n",
" D = length(descs)\n",
" X = Matrix{Real}(undef, length(p), n)\n",
" for i ∈ 1:n\n",
" descs_i = [descs[d][:, i] for d ∈ 1:D]\n",
" model = bisection_sampler(p, descs_i, t)\n",
" chn = sample(model, sampler, burn_in+1)\n",
" X[1, i] = get(chn, :x).x[1][burn_in+1]\n",
" end\n",
" return X\n",
"end\n",
"\n",
"function sample_anc_coords(p, descs, t)\n",
" if eltype(p) <: Integer \n",
" return sample_anc_coords_sub(p, descs, t)\n",
" else \n",
" return sample_anc_coords_wn(p, descs, t)\n",
" end\n",
"end\n",
"\n",
"function ancestor_sampling(M_YZ::Alignment, Y::ObservedChain, Z::ObservedChain, \n",
" t::Real, ξ::MixtureProductProcess)\n",
" # step 1 - sample XYZ alignment\n",
" M_XYZ = sample_anc_alignment(M_YZ, Y, Z, t, ξ)\n",
" \n",
" # step 2 - sample coordinates of X\n",
" alignment = M_XYZ\n",
" X_mask = mask(alignment, [[1], [0,1], [0,1]])\n",
" alignmentX = slice(alignment, X_mask)\n",
" Y_mask = mask(alignment, [[0,1], [1], [0,1]])\n",
" alignmentY = slice(alignment, Y_mask)\n",
" Z_mask = mask(alignment, [[0,1], [0,1], [1]])\n",
" alignmentZ = slice(alignment, Z_mask)\n",
" \n",
" M = length(alignment)\n",
" regimes = ones(M)\n",
" \n",
" X_maskX = mask(alignmentX, [[1], [0], [0]])\n",
" \n",
" XY_maskX = mask(alignmentX, [[1], [1], [0]])\n",
" XY_maskY = mask(alignmentY, [[1], [1], [0]])\n",
" \n",
" XZ_maskX = mask(alignmentX, [[1], [0], [1]])\n",
" XZ_maskZ = mask(alignmentZ, [[1], [0], [1]])\n",
" \n",
" XYZ_maskX = mask(alignmentX, [[1], [1], [1]])\n",
" XYZ_maskY = mask(alignmentY, [[1], [1], [1]])\n",
" XYZ_maskZ = mask(alignmentZ, [[1], [1], [1]])\n",
" \n",
" dataY = data(Y)\n",
" dataZ = data(Z)\n",
" \n",
" # Initialise internal coordinates of X\n",
" N = sequence_lengths(M_XYZ)[1]\n",
" dataX = [similar(dataY[c], size(dataY[c], 1), N) for c ∈ 1:C]\n",
" \n",
" for c ∈ 1:C, e ∈ 1:E\n",
" p = processes(ξ)[c, e]\n",
" \n",
" # [1, 0, 0] - sample from stationary distribution\n",
" dataX100 = @view dataX[c][:, X_maskX .& regimesX .== e]\n",
" n100 = size(dataX100, 2)\n",
" \n",
" dataX100 .= rand(statdist(p), n100)\n",
" \n",
" \n",
" # [1, 1, 0] - observe Y, then sample X from Y\n",
" dataY110 = @view dataY[c][:, XY_maskY .& regimesY .== e]\n",
" dataX110 = @view dataX[c][:, XY_maskX .& regimesX .== e]\n",
" \n",
" dataX110 .= sample_anc_coords(p, [dataY110], t)\n",
" \n",
" \n",
" # [1, 0, 1] - observe Z, then sample X from Z\n",
" dataZ101 = @view dataZ[c][:, XZ_maskZ .& regimesZ .== e]\n",
" dataX101 = @view dataX[c][:, XZ_maskZ .& regimesX .== e]\n",
" \n",
" dataX101 .= sample_anc_coords(p, [dataZ101], t)\n",
" \n",
" \n",
" # [1, 1, 1] - observe Y, sample X from Y, then observe Z from X\n",
" dataY111 = @view dataY[c][:, XYZ_maskY .& regimesY .== e]\n",
" dataZ111 = @view dataZ[c][:, XYZ_maskZ .& regimesZ .== e]\n",
" dataX111 = @view dataX[c][:, XYZ_maskX .& regimesX .== e]\n",
" \n",
" dataX111 .= sample_anc_coords(p, [dataY111, dataZ111], t)\n",
" end\n",
" \n",
" X = ObservedChain(dataX)\n",
"end"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.9.0-rc2",
"language": "julia",
"name": "julia-1.9"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit c7c888f

Please sign in to comment.