Skip to content

Commit

Permalink
add context likelihood test
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Apr 5, 2024
1 parent 36ea19f commit 75b8fa7
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 10 deletions.
5 changes: 2 additions & 3 deletions gctree/branching_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,6 @@ def filter_trees( # noqa: C901
)
else:
mut_funcs = _context_poisson_likelihood_dagfuncs(

mutability_file=mutability_file,
substitution_file=substitution_file,
splits=[] if chain_split is None else [chain_split],
Expand Down Expand Up @@ -1291,8 +1290,8 @@ def linear_combinator(weighttuple):
ordering_name="LinearCombination",
)
ranking_description = (
"Ranking trees to minimize a linear combination of " +
" + ".join(
"Ranking trees to minimize a linear combination of "
+ " + ".join(
str(coeff) + "(" + fl.weight_funcs.name + ")"
for fl, coeff in dag_filters
)
Expand Down
4 changes: 1 addition & 3 deletions gctree/mutation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,7 @@ def distance(seq1, seq2):
return distance


def _context_poisson_likelihood_dagfuncs(
*args, splits: List[int] = [], **kwargs
):
def _context_poisson_likelihood_dagfuncs(*args, splits: List[int] = [], **kwargs):
mutation_model = MutationModel(*args, **kwargs)
distance = _context_poisson_likelihood(mutation_model, splits=splits)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_isotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def test_trim_byisotype():
for node in tdag.preorder():
if node.attr is not None:
node.attr["isotype"] = node._dp_data
kwargs = _isotype_dagfuncs()
c = tdag.weight_count(**kwargs)
dag_filter = _isotype_dagfuncs()
c = tdag.weight_count(**dag_filter)
key = min(c)
count = c[key]
tdag.trim_optimal_weight(**kwargs, optimal_func=min)
assert tdag.weight_count(**kwargs) == {key: count}
tdag.trim_optimal_weight(**dag_filter)
assert tdag.weight_count(**dag_filter) == {key: count}
36 changes: 36 additions & 0 deletions tests/test_likelihoods.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import gctree.branching_processes as bp
import gctree.phylip_parse as pp
import gctree.utils as utils
import gctree.mutation_model as mm
from math import log

import numpy as np
from multiset import FrozenMultiset
Expand Down Expand Up @@ -198,3 +200,37 @@ def test_recursion_depth():
bp.CollapsedTree._max_ll_cache = {}
with np.errstate(all="raise"):
bp.CollapsedTree._ll_genotype(2, 500, 0.4, 0.6)


def test_context_likelihood():
# These files will be present if pytest is run through `make test`.
mutation_model = mm.MutationModel(
mutability_file="HS5F_Mutability.csv", substitution_file="HS5F_Substitution.csv"
)
log_likelihood = mm._context_poisson_likelihood(mutation_model, splits=[])

parent_seq = "AAGAAA"
child_seq = "AATCAA"

term1 = sum(
log(
mutation_model.mutability(fivemer)[0]
* mutation_model.mutability(fivemer)[1][target_base]
)
for fivemer, target_base in [("AAGAA", "T"), ("AGAAA", "C")]
)
sum_mutabilities = sum(
mutation_model.mutability(fivemer)[0]
for fivemer in ["NNAAG", "NAAGA", "AAGAA", "AGAAA", "GAAAN", "AAANN"]
)
true_val = term1 + 2 * log(2 / sum_mutabilities) - 2
assert true_val == log_likelihood(parent_seq, child_seq)

# Now test chain split:
parent_seq = parent_seq + parent_seq
child_seq = child_seq + child_seq
# At index 6, the second concatenated sequence starts.
log_likelihood = mm._context_poisson_likelihood(mutation_model, splits=[6])

true_val = 2 * term1 + 4 * log(4 / (2 * sum_mutabilities)) - 4
assert true_val == log_likelihood(parent_seq, child_seq)

0 comments on commit 75b8fa7

Please sign in to comment.