-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,530 additions
and
68 deletions.
There are no files selected for viewing
308 changes: 308 additions & 0 deletions
308
.ipynb_checkpoints/Simulated Data Experiments-checkpoint.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,308 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "54f56c5a", | ||
"metadata": {}, | ||
"source": [ | ||
"## 1. Experiment set up" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "5d67afe0", | ||
"metadata": {}, | ||
"source": [ | ||
"### 1.1. Sample evolutionary process parameters" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 53, | ||
"id": "dc9cf9bb", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39m[Turing]: progress logging is enabled globally\n", | ||
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39m[AdvancedVI]: global PROGRESS is set as true\n", | ||
"\u001b[36m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39m(2.0793844375569694, 0.005773570704397102, 0.009389661226149627, 0.769318191229367, 0.0004517633068183017, 0.3867126231221906, 0.0008207706544378668, 0.8074528459487708)\n", | ||
"\u001b[36m\u001b[1m└ \u001b[22m\u001b[39m Ξ_1 = \"Ξ_1\"\n", | ||
"\u001b[36m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39m(1.1344905145087365, -0.6151056108872228, 0.5045661989378393, 1.1942245564097882, 0.001300041441663701, 0.02682643626914298, 0.0005803159622610678, 1.3748720746700442)\n", | ||
"\u001b[36m\u001b[1m└ \u001b[22m\u001b[39m Ξ_2 = \"Ξ_2\"\n", | ||
"\u001b[36m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39m(0.029162685377525838, 0.061302174093350016, 0.5520709365048283)\n", | ||
"\u001b[36m\u001b[1m└ \u001b[22m\u001b[39m Λ = \"Λ\"\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"10-element Vector{Float64}:\n", | ||
" 0.31426686000171133\n", | ||
" 0.15546045409280682\n", | ||
" 0.06330746782915772\n", | ||
" 0.11425781020314718\n", | ||
" 0.1902670150680404\n", | ||
" 0.1413388813018777\n", | ||
" 0.3871452754000275\n", | ||
" 0.04522098943588153\n", | ||
" 0.04587982654249431\n", | ||
" 0.01649363662998924" | ||
] | ||
}, | ||
"execution_count": 53, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"using TorusEvol \n", | ||
"using Distributions\n", | ||
"using Turing\n", | ||
"using Random\n", | ||
"\n", | ||
"Random.seed!(10203)\n", | ||
"Turing.setprogress!(true)\n", | ||
"\n", | ||
"# Number of descendants \n", | ||
"D = 10\n", | ||
"Y = Vector{ObservedChain}(undef, D) # descendants\n", | ||
"M = Vector{Alignment}(undef, D) # alignments\n", | ||
"\n", | ||
"# Evolutionary regimes\n", | ||
"E = 2\n", | ||
"weights = rand(Dirichlet(E, 1.0))\n", | ||
"\n", | ||
"# Site level process\n", | ||
"S = WAG_SubstitutionProcess()\n", | ||
"\n", | ||
"Ξ_1 = jwndiff_prior()(); Θ_1 = JumpingWrappedDiffusion(Ξ_1...); @info Ξ_1 \"Ξ_1\"\n", | ||
"Ξ_2 = jwndiff_prior()(); Θ_2 = JumpingWrappedDiffusion(Ξ_2...); @info Ξ_2 \"Ξ_2\"\n", | ||
"ξ = MixtureProductProcess(weights, [S S; Θ_1 Θ_2])\n", | ||
"\n", | ||
"# Alignment parameters\n", | ||
"Λ = tkf92_prior()(); @info Λ \"Λ\"\n", | ||
"\n", | ||
"# Evolutionary distances \n", | ||
"ts = rand(Exponential(0.1), D) \n", | ||
"ts" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "14f740f0", | ||
"metadata": {}, | ||
"source": [ | ||
"### 1.2. Sample ancestor and aligned descendants" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 61, | ||
"id": "5bdfc105", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"The log pdf of X and Y_1 is -83.51918048243054" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mSuperimposing based on a sequence alignment between 7 residues\n", | ||
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mSuperimposing based on 7 atoms\n", | ||
"\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mModel 1 with 2 chains (1,2), 14 residues, 54 atoms\n", | ||
"\u001b[36m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39m\n", | ||
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---#----##-#----#-#\n", | ||
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-####--------#----#-#\n", | ||
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---#----##-#----#-#\n", | ||
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---##----##-----#-#\n", | ||
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---#-#---#-#----#-#\n", | ||
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m#----#----##-#----#-#\n", | ||
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#------------#####-#\n", | ||
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m--------####-#----##-\n", | ||
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---#----##-#----#-#\n", | ||
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---#----##-#----#-#\n", | ||
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m-#---#----##-#----#-#\n", | ||
"\u001b[36m\u001b[1m│ \u001b[22m\u001b[39m\n", | ||
"\u001b[36m\u001b[1m└ \u001b[22m\u001b[39m\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"pair_chain_dist = ChainJointDistribution(ξ, TKF92([ts[1]], Λ...))\n", | ||
"(X, _) = rand(pair_chain_dist)\n", | ||
"\n", | ||
"for d ∈ 1:D\n", | ||
" τ = TKF92([ts[d]], Λ...)\n", | ||
" Y[d] = rand(ChainTransitionDistribution(ξ, τ, X))\n", | ||
" Γ = ChainJointDistribution(ξ, τ)\n", | ||
" α_XY = get_α(τ, (X, Y[d])); logpdfα!(α_XY, Γ, (X, Y[d]))\n", | ||
" M[d] = Alignment(rand(ConditionedAlignmentDistribution(TKF92([ts[d]], Λ...), α_XY)), τ)\n", | ||
" M[d] = Alignment(data(M[d]), [1, d+1])\n", | ||
"end\n", | ||
"\n", | ||
"chainX = from_primary_dihedrals(Int.(data(X)[1]), data(X)[2])\n", | ||
"chainY = from_primary_dihedrals(Int.(data(Y[1])[1]), data(Y[1])[2])\n", | ||
"lp = logpdf(pair_chain_dist, (X, Y[1]))\n", | ||
"print(\"The log pdf of X and Y_1 is $lp\")\n", | ||
"render(chainX, chainY; aligned=true)\n", | ||
"\n", | ||
"M_full = combine(1, M[1], M[2])\n", | ||
"for d ∈ 3:D \n", | ||
" M_full = combine(1, M_full, M[d])\n", | ||
"end\n", | ||
"\n", | ||
"@info M_full" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "310f058b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"@model function bisection_sampler(p, descs, t)\n", | ||
" D = length(descs)\n", | ||
" descs[1] ~ statdist(p)\n", | ||
" \n", | ||
" x = Vector{Real}(undef, length(p))\n", | ||
" x ~ transdist(p, t, descs[1])\n", | ||
" for d ∈ 2:D\n", | ||
" descs[d] ~ transdist(p, t, x)\n", | ||
" end\n", | ||
"end\n", | ||
"\n", | ||
"function sample_anc_coords_wn(p, descs, t; burn_in=300)\n", | ||
" sampler = MH(:x => v -> MixtureModel([WrappedNormal(v, I), WrappedNormal(v, 20*I)], [0.8, 0.2]))\n", | ||
" n = size(descs[1], 2)\n", | ||
" D = length(descs)\n", | ||
" X = Matrix{Real}(undef, length(p), n)\n", | ||
" for i ∈ 1:n\n", | ||
" descs_i = [descs[d][:, i] for d ∈ 1:D]\n", | ||
" model = bisection_sampler(p, descs_i, t)\n", | ||
" chn = sample(model, sampler, burn_in+1)\n", | ||
" X[1, i] = get(chn, :x).x[1][burn_in+1]\n", | ||
" X[2, i] = get(chn, :x).x[2][burn_in+1]\n", | ||
" end\n", | ||
" return X\n", | ||
"end\n", | ||
"\n", | ||
"function sample_anc_coords_sub(p, descs, t; burn_in=300)\n", | ||
" sampler = PG(60, :x)\n", | ||
" n = size(descs[1], 2)\n", | ||
" D = length(descs)\n", | ||
" X = Matrix{Real}(undef, length(p), n)\n", | ||
" for i ∈ 1:n\n", | ||
" descs_i = [descs[d][:, i] for d ∈ 1:D]\n", | ||
" model = bisection_sampler(p, descs_i, t)\n", | ||
" chn = sample(model, sampler, burn_in+1)\n", | ||
" X[1, i] = get(chn, :x).x[1][burn_in+1]\n", | ||
" end\n", | ||
" return X\n", | ||
"end\n", | ||
"\n", | ||
"function sample_anc_coords(p, descs, t)\n", | ||
" if eltype(p) <: Integer \n", | ||
" return sample_anc_coords_sub(p, descs, t)\n", | ||
" else \n", | ||
" return sample_anc_coords_wn(p, descs, t)\n", | ||
" end\n", | ||
"end\n", | ||
"\n", | ||
"function ancestor_sampling(M_YZ::Alignment, Y::ObservedChain, Z::ObservedChain, \n", | ||
" t::Real, ξ::MixtureProductProcess)\n", | ||
" # step 1 - sample XYZ alignment\n", | ||
" M_XYZ = sample_anc_alignment(M_YZ, Y, Z, t, ξ)\n", | ||
" \n", | ||
" # step 2 - sample coordinates of X\n", | ||
" alignment = M_XYZ\n", | ||
" X_mask = mask(alignment, [[1], [0,1], [0,1]])\n", | ||
" alignmentX = slice(alignment, X_mask)\n", | ||
" Y_mask = mask(alignment, [[0,1], [1], [0,1]])\n", | ||
" alignmentY = slice(alignment, Y_mask)\n", | ||
" Z_mask = mask(alignment, [[0,1], [0,1], [1]])\n", | ||
" alignmentZ = slice(alignment, Z_mask)\n", | ||
" \n", | ||
" M = length(alignment)\n", | ||
" regimes = ones(M)\n", | ||
" \n", | ||
" X_maskX = mask(alignmentX, [[1], [0], [0]])\n", | ||
" \n", | ||
" XY_maskX = mask(alignmentX, [[1], [1], [0]])\n", | ||
" XY_maskY = mask(alignmentY, [[1], [1], [0]])\n", | ||
" \n", | ||
" XZ_maskX = mask(alignmentX, [[1], [0], [1]])\n", | ||
" XZ_maskZ = mask(alignmentZ, [[1], [0], [1]])\n", | ||
" \n", | ||
" XYZ_maskX = mask(alignmentX, [[1], [1], [1]])\n", | ||
" XYZ_maskY = mask(alignmentY, [[1], [1], [1]])\n", | ||
" XYZ_maskZ = mask(alignmentZ, [[1], [1], [1]])\n", | ||
" \n", | ||
" dataY = data(Y)\n", | ||
" dataZ = data(Z)\n", | ||
" \n", | ||
" # Initialise internal coordinates of X\n", | ||
" N = sequence_lengths(M_XYZ)[1]\n", | ||
" dataX = [similar(dataY[c], size(dataY[c], 1), N) for c ∈ 1:C]\n", | ||
" \n", | ||
" for c ∈ 1:C, e ∈ 1:E\n", | ||
" p = processes(ξ)[c, e]\n", | ||
" \n", | ||
" # [1, 0, 0] - sample from stationary distribution\n", | ||
" dataX100 = @view dataX[c][:, X_maskX .& regimesX .== e]\n", | ||
" n100 = size(dataX100, 2)\n", | ||
" \n", | ||
" dataX100 .= rand(statdist(p), n100)\n", | ||
" \n", | ||
" \n", | ||
" # [1, 1, 0] - observe Y, then sample X from Y\n", | ||
" dataY110 = @view dataY[c][:, XY_maskY .& regimesY .== e]\n", | ||
" dataX110 = @view dataX[c][:, XY_maskX .& regimesX .== e]\n", | ||
" \n", | ||
" dataX110 .= sample_anc_coords(p, [dataY110], t)\n", | ||
" \n", | ||
" \n", | ||
" # [1, 0, 1] - observe Z, then sample X from Z\n", | ||
" dataZ101 = @view dataZ[c][:, XZ_maskZ .& regimesZ .== e]\n", | ||
" dataX101 = @view dataX[c][:, XZ_maskZ .& regimesX .== e]\n", | ||
" \n", | ||
" dataX101 .= sample_anc_coords(p, [dataZ101], t)\n", | ||
" \n", | ||
" \n", | ||
" # [1, 1, 1] - observe Y, sample X from Y, then observe Z from X\n", | ||
" dataY111 = @view dataY[c][:, XYZ_maskY .& regimesY .== e]\n", | ||
" dataZ111 = @view dataZ[c][:, XYZ_maskZ .& regimesZ .== e]\n", | ||
" dataX111 = @view dataX[c][:, XYZ_maskX .& regimesX .== e]\n", | ||
" \n", | ||
" dataX111 .= sample_anc_coords(p, [dataY111, dataZ111], t)\n", | ||
" end\n", | ||
" \n", | ||
" X = ObservedChain(dataX)\n", | ||
"end" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Julia 1.9.0-rc2", | ||
"language": "julia", | ||
"name": "julia-1.9" | ||
}, | ||
"language_info": { | ||
"file_extension": ".jl", | ||
"mimetype": "application/julia", | ||
"name": "julia", | ||
"version": "1.9.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.