Skip to content

Commit 04e1c25

Browse files
committed
Fix comparison
1 parent ba3a9f8 commit 04e1c25

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

docs/source/scripts/check_notebooks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import argparse
22
import copy
3-
import multiprocess
43
import os
54
import shutil
65
import string
@@ -9,6 +8,7 @@
98
from functools import partial
109
from pathlib import Path
1110

11+
import multiprocess
1212
import nbformat
1313
from nbconvert.preprocessors import ExecutePreprocessor
1414

@@ -20,7 +20,7 @@
2020
TRANS_TABLE = str.maketrans(dict.fromkeys(string.whitespace))
2121

2222

23-
def inject_sst():
23+
def inject_shared_download():
2424
delim = "&" if os.name == "nt" else ";"
2525
subprocess.call(
2626
delim.join([INSTALL_SOURCE_VERSION_COMMAND[4:], INSTALL_SST_COMMAND]),
@@ -108,7 +108,7 @@ def check_notebook_output(notebook_path, env="python3", ignore_whitespace=False)
108108
new_cell_stdout_ = new_cell_stdout
109109

110110
if ignore_whitespace:
111-
original_cell = original_cell_stdout.translate(TRANS_TABLE)
111+
original_cell_stdout = original_cell_stdout.translate(TRANS_TABLE)
112112
new_cell_stdout = new_cell_stdout.translate(TRANS_TABLE)
113113
else:
114114
if new_cell_stdout[-1] == "\n" and original_cell_stdout[-1] != "\n":
@@ -150,8 +150,8 @@ def check_notebook_output(notebook_path, env="python3", ignore_whitespace=False)
150150
report = check_notebook_output(notebook_path, env=args.env, ignore_whitespace=args.ignore_whitespace)
151151
reports.append(report)
152152
else:
153-
# inject the SST dataset to prevent parallel download
154-
inject_sst()
153+
# predownload datasets/vectorizers to prevent parallel download
154+
inject_shared_download()
155155
with multiprocess.Pool(num_proc) as pool:
156156
reports = pool.map(partial(check_notebook_output, env=args.env, ignore_whitespace=args.ignore_whitespace), notebook_paths)
157157

@@ -168,7 +168,7 @@ def check_notebook_output(notebook_path, env="python3", ignore_whitespace=False)
168168
for i, original_output, new_output in report),
169169
" " * 4,
170170
)
171-
for notebook, report in reports
171+
for notebook, report in reports if len(report) > 0
172172
])
173173
raise Exception(
174174
"❌❌ Mismatches found in the outputs of the notebooks:\n\n" + reports_str

0 commit comments

Comments
 (0)