From 831bc404afcdd8cd38fbdd86eb349b8edf2940f8 Mon Sep 17 00:00:00 2001 From: Brian Gunnarson <49216024+bgunnar5@users.noreply.github.com> Date: Tue, 11 Jun 2024 07:56:14 -0700 Subject: [PATCH] bugfix/flux-nodes-prior-versions (#487) * add a version check for flux when getting node count * update CHANGELOG * add major version check for flux --- CHANGELOG.md | 1 + merlin/study/batch.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e1869e1..27933351 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Link to Merlin banner in readme - Issue with escape sequences in ascii art (caught by python 3.12) - Bug where Flux wasn't identifying total number of nodes on an allocation + - Not supporting Flux versions below 0.17.0 ## [1.12.1] diff --git a/merlin/study/batch.py b/merlin/study/batch.py index b8c7cc95..16482f39 100644 --- a/merlin/study/batch.py +++ b/merlin/study/batch.py @@ -40,7 +40,7 @@ import subprocess from typing import Dict, Optional, Union -from merlin.utils import convert_timestring, get_flux_alloc, get_yaml_var +from merlin.utils import convert_timestring, get_flux_alloc, get_flux_version, get_yaml_var LOG = logging.getLogger(__name__) @@ -126,7 +126,7 @@ def get_batch_type(scheduler_legend, default=None): return default -def get_node_count(default=1): +def get_node_count(parsed_batch: Dict, default=1): """ Determine a default node count based on the environment. @@ -135,6 +135,12 @@ def get_node_count(default=1): :param returns: (int) The number of nodes to use. """ + # Flux version check + flux_ver = get_flux_version(parsed_batch["flux exe"], no_errors=True) + major, minor, _ = map(int, flux_ver.split(".")) + if major < 1 and minor < 17: + raise ValueError("Flux version is too old. Supported versions are 0.17.0+.") + # If flux is the scheduler, we can get the size of the allocation with this try: get_size_proc = subprocess.run("flux getattr size", shell=True, capture_output=True, text=True) @@ -254,7 +260,7 @@ def batch_worker_launch( # Get the number of nodes from the environment if unset if nodes is None or nodes == "all": - nodes = get_node_count(default=1) + nodes = get_node_count(parsed_batch, default=1) elif not isinstance(nodes, int): raise TypeError("Nodes was passed into batch_worker_launch with an invalid type (likely a string other than 'all').")