Skip to content

Commit

Permalink
Add an option to disable using the distributed scheduler from the dia…
Browse files Browse the repository at this point in the history
…gnostic script (#3787)
  • Loading branch information
bouweandela authored Jan 22, 2025
1 parent c84868e commit 5a3d3e7
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 10 deletions.
25 changes: 15 additions & 10 deletions esmvaltool/diag_scripts/shared/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __init__(self, cfg):
if not os.path.exists(self._log_file):
self.table = {}
else:
with open(self._log_file, 'r') as file:
with open(self._log_file, 'r', encoding='utf-8') as file:
self.table = yaml.safe_load(file)

def log(self, filename, record):
Expand Down Expand Up @@ -212,8 +212,8 @@ def log(self, filename, record):
if isinstance(filename, Path):
filename = str(filename)
if filename in self.table:
raise KeyError(
"Provenance record for {} already exists.".format(filename))
msg = f"Provenance record for {filename} already exists."
raise KeyError(msg)

self.table[filename] = record

Expand All @@ -222,7 +222,7 @@ def _save(self):
dirname = os.path.dirname(self._log_file)
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(self._log_file, 'w') as file:
with open(self._log_file, 'w', encoding='utf-8') as file:
yaml.safe_dump(self.table, file)

def __enter__(self):
Expand Down Expand Up @@ -253,9 +253,8 @@ def select_metadata(metadata, **attributes):
"""
selection = []
for attribs in metadata:
if all(a in attribs and (
attribs[a] == attributes[a] or attributes[a] == '*')
for a in attributes):
if all(a in attribs and v in (attribs[a], '*')
for a, v in attributes.items()):
selection.append(attribs)
return selection

Expand Down Expand Up @@ -424,7 +423,7 @@ def get_cfg(filename=None):
"""Read diagnostic script configuration from settings.yml."""
if filename is None:
filename = sys.argv[1]
with open(filename) as file:
with open(filename, encoding='utf-8') as file:
cfg = yaml.safe_load(file)
return cfg

Expand All @@ -441,7 +440,7 @@ def _get_input_data_files(cfg):

input_files = {}
for filename in metadata_files:
with open(filename) as file:
with open(filename, encoding='utf-8') as file:
metadata = yaml.safe_load(file)
input_files.update(metadata)

Expand Down Expand Up @@ -469,6 +468,10 @@ def main(cfg):
with run_diagnostic() as cfg:
main(cfg)
To prevent the diagnostic script from using the Dask Distributed scheduler,
set ``no_distributed: true`` in the diagnostic script definition in the
recipe or in the resulting settings.yml file.
The `cfg` dict passed to `main` contains the script configuration that
can be used with the other functions in this module.
"""
Expand Down Expand Up @@ -568,7 +571,9 @@ def main(cfg):
logger.info("Removing %s from previous run.", provenance_file)
os.remove(provenance_file)

if not args.no_distributed and 'scheduler_address' in cfg:
use_distributed = not (args.no_distributed
or cfg.get('no_distributed', False))
if use_distributed and 'scheduler_address' in cfg:
try:
client = distributed.Client(cfg['scheduler_address'])
except OSError as exc:
Expand Down
3 changes: 3 additions & 0 deletions esmvaltool/recipes/recipe_eady_growth_rate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ diagnostics:
scripts:
annual_eady_growth_rate:
script: primavera/eady_growth_rate/eady_growth_rate.py
no_distributed: true
time_statistic: 'annual_mean'


Expand All @@ -63,6 +64,7 @@ diagnostics:
scripts:
summer_eady_growth_rate:
script: primavera/eady_growth_rate/eady_growth_rate.py
no_distributed: true
time_statistic: 'seasonal_mean'

winter_egr:
Expand All @@ -76,5 +78,6 @@ diagnostics:
scripts:
winter_eady_growth_rate:
script: primavera/eady_growth_rate/eady_growth_rate.py
no_distributed: true
time_statistic: 'seasonal_mean'
plot_levels: [70000]
34 changes: 34 additions & 0 deletions tests/unit/diag_scripts/shared/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,40 @@ def test_run_diagnostic(tmp_path, monkeypatch):
assert 'example_setting' in cfg


@pytest.mark.parametrize("no_distributed", [False, True])
def test_run_diagnostic_configures_dask(
tmp_path,
monkeypatch,
mocker,
no_distributed,
):

settings = create_settings(tmp_path)
scheduler_address = "tcp://127.0.0.1:38789"
settings["scheduler_address"] = scheduler_address
if no_distributed:
settings["no_distributed"] = True
settings_file = write_settings(settings)

monkeypatch.setattr(sys, 'argv', ['', settings_file])

# Create files created by ESMValCore
for filename in ('log.txt', 'profile.bin', 'resource_usage.txt'):
file = Path(settings['run_dir']) / filename
file.touch()

mocker.patch.object(shared._base.distributed, "Client")

with shared.run_diagnostic() as cfg:
assert 'example_setting' in cfg

if no_distributed:
shared._base.distributed.Client.assert_not_called()
else:
shared._base.distributed.Client.assert_called_once_with(
scheduler_address)


@pytest.mark.parametrize('flag', ['-l', '--log-level'])
def test_run_diagnostic_log_level(tmp_path, monkeypatch, flag):
"""Test if setting the log level from the command line works."""
Expand Down

0 comments on commit 5a3d3e7

Please sign in to comment.