Skip to content

Commit 531b9a3

Browse files
ethanbwaitefacebook-github-bot
authored andcommitted
(torchx/runopt) Allow runopt type to be builtin list[str] and dict[str,str]
Summary: As title, supports `list[str]` and `dict[str, str]` types in runopt. Updated testcases to better cover potential type casting issues. Differential Revision: D78767495
1 parent 4adf7f6 commit 531b9a3

File tree

4 files changed

+92
-35
lines changed

4 files changed

+92
-35
lines changed

torchx/runner/config.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,14 @@ def dump(
278278
continue
279279

280280
# serialize list elements with `;` delimiter (consistent with torchx cli)
281-
if opt.opt_type == List[str]:
281+
if opt.is_type_list:
282282
# deal with empty or None default lists
283283
if opt.default:
284284
# pyre-ignore[6] opt.default type checked already as List[str]
285285
val = ";".join(opt.default)
286286
else:
287287
val = _NONE
288-
elif opt.opt_type == Dict[str, str]:
288+
elif opt.is_type_dict:
289289
# deal with empty or None default lists
290290
if opt.default:
291291
# pyre-ignore[16] opt.default type checked already as Dict[str, str]
@@ -536,26 +536,26 @@ def load(scheduler: str, f: TextIO, cfg: Dict[str, CfgVal]) -> None:
536536
# this also handles empty or None lists
537537
cfg[name] = None
538538
else:
539-
runopt = runopts.get(name)
539+
opt = runopts.get(name)
540540

541-
if runopt is None:
541+
if opt is None:
542542
log.warning(
543543
f"`{name} = {value}` was declared in the [{section}] section "
544544
f" of the config file but is not a runopt of `{scheduler}` scheduler."
545545
f" Remove the entry from the config file to no longer see this warning"
546546
)
547547
else:
548-
if runopt.opt_type is bool:
548+
if opt.opt_type is bool:
549549
# need to handle bool specially since str -> bool is based on
550550
# str emptiness not value (e.g. bool("False") == True)
551551
cfg[name] = config.getboolean(section, name)
552-
elif runopt.opt_type is List[str]:
552+
elif opt.is_type_list:
553553
cfg[name] = value.split(";")
554-
elif runopt.opt_type is Dict[str, str]:
554+
elif opt.is_type_dict:
555555
cfg[name] = {
556556
s.split(":", 1)[0]: s.split(":", 1)[1]
557557
for s in value.replace(",", ";").split(";")
558558
}
559559
else:
560560
# pyre-ignore[29]
561-
cfg[name] = runopt.opt_type(value)
561+
cfg[name] = opt.opt_type(value)

torchx/runner/test/config_test.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,34 @@ def _run_opts(self) -> runopts:
9595
)
9696
opts.add(
9797
"l",
98-
type_=List[str],
98+
type_=list[str],
9999
default=["a", "b", "c"],
100100
help="a list option",
101101
)
102102
opts.add(
103-
"l_none",
103+
"l_typing",
104104
type_=List[str],
105+
default=["a", "b", "c"],
106+
help="a typing.List option",
107+
)
108+
opts.add(
109+
"l_none",
110+
type_=list[str],
105111
default=None,
106112
help="a None list option",
107113
)
108114
opts.add(
109115
"d",
110-
type_=Dict[str, str],
116+
type_=dict[str, str],
111117
default={"foo": "bar"},
112118
help="a dict option",
113119
)
120+
opts.add(
121+
"d_typing",
122+
type_=Dict[str, str],
123+
default={"foo": "bar"},
124+
help="a typing.Dict option",
125+
)
114126
opts.add(
115127
"d_none",
116128
type_=Dict[str, str],
@@ -151,6 +163,10 @@ def _run_opts(self) -> runopts:
151163
[test]
152164
s = my_default
153165
i = 100
166+
l = abc;def
167+
l_typing = ghi;jkl
168+
d = a:b,c:d
169+
d_typing = e:f,g:h
154170
"""
155171

156172
_MY_CONFIG2 = """#
@@ -387,6 +403,10 @@ def test_apply_dirs(self, _) -> None:
387403
self.assertEqual("runtime_value", cfg.get("s"))
388404
self.assertEqual(100, cfg.get("i"))
389405
self.assertEqual(1.2, cfg.get("f"))
406+
self.assertEqual({"a": "b", "c": "d"}, cfg.get("d"))
407+
self.assertEqual({"e": "f", "g": "h"}, cfg.get("d_typing"))
408+
self.assertEqual(["abc", "def"], cfg.get("l"))
409+
self.assertEqual(["ghi", "jkl"], cfg.get("l_typing"))
390410

391411
def test_dump_invalid_scheduler(self) -> None:
392412
with self.assertRaises(ValueError):
@@ -460,7 +480,7 @@ def test_dump_and_load_all_runopt_types(self, _) -> None:
460480

461481
# all runopts in the TestScheduler have defaults, just check against those
462482
for opt_name, opt in TestScheduler("test").run_opts():
463-
self.assertEqual(cfg.get(opt_name), opt.default)
483+
self.assertEqual(opt.default, cfg.get(opt_name))
464484

465485
def test_dump_and_load_all_registered_schedulers(self) -> None:
466486
# dump all the runopts for all registered schedulers

torchx/specs/api.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,48 @@ class runopt:
789789
is_required: bool
790790
help: str
791791

792+
@property
793+
def is_type_list(self) -> bool:
794+
return self.opt_type in (List[str], list[str])
795+
796+
@property
797+
def is_type_dict(self) -> bool:
798+
return self.opt_type in (Dict[str, str], dict[str, str])
799+
800+
def cast_to_type(self, value: str) -> CfgVal:
801+
"""Casts the given `value` (in its string representation) to the type of this run option.
802+
Below are the cast rules for each option type and value literal:
803+
804+
1. opt_type=str, value="foo" -> "foo"
805+
1. opt_type=bool, value="True"/"False" -> True/False
806+
1. opt_type=int, value="1" -> 1
807+
1. opt_type=float, value="1.1" -> 1.1
808+
1. opt_type=list[str]/List[str], value="a,b,c" or value="a;b;c" -> ["a", "b", "c"]
809+
1. opt_type=dict[str,str]/Dict[str,str],
810+
value="key1:val1,key2:val2" or value="key1:val1;key2:val2" -> {"key1": "val1", "key2": "val2"}
811+
812+
NOTE: dict parsing uses ":" as the kv separator (rather than the standard "=") because "=" is used
813+
at the top-level cfg to parse runopts (notice the plural) from the CLI. Originally torchx only supported
814+
primitives and list[str] as CfgVal but dict[str,str] was added in https://github.com/pytorch/torchx/pull/855
815+
"""
816+
817+
if self.opt_type is None:
818+
raise ValueError("runopt's opt_type cannot be `None`")
819+
elif self.opt_type == bool:
820+
return value.lower() == "true"
821+
elif self.opt_type in (List[str], list[str]):
822+
# lists may be ; or , delimited
823+
# also deal with trailing "," by removing empty strings
824+
return [v for v in value.replace(";", ",").split(",") if v]
825+
elif self.opt_type in (Dict[str, str], dict[str, str]):
826+
return {
827+
s.split(":", 1)[0]: s.split(":", 1)[1]
828+
for s in value.replace(";", ",").split(",")
829+
}
830+
else:
831+
assert self.opt_type in (str, int, float)
832+
return self.opt_type(value)
833+
792834

793835
class runopts:
794836
"""
@@ -948,27 +990,11 @@ def cfg_from_str(self, cfg_str: str) -> Dict[str, CfgVal]:
948990
949991
"""
950992

951-
def _cast_to_type(value: str, opt_type: Type[CfgVal]) -> CfgVal:
952-
if opt_type == bool:
953-
return value.lower() == "true"
954-
elif opt_type in (List[str], list[str]):
955-
# lists may be ; or , delimited
956-
# also deal with trailing "," by removing empty strings
957-
return [v for v in value.replace(";", ",").split(",") if v]
958-
elif opt_type in (Dict[str, str], dict[str, str]):
959-
return {
960-
s.split(":", 1)[0]: s.split(":", 1)[1]
961-
for s in value.replace(";", ",").split(",")
962-
}
963-
else:
964-
# pyre-ignore[19, 6] type won't be dict here as we handled it above
965-
return opt_type(value)
966-
967993
cfg: Dict[str, CfgVal] = {}
968994
for key, val in to_dict(cfg_str).items():
969-
runopt_ = self.get(key)
970-
if runopt_:
971-
cfg[key] = _cast_to_type(val, runopt_.opt_type)
995+
opt = self.get(key)
996+
if opt:
997+
cfg[key] = opt.cast_to_type(val)
972998
else:
973999
logger.warning(
9741000
f"{YELLOW_BOLD}Unknown run option passed to scheduler: {key}={val}{RESET}"
@@ -982,16 +1008,16 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
9821008
cfg: Dict[str, CfgVal] = {}
9831009
cfg_dict = json.loads(json_repr)
9841010
for key, val in cfg_dict.items():
985-
runopt_ = self.get(key)
986-
if runopt_:
1011+
opt = self.get(key)
1012+
if opt:
9871013
# Optional runopt cfg values default their value to None,
9881014
# but use `_type` to specify their type when provided.
9891015
# Make sure not to treat None's as lists/dictionaries
9901016
if val is None:
9911017
cfg[key] = val
992-
elif runopt_.opt_type == List[str]:
1018+
elif opt.is_type_list:
9931019
cfg[key] = [str(v) for v in val]
994-
elif runopt_.opt_type == Dict[str, str]:
1020+
elif opt.is_type_dict:
9951021
cfg[key] = {str(k): str(v) for k, v in val.items()}
9961022
else:
9971023
cfg[key] = val

torchx/specs/test/api_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
RetryPolicy,
3939
Role,
4040
RoleStatus,
41+
runopt,
4142
runopts,
4243
)
4344

@@ -437,6 +438,16 @@ def test_valid_values(self) -> None:
437438
self.assertTrue(cfg.get("preemptible"))
438439
self.assertIsNone(cfg.get("unknown"))
439440

441+
def test_runopt_cast_to_type_typing_list(self) -> None:
442+
opt = runopt(default="", opt_type=List[str], is_required=False, help="help")
443+
self.assertEqual(["a", "b", "c"], opt.cast_to_type("a,b,c"))
444+
self.assertEqual(["abc", "def", "ghi"], opt.cast_to_type("abc;def;ghi"))
445+
446+
def test_runopt_cast_to_type_builtin_list(self) -> None:
447+
opt = runopt(default="", opt_type=list[str], is_required=False, help="help")
448+
self.assertEqual(["a", "b", "c"], opt.cast_to_type("a,b,c"))
449+
self.assertEqual(["abc", "def", "ghi"], opt.cast_to_type("abc;def;ghi"))
450+
440451
def test_runopts_add(self) -> None:
441452
"""
442453
tests for various add option variations

0 commit comments

Comments
 (0)