Skip to content

Commit 3813c06

Browse files
authored
fix: actually fix dask (#114)
1 parent 17e3b35 commit 3813c06

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ overrides.matrix.extras.dependencies = [
9797
overrides.matrix.resolution.features = [
9898
{ if = [ "lowest" ], value = "min-reqs" }, # feature added by hatch-min-requirements
9999
]
100+
overrides.matrix.resolution.dependencies = [
101+
# TODO: move to min dep once this is fixed: https://github.com/tlambert03/hatch-min-requirements/issues/5
102+
{ if = [ "lowest" ], value = "dask==2023.5.1" },
103+
]
100104

101105
[[tool.hatch.envs.hatch-test.matrix]]
102106
python = [ "3.13", "3.11" ]

src/fast_array_utils/_plugins/dask.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
# SPDX-License-Identifier: MPL-2.0
22
from __future__ import annotations
33

4-
from dask.array.dispatch import concatenate_lookup, take_lookup, tensordot_lookup
4+
from dask.array.dispatch import concatenate_lookup, tensordot_lookup
55
from scipy.sparse import sparray, spmatrix
66

77

8+
try:
9+
from dask.array.dispatch import take_lookup
10+
except ImportError:
11+
take_lookup = None
12+
13+
814
# TODO(flying-sheep): upstream
915
# https://github.com/dask/dask/issues/11749
1016
def patch() -> None: # pragma: no cover
@@ -13,9 +19,10 @@ def patch() -> None: # pragma: no cover
1319
See <https://github.com/dask/dask/blob/d9b5c5b0256208f1befe94b26bfa8eaabcd0536d/dask/array/backends.py#L239-L241>
1420
"""
1521
# Avoid patch if already patched or upstream support has been added
16-
if concatenate_lookup.dispatch(sparray) is not concatenate_lookup.dispatch(spmatrix):
22+
if concatenate_lookup.dispatch(sparray) is concatenate_lookup.dispatch(spmatrix):
1723
return
1824

1925
concatenate_lookup.register(sparray, concatenate_lookup.dispatch(spmatrix))
2026
tensordot_lookup.register(sparray, tensordot_lookup.dispatch(spmatrix))
21-
take_lookup.register(sparray, take_lookup.dispatch(spmatrix))
27+
if take_lookup is not None:
28+
take_lookup.register(sparray, take_lookup.dispatch(spmatrix))

0 commit comments

Comments
 (0)