Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

napari widget for loading and visualising pose tracks #112

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
630f65f
added napari-related dependencies
niksirbi Jan 23, 2024
a95f4ac
[WIP] first draft of napari reader for pose tracks
niksirbi Jan 25, 2024
51851d6
changed napari-related dependencies
niksirbi Feb 6, 2024
d38a3cf
WIP converting reader plugin to widget
niksirbi Feb 6, 2024
3de9980
converted reader plugin to widget
niksirbi Feb 8, 2024
9929d67
remove some redundant code
niksirbi Feb 9, 2024
c7b3928
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2024
a15a75d
constrained brainglobe-utils to >= 0.4
niksirbi Feb 12, 2024
8f0355b
fixed color ordering in points and tracks colormaps
niksirbi Feb 19, 2024
3878062
refactored napari layer styles into dataclasses
niksirbi Feb 19, 2024
78e7446
use the generic from_file function to load the pose data
niksirbi Feb 20, 2024
bf64375
simplify the way napari dependencies are included in dev extras
niksirbi Feb 22, 2024
cd74549
refactored ds_to_napari_tracks function
niksirbi Feb 26, 2024
00dd611
fixed typo
niksirbi Feb 27, 2024
04d73b5
wrote test for ds_to_napari_tracks
niksirbi Feb 27, 2024
6816e17
test replacement of NaNs in confidence
niksirbi Feb 27, 2024
f89a2ea
WIP smoke test for napari widget
niksirbi Feb 27, 2024
d430b64
refactor and test function for converting pd column to categorical codes
niksirbi Feb 28, 2024
c691fcb
set napari playback speed based on fps
niksirbi Feb 28, 2024
43fc9a0
rename widgets
niksirbi Feb 28, 2024
e4a51af
pose file is loaded upon "Load" button click
niksirbi Feb 28, 2024
ac31448
test instantiation of meta and loader widgets
niksirbi Feb 29, 2024
583a9f0
add teardown to metawidget fixture
niksirbi Mar 17, 2024
ccb7f1d
Revert "add teardown to metawidget fixture"
niksirbi Mar 25, 2024
b647fb3
add workflow step to enable qt testing on linux
niksirbi Mar 25, 2024
c726d9a
use make_napari_viewer_proxy
niksirbi Mar 25, 2024
21c68f5
updated tox config to enable headless qt tests on ubuntu
niksirbi Apr 16, 2024
f438744
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ jobs:
python-version: "3.10"

steps:
- name: enable Qt testing on Linux
uses: pyvista/setup-headless-display-action@v2
with:
qt: true
- name: Cache Test Data
uses: actions/cache@v4
with:
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include LICENSE
include *.md
include movement/napari/napari.yaml
exclude .pre-commit-config.yaml
exclude .cruft.json

Expand Down
Empty file added movement/napari/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions movement/napari/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import logging

import numpy as np
import pandas as pd
import xarray as xr

# get logger
logger = logging.getLogger(__name__)


def _replace_nans_with_zeros(
ds: xr.Dataset, data_vars: list[str]
) -> xr.Dataset:
"""Replace NaN values in specified data variables with zeros."""
for data_var in data_vars:
if ds[data_var].isnull().any():
logger.warning(
f"NaNs found in {data_var}, will be replaced with zeros."
)
ds[data_var] = ds[data_var].fillna(0)
return ds


def _construct_properties_dataframe(ds: xr.Dataset) -> pd.DataFrame:
"""Construct a pandas DataFrame with properties from the dataset."""
properties = pd.DataFrame(
{
"individual": ds.coords["individuals"].values,
"keypoint": ds.coords["keypoints"].values,
"time": ds.coords["time"].values,
"confidence": ds["confidence"].values.flatten(),
}
)
return properties


def ds_to_napari_tracks(ds: xr.Dataset) -> tuple[np.ndarray, pd.DataFrame]:
"""Converts movement xarray dataset to napari tracks array and properties.
For reference, see the napari tracks array documentation [1]_.

Parameters
----------
ds : xr.Dataset
Movement dataset with pose tracks and confidence data variables.

Returns
-------
data : np.ndarray
Napari Tracks array with shape (N, 4),
where N is n_keypoints * n_individuals * n_frames
and the 4 columns are (track_id, frame_idx, y, x).
properties : pd.DataFrame
DataFrame with properties (individual, keypoint, time, confidence).

Notes
-----
A corresponding napari Points array can be derived from the Tracks array
by taking its last 3 columns: (frame_idx, y, x). See the napari Points
array documentation [2]_.

References
----------
.. [1] https://napari.org/stable/howtos/layers/tracks.html
.. [2] https://napari.org/stable/howtos/layers/points.html

"""
ds_ = ds.copy() # make a copy to avoid modifying the original dataset

n_frames = ds_.sizes["time"]
n_individuals = ds_.sizes["individuals"]
n_keypoints = ds_.sizes["keypoints"]
n_tracks = n_individuals * n_keypoints

ds_ = _replace_nans_with_zeros(ds_, ["confidence"])
# assign unique integer ids to individuals and keypoints
ds_.coords["individual_ids"] = ("individuals", range(n_individuals))
ds_.coords["keypoint_ids"] = ("keypoints", range(n_keypoints))

# Convert 4D to 2D array by stacking
ds_ = ds_.stack(tracks=("individuals", "keypoints", "time"))
# Track ids are unique ints (individual_id * n_keypoints + keypoint_id)
individual_ids = ds_.coords["individual_ids"].values
keypoint_ids = ds_.coords["keypoint_ids"].values
track_ids = (individual_ids * n_keypoints + keypoint_ids).reshape(-1, 1)

# Construct the napari Tracks array
yx_columns = np.fliplr(ds_["pose_tracks"].values.T)
time_column = np.tile(range(n_frames), n_tracks).reshape(-1, 1)
data = np.hstack((track_ids, time_column, yx_columns))

# Construct the properties DataFrame
properties = _construct_properties_dataframe(ds_)

return data, properties
77 changes: 77 additions & 0 deletions movement/napari/layer_styles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Dataclasses containing layer styles for napari."""

from dataclasses import dataclass, field
from typing import Optional

import numpy as np
import pandas as pd
from napari.utils.colormaps import ensure_colormap

DEFAULT_COLORMAP = "turbo"


@dataclass
class LayerStyle:
"""Base class for napari layer styles."""

name: str
properties: pd.DataFrame
visible: bool = True
blending: str = "translucent"

def as_kwargs(self) -> dict:
"""Return the style properties as a dictionary of kwargs."""
return self.__dict__


@dataclass
class PointsStyle(LayerStyle):
"""Style properties for a napari Points layer."""

name: str
properties: pd.DataFrame
visible: bool = True
blending: str = "translucent"
symbol: str = "disc"
size: int = 10
edge_width: int = 0
face_color: Optional[str] = None
face_color_cycle: Optional[list[tuple]] = None
face_colormap: str = DEFAULT_COLORMAP
text: dict = field(default_factory=lambda: {"visible": False})

@staticmethod
def _sample_colormap(n: int, cmap_name: str) -> list[tuple]:
"""Sample n equally-spaced colors from a napari colormap,
including the endpoints.
"""
cmap = ensure_colormap(cmap_name)
samples = np.linspace(0, len(cmap.colors) - 1, n).astype(int)
return [tuple(cmap.colors[i]) for i in samples]

def set_color_by(self, prop: str, cmap: str) -> None:
"""Set the face_color to a column in the properties DataFrame."""
self.face_color = prop
self.text["string"] = prop
n_colors = len(self.properties[prop].unique())
self.face_color_cycle = self._sample_colormap(n_colors, cmap)


@dataclass
class TracksStyle(LayerStyle):
"""Style properties for a napari Tracks layer."""

name: str
properties: pd.DataFrame
tail_width: int = 5
tail_length: int = 60
head_length: int = 0
color_by: str = "track_id"
colormap: str = DEFAULT_COLORMAP
visible: bool = True
blending: str = "translucent"

def set_color_by(self, prop: str, cmap: str) -> None:
"""Set the color_by to a column in the properties DataFrame."""
self.color_by = prop
self.colormap = cmap
140 changes: 140 additions & 0 deletions movement/napari/loader_widget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import logging
from pathlib import Path

from napari.viewer import Viewer
from qtpy.QtWidgets import (
QComboBox,
QFileDialog,
QFormLayout,
QHBoxLayout,
QLineEdit,
QPushButton,
QSpinBox,
QWidget,
)

from movement.io import load_poses
from movement.napari.convert import ds_to_napari_tracks
from movement.napari.layer_styles import PointsStyle, TracksStyle
from movement.napari.utils import (
columns_to_categorical_codes,
set_playback_fps,
)

logger = logging.getLogger(__name__)


class Loader(QWidget):
"""Widget for loading data from files."""

file_suffix_map = {
"DeepLabCut": "Files containing predicted poses (*.h5 *.csv)",
"LightningPose": "Files containing predicted poses (*.csv)",
"SLEAP": "Files containing predicted poses (*.h5 *.slp)",
}

def __init__(self, napari_viewer: Viewer, parent=None):
super().__init__(parent=parent)
self.viewer = napari_viewer
self.setLayout(QFormLayout())
# Create widgets
self.create_source_software_widget()
self.create_fps_widget()
self.create_file_path_widget()
self.create_load_button()

def create_source_software_widget(self):
"""Create a combo box for selecting the source software."""
self.source_software_combo = QComboBox()
self.source_software_combo.addItems(
["SLEAP", "DeepLabCut", "LightningPose"]
)
self.layout().addRow("source software:", self.source_software_combo)

def create_fps_widget(self):
"""Create a spinbox for selecting the frames per second (fps)."""
self.fps_spinbox = QSpinBox()
self.fps_spinbox.setMinimum(1)
self.fps_spinbox.setMaximum(1000)
self.fps_spinbox.setValue(50)
self.layout().addRow("fps:", self.fps_spinbox)

def create_file_path_widget(self):
"""Create a line edit and browse button for selecting the file path.
This allows the user to either browse the file system,
or type the path directly into the line edit.
"""
# File path line edit and browse button
self.file_path_edit = QLineEdit()
self.browse_button = QPushButton("browse")
self.browse_button.clicked.connect(self.open_file_dialog)
# Layout for line edit and button
self.file_path_layout = QHBoxLayout()
self.file_path_layout.addWidget(self.file_path_edit)
self.file_path_layout.addWidget(self.browse_button)
self.layout().addRow("pose file:", self.file_path_layout)

def create_load_button(self):
"""Create a button to load the file and add layers to the viewer."""
self.load_button = QPushButton("Load")
self.load_button.clicked.connect(lambda: self.load_file())
self.layout().addRow(self.load_button)

def open_file_dialog(self):
dlg = QFileDialog()
dlg.setFileMode(QFileDialog.ExistingFile)
# Allowed file suffixes based on the source software
dlg.setNameFilter(
self.file_suffix_map[self.source_software_combo.currentText()]
)
if dlg.exec_():
file_paths = dlg.selectedFiles()
# Set the file path in the line edit
self.file_path_edit.setText(file_paths[0])

def load_file(self):
fps = self.fps_spinbox.value()
source_software = self.source_software_combo.currentText()
file_path = self.file_path_edit.text()
if file_path == "":
logger.warning("No file path specified.")
return
ds = load_poses.from_file(file_path, source_software, fps)

self.data, self.props = ds_to_napari_tracks(ds)
logger.info("Converted pose tracks to a napari Tracks array.")
logger.debug(f"Tracks data shape: {self.data.shape}")

self.file_name = Path(file_path).name
self.add_layers()

set_playback_fps(fps)
logger.debug(f"Set napari playback speed to {fps} fps.")

def add_layers(self):
"""Add the predicted pose tracks and keypoints to the napari viewer."""
n_individuals = len(self.props["individual"].unique())
color_by = "individual" if n_individuals > 1 else "keypoint"

# Style properties for the napari Points layer
points_style = PointsStyle(
name=f"Keypoints - {self.file_name}",
properties=self.props,
)
points_style.set_color_by(prop=color_by, cmap="turbo")

# Track properties must be numeric, so convert str to categorical codes
tracks_props = columns_to_categorical_codes(
self.props, ["individual", "keypoint"]
)

# kwargs for the napari Tracks layer
tracks_style = TracksStyle(
name=f"Tracks - {self.file_name}",
properties=tracks_props,
)
tracks_style.set_color_by(prop=color_by, cmap="turbo")

# Add the new layers to the napari viewer
self.viewer.add_tracks(self.data, **tracks_style.as_kwargs())
self.viewer.add_points(self.data[:, 1:], **points_style.as_kwargs())
24 changes: 24 additions & 0 deletions movement/napari/meta_widget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from brainglobe_utils.qtpy.collapsible_widget import CollapsibleWidgetContainer
from napari.viewer import Viewer

from movement.napari.loader_widget import Loader


class MovementMetaWidget(CollapsibleWidgetContainer):
"""The widget to rule all movement napari widgets.

This is a container of collapsible widgets, each responsible
for handing specific tasks in the movement napari workflow.
"""

def __init__(self, napari_viewer: Viewer, parent=None):
super().__init__()

self.add_widget(
Loader(napari_viewer, parent=self),
collapsible=True,
widget_title="Load",
)

self.loader = self.collapsible_widgets[0]
self.loader.expand() # expand the loader widget by default
10 changes: 10 additions & 0 deletions movement/napari/napari.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: movement
display_name: movement
contributions:
commands:
- id: movement.make_widget
python_name: movement.napari.meta_widget:MovementMetaWidget
title: movement
widgets:
- command: movement.make_widget
display_name: movement
Loading
Loading