Skip to content

Commit efd4fb0

Browse files
authored
Local branching metrics fixes (#96)
* black format tests * revise author info for pypi * revise `local_branching` method, and add a unit test
1 parent fcc4fa8 commit efd4fb0

File tree

5 files changed

+122
-32
lines changed

5 files changed

+122
-32
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ test:
99
gctree test
1010

1111
format:
12-
black gctree
12+
black gctree tests
1313
docformatter --in-place gctree/*.py
1414

1515
lint:

gctree/branching_processes.py

+28-26
Original file line numberDiff line numberDiff line change
@@ -851,8 +851,7 @@ def support(
851851
compatibility_ += weights[i] if weights is not None else 1
852852
node.support = compatibility_ if compatibility else support
853853

854-
@np.errstate(all="raise")
855-
def local_branching(self, tau=1, tau0=0.1):
854+
def local_branching(self, tau=1, tau0=0.1, infinite_root_branch=True):
856855
r"""Add local branching statistics (Neher et al. 2014) as tree node
857856
features to the ETE tree attribute.
858857
After execution, all nodes will have new features ``LBI``
@@ -862,45 +861,48 @@ def local_branching(self, tau=1, tau0=0.1):
862861
Args:
863862
tau: decay timescale for exponential filter
864863
tau0: effective branch length for branches with zero mutations
864+
infinite_root_branch: calculate assuming the root node has an infinite branch
865865
"""
866866
# the fixed integral contribution for clonal cells indicated by abundance annotations
867867
clone_contribution = tau * (1 - np.exp(-tau0 / tau))
868868

869-
# post-order traversal to populate downward integral for each node
869+
# post-order traversal to populate downward integrals for each node
870870
for node in self.tree.traverse(strategy="postorder"):
871871
if node.is_leaf():
872-
node.add_feature(
873-
"LB_down",
874-
node.abundance * clone_contribution if node.abundance > 1 else 0,
875-
)
872+
node.LB_down = {
873+
node: node.abundance * clone_contribution
874+
if node.abundance > 1
875+
else 0
876+
}
876877
else:
877-
node.add_feature(
878-
"LB_down",
879-
node.abundance * clone_contribution
880-
+ sum(
881-
tau * (1 - np.exp(-child.dist / tau))
882-
+ np.exp(-child.dist / tau) * child.LB_down
883-
for child in node.children
884-
),
885-
)
878+
node.LB_down = {node: node.abundance * clone_contribution}
879+
for child in node.children:
880+
node.LB_down[child] = tau * (
881+
1 - np.exp(-child.dist / tau)
882+
) + np.exp(-child.dist / tau) * sum(child.LB_down.values())
886883

887884
# pre-order traversal to populate upward integral for each node
888885
for node in self.tree.traverse(strategy="preorder"):
889886
if node.is_root():
890-
node.add_feature("LB_up", 0)
887+
# integral corresponding to infinite branch above root node
888+
node.LB_up = tau if infinite_root_branch else 0
891889
else:
892-
node.add_feature(
893-
"LB_up",
894-
tau * (1 - np.exp(-node.dist / tau))
895-
+ np.exp(-node.dist / tau) * (node.up.LB_up + node.up.LB_down),
890+
node.LB_up = tau * (1 - np.exp(-node.dist / tau)) + np.exp(
891+
-node.dist / tau
892+
) * (
893+
node.up.LB_up
894+
+ sum(
895+
node.up.LB_down[message]
896+
for message in node.up.LB_down
897+
if message != node
898+
)
896899
)
897900

898-
# finally, compute LBI (LBR) as the sum (ratio) of upward and downward integrals at each node
901+
# finally, compute LBI (LBR) as the sum (ratio) of downward and upward integrals at each node
899902
for node in self.tree.traverse():
900-
node.add_feature("LBI", node.LB_down + node.LB_up)
901-
node.add_feature(
902-
"LBR", node.LB_down / node.LB_up if not node.is_root() else np.nan
903-
)
903+
node_LB_down_total = sum(node.LB_down.values())
904+
node.LBI = node_LB_down_total + node.LB_up
905+
node.LBR = node_LB_down_total / node.LB_up
904906

905907

906908
def _requires_dag(func):

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
name="gctree",
1010
version=versioneer.get_version(),
1111
cmdclass=versioneer.get_cmdclass(),
12-
author="William DeWitt",
13-
author_email="wsdewitt@gmail.com",
12+
author="Matsen Group",
13+
author_email="ematsen@gmail.com",
1414
description="phylogenetic inference of genotype-collapsed trees",
1515
long_description=long_description,
1616
long_description_content_type="text/markdown",

tests/test_likelihoods.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,10 @@ def test_newlikelihoods():
149149

150150
assert ll_isclose(oldfll, newfll)
151151
assert ll_isclose(oldfll, newfctreell)
152-
for dagll, treell in zip(sorted(newforest._forest.weight_count(**ll_dagfuncs).elements()), sorted(ctree.ll(p, q) for ctree in newforest)):
152+
for dagll, treell in zip(
153+
sorted(newforest._forest.weight_count(**ll_dagfuncs).elements()),
154+
sorted(ctree.ll(p, q) for ctree in newforest),
155+
):
153156
assert np.isclose(dagll, treell)
154157

155158

@@ -180,7 +183,7 @@ def test_validate_ll_genotype():
180183
for c in range(c_max):
181184
for m in range(m_max):
182185
if c > 0 or m > 1:
183-
with np.errstate(all='raise'):
186+
with np.errstate(all="raise"):
184187
true_res = OldCollapsedTree._ll_genotype(c, m, *params)
185188
res = bp.CollapsedTree._ll_genotype(c, m, *params)
186189
assert np.isclose(true_res[0], res[0])
@@ -193,5 +196,5 @@ def test_recursion_depth():
193196
recursion depth issues"""
194197
bp.CollapsedTree._ll_genotype.cache_clear()
195198
bp.CollapsedTree._max_ll_cache = {}
196-
with np.errstate(all='raise'):
199+
with np.errstate(all="raise"):
197200
bp.CollapsedTree._ll_genotype(2, 500, 0.4, 0.6)

tests/test_local_branching.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import ete3
2+
import gctree
3+
import numpy as np
4+
5+
tree = ete3.TreeNode(name="naive", dist=0)
6+
tree.abundance = 0
7+
tree.sequence = "A"
8+
tree.isotype = {}
9+
10+
child1 = ete3.TreeNode(name="seq1", dist=1)
11+
child1.abundance = 1
12+
child1.sequence = "C"
13+
child1.isotype = {}
14+
15+
child2 = ete3.TreeNode(name="seq2", dist=2)
16+
child2.abundance = 2
17+
child2.sequence = "G"
18+
child2.isotype = {}
19+
20+
grandchild = ete3.TreeNode(name="seq3", dist=3)
21+
grandchild.abundance = 3
22+
grandchild.sequence = "T"
23+
grandchild.isotype = {}
24+
25+
tree.add_child(child1)
26+
tree.add_child(child2)
27+
child1.add_child(grandchild)
28+
29+
ctree = gctree.CollapsedTree()
30+
ctree.tree = tree
31+
32+
τ = 1
33+
τ0 = 0.1
34+
35+
ctree.local_branching(tau=τ, tau0=τ0)
36+
37+
# dummy integrated branch length for exploded clonal abundances
38+
clone_contribution = τ * (1 - np.exp(-τ0 / τ))
39+
40+
LB_down = {
41+
tree: {
42+
tree: 0,
43+
child1: τ * (1 - np.exp(-(child1.dist + grandchild.dist) / τ))
44+
+ np.exp(-child1.dist / τ) * child1.abundance * clone_contribution
45+
+ np.exp(-(child1.dist + grandchild.dist) / τ)
46+
* grandchild.abundance
47+
* clone_contribution,
48+
child2: τ * (1 - np.exp(-child2.dist / τ))
49+
+ np.exp(-child2.dist / τ) * child2.abundance * clone_contribution,
50+
},
51+
child1: {
52+
child1: child1.abundance * clone_contribution,
53+
grandchild: τ * (1 - np.exp(-grandchild.dist / τ))
54+
+ np.exp(-grandchild.dist / τ) * grandchild.abundance * clone_contribution,
55+
},
56+
child2: {child2: child2.abundance * clone_contribution},
57+
grandchild: {grandchild: grandchild.abundance * clone_contribution},
58+
}
59+
60+
LB_up = {tree: τ}
61+
LB_up[child1] = τ * (1 - np.exp(-child1.dist / τ)) + np.exp(-child1.dist / τ) * (
62+
LB_up[tree]
63+
+ sum(LB_down[tree][message] for message in LB_down[tree] if message != child1)
64+
)
65+
LB_up[child2] = τ * (1 - np.exp(-child2.dist / τ)) + np.exp(-child2.dist / τ) * (
66+
LB_up[tree]
67+
+ sum(LB_down[tree][message] for message in LB_down[tree] if message != child2)
68+
)
69+
LB_up[grandchild] = τ * (1 - np.exp(-grandchild.dist / τ)) + np.exp(
70+
-grandchild.dist / τ
71+
) * (
72+
LB_up[child1]
73+
+ sum(
74+
LB_down[child1][message] for message in LB_down[child1] if message != grandchild
75+
)
76+
)
77+
78+
LBI = {node: LB_up[node] + sum(LB_down[node].values()) for node in tree.traverse()}
79+
LBR = {node: sum(LB_down[node].values()) / LB_up[node] for node in tree.traverse()}
80+
81+
for node in ctree.tree.traverse():
82+
assert LB_up[node] == node.LB_up
83+
assert LB_down[node] == node.LB_down
84+
assert LBI[node] == node.LBI
85+
assert LBR[node] == node.LBR

0 commit comments

Comments
 (0)