Skip to content

Commit b4361af

Browse files
committed
Handle animation export
1 parent 450712b commit b4361af

File tree

3 files changed

+825
-29
lines changed

3 files changed

+825
-29
lines changed

src/tikzplotlib/_save.py

+109-29
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
import enum
44
import tempfile
55
import warnings
6+
from typing import Literal
67
from pathlib import Path
78

89
import matplotlib as mpl
910
import matplotlib.pyplot as plt
11+
import matplotlib.animation as animation
12+
import matplotlib.figure as figure
1013

1114
from . import _axes
1215
from . import _image as img
@@ -17,7 +20,7 @@
1720

1821

1922
def get_tikz_code(
20-
figure="gcf",
23+
figure: Literal["gcf"] | figure.Figure | animation.TimedAnimation = "gcf",
2124
filepath: str | Path | None = None,
2225
axis_width: str | None = None,
2326
axis_height: str | None = None,
@@ -32,6 +35,11 @@ def get_tikz_code(
3235
extra_axis_parameters: list | set | None = None,
3336
extra_groupstyle_parameters: dict = {},
3437
extra_tikzpicture_parameters: list | set | None = None,
38+
extra_animation_parameters: list | set | None = [
39+
"autoplay",
40+
"autoresume",
41+
"controls",
42+
],
3543
extra_lines_start: list | set | None = None,
3644
dpi: int | None = None,
3745
show_info: bool = False,
@@ -44,7 +52,7 @@ def get_tikz_code(
4452
"""Main function. Here, the recursion into the image starts and the
4553
contents are picked up. The actual file gets written in this routine.
4654
47-
:param figure: either a Figure object or 'gcf' (default).
55+
:param figure: either a Figure object or an animation or 'gcf' (default).
4856
4957
:param axis_width: If not ``None``, this will be used as figure width within the
5058
TikZ/PGFPlots output. If ``axis_height`` is not given,
@@ -109,6 +117,10 @@ def get_tikz_code(
109117
(as a set) to pgfplots.
110118
:type extra_tikzpicture_parameters: a set of strings for the pfgplots tikzpicture.
111119
120+
:param extra_animation_parameters: Extra animation options to be passed
121+
(as a set) to animateinline when an animation is passed.
122+
:type extra_animation_parameters: a set of strings for the animateinline animation.
123+
112124
:param dpi: The resolution in dots per inch of the rendered image in case
113125
of QuadMesh plots. If ``None`` it will default to the value
114126
``savefig.dpi`` from matplotlib.rcParams. Default is ``None``.
@@ -150,6 +162,7 @@ def get_tikz_code(
150162
if figure == "gcf":
151163
figure = plt.gcf()
152164
data = {}
165+
data["animation"] = isinstance(figure, animation.Animation)
153166
data["axis width"] = axis_width
154167
data["axis height"] = axis_height
155168
data["rel data path"] = (
@@ -209,39 +222,95 @@ def get_tikz_code(
209222
if show_info:
210223
_print_pgfplot_libs_message(data)
211224

212-
# gather the file content
213-
data, content = _recurse(data, figure)
225+
def get_figure_tikz_code(
226+
data,
227+
figure,
228+
wrap: bool = wrap,
229+
include_disclaimer: bool = include_disclaimer,
230+
include_colordefs: bool = True,
231+
):
232+
# gather the file content
233+
data, content = _recurse(data, figure)
234+
235+
# Check if there is still an open groupplot environment. This occurs if not
236+
# all of the group plot slots are used.
237+
if "is_in_groupplot_env" in data and data["is_in_groupplot_env"]:
238+
content.extend(data["flavor"].end("groupplot") + "\n\n")
239+
data["is_in_groupplot_env"] = False
240+
241+
code = """"""
242+
243+
if include_disclaimer:
244+
disclaimer = f"This file was created with tikzplotlib v{__version__}."
245+
code += _tex_comment(disclaimer)
246+
247+
# write the contents
248+
if wrap and add_axis_environment:
249+
code += data["flavor"].start("tikzpicture")
250+
if extra_tikzpicture_parameters:
251+
code += "[\n" + ",\n".join(extra_tikzpicture_parameters) + "\n]"
252+
code += "\n"
253+
if extra_lines_start:
254+
code += "\n".join(extra_lines_start) + "\n"
255+
code += "\n"
256+
257+
coldefs = _get_color_definitions(data)
258+
if coldefs and include_colordefs:
259+
code += "\n".join(coldefs) + "\n\n"
260+
261+
code += "".join(content)
262+
263+
if wrap and add_axis_environment:
264+
code += data["flavor"].end("tikzpicture") + "\n"
265+
266+
return data, content, code
267+
268+
if isinstance(figure, animation.TimedAnimation):
269+
extra_animation_parameters = list(extra_animation_parameters or [])
270+
if figure._repeat and "loop" not in extra_animation_parameters:
271+
extra_animation_parameters.append("loop")
272+
273+
data["framerate"] = 1000 / figure._interval
274+
275+
frames = []
276+
277+
for frame in figure.new_frame_seq():
278+
figure._draw_frame(frame)
279+
data, content, code = get_figure_tikz_code(
280+
data,
281+
figure._fig,
282+
wrap=True,
283+
include_disclaimer=False,
284+
include_colordefs=False,
285+
)
286+
frames.append(f"% Frame {frame + 1}\n{code}\n")
214287

215-
# Check if there is still an open groupplot environment. This occurs if not
216-
# all of the group plot slots are used.
217-
if "is_in_groupplot_env" in data and data["is_in_groupplot_env"]:
218-
content.extend(data["flavor"].end("groupplot") + "\n\n")
288+
code = """"""
219289

220-
# write disclaimer to the file header
221-
code = """"""
290+
if include_disclaimer:
291+
disclaimer = f"This file was created with tikzplotlib v{__version__}."
292+
code += _tex_comment(disclaimer)
222293

223-
if include_disclaimer:
224-
disclaimer = f"This file was created with tikzplotlib v{__version__}."
225-
code += _tex_comment(disclaimer)
294+
# write the contents
295+
if wrap:
296+
code += data["flavor"].start("animateinline")
297+
if extra_animation_parameters:
298+
code += "[\n" + ",\n".join(extra_animation_parameters) + "\n]"
299+
code += f"{{{data['framerate']}}}"
300+
code += "\n"
301+
code += "\n"
226302

227-
# write the contents
228-
if wrap and add_axis_environment:
229-
code += data["flavor"].start("tikzpicture")
230-
if extra_tikzpicture_parameters:
231-
code += "[\n" + ",\n".join(extra_tikzpicture_parameters) + "\n]"
232-
code += "\n"
233-
if extra_lines_start:
234-
code += "\n".join(extra_lines_start) + "\n"
235-
code += "\n"
303+
coldefs = _get_color_definitions(data)
304+
if coldefs:
305+
code += "\n".join(coldefs) + "\n\n"
236306

237-
coldefs = _get_color_definitions(data)
238-
if coldefs:
239-
code += "\n".join(coldefs) + "\n\n"
307+
code += "\n\\newframe\n".join(frames)
240308

241-
code += "".join(content)
309+
if wrap:
310+
code += data["flavor"].end("animateinline") + "\n"
242311

243-
if wrap and add_axis_environment:
244-
code += data["flavor"].end("tikzpicture") + "\n"
312+
else:
313+
data, content, code = get_figure_tikz_code(data, figure)
245314

246315
if standalone:
247316
# When using pdflatex, \\DeclareUnicodeCharacter is necessary.
@@ -406,6 +475,7 @@ class Flavors(enum.Enum):
406475
\\usetikzlibrary{{{tikzlibs}}}
407476
\\pgfplotsset{{compat=newest}}
408477
""",
478+
r"\usepackage{{{}}}",
409479
)
410480
context = (
411481
r"\start{}",
@@ -422,6 +492,7 @@ class Flavors(enum.Enum):
422492
\\unexpanded\\def\\startgroupplot{{\\groupplot}}
423493
\\unexpanded\\def\\stopgroupplot{{\\endgroupplot}}
424494
""",
495+
r"\usemodule[{}]",
425496
)
426497

427498
def start(self, what):
@@ -430,6 +501,9 @@ def start(self, what):
430501
def end(self, what):
431502
return self.value[1].format(what)
432503

504+
def usepackage(self, *what):
505+
return self.value[4]
506+
433507
def preamble(self, data=None):
434508
if data is None:
435509
data = {
@@ -438,7 +512,13 @@ def preamble(self, data=None):
438512
}
439513
pgfplotslibs = ",".join(data["pgfplots libs"])
440514
tikzlibs = ",".join(data["tikz libs"])
441-
return self.value[3].format(pgfplotslibs=pgfplotslibs, tikzlibs=tikzlibs)
515+
extra_imports = (
516+
self.usepackage("animate") + "\n" if data.get("animation", False) else ""
517+
)
518+
return (
519+
self.value[3].format(pgfplotslibs=pgfplotslibs, tikzlibs=tikzlibs)
520+
+ extra_imports
521+
)
442522

443523
def standalone(self, code):
444524
docenv = self.value[2]

tests/test_animation.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#%%
2+
def plot():
3+
import numpy as np
4+
from matplotlib import pyplot as plt
5+
import matplotlib.animation as animation
6+
7+
fig, ax = plt.subplots()
8+
scat = ax.scatter(range(10), [0] * 10)
9+
ax.set_xlim(0, 9)
10+
ax.set_ylim(0, 10)
11+
12+
def update(frame):
13+
y_data = [yi + frame * 0.2 for yi in range(10)]
14+
scat.set_offsets(list(zip(range(10), y_data)))
15+
ax.set_title(f"Frame {frame+1}/20")
16+
17+
ani = animation.FuncAnimation(fig, update, frames=20, repeat=False)
18+
return ani
19+
20+
21+
def test():
22+
from .helpers import assert_equality
23+
24+
assert_equality(plot, __file__[:-3] + "_reference.tex")
25+
26+
# %%
27+
ani = plot()
28+
# %%

0 commit comments

Comments
 (0)