Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
71c2110
simplify conditional
melonora Mar 28, 2024
93ca0ed
Merge branch 'main' of https://github.com/melonora/napari-spatialdata
melonora May 21, 2024
f363842
Merge branch 'main' of https://github.com/melonora/napari-spatialdata
melonora Jun 19, 2024
07f60a5
Merge branch 'main' of https://github.com/melonora/napari-spatialdata
melonora Jun 24, 2024
784e95b
Merge branch 'main' of https://github.com/melonora/napari-spatialdata
melonora Jun 24, 2024
79a2a50
Merge branch 'scverse:main' into main
melonora Jul 9, 2024
29b1579
Merge branch 'scverse:main' into main
melonora Jul 22, 2024
5a2de2e
Merge branch 'scverse:main' into main
melonora Aug 9, 2024
7576271
Merge branch 'scverse:main' into main
melonora Sep 7, 2024
85e5c61
Merge branch 'scverse:main' into main
melonora Sep 19, 2024
dcc6b68
Merge branch 'scverse:main' into main
melonora Oct 14, 2024
3036463
Merge branch 'scverse:main' into main
melonora Dec 2, 2024
044bd18
Merge branch 'scverse:main' into main
melonora Dec 17, 2024
788a328
Merge branch 'main' of https://github.com/scverse/napari-spatialdata
melonora Feb 13, 2025
16f8a2e
Merge branch 'main' of https://github.com/melonora/napari-spatialdata
melonora Feb 13, 2025
8bda7d0
Merge branch 'main' of https://github.com/scverse/napari-spatialdata
melonora Mar 16, 2025
851333c
merge main
melonora Apr 18, 2025
bb29dd6
merge main
melonora May 26, 2025
b66f1ab
add channel_widget, remove colorbar
melonora Jun 2, 2025
a3a0761
allow multiscale channel selection
melonora Jun 2, 2025
40b0e86
add test and fix
melonora Jun 2, 2025
2959acd
copy dask pin from spatialdata
melonora Jun 2, 2025
ee067ad
remove quote
melonora Jun 2, 2025
abb01c2
merge three widgets
melonora Jun 2, 2025
3484cf4
update docstring
melonora Jun 2, 2025
d9dd413
remove require_widget wrapper
melonora Jun 3, 2025
4cc8410
remove unused element attribute
melonora Jun 3, 2025
02416a0
change widget type to bool param
melonora Jun 3, 2025
003bc94
readd elements for cache but remove dict
melonora Jun 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ install_requires =
anndata
click
cycler
dask>=2024.4.1
dask>=2024.4.1,<=2024.11.2
geopandas
loguru
matplotlib
Expand Down
390 changes: 356 additions & 34 deletions src/napari_spatialdata/_sdata_widgets.py

Large diffs are not rendered by default.

9 changes: 0 additions & 9 deletions src/napari_spatialdata/_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
from napari_spatialdata._widgets import (
AListWidget,
AnnDataSaveDialog,
CBarWidget,
ComponentWidget,
RangeSliderWidget,
SaveDialog,
ScatterAnnotationDialog,
)
Expand Down Expand Up @@ -497,13 +495,6 @@ def __init__(self, napari_viewer: Viewer, model: DataModel | None = None) -> Non
self.color_by = QLabel("Colored by:")
self.layout().addWidget(self.color_by)

# scalebar
colorbar = CBarWidget(model=self.model)
self.slider = RangeSliderWidget(self.viewer, self.model, colorbar=colorbar)
self._viewer.window.add_dock_widget(self.slider, area="left", name="slider")
self._viewer.window.add_dock_widget(colorbar, area="left", name="colorbar")
self.viewer.layers.selection.events.active.connect(self.slider._onLayerChange)

if (layer := self.viewer.layers.selection.active) is not None and layer.metadata.get("adata") is not None:
self._on_layer_update()

Expand Down
20 changes: 15 additions & 5 deletions src/napari_spatialdata/_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_get_ellipses_from_circles,
_get_init_metadata_adata,
_get_transform,
_obtain_channel_image,
_transform_coordinates,
get_duplicate_element_names,
get_napari_version,
Expand Down Expand Up @@ -442,10 +443,14 @@ def clean_worker(self) -> None:
"""Clean the worker."""
self.worker = None

def add_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> None:
self.add_layer(self.get_sdata_image(sdata, key, selected_cs, multi))
def add_sdata_image(
self, sdata: SpatialData, key: str, selected_cs: str, multi: bool, channel_name: str | None = None
) -> None:
self.add_layer(self.get_sdata_image(sdata, key, selected_cs, multi, channel_name))

def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> Image:
def get_sdata_image(
self, sdata: SpatialData, key: str, selected_cs: str, multi: bool, channel_name: str | None = None
) -> Image:
"""
Add an image in a spatial data object to the viewer.

Expand All @@ -465,15 +470,20 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi:
original_name = original_name[: original_name.rfind("_")]

affine = _get_transform(sdata.images[original_name], selected_cs)
rgb_image, rgb = _adjust_channels_order(element=sdata.images[original_name])
if channel_name:
image = _obtain_channel_image(element=sdata.images[original_name], channel_name=channel_name)
rgb = False
key = key + f"_ch:{channel_name}"
else:
image, rgb = _adjust_channels_order(element=sdata.images[original_name])

channels = ("RGB(A)",) if rgb else get_channels(sdata.images[original_name])

adata = AnnData(shape=(0, len(channels)), var=pd.DataFrame(index=channels))

# TODO: type check
return Image(
rgb_image,
image,
rgb=rgb,
name=key,
affine=affine,
Expand Down
194 changes: 1 addition & 193 deletions src/napari_spatialdata/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,12 @@
from qtpy import QtCore, QtWidgets
from qtpy.QtCore import Qt, Signal
from scanpy.plotting._utils import _set_colors_for_categorical_obs
from sklearn.preprocessing import MinMaxScaler
from spatialdata._types import ArrayLike
from superqt import QRangeSlider
from vispy import scene
from vispy.color.colormap import Colormap, MatplotlibColormap
from vispy.scene.widgets import ColorBarWidget

from napari_spatialdata._model import DataModel
from napari_spatialdata.utils._utils import _min_max_norm, get_napari_version

__all__ = ["AListWidget", "CBarWidget", "RangeSliderWidget", "ComponentWidget"]
__all__ = ["AListWidget", "ComponentWidget"]

# label string: attribute name
# TODO(giovp): remove since layer controls private?
Expand Down Expand Up @@ -409,193 +404,6 @@ def attr(self, field: str | None) -> None:
self._attr = field


class CBarWidget(QtWidgets.QWidget):
FORMAT = "{0:0.2f}"

cmapChanged = Signal(str)
climChanged = Signal((float, float))

def __init__(
self,
model: DataModel,
cmap: str = "viridis",
label: str | None = None,
width: int | None = 250,
height: int | None = 50,
**kwargs: Any,
):
super().__init__(**kwargs)

self._model = model

self._clim = (0.0, 1.0)
self._oclim = self._clim

self._width = width
self._height = height
self._label = label

self.__init_UI()

def __init_UI(self) -> None:
self.setFixedWidth(self._width)
self.setFixedHeight(self._height)

# use napari's BG color for dark mode
self._canvas = scene.SceneCanvas(
size=(self._width, self._height), bgcolor="#262930", parent=self, decorate=False, resizable=False, dpi=150
)
self._colorbar = ColorBarWidget(
self._create_colormap(self.cmap),
orientation="top",
label=self._label,
label_color="white",
clim=self.getClim(),
border_width=1.0,
border_color="black",
padding=(0.3, 0.167),
axis_ratio=0.05,
)

self._canvas.central_widget.add_widget(self._colorbar)

self.climChanged.connect(self.onClimChanged)
self.cmapChanged.connect(self.onCmapChanged)

def _create_colormap(self, cmap: str) -> Colormap:
ominn, omaxx = self.getOclim()
delta = omaxx - ominn + 1e-12

minn, maxx = self.getClim()
minn = (minn - ominn) / delta
maxx = (maxx - ominn) / delta

assert 0 <= minn <= 1, f"Expected `min` to be in `[0, 1]`, found `{minn}`"
assert 0 <= maxx <= 1, f"Expected `maxx` to be in `[0, 1]`, found `{maxx}`"

cm = MatplotlibColormap(cmap)

return Colormap(cm[np.linspace(minn, maxx, len(cm.colors))], interpolation="linear")

def getCmap(self) -> str:
return self.cmap

def onCmapChanged(self, value: str) -> None:
# this does not trigger update for some reason...
self._colorbar.cmap = self._create_colormap(value)
self._colorbar._colorbar._update()

def setClim(self, value: tuple[float, float]) -> None:
if value == self._clim:
return

self._clim = value
self.climChanged.emit(*value)

def getClim(self) -> tuple[float, float]:
return self._clim

def getOclim(self) -> tuple[float, float]:
return self._oclim

def setOclim(self, value: tuple[float, float]) -> None:
# original color limit used for 0-1 normalization
self._oclim = value

def onClimChanged(self, minn: float, maxx: float) -> None:
# ticks are not working with vispy's colorbar
self._colorbar.cmap = self._create_colormap(self.cmap)
self._colorbar.clim = (self.FORMAT.format(minn), self.FORMAT.format(maxx))

def getCanvas(self) -> scene.SceneCanvas:
return self._canvas

def getColorBar(self) -> ColorBarWidget:
return self._colorbar

def setLayout(self, layout: QtWidgets.QLayout) -> None:
layout.addWidget(self.getCanvas().native)
super().setLayout(layout)

def update_color(self) -> None:
# when changing selected layers that have the same limit
# could also trigger it as self._colorbar.clim = self.getClim()
# but the above option also updates geometry
# cbarwidget->cbar->cbarvisual
self._colorbar._colorbar._colorbar._update()

@property
def cmap(self) -> str:
return self._model.cmap


class RangeSliderWidget(QRangeSlider):
def __init__(self, viewer: Viewer, model: DataModel, colorbar: CBarWidget, **kwargs: Any):
super().__init__(**kwargs)

self._viewer = viewer
self._model = model
self._colorbar = colorbar
self._cmap = plt.get_cmap(self._colorbar.cmap)
self.setValue((0, 100))
self.setSliderPosition((0, 100))
self.setSingleStep(0.01)
self.setOrientation(Qt.Horizontal)
self.valueChanged.connect(self._onValueChange)

def _onLayerChange(self) -> None:
layer = self.viewer.layers.selection.active
if layer is not None:
self._onValueChange((0, 100))

def _onValueChange(self, percentile: tuple[float, float]) -> None:
layer = self.viewer.layers.selection.active
# TODO(michalk8): use constants
if "data" not in layer.metadata:
return None # noqa: RET501
v = layer.metadata["data"]
# this code is currently not used since the slider is not enabled; so I silenced the mypy error; 2. there is a
# mismatch for this error with the mypy in the CI, so I silenced the unused-ignore from the local mypy.
# when this code is re-enabled, let's fix mypy
clipped = np.clip(v, *np.percentile(v, percentile)) # type: ignore[misc,unused-ignore]

if isinstance(layer, Points):
layer.metadata = {**layer.metadata, "perc": percentile}
layer.face_color = "value"
layer.properties = {"value": clipped}
layer.refresh_colors()
elif isinstance(layer, Labels):
norm_vec = self._scale_vec(clipped)
color_vec = self._cmap(norm_vec)
layer.color = dict(zip(layer.color.keys(), color_vec, strict=False))
layer.properties = {"value": clipped}
layer.refresh()

self._colorbar.setOclim(layer.metadata["minmax"])
self._colorbar.setClim((np.min(layer.properties["value"]), np.max(layer.properties["value"])))
self._colorbar.update_color()

def _scale_vec(self, vec: ArrayLike) -> ArrayLike:
ominn, omaxx = self._colorbar.getOclim()
delta = omaxx - ominn + 1e-12

minn, maxx = self._colorbar.getClim()
minn = (minn - ominn) / delta
maxx = (maxx - ominn) / delta
scaler = MinMaxScaler(feature_range=(minn, maxx))
return scaler.fit_transform(vec.reshape(-1, 1))

@property
def viewer(self) -> napari.Viewer:
""":mod:`napari` viewer."""
return self._viewer

@property
def model(self) -> DataModel:
""":mod:`napari` viewer."""
return self._model


class SaveDialog(QtWidgets.QDialog):
def __init__(self, layer: Layer, table_name: str) -> None:
super().__init__()
Expand Down
46 changes: 33 additions & 13 deletions src/napari_spatialdata/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import contextmanager
from functools import wraps
from random import randint
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal, TypeVar

import numpy as np
import packaging.version
Expand Down Expand Up @@ -44,7 +44,7 @@
from napari.utils.events import EventedList
from qtpy.QtWidgets import QListWidgetItem

from napari_spatialdata._sdata_widgets import CoordinateSystemWidget, ElementWidget
from napari_spatialdata._sdata_widgets import ListWidget

Check warning on line 47 in src/napari_spatialdata/utils/_utils.py

View check run for this annotation

Codecov / codecov/patch

src/napari_spatialdata/utils/_utils.py#L47

Added line #L47 was not covered by tests

from spatialdata._types import ArrayLike

Expand Down Expand Up @@ -221,6 +221,31 @@
return out


def _datatree_to_dataarray_list(new_raster: DataArray | DataTree) -> DataArray | list[DataArray]:
if isinstance(new_raster, DataTree):
list_of_xdata = []
for k in new_raster:
v = new_raster[k].values()
assert len(v) == 1
xdata = v.__iter__().__next__()
list_of_xdata.append(xdata)
return list_of_xdata
return new_raster


def _obtain_channel_image(element: DataArray | DataTree, channel_name: str | int) -> DataArray | list[DataArray]:
is_multiscale_int_ch = isinstance(element, DataTree) and np.issubdtype(
element["scale0"].c.to_numpy().dtype, np.integer
)
is_int_ch = isinstance(element, DataArray) and np.issubdtype(element.c.to_numpy().dtype, np.integer)
if isinstance(channel_name, str) and (is_multiscale_int_ch or is_int_ch):
channel_name = int(channel_name)

# works for both DataArray and DataTree
new_raster = element.sel(c=channel_name)
return _datatree_to_dataarray_list(new_raster)


def _adjust_channels_order(element: DataArray | DataTree) -> tuple[DataArray | list[DataArray], bool]:
"""Swap the axes to y, x, c and check if an image supports rgb(a) visualization.

Expand Down Expand Up @@ -264,14 +289,7 @@
rgb = False
new_raster = element

if isinstance(new_raster, DataTree):
list_of_xdata = []
for k in new_raster:
v = new_raster[k].values()
assert len(v) == 1
xdata = v.__iter__().__next__()
list_of_xdata.append(xdata)
new_raster = list_of_xdata
new_raster = _datatree_to_dataarray_list(new_raster)

return new_raster, rgb

Expand Down Expand Up @@ -387,9 +405,7 @@
return adata


def get_itemindex_by_text(
list_widget: CoordinateSystemWidget | ElementWidget, item_text: str
) -> None | QListWidgetItem:
def get_itemindex_by_text(list_widget: ListWidget, item_text: str) -> None | QListWidgetItem:
"""
Get the item in a listwidget based on its text.

Expand Down Expand Up @@ -493,3 +509,7 @@
yield
finally:
widget.blockSignals(False)


WidgetType = Literal["coordinate_system", "element", "channel"]
F = TypeVar("F", bound=Callable[..., Any])
Loading
Loading