- 
                Notifications
    You must be signed in to change notification settings 
- Fork 674
Add neighbors_from_distance for computing neighborhood graphs from precomputed distance matrices #3627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add neighbors_from_distance for computing neighborhood graphs from precomputed distance matrices #3627
Changes from all commits
f76dc7b
              f092469
              7ffa1ec
              c0d0c52
              68652a7
              948319a
              6a64330
              793351f
              92d8e26
              198c4fb
              e7fb67a
              14cb441
              0ce8c15
              914b87d
              d285203
              50705b3
              4730667
              040b8b7
              c03b863
              473a437
              ec586df
              43dcfc0
              8a3588c
              293f568
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Added `neighbors_from_distance`, function for computing graphs from a precoputing distance matrix using UMAP or Gaussian methods. {smaller}`A. Karesh` | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -21,15 +21,17 @@ | |
| from .._utils import NeighborsView, _doc_params, get_literal_vals | ||
| from . import _connectivity | ||
| from ._common import ( | ||
| _get_indices_distances_from_dense_matrix, | ||
| _get_indices_distances_from_sparse_matrix, | ||
| _get_sparse_matrix_from_indices_distances, | ||
| ) | ||
| from ._connectivity import umap | ||
| from ._doc import doc_n_pcs, doc_use_rep | ||
| from ._types import _KnownTransformer, _Method | ||
|  | ||
| if TYPE_CHECKING: | ||
| from collections.abc import Callable, MutableMapping | ||
| from typing import Any, Literal, NotRequired | ||
| from typing import Any, Literal, NotRequired, Unpack | ||
|  | ||
| from anndata import AnnData | ||
| from igraph import Graph | ||
|  | @@ -58,11 +60,18 @@ class KwdsForTransformer(TypedDict): | |
| random_state: _LegacyRandom | ||
|  | ||
|  | ||
| class NeighborsDict(TypedDict): # noqa: D101 | ||
| connectivities_key: str | ||
| distances_key: str | ||
| params: NeighborsParams | ||
| rp_forest: NotRequired[RPForestDict] | ||
|  | ||
|  | ||
| class NeighborsParams(TypedDict): # noqa: D101 | ||
| n_neighbors: int | ||
| method: _Method | ||
| random_state: _LegacyRandom | ||
| metric: _Metric | _MetricFn | ||
| metric: _Metric | _MetricFn | None | ||
| metric_kwds: NotRequired[Mapping[str, Any]] | ||
| use_rep: NotRequired[str] | ||
| n_pcs: NotRequired[int] | ||
|  | @@ -74,11 +83,12 @@ def neighbors( # noqa: PLR0913 | |
| n_neighbors: int = 15, | ||
| n_pcs: int | None = None, | ||
| *, | ||
| distances: np.ndarray | SpBase | None = None, | ||
| use_rep: str | None = None, | ||
| knn: bool = True, | ||
| method: _Method = "umap", | ||
| transformer: KnnTransformerLike | _KnownTransformer | None = None, | ||
| metric: _Metric | _MetricFn = "euclidean", | ||
| metric: _Metric | _MetricFn | None = None, | ||
| metric_kwds: Mapping[str, Any] = MappingProxyType({}), | ||
| random_state: _LegacyRandom = 0, | ||
| key_added: str | None = None, | ||
|  | @@ -135,6 +145,8 @@ def neighbors( # noqa: PLR0913 | |
| Use :func:`rapids_singlecell.pp.neighbors` instead. | ||
| metric | ||
| A known metric’s name or a callable that returns a distance. | ||
| If `distances` is given, this parameter is simply stored in `.uns` (see below), | ||
| otherwise defaults to `'euclidean'`. | ||
|  | ||
| *ignored if ``transformer`` is an instance.* | ||
| metric_kwds | ||
|  | @@ -186,6 +198,20 @@ def neighbors( # noqa: PLR0913 | |
| :doc:`/how-to/knn-transformers` | ||
|  | ||
| """ | ||
| if distances is not None: | ||
| if callable(metric): | ||
| msg = "`metric` must be a string if `distances` is given." | ||
| raise TypeError(msg) | ||
| # if a precomputed distance matrix is provided, skip the PCA and distance computation | ||
| return neighbors_from_distance( | ||
| adata, | ||
| distances, | ||
| n_neighbors=n_neighbors, | ||
| metric=metric, | ||
| method=method, | ||
| ) | ||
| if metric is None: | ||
| metric = "euclidean" | ||
| start = logg.info("computing neighbors") | ||
| adata = adata.copy() if copy else adata | ||
| if adata.is_view: # we shouldn't need this here... | ||
|  | @@ -203,51 +229,124 @@ def neighbors( # noqa: PLR0913 | |
| random_state=random_state, | ||
| ) | ||
|  | ||
| if key_added is None: | ||
| key_added = "neighbors" | ||
| conns_key = "connectivities" | ||
| dists_key = "distances" | ||
| else: | ||
| conns_key = key_added + "_connectivities" | ||
| dists_key = key_added + "_distances" | ||
|  | ||
| adata.uns[key_added] = {} | ||
|  | ||
| neighbors_dict = adata.uns[key_added] | ||
|  | ||
| neighbors_dict["connectivities_key"] = conns_key | ||
| neighbors_dict["distances_key"] = dists_key | ||
|  | ||
| neighbors_dict["params"] = NeighborsParams( | ||
| key_added, neighbors_dict = _get_metadata( | ||
| key_added, | ||
| n_neighbors=neighbors.n_neighbors, | ||
| method=method, | ||
| random_state=random_state, | ||
| metric=metric, | ||
| **({} if not metric_kwds else dict(metric_kwds=metric_kwds)), | ||
| **({} if use_rep is None else dict(use_rep=use_rep)), | ||
| **({} if n_pcs is None else dict(n_pcs=n_pcs)), | ||
| ) | ||
| if metric_kwds: | ||
| neighbors_dict["params"]["metric_kwds"] = metric_kwds | ||
| if use_rep is not None: | ||
| neighbors_dict["params"]["use_rep"] = use_rep | ||
| if n_pcs is not None: | ||
| neighbors_dict["params"]["n_pcs"] = n_pcs | ||
|  | ||
| adata.obsp[dists_key] = neighbors.distances | ||
| adata.obsp[conns_key] = neighbors.connectivities | ||
|  | ||
| if neighbors.rp_forest is not None: | ||
| neighbors_dict["rp_forest"] = neighbors.rp_forest | ||
|  | ||
| adata.uns[key_added] = neighbors_dict | ||
| adata.obsp[neighbors_dict["distances_key"]] = neighbors.distances | ||
| adata.obsp[neighbors_dict["connectivities_key"]] = neighbors.connectivities | ||
|  | ||
| logg.info( | ||
| " finished", | ||
| time=start, | ||
| deep=( | ||
| f"added to `.uns[{key_added!r}]`\n" | ||
| f" `.obsp[{dists_key!r}]`, distances for each pair of neighbors\n" | ||
| f" `.obsp[{conns_key!r}]`, weighted adjacency matrix" | ||
| f" `.obsp[{neighbors_dict['distances_key']!r}]`, distances for each pair of neighbors\n" | ||
| f" `.obsp[{neighbors_dict['connectivities_key']!r}]`, weighted adjacency matrix" | ||
| ), | ||
| ) | ||
| return adata if copy else None | ||
|  | ||
|  | ||
| def neighbors_from_distance( | ||
| adata: AnnData, | ||
| distances: np.ndarray | SpBase, | ||
| *, | ||
| n_neighbors: int = 15, | ||
| metric: _Metric | None = None, | ||
| method: _Method = "umap", # default to umap | ||
| key_added: str | None = None, | ||
| 
      Comment on lines
    
      +266
     to 
      +269
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please remove the defaults here, and fix the bug that gets uncovered by this action. | ||
| ) -> AnnData: | ||
| """Compute neighbors from a precomputer distance matrix. | ||
|  | ||
| Parameters | ||
| ---------- | ||
| adata | ||
| Annotated data matrix. | ||
| distances | ||
| Precomputed dense or sparse distance matrix. | ||
| n_neighbors | ||
| Number of nearest neighbors to use in the graph. | ||
| metric | ||
| Name of metric used to compute `distances`. | ||
| method | ||
| Method to use for computing the graph. Currently only `'umap'` is supported. | ||
| key_added | ||
| Optional key under which to store the results. Default is 'neighbors'. | ||
|  | ||
| Returns | ||
| ------- | ||
| adata | ||
| Annotated data with computed distances and connectivities. | ||
| """ | ||
| if isinstance(distances, SpBase): | ||
| distances = sparse.csr_matrix(distances) # noqa: TID251 | ||
| distances.setdiag(0) | ||
| distances.eliminate_zeros() | ||
| else: | ||
| distances = np.asarray(distances) | ||
| np.fill_diagonal(distances, 0) | ||
|  | ||
| if method == "umap": | ||
| if isinstance(distances, CSRBase): | ||
| knn_indices, knn_distances = _get_indices_distances_from_sparse_matrix( | ||
| distances, n_neighbors | ||
| ) | ||
| else: | ||
| knn_indices, knn_distances = _get_indices_distances_from_dense_matrix( | ||
| distances, n_neighbors | ||
| ) | ||
| connectivities = umap( | ||
| knn_indices, knn_distances, n_obs=adata.n_obs, n_neighbors=n_neighbors | ||
| ) | ||
| elif method == "gauss": | ||
| distances = sparse.csr_matrix(distances) # noqa: TID251 | ||
| connectivities = _connectivity.gauss(distances, n_neighbors, knn=True) | ||
| else: | ||
| msg = f"Method {method} not implemented." | ||
| raise NotImplementedError(msg) | ||
|  | ||
| key_added, neighbors_dict = _get_metadata( | ||
| key_added, | ||
| n_neighbors=n_neighbors, | ||
| method=method, | ||
| random_state=0, | ||
| metric=metric, | ||
| ) | ||
| adata.uns[key_added] = neighbors_dict | ||
| adata.obsp[neighbors_dict["distances_key"]] = distances | ||
| adata.obsp[neighbors_dict["connectivities_key"]] = connectivities | ||
| return adata | ||
|  | ||
|  | ||
| def _get_metadata( | ||
| key_added: str | None, | ||
| **params: Unpack[NeighborsParams], | ||
| ) -> tuple[str, NeighborsDict]: | ||
| if key_added is None: | ||
| return "neighbors", NeighborsDict( | ||
| connectivities_key="connectivities", | ||
| distances_key="distances", | ||
| params=params, | ||
| ) | ||
| return key_added, NeighborsDict( | ||
| connectivities_key=f"{key_added}_connectivities", | ||
| distances_key=f"{key_added}_distances", | ||
| params=params, | ||
| ) | ||
|  | ||
|  | ||
| class FlatTree(NamedTuple): # noqa: D101 | ||
| hyperplanes: None | ||
| offsets: None | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -13,6 +13,7 @@ | |
| from scanpy import Neighbors | ||
| from scanpy._compat import CSBase | ||
| from testing.scanpy._helpers import anndata_v0_8_constructor_compat | ||
| from testing.scanpy._helpers.data import pbmc68k_reduced | ||
|  | ||
| if TYPE_CHECKING: | ||
| from typing import Literal | ||
|  | @@ -241,3 +242,26 @@ def test_restore_n_neighbors(neigh, conv): | |
| ad.uns["neighbors"] = dict(connectivities=conv(neigh.connectivities)) | ||
| neigh_restored = Neighbors(ad) | ||
| assert neigh_restored.n_neighbors == 1 | ||
|  | ||
|  | ||
| def test_neighbors_distance_equivalence(): | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please parametrize this test over  | ||
| adata = pbmc68k_reduced() | ||
| adata_d = adata.copy() | ||
|  | ||
| sc.pp.neighbors(adata) | ||
| # reusing the same distances | ||
| sc.pp.neighbors(adata_d, distances=adata.obsp["distances"]) | ||
| np.testing.assert_allclose( | ||
| adata.obsp["connectivities"].toarray(), | ||
| adata_d.obsp["connectivities"].toarray(), | ||
| rtol=1e-5, | ||
| ) | ||
| np.testing.assert_allclose( | ||
| adata.obsp["distances"].toarray(), | ||
| adata_d.obsp["distances"].toarray(), | ||
| rtol=1e-5, | ||
| ) | ||
| p, p_d = (ad.uns["neighbors"]["params"].copy() for ad in (adata, adata_d)) | ||
| assert p.pop("metric") == "euclidean" | ||
| assert p_d.pop("metric") is None | ||
| assert p == p_d | ||
Uh oh!
There was an error while loading. Please reload this page.