Skip to content

Commit a5e0738

Browse files
JackCaoGroot
and
root
authored
Truncate python stack when outputting frame that cause the graph executation (pytorch#5933)
* Truncate python stack when outputting frame that cause the graph execution * add mp tests * move tests to a new dir --------- Co-authored-by: root <[email protected]>
1 parent 402166b commit a5e0738

File tree

4 files changed

+79
-4
lines changed

4 files changed

+79
-4
lines changed
+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
3+
import torch
4+
import torch_xla.core.xla_model as xm
5+
import torch_xla.distributed.xla_multiprocessing as xmp
6+
7+
8+
def check_env_flag(name, default=''):
9+
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
10+
11+
12+
def extract_execution_cause(lines):
13+
causes = []
14+
for i in range(len(lines)):
15+
if 'Execution Cause' in lines[i].decode():
16+
causes.append(lines[i + 1].decode())
17+
return causes
18+
19+
20+
def extract_python_frames(lines):
21+
frames = []
22+
current_frame = ''
23+
record_frame = False
24+
for i in range(len(lines)):
25+
if 'Python Frame Triggered Execution' in lines[i].decode():
26+
record_frame = True
27+
elif 'Execution Analysis: ----------------' in lines[i].decode():
28+
record_frame = False
29+
frames.append(current_frame)
30+
current_frame = ''
31+
if record_frame:
32+
current_frame += lines[i].decode()
33+
return frames
34+
35+
36+
def _mp_fn(index):
37+
if not check_env_flag('PT_XLA_DEBUG'):
38+
assert False, "This test should be run with PT_XLA_DEBUG"
39+
debug_file_name = os.getenv('PT_XLA_DEBUG_FILE')
40+
if not debug_file_name:
41+
assert False, "This test should be run with PT_XLA_DEBUG_FILE"
42+
if index == 0:
43+
open(debug_file_name, 'w').close()
44+
device = xm.xla_device()
45+
t1 = torch.randn(10, 10, device=device)
46+
t2 = t1 * 100
47+
xm.mark_step()
48+
xm.wait_device_ops()
49+
50+
if index == 0:
51+
# All of the process will write to the same PT_XLA_DEBUG_FILE, but the
52+
# no need to check this on all processes.
53+
with open(debug_file_name, 'rb') as f:
54+
lines = f.readlines()
55+
causes = extract_execution_cause(lines)
56+
frames = extract_python_frames(lines)
57+
# only the local master process should dump the executation analysis
58+
assert (len(causes) == 1)
59+
assert ('user mark_step' in causes[0])
60+
# make sure that frame that spawn up process is skipped
61+
assert (len(frames) == 1)
62+
assert ('....' in frames[0])
63+
assert ('_internal/pjrt.py' not in frames[0])
64+
65+
66+
if __name__ == '__main__':
67+
xmp.spawn(_mp_fn, args=())
File renamed without changes.

test/run_tests.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ function run_xla_op_tests1 {
161161
run_test "$CDIR/test_grad_checkpoint.py"
162162
run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
163163
run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
164-
run_pt_xla_debug "$CDIR/test_pt_xla_debug.py"
164+
run_pt_xla_debug "$CDIR/debug_tool/test_pt_xla_debug.py"
165165
run_test "$CDIR/test_async_closures.py"
166166
run_test "$CDIR/test_hlo_metadata.py"
167167
run_test "$CDIR/test_profiler.py"
@@ -232,6 +232,7 @@ function run_mp_op_tests {
232232
run_test "$CDIR/test_mp_save.py"
233233
run_test "$CDIR/test_mp_mesh_reduce.py"
234234
run_test "$CDIR/test_mp_sync_batch_norm.py"
235+
run_pt_xla_debug "$CDIR/debug_tool/test_mp_pt_xla_debug.py"
235236
run_xla_backend_mp "$CDIR/test_torch_distributed_all_gather_xla_backend.py"
236237
run_xla_backend_mp "$CDIR/test_torch_distributed_all_reduce_xla_backend.py"
237238
run_xla_backend_mp "$CDIR/test_torch_distributed_multi_all_reduce_xla_backend.py"

torch_xla/csrc/debug_util.cpp

+10-3
Original file line numberDiff line numberDiff line change
@@ -272,11 +272,18 @@ void DebugUtil::analyze_graph_execution_python_frame(
272272
"mark_step\n";
273273
}
274274

275-
// TODO(JackCaoG): make number of frames printed configurable
276275
ss << debug_output_prefix << "Python Frame Triggered Execution: \n";
277276
for (auto& location : frames) {
278-
ss << debug_output_prefix << " " << location.function << " ("
279-
<< location.file << ":" << location.line << ")\n";
277+
// if current frame `__call__` at pjrt.py, bleow stack will be python
278+
// code to spawn up process, not very useful to the user.
279+
if (location.function == "__call__" &&
280+
endsWith(location.file, "_internal/pjrt.py")) {
281+
ss << debug_output_prefix << " ..........\n";
282+
break;
283+
} else {
284+
ss << debug_output_prefix << " " << location.function << " ("
285+
<< location.file << ":" << location.line << ")\n";
286+
}
280287
}
281288
ss << debug_output_prefix
282289
<< "----------------------------------------------------------------------"

0 commit comments

Comments
 (0)