Skip to content

Commit 7d8076c

Browse files
authored
Merge pull request #212 from robotpy/tunable-empty-seq
tunable: Allow empty default lists when type-hinted
2 parents 66e5424 + 20ea22e commit 7d8076c

File tree

2 files changed

+85
-5
lines changed

2 files changed

+85
-5
lines changed

Diff for: magicbot/magic_tunable.py

+42-4
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def execute(self):
8080
"_ntsubtable",
8181
"_ntwritedefault",
8282
# "__doc__",
83+
"__orig_class__",
8384
"_topic_type",
8485
"_nt",
8586
)
@@ -100,13 +101,48 @@ def __init__(
100101
self._ntwritedefault = writeDefault
101102
# self.__doc__ = doc
102103

103-
self._topic_type = _get_topic_type_for_value(self._ntdefault)
104-
if self._topic_type is None:
105-
checked_type: type = type(self._ntdefault)
104+
# Defer checks for empty sequences to check type hints.
105+
# Report errors here when we can so the error points to the tunable line.
106+
if default or not isinstance(default, collections.abc.Sequence):
107+
topic_type = _get_topic_type_for_value(default)
108+
if topic_type is None:
109+
checked_type: type = type(default)
110+
raise TypeError(
111+
f"tunable is not publishable to NetworkTables, type: {checked_type.__name__}"
112+
)
113+
self._topic_type = topic_type
114+
115+
def __set_name__(self, owner: type, name: str) -> None:
116+
type_hint: Optional[type] = None
117+
# __orig_class__ is set after __init__, check it here.
118+
orig_class = getattr(self, "__orig_class__", None)
119+
if orig_class is not None:
120+
# Accept field = tunable[Sequence[int]]([])
121+
type_hint = typing.get_args(orig_class)[0]
122+
else:
123+
type_hint = typing.get_type_hints(owner).get(name)
124+
origin = typing.get_origin(type_hint)
125+
if origin is typing.ClassVar:
126+
# Accept field: ClassVar[tunable[Sequence[int]]] = tunable([])
127+
type_hint = typing.get_args(type_hint)[0]
128+
origin = typing.get_origin(type_hint)
129+
if origin is tunable:
130+
# Accept field: tunable[Sequence[int]] = tunable([])
131+
type_hint = typing.get_args(type_hint)[0]
132+
133+
if type_hint is not None:
134+
topic_type = _get_topic_type(type_hint)
135+
else:
136+
topic_type = _get_topic_type_for_value(self._ntdefault)
137+
138+
if topic_type is None:
139+
checked_type: type = type_hint or type(self._ntdefault)
106140
raise TypeError(
107141
f"tunable is not publishable to NetworkTables, type: {checked_type.__name__}"
108142
)
109143

144+
self._topic_type = topic_type
145+
110146
@overload
111147
def __get__(self, instance: None, owner=None) -> "tunable[V]": ...
112148

@@ -218,7 +254,7 @@ class MyComponent:
218254
navx: ...
219255
220256
@feedback
221-
def get_angle(self):
257+
def get_angle(self) -> float:
222258
return self.navx.getYaw()
223259
224260
class MyRobot(magicbot.MagicRobot):
@@ -297,6 +333,8 @@ def _get_topic_type(
297333
if hasattr(inner_type, "WPIStruct"):
298334
return lambda topic: ntcore.StructArrayTopic(topic, inner_type)
299335

336+
return None
337+
300338

301339
def collect_feedbacks(component, cname: str, prefix: Optional[str] = "components"):
302340
"""

Diff for: tests/test_magicbot_tunable.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import ClassVar, List, Sequence
2+
13
import ntcore
24
import pytest
35
from wpimath import geometry
@@ -25,6 +27,7 @@ class Component:
2527
topic = nt.getTopic(name)
2628
assert topic.getTypeString() == type_str
2729
assert topic.genericSubscribe().get().value() == value
30+
assert getattr(component, name) == value
2831

2932
for name, value in [
3033
("rotation", geometry.Rotation2d()),
@@ -33,13 +36,15 @@ class Component:
3336
assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}"
3437
topic = nt.getStructTopic(name, struct_type)
3538
assert topic.subscribe(None).get() == value
39+
assert getattr(component, name) == value
3640

3741
for name, struct_type, value in [
3842
("rotations", geometry.Rotation2d, [geometry.Rotation2d()]),
3943
]:
4044
assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}[]"
4145
topic = nt.getStructArrayTopic(name, struct_type)
4246
assert topic.subscribe([]).get() == value
47+
assert getattr(component, name) == value
4348

4449

4550
def test_tunable_errors():
@@ -50,7 +55,44 @@ class Component:
5055

5156

5257
def test_tunable_errors_with_empty_sequence():
53-
with pytest.raises(ValueError):
58+
with pytest.raises((RuntimeError, ValueError)):
5459

5560
class Component:
5661
empty = tunable([])
62+
63+
64+
def test_type_hinted_empty_sequences() -> None:
65+
class Component:
66+
generic_seq = tunable[Sequence[int]](())
67+
class_var_seq: ClassVar[tunable[Sequence[int]]] = tunable(())
68+
inst_seq: Sequence[int] = tunable(())
69+
70+
generic_typing_list = tunable[List[int]]([])
71+
class_var_typing_list: ClassVar[tunable[List[int]]] = tunable([])
72+
inst_typing_list: List[int] = tunable([])
73+
74+
# TODO(davo): re-enable after py3.8 is dropped
75+
# generic_list = tunable[list[int]]([])
76+
# class_var_list: ClassVar[tunable[list[int]]] = tunable([])
77+
# inst_list: list[int] = tunable([])
78+
79+
component = Component()
80+
setup_tunables(component, "test_type_hinted_sequences")
81+
NetworkTables = ntcore.NetworkTableInstance.getDefault()
82+
nt = NetworkTables.getTable("/components/test_type_hinted_sequences")
83+
84+
for name in [
85+
"generic_seq",
86+
"class_var_seq",
87+
"inst_seq",
88+
"generic_typing_list",
89+
"class_var_typing_list",
90+
"inst_typing_list",
91+
# "generic_list",
92+
# "class_var_list",
93+
# "inst_list",
94+
]:
95+
assert nt.getTopic(name).getTypeString() == "int[]"
96+
entry = nt.getEntry(name)
97+
assert entry.getIntegerArray(None) == []
98+
assert getattr(component, name) == []

0 commit comments

Comments
 (0)