9
9
from molpipeline import ErrorFilter , FilterReinserter , Pipeline , PostPredictionWrapper
10
10
from molpipeline .any2mol import SmilesToMol
11
11
from molpipeline .estimators import NamedNearestNeighbors , TanimotoToTraining
12
- from molpipeline .estimators .nearest_neighbor import NearestNeighborsRetrieverTanimoto
12
+ from molpipeline .estimators .nearest_neighbor import TanimotoKNN
13
13
from molpipeline .mol2any import MolToMorganFP
14
14
from molpipeline .utils .kernel import tanimoto_distance_sparse
15
15
@@ -222,8 +222,8 @@ def test_fit_and_predict_invalid_with_distance(self) -> None:
222
222
)
223
223
224
224
225
- class TestNearestNeighborsRetrieverTanimoto (TestCase ):
226
- """Test nearest neighbors retriever with tanimoto ."""
225
+ class TestTanimotoKNN (TestCase ):
226
+ """Test TanimotoKNN estimator ."""
227
227
228
228
example_fingerprints : csr_matrix
229
229
@@ -243,16 +243,16 @@ def test_k_equals_1(self) -> None:
243
243
target_fps = self .example_fingerprints
244
244
query_fps = self .example_fingerprints
245
245
246
- retriever = NearestNeighborsRetrieverTanimoto (target_fps , k = 1 )
247
- indices , similarities = retriever .predict (query_fps )
246
+ knn = TanimotoKNN (k = 1 )
247
+ knn .fit (target_fps )
248
+ indices , similarities = knn .predict (query_fps )
248
249
self .assertTrue (np .array_equal (indices , np .array ([0 , 1 , 2 , 3 ])))
249
250
self .assertTrue (np .allclose (similarities , np .array ([1 , 1 , 1 , 1 ])))
250
251
251
252
# test parallel
252
- retriever = NearestNeighborsRetrieverTanimoto (
253
- target_fps , k = 1 , n_jobs = 2 , batch_size = 2
254
- )
255
- indices , similarities = retriever .predict (query_fps )
253
+ knn = TanimotoKNN (k = 1 , n_jobs = 2 , batch_size = 2 )
254
+ knn .fit (target_fps )
255
+ indices , similarities = knn .predict (query_fps )
256
256
self .assertTrue (np .array_equal (indices , np .array ([0 , 1 , 2 , 3 ])))
257
257
self .assertTrue (np .allclose (similarities , np .array ([1 , 1 , 1 , 1 ])))
258
258
@@ -261,18 +261,18 @@ def test_k_greater_1_less_n(self) -> None:
261
261
target_fps = self .example_fingerprints
262
262
query_fps = self .example_fingerprints
263
263
264
- retriever = NearestNeighborsRetrieverTanimoto (target_fps , k = 2 )
265
- indices , similarities = retriever .predict (query_fps )
264
+ knn = TanimotoKNN (k = 2 )
265
+ knn .fit (target_fps )
266
+ indices , similarities = knn .predict (query_fps )
266
267
self .assertTrue (
267
268
np .array_equal (indices , np .array ([[0 , 1 ], [1 , 0 ], [2 , 3 ], [3 , 2 ]]))
268
269
)
269
270
self .assertTrue (np .allclose (similarities , TWO_NN_SIMILARITIES ))
270
271
271
272
# test parallel
272
- retriever = NearestNeighborsRetrieverTanimoto (
273
- target_fps , k = 2 , n_jobs = 2 , batch_size = 2
274
- )
275
- indices , similarities = retriever .predict (query_fps )
273
+ knn = TanimotoKNN (k = 2 , n_jobs = 2 , batch_size = 2 )
274
+ knn .fit (target_fps )
275
+ indices , similarities = knn .predict (query_fps )
276
276
self .assertTrue (
277
277
np .array_equal (indices , np .array ([[0 , 1 ], [1 , 0 ], [2 , 3 ], [3 , 2 ]]))
278
278
)
@@ -283,8 +283,9 @@ def test_k_equals_n(self) -> None:
283
283
target_fps = self .example_fingerprints
284
284
query_fps = self .example_fingerprints
285
285
286
- retriever = NearestNeighborsRetrieverTanimoto (target_fps , k = target_fps .shape [0 ])
287
- indices , similarities = retriever .predict (query_fps )
286
+ knn = TanimotoKNN (k = target_fps .shape [0 ])
287
+ knn .fit (target_fps )
288
+ indices , similarities = knn .predict (query_fps )
288
289
self .assertTrue (
289
290
np .array_equal (
290
291
indices ,
@@ -294,10 +295,9 @@ def test_k_equals_n(self) -> None:
294
295
self .assertTrue (np .allclose (similarities , FOUR_NN_SIMILARITIES ))
295
296
296
297
# test parallel
297
- retriever = NearestNeighborsRetrieverTanimoto (
298
- target_fps , k = target_fps .shape [0 ], n_jobs = 2 , batch_size = 2
299
- )
300
- indices , similarities = retriever .predict (query_fps )
298
+ knn = TanimotoKNN (k = target_fps .shape [0 ], n_jobs = 2 , batch_size = 2 )
299
+ knn .fit (target_fps )
300
+ indices , similarities = knn .predict (query_fps )
301
301
self .assertTrue (
302
302
np .array_equal (
303
303
indices ,
@@ -306,9 +306,37 @@ def test_k_equals_n(self) -> None:
306
306
)
307
307
self .assertTrue (np .allclose (similarities , FOUR_NN_SIMILARITIES ))
308
308
309
- # [
310
- # [1.0, 3 / 14, 0.0, 0.0],
311
- # [1.0, 3 / 14, 0.038461538461538464, 0.0],
312
- # [1.0, 4 / 9, 0.0, 0.0],
313
- # [1.0, 4 / 9, 0.038461538461538464, 0.0],
314
- # ]
309
+ def test_pipeline (self ) -> None :
310
+ """Test TanimotoKNN in a pipeline."""
311
+ # test normal pipeline
312
+ pipeline = Pipeline (
313
+ [
314
+ ("mol" , SmilesToMol ()),
315
+ ("fingerprint" , MolToMorganFP ()),
316
+ ("knn" , TanimotoKNN (k = 1 )),
317
+ ]
318
+ )
319
+ pipeline .fit (TEST_SMILES )
320
+ indices , similarities = pipeline .predict (TEST_SMILES )
321
+ self .assertTrue (np .array_equal (indices , np .array ([0 , 1 , 2 , 3 ])))
322
+ self .assertTrue (np .allclose (similarities , np .array ([1 , 1 , 1 , 1 ])))
323
+
324
+ # test pipeline with failing smiles
325
+ test_smiles = [
326
+ "c1ccccc1" ,
327
+ "c1cc(-C(=O)O)ccc1" ,
328
+ "I am a failing smiles :)" ,
329
+ "CCCCCCN" ,
330
+ "CCCCCCO" ,
331
+ ]
332
+ pipeline = Pipeline (
333
+ [
334
+ ("mol" , SmilesToMol ()),
335
+ ("error_filter" , ErrorFilter (filter_everything = True )),
336
+ ("fingerprint" , MolToMorganFP ()),
337
+ ("knn" , TanimotoKNN (k = 1 )),
338
+ ]
339
+ )
340
+ pipeline .fit (test_smiles )
341
+ indices , similarities = pipeline .predict (test_smiles )
342
+ todo assert right result
0 commit comments