Skip to content

Commit 1c63024

Browse files
committed
simpler refresh reference files
1 parent c834d7c commit 1c63024

File tree

1 file changed

+27
-20
lines changed

1 file changed

+27
-20
lines changed

test/refresh_reference_files.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,35 @@
88

99

1010
def _main():
11-
parser = argparse.ArgumentParser(description="Refresh the reference TeX files.")
12-
parser.add_argument("files", nargs="+", help="Files to refresh")
13-
args = parser.parse_args()
11+
parser = argparse.ArgumentParser(description="Refresh all reference TeX files.")
12+
parser.parse_args()
1413

1514
this_dir = os.path.dirname(os.path.abspath(__file__))
16-
exclude_list = ["test_rotated_labels.py", "test_deterministic_output.py"]
17-
18-
for filename in args.files:
19-
if filename in exclude_list:
20-
continue
21-
if filename.startswith("test_") and filename.endswith(".py"):
22-
spec = importlib.util.spec_from_file_location("plot", filename)
23-
module = importlib.util.module_from_spec(spec)
24-
spec.loader.exec_module(module)
25-
module.plot()
26-
27-
code = tpl.get_tikz_code(include_disclaimer=False, float_format=".8g")
28-
plt.close()
29-
30-
tex_filename = filename[:-3] + "_reference.tex"
31-
with open(os.path.join(this_dir, tex_filename), "w", encoding="utf8") as f:
32-
f.write(code)
15+
16+
test_files = [
17+
f
18+
for f in os.listdir(this_dir)
19+
if os.path.isfile(os.path.join(this_dir, f))
20+
and f[:5] == "test_"
21+
and f[-3:] == ".py"
22+
]
23+
test_modules = [f[:-3] for f in test_files]
24+
25+
# remove some edge cases
26+
test_modules.remove("test_rotated_labels")
27+
test_modules.remove("test_deterministic_output")
28+
test_modules.remove("test_cleanfigure")
29+
30+
for mod in test_modules:
31+
module = importlib.import_module(mod)
32+
module.plot()
33+
34+
code = tpl.get_tikz_code(include_disclaimer=False, float_format=".8g")
35+
plt.close()
36+
37+
tex_filename = mod + "_reference.tex"
38+
with open(os.path.join(this_dir, tex_filename), "w", encoding="utf8") as f:
39+
f.write(code)
3340

3441

3542
if __name__ == "__main__":

0 commit comments

Comments
 (0)