Skip to content

Plotting wrappers: Head Trajectory #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Feb 12, 2025
Merged

Plotting wrappers: Head Trajectory #394

merged 37 commits into from
Feb 12, 2025

Conversation

stellaprins
Copy link
Contributor

@stellaprins stellaprins commented Jan 28, 2025

Description

What is this PR

  • Bug fix
  • Addition of a new feature
  • Other

Why is this PR needed?
To conveniently plot trajectories of keypoints or midpoints between two keypoints (e.g. left_ear and right_ear for head trajectory, see first two plots in computes polar coordinates example) with a single line of code.

What does this PR do?
Introduces a plot module with a trajectory function.

References

#387

How has this PR been tested?

Checked the updated documentation locally.

Is this a breaking change?

No.

Does this PR require an update to the documentation?

Yes. I've updated the first part of the Express 2D vectors in polar coordinates example.

Checklist:

  • The code has been tested locally
  • Tests have been added to cover all new functionality
  • The documentation has been updated to reflect any changes
  • The code has been formatted with pre-commit

Copy link

codecov bot commented Jan 28, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.83%. Comparing base (55b2561) to head (cd2bc99).
Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #394   +/-   ##
=======================================
  Coverage   99.82%   99.83%           
=======================================
  Files          20       22    +2     
  Lines        1169     1207   +38     
=======================================
+ Hits         1167     1205   +38     
  Misses          2        2           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@willGraham01 willGraham01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some stylistic comments here from me.

As for testing, I'll copy what I put in our chat here so it's visible for everyone:

You can't really test "images" per se, so I'd follow the advice of the StackOverflow hive mind like in this article here: use trajectory to create a figure on some (small) custom dataset, then examine the data stored in the returned Figure object.

Maybe make a DataArray with 3 keypoints (left, centre, and right) that all move in parallel lines, with centre being the midpoint of the other two.
Then you can check that if you pass the left and right keypoints, the figure contains data that matches the centre points? Likewise you can also check that if you only pass one of the keypoints, the data in the returned plot matches the raw data for that keypoint?

@stellaprins
Copy link
Contributor Author

Maybe make a DataArray with 3 keypoints (left, centre, and right) that all move in parallel lines, with centre being the midpoint of the other two. Then you can check that if you pass the left and right keypoints, the figure contains data that matches the centre points? Likewise you can also check that if you only pass one of the keypoints, the data in the returned plot matches the raw data for that keypoint?

May be able to use some of the valid_poses_dataset_ fixtures that have left, right, and centroid keypoints.

@willGraham01
Copy link
Contributor

May be able to use some of the valid_poses_dataset_ fixtures that have left, right, and centroid keypoints.

Possibly, if they're simple enough to compute and work with then definitely re-use them.

My original idea was something like 3 "parallel lines" for the keypoints:

left: (-1,0) -> (-1,1) -> (-1,2) -> (-1,3) -> (-1,4)
centre: (0,0) -> (0,1) -> (0,2) -> (0,3) -> (0,4)
right: (1,0) -> (1,1) -> (1,2) -> (1,3) -> (1,4)

but if something like this already exists then go with that!

@stellaprins stellaprins marked this pull request as ready for review January 31, 2025 12:13
@stellaprins stellaprins requested a review from niksirbi January 31, 2025 12:14
Copy link
Member

@niksirbi niksirbi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @stellaprins,

Thanks for starting the work on the plotting wrappers!
I dived deep into this one and left lots of comments. My rationale is that since this is the first of many such functions, it's important to get it right the first time. The patterns we establish here can be the easily generalised to the other wrappers as well.

Feel free to push back on any of my suggestions. I am quite opinionated about matplotlib and how it should be wrapped, but my opinions can be often subjective.

To make your life somewhat easier, I've written my own version of this function containing all my various suggestions, so you can see how they all come together.
Beware that I haven't tested it rigorously (I just made sure it runs in several places in the examples).

My version of plot.trajectory
def trajectory(
    data: xr.DataArray,
    individual: int | str = 0,
    keypoint: None | str | list[str] = None,
    image_path: None | Path = None,
    title: str | None = None,
    ax: plt.Axes | None = None,
    **kwargs,
) -> tuple[plt.Figure, plt.Axes]:
    """Plot trajectory.

    This function plots the trajectory of a specified keypoint or the centroid
    between multiple keypoints for a given individual. The individual can be
    specified by their index or name. By default, the first individual is
    selected. The trajectory is colored by time (using the default colormap).
    Pass a different colormap through ``cmap`` if desired.

    Parameters
    ----------
    data : xr.DataArray
        A data array containing position information, with `time` and `space`
        as required dimensions. Optionally, it may have `individuals` and/or
        `keypoints` dimensions.
    individual : int or str, default=0
        Individual index or name. By default, the first individual is chosen.
        If there is no `individuals` dimension, this argument is ignored.
    keypoint : None, str or list of str, optional
        - If None, the centroid of **all** keypoints is plotted (default).
        - If str, that single keypoint's trajectory is plotted.
        - If a list of keypoints, their centroid is plotted.
        If there is no `keypoints` dimension, this argument is ignored.
    image_path : None or Path, optional
        Path to an image over which the trajectory data can be overlaid,
        e.g., a reference video frame.
    title : str or None, optional
        Title of the plot. If not provided, one is generated based on
        the individual's name (if applicable).
    ax : matplotlib.axes.Axes or None, optional
        Axes object on which to draw the trajectory. If None, a new
        figure and axes are created.
    **kwargs : dict
        Additional keyword arguments passed to
        ``matplotlib.axes.Axes.scatter()``.

    Returns
    -------
    (figure, axes) : tuple of (matplotlib.pyplot.Figure, matplotlib.axes.Axes)
        The figure and axes containing the trajectory plot.

    """
    # Construct selection dict for individuals and keypoints
    selection = {}

    # Determine which individual to select (if any)
    if "individuals" in data.dims:
        # Convert int index to actual individual name if needed
        selection["individuals"] = chosen_ind = (
            data.individuals.values[individual]
            if isinstance(individual, int)
            else str(individual)
        )
        title_suffix = f" of {chosen_ind}"
    else:
        title_suffix = ""

    # Determine which keypoint(s) to select (if any)
    if "keypoints" in data.dims:
        if keypoint is None:
            selection["keypoints"] = data.keypoints.values
        elif isinstance(keypoint, str):
            selection["keypoints"] = [keypoint]
        elif isinstance(keypoint, list):
            selection["keypoints"] = keypoint

    # Select the data for the specified individual and keypoint(s)
    plot_point = data.sel(**selection)

    # If there are multiple selected keypoints, calculate the centroid
    if "keypoints" in plot_point.dims and plot_point.sizes["keypoints"] > 1:
        plot_point = plot_point.mean(dim="keypoints", skipna=True)

    # Squeeze all dimensions with size 1 (only time and space should remain)
    plot_point = plot_point.squeeze()

    # Create a new Figure/Axes if none is passed
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))
    else:
        fig = ax.figure

    # Merge default plotting args with user-provided kwargs
    for key, value in DEFAULT_PLOTTING_ARGS.items():
        kwargs.setdefault(key, value)

    # Plot the scatter, coloring by time
    sc = ax.scatter(
        plot_point.sel(space="x"),
        plot_point.sel(space="y"),
        c=plot_point.time,
        **kwargs,
    )

    # Handle axis labeling
    space_unit = data.attrs.get("space_unit", "pixels")
    ax.set_xlabel(f"x ({space_unit})")
    ax.set_ylabel(f"y ({space_unit})")
    ax.axis("equal")

    # By default, invert y-axis so (0,0) is in the top-left,
    # matching typical image coordinate systems
    ax.invert_yaxis()

    # Generate default title if none provided
    if title is None:
        title = f"Trajectory{title_suffix}"
    ax.set_title(title)

    # Add colorbar for time dimension
    time_unit = data.attrs.get("time_unit")
    time_label = f"time ({time_unit})" if time_unit else "time steps (frames)"
    colorbar = fig.colorbar(sc, ax=ax, label=time_label)
    # Ensure colorbar is fully opaque
    colorbar.solids.set(alpha=1.0)

    if image_path is not None:
        frame = plt.imread(image_path)
        # Invert the y-axis back again since the the image is plotted
        # using a coordinate system with origin on the top left of the image
        ax.invert_yaxis()
        ax.imshow(frame)

    return fig, ax

After updating the function, you'd have to modify the tests accordingly (I haven't commented on them directly) and also update the examples. Beware that the modified trajectory function can be used in multiple places in our examples, and I suggest going through all of them and using the opportunity yo do so where appropriate. Our users will benefit from seeing this function be used in various contexts.

Apologies for putting you through all this work. If you prefer, I can also take on the job of updating the examples when you are done with the function and tests. Let me know.

@stellaprins
Copy link
Contributor Author

Hi @niksirbi!

Apologies for putting you through all this work.

No no. I really appreciate it! Totally agree with you it's better to get it right the first time and I am learning a lot from this, so thanks for all the comments. Will go through them and implement the changes, update tests, and then update the other plot wrapper PR #402 acoordingly.

Feel free to push back on any of my suggestions. I am quite opinionated about matplotlib and how it should be wrapped, but my opinions can be often subjective.

I don't really have a strong opinion on matplotlib (yet?), but maybe @willGraham01 has?

To make your life somewhat easier, I've written my own version of this function containing all my various suggestions, so you can see how they all come together.

Thanks!

@niksirbi
Copy link
Member

niksirbi commented Feb 4, 2025

I've started a zulip thread to discuss general standards for wrapping matplotlib.

@willGraham01
Copy link
Contributor

I don't really have a strong opinion on matplotlib (yet?), but maybe @willGraham01 has?

My personal opinion is to just delegate everything that isn't the barebones functionality you want to wrap to some kwargs that you can hand to matplotlib's backend. But I'll move this to the Zulip thread.

@niksirbi
Copy link
Member

niksirbi commented Feb 4, 2025

My personal opinion is to just delegate everything that isn't the barebones functionality you want to wrap to some kwargs that you can hand to matplotlib's backend. But I'll move this to the Zulip thread.

Agreed, so for this specific function, that would also mean not exposing the title parameter.

@stellaprins stellaprins requested a review from niksirbi February 8, 2025 11:09
Copy link
Member

@niksirbi niksirbi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we finally have it @stellaprins.

I left some minor comments, mostly about wording.

Copy link
Member

@niksirbi niksirbi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved and ready to merge (after fixing my 2 remaining comments, which I had missed before).

Copy link

@stellaprins stellaprins added this pull request to the merge queue Feb 12, 2025
Merged via the queue into main with commit 6db50cc Feb 12, 2025
18 checks passed
@niksirbi niksirbi deleted the sp/282-plot-wrappers branch February 12, 2025 10:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants