Skip to content

Commit a2db029

Browse files
More pickle tests...
1 parent f12640d commit a2db029

File tree

2 files changed

+70
-27
lines changed

2 files changed

+70
-27
lines changed

matplotview/tests/test_view_obj.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
import matplotlib.pyplot as plt
2-
import pickle
2+
from matplotview.tests.utils import plotting_test, matches_post_pickle
33
from matplotview import view, view_wrapper, inset_zoom_axes
44
import numpy as np
55

6-
def to_image(figure):
7-
figure.canvas.draw()
8-
img = np.frombuffer(figure.canvas.tostring_rgb(), dtype=np.uint8)
9-
return img.reshape(figure.canvas.get_width_height()[::-1] + (3,))
10-
116

127
def test_obj_comparison():
138
from matplotlib.axes import Subplot, Axes
@@ -21,12 +16,13 @@ def test_obj_comparison():
2116
assert view_class2 != view_class3
2217

2318

24-
def test_subplot_view_pickle():
19+
@plotting_test()
20+
def test_subplot_view_pickle(fig_test):
2521
np.random.seed(1)
2622
im_data = np.random.rand(30, 30)
2723

2824
# Test case...
29-
fig_test, (ax_test1, ax_test2) = plt.subplots(1, 2)
25+
ax_test1, ax_test2 = fig_test.subplots(1, 2)
3026

3127
ax_test1.plot([i for i in range(10)], "r")
3228
ax_test1.add_patch(plt.Circle((3, 3), 1, ec="black", fc="blue"))
@@ -38,24 +34,15 @@ def test_subplot_view_pickle():
3834
ax_test2.set_xlim(ax_test1.get_xlim())
3935
ax_test2.set_ylim(ax_test1.get_ylim())
4036

41-
img_expected = to_image(fig_test)
42-
43-
saved_fig = pickle.dumps(fig_test)
44-
plt.clf()
45-
46-
fig_test = pickle.loads(saved_fig)
47-
img_result = to_image(fig_test)
37+
assert matches_post_pickle(fig_test)
4838

49-
assert np.all(img_expected == img_result)
50-
51-
52-
def test_zoom_plot_pickle():
39+
@plotting_test()
40+
def test_zoom_plot_pickle(fig_test):
5341
np.random.seed(1)
54-
plt.clf()
5542
im_data = np.random.rand(30, 30)
43+
arrow_s = dict(arrowstyle="->")
5644

5745
# Test Case...
58-
fig_test = plt.gcf()
5946
ax_test = fig_test.gca()
6047
ax_test.plot([i for i in range(10)], "r")
6148
ax_test.add_patch(plt.Circle((3, 3), 1, ec="black", fc="blue"))
@@ -65,14 +52,29 @@ def test_zoom_plot_pickle():
6552
axins_test.set_linescaling(False)
6653
axins_test.set_xlim(1, 5)
6754
axins_test.set_ylim(1, 5)
55+
axins_test.annotate(
56+
"Interesting", (3, 3), (0, 0),
57+
textcoords="axes fraction", arrowprops=arrow_s
58+
)
6859
ax_test.indicate_inset_zoom(axins_test, edgecolor="black")
6960

70-
img_expected = to_image(fig_test)
61+
assert matches_post_pickle(fig_test)
62+
7163

72-
saved_fig = pickle.dumps(fig_test)
73-
plt.clf()
64+
@plotting_test()
65+
def test_3d_view_pickle(fig_test):
66+
X = Y = np.arange(-5, 5, 0.25)
67+
X, Y = np.meshgrid(X, Y)
68+
Z = np.sin(np.sqrt(X ** 2 + Y ** 2))
7469

75-
fig_test = pickle.loads(saved_fig)
76-
img_result = to_image(fig_test)
70+
ax1_test, ax2_test = fig_test.subplots(
71+
1, 2, subplot_kw=dict(projection="3d")
72+
)
73+
ax1_test.plot_surface(X, Y, Z, cmap="plasma")
74+
view(ax2_test, ax1_test)
75+
ax2_test.view_init(elev=80)
76+
ax2_test.set_xlim(-10, 10)
77+
ax2_test.set_ylim(-10, 10)
78+
ax2_test.set_zlim(-2, 2)
7779

78-
assert np.all(img_expected == img_result)
80+
assert matches_post_pickle(fig_test)

matplotview/tests/utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import functools
2+
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
6+
7+
def figure_to_image(figure):
8+
figure.canvas.draw()
9+
img = np.frombuffer(figure.canvas.tostring_rgb(), dtype=np.uint8)
10+
return img.reshape(figure.canvas.get_width_height()[::-1] + (3,))
11+
12+
13+
def matches_post_pickle(figure):
14+
import pickle
15+
img_expected = figure_to_image(figure)
16+
17+
saved_fig = pickle.dumps(figure)
18+
plt.close("all")
19+
20+
figure = pickle.loads(saved_fig)
21+
img_result = figure_to_image(figure)
22+
23+
return np.all(img_expected == img_result)
24+
25+
26+
def plotting_test(num_figs = 1, *args, **kwargs):
27+
def plotting_decorator(function):
28+
def test_plotting():
29+
plt.close("all")
30+
res = function(
31+
*(plt.figure(*args, **kwargs) for __ in range(num_figs))
32+
)
33+
plt.close("all")
34+
return res
35+
36+
return test_plotting
37+
38+
return plotting_decorator
39+
40+
41+

0 commit comments

Comments
 (0)