Skip to content

Commit b9d0ad9

Browse files
committed
refactored napari layer styles into dataclasses
1 parent 1e74f56 commit b9d0ad9

File tree

2 files changed

+108
-50
lines changed

2 files changed

+108
-50
lines changed

movement/napari/layer_styles.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Dataclasses containing layer styles for napari."""
2+
3+
from dataclasses import dataclass, field
4+
from typing import Optional
5+
6+
import numpy as np
7+
import pandas as pd
8+
from napari.utils.colormaps import ensure_colormap
9+
10+
DEFAULT_COLORMAP = "turbo"
11+
12+
13+
@dataclass
14+
class LayerStyle:
15+
"""Base class for napari layer styles."""
16+
17+
name: str
18+
properties: pd.DataFrame
19+
visible: bool = True
20+
blending: str = "translucent"
21+
22+
def as_kwargs(self) -> dict:
23+
"""Return the style properties as a dictionary of kwargs."""
24+
return self.__dict__
25+
26+
27+
@dataclass
28+
class PointsStyle(LayerStyle):
29+
"""Style properties for a napari Points layer."""
30+
31+
name: str
32+
properties: pd.DataFrame
33+
visible: bool = True
34+
blending: str = "translucent"
35+
symbol: str = "disc"
36+
size: int = 10
37+
edge_width: int = 0
38+
face_color: Optional[str] = None
39+
face_color_cycle: Optional[list[tuple]] = None
40+
face_colormap: str = DEFAULT_COLORMAP
41+
text: dict = field(default_factory=lambda: {"visible": False})
42+
43+
@staticmethod
44+
def _sample_colormap(n: int, cmap_name: str) -> list[tuple]:
45+
"""Sample n equally-spaced colors from a napari colormap,
46+
including the endpoints."""
47+
cmap = ensure_colormap(cmap_name)
48+
samples = np.linspace(0, len(cmap.colors) - 1, n).astype(int)
49+
return [tuple(cmap.colors[i]) for i in samples]
50+
51+
def set_color_by(self, prop: str, cmap: str) -> None:
52+
"""Set the face_color to a column in the properties DataFrame."""
53+
self.face_color = prop
54+
self.text["string"] = prop
55+
n_colors = len(self.properties[prop].unique())
56+
self.face_color_cycle = self._sample_colormap(n_colors, cmap)
57+
58+
59+
@dataclass
60+
class TracksStyle(LayerStyle):
61+
"""Style properties for a napari Tracks layer."""
62+
63+
name: str
64+
properties: pd.DataFrame
65+
tail_width: int = 5
66+
tail_length: int = 60
67+
head_length: int = 0
68+
color_by: str = "track_id"
69+
colormap: str = DEFAULT_COLORMAP
70+
visible: bool = True
71+
blending: str = "translucent"
72+
73+
def set_color_by(self, prop: str, cmap: str) -> None:
74+
"""Set the color_by to a column in the properties DataFrame."""
75+
self.color_by = prop
76+
self.colormap = cmap

movement/napari/loader_widgets.py

Lines changed: 32 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import logging
22
from pathlib import Path
33

4-
import numpy as np
5-
from napari.utils.colormaps import ensure_colormap
4+
import pandas as pd
65
from napari.viewer import Viewer
7-
from pandas.api.types import CategoricalDtype
86
from qtpy.QtWidgets import (
97
QComboBox,
108
QFileDialog,
@@ -18,16 +16,21 @@
1816

1917
from movement.io import load_poses
2018
from movement.napari.convert import ds_to_napari_tracks
19+
from movement.napari.layer_styles import PointsStyle, TracksStyle
2120

2221
logger = logging.getLogger(__name__)
2322

2423

25-
def sample_colormap(n: int, cmap_name: str) -> list[tuple]:
26-
"""Sample n equally-spaced colors from a napari colormap,
27-
including the endpoints."""
28-
cmap = ensure_colormap(cmap_name)
29-
samples = np.linspace(0, len(cmap.colors) - 1, n).astype(int)
30-
return [tuple(cmap.colors[i]) for i in samples]
24+
def columns_to_categorical(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
25+
"""Convert columns in a DataFrame to ordered categorical data type. The
26+
categories are the unique values in the column, ordered by appearance."""
27+
new_df = df.copy()
28+
for col in cols:
29+
cat_dtype = pd.api.types.CategoricalDtype(
30+
categories=df[col].unique().tolist(), ordered=True
31+
)
32+
new_df[col] = df[col].astype(cat_dtype).cat.codes
33+
return new_df
3134

3235

3336
class FileLoader(QWidget):
@@ -120,49 +123,28 @@ def load_file(self, file_path):
120123

121124
def add_layers(self):
122125
"""Add the predicted pose tracks and keypoints to the napari viewer."""
123-
124-
common_kwargs = {"visible": True, "blending": "translucent"}
125126
n_individuals = len(self.props["individual"].unique())
126127
color_by = "individual" if n_individuals > 1 else "keypoint"
127-
n_colors = len(self.props[color_by].unique())
128-
129-
# kwargs for the napari Points layer
130-
points_kwargs = {
131-
**common_kwargs,
132-
"name": f"Keypoints - {self.file_name}",
133-
"properties": self.props,
134-
"symbol": "disc",
135-
"size": 10,
136-
"edge_width": 0,
137-
"face_color": color_by,
138-
"face_color_cycle": sample_colormap(n_colors, "turbo"),
139-
"face_colormap": "turbo",
140-
"text": {"string": color_by, "visible": False},
141-
}
142-
143-
# Modify properties for the napari Tracks layer
144-
tracks_props = self.props.copy()
128+
129+
# Style properties for the napari Points layer
130+
points_style = PointsStyle(
131+
name=f"Keypoints - {self.file_name}",
132+
properties=self.props,
133+
)
134+
points_style.set_color_by(prop=color_by, cmap="turbo")
135+
145136
# Track properties must be numeric, so convert str to categorical codes
146-
for col in ["individual", "keypoint"]:
147-
cat_dtype = CategoricalDtype(
148-
categories=tracks_props[col].unique(), ordered=True
149-
)
150-
tracks_props[col] = tracks_props[col].astype(cat_dtype).cat.codes
137+
tracks_props = columns_to_categorical(
138+
self.props, ["individual", "keypoint"]
139+
)
151140

152141
# kwargs for the napari Tracks layer
153-
tracks_kwargs = {
154-
**common_kwargs,
155-
"name": f"Tracks - {self.file_name}",
156-
"properties": tracks_props,
157-
"tail_width": 5,
158-
"tail_length": 60,
159-
"head_length": 0,
160-
"color_by": color_by,
161-
"colormap": "turbo",
162-
}
163-
164-
# Add the napari Tracks layer to the viewer
165-
self.viewer.add_tracks(self.data, **tracks_kwargs)
166-
167-
# Add the napari Points layer to the viewer
168-
self.viewer.add_points(self.data[:, 1:], **points_kwargs)
142+
tracks_style = TracksStyle(
143+
name=f"Tracks - {self.file_name}",
144+
properties=tracks_props,
145+
)
146+
tracks_style.set_color_by(prop=color_by, cmap="turbo")
147+
148+
# Add the new layers to the napari viewer
149+
self.viewer.add_tracks(self.data, **tracks_style.as_kwargs())
150+
self.viewer.add_points(self.data[:, 1:], **points_style.as_kwargs())

0 commit comments

Comments
 (0)