|
1 | 1 | import logging
|
2 | 2 | from pathlib import Path
|
3 | 3 |
|
4 |
| -import numpy as np |
5 |
| -from napari.utils.colormaps import ensure_colormap |
| 4 | +import pandas as pd |
6 | 5 | from napari.viewer import Viewer
|
7 |
| -from pandas.api.types import CategoricalDtype |
8 | 6 | from qtpy.QtWidgets import (
|
9 | 7 | QComboBox,
|
10 | 8 | QFileDialog,
|
|
18 | 16 |
|
19 | 17 | from movement.io import load_poses
|
20 | 18 | from movement.napari.convert import ds_to_napari_tracks
|
| 19 | +from movement.napari.layer_styles import PointsStyle, TracksStyle |
21 | 20 |
|
22 | 21 | logger = logging.getLogger(__name__)
|
23 | 22 |
|
24 | 23 |
|
25 |
| -def sample_colormap(n: int, cmap_name: str) -> list[tuple]: |
26 |
| - """Sample n equally-spaced colors from a napari colormap, |
27 |
| - including the endpoints.""" |
28 |
| - cmap = ensure_colormap(cmap_name) |
29 |
| - samples = np.linspace(0, len(cmap.colors) - 1, n).astype(int) |
30 |
| - return [tuple(cmap.colors[i]) for i in samples] |
| 24 | +def columns_to_categorical(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame: |
| 25 | + """Convert columns in a DataFrame to ordered categorical data type. The |
| 26 | + categories are the unique values in the column, ordered by appearance.""" |
| 27 | + new_df = df.copy() |
| 28 | + for col in cols: |
| 29 | + cat_dtype = pd.api.types.CategoricalDtype( |
| 30 | + categories=df[col].unique().tolist(), ordered=True |
| 31 | + ) |
| 32 | + new_df[col] = df[col].astype(cat_dtype).cat.codes |
| 33 | + return new_df |
31 | 34 |
|
32 | 35 |
|
33 | 36 | class FileLoader(QWidget):
|
@@ -120,49 +123,28 @@ def load_file(self, file_path):
|
120 | 123 |
|
121 | 124 | def add_layers(self):
|
122 | 125 | """Add the predicted pose tracks and keypoints to the napari viewer."""
|
123 |
| - |
124 |
| - common_kwargs = {"visible": True, "blending": "translucent"} |
125 | 126 | n_individuals = len(self.props["individual"].unique())
|
126 | 127 | color_by = "individual" if n_individuals > 1 else "keypoint"
|
127 |
| - n_colors = len(self.props[color_by].unique()) |
128 |
| - |
129 |
| - # kwargs for the napari Points layer |
130 |
| - points_kwargs = { |
131 |
| - **common_kwargs, |
132 |
| - "name": f"Keypoints - {self.file_name}", |
133 |
| - "properties": self.props, |
134 |
| - "symbol": "disc", |
135 |
| - "size": 10, |
136 |
| - "edge_width": 0, |
137 |
| - "face_color": color_by, |
138 |
| - "face_color_cycle": sample_colormap(n_colors, "turbo"), |
139 |
| - "face_colormap": "turbo", |
140 |
| - "text": {"string": color_by, "visible": False}, |
141 |
| - } |
142 |
| - |
143 |
| - # Modify properties for the napari Tracks layer |
144 |
| - tracks_props = self.props.copy() |
| 128 | + |
| 129 | + # Style properties for the napari Points layer |
| 130 | + points_style = PointsStyle( |
| 131 | + name=f"Keypoints - {self.file_name}", |
| 132 | + properties=self.props, |
| 133 | + ) |
| 134 | + points_style.set_color_by(prop=color_by, cmap="turbo") |
| 135 | + |
145 | 136 | # Track properties must be numeric, so convert str to categorical codes
|
146 |
| - for col in ["individual", "keypoint"]: |
147 |
| - cat_dtype = CategoricalDtype( |
148 |
| - categories=tracks_props[col].unique(), ordered=True |
149 |
| - ) |
150 |
| - tracks_props[col] = tracks_props[col].astype(cat_dtype).cat.codes |
| 137 | + tracks_props = columns_to_categorical( |
| 138 | + self.props, ["individual", "keypoint"] |
| 139 | + ) |
151 | 140 |
|
152 | 141 | # kwargs for the napari Tracks layer
|
153 |
| - tracks_kwargs = { |
154 |
| - **common_kwargs, |
155 |
| - "name": f"Tracks - {self.file_name}", |
156 |
| - "properties": tracks_props, |
157 |
| - "tail_width": 5, |
158 |
| - "tail_length": 60, |
159 |
| - "head_length": 0, |
160 |
| - "color_by": color_by, |
161 |
| - "colormap": "turbo", |
162 |
| - } |
163 |
| - |
164 |
| - # Add the napari Tracks layer to the viewer |
165 |
| - self.viewer.add_tracks(self.data, **tracks_kwargs) |
166 |
| - |
167 |
| - # Add the napari Points layer to the viewer |
168 |
| - self.viewer.add_points(self.data[:, 1:], **points_kwargs) |
| 142 | + tracks_style = TracksStyle( |
| 143 | + name=f"Tracks - {self.file_name}", |
| 144 | + properties=tracks_props, |
| 145 | + ) |
| 146 | + tracks_style.set_color_by(prop=color_by, cmap="turbo") |
| 147 | + |
| 148 | + # Add the new layers to the napari viewer |
| 149 | + self.viewer.add_tracks(self.data, **tracks_style.as_kwargs()) |
| 150 | + self.viewer.add_points(self.data[:, 1:], **points_style.as_kwargs()) |
0 commit comments