Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset similarity #122

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 213 additions & 1 deletion molpipeline/estimators/nearest_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@

from __future__ import annotations

import multiprocessing
from typing import Any, Callable, Literal, Sequence, Union

from joblib import Parallel, delayed

try:
from typing import Self
except ImportError:
from typing_extensions import Self


import numpy as np
import numpy.typing as npt
from scipy import sparse
from sklearn.neighbors import NearestNeighbors

from molpipeline.utils.kernel import tanimoto_similarity_sparse
from molpipeline.utils.value_checks import get_length

__all__ = ["NamedNearestNeighbors"]
Expand Down Expand Up @@ -202,3 +205,212 @@ def fit_predict(
"""
self.fit(X, y)
return self.predict(X, return_distance=return_distance, n_neighbors=n_neighbors)


class NearestNeighborsRetrieverTanimoto: # pylint: disable=too-few-public-methods
"""k-nearest neighbors between data sets using Tanimoto similarity.

This class uses the Tanimoto similarity to find the k-nearest neighbors of a query set in a target set.
The full similarity matrix is computed and reduced to the k-nearest neighbors. A dot-product based
algorithm is used, which is faster than using the RDKit native Tanimoto function.

For handling larger datasets, the computation can be batched to reduce memory usage. In addition,
the batches can be processed in parallel using joblib.
"""

def __init__(
self,
target_fingerprints: sparse.csr_matrix,
k: int | None = None,
batch_size: int = 1000,
n_jobs: int = 1,
):
"""Initialize NearestNeighborsRetrieverTanimoto.

Parameters
----------
target_fingerprints: sparse.csr_matrix
Fingerprints of target molecules. Must be a binary sparse matrix.
k: int, optional (default=None)
Number of nearest neighbors to find. If None, all neighbors are returned.
batch_size: int, optional (default=1000)
Size of the batches for parallel processing.
n_jobs: int, optional (default=1)
Number of parallel jobs to run for neighbors search.
"""
self.target_fingerprints = target_fingerprints
if k is None:
self.k = self.target_fingerprints.shape[0]
else:
self.k = k
self.batch_size = batch_size
if n_jobs == -1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.n_jobs = multiprocessing.cpu_count()
else:
self.n_jobs = n_jobs
if self.k == 1:
self.knn_reduce_function = self._reduce_k_equals_1
elif self.k < self.target_fingerprints.shape[0]:
self.knn_reduce_function = self._reduce_k_greater_1_less_n
else:
self.knn_reduce_function = self._reduct_to_indices_k_equals_n

@staticmethod
def _reduce_k_equals_1(
similarity_matrix: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Reduce similarity matrix to k=1 nearest neighbors.

Uses argmax to find the index of the nearest neighbor in the target fingerprints.
This function has therefore O(n) time complexity.

Parameters
----------
similarity_matrix: npt.NDArray[np.float64]
Similarity matrix of Tanimoto scores between query and target fingerprints.

Returns
-------
npt.NDArray[np.int64]
Indices of the query's nearest neighbors in the target fingerprints.
"""
topk_indices = np.argmax(similarity_matrix, axis=1)
topk_similarities = np.take_along_axis(
similarity_matrix, topk_indices.reshape(-1, 1), axis=1
).squeeze()
return topk_indices, topk_similarities

def _reduce_k_greater_1_less_n(
self,
similarity_matrix: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Reduce similarity matrix to k>1 and k<n nearest neighbors.

Uses argpartition to find the k-nearest neighbors in the target fingerprints, which uses a linear
partial sort algorithm. The top k hits must be sorted afterward to get the k-nearest neighbors
in descending order. This function has therefore O(n + k log k) time complexity.

The indices are sorted descending by similarity.

Parameters
----------
similarity_matrix: npt.NDArray[np.float64]
Similarity matrix of Tanimoto scores between query and target fingerprints.

Returns
-------
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]
Indices of the query's k-nearest neighbors in the target fingerprints and
the corresponding similarities.
"""
# Get the indices of the k-nearest neighbors. argpartition returns them unsorted.
topk_indices = np.argpartition(similarity_matrix, kth=-self.k, axis=1)[
:, -self.k :
]
topk_similarities = np.take_along_axis(similarity_matrix, topk_indices, axis=1)
# sort the topk_indices descending by similarity
topk_indices_sorted = np.take_along_axis(
topk_indices,
np.fliplr(topk_similarities.argsort(axis=1, kind="stable")),
axis=1,
)
topk_similarities_sorted = np.take_along_axis(
similarity_matrix, topk_indices_sorted, axis=1
)
return topk_indices_sorted, topk_similarities_sorted

@staticmethod
def _reduct_to_indices_k_equals_n(
similarity_matrix: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Reduce similarity matrix to k=n nearest neighbors.

Parameters
----------
similarity_matrix: npt.NDArray[np.float64]
Similarity matrix of Tanimoto scores between query and target fingerprints.

Returns
-------
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]
Indices of the query's k-nearest neighbors in the target fingerprints and
the corresponding similarities.
"""
indices = np.fliplr(similarity_matrix.argsort(axis=1, kind="stable"))
similarities = np.take_along_axis(similarity_matrix, indices, axis=1)
return indices, similarities

def _process_batch(
self, query_batch: sparse.csr_matrix
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Process a batch of query fingerprints.

Parameters
----------
query_batch: sparse.csr_matrix
Batch of query fingerprints.

Returns
-------
tuple
Indices of the k-nearest neighbors in the target fingerprints and the corresponding similarities.
"""
# compute full similarity matrix for the query batch
similarity_mat_chunk = tanimoto_similarity_sparse(
query_batch, self.target_fingerprints
)

# reduce the similarity matrix to the k nearest neighbors
return self.knn_reduce_function(similarity_mat_chunk)

def predict(
self, query_fingerprints: sparse.csr_matrix
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
"""Predict the k-nearest neighbors of the query fingerprints.

Parameters
----------
query_fingerprints: sparse.csr_matrix
Query fingerprints.

Returns
-------
tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]
Indices of the k-nearest neighbors in the target fingerprints and the corresponding similarities.
"""
if query_fingerprints.shape[1] != self.target_fingerprints.shape[1]:
raise ValueError(
"The number of features in the query fingerprints does not match the number of features in the target fingerprints."
)
if self.n_jobs > 1:
# parallel execution
with Parallel(n_jobs=self.n_jobs) as parallel:
# the parallelization is not optimal: the self.target_fingerprints (and query_fingerprints) are copied to each child process worker
# -> joblib does some behind the scenes mmapping but copying the full matrices is probably a memory bottleneck.
# If Python removes the GIL this here would be a good use case for threading with zero copies.
res = parallel(
delayed(self._process_batch)(
query_fingerprints[i : i + self.batch_size, :]
)
for i in range(0, query_fingerprints.shape[0], self.batch_size)
)
result_indices_tmp, result_similarities_tmp = zip(*res)
result_indices = np.concatenate(result_indices_tmp)
result_similarities = np.concatenate(result_similarities_tmp)
else:
# single process execution
result_shape = (
(query_fingerprints.shape[0], self.k)
if self.k > 1
else (query_fingerprints.shape[0],)
)
result_indices = np.full(result_shape, -1, dtype=np.int64)
result_similarities = np.full(result_shape, np.nan, dtype=np.float64)
for i in range(0, query_fingerprints.shape[0], self.batch_size):
query_batch = query_fingerprints[i : i + self.batch_size, :]
(
result_indices[i : i + self.batch_size],
result_similarities[i : i + self.batch_size],
) = self._process_batch(query_batch)

return result_indices, result_similarities
8 changes: 6 additions & 2 deletions molpipeline/utils/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ def tanimoto_similarity_sparse(
Matrix of similarity values between instances of A (rows/first dim) , and instances of B (columns/second dim).
"""
intersection = matrix_a.dot(matrix_b.transpose()).toarray()
norm_1 = np.array(matrix_a.multiply(matrix_a).sum(axis=1))
norm_2 = np.array(matrix_b.multiply(matrix_b).sum(axis=1))
norm_1 = np.array(matrix_a.sum(axis=1))
if matrix_a is matrix_b:
# avoid calculating the same norm twice
norm_2 = norm_1
else:
norm_2 = np.array(matrix_b.sum(axis=1))
union = norm_1 + norm_2.T - intersection
# avoid division by zero https://stackoverflow.com/a/37977222
return np.divide(
Expand Down
Loading