Skip to content

Commit

Permalink
Refactored repeated function calls
Browse files Browse the repository at this point in the history
  • Loading branch information
jkvis committed May 29, 2024
1 parent d143987 commit 3ce4b00
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 25 deletions.
26 changes: 15 additions & 11 deletions algebra/lcs/lcs_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def expand(idx):
col = row + idx
end = max(diagonals[offset + idx - 1], diagonals[offset + idx + 1])

remaining = abs((len_reference - row) - (len_observed - col))
matching = False
match_row = 0
match_col = 0
Expand All @@ -267,7 +268,7 @@ def expand(idx):
match_col = col
matching = True
elif matching:
lcs_pos = ((row + col) - (abs(delta) + 2 * it - abs((len(reference) - row) - (len(observed) - col)))) // 2 - 1
lcs_pos = ((row + col) - (abs_delta + 2 * it - remaining)) // 2 - 1
max_lcs_pos = max(lcs_pos, max_lcs_pos)
lcs_nodes[lcs_pos].append(LCSgraph.Node(match_row + shift, match_col + shift, row - match_row))
matching = False
Expand All @@ -278,24 +279,27 @@ def expand(idx):
if not matching:
match_row = row
match_col = col
while row < len(reference) and col < len(observed) and reference[row] == observed[col]:
while row < len_reference and col < len_observed and reference[row] == observed[col]:
matching = True
row += 1
col += 1
steps += 1
if matching:
lcs_pos = ((row + col) - (abs(delta) + 2 * it - abs((len(reference) - row) - (len(observed) - col)))) // 2 - 1
lcs_pos = ((row + col) - (abs_delta + 2 * it - remaining)) // 2 - 1
max_lcs_pos = max(lcs_pos, max_lcs_pos)
lcs_nodes[lcs_pos].append(LCSgraph.Node(match_row + shift, match_col + shift, row - match_row))

return steps

lcs_nodes = [[] for _ in range(min(len(reference), len(observed)))]
len_reference = len(reference)
len_observed = len(observed)
lcs_nodes = [[] for _ in range(min(len_reference, len_observed))]
max_lcs_pos = 0

delta = len(observed) - len(reference)
offset = len(reference) + 1
diagonals = [0] * (len(reference) + len(observed) + 3)
delta = len_observed - len_reference
abs_delta = abs(delta)
offset = len_reference + 1
diagonals = [0] * (len_reference + len_observed + 3)
it = 0

if delta >= 0:
Expand All @@ -305,7 +309,7 @@ def expand(idx):
lower = delta
upper = 0

while diagonals[offset + delta] <= max(len(reference), len(observed)) - abs(delta):
while diagonals[offset + delta] <= max(len_reference, len_observed) - abs_delta:
for idx in range(lower - it, delta):
diagonals[offset + idx] = expand(idx)

Expand All @@ -315,10 +319,10 @@ def expand(idx):
diagonals[offset + delta] = expand(delta)
it += 1

if max_distance and abs(delta) + 2 * (it - 1) > max_distance:
if max_distance and abs_delta + 2 * (it - 1) > max_distance:
raise ValueError("maximum distance exceeded")

return abs(delta) + 2 * (it - 1), lcs_nodes[:max_lcs_pos + 1]
return abs_delta + 2 * (it - 1), lcs_nodes[:max_lcs_pos + 1]


def _build_graph(reference, observed, lcs_nodes, shift=0):
Expand Down Expand Up @@ -394,7 +398,7 @@ def _build_graph(reference, observed, lcs_nodes, shift=0):
lcs_nodes[-2].insert(idx_parent, node)

del lcs_nodes[-1]
len_lcs_nodes = len(lcs_nodes)
len_lcs_nodes -= 1

source = lcs_nodes[0][0]
if lcs_nodes[0][0].row == lcs_nodes[0][0].col == shift:
Expand Down
20 changes: 10 additions & 10 deletions algebra/variants/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def match_variant(reference):
raise ValueError(f"inconsistent duplicated length at {pos}")
if reference is not None and sequence != reference[start:end]:
raise ValueError(f"'{sequence}' not found in reference at {start}")
return Variant(start, end, 2 * sequence)
return Variant.create_safe(start, end, 2 * sequence)

if match_optional("inv"):
try:
Expand All @@ -137,7 +137,7 @@ def match_variant(reference):
raise ValueError(f"inconsistent inversion length at {ctx_pos + 1}")
if reference is not None and sequence != reverse_complement(reference[start:end]):
raise ValueError(f"'{sequence}' not found in reference at {start}")
return Variant(start, end, sequence)
return Variant.create_safe(start, end, sequence)

if match_optional("del"):
if start == end:
Expand All @@ -152,13 +152,13 @@ def match_variant(reference):
if reference is not None and sequence != reference[start:end]:
raise ValueError(f"'{sequence}' not found in reference at {start}")
if match_optional("ins"):
return Variant(start, end, match_insertion())
return Variant(start, end, "")
return Variant.create_safe(start, end, match_insertion())
return Variant.create_safe(start, end, "")

if match_optional("ins"):
if end - start != 2:
raise ValueError(f"invalid inserted range at {pos}")
return Variant(start + 1, start + 1, match_insertion())
return Variant.create_safe(start + 1, start + 1, match_insertion())

try:
sequence = match_sequence()
Expand All @@ -171,10 +171,10 @@ def match_variant(reference):
raise ValueError(f"inconstistent deletion length at {ctx_pos + 1}")
if reference is not None and sequence != reference[start:end]:
raise ValueError(f"'{sequence}' not found in reference at {start}")
return Variant(start, end, match_sequence())
return Variant.create_safe(start, end, match_sequence())

if match_optional("="):
return Variant(0, 0, "")
return Variant.create_safe(0, 0, "")

if match_optional("["):
repeat = match_number()
Expand All @@ -188,10 +188,10 @@ def match_variant(reference):
found += 1
if found == 0:
raise ValueError(f"'{sequence}' not found in reference at {start}")
return Variant(start, start + found * len(sequence), repeat * sequence)
return Variant.create_safe(start, start + found * len(sequence), repeat * sequence)

# HGVS style repeat
return Variant(start, end, repeat * sequence)
return Variant.create_safe(start, end, repeat * sequence)

raise NotImplementedError(f"unsupported variant at {ctx_pos + 1}")

Expand Down Expand Up @@ -254,4 +254,4 @@ def parse_spdi(expression):
length = int(deletion)
except ValueError:
length = len(deletion)
return [Variant(start, start + length, insertion)]
return [Variant.create_safe(start, start + length, insertion)]
10 changes: 7 additions & 3 deletions algebra/variants/variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class Variant:
"""Variant class for deletion/insertions."""

def __init__(self, start, end, sequence):
self.start = start
self.end = end
self.sequence = sequence

@classmethod
def create_safe(cls, start, end, sequence):
"""Create a variant.
Parameters
Expand Down Expand Up @@ -59,9 +65,7 @@ def __init__(self, start, end, sequence):
if start > end:
raise ValueError("start must not be after end")

self.start = start
self.end = end
self.sequence = sequence
return cls(start, end, sequence)

def __eq__(self, other):
return (self.start == other.start and self.end == other.end and
Expand Down
2 changes: 1 addition & 1 deletion tests/variants/test_variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_variant(args, expected):
])
def test_variant_fail(args, exception, message):
with pytest.raises(exception) as exc:
Variant(*args)
Variant.create_safe(*args)
assert str(exc.value) == message


Expand Down

0 comments on commit 3ce4b00

Please sign in to comment.