1
+ """
2
+ Module for evaluating timing consistency in MNIST workload training.
3
+
4
+ This script runs timing consistency tests for PyTorch and JAX implementations of an MNIST training workload.
5
+ It ensures that the total reported training time aligns with the sum of submission, evaluation, and logging times.
6
+ """
7
+
1
8
import os
2
9
import sys
3
10
import copy
6
13
from absl .testing import parameterized
7
14
from absl import logging
8
15
from collections import namedtuple
9
- import json
10
- import jax
11
16
from algoperf import halton
12
17
from algoperf import random_utils as prng
13
18
from algoperf .profiler import PassThroughProfiler
16
21
import reference_algorithms .development_algorithms .mnist .mnist_pytorch .submission as submission_pytorch
17
22
import reference_algorithms .development_algorithms .mnist .mnist_jax .submission as submission_jax
18
23
import jax .random as jax_rng
19
- # try:
20
- # import jax.random as jax_rng
21
- # except (ImportError, ModuleNotFoundError):
22
- # logging.warning(
23
- # 'Could not import jax.random for the submission runner, falling back to '
24
- # 'numpy random_utils.')
25
- # jax_rng = None
26
24
27
25
FLAGS = flags .FLAGS
28
26
FLAGS (sys .argv )
29
27
30
28
class Hyperparameters :
29
+ """
30
+ Defines hyperparameters for training.
31
+ """
31
32
def __init__ (self ):
32
33
self .learning_rate = 0.0005
33
34
self .one_minus_beta_1 = 0.05
@@ -38,87 +39,86 @@ def __init__(self):
38
39
self .dropout_rate = 0.1
39
40
40
41
class CheckTime (parameterized .TestCase ):
41
- """Tests to check if submission_time + eval_time + logging_time ~ total _wallclock_time """
42
+ """
43
+ Test class to verify timing consistency in MNIST workload training.
44
+
45
+ Ensures that submission time, evaluation time, and logging time sum up to approximately the total wall-clock time.
46
+ """
42
47
rng_seed = 0
43
48
44
49
@parameterized .named_parameters (
45
- * [ dict (
46
- testcase_name = 'mnist_pytorch' ,
47
- framework = 'pytorch' ,
48
- init_optimizer_state = submission_pytorch .init_optimizer_state ,
49
- update_params = submission_pytorch .update_params ,
50
- data_selection = submission_pytorch .data_selection ,
51
- rng = prng .PRNGKey (rng_seed ))],
52
-
53
- * [
54
- dict (
55
- testcase_name = 'mnist_jax' ,
56
- framework = 'jax' ,
50
+ dict (
51
+ testcase_name = 'mnist_pytorch' ,
52
+ framework = 'pytorch' ,
53
+ init_optimizer_state = submission_pytorch .init_optimizer_state ,
54
+ update_params = submission_pytorch .update_params ,
55
+ data_selection = submission_pytorch .data_selection ,
56
+ rng = prng .PRNGKey (rng_seed )
57
+ ),
58
+ dict (
59
+ testcase_name = 'mnist_jax' ,
60
+ framework = 'jax' ,
57
61
init_optimizer_state = submission_jax .init_optimizer_state ,
58
62
update_params = submission_jax .update_params ,
59
63
data_selection = submission_jax .data_selection ,
60
- #rng = jax.random.PRNGKey(rng_seed),),
61
- rng = prng .PRNGKey (rng_seed ),),
62
- ]
64
+ rng = jax_rng .PRNGKey (rng_seed )
65
+ )
63
66
)
64
67
def test_train_once_time_consistency (self , framework , init_optimizer_state , update_params , data_selection , rng ):
65
- """Test to check the consistency of timing metrics."""
66
- rng_seed = 0
67
- #rng = jax.random.PRNGKey(rng_seed)
68
- #rng, _ = prng.split(rng, 2)
68
+ """
69
+ Tests the consistency of timing metrics in the training process.
70
+
71
+ Ensures that:
72
+ - The total logged time is approximately the sum of submission, evaluation, and logging times.
73
+ - The expected number of evaluations occurred within the training period.
74
+ """
69
75
workload_metadata = copy .deepcopy (workloads .WORKLOADS ["mnist" ])
70
76
workload_metadata ['workload_path' ] = os .path .join (
71
- workloads .BASE_WORKLOADS_DIR ,
72
- workload_metadata ['workload_path' ] + '_' + framework ,
73
- 'workload.py' )
77
+ workloads .BASE_WORKLOADS_DIR ,
78
+ workload_metadata ['workload_path' ] + '_' + framework ,
79
+ 'workload.py'
80
+ )
74
81
workload = workloads .import_workload (
75
82
workload_path = workload_metadata ['workload_path' ],
76
83
workload_class_name = workload_metadata ['workload_class_name' ],
77
- workload_init_kwargs = {})
84
+ workload_init_kwargs = {}
85
+ )
78
86
79
- Hp = namedtuple ("Hp" ,["dropout_rate" , "learning_rate" , "one_minus_beta_1" , "weight_decay" , "beta2" , "warmup_factor" , "epsilon" ])
80
- hp1 = Hp (0.1 ,0.0017486387539278373 ,0.06733926164 ,0.9955159689799007 ,0.08121616522670176 , 0.02 , 1e-25 )
81
- # HPARAMS = {
82
- # "dropout_rate": 0.1,
83
- # "learning_rate": 0.0017486387539278373,
84
- # "one_minus_beta_1": 0.06733926164,
85
- # "beta2": 0.9955159689799007,
86
- # "weight_decay": 0.08121616522670176,
87
- # "warmup_factor": 0.02,
88
- # "epsilon" : 1e-25
89
- # }
87
+ Hp = namedtuple ("Hp" , ["dropout_rate" , "learning_rate" , "one_minus_beta_1" , "weight_decay" , "beta2" , "warmup_factor" , "epsilon" ])
88
+ hp1 = Hp (0.1 , 0.0017486387539278373 , 0.06733926164 , 0.9955159689799007 , 0.08121616522670176 , 0.02 , 1e-25 )
90
89
91
-
92
90
accumulated_submission_time , metrics = submission_runner .train_once (
93
- workload = workload ,
91
+ workload = workload ,
94
92
workload_name = "mnist" ,
95
- global_batch_size = 32 ,
96
- global_eval_batch_size = 256 ,
97
- data_dir = '~/tensorflow_datasets' , # not sure
98
- imagenet_v2_data_dir = None ,
99
- hyperparameters = hp1 ,
100
- init_optimizer_state = init_optimizer_state ,
101
- update_params = update_params ,
102
- data_selection = data_selection ,
103
- rng = rng ,
104
- rng_seed = 0 ,
105
- profiler = PassThroughProfiler (),
93
+ global_batch_size = 32 ,
94
+ global_eval_batch_size = 256 ,
95
+ data_dir = '~/tensorflow_datasets' , # Dataset location
96
+ imagenet_v2_data_dir = None ,
97
+ hyperparameters = hp1 ,
98
+ init_optimizer_state = init_optimizer_state ,
99
+ update_params = update_params ,
100
+ data_selection = data_selection ,
101
+ rng = rng ,
102
+ rng_seed = 0 ,
103
+ profiler = PassThroughProfiler (),
106
104
max_global_steps = 500 ,
107
- prepare_for_eval = None )
108
-
109
-
110
- # Example: Check if total time roughly equals to submission_time + eval_time + logging_time
111
- total_logged_time = (metrics ['eval_results' ][- 1 ][1 ]['total_duration' ]
112
- - (accumulated_submission_time +
113
- metrics ['eval_results' ][- 1 ][1 ]['accumulated_logging_time' ] +
114
- metrics ['eval_results' ][- 1 ][1 ]['accumulated_eval_time' ]))
105
+ prepare_for_eval = None
106
+ )
107
+
108
+ # Calculate total logged time
109
+ total_logged_time = (
110
+ metrics ['eval_results' ][- 1 ][1 ]['total_duration' ]
111
+ - (accumulated_submission_time +
112
+ metrics ['eval_results' ][- 1 ][1 ]['accumulated_logging_time' ] +
113
+ metrics ['eval_results' ][- 1 ][1 ]['accumulated_eval_time' ])
114
+ )
115
115
116
- # Use a tolerance for floating-point arithmetic
116
+ # Set tolerance for floating-point precision errors
117
117
tolerance = 10
118
- self .assertAlmostEqual (total_logged_time , 0 , delta = tolerance ,
118
+ self .assertAlmostEqual (total_logged_time , 0 , delta = tolerance ,
119
119
msg = "Total wallclock time does not match the sum of submission, eval, and logging times." )
120
120
121
- # Check if the expected number of evaluations occurred
121
+ # Verify expected number of evaluations
122
122
expected_evals = int (accumulated_submission_time // workload .eval_period_time_sec )
123
123
self .assertTrue (expected_evals <= len (metrics ['eval_results' ]) + 2 ,
124
124
f"Number of evaluations { len (metrics ['eval_results' ])} exceeded the expected number { expected_evals + 2 } ." )
0 commit comments