Skip to content

Commit 753c9ab

Browse files
first part Christians comments
1 parent 4a3700f commit 753c9ab

File tree

3 files changed

+135
-50
lines changed

3 files changed

+135
-50
lines changed

molpipeline/estimators/nearest_neighbor.py

+79-22
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
from __future__ import annotations
44

5-
import multiprocessing
65
from typing import Any, Callable, Literal, Sequence, Union
76

87
from joblib import Parallel, delayed
8+
from scipy.sparse import csr_matrix
9+
from sklearn.base import BaseEstimator
10+
11+
from molpipeline.utils.multi_proc import check_available_cores
912

1013
try:
1114
from typing import Self
@@ -207,8 +210,8 @@ def fit_predict(
207210
return self.predict(X, return_distance=return_distance, n_neighbors=n_neighbors)
208211

209212

210-
class NearestNeighborsRetrieverTanimoto: # pylint: disable=too-few-public-methods
211-
"""k-nearest neighbors between data sets using Tanimoto similarity.
213+
class TanimotoKNN(BaseEstimator): # pylint: disable=too-few-public-methods
214+
"""k-nearest neighbors (KNN) between data sets using Tanimoto similarity.
212215
213216
This class uses the Tanimoto similarity to find the k-nearest neighbors of a query set in a target set.
214217
The full similarity matrix is computed and reduced to the k-nearest neighbors. A dot-product based
@@ -218,42 +221,85 @@ class NearestNeighborsRetrieverTanimoto: # pylint: disable=too-few-public-metho
218221
the batches can be processed in parallel using joblib.
219222
"""
220223

224+
target_indices_mapping_: npt.NDArray[np.int64] | None
225+
221226
def __init__(
222227
self,
223-
target_fingerprints: sparse.csr_matrix,
224-
k: int | None = None,
228+
*,
229+
k: int | None,
225230
batch_size: int = 1000,
226231
n_jobs: int = 1,
227232
):
228-
"""Initialize NearestNeighborsRetrieverTanimoto.
233+
"""Initialize TanimotoKNN.
229234
230235
Parameters
231236
----------
232-
target_fingerprints: sparse.csr_matrix
233-
Fingerprints of target molecules. Must be a binary sparse matrix.
234-
k: int, optional (default=None)
237+
k: int | None
235238
Number of nearest neighbors to find. If None, all neighbors are returned.
236239
batch_size: int, optional (default=1000)
237240
Size of the batches for parallel processing.
238241
n_jobs: int, optional (default=1)
239242
Number of parallel jobs to run for neighbors search.
240243
"""
241-
self.target_fingerprints = target_fingerprints
242-
if k is None:
243-
self.k = self.target_fingerprints.shape[0]
244-
else:
245-
self.k = k
244+
self.target_fingerprints: csr_matrix | None = None
245+
self.k = k
246246
self.batch_size = batch_size
247-
if n_jobs == -1:
248-
self.n_jobs = multiprocessing.cpu_count()
249-
else:
250-
self.n_jobs = n_jobs
247+
self.n_jobs = check_available_cores(n_jobs)
248+
self.knn_reduce_function: (
249+
Callable[
250+
[npt.NDArray[np.float64]],
251+
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]],
252+
]
253+
| None
254+
) = None
255+
256+
def fit(
257+
self,
258+
X: sparse.csr_matrix, # pylint: disable=invalid-name
259+
y: Sequence[Any] | None = None, # pylint: disable=invalid-name
260+
) -> Self:
261+
"""Fit the estimator using X as target fingerprint data set.
262+
263+
Parameters
264+
----------
265+
X : sparse.csr_matrix
266+
The target fingerprint data set. By calling `predict`, searches are performed
267+
against this target data set.
268+
y : Sequence[Any]
269+
Target values. Here values are used as returned nearest neighbors.
270+
Must have the same length as X.
271+
Will be stored as the learned_names_ attribute as npt.NDArray[Any].
272+
273+
Returns
274+
-------
275+
Self
276+
The instance itself.
277+
278+
Raises
279+
------
280+
ValueError
281+
If the input arrays have different lengths or do not have a shape nor len attribute.
282+
"""
283+
if y is None:
284+
y = list(range(X.shape[0]))
285+
if X.shape[0] != get_length(y):
286+
raise ValueError("X and y must have the same length.")
287+
288+
if self.k is None:
289+
# set k to the number of target fingerprints if k is None
290+
self.k = X.shape[0]
291+
292+
# determine the recude function dependent on the value of k
251293
if self.k == 1:
252294
self.knn_reduce_function = self._reduce_k_equals_1
253-
elif self.k < self.target_fingerprints.shape[0]:
295+
elif self.k < X.shape[0]:
254296
self.knn_reduce_function = self._reduce_k_greater_1_less_n
255297
else:
256-
self.knn_reduce_function = self._reduct_to_indices_k_equals_n
298+
self.knn_reduce_function = self._reduce_k_equals_n
299+
300+
self.target_indices_mapping_ = np.array(y)
301+
self.target_fingerprints = X
302+
return self
257303

258304
@staticmethod
259305
def _reduce_k_equals_1(
@@ -320,7 +366,7 @@ def _reduce_k_greater_1_less_n(
320366
return topk_indices_sorted, topk_similarities_sorted
321367

322368
@staticmethod
323-
def _reduct_to_indices_k_equals_n(
369+
def _reduce_k_equals_n(
324370
similarity_matrix: npt.NDArray[np.float64],
325371
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
326372
"""Reduce similarity matrix to k=n nearest neighbors.
@@ -360,11 +406,16 @@ def _process_batch(
360406
query_batch, self.target_fingerprints
361407
)
362408

409+
if self.knn_reduce_function is None:
410+
raise AssertionError(
411+
"The knn_reduce_function has not been set. This should happen in the fit function."
412+
)
363413
# reduce the similarity matrix to the k nearest neighbors
364414
return self.knn_reduce_function(similarity_mat_chunk)
365415

366416
def predict(
367-
self, query_fingerprints: sparse.csr_matrix
417+
self,
418+
query_fingerprints: sparse.csr_matrix,
368419
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
369420
"""Predict the k-nearest neighbors of the query fingerprints.
370421
@@ -378,6 +429,12 @@ def predict(
378429
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]
379430
Indices of the k-nearest neighbors in the target fingerprints and the corresponding similarities.
380431
"""
432+
if self.target_fingerprints is None:
433+
raise ValueError("The model has not been fitted yet.")
434+
if self.k is None:
435+
raise AssertionError(
436+
"The number of neighbors k has not been set. This should happen in the fit function."
437+
)
381438
if query_fingerprints.shape[1] != self.target_fingerprints.shape[1]:
382439
raise ValueError(
383440
"The number of features in the query fingerprints does not match the number of features in the target fingerprints."

molpipeline/utils/multi_proc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,6 @@ def check_available_cores(n_requested_cores: int) -> int:
3535
)
3636
return n_available_cores
3737
if n_requested_cores < 0:
38-
return n_available_cores
38+
return n_available_cores + 1 + n_requested_cores
3939

4040
return n_requested_cores

tests/test_estimators/test_nearest_neighbors.py

+55-27
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from molpipeline import ErrorFilter, FilterReinserter, Pipeline, PostPredictionWrapper
1010
from molpipeline.any2mol import SmilesToMol
1111
from molpipeline.estimators import NamedNearestNeighbors, TanimotoToTraining
12-
from molpipeline.estimators.nearest_neighbor import NearestNeighborsRetrieverTanimoto
12+
from molpipeline.estimators.nearest_neighbor import TanimotoKNN
1313
from molpipeline.mol2any import MolToMorganFP
1414
from molpipeline.utils.kernel import tanimoto_distance_sparse
1515

@@ -222,8 +222,8 @@ def test_fit_and_predict_invalid_with_distance(self) -> None:
222222
)
223223

224224

225-
class TestNearestNeighborsRetrieverTanimoto(TestCase):
226-
"""Test nearest neighbors retriever with tanimoto."""
225+
class TestTanimotoKNN(TestCase):
226+
"""Test TanimotoKNN estimator."""
227227

228228
example_fingerprints: csr_matrix
229229

@@ -243,16 +243,16 @@ def test_k_equals_1(self) -> None:
243243
target_fps = self.example_fingerprints
244244
query_fps = self.example_fingerprints
245245

246-
retriever = NearestNeighborsRetrieverTanimoto(target_fps, k=1)
247-
indices, similarities = retriever.predict(query_fps)
246+
knn = TanimotoKNN(k=1)
247+
knn.fit(target_fps)
248+
indices, similarities = knn.predict(query_fps)
248249
self.assertTrue(np.array_equal(indices, np.array([0, 1, 2, 3])))
249250
self.assertTrue(np.allclose(similarities, np.array([1, 1, 1, 1])))
250251

251252
# test parallel
252-
retriever = NearestNeighborsRetrieverTanimoto(
253-
target_fps, k=1, n_jobs=2, batch_size=2
254-
)
255-
indices, similarities = retriever.predict(query_fps)
253+
knn = TanimotoKNN(k=1, n_jobs=2, batch_size=2)
254+
knn.fit(target_fps)
255+
indices, similarities = knn.predict(query_fps)
256256
self.assertTrue(np.array_equal(indices, np.array([0, 1, 2, 3])))
257257
self.assertTrue(np.allclose(similarities, np.array([1, 1, 1, 1])))
258258

@@ -261,18 +261,18 @@ def test_k_greater_1_less_n(self) -> None:
261261
target_fps = self.example_fingerprints
262262
query_fps = self.example_fingerprints
263263

264-
retriever = NearestNeighborsRetrieverTanimoto(target_fps, k=2)
265-
indices, similarities = retriever.predict(query_fps)
264+
knn = TanimotoKNN(k=2)
265+
knn.fit(target_fps)
266+
indices, similarities = knn.predict(query_fps)
266267
self.assertTrue(
267268
np.array_equal(indices, np.array([[0, 1], [1, 0], [2, 3], [3, 2]]))
268269
)
269270
self.assertTrue(np.allclose(similarities, TWO_NN_SIMILARITIES))
270271

271272
# test parallel
272-
retriever = NearestNeighborsRetrieverTanimoto(
273-
target_fps, k=2, n_jobs=2, batch_size=2
274-
)
275-
indices, similarities = retriever.predict(query_fps)
273+
knn = TanimotoKNN(k=2, n_jobs=2, batch_size=2)
274+
knn.fit(target_fps)
275+
indices, similarities = knn.predict(query_fps)
276276
self.assertTrue(
277277
np.array_equal(indices, np.array([[0, 1], [1, 0], [2, 3], [3, 2]]))
278278
)
@@ -283,8 +283,9 @@ def test_k_equals_n(self) -> None:
283283
target_fps = self.example_fingerprints
284284
query_fps = self.example_fingerprints
285285

286-
retriever = NearestNeighborsRetrieverTanimoto(target_fps, k=target_fps.shape[0])
287-
indices, similarities = retriever.predict(query_fps)
286+
knn = TanimotoKNN(k=target_fps.shape[0])
287+
knn.fit(target_fps)
288+
indices, similarities = knn.predict(query_fps)
288289
self.assertTrue(
289290
np.array_equal(
290291
indices,
@@ -294,10 +295,9 @@ def test_k_equals_n(self) -> None:
294295
self.assertTrue(np.allclose(similarities, FOUR_NN_SIMILARITIES))
295296

296297
# test parallel
297-
retriever = NearestNeighborsRetrieverTanimoto(
298-
target_fps, k=target_fps.shape[0], n_jobs=2, batch_size=2
299-
)
300-
indices, similarities = retriever.predict(query_fps)
298+
knn = TanimotoKNN(k=target_fps.shape[0], n_jobs=2, batch_size=2)
299+
knn.fit(target_fps)
300+
indices, similarities = knn.predict(query_fps)
301301
self.assertTrue(
302302
np.array_equal(
303303
indices,
@@ -306,9 +306,37 @@ def test_k_equals_n(self) -> None:
306306
)
307307
self.assertTrue(np.allclose(similarities, FOUR_NN_SIMILARITIES))
308308

309-
# [
310-
# [1.0, 3 / 14, 0.0, 0.0],
311-
# [1.0, 3 / 14, 0.038461538461538464, 0.0],
312-
# [1.0, 4 / 9, 0.0, 0.0],
313-
# [1.0, 4 / 9, 0.038461538461538464, 0.0],
314-
# ]
309+
def test_pipeline(self) -> None:
310+
"""Test TanimotoKNN in a pipeline."""
311+
# test normal pipeline
312+
pipeline = Pipeline(
313+
[
314+
("mol", SmilesToMol()),
315+
("fingerprint", MolToMorganFP()),
316+
("knn", TanimotoKNN(k=1)),
317+
]
318+
)
319+
pipeline.fit(TEST_SMILES)
320+
indices, similarities = pipeline.predict(TEST_SMILES)
321+
self.assertTrue(np.array_equal(indices, np.array([0, 1, 2, 3])))
322+
self.assertTrue(np.allclose(similarities, np.array([1, 1, 1, 1])))
323+
324+
# test pipeline with failing smiles
325+
test_smiles = [
326+
"c1ccccc1",
327+
"c1cc(-C(=O)O)ccc1",
328+
"I am a failing smiles :)",
329+
"CCCCCCN",
330+
"CCCCCCO",
331+
]
332+
pipeline = Pipeline(
333+
[
334+
("mol", SmilesToMol()),
335+
("error_filter", ErrorFilter(filter_everything=True)),
336+
("fingerprint", MolToMorganFP()),
337+
("knn", TanimotoKNN(k=1)),
338+
]
339+
)
340+
pipeline.fit(test_smiles)
341+
indices, similarities = pipeline.predict(test_smiles)
342+
todo assert right result

0 commit comments

Comments
 (0)