Skip to content

Commit c8cc1ec

Browse files
committed
..
1 parent 1d81455 commit c8cc1ec

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed

tests/test_evals_time.py

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import os
2+
import sys
3+
import copy
4+
from absl import flags
5+
from absl.testing import absltest
6+
from absl.testing import parameterized
7+
from absl import logging
8+
from collections import namedtuple
9+
import json
10+
import jax
11+
from algoperf import halton
12+
from algoperf import random_utils as prng
13+
from algoperf.profiler import PassThroughProfiler
14+
from algoperf.workloads import workloads
15+
import submission_runner
16+
import reference_algorithms.development_algorithms.mnist.mnist_pytorch.submission as submission_pytorch
17+
import reference_algorithms.development_algorithms.mnist.mnist_jax.submission as submission_jax
18+
import jax.random as jax_rng
19+
# try:
20+
# import jax.random as jax_rng
21+
# except (ImportError, ModuleNotFoundError):
22+
# logging.warning(
23+
# 'Could not import jax.random for the submission runner, falling back to '
24+
# 'numpy random_utils.')
25+
# jax_rng = None
26+
27+
FLAGS = flags.FLAGS
28+
FLAGS(sys.argv)
29+
30+
class Hyperparameters:
31+
def __init__(self):
32+
self.learning_rate = 0.0005
33+
self.one_minus_beta_1 = 0.05
34+
self.beta2 = 0.999
35+
self.weight_decay = 0.01
36+
self.epsilon = 1e-25
37+
self.label_smoothing = 0.1
38+
self.dropout_rate = 0.1
39+
40+
class CheckTime(parameterized.TestCase):
41+
"""Tests to check if submission_time + eval_time + logging_time ~ total _wallclock_time """
42+
rng_seed = 0
43+
44+
@parameterized.named_parameters(
45+
*[ dict(
46+
testcase_name = 'mnist_pytorch',
47+
framework = 'pytorch',
48+
init_optimizer_state=submission_pytorch.init_optimizer_state,
49+
update_params=submission_pytorch.update_params,
50+
data_selection=submission_pytorch.data_selection,
51+
rng = prng.PRNGKey(rng_seed))],
52+
53+
*[
54+
dict(
55+
testcase_name = 'mnist_jax',
56+
framework = 'jax',
57+
init_optimizer_state=submission_jax.init_optimizer_state,
58+
update_params=submission_jax.update_params,
59+
data_selection=submission_jax.data_selection,
60+
#rng = jax.random.PRNGKey(rng_seed),),
61+
rng = prng.PRNGKey(rng_seed),),
62+
]
63+
)
64+
def test_train_once_time_consistency(self, framework, init_optimizer_state, update_params, data_selection, rng):
65+
"""Test to check the consistency of timing metrics."""
66+
rng_seed = 0
67+
#rng = jax.random.PRNGKey(rng_seed)
68+
#rng, _ = prng.split(rng, 2)
69+
workload_metadata = copy.deepcopy(workloads.WORKLOADS["mnist"])
70+
workload_metadata['workload_path'] = os.path.join(
71+
workloads.BASE_WORKLOADS_DIR,
72+
workload_metadata['workload_path'] + '_' + framework,
73+
'workload.py')
74+
workload = workloads.import_workload(
75+
workload_path=workload_metadata['workload_path'],
76+
workload_class_name=workload_metadata['workload_class_name'],
77+
workload_init_kwargs={})
78+
79+
Hp = namedtuple("Hp",["dropout_rate", "learning_rate", "one_minus_beta_1", "weight_decay", "beta2", "warmup_factor", "epsilon" ])
80+
hp1 = Hp(0.1,0.0017486387539278373,0.06733926164,0.9955159689799007,0.08121616522670176, 0.02, 1e-25)
81+
# HPARAMS = {
82+
# "dropout_rate": 0.1,
83+
# "learning_rate": 0.0017486387539278373,
84+
# "one_minus_beta_1": 0.06733926164,
85+
# "beta2": 0.9955159689799007,
86+
# "weight_decay": 0.08121616522670176,
87+
# "warmup_factor": 0.02,
88+
# "epsilon" : 1e-25
89+
# }
90+
91+
92+
accumulated_submission_time, metrics = submission_runner.train_once(
93+
workload = workload,
94+
workload_name="mnist",
95+
global_batch_size = 32,
96+
global_eval_batch_size = 256,
97+
data_dir = '~/tensorflow_datasets', # not sure
98+
imagenet_v2_data_dir = None,
99+
hyperparameters= hp1,
100+
init_optimizer_state = init_optimizer_state,
101+
update_params = update_params,
102+
data_selection = data_selection,
103+
rng = rng,
104+
rng_seed = 0,
105+
profiler= PassThroughProfiler(),
106+
max_global_steps=500,
107+
prepare_for_eval = None)
108+
109+
110+
# Example: Check if total time roughly equals to submission_time + eval_time + logging_time
111+
total_logged_time = (metrics['eval_results'][-1][1]['total_duration']
112+
- (accumulated_submission_time +
113+
metrics['eval_results'][-1][1]['accumulated_logging_time'] +
114+
metrics['eval_results'][-1][1]['accumulated_eval_time']))
115+
116+
# Use a tolerance for floating-point arithmetic
117+
tolerance = 10
118+
self.assertAlmostEqual(total_logged_time, 0, delta=tolerance,
119+
msg="Total wallclock time does not match the sum of submission, eval, and logging times.")
120+
121+
# Check if the expected number of evaluations occurred
122+
expected_evals = int(accumulated_submission_time // workload.eval_period_time_sec)
123+
self.assertTrue(expected_evals <= len(metrics['eval_results']) + 2,
124+
f"Number of evaluations {len(metrics['eval_results'])} exceeded the expected number {expected_evals + 2}.")
125+
126+
if __name__ == '__main__':
127+
absltest.main()

0 commit comments

Comments
 (0)