Skip to content

Prepare for v0.6 release #876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 43 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
78a1409
Formatting, ignore `.eggs/` in yapf
fsschneider Jun 18, 2025
462e8b7
Replace yapf, pylint, isort with ruff
fsschneider Jun 18, 2025
383db7a
Replace pre-commit with ruff
fsschneider Jun 18, 2025
2c28136
Use extend-select instead, and reduce lint rules
fsschneider Jun 18, 2025
999f7a2
Replace linting GH actions with ruff
fsschneider Jun 18, 2025
7b245d6
Add ruff badge
fsschneider Jun 18, 2025
277674c
Update style testing with ruff
fsschneider Jun 18, 2025
830d2c2
Format submission_runner
fsschneider Jun 23, 2025
d84eddf
Format submissions/
fsschneider Jun 23, 2025
fbbeafa
Format scoring/
fsschneider Jun 23, 2025
f4ae9be
Format reference_algorithms/
fsschneider Jun 23, 2025
e5209d1
Format tests/
fsschneider Jun 23, 2025
f026711
Format prize_qualification_baselines/
fsschneider Jun 23, 2025
531c99e
Format datasets/
fsschneider Jun 23, 2025
c34af17
Format algoperf/
fsschneider Jun 23, 2025
9725554
Format docker/
fsschneider Jun 23, 2025
7b18fff
Lint tests/
fsschneider Jun 23, 2025
f34bb6d
Lint submissions/
fsschneider Jun 23, 2025
0aeb545
Remove perf. profile tests as it is only a placeholder
fsschneider Jun 23, 2025
4802dfb
Lint scoring/
fsschneider Jun 23, 2025
5e97e78
Lint prize_qualification_baselines/
fsschneider Jun 23, 2025
4ae5418
Lint datasets/
fsschneider Jun 23, 2025
e3f1b74
Lint reference_algorithms/
fsschneider Jun 23, 2025
566b6d9
Lint algoperf/
fsschneider Jun 23, 2025
e846648
Remove unnecessary isort=off commands
fsschneider Jun 23, 2025
3e425e0
Update Ruff linting rules in pyproject.toml to include additional opt…
fsschneider Jun 23, 2025
8bca401
Add pylint errors to linting rules
fsschneider Jun 23, 2025
09aca7f
Fix formatting
fsschneider Jun 25, 2025
ca4f4b6
Rework Readme
fsschneider Jun 24, 2025
2d2fd37
Remove deprecated rules and call for submissions
fsschneider Jun 24, 2025
2544fb9
Clarify automatic versioning
fsschneider Jun 25, 2025
c8da1f9
Mention changelog
fsschneider Jun 25, 2025
d233344
Backlog changes
fsschneider Jun 25, 2025
679e5ec
Increment version + formatting
fsschneider Jun 25, 2025
891fdb7
Formatting
fsschneider Jun 25, 2025
0390e13
Mention our version policy
fsschneider Jun 25, 2025
818710b
More descriptive link text
fsschneider Jun 25, 2025
03114bd
Add versioning protocol
fsschneider Jun 25, 2025
a2ab920
Document dropout PR
fsschneider Jun 25, 2025
64e9961
Compact layout for default dropout values
fsschneider Jun 25, 2025
23004a2
Change from 5 to 3 studies
fsschneider Jun 25, 2025
576c661
Change from 5 to 3 studies
fsschneider Jun 25, 2025
bdfd9d8
Remove held-out workloads
fsschneider Jun 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 10 additions & 29 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,54 +3,35 @@ name: Linting
on: [push, pull_request]

jobs:
pylint:
ruff-linting:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.11.10
uses: actions/setup-python@v2
with:
python-version: 3.11.10
- name: Install pylint
- name: Install ruff
run: |
python -m pip install --upgrade pip
pip install pylint==2.16.1
- name: Run pylint
pip install ruff==0.12.0
- name: Run ruff linter
run: |
pylint algoperf
pylint reference_algorithms
pylint prize_qualification_baselines
pylint submission_runner.py
pylint tests
ruff check

isort:
ruff-formatter:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.11.10
uses: actions/setup-python@v2
with:
python-version: 3.11.10
- name: Install isort
- name: Install ruff
run: |
python -m pip install --upgrade pip
pip install isort==5.12.0
- name: Run isort
pip install ruff==0.12.0
- name: Run ruff formatter
run: |
isort . --check --diff
ruff format --check

yapf:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.11.10
uses: actions/setup-python@v2
with:
python-version: 3.11.10
- name: Install yapf
run: |
python -m pip install --upgrade pip
pip install yapf==0.32 toml
- name: Run yapf
run: |
yapf . --diff --recursive
20 changes: 8 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
repos:
- repo: https://github.com/google/yapf
rev: v0.32.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.12.0
hooks:
- id: yapf
args: ["--in-place", "--parallel", "--verbose", "--recursive"]
- repo: https://github.com/pycqa/isort
rev: 5.10.1
hooks:
- id: isort
- repo: https://github.com/pycqa/pylint
rev: v2.16.1
hooks:
- id: pylint
# Run the linter (don't change files).
- id: ruff-check
# Run the formatter (don't change files).
- id: ruff-format
args: ["--check"]
177 changes: 121 additions & 56 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion algoperf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

from ._version import version as __version__

__all__ = ["__version__"]
__all__ = ['__version__']
201 changes: 111 additions & 90 deletions algoperf/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,41 @@
import os
from typing import Sequence, Tuple

import jax
import numpy as np
import torch
from absl import logging
from flax import jax_utils
from flax.training import checkpoints as flax_checkpoints
from flax.training.checkpoints import latest_checkpoint
import jax
import numpy as np
from tensorflow.io import gfile # pytype: disable=import-error
import torch

from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup

_, _, DEVICE, _ = pytorch_setup()
CheckpointReturn = Tuple[spec.OptimizerState,
spec.ParameterContainer,
spec.ModelAuxiliaryState,
dict,
list,
int,
int]


def maybe_restore_checkpoint(framework: str,
optimizer_state: spec.OptimizerState,
model_params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
train_state: dict,
eval_results: list,
global_step: int,
preemption_count: int,
checkpoint_dir: str) -> CheckpointReturn:
CheckpointReturn = Tuple[
spec.OptimizerState,
spec.ParameterContainer,
spec.ModelAuxiliaryState,
dict,
list,
int,
int,
]


def maybe_restore_checkpoint(
framework: str,
optimizer_state: spec.OptimizerState,
model_params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
train_state: dict,
eval_results: list,
global_step: int,
preemption_count: int,
checkpoint_dir: str,
) -> CheckpointReturn:
"""Optionally restores from a checkpoint.

The checkpoint logic is as follows: if there is a checkpoint in
Expand Down Expand Up @@ -69,20 +73,22 @@ def maybe_restore_checkpoint(framework: str,
uninitialized_global_step = -1
uninitialized_preemption_count = -1
checkpoint_state = {
'model_params': model_params,
'optimizer_state': opt_state,
'model_state': model_state,
'train_state': train_state,
'eval_results': None,
'global_step': uninitialized_global_step,
'preemption_count': uninitialized_preemption_count,
'model_params': model_params,
'optimizer_state': opt_state,
'model_state': model_state,
'train_state': train_state,
'eval_results': None,
'global_step': uninitialized_global_step,
'preemption_count': uninitialized_preemption_count,
}

if framework == 'jax':
latest_ckpt = flax_checkpoints.restore_checkpoint(
checkpoint_dir, target=checkpoint_state)
save_path = os.path.join(checkpoint_dir,
'checkpoint_' + str(latest_ckpt['global_step']))
checkpoint_dir, target=checkpoint_state
)
save_path = os.path.join(
checkpoint_dir, 'checkpoint_' + str(latest_ckpt['global_step'])
)
else:
latest_ckpt = checkpoint_state
save_path = latest_checkpoint(checkpoint_dir)
Expand All @@ -94,55 +100,64 @@ def maybe_restore_checkpoint(framework: str,
found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step

if not found_checkpoint:
return (optimizer_state,
model_params,
model_state,
train_state,
eval_results,
global_step,
preemption_count)
return (
optimizer_state,
model_params,
model_state,
train_state,
eval_results,
global_step,
preemption_count,
)

# If there's the latest checkpoint in the checkpoint_dir, restore from that.
if framework == 'jax':
checkpoint_state = replicate_checkpoint(
latest_ckpt,
pytree_keys=[
'optimizer_state',
'model_params',
'model_state',
])
checkpoint_state['optimizer_state'] = (checkpoint_state['optimizer_state'],
opt_update_fn)
latest_ckpt,
pytree_keys=[
'optimizer_state',
'model_params',
'model_state',
],
)
checkpoint_state['optimizer_state'] = (
checkpoint_state['optimizer_state'],
opt_update_fn,
)
checkpoint_state['eval_results'] = [
(value, key) for key, value in latest_ckpt['eval_results'].items()
(value, key) for key, value in latest_ckpt['eval_results'].items()
]

else:
checkpoint_state = latest_ckpt
if isinstance(
model_params,
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
model_params,
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel),
):
model_params = model_params.module
model_params.load_state_dict(checkpoint_state['model_params'])
checkpoint_state['model_params'] = model_params
for key in optimizer_state.keys():
optimizer_state[key].load_state_dict(
checkpoint_state['optimizer_state'][key])
checkpoint_state['optimizer_state'][key]
)
checkpoint_state['optimizer_state'][key] = optimizer_state[key]

logging.info(f'Loaded checkpoint from {save_path}.')
return (checkpoint_state['optimizer_state'],
checkpoint_state['model_params'],
checkpoint_state['model_state'],
checkpoint_state['train_state'],
list(checkpoint_state['eval_results']),
checkpoint_state['global_step'],
checkpoint_state['preemption_count'] + 1)


def replicate_checkpoint(latest: dict,
pytree_keys: Sequence[str],
replicate: bool = True) -> dict:
return (
checkpoint_state['optimizer_state'],
checkpoint_state['model_params'],
checkpoint_state['model_state'],
checkpoint_state['train_state'],
list(checkpoint_state['eval_results']),
checkpoint_state['global_step'],
checkpoint_state['preemption_count'] + 1,
)


def replicate_checkpoint(
latest: dict, pytree_keys: Sequence[str], replicate: bool = True
) -> dict:
"""Restores from the provided checkpoint.

Args:
Expand All @@ -163,16 +178,18 @@ def replicate_checkpoint(latest: dict,
return pytree


def save_checkpoint(framework: str,
optimizer_state: spec.OptimizerState,
model_params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
train_state: dict,
eval_results: list,
global_step: int,
preemption_count: int,
checkpoint_dir: str,
save_intermediate_checkpoints: bool) -> None:
def save_checkpoint(
framework: str,
optimizer_state: spec.OptimizerState,
model_params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
train_state: dict,
eval_results: list,
global_step: int,
preemption_count: int,
checkpoint_dir: str,
save_intermediate_checkpoints: bool,
) -> None:
"""Save the checkpoint in `checkpoint_dir`.

Args:
Expand All @@ -199,8 +216,9 @@ def save_checkpoint(framework: str,
model_state = jax.device_get(jax_utils.unreplicate(model_state))
else:
if isinstance(
model_params,
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
model_params,
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel),
):
model_params = model_params.module
model_params = model_params.state_dict()
optimizer_state_dict = {}
Expand All @@ -209,33 +227,36 @@ def save_checkpoint(framework: str,
optimizer_state_dict[key] = optimizer_state[key].state_dict()
else:
logging.warning(
f'The optimizer state for key {key} is not saved, because '
f'{type(optimizer_state[key])} has not implemented a state_dict() '
'method.')
f'The optimizer state for key {key} is not saved, because '
f'{type(optimizer_state[key])} has not implemented a state_dict() '
'method.'
)
opt_state = optimizer_state_dict

checkpoint_state = {
'model_params': model_params,
'optimizer_state': opt_state,
'model_state': model_state,
'train_state': train_state,
'eval_results': tuple(eval_results),
'global_step': global_step,
'preemption_count': preemption_count,
'model_params': model_params,
'optimizer_state': opt_state,
'model_state': model_state,
'train_state': train_state,
'eval_results': tuple(eval_results),
'global_step': global_step,
'preemption_count': preemption_count,
}

save_path = os.path.join(checkpoint_dir, f'checkpoint_{global_step}')
if framework == 'jax':
flax_checkpoints.save_checkpoint(
checkpoint_dir,
target=checkpoint_state,
step=global_step,
overwrite=True,
keep=np.inf if save_intermediate_checkpoints else 1)
checkpoint_dir,
target=checkpoint_state,
step=global_step,
overwrite=True,
keep=np.inf if save_intermediate_checkpoints else 1,
)
else:
if not save_intermediate_checkpoints:
checkpoint_files = gfile.glob(
os.path.join(checkpoint_dir, 'checkpoint_*'))
os.path.join(checkpoint_dir, 'checkpoint_*')
)
for path in checkpoint_files:
logging.info('Removing checkpoint at %s', path)
gfile.rmtree(path)
Expand Down
Loading