-
Notifications
You must be signed in to change notification settings - Fork 62
Plotting wrappers: Head Trajectory #394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #394 +/- ##
=======================================
Coverage 99.82% 99.83%
=======================================
Files 20 22 +2
Lines 1169 1207 +38
=======================================
+ Hits 1167 1205 +38
Misses 2 2 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some stylistic comments here from me.
As for testing, I'll copy what I put in our chat here so it's visible for everyone:
You can't really test "images" per se, so I'd follow the advice of the StackOverflow hive mind like in this article here: use trajectory
to create a figure on some (small) custom dataset, then examine the data stored in the returned Figure
object.
Maybe make a DataArray with 3 keypoints (left
, centre
, and right
) that all move in parallel lines, with centre
being the midpoint of the other two.
Then you can check that if you pass the left
and right
keypoints, the figure contains data that matches the centre
points? Likewise you can also check that if you only pass one of the keypoints, the data in the returned plot matches the raw data for that keypoint?
May be able to use some of the |
Possibly, if they're simple enough to compute and work with then definitely re-use them. My original idea was something like 3 "parallel lines" for the keypoints:
but if something like this already exists then go with that! |
Co-authored-by: Will Graham <[email protected]>
Co-authored-by: Will Graham <[email protected]>
Co-authored-by: Will Graham <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @stellaprins,
Thanks for starting the work on the plotting wrappers!
I dived deep into this one and left lots of comments. My rationale is that since this is the first of many such functions, it's important to get it right the first time. The patterns we establish here can be the easily generalised to the other wrappers as well.
Feel free to push back on any of my suggestions. I am quite opinionated about matplotlib and how it should be wrapped, but my opinions can be often subjective.
To make your life somewhat easier, I've written my own version of this function containing all my various suggestions, so you can see how they all come together.
Beware that I haven't tested it rigorously (I just made sure it runs in several places in the examples).
My version of plot.trajectory
def trajectory(
data: xr.DataArray,
individual: int | str = 0,
keypoint: None | str | list[str] = None,
image_path: None | Path = None,
title: str | None = None,
ax: plt.Axes | None = None,
**kwargs,
) -> tuple[plt.Figure, plt.Axes]:
"""Plot trajectory.
This function plots the trajectory of a specified keypoint or the centroid
between multiple keypoints for a given individual. The individual can be
specified by their index or name. By default, the first individual is
selected. The trajectory is colored by time (using the default colormap).
Pass a different colormap through ``cmap`` if desired.
Parameters
----------
data : xr.DataArray
A data array containing position information, with `time` and `space`
as required dimensions. Optionally, it may have `individuals` and/or
`keypoints` dimensions.
individual : int or str, default=0
Individual index or name. By default, the first individual is chosen.
If there is no `individuals` dimension, this argument is ignored.
keypoint : None, str or list of str, optional
- If None, the centroid of **all** keypoints is plotted (default).
- If str, that single keypoint's trajectory is plotted.
- If a list of keypoints, their centroid is plotted.
If there is no `keypoints` dimension, this argument is ignored.
image_path : None or Path, optional
Path to an image over which the trajectory data can be overlaid,
e.g., a reference video frame.
title : str or None, optional
Title of the plot. If not provided, one is generated based on
the individual's name (if applicable).
ax : matplotlib.axes.Axes or None, optional
Axes object on which to draw the trajectory. If None, a new
figure and axes are created.
**kwargs : dict
Additional keyword arguments passed to
``matplotlib.axes.Axes.scatter()``.
Returns
-------
(figure, axes) : tuple of (matplotlib.pyplot.Figure, matplotlib.axes.Axes)
The figure and axes containing the trajectory plot.
"""
# Construct selection dict for individuals and keypoints
selection = {}
# Determine which individual to select (if any)
if "individuals" in data.dims:
# Convert int index to actual individual name if needed
selection["individuals"] = chosen_ind = (
data.individuals.values[individual]
if isinstance(individual, int)
else str(individual)
)
title_suffix = f" of {chosen_ind}"
else:
title_suffix = ""
# Determine which keypoint(s) to select (if any)
if "keypoints" in data.dims:
if keypoint is None:
selection["keypoints"] = data.keypoints.values
elif isinstance(keypoint, str):
selection["keypoints"] = [keypoint]
elif isinstance(keypoint, list):
selection["keypoints"] = keypoint
# Select the data for the specified individual and keypoint(s)
plot_point = data.sel(**selection)
# If there are multiple selected keypoints, calculate the centroid
if "keypoints" in plot_point.dims and plot_point.sizes["keypoints"] > 1:
plot_point = plot_point.mean(dim="keypoints", skipna=True)
# Squeeze all dimensions with size 1 (only time and space should remain)
plot_point = plot_point.squeeze()
# Create a new Figure/Axes if none is passed
if ax is None:
fig, ax = plt.subplots(figsize=(6, 6))
else:
fig = ax.figure
# Merge default plotting args with user-provided kwargs
for key, value in DEFAULT_PLOTTING_ARGS.items():
kwargs.setdefault(key, value)
# Plot the scatter, coloring by time
sc = ax.scatter(
plot_point.sel(space="x"),
plot_point.sel(space="y"),
c=plot_point.time,
**kwargs,
)
# Handle axis labeling
space_unit = data.attrs.get("space_unit", "pixels")
ax.set_xlabel(f"x ({space_unit})")
ax.set_ylabel(f"y ({space_unit})")
ax.axis("equal")
# By default, invert y-axis so (0,0) is in the top-left,
# matching typical image coordinate systems
ax.invert_yaxis()
# Generate default title if none provided
if title is None:
title = f"Trajectory{title_suffix}"
ax.set_title(title)
# Add colorbar for time dimension
time_unit = data.attrs.get("time_unit")
time_label = f"time ({time_unit})" if time_unit else "time steps (frames)"
colorbar = fig.colorbar(sc, ax=ax, label=time_label)
# Ensure colorbar is fully opaque
colorbar.solids.set(alpha=1.0)
if image_path is not None:
frame = plt.imread(image_path)
# Invert the y-axis back again since the the image is plotted
# using a coordinate system with origin on the top left of the image
ax.invert_yaxis()
ax.imshow(frame)
return fig, ax
After updating the function, you'd have to modify the tests accordingly (I haven't commented on them directly) and also update the examples. Beware that the modified trajectory function can be used in multiple places in our examples, and I suggest going through all of them and using the opportunity yo do so where appropriate. Our users will benefit from seeing this function be used in various contexts.
Apologies for putting you through all this work. If you prefer, I can also take on the job of updating the examples when you are done with the function and tests. Let me know.
Hi @niksirbi!
No no. I really appreciate it! Totally agree with you it's better to get it right the first time and I am learning a lot from this, so thanks for all the comments. Will go through them and implement the changes, update tests, and then update the other plot wrapper PR #402 acoordingly.
I don't really have a strong opinion on matplotlib (yet?), but maybe @willGraham01 has?
Thanks! |
I've started a zulip thread to discuss general standards for wrapping matplotlib. |
My personal opinion is to just delegate everything that isn't the barebones functionality you want to wrap to some |
Agreed, so for this specific function, that would also mean not exposing the |
…atics-unit/movement into sp/282-plot-wrappers
…ion, remove title
…ion, remove title
…atics-unit/movement into sp/282-plot-wrappers
…ints Co-authored-by: Niko Sirmpilatze <[email protected]>
Co-authored-by: Will Graham <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we finally have it @stellaprins.
I left some minor comments, mostly about wording.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved and ready to merge (after fixing my 2 remaining comments, which I had missed before).
|
Description
What is this PR
Why is this PR needed?
To conveniently plot trajectories of keypoints or midpoints between two keypoints (e.g.
left_ear
andright_ear
for head trajectory, see first two plots in computes polar coordinates example) with a single line of code.What does this PR do?
Introduces a
plot
module with atrajectory
function.References
#387
How has this PR been tested?
Checked the updated documentation locally.
Is this a breaking change?
No.
Does this PR require an update to the documentation?
Yes. I've updated the first part of the Express 2D vectors in polar coordinates example.
Checklist: