Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion src/napari_spatialdata/_sdata_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,22 @@
from packaging.version import parse as parse_version
from qtpy.QtCore import QThread, Signal
from qtpy.QtGui import QIcon
from qtpy.QtWidgets import QLabel, QListWidget, QListWidgetItem, QProgressBar, QVBoxLayout, QWidget
from qtpy.QtWidgets import (
QCheckBox,
QHBoxLayout,
QLabel,
QListWidget,
QListWidgetItem,
QProgressBar,
QVBoxLayout,
QWidget,
)
from spatialdata import SpatialData
from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM
from superqt import QDoubleRangeSlider

from napari_spatialdata._viewer import SpatialDataViewer
from napari_spatialdata.constants import config
from napari_spatialdata.constants.config import N_CIRCLES_WARNING_THRESHOLD, N_SHAPES_WARNING_THRESHOLD
from napari_spatialdata.utils._utils import _get_sdata_key, get_duplicate_element_names, get_elements_meta_mapping

Expand Down Expand Up @@ -174,11 +185,45 @@ def __init__(self, viewer: Viewer, sdata: EventedList):
self.slider.setRange(0, 0)
self.slider.setVisible(False)

self.enable_3d_points = QCheckBox("Enable 3D points")
self.enable_3d_points.setChecked(not config.PROJECT_3D_POINTS_TO_2D)
self.enable_3d_points.setToolTip("When checked, points with a z coordinate are displayed in 3D.")
self.enable_3d_points.toggled.connect(self._on_3d_points_toggled)

self.enable_2_5d_shapes = QCheckBox("Enable 2.5D shapes")
self.enable_2_5d_shapes.setChecked(not config.PROJECT_2_5D_SHAPES_TO_2D)
self.enable_2_5d_shapes.setToolTip("When checked, shapes with a z coordinate are displayed in 2.5D.")
self.enable_2_5d_shapes.toggled.connect(self._on_2_5d_shapes_toggled)

self.z_range_label = QLabel("Z range:")
self.z_range_value_label = QLabel("")
z_range_header = QHBoxLayout()
z_range_header.addWidget(self.z_range_label)
z_range_header.addWidget(self.z_range_value_label)
z_range_header.addStretch()
self.z_range_header_widget = QWidget()
self.z_range_header_widget.setLayout(z_range_header)

self.z_range_slider = QDoubleRangeSlider()
self.z_range_slider.setRange(0.0, 1.0)
self.z_range_slider.setValue((0.0, 1.0))
self.z_range_slider.setToolTip("Filter visible points and shapes by z coordinate range.")
self.z_range_slider.valueChanged.connect(self._on_z_range_changed)

self._z_slider_visible = False
self.z_range_header_widget.setVisible(False)
self.z_range_slider.setVisible(False)

self.layout().addWidget(self.slider)
self.layout().addWidget(QLabel("Coordinate System:"))
self.layout().addWidget(self.coordinate_system_widget)
self.layout().addWidget(QLabel("Elements:"))
self.layout().addWidget(self.elements_widget)
self.layout().addWidget(QLabel("3D Settings:"))
self.layout().addWidget(self.enable_3d_points)
self.layout().addWidget(self.enable_2_5d_shapes)
self.layout().addWidget(self.z_range_header_widget)
self.layout().addWidget(self.z_range_slider)
self.elements_widget.itemDoubleClicked.connect(self._on_click_item)
self.coordinate_system_widget.currentItemChanged.connect(
lambda item: self.elements_widget._onItemChange(item.text())
Expand All @@ -196,12 +241,14 @@ def __init__(self, viewer: Viewer, sdata: EventedList):
def _on_insert_layer(self, event: Event) -> None:
layer = event.value
layer.events.visible.connect(self._update_visible_in_coordinate_system)
self._update_z_slider()

def _on_click_item(self, item: QListWidgetItem) -> None:
self._onClick(item.text())

def _hide_slider(self) -> None:
self.slider.setVisible(False)
self._update_z_slider()

def _onClick(self, text: str) -> None:
selected_cs = self.coordinate_system_widget._system
Expand Down Expand Up @@ -258,6 +305,38 @@ def _update_layers_visibility(self) -> None:
layer.metadata["_active_in_cs"].add(coordinate_system)
layer.metadata["_current_cs"] = coordinate_system

def _on_3d_points_toggled(self, checked: bool) -> None:
config.PROJECT_3D_POINTS_TO_2D = not checked

def _on_2_5d_shapes_toggled(self, checked: bool) -> None:
config.PROJECT_2_5D_SHAPES_TO_2D = not checked

def _update_z_slider(self) -> None:
"""Show the z-range slider when layers with z data are present and update its range."""
z_range = self.viewer_model.get_z_range()
if z_range is None:
self.z_range_header_widget.setVisible(False)
self.z_range_slider.setVisible(False)
self._z_slider_visible = False
return

z_min, z_max = z_range
if z_min == z_max:
z_max = z_min + 1.0

self.z_range_slider.setRange(z_min, z_max)
if not self._z_slider_visible:
self.z_range_slider.setValue((z_min, z_max))
self._z_slider_visible = True
self.z_range_value_label.setText(f"[{z_min:.1f}, {z_max:.1f}]")
self.z_range_header_widget.setVisible(True)
self.z_range_slider.setVisible(True)

def _on_z_range_changed(self, value: tuple[float, float]) -> None:
z_min, z_max = value
self.z_range_value_label.setText(f"[{z_min:.1f}, {z_max:.1f}]")
self.viewer_model.filter_layers_by_z_range(z_min, z_max)

def _get_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> Shapes | Points:
original_name = key[: key.rfind("_")] if multi else key

Expand Down
139 changes: 126 additions & 13 deletions src/napari_spatialdata/_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,6 @@ def _save_points_to_sdata(
raise ValueError("Cannot export a points element with no points")
transformed_data = np.array([layer_to_save.data_to_world(xy) for xy in layer_to_save.data])
swap_data = np.fliplr(transformed_data)
# ignore z axis if present
if swap_data.shape[1] == 3:
swap_data = swap_data[:, :2]
parsed = PointsModel.parse(swap_data, transformations=transformation)

# saving to disk of points temporarily disabled until the interface update that will unify the view widget,
Expand Down Expand Up @@ -261,14 +258,21 @@ def _save_shapes_to_sdata(
for shape in layer_to_save._data_view.shapes
]

def _fix_coords(coords: ArrayLike) -> ArrayLike:
remove_z = coords.shape[1] == 3
first_index = 1 if remove_z else 0
coords = coords[:, first_index::]
return np.fliplr(coords)
has_z = coords[0].shape[1] == 3

polygons: list[Polygon] = [Polygon(_fix_coords(p)) for p in coords]
gdf = GeoDataFrame({"geometry": polygons})
def _fix_coords(coords: ArrayLike) -> tuple[ArrayLike, float | None]:
if coords.shape[1] == 3:
z_val = float(coords[0, 0])
yx = coords[:, 1:]
return np.fliplr(yx), z_val
return np.fliplr(coords), None

fixed = [_fix_coords(p) for p in coords]
polygons: list[Polygon] = [Polygon(xy) for xy, _ in fixed]
gdf_dict: dict[str, Any] = {"geometry": polygons}
if has_z:
gdf_dict["z"] = [z_val for _, z_val in fixed]
gdf = GeoDataFrame(gdf_dict)

force_2d(gdf)
parsed = ShapesModel.parse(gdf, transformations=transformation)
Expand Down Expand Up @@ -514,11 +518,15 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
original_name = original_name[: original_name.rfind("_")]

df = sdata.shapes[original_name]
affine = _get_transform(sdata.shapes[original_name], selected_cs)
axes = get_axes_names(df)
include_z = "z" in axes and not config.PROJECT_2_5D_SHAPES_TO_2D
affine = _get_transform(sdata.shapes[original_name], selected_cs, include_z=include_z)

# 2.5D circles not supported yet
xy = np.array([df.geometry.x, df.geometry.y]).T
yx = np.fliplr(xy)
if include_z:
z_vals = df["z"].to_numpy()
yx = np.column_stack([z_vals, yx])
radii = df.radius.to_numpy()

adata, table_name, table_names = self._get_table_data(sdata, original_name)
Expand Down Expand Up @@ -804,8 +812,113 @@ def _affine_transform_layers(self, coordinate_system: str) -> None:
sdata = metadata["sdata"]
element_name = metadata["name"]
element_data = sdata[element_name]
affine = _get_transform(element_data, coordinate_system)
include_z = self._should_include_z(element_data)
affine = _get_transform(element_data, coordinate_system, include_z=include_z)
if affine is not None:
layer.affine = affine
if layer._type_string == "points":
self._adjust_radii_of_points_layer(layer, affine)

@staticmethod
def _should_include_z(element: DaskDataFrame | GeoDataFrame) -> bool:
"""Determine whether to include the z axis for a given spatial element.

For raster data (images, labels) z is always included when present.
For vector data (points, shapes) z inclusion depends on the user-facing
projection config flags.
"""
from xarray import DataArray, DataTree

if isinstance(element, DataArray | DataTree):
return True
axes = get_axes_names(element)
if "z" not in axes:
return False
if isinstance(element, DaskDataFrame):
return not config.PROJECT_3D_POINTS_TO_2D
return not config.PROJECT_2_5D_SHAPES_TO_2D

def get_z_range(self) -> tuple[float, float] | None:
"""Return the global (min, max) z range across all visible layers, or ``None`` if no z data exists."""
z_min, z_max = float("inf"), float("-inf")
found = False
for layer in self.viewer.layers:
metadata = layer.metadata
if not metadata.get("sdata"):
continue
sdata = metadata["sdata"]
element_name = metadata["name"]
element_data = sdata[element_name]
axes = get_axes_names(element_data)
if "z" not in axes:
continue
if isinstance(element_data, DaskDataFrame):
z_vals = element_data["z"].compute().values
elif isinstance(element_data, GeoDataFrame):
if "z" not in element_data.columns:
continue
z_vals = element_data["z"].values
else:
continue
if len(z_vals) == 0:
continue
found = True
z_min = min(z_min, float(z_vals.min()))
z_max = max(z_max, float(z_vals.max()))
if not found:
return None
return z_min, z_max

def filter_layers_by_z_range(self, z_min: float, z_max: float) -> None:
"""Hide points/shapes outside the given z range.

For :class:`~napari.layers.Points` layers the ``shown`` property is
used. For :class:`~napari.layers.Shapes` layers the face and edge
color alpha channels are set to 0 for shapes outside the range while
preserving the original alpha for visible shapes.
"""
for layer in self.viewer.layers:
metadata = layer.metadata
if not metadata.get("sdata"):
continue
sdata = metadata["sdata"]
element_name = metadata["name"]
element_data = sdata[element_name]
axes = get_axes_names(element_data)
if "z" not in axes:
continue

if isinstance(layer, Points):
if layer.data.shape[1] == 3:
z_vals = layer.data[:, 0]
else:
continue
mask = (z_vals >= z_min) & (z_vals <= z_max)
layer.shown = mask

elif isinstance(layer, Shapes):
n_shapes = len(layer.data)
if n_shapes == 0:
continue

if layer.data[0].shape[1] == 3:
z_vals = np.array([float(s[0, 0]) for s in layer.data])
elif isinstance(element_data, GeoDataFrame) and "z" in element_data.columns:
z_raw = element_data["z"].values
if len(z_raw) != n_shapes:
continue
z_vals = z_raw
else:
continue

if "_original_face_color" not in metadata:
metadata["_original_face_color"] = layer.face_color.copy()
metadata["_original_edge_color"] = layer.edge_color.copy()

mask = (z_vals >= z_min) & (z_vals <= z_max)
face_colors = metadata["_original_face_color"].copy()
edge_colors = metadata["_original_edge_color"].copy()
face_colors[~mask, 3] = 0.0
edge_colors[~mask, 3] = 0.0
layer.face_color = face_colors
layer.edge_color = edge_colors
31 changes: 24 additions & 7 deletions src/napari_spatialdata/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,13 @@ def generate_random_color_hex() -> str:
def _get_ellipses_from_circles(yx: ArrayLike, radii: ArrayLike) -> ArrayLike:
"""Convert circles to ellipses.

Supports both 2D (y, x) and 2.5D (z, y, x) centroids. For 2.5D input the radius is
applied only to y and x while z is kept constant across the four corner vertices.

Parameters
----------
yx
Centroids of the circles.
Centroids of the circles with shape ``(N, 2)`` or ``(N, 3)``.
radii
Radii of the circles.

Expand All @@ -475,13 +478,27 @@ def _get_ellipses_from_circles(yx: ArrayLike, radii: ArrayLike) -> ArrayLike:
Ellipses.
"""
ndim = yx.shape[1]
assert ndim == 2
r = np.stack([radii] * ndim, axis=1)
lower_left = yx - r
upper_right = yx + r
assert ndim in (2, 3)

if ndim == 3:
z = yx[:, :1]
yx_2d = yx[:, 1:]
else:
yx_2d = yx

r = np.stack([radii, radii], axis=1)
lower_left = yx_2d - r
upper_right = yx_2d + r
r[:, 0] = -r[:, 0]
lower_right = yx - r
upper_left = yx + r
lower_right = yx_2d - r
upper_left = yx_2d + r

if ndim == 3:
lower_left = np.column_stack([z, lower_left])
lower_right = np.column_stack([z, lower_right])
upper_right = np.column_stack([z, upper_right])
upper_left = np.column_stack([z, upper_left])

ellipses = np.stack([lower_left, lower_right, upper_right, upper_left], axis=1)
assert isinstance(ellipses, np.ndarray)
return ellipses
Expand Down
Loading
Loading