Skip to content

Commit 71aa18e

Browse files
authored
Merge pull request #3 from tlambert03/updates
support more array types, add `imshow`
2 parents 587a6ef + b626d4b commit 71aa18e

23 files changed

+796
-525
lines changed

.github/workflows/ci.yml

+16-2
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,25 @@ jobs:
3333
fail-fast: false
3434
matrix:
3535
os: [ubuntu-latest, macos-latest]
36-
python-version: ["3.9", "3.10", "3.11", "3.12"]
36+
python-version: ["3.10", "3.11"]
37+
38+
test-array-libs:
39+
uses: pyapp-kit/workflows/.github/workflows/test-pyrepo.yml@v2
40+
with:
41+
os: ${{ matrix.os }}
42+
python-version: ${{ matrix.python-version }}
43+
extras: "test,third_party_arrays"
44+
coverage-upload: artifact
45+
qt: pyqt6
46+
strategy:
47+
fail-fast: false
48+
matrix:
49+
os: [ubuntu-latest, macos-latest]
50+
python-version: ["3.9", "3.12"]
3751

3852
upload_coverage:
3953
if: always()
40-
needs: [test]
54+
needs: [test, test-array-libs]
4155
uses: pyapp-kit/workflows/.github/workflows/upload-coverage.yml@v2
4256
secrets:
4357
codecov_token: ${{ secrets.CODECOV_TOKEN }}

README.md

+21-13
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,12 @@
99
Simple, fast-loading, asynchronous, n-dimensional array viewer for Qt, with minimal dependencies.
1010

1111
```python
12-
from qtpy import QtWidgets
13-
from ndv import NDViewer
14-
from skimage import data # just for example data here
15-
16-
qapp = QtWidgets.QApplication([])
17-
v = NDViewer(data.cells3d())
18-
v.show()
19-
qapp.exec()
12+
import ndv
13+
14+
data = ndv.data.cells3d()
15+
# or ndv.data.nd_sine_wave()
16+
# or *any* arraylike object (see support below)
17+
ndv.imshow(data)
2018
```
2119

2220
![Montage](https://github.com/pyapp-kit/ndv/assets/1609449/712861f7-ddcb-4ecd-9a4c-ba5f0cc1ee2c)
@@ -27,12 +25,22 @@ qapp.exec()
2725
- sliders support integer as well as slice (range)-based slicing
2826
- colormaps provided by [cmap](https://github.com/tlambert03/cmap)
2927
- supports [vispy](https://github.com/vispy/vispy) and [pygfx](https://github.com/pygfx/pygfx) backends
30-
- supports any numpy-like duck arrays, with special support for features in:
31-
- `xarray.DataArray`
28+
- supports any numpy-like duck arrays, including (but not limited to):
29+
- `numpy.ndarray`
30+
- `cupy.ndarray`
3231
- `dask.array.Array`
33-
- `tensorstore.TensorStore`
34-
- `zarr`
35-
- `dask`
32+
- `jax.Array`
33+
- `pyopencl.array.Array`
34+
- `sparse.COO`
35+
- `tensorstore.TensorStore` (supports named dimensions)
36+
- `torch.Tensor` (supports named dimensions)
37+
- `xarray.DataArray` (supports named dimensions)
38+
- `zarr` (supports named dimensions)
39+
- You can add support for your own storage class by subclassing `ndv.DataWrapper`
40+
and implementing a couple methods. (This doesn't require modifying ndv,
41+
but contributions of new wrappers are welcome!)
42+
43+
See examples for each of these array types in [examples](./examples/)
3644

3745
## Installation
3846

examples/custom_store.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
import numpy as np
6+
7+
import ndv
8+
9+
if TYPE_CHECKING:
10+
from ndv import Indices, Sizes
11+
12+
13+
class MyArrayThing:
14+
def __init__(self, shape: tuple[int, ...]) -> None:
15+
self.shape = shape
16+
self._data = np.random.randint(0, 256, shape)
17+
18+
def __getitem__(self, item: Any) -> np.ndarray:
19+
return self._data[item] # type: ignore [no-any-return]
20+
21+
22+
class MyWrapper(ndv.DataWrapper[MyArrayThing]):
23+
@classmethod
24+
def supports(cls, data: Any) -> bool:
25+
if isinstance(data, MyArrayThing):
26+
return True
27+
return False
28+
29+
def sizes(self) -> Sizes:
30+
"""Return a mapping of {dim: size} for the data"""
31+
return {f"dim_{k}": v for k, v in enumerate(self.data.shape)}
32+
33+
def isel(self, indexers: Indices) -> Any:
34+
"""Convert mapping of {dim: index} to conventional indexing"""
35+
idx = tuple(indexers.get(k, slice(None)) for k in range(len(self.data.shape)))
36+
return self.data[idx]
37+
38+
39+
data = MyArrayThing((10, 3, 512, 512))
40+
ndv.imshow(data)

examples/dask_arr.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from dask.array.core import map_blocks
77
except ImportError:
88
raise ImportError("Please `pip install dask[array]` to run this example.")
9+
import ndv
910

1011
frame_size = (1024, 1024)
1112

@@ -21,12 +22,4 @@ def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None:
2122
chunks += [(x,) for x in frame_size]
2223
dask_arr = map_blocks(_dask_block, chunks=chunks, dtype=np.uint8)
2324

24-
if __name__ == "__main__":
25-
from qtpy import QtWidgets
26-
27-
from ndv import NDViewer
28-
29-
qapp = QtWidgets.QApplication([])
30-
v = NDViewer(dask_arr)
31-
v.show()
32-
qapp.exec()
25+
v = ndv.imshow(dask_arr)

examples/jax_arr.py

+3-13
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,7 @@
44
import jax.numpy as jnp
55
except ImportError:
66
raise ImportError("Please install jax to run this example")
7-
from numpy_arr import generate_5d_sine_wave
8-
from qtpy import QtWidgets
7+
import ndv
98

10-
from ndv import NDViewer
11-
12-
# Example usage
13-
array_shape = (10, 3, 5, 512, 512) # Specify the desired dimensions
14-
sine_wave_5d = jnp.asarray(generate_5d_sine_wave(array_shape))
15-
16-
if __name__ == "__main__":
17-
qapp = QtWidgets.QApplication([])
18-
v = NDViewer(sine_wave_5d, channel_axis=1)
19-
v.show()
20-
qapp.exec()
9+
jax_arr = jnp.asarray(ndv.data.nd_sine_wave())
10+
v = ndv.imshow(jax_arr)

examples/numpy_arr.py

+6-59
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,11 @@
11
from __future__ import annotations
22

3-
import numpy as np
4-
5-
6-
def generate_5d_sine_wave(
7-
shape: tuple[int, int, int, int, int],
8-
amplitude: float = 240,
9-
base_frequency: float = 5,
10-
) -> np.ndarray:
11-
"""5D dataset."""
12-
# Unpack the dimensions
13-
angle_dim, freq_dim, phase_dim, ny, nx = shape
14-
15-
# Create an empty array to hold the data
16-
output = np.zeros(shape)
17-
18-
# Define spatial coordinates for the last two dimensions
19-
half_per = base_frequency * np.pi
20-
x = np.linspace(-half_per, half_per, nx)
21-
y = np.linspace(-half_per, half_per, ny)
22-
y, x = np.meshgrid(y, x)
23-
24-
# Iterate through each parameter in the higher dimensions
25-
for phase_idx in range(phase_dim):
26-
for freq_idx in range(freq_dim):
27-
for angle_idx in range(angle_dim):
28-
# Calculate phase and frequency
29-
phase = np.pi / phase_dim * phase_idx
30-
frequency = 1 + (freq_idx * 0.1) # Increasing frequency with each step
31-
32-
# Calculate angle
33-
angle = np.pi / angle_dim * angle_idx
34-
# Rotate x and y coordinates
35-
xr = np.cos(angle) * x - np.sin(angle) * y
36-
np.sin(angle) * x + np.cos(angle) * y
37-
38-
# Compute the sine wave
39-
sine_wave = (amplitude * 0.5) * np.sin(frequency * xr + phase)
40-
sine_wave += amplitude * 0.5
41-
42-
# Assign to the output array
43-
output[angle_idx, freq_idx, phase_idx] = sine_wave
44-
45-
return output
46-
3+
import ndv
474

485
try:
49-
from skimage import data
50-
51-
img = data.cells3d()
52-
except Exception:
53-
img = generate_5d_sine_wave((10, 3, 8, 512, 512))
54-
55-
56-
if __name__ == "__main__":
57-
from qtpy import QtWidgets
58-
59-
from ndv import NDViewer
6+
img = ndv.data.cells3d()
7+
except Exception as e:
8+
print(e)
9+
img = ndv.data.nd_sine_wave((10, 3, 8, 512, 512))
6010

61-
qapp = QtWidgets.QApplication([])
62-
v = NDViewer(img)
63-
v.show()
64-
qapp.exec()
11+
ndv.imshow(img)

examples/pyopencl_arr.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from __future__ import annotations
2+
3+
try:
4+
import pyopencl as cl
5+
import pyopencl.array as cl_array
6+
except ImportError:
7+
raise ImportError("Please install pyopencl to run this example")
8+
import ndv
9+
10+
# Set up OpenCL context and queue
11+
context = cl.create_some_context(interactive=False)
12+
queue = cl.CommandQueue(context)
13+
14+
15+
gpu_data = cl_array.to_device(queue, ndv.data.nd_sine_wave())
16+
17+
ndv.imshow(gpu_data)

examples/sparse_arr.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from __future__ import annotations
2+
3+
try:
4+
import sparse
5+
except ImportError:
6+
raise ImportError("Please install sparse to run this example")
7+
8+
import numpy as np
9+
10+
import ndv
11+
12+
shape = (256, 4, 512, 512)
13+
N = int(np.prod(shape) * 0.001)
14+
coords = np.random.randint(low=0, high=shape, size=(N, len(shape))).T
15+
data = np.random.randint(0, 256, N)
16+
17+
18+
# Create the sparse array from the coordinates and data
19+
sparse_array = sparse.COO(coords, data, shape=shape)
20+
21+
ndv.imshow(sparse_array)

examples/tensorstore_arr.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,28 @@
11
from __future__ import annotations
22

3-
import numpy as np
4-
import tensorstore as ts
5-
from qtpy import QtWidgets
3+
try:
4+
import tensorstore as ts
5+
except ImportError:
6+
raise ImportError("Please install tensorstore to run this example")
67

7-
from ndv import NDViewer
88

9-
shape = (10, 4, 3, 512, 512)
9+
import ndv
10+
11+
data = ndv.data.cells3d()
12+
1013
ts_array = ts.open(
11-
{"driver": "zarr", "kvstore": {"driver": "memory"}},
14+
{
15+
"driver": "zarr",
16+
"kvstore": {"driver": "memory"},
17+
"transform": {
18+
# tensorstore supports labeled dimensions
19+
"input_labels": ["z", "c", "y", "x"],
20+
},
21+
},
1222
create=True,
13-
shape=shape,
14-
dtype=ts.uint8,
23+
shape=data.shape,
24+
dtype=data.dtype,
1525
).result()
16-
ts_array[:] = np.random.randint(0, 255, size=shape, dtype=np.uint8)
17-
ts_array = ts_array[ts.d[:].label["t", "c", "z", "y", "x"]]
26+
ts_array[:] = ndv.data.cells3d()
1827

19-
if __name__ == "__main__":
20-
qapp = QtWidgets.QApplication([])
21-
v = NDViewer(ts_array)
22-
v.show()
23-
qapp.exec()
28+
ndv.imshow(ts_array)

examples/torch_arr.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from __future__ import annotations
2+
3+
try:
4+
import torch
5+
except ImportError:
6+
raise ImportError("Please install torch to run this example")
7+
8+
import warnings
9+
10+
import ndv
11+
12+
warnings.filterwarnings("ignore", "Named tensors") # Named tensors are experimental
13+
14+
# Example usage
15+
try:
16+
torch_data = torch.tensor(ndv.data.nd_sine_wave(), names=("t", "c", "z", "y", "x"))
17+
except TypeError:
18+
print("Named tensors are not supported in your version of PyTorch")
19+
torch_data = torch.tensor(ndv.data.nd_sine_wave())
20+
21+
ndv.imshow(torch_data)

examples/xarray_arr.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
from __future__ import annotations
22

3-
import xarray as xr
4-
from qtpy import QtWidgets
5-
6-
from ndv import NDViewer
3+
try:
4+
import xarray as xr
5+
except ImportError:
6+
raise ImportError("Please install xarray to run this example")
7+
import ndv
78

89
da = xr.tutorial.open_dataset("air_temperature").air
9-
10-
if __name__ == "__main__":
11-
qapp = QtWidgets.QApplication([])
12-
v = NDViewer(da, colormaps=["thermal"], channel_mode="composite")
13-
v.show()
14-
qapp.exec()
10+
ndv.imshow(da, cmap="thermal")

examples/zarr_arr.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
from __future__ import annotations
22

3-
import zarr
4-
import zarr.storage
5-
from qtpy import QtWidgets
3+
import ndv
4+
5+
try:
6+
import zarr
7+
import zarr.storage
8+
except ImportError:
9+
raise ImportError("Please `pip install zarr aiohttp` to run this example")
610

7-
from ndv import NDViewer
811

912
URL = "https://s3.embl.de/i2k-2020/ngff-example-data/v0.4/tczyx.ome.zarr"
1013
zarr_arr = zarr.open(URL, mode="r")
1114

12-
if __name__ == "__main__":
13-
qapp = QtWidgets.QApplication([])
14-
v = NDViewer(zarr_arr["s0"])
15-
v.show()
16-
qapp.exec()
15+
ndv.imshow(zarr_arr["s0"])

0 commit comments

Comments
 (0)