Skip to content

Commit

Permalink
Merge pull request #89 from Genentech/fix-ism
Browse files Browse the repository at this point in the history
fix bug in ism_predict on multiple sequences
  • Loading branch information
avantikalal authored Jan 31, 2025
2 parents 292aa93 + 23aa16f commit ab65884
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/grelu/interpret/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


def ISM_predict(
seqs: Union[pd.DataFrame, np.ndarray, str],
seqs: Union[pd.DataFrame, np.ndarray, str, List[str]],
model: Callable,
genome: Optional[str] = None,
prediction_transform: Optional[Callable] = None,
Expand Down Expand Up @@ -87,11 +87,16 @@ def ISM_predict(
)
# B, L, 4, T, L

# Calculate log ratio w.r.t reference sequence
if compare_func is not None:

# Slice the prediction corresponding to each reference sequence
ref_bases = [BASE_TO_INDEX_HASH[seq[start_pos]] for seq in seqs]
ref_pred = preds[:, [0], [ref_bases], :] # B, 1, 1, T, L
preds = get_compare_func(compare_func, tensor=False)(preds, ref_pred)
ref_preds = np.concatenate(
[preds[None, None, None, i, 0, x] for i, x in enumerate(ref_bases)]
) # B, L, 1, T, L

# Compare all predictions to the prediction for the corresponding reference sequence
preds = get_compare_func(compare_func, tensor=False)(preds, ref_preds)

# Convert into a dataframe
if return_df:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,41 @@ def test_marginalize_patterns():


def test_ISM_predict():

# Single sequence
seq = "AA"
expected_preds = np.array([[4.0, 1.0, 2.0, 2.0], [4.0, 1.0, 2.0, 2.0]]).T
preds = ISM_predict(seq, model, compare_func=None)
assert np.allclose(preds.values, expected_preds)
preds = ISM_predict(seq, model, compare_func="log2FC")
assert np.allclose(preds.values, np.log2(expected_preds / 4))

# Multiple sequences
seqs = ["AAA", "CCC"]
expected_preds = np.expand_dims(
np.array(
[
[
[4.0, 2.0, 2.6666667, 2.6666667],
[4.0, 2.0, 2.6666667, 2.6666667],
[4.0, 2.0, 2.6666667, 2.6666667],
],
[
[0.0, -2.0, -1.3333334, -1.3333334],
[0.0, -2.0, -1.3333334, -1.3333334],
[0.0, -2.0, -1.3333334, -1.3333334],
],
]
),
(3, 4),
)
preds = ISM_predict(seqs, model, compare_func=None, return_df=False)
assert np.allclose(preds, expected_preds)
preds = ISM_predict(seqs, model, compare_func="log2FC", return_df=False)
assert np.allclose(
preds, np.log2(np.stack([expected_preds[0] / 4, -expected_preds[1] / 2]))
)


def test_get_attributions():
seq = generate_random_sequences(n=1, seq_len=50, seed=0, output_format="strings")[0]
Expand Down

0 comments on commit ab65884

Please sign in to comment.