File tree Expand file tree Collapse file tree 2 files changed +19
-5
lines changed Expand file tree Collapse file tree 2 files changed +19
-5
lines changed Original file line number Diff line number Diff line change @@ -34,6 +34,20 @@ class Profiling:
3434 profile_freq : int = 10
3535 """How often to collect profile traces, in iterations"""
3636
37+ profiler_active : int = 1
38+ """
39+ The steps profiler is active for.
40+
41+ This is used to configure torch.profile.schedule.
42+ """
43+
44+ profiler_warmup : int = 3
45+ """
46+ The number of warmup steps before the active step in each profiling cycle.
47+
48+ This is used to configure torch.profile.schedule.
49+ """
50+
3751 enable_memory_snapshot : bool = False
3852 """Whether to dump memory snapshot"""
3953
Original file line number Diff line number Diff line change 1414from torchtitan .config import Profiling as ProfilingConfig
1515from torchtitan .tools .logging import logger
1616
17- # the number of warmup steps before the active step in each profiling cycle
18- WARMUP = 3
19-
2017# how much memory allocation/free ops to record in memory snapshots
2118MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
2219
@@ -34,7 +31,11 @@ def maybe_enable_profiling(
3431
3532 if enable_profiling :
3633 trace_dir = os .path .join (base_folder , profiling_config .save_traces_folder )
37- profile_freq = profiling_config .profile_freq
34+ profile_freq , warmup , active = (
35+ profiling_config .profile_freq ,
36+ profiling_config .profiler_warmup ,
37+ profiling_config .profiler_active ,
38+ )
3839
3940 rank = torch .distributed .get_rank ()
4041
@@ -58,7 +59,6 @@ def trace_handler(prof):
5859 if not os .path .exists (trace_dir ):
5960 os .makedirs (trace_dir , exist_ok = True )
6061
61- warmup , active = WARMUP , 1
6262 wait = profile_freq - (active + warmup )
6363 assert (
6464 wait >= 0
You can’t perform that action at this time.
0 commit comments