1
1
import argparse
2
2
import copy
3
- import multiprocess
4
3
import os
5
4
import shutil
6
5
import string
9
8
from functools import partial
10
9
from pathlib import Path
11
10
11
+ import multiprocess
12
12
import nbformat
13
13
from nbconvert .preprocessors import ExecutePreprocessor
14
14
20
20
TRANS_TABLE = str .maketrans (dict .fromkeys (string .whitespace ))
21
21
22
22
23
- def inject_sst ():
23
+ def inject_shared_download ():
24
24
delim = "&" if os .name == "nt" else ";"
25
25
subprocess .call (
26
26
delim .join ([INSTALL_SOURCE_VERSION_COMMAND [4 :], INSTALL_SST_COMMAND ]),
@@ -108,7 +108,7 @@ def check_notebook_output(notebook_path, env="python3", ignore_whitespace=False)
108
108
new_cell_stdout_ = new_cell_stdout
109
109
110
110
if ignore_whitespace :
111
- original_cell = original_cell_stdout .translate (TRANS_TABLE )
111
+ original_cell_stdout = original_cell_stdout .translate (TRANS_TABLE )
112
112
new_cell_stdout = new_cell_stdout .translate (TRANS_TABLE )
113
113
else :
114
114
if new_cell_stdout [- 1 ] == "\n " and original_cell_stdout [- 1 ] != "\n " :
@@ -150,8 +150,8 @@ def check_notebook_output(notebook_path, env="python3", ignore_whitespace=False)
150
150
report = check_notebook_output (notebook_path , env = args .env , ignore_whitespace = args .ignore_whitespace )
151
151
reports .append (report )
152
152
else :
153
- # inject the SST dataset to prevent parallel download
154
- inject_sst ()
153
+ # predownload datasets/vectorizers to prevent parallel download
154
+ inject_shared_download ()
155
155
with multiprocess .Pool (num_proc ) as pool :
156
156
reports = pool .map (partial (check_notebook_output , env = args .env , ignore_whitespace = args .ignore_whitespace ), notebook_paths )
157
157
@@ -168,7 +168,7 @@ def check_notebook_output(notebook_path, env="python3", ignore_whitespace=False)
168
168
for i , original_output , new_output in report ),
169
169
" " * 4 ,
170
170
)
171
- for notebook , report in reports
171
+ for notebook , report in reports if len ( report ) > 0
172
172
])
173
173
raise Exception (
174
174
"❌❌ Mismatches found in the outputs of the notebooks:\n \n " + reports_str
0 commit comments