Skip to content

Support exporting animations via animateinline #614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 104 additions & 30 deletions src/tikzplotlib/_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import enum
import tempfile
import warnings
from typing import Literal
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.figure as figure

from . import _axes
from . import _image as img
Expand All @@ -17,7 +20,7 @@


def get_tikz_code(
figure="gcf",
figure: Literal["gcf"] | figure.Figure | animation.TimedAnimation = "gcf",
filepath: str | Path | None = None,
axis_width: str | None = None,
axis_height: str | None = None,
Expand All @@ -32,6 +35,11 @@ def get_tikz_code(
extra_axis_parameters: list | set | None = None,
extra_groupstyle_parameters: dict = {},
extra_tikzpicture_parameters: list | set | None = None,
extra_animation_parameters: list | set | None = [
"autoplay",
"autoresume",
"controls",
],
extra_lines_start: list | set | None = None,
dpi: int | None = None,
show_info: bool = False,
Expand All @@ -44,7 +52,7 @@ def get_tikz_code(
"""Main function. Here, the recursion into the image starts and the
contents are picked up. The actual file gets written in this routine.

:param figure: either a Figure object or 'gcf' (default).
:param figure: either a Figure object or an animation or 'gcf' (default).

:param axis_width: If not ``None``, this will be used as figure width within the
TikZ/PGFPlots output. If ``axis_height`` is not given,
Expand Down Expand Up @@ -109,6 +117,10 @@ def get_tikz_code(
(as a set) to pgfplots.
:type extra_tikzpicture_parameters: a set of strings for the pfgplots tikzpicture.

:param extra_animation_parameters: Extra animation options to be passed
(as a set) to animateinline when an animation is passed.
:type extra_animation_parameters: a set of strings for the animateinline animation.

:param dpi: The resolution in dots per inch of the rendered image in case
of QuadMesh plots. If ``None`` it will default to the value
``savefig.dpi`` from matplotlib.rcParams. Default is ``None``.
Expand Down Expand Up @@ -150,6 +162,7 @@ def get_tikz_code(
if figure == "gcf":
figure = plt.gcf()
data = {}
data["animation"] = isinstance(figure, animation.Animation)
data["axis width"] = axis_width
data["axis height"] = axis_height
data["rel data path"] = (
Expand Down Expand Up @@ -209,39 +222,89 @@ def get_tikz_code(
if show_info:
_print_pgfplot_libs_message(data)

# gather the file content
data, content = _recurse(data, figure)

# Check if there is still an open groupplot environment. This occurs if not
# all of the group plot slots are used.
if "is_in_groupplot_env" in data and data["is_in_groupplot_env"]:
content.extend(data["flavor"].end("groupplot") + "\n\n")
def get_figure_tikz_code(
data,
figure,
wrap: bool = wrap,
include_disclaimer: bool = include_disclaimer,
):
# gather the file content
data, content = _recurse(data, figure)

# Check if there is still an open groupplot environment. This occurs if not
# all of the group plot slots are used.
if "is_in_groupplot_env" in data and data["is_in_groupplot_env"]:
content.extend(data["flavor"].end("groupplot") + "\n\n")
data["is_in_groupplot_env"] = False

code = """"""

if include_disclaimer:
disclaimer = f"This file was created with tikzplotlib v{__version__}."
code += _tex_comment(disclaimer)

# write the contents
if wrap and add_axis_environment:
code += data["flavor"].start("tikzpicture")
if extra_tikzpicture_parameters:
code += "[\n" + ",\n".join(extra_tikzpicture_parameters) + "\n]"
code += "\n"
if extra_lines_start:
code += "\n".join(extra_lines_start) + "\n"
code += "\n"

coldefs = _get_color_definitions(data)
if coldefs:
code += "\n".join(coldefs) + "\n\n"

code += "".join(content)

if wrap and add_axis_environment:
code += data["flavor"].end("tikzpicture") + "\n"

return data, content, code

if isinstance(figure, animation.TimedAnimation):
extra_animation_parameters = list(extra_animation_parameters or [])
if figure._repeat and "loop" not in extra_animation_parameters:
extra_animation_parameters.append("loop")

data["framerate"] = 1000 / figure._interval

frames = []

for frame in figure.new_frame_seq():
figure._draw_frame(frame)
data, content, code = get_figure_tikz_code(
data,
figure._fig,
wrap=True,
include_disclaimer=False,
)
frames.append(f"% Frame {frame + 1}\n{code}\n")

# write disclaimer to the file header
code = """"""
code = """"""

if include_disclaimer:
disclaimer = f"This file was created with tikzplotlib v{__version__}."
code += _tex_comment(disclaimer)
if include_disclaimer:
disclaimer = f"This file was created with tikzplotlib v{__version__}."
code += _tex_comment(disclaimer)

# write the contents
if wrap and add_axis_environment:
code += data["flavor"].start("tikzpicture")
if extra_tikzpicture_parameters:
code += "[\n" + ",\n".join(extra_tikzpicture_parameters) + "\n]"
code += "\n"
if extra_lines_start:
code += "\n".join(extra_lines_start) + "\n"
code += "\n"
# write the contents
if wrap:
code += data["flavor"].start("animateinline")
if extra_animation_parameters:
code += "[\n" + ",\n".join(extra_animation_parameters) + "\n]"
code += f"{{{data['framerate']}}}"
code += "\n"
code += "\n"

coldefs = _get_color_definitions(data)
if coldefs:
code += "\n".join(coldefs) + "\n\n"
code += "\n\\newframe\n".join(frames)

code += "".join(content)
if wrap:
code += data["flavor"].end("animateinline") + "\n"

if wrap and add_axis_environment:
code += data["flavor"].end("tikzpicture") + "\n"
else:
data, content, code = get_figure_tikz_code(data, figure)

if standalone:
# When using pdflatex, \\DeclareUnicodeCharacter is necessary.
Expand Down Expand Up @@ -406,6 +469,7 @@ class Flavors(enum.Enum):
\\usetikzlibrary{{{tikzlibs}}}
\\pgfplotsset{{compat=newest}}
""",
r"\usepackage{{{}}}",
)
context = (
r"\start{}",
Expand All @@ -422,6 +486,7 @@ class Flavors(enum.Enum):
\\unexpanded\\def\\startgroupplot{{\\groupplot}}
\\unexpanded\\def\\stopgroupplot{{\\endgroupplot}}
""",
r"\usemodule[{}]",
)

def start(self, what):
Expand All @@ -430,6 +495,9 @@ def start(self, what):
def end(self, what):
return self.value[1].format(what)

def usepackage(self, *what):
return self.value[4]

def preamble(self, data=None):
if data is None:
data = {
Expand All @@ -438,7 +506,13 @@ def preamble(self, data=None):
}
pgfplotslibs = ",".join(data["pgfplots libs"])
tikzlibs = ",".join(data["tikz libs"])
return self.value[3].format(pgfplotslibs=pgfplotslibs, tikzlibs=tikzlibs)
extra_imports = (
self.usepackage("animate") + "\n" if data.get("animation", False) else ""
)
return (
self.value[3].format(pgfplotslibs=pgfplotslibs, tikzlibs=tikzlibs)
+ extra_imports
)

def standalone(self, code):
docenv = self.value[2]
Expand Down
28 changes: 28 additions & 0 deletions tests/test_animation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#%%
def plot():
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.animation as animation

fig, ax = plt.subplots()
scat = ax.scatter(range(10), [0] * 10)
ax.set_xlim(0, 9)
ax.set_ylim(0, 10)

def update(frame):
y_data = [yi + frame * 0.2 for yi in range(10)]
scat.set_offsets(list(zip(range(10), y_data)))
ax.set_title(f"Frame {frame+1}/20")

ani = animation.FuncAnimation(fig, update, frames=20, repeat=False)
return ani


def test():
from .helpers import assert_equality

assert_equality(plot, __file__[:-3] + "_reference.tex")

# %%
ani = plot()
# %%
Loading