Skip to content

Commit e9d4342

Browse files
committed
Merge branch 'dev' of github.com:fsschneider/algorithmic-efficiency into dev
2 parents 45a9b9a + 4345e8b commit e9d4342

File tree

61 files changed

+1134
-288
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1134
-288
lines changed

.github/workflows/CI.yml

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ jobs:
77
runs-on: ubuntu-latest
88
steps:
99
- uses: actions/checkout@v3
10-
- name: Set up Python 3.9
10+
- name: Set up Python 3.11.10
1111
uses: actions/setup-python@v4
1212
with:
13-
python-version: 3.9
13+
python-version: 3.11.10
1414
cache: 'pip' # Cache pip dependencies\.
1515
cache-dependency-path: '**/setup.py'
1616
- name: Install Modules and Run
@@ -25,10 +25,10 @@ jobs:
2525
runs-on: ubuntu-latest
2626
steps:
2727
- uses: actions/checkout@v3
28-
- name: Set up Python 3.9
28+
- name: Set up Python 3.11.10
2929
uses: actions/setup-python@v4
3030
with:
31-
python-version: 3.9
31+
python-version: 3.11.10
3232
cache: 'pip' # Cache pip dependencies\.
3333
cache-dependency-path: '**/setup.py'
3434
- name: Install Modules and Run
@@ -42,10 +42,10 @@ jobs:
4242
runs-on: ubuntu-latest
4343
steps:
4444
- uses: actions/checkout@v3
45-
- name: Set up Python 3.9
45+
- name: Set up Python 3.11.10
4646
uses: actions/setup-python@v4
4747
with:
48-
python-version: 3.9
48+
python-version: 3.11.10
4949
cache: 'pip' # Cache pip dependencies\.
5050
cache-dependency-path: '**/setup.py'
5151
- name: Install Modules and Run
@@ -59,10 +59,10 @@ jobs:
5959
runs-on: ubuntu-latest
6060
steps:
6161
- uses: actions/checkout@v3
62-
- name: Set up Python 3.9
62+
- name: Set up Python 3.11.10
6363
uses: actions/setup-python@v4
6464
with:
65-
python-version: 3.9
65+
python-version: 3.11.10
6666
cache: 'pip' # Cache pip dependencies\.
6767
cache-dependency-path: '**/setup.py'
6868
- name: Install Modules and Run
@@ -77,10 +77,10 @@ jobs:
7777
runs-on: ubuntu-latest
7878
steps:
7979
- uses: actions/checkout@v3
80-
- name: Set up Python 3.9
80+
- name: Set up Python 3.11.10
8181
uses: actions/setup-python@v4
8282
with:
83-
python-version: 3.9
83+
python-version: 3.11.10
8484
cache: 'pip' # Cache pip dependencies\.
8585
cache-dependency-path: '**/setup.py'
8686
- name: Install Modules and Run
@@ -96,10 +96,10 @@ jobs:
9696
runs-on: ubuntu-latest
9797
steps:
9898
- uses: actions/checkout@v3
99-
- name: Set up Python 3.9
99+
- name: Set up Python 3.11.10
100100
uses: actions/setup-python@v4
101101
with:
102-
python-version: 3.9
102+
python-version: 3.11.10
103103
cache: 'pip' # Cache pip dependencies\.
104104
cache-dependency-path: '**/setup.py'
105105
- name: Install Modules and Run
@@ -113,10 +113,10 @@ jobs:
113113
runs-on: ubuntu-latest
114114
steps:
115115
- uses: actions/checkout@v3
116-
- name: Set up Python 3.9
116+
- name: Set up Python 3.11.10
117117
uses: actions/setup-python@v4
118118
with:
119-
python-version: 3.9
119+
python-version: 3.11.10
120120
cache: 'pip' # Cache pip dependencies\.
121121
cache-dependency-path: '**/setup.py'
122122
- name: Install Modules and Run
@@ -130,10 +130,10 @@ jobs:
130130
runs-on: ubuntu-latest
131131
steps:
132132
- uses: actions/checkout@v3
133-
- name: Set up Python 3.9
133+
- name: Set up Python 3.11.10
134134
uses: actions/setup-python@v4
135135
with:
136-
python-version: 3.9
136+
python-version: 3.11.10
137137
cache: 'pip' # Cache pip dependencies\.
138138
cache-dependency-path: '**/setup.py'
139139
- name: Install Modules and Run
@@ -148,10 +148,10 @@ jobs:
148148
runs-on: ubuntu-latest
149149
steps:
150150
- uses: actions/checkout@v3
151-
- name: Set up Python 3.9
151+
- name: Set up Python 3.11.10
152152
uses: actions/setup-python@v4
153153
with:
154-
python-version: 3.9
154+
python-version: 3.11.10
155155
cache: 'pip' # Cache pip dependencies\.
156156
cache-dependency-path: '**/setup.py'
157157
- name: Install Modules and Run
@@ -166,10 +166,10 @@ jobs:
166166
runs-on: ubuntu-latest
167167
steps:
168168
- uses: actions/checkout@v3
169-
- name: Set up Python 3.9
169+
- name: Set up Python 3.11.10
170170
uses: actions/setup-python@v4
171171
with:
172-
python-version: 3.9
172+
python-version: 3.11.10
173173
cache: 'pip' # Cache pip dependencies\.
174174
cache-dependency-path: '**/setup.py'
175175
- name: Install Modules and Run
@@ -184,10 +184,10 @@ jobs:
184184
runs-on: ubuntu-latest
185185
steps:
186186
- uses: actions/checkout@v3
187-
- name: Set up Python 3.9
187+
- name: Set up Python 3.11.10
188188
uses: actions/setup-python@v4
189189
with:
190-
python-version: 3.9
190+
python-version: 3.11.10
191191
cache: 'pip' # Cache pip dependencies\.
192192
cache-dependency-path: '**/setup.py'
193193
- name: Install pytest
@@ -208,10 +208,10 @@ jobs:
208208
runs-on: ubuntu-latest
209209
steps:
210210
- uses: actions/checkout@v3
211-
- name: Set up Python 3.9
211+
- name: Set up Python 3.11.10
212212
uses: actions/setup-python@v4
213213
with:
214-
python-version: 3.9
214+
python-version: 3.11.10
215215
cache: 'pip' # Cache pip dependencies\.
216216
cache-dependency-path: '**/setup.py'
217217
- name: Install pytest

.github/workflows/linting.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ jobs:
77
runs-on: ubuntu-latest
88
steps:
99
- uses: actions/checkout@v2
10-
- name: Set up Python 3.9
10+
- name: Set up Python 3.11.10
1111
uses: actions/setup-python@v2
1212
with:
13-
python-version: 3.9
13+
python-version: 3.11.10
1414
- name: Install pylint
1515
run: |
1616
python -m pip install --upgrade pip
@@ -27,10 +27,10 @@ jobs:
2727
runs-on: ubuntu-latest
2828
steps:
2929
- uses: actions/checkout@v2
30-
- name: Set up Python 3.9
30+
- name: Set up Python 3.11.10
3131
uses: actions/setup-python@v2
3232
with:
33-
python-version: 3.9
33+
python-version: 3.11.10
3434
- name: Install isort
3535
run: |
3636
python -m pip install --upgrade pip
@@ -43,10 +43,10 @@ jobs:
4343
runs-on: ubuntu-latest
4444
steps:
4545
- uses: actions/checkout@v2
46-
- name: Set up Python 3.9
46+
- name: Set up Python 3.11.10
4747
uses: actions/setup-python@v2
4848
with:
49-
python-version: 3.9
49+
python-version: 3.11.10
5050
- name: Install yapf
5151
run: |
5252
python -m pip install --upgrade pip

algoperf/checkpoint_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def save_checkpoint(framework: str,
231231
target=checkpoint_state,
232232
step=global_step,
233233
overwrite=True,
234-
keep=np.Inf if save_intermediate_checkpoints else 1)
234+
keep=np.inf if save_intermediate_checkpoints else 1)
235235
else:
236236
if not save_intermediate_checkpoints:
237237
checkpoint_files = gfile.glob(

algoperf/data_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _prepare(x):
6565
# Assumes that `global_batch_size % local_device_count == 0`.
6666
return x.reshape((local_device_count, -1, *x.shape[1:]))
6767

68-
return jax.tree_map(_prepare, batch)
68+
return jax.tree.map(_prepare, batch)
6969

7070

7171
def pad(tensor: np.ndarray,

algoperf/halton.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
import functools
1111
import itertools
1212
import math
13-
from typing import Any, Callable, Dict, List, Sequence, Text, Tuple, Union
13+
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
1414

1515
from absl import logging
1616
from numpy import random
1717

18-
_SweepSequence = List[Dict[Text, Any]]
19-
_GeneratorFn = Callable[[float], Tuple[Text, float]]
18+
_SweepSequence = List[Dict[str, Any]]
19+
_GeneratorFn = Callable[[float], Tuple[str, float]]
2020

2121

2222
def generate_primes(n: int) -> List[int]:
@@ -195,10 +195,10 @@ def generate_sequence(num_samples: int,
195195
return halton_sequence
196196

197197

198-
def _generate_double_point(name: Text,
198+
def _generate_double_point(name: str,
199199
min_val: float,
200200
max_val: float,
201-
scaling: Text,
201+
scaling: str,
202202
halton_point: float) -> Tuple[str, float]:
203203
"""Generate a float hyperparameter value from a Halton sequence point."""
204204
if scaling not in ['linear', 'log']:
@@ -234,7 +234,7 @@ def interval(start: int, end: int) -> Tuple[int, int]:
234234
return start, end
235235

236236

237-
def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
237+
def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
238238
min_val, max_val = range_endpoints
239239
return functools.partial(_generate_double_point,
240240
name,
@@ -244,8 +244,8 @@ def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
244244

245245

246246
def uniform(
247-
name: Text, search_points: Union[_DiscretePoints,
248-
Tuple[int, int]]) -> _GeneratorFn:
247+
name: str, search_points: Union[_DiscretePoints,
248+
Tuple[int, int]]) -> _GeneratorFn:
249249
if isinstance(search_points, _DiscretePoints):
250250
return functools.partial(_generate_discrete_point,
251251
name,

algoperf/logger_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def _get_system_software_info() -> Dict:
211211
system_software_info['os_platform'] = \
212212
platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29'
213213
system_software_info['python_version'] = platform.python_version(
214-
) # Ex. '3.8.10'
214+
) # Ex. '3.11.10'
215215
system_software_info['python_compiler'] = platform.python_compiler(
216216
) # Ex. 'GCC 9.3.0'
217217
# Note: do not store hostname as that may be sensitive

algoperf/param_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def pytorch_param_types(
6666

6767
def jax_param_shapes(
6868
params: spec.ParameterContainer) -> spec.ParameterShapeTree:
69-
return jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params)
69+
return jax.tree.map(lambda x: spec.ShapeTuple(x.shape), params)
7070

7171

7272
def jax_param_types(param_shapes: spec.ParameterShapeTree,

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from flax import jax_utils
77
from flax import linen as nn
8+
from flax.core import pop
89
import jax
910
from jax import lax
1011
import jax.numpy as jnp
@@ -74,8 +75,8 @@ def sync_batch_stats(
7475
# In this case each device has its own version of the batch statistics
7576
# and we average them.
7677
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
77-
new_model_state = model_state.copy(
78-
{'batch_stats': avg_fn(model_state['batch_stats'])})
78+
new_model_state = model_state.copy()
79+
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
7980
return new_model_state
8081

8182
def init_model_fn(
@@ -92,7 +93,7 @@ def init_model_fn(
9293
input_shape = (1, 32, 32, 3)
9394
variables = jax.jit(model.init)({'params': rng},
9495
jnp.ones(input_shape, model.dtype))
95-
model_state, params = variables.pop('params')
96+
model_state, params = pop(variables, 'params')
9697
self._param_shapes = param_utils.jax_param_shapes(params)
9798
self._param_types = param_utils.jax_param_types(self._param_shapes)
9899
model_state = jax_utils.replicate(model_state)
@@ -205,4 +206,4 @@ def _normalize_eval_metrics(
205206
self, num_examples: int, total_metrics: Dict[str,
206207
Any]) -> Dict[str, float]:
207208
"""Normalize eval metrics."""
208-
return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics)
209+
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)

algoperf/workloads/cifar/cifar_pytorch/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def _build_dataset(
8181
}
8282
if split == 'eval_train':
8383
train_indices = indices_split['train']
84-
random.Random(data_rng[0]).shuffle(train_indices)
84+
random.Random(int(data_rng[0])).shuffle(train_indices)
8585
indices_split['eval_train'] = train_indices[:self.num_eval_train_examples]
8686
if split in indices_split:
8787
dataset = torch.utils.data.Subset(dataset, indices_split[split])

0 commit comments

Comments
 (0)