Skip to content

Commit 34e62c1

Browse files
committed
Added the test_evals_time which checks the time consistency and changed the int to uint
1 parent c8cc1ec commit 34e62c1

File tree

2 files changed

+75
-75
lines changed

2 files changed

+75
-75
lines changed

algoperf/random_utils.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,30 @@
1818

1919
# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 31 - 1] (an
2020
# unsigned int), while RandomState.randint only accepts and returns signed ints.
21-
MAX_INT32 = 2**31 - 1
22-
MIN_INT32 = 0
21+
MAX_UINT32 = 2**31 - 1
22+
MIN_UINT32 = 0
2323

2424
SeedType = Union[int, list, np.ndarray]
2525

2626

2727
def _signed_to_unsigned(seed: SeedType) -> SeedType:
2828
if isinstance(seed, int):
29-
return seed % MAX_INT32
29+
return seed % MAX_UINT32
3030
if isinstance(seed, list):
31-
return [s % MAX_INT32 for s in seed]
31+
return [s % MAX_UINT32 for s in seed]
3232
if isinstance(seed, np.ndarray):
33-
return np.array([s % MAX_INT32 for s in seed.tolist()])
33+
return np.array([s % MAX_UINT32 for s in seed.tolist()])
3434

3535

3636
def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
3737
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
38-
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
38+
new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32)
3939
return [new_seed, data]
4040

4141

4242
def _split(seed: SeedType, num: int = 2) -> SeedType:
4343
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
44-
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
44+
return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2])
4545

4646

4747
def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name

tests/test_evals_time.py

+68-68
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
"""
2+
Module for evaluating timing consistency in MNIST workload training.
3+
4+
This script runs timing consistency tests for PyTorch and JAX implementations of an MNIST training workload.
5+
It ensures that the total reported training time aligns with the sum of submission, evaluation, and logging times.
6+
"""
7+
18
import os
29
import sys
310
import copy
@@ -6,8 +13,6 @@
613
from absl.testing import parameterized
714
from absl import logging
815
from collections import namedtuple
9-
import json
10-
import jax
1116
from algoperf import halton
1217
from algoperf import random_utils as prng
1318
from algoperf.profiler import PassThroughProfiler
@@ -16,18 +21,14 @@
1621
import reference_algorithms.development_algorithms.mnist.mnist_pytorch.submission as submission_pytorch
1722
import reference_algorithms.development_algorithms.mnist.mnist_jax.submission as submission_jax
1823
import jax.random as jax_rng
19-
# try:
20-
# import jax.random as jax_rng
21-
# except (ImportError, ModuleNotFoundError):
22-
# logging.warning(
23-
# 'Could not import jax.random for the submission runner, falling back to '
24-
# 'numpy random_utils.')
25-
# jax_rng = None
2624

2725
FLAGS = flags.FLAGS
2826
FLAGS(sys.argv)
2927

3028
class Hyperparameters:
29+
"""
30+
Defines hyperparameters for training.
31+
"""
3132
def __init__(self):
3233
self.learning_rate = 0.0005
3334
self.one_minus_beta_1 = 0.05
@@ -38,87 +39,86 @@ def __init__(self):
3839
self.dropout_rate = 0.1
3940

4041
class CheckTime(parameterized.TestCase):
41-
"""Tests to check if submission_time + eval_time + logging_time ~ total _wallclock_time """
42+
"""
43+
Test class to verify timing consistency in MNIST workload training.
44+
45+
Ensures that submission time, evaluation time, and logging time sum up to approximately the total wall-clock time.
46+
"""
4247
rng_seed = 0
4348

4449
@parameterized.named_parameters(
45-
*[ dict(
46-
testcase_name = 'mnist_pytorch',
47-
framework = 'pytorch',
48-
init_optimizer_state=submission_pytorch.init_optimizer_state,
49-
update_params=submission_pytorch.update_params,
50-
data_selection=submission_pytorch.data_selection,
51-
rng = prng.PRNGKey(rng_seed))],
52-
53-
*[
54-
dict(
55-
testcase_name = 'mnist_jax',
56-
framework = 'jax',
50+
dict(
51+
testcase_name='mnist_pytorch',
52+
framework='pytorch',
53+
init_optimizer_state=submission_pytorch.init_optimizer_state,
54+
update_params=submission_pytorch.update_params,
55+
data_selection=submission_pytorch.data_selection,
56+
rng=prng.PRNGKey(rng_seed)
57+
),
58+
dict(
59+
testcase_name='mnist_jax',
60+
framework='jax',
5761
init_optimizer_state=submission_jax.init_optimizer_state,
5862
update_params=submission_jax.update_params,
5963
data_selection=submission_jax.data_selection,
60-
#rng = jax.random.PRNGKey(rng_seed),),
61-
rng = prng.PRNGKey(rng_seed),),
62-
]
64+
rng=jax_rng.PRNGKey(rng_seed)
65+
)
6366
)
6467
def test_train_once_time_consistency(self, framework, init_optimizer_state, update_params, data_selection, rng):
65-
"""Test to check the consistency of timing metrics."""
66-
rng_seed = 0
67-
#rng = jax.random.PRNGKey(rng_seed)
68-
#rng, _ = prng.split(rng, 2)
68+
"""
69+
Tests the consistency of timing metrics in the training process.
70+
71+
Ensures that:
72+
- The total logged time is approximately the sum of submission, evaluation, and logging times.
73+
- The expected number of evaluations occurred within the training period.
74+
"""
6975
workload_metadata = copy.deepcopy(workloads.WORKLOADS["mnist"])
7076
workload_metadata['workload_path'] = os.path.join(
71-
workloads.BASE_WORKLOADS_DIR,
72-
workload_metadata['workload_path'] + '_' + framework,
73-
'workload.py')
77+
workloads.BASE_WORKLOADS_DIR,
78+
workload_metadata['workload_path'] + '_' + framework,
79+
'workload.py'
80+
)
7481
workload = workloads.import_workload(
7582
workload_path=workload_metadata['workload_path'],
7683
workload_class_name=workload_metadata['workload_class_name'],
77-
workload_init_kwargs={})
84+
workload_init_kwargs={}
85+
)
7886

79-
Hp = namedtuple("Hp",["dropout_rate", "learning_rate", "one_minus_beta_1", "weight_decay", "beta2", "warmup_factor", "epsilon" ])
80-
hp1 = Hp(0.1,0.0017486387539278373,0.06733926164,0.9955159689799007,0.08121616522670176, 0.02, 1e-25)
81-
# HPARAMS = {
82-
# "dropout_rate": 0.1,
83-
# "learning_rate": 0.0017486387539278373,
84-
# "one_minus_beta_1": 0.06733926164,
85-
# "beta2": 0.9955159689799007,
86-
# "weight_decay": 0.08121616522670176,
87-
# "warmup_factor": 0.02,
88-
# "epsilon" : 1e-25
89-
# }
87+
Hp = namedtuple("Hp", ["dropout_rate", "learning_rate", "one_minus_beta_1", "weight_decay", "beta2", "warmup_factor", "epsilon"])
88+
hp1 = Hp(0.1, 0.0017486387539278373, 0.06733926164, 0.9955159689799007, 0.08121616522670176, 0.02, 1e-25)
9089

91-
9290
accumulated_submission_time, metrics = submission_runner.train_once(
93-
workload = workload,
91+
workload=workload,
9492
workload_name="mnist",
95-
global_batch_size = 32,
96-
global_eval_batch_size = 256,
97-
data_dir = '~/tensorflow_datasets', # not sure
98-
imagenet_v2_data_dir = None,
99-
hyperparameters= hp1,
100-
init_optimizer_state = init_optimizer_state,
101-
update_params = update_params,
102-
data_selection = data_selection,
103-
rng = rng,
104-
rng_seed = 0,
105-
profiler= PassThroughProfiler(),
93+
global_batch_size=32,
94+
global_eval_batch_size=256,
95+
data_dir='~/tensorflow_datasets', # Dataset location
96+
imagenet_v2_data_dir=None,
97+
hyperparameters=hp1,
98+
init_optimizer_state=init_optimizer_state,
99+
update_params=update_params,
100+
data_selection=data_selection,
101+
rng=rng,
102+
rng_seed=0,
103+
profiler=PassThroughProfiler(),
106104
max_global_steps=500,
107-
prepare_for_eval = None)
108-
109-
110-
# Example: Check if total time roughly equals to submission_time + eval_time + logging_time
111-
total_logged_time = (metrics['eval_results'][-1][1]['total_duration']
112-
- (accumulated_submission_time +
113-
metrics['eval_results'][-1][1]['accumulated_logging_time'] +
114-
metrics['eval_results'][-1][1]['accumulated_eval_time']))
105+
prepare_for_eval=None
106+
)
107+
108+
# Calculate total logged time
109+
total_logged_time = (
110+
metrics['eval_results'][-1][1]['total_duration']
111+
- (accumulated_submission_time +
112+
metrics['eval_results'][-1][1]['accumulated_logging_time'] +
113+
metrics['eval_results'][-1][1]['accumulated_eval_time'])
114+
)
115115

116-
# Use a tolerance for floating-point arithmetic
116+
# Set tolerance for floating-point precision errors
117117
tolerance = 10
118-
self.assertAlmostEqual(total_logged_time, 0, delta=tolerance,
118+
self.assertAlmostEqual(total_logged_time, 0, delta=tolerance,
119119
msg="Total wallclock time does not match the sum of submission, eval, and logging times.")
120120

121-
# Check if the expected number of evaluations occurred
121+
# Verify expected number of evaluations
122122
expected_evals = int(accumulated_submission_time // workload.eval_period_time_sec)
123123
self.assertTrue(expected_evals <= len(metrics['eval_results']) + 2,
124124
f"Number of evaluations {len(metrics['eval_results'])} exceeded the expected number {expected_evals + 2}.")

0 commit comments

Comments
 (0)