Skip to content

(torchx/runopt) Allow runopt type to be builtin list[str] and dict[str,str] #1093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions torchx/runner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,14 @@ def dump(
continue

# serialize list elements with `;` delimiter (consistent with torchx cli)
if opt.opt_type == List[str]:
if opt.is_type_list_of_str:
# deal with empty or None default lists
if opt.default:
# pyre-ignore[6] opt.default type checked already as List[str]
val = ";".join(opt.default)
else:
val = _NONE
elif opt.opt_type == Dict[str, str]:
elif opt.is_type_dict_of_str:
# deal with empty or None default lists
if opt.default:
# pyre-ignore[16] opt.default type checked already as Dict[str, str]
Expand Down Expand Up @@ -536,26 +536,26 @@ def load(scheduler: str, f: TextIO, cfg: Dict[str, CfgVal]) -> None:
# this also handles empty or None lists
cfg[name] = None
else:
runopt = runopts.get(name)
opt = runopts.get(name)

if runopt is None:
if opt is None:
log.warning(
f"`{name} = {value}` was declared in the [{section}] section "
f" of the config file but is not a runopt of `{scheduler}` scheduler."
f" Remove the entry from the config file to no longer see this warning"
)
else:
if runopt.opt_type is bool:
if opt.opt_type is bool:
# need to handle bool specially since str -> bool is based on
# str emptiness not value (e.g. bool("False") == True)
cfg[name] = config.getboolean(section, name)
elif runopt.opt_type is List[str]:
elif opt.is_type_list_of_str:
cfg[name] = value.split(";")
elif runopt.opt_type is Dict[str, str]:
elif opt.is_type_dict_of_str:
cfg[name] = {
s.split(":", 1)[0]: s.split(":", 1)[1]
for s in value.replace(",", ";").split(";")
}
else:
# pyre-ignore[29]
cfg[name] = runopt.opt_type(value)
cfg[name] = opt.opt_type(value)
28 changes: 24 additions & 4 deletions torchx/runner/test/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,22 +95,34 @@ def _run_opts(self) -> runopts:
)
opts.add(
"l",
type_=List[str],
type_=list[str],
default=["a", "b", "c"],
help="a list option",
)
opts.add(
"l_none",
"l_typing",
type_=List[str],
default=["a", "b", "c"],
help="a typing.List option",
)
opts.add(
"l_none",
type_=list[str],
default=None,
help="a None list option",
)
opts.add(
"d",
type_=Dict[str, str],
type_=dict[str, str],
default={"foo": "bar"},
help="a dict option",
)
opts.add(
"d_typing",
type_=Dict[str, str],
default={"foo": "bar"},
help="a typing.Dict option",
)
opts.add(
"d_none",
type_=Dict[str, str],
Expand Down Expand Up @@ -151,6 +163,10 @@ def _run_opts(self) -> runopts:
[test]
s = my_default
i = 100
l = abc;def
l_typing = ghi;jkl
d = a:b,c:d
d_typing = e:f,g:h
"""

_MY_CONFIG2 = """#
Expand Down Expand Up @@ -387,6 +403,10 @@ def test_apply_dirs(self, _) -> None:
self.assertEqual("runtime_value", cfg.get("s"))
self.assertEqual(100, cfg.get("i"))
self.assertEqual(1.2, cfg.get("f"))
self.assertEqual({"a": "b", "c": "d"}, cfg.get("d"))
self.assertEqual({"e": "f", "g": "h"}, cfg.get("d_typing"))
self.assertEqual(["abc", "def"], cfg.get("l"))
self.assertEqual(["ghi", "jkl"], cfg.get("l_typing"))

def test_dump_invalid_scheduler(self) -> None:
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -460,7 +480,7 @@ def test_dump_and_load_all_runopt_types(self, _) -> None:

# all runopts in the TestScheduler have defaults, just check against those
for opt_name, opt in TestScheduler("test").run_opts():
self.assertEqual(cfg.get(opt_name), opt.default)
self.assertEqual(opt.default, cfg.get(opt_name))

def test_dump_and_load_all_registered_schedulers(self) -> None:
# dump all the runopts for all registered schedulers
Expand Down
84 changes: 61 additions & 23 deletions torchx/specs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,60 @@ class runopt:
is_required: bool
help: str

@property
def is_type_list_of_str(self) -> bool:
"""
Checks if the option type is a list of strings.

Returns:
bool: True if the option type is either List[str] or list[str], False otherwise.
"""
return self.opt_type in (List[str], list[str])

@property
def is_type_dict_of_str(self) -> bool:
"""
Checks if the option type is a dict of string keys to string values.

Returns:
bool: True if the option type is either Dict[str, str] or dict[str, str], False otherwise.
"""
return self.opt_type in (Dict[str, str], dict[str, str])

def cast_to_type(self, value: str) -> CfgVal:
"""Casts the given `value` (in its string representation) to the type of this run option.
Below are the cast rules for each option type and value literal:

1. opt_type=str, value="foo" -> "foo"
1. opt_type=bool, value="True"/"False" -> True/False
1. opt_type=int, value="1" -> 1
1. opt_type=float, value="1.1" -> 1.1
1. opt_type=list[str]/List[str], value="a,b,c" or value="a;b;c" -> ["a", "b", "c"]
1. opt_type=dict[str,str]/Dict[str,str],
value="key1:val1,key2:val2" or value="key1:val1;key2:val2" -> {"key1": "val1", "key2": "val2"}

NOTE: dict parsing uses ":" as the kv separator (rather than the standard "=") because "=" is used
at the top-level cfg to parse runopts (notice the plural) from the CLI. Originally torchx only supported
primitives and list[str] as CfgVal but dict[str,str] was added in https://github.com/pytorch/torchx/pull/855
"""

if self.opt_type is None:
raise ValueError("runopt's opt_type cannot be `None`")
elif self.opt_type == bool:
return value.lower() == "true"
elif self.opt_type in (List[str], list[str]):
# lists may be ; or , delimited
# also deal with trailing "," by removing empty strings
return [v for v in value.replace(";", ",").split(",") if v]
elif self.opt_type in (Dict[str, str], dict[str, str]):
return {
s.split(":", 1)[0]: s.split(":", 1)[1]
for s in value.replace(";", ",").split(",")
}
else:
assert self.opt_type in (str, int, float)
return self.opt_type(value)


class runopts:
"""
Expand Down Expand Up @@ -948,27 +1002,11 @@ def cfg_from_str(self, cfg_str: str) -> Dict[str, CfgVal]:

"""

def _cast_to_type(value: str, opt_type: Type[CfgVal]) -> CfgVal:
if opt_type == bool:
return value.lower() == "true"
elif opt_type in (List[str], list[str]):
# lists may be ; or , delimited
# also deal with trailing "," by removing empty strings
return [v for v in value.replace(";", ",").split(",") if v]
elif opt_type in (Dict[str, str], dict[str, str]):
return {
s.split(":", 1)[0]: s.split(":", 1)[1]
for s in value.replace(";", ",").split(",")
}
else:
# pyre-ignore[19, 6] type won't be dict here as we handled it above
return opt_type(value)

cfg: Dict[str, CfgVal] = {}
for key, val in to_dict(cfg_str).items():
runopt_ = self.get(key)
if runopt_:
cfg[key] = _cast_to_type(val, runopt_.opt_type)
opt = self.get(key)
if opt:
cfg[key] = opt.cast_to_type(val)
else:
logger.warning(
f"{YELLOW_BOLD}Unknown run option passed to scheduler: {key}={val}{RESET}"
Expand All @@ -982,16 +1020,16 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
cfg: Dict[str, CfgVal] = {}
cfg_dict = json.loads(json_repr)
for key, val in cfg_dict.items():
runopt_ = self.get(key)
if runopt_:
opt = self.get(key)
if opt:
# Optional runopt cfg values default their value to None,
# but use `_type` to specify their type when provided.
# Make sure not to treat None's as lists/dictionaries
if val is None:
cfg[key] = val
elif runopt_.opt_type == List[str]:
elif opt.is_type_list_of_str:
cfg[key] = [str(v) for v in val]
elif runopt_.opt_type == Dict[str, str]:
elif opt.is_type_dict_of_str:
cfg[key] = {str(k): str(v) for k, v in val.items()}
else:
cfg[key] = val
Expand Down
11 changes: 11 additions & 0 deletions torchx/specs/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
RetryPolicy,
Role,
RoleStatus,
runopt,
runopts,
)

Expand Down Expand Up @@ -437,6 +438,16 @@ def test_valid_values(self) -> None:
self.assertTrue(cfg.get("preemptible"))
self.assertIsNone(cfg.get("unknown"))

def test_runopt_cast_to_type_typing_list(self) -> None:
opt = runopt(default="", opt_type=List[str], is_required=False, help="help")
self.assertEqual(["a", "b", "c"], opt.cast_to_type("a,b,c"))
self.assertEqual(["abc", "def", "ghi"], opt.cast_to_type("abc;def;ghi"))

def test_runopt_cast_to_type_builtin_list(self) -> None:
opt = runopt(default="", opt_type=list[str], is_required=False, help="help")
self.assertEqual(["a", "b", "c"], opt.cast_to_type("a,b,c"))
self.assertEqual(["abc", "def", "ghi"], opt.cast_to_type("abc;def;ghi"))

def test_runopts_add(self) -> None:
"""
tests for various add option variations
Expand Down
Loading