Skip to content

Commit 6db50cc

Browse files
stellaprinswillGraham01niksirbi
authored
Plotting wrappers: Head Trajectory (#394)
* add plot module with trajectory function * update example * add plot kwargs and defaults * add tests for trajectory plotting * parametrize test_trajectory Co-authored-by: Will Graham <[email protected]> * add docstring to trajectory, remove test code Co-authored-by: Will Graham <[email protected]> * improve code coverage, make colorbar alpha resistant Co-authored-by: Will Graham <[email protected]> * Add Niko's plot trajectory suggestion, fix tests * update example * adjust example, make lines causing '\d' SyntaxWarning raw * change trajectory inputs, replace keypoint and individual with selection, remove title * change trajectory inputs, replace keypoint and individual with selection, remove title * use selection dict for individuals and keypoints * update trajectory examples, allow user to set marker colour * fix examples, fix colorbar * improve code coverage * remove image_path from plot input, adjust example * fix test after removing image_path input * add more tests * test trajectory without individuals and/or keypoints dimension * fix logic test_trajectory_dropped_dim * add tests, update examples, change selection to individuals and keypoints Co-authored-by: Niko Sirmpilatze <[email protected]> * change trajectory input in examples as well * fix input load and explore poses example * change folder structure * add init file Co-authored-by: Will Graham <[email protected]> * fix test * deal with drop deprecation (-> drop_vars) * process review suggestions, adjust drop dimension test * reorder plot_trajectory inputs * fix typo example --------- Co-authored-by: Will Graham <[email protected]> Co-authored-by: Niko Sirmpilatze <[email protected]>
1 parent 55b2561 commit 6db50cc

File tree

7 files changed

+400
-84
lines changed

7 files changed

+400
-84
lines changed

examples/compute_kinematics.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from matplotlib import pyplot as plt
1717

1818
from movement import sample_data
19+
from movement.plots import plot_trajectory
1920
from movement.utils.vector import compute_norm
2021

2122
# %%
@@ -48,27 +49,36 @@
4849
# Visualise the data
4950
# ---------------------------
5051
# First, let's visualise the trajectories of the mice in the XY plane,
51-
# colouring them by individual.
52+
# colouring them by individual. We use the ``plot_trajectory`` function from
53+
# ``movement.plots`` which is a wrapper around
54+
# ``matplotlib.pyplot.scatter`` that simplifies plotting the trajectories of
55+
# individuals in the dataset. The fig and ax objects returned can be used to
56+
# further customise the plot.
5257

58+
# Create a single figure and axes
5359
fig, ax = plt.subplots(1, 1)
60+
# Invert y-axis so (0,0) is in the top-left,
61+
# matching typical image coordinate systems
62+
ax.invert_yaxis()
63+
# Plot trajectories for each mouse on the same axes
5464
for mouse_name, col in zip(
55-
position.individuals.values, ["r", "g", "b"], strict=False
65+
position.individuals.values,
66+
["r", "g", "b"], # colours
67+
strict=False,
5668
):
57-
ax.plot(
58-
position.sel(individuals=mouse_name, space="x"),
59-
position.sel(individuals=mouse_name, space="y"),
60-
linestyle="-",
61-
marker=".",
62-
markersize=2,
63-
linewidth=0.5,
69+
plot_trajectory(
70+
position,
71+
individual=mouse_name,
72+
ax=ax, # Use the same axes for all plots
6473
c=col,
74+
marker="o",
75+
s=10,
76+
alpha=0.2,
6577
label=mouse_name,
6678
)
67-
ax.invert_yaxis()
68-
ax.set_xlabel("x (pixels)")
69-
ax.set_ylabel("y (pixels)")
70-
ax.axis("equal")
71-
ax.legend()
79+
ax.legend().set_alpha(1)
80+
ax.title.set_text("Trajectories of three mice")
81+
fig.show()
7282

7383
# %%
7484
# We can see that the trajectories of the three mice are close to a circular
@@ -77,23 +87,19 @@
7787
# follows the convention for SLEAP and most image processing tools.
7888

7989
# %%
80-
# We can also color the data points based on their timestamps:
90+
# By default the ``plot_trajectory`` function in ``movement.plots`` colours
91+
# data points based on their timestamps:
8192
fig, axes = plt.subplots(3, 1, sharey=True)
8293
for mouse_name, ax in zip(position.individuals.values, axes, strict=False):
83-
sc = ax.scatter(
84-
position.sel(individuals=mouse_name, space="x"),
85-
position.sel(individuals=mouse_name, space="y"),
94+
ax.invert_yaxis()
95+
fig, ax = plot_trajectory(
96+
position,
97+
individual=mouse_name,
98+
ax=ax,
8699
s=2,
87-
c=position.time,
88-
cmap="viridis",
89100
)
90-
ax.invert_yaxis()
91-
ax.set_title(mouse_name)
92-
ax.set_xlabel("x (pixels)")
93-
ax.set_ylabel("y (pixels)")
94-
ax.axis("equal")
95-
fig.colorbar(sc, ax=ax, label="time (s)")
96101
fig.tight_layout()
102+
fig.show()
97103

98104
# %%
99105
# These plots show that for this snippet of the data,

examples/compute_polar_coordinates.py

Lines changed: 28 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from movement import sample_data
1818
from movement.io import load_poses
19+
from movement.plots import plot_trajectory
1920
from movement.utils.vector import cart2pol, pol2cart
2021

2122
# %%
@@ -69,65 +70,51 @@
6970
# We can plot the data to check that our computation of the head vector is
7071
# correct.
7172
#
72-
# We can start by plotting the trajectory of the midpoint between the ears. We
73-
# will refer to this as the head trajectory.
74-
75-
fig, ax = plt.subplots(1, 1)
76-
mouse_name = ds.individuals.values[0]
77-
78-
sc = ax.scatter(
79-
midpoint_ears.sel(individuals=mouse_name, space="x"),
80-
midpoint_ears.sel(individuals=mouse_name, space="y"),
81-
s=15,
82-
c=midpoint_ears.time,
83-
cmap="viridis",
84-
marker="o",
85-
)
86-
87-
ax.axis("equal")
88-
ax.set_xlabel("x (pixels)")
89-
ax.set_ylabel("y (pixels)")
73+
# We can start by plotting the head trajectory with the ``plot_trajectory``
74+
# from ``movement.plots`` which creates a plot of the centroid of
75+
# the selected keypoints, for the head trajectory, we will use the midpoint
76+
# between the ears. By default, the trajectory of the first listed individual
77+
# is shown.
78+
79+
fig, ax = plot_trajectory(position, keypoints=["left_ear", "right_ear"])
80+
# Invert y-axis so (0,0) is in the top-left,
81+
# matching typical image coordinate systems
9082
ax.invert_yaxis()
91-
ax.set_title(f"Head trajectory ({mouse_name})")
92-
fig.colorbar(sc, ax=ax, label=f"time ({ds.attrs['time_unit']})")
9383
fig.show()
9484

85+
9586
# %%
87+
# Overlay trajectory on Elevated Plus Maze
88+
# ----------------------------------------
9689
# We can see that the majority of the head trajectory data is within a
9790
# cruciform shape. This is because the dataset is of a mouse moving on an
9891
# `Elevated Plus Maze <https://en.wikipedia.org/wiki/Elevated_plus_maze>`_.
9992
# We can actually verify this is the case by overlaying the head
10093
# trajectory on the sample frame of the dataset.
10194

102-
# read sample frame
95+
# Read sample frame
10396
frame_path = sample_data.fetch_dataset_paths(
10497
"SLEAP_single-mouse_EPM.analysis.h5"
10598
)["frame"]
106-
im = plt.imread(frame_path)
10799

108-
109-
# plot sample frame
100+
# Create figure and axis
110101
fig, ax = plt.subplots(1, 1)
111-
ax.imshow(im)
112-
113-
# plot head trajectory with semi-transparent markers
114-
sc = ax.scatter(
115-
midpoint_ears.sel(individuals=mouse_name, space="x"),
116-
midpoint_ears.sel(individuals=mouse_name, space="y"),
117-
s=15,
118-
c=midpoint_ears.time,
102+
# Plot the frame using imshow
103+
ax.imshow(plt.imread(frame_path))
104+
# No need to invert the y-axis now, since the image is plotted
105+
# using a pixel coordinate system with origin on the top left of the image
106+
fig, ax = plot_trajectory(
107+
ds.position,
108+
individual="individual_0",
109+
keypoints=["left_ear", "right_ear"],
110+
ax=ax,
111+
s=10,
119112
cmap="viridis",
120113
marker="o",
121-
alpha=0.05, # transparency
114+
alpha=0.05,
122115
)
123-
124-
ax.axis("equal")
125-
ax.set_xlabel("x (pixels)")
126-
ax.set_ylabel("y (pixels)")
127-
# No need to invert the y-axis now, since the image is plotted
128-
# using a pixel coordinate system with origin on the top left of the image
129-
ax.set_title(f"Head trajectory ({mouse_name})")
130-
116+
# Adjust title
117+
ax.set_title("Head trajectory (individual_0)")
131118
fig.show()
132119

133120
# %%

examples/load_and_explore_poses.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
# %%
88
# Imports
99
# -------
10-
from matplotlib import pyplot as plt
1110

1211
from movement import sample_data
1312
from movement.io import load_poses
13+
from movement.plots import plot_trajectory
1414

1515
# %%
1616
# Define the file path
@@ -70,19 +70,10 @@
7070
# Trajectory plots
7171
# ----------------
7272
# We are not limited to ``xarray``'s built-in plots.
73-
# For example, we can use ``matplotlib`` to plot trajectories
74-
# (using scatter plots):
73+
# The ``movement.plots`` module provides some additional
74+
# visualisations, like ``plot_trajectory()``.
7575

76-
mouse_name = "AEON3B_TP1"
7776

78-
plt.scatter(
79-
da.sel(individuals=mouse_name, space="x"),
80-
da.sel(individuals=mouse_name, space="y"),
81-
s=2,
82-
c=da.time,
83-
cmap="viridis",
84-
)
85-
plt.title(f"Trajectory of {mouse_name}")
86-
plt.xlabel("x")
87-
plt.ylabel("y")
88-
plt.colorbar(label="time (sec)")
77+
mouse_name = "AEON3B_TP1"
78+
fig, ax = plot_trajectory(position, individual=mouse_name)
79+
fig.show()

movement/plots/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from movement.plots.trajectory import plot_trajectory
2+
3+
__all__ = ["plot_trajectory"]

movement/plots/trajectory.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""Wrappers to plot movement data."""
2+
3+
import xarray as xr
4+
from matplotlib import pyplot as plt
5+
6+
DEFAULT_PLOTTING_ARGS = {
7+
"s": 15,
8+
"marker": "o",
9+
"alpha": 1.0,
10+
}
11+
12+
13+
def plot_trajectory(
14+
da: xr.DataArray,
15+
individual: str | None = None,
16+
keypoints: str | list[str] | None = None,
17+
ax: plt.Axes | None = None,
18+
**kwargs,
19+
) -> tuple[plt.Figure, plt.Axes]:
20+
"""Plot trajectory.
21+
22+
This function plots the trajectory of a specified keypoint or the centroid
23+
of multiple keypoints for a given individual. By default, the first
24+
is colored by time (using the default colormap). Pass a different colormap
25+
through ``cmap`` if desired.
26+
27+
Parameters
28+
----------
29+
da : xr.DataArray
30+
A data array containing position information, with `time` and `space`
31+
as required dimensions. Optionally, it may have `individuals` and/or
32+
`keypoints` dimensions.
33+
individual : str, optional
34+
The name of the individual to be plotted. By default, the first
35+
individual is plotted.
36+
keypoints : str, list[str], optional
37+
The name of the keypoint to be plotted, or a list of keypoint names
38+
(their centroid will be plotted). By default, the centroid of all
39+
keypoints is plotted.
40+
ax : matplotlib.axes.Axes or None, optional
41+
Axes object on which to draw the trajectory. If None, a new
42+
figure and axes are created.
43+
**kwargs : dict
44+
Additional keyword arguments passed to
45+
``matplotlib.axes.Axes.scatter()``.
46+
47+
Returns
48+
-------
49+
(figure, axes) : tuple of (matplotlib.pyplot.Figure, matplotlib.axes.Axes)
50+
The figure and axes containing the trajectory plot.
51+
52+
"""
53+
if isinstance(individual, list):
54+
raise ValueError("Only one individual can be selected.")
55+
56+
selection = {}
57+
58+
if "individuals" in da.dims:
59+
if individual is None:
60+
selection["individuals"] = da.individuals.values[0]
61+
else:
62+
selection["individuals"] = individual
63+
64+
title_suffix = f" of {individual}" if "individuals" in da.dims else ""
65+
66+
if "keypoints" in da.dims:
67+
if keypoints is None:
68+
selection["keypoints"] = da.keypoints.values
69+
else:
70+
selection["keypoints"] = keypoints
71+
72+
plot_point = da.sel(**selection)
73+
74+
# If there are multiple selected keypoints, calculate the centroid
75+
plot_point = (
76+
plot_point.mean(dim="keypoints", skipna=True)
77+
if "keypoints" in plot_point.dims and plot_point.sizes["keypoints"] > 1
78+
else plot_point
79+
)
80+
81+
plot_point = plot_point.squeeze() # Only space and time should remain
82+
83+
fig, ax = plt.subplots(figsize=(6, 6)) if ax is None else (ax.figure, ax)
84+
85+
# Merge default plotting args with user-provided kwargs
86+
for key, value in DEFAULT_PLOTTING_ARGS.items():
87+
kwargs.setdefault(key, value)
88+
89+
colorbar = False
90+
if "c" not in kwargs:
91+
kwargs["c"] = plot_point.time
92+
colorbar = True
93+
94+
# Plot the scatter, colouring by time or user-provided colour
95+
sc = ax.scatter(
96+
plot_point.sel(space="x"),
97+
plot_point.sel(space="y"),
98+
**kwargs,
99+
)
100+
101+
space_unit = da.attrs.get("space_unit", "pixels")
102+
ax.set_xlabel(f"x ({space_unit})")
103+
ax.set_ylabel(f"y ({space_unit})")
104+
ax.axis("equal")
105+
ax.set_title(f"Trajectory{title_suffix}")
106+
107+
# Add 'colorbar' for time dimension if no colour was provided by user
108+
time_unit = da.attrs.get("time_unit")
109+
time_label = f"time ({time_unit})" if time_unit else "time steps (frames)"
110+
fig.colorbar(sc, ax=ax, label=time_label).solids.set(
111+
alpha=1.0
112+
) if colorbar else None
113+
114+
return fig, ax

tests/test_unit/test_load_bboxes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def test_from_via_tracks_file(
295295
r"_(0\d*)_$",
296296
AttributeError,
297297
"/crab_1/00000.jpg (row 0): "
298-
"The provided frame regexp (_(0\d*)_$) did not return any "
298+
r"The provided frame regexp (_(0\d*)_$) did not return any "
299299
"matches and a frame number could not be extracted from "
300300
"the filename.",
301301
),
@@ -304,7 +304,7 @@ def test_from_via_tracks_file(
304304
ValueError,
305305
"/crab_1/00000.jpg (row 0): "
306306
"The frame number extracted from the filename "
307-
"using the provided regexp ((0\d*\.\w+)$) "
307+
r"using the provided regexp ((0\d*\.\w+)$) "
308308
"could not be cast as an integer.",
309309
),
310310
],

0 commit comments

Comments
 (0)