Skip to content

Commit 66e5424

Browse files
authored
Merge pull request #208 from robotpy/typed-feedback
magicbot: Allow typed feedbacks using return type hints
2 parents ca789fc + 0c6b539 commit 66e5424

File tree

5 files changed

+269
-20
lines changed

5 files changed

+269
-20
lines changed

magicbot/magic_tunable.py

Lines changed: 103 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
1+
import collections.abc
12
import functools
23
import inspect
4+
import typing
35
import warnings
4-
from typing import Callable, Generic, Optional, TypeVar, overload
6+
from typing import Callable, Generic, Optional, Sequence, TypeVar, Union, overload
57

6-
from ntcore import NetworkTableInstance, Value
8+
import ntcore
9+
from ntcore import NetworkTableInstance
710
from ntcore.types import ValueT
811

12+
13+
class StructSerializable(typing.Protocol):
14+
"""Any type that is a wpiutil.wpistruct."""
15+
16+
WPIStruct: typing.ClassVar
17+
18+
919
T = TypeVar("T")
10-
V = TypeVar("V", bound=ValueT)
20+
V = TypeVar("V", bound=Union[ValueT, StructSerializable, Sequence[StructSerializable]])
1121

1222

1323
class tunable(Generic[V]):
@@ -50,6 +60,10 @@ def execute(self):
5060
you will want to use setup_tunables to set the object up.
5161
In normal usage, MagicRobot does this for you, so you don't
5262
have to do anything special.
63+
64+
.. versionchanged:: 2024.1.0
65+
Added support for WPILib Struct serializable types.
66+
Integer defaults now create integer topics instead of double topics.
5367
"""
5468

5569
# the way this works is we use a special class to indicate that it
@@ -66,7 +80,7 @@ def execute(self):
6680
"_ntsubtable",
6781
"_ntwritedefault",
6882
# "__doc__",
69-
"_mkv",
83+
"_topic_type",
7084
"_nt",
7185
)
7286

@@ -84,10 +98,15 @@ def __init__(
8498
self._ntdefault = default
8599
self._ntsubtable = subtable
86100
self._ntwritedefault = writeDefault
87-
d = Value.makeValue(default)
88-
self._mkv = Value.getFactoryByType(d.type())
89101
# self.__doc__ = doc
90102

103+
self._topic_type = _get_topic_type_for_value(self._ntdefault)
104+
if self._topic_type is None:
105+
checked_type: type = type(self._ntdefault)
106+
raise TypeError(
107+
f"tunable is not publishable to NetworkTables, type: {checked_type.__name__}"
108+
)
109+
91110
@overload
92111
def __get__(self, instance: None, owner=None) -> "tunable[V]": ...
93112

@@ -96,11 +115,23 @@ def __get__(self, instance, owner=None) -> V: ...
96115

97116
def __get__(self, instance, owner=None):
98117
if instance is not None:
99-
return instance._tunables[self].value
118+
return instance._tunables[self].get()
100119
return self
101120

102121
def __set__(self, instance, value: V) -> None:
103-
instance._tunables[self].setValue(self._mkv(value))
122+
instance._tunables[self].set(value)
123+
124+
125+
def _get_topic_type_for_value(value) -> Optional[Callable[[ntcore.Topic], typing.Any]]:
126+
topic_type = _get_topic_type(type(value))
127+
# bytes and str are Sequences. They must be checked before Sequence.
128+
if topic_type is None and isinstance(value, collections.abc.Sequence):
129+
if not value:
130+
raise ValueError(
131+
f"tunable default cannot be an empty sequence, got {value}"
132+
)
133+
topic_type = _get_topic_type(Sequence[type(value[0])]) # type: ignore [misc]
134+
return topic_type
104135

105136

106137
def setup_tunables(component, cname: str, prefix: Optional[str] = "components") -> None:
@@ -124,7 +155,7 @@ def setup_tunables(component, cname: str, prefix: Optional[str] = "components")
124155

125156
NetworkTables = NetworkTableInstance.getDefault()
126157

127-
tunables = {}
158+
tunables: dict[tunable, ntcore.Topic] = {}
128159

129160
for n in dir(cls):
130161
if n.startswith("_"):
@@ -139,11 +170,12 @@ def setup_tunables(component, cname: str, prefix: Optional[str] = "components")
139170
else:
140171
key = "%s/%s" % (prefix, n)
141172

142-
ntvalue = NetworkTables.getEntry(key)
173+
topic = prop._topic_type(NetworkTables.getTopic(key))
174+
ntvalue = topic.getEntry(prop._ntdefault)
143175
if prop._ntwritedefault:
144-
ntvalue.setValue(prop._ntdefault)
176+
ntvalue.set(prop._ntdefault)
145177
else:
146-
ntvalue.setDefaultValue(prop._ntdefault)
178+
ntvalue.setDefault(prop._ntdefault)
147179
tunables[prop] = ntvalue
148180

149181
component._tunables = tunables
@@ -201,6 +233,10 @@ class MyRobot(magicbot.MagicRobot):
201233
especially if you wish to monitor WPILib objects.
202234
203235
.. versionadded:: 2018.1.0
236+
237+
.. versionchanged:: 2024.1.0
238+
WPILib Struct serializable types are supported when the return type is type hinted.
239+
An ``int`` return type hint now creates an integer topic.
204240
"""
205241
if f is None:
206242
return functools.partial(feedback, key=key)
@@ -222,10 +258,50 @@ class MyRobot(magicbot.MagicRobot):
222258
return f
223259

224260

261+
_topic_types = {
262+
bool: ntcore.BooleanTopic,
263+
int: ntcore.IntegerTopic,
264+
float: ntcore.DoubleTopic,
265+
str: ntcore.StringTopic,
266+
bytes: ntcore.RawTopic,
267+
}
268+
_array_topic_types = {
269+
bool: ntcore.BooleanArrayTopic,
270+
int: ntcore.IntegerArrayTopic,
271+
float: ntcore.DoubleArrayTopic,
272+
str: ntcore.StringArrayTopic,
273+
}
274+
275+
276+
def _get_topic_type(
277+
return_annotation,
278+
) -> Optional[Callable[[ntcore.Topic], typing.Any]]:
279+
if return_annotation in _topic_types:
280+
return _topic_types[return_annotation]
281+
if hasattr(return_annotation, "WPIStruct"):
282+
return lambda topic: ntcore.StructTopic(topic, return_annotation)
283+
284+
# Check for PEP 484 generic types
285+
origin = getattr(return_annotation, "__origin__", None)
286+
args = typing.get_args(return_annotation)
287+
if origin in (list, tuple, collections.abc.Sequence) and args:
288+
# Ensure tuples are tuple[T, ...] or homogenous
289+
if origin is tuple and not (
290+
(len(args) == 2 and args[1] is Ellipsis) or len(set(args)) == 1
291+
):
292+
return None
293+
294+
inner_type = args[0]
295+
if inner_type in _array_topic_types:
296+
return _array_topic_types[inner_type]
297+
if hasattr(inner_type, "WPIStruct"):
298+
return lambda topic: ntcore.StructArrayTopic(topic, inner_type)
299+
300+
225301
def collect_feedbacks(component, cname: str, prefix: Optional[str] = "components"):
226302
"""
227303
Finds all methods decorated with :func:`feedback` on an object
228-
and returns a list of 2-tuples (method, NetworkTables entry).
304+
and returns a list of 2-tuples (method, NetworkTables entry setter).
229305
230306
.. note:: This isn't useful for normal use.
231307
"""
@@ -246,7 +322,19 @@ def collect_feedbacks(component, cname: str, prefix: Optional[str] = "components
246322
else:
247323
key = name
248324

249-
entry = nt.getEntry(key)
250-
feedbacks.append((method, entry))
325+
return_annotation = typing.get_type_hints(method).get("return", None)
326+
if return_annotation is not None:
327+
topic_type = _get_topic_type(return_annotation)
328+
else:
329+
topic_type = None
330+
331+
if topic_type is None:
332+
entry = nt.getEntry(key)
333+
setter = entry.setValue
334+
else:
335+
publisher = topic_type(nt.getTopic(key)).publish()
336+
setter = publisher.set
337+
338+
feedbacks.append((method, setter))
251339

252340
return feedbacks

magicbot/magicrobot.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import hal
1111
import wpilib
1212

13-
from ntcore import NetworkTableInstance, NetworkTableEntry
13+
from ntcore import NetworkTableInstance
1414

1515
# from wpilib.shuffleboard import Shuffleboard
1616

@@ -73,7 +73,7 @@ def __init__(self) -> None:
7373
self.__last_error_report = -10
7474

7575
self._components: List[Tuple[str, Any]] = []
76-
self._feedbacks: List[Tuple[Callable[[], Any], NetworkTableEntry]] = []
76+
self._feedbacks: List[Tuple[Callable[[], Any], Callable[[Any], Any]]] = []
7777
self._reset_components: List[Tuple[Dict[str, Any], Any]] = []
7878

7979
self.__done = False
@@ -720,13 +720,13 @@ def _do_periodics(self) -> None:
720720
"""Run periodic methods which run in every mode."""
721721
watchdog = self.watchdog
722722

723-
for method, entry in self._feedbacks:
723+
for method, setter in self._feedbacks:
724724
try:
725725
value = method()
726726
except:
727727
self.onException()
728728
else:
729-
entry.setValue(value)
729+
setter(value)
730730

731731
watchdog.addEpoch("@magicbot.feedback")
732732

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ zip_safe = False
2424
include_package_data = True
2525
packages = find:
2626
install_requires =
27-
wpilib>=2024.1.1.0,<2025
27+
wpilib>=2024.3.2.1,<2025
2828
setup_requires =
2929
setuptools_scm > 6
3030
python_requires = >=3.8

tests/test_magicbot_feedback.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from typing import Sequence, Tuple
2+
3+
import ntcore
4+
from wpimath import geometry
5+
6+
import magicbot
7+
8+
9+
class BasicComponent:
10+
@magicbot.feedback
11+
def get_number(self):
12+
return 0
13+
14+
@magicbot.feedback
15+
def get_ints(self):
16+
return (0,)
17+
18+
@magicbot.feedback
19+
def get_floats(self):
20+
return (0.0, 0)
21+
22+
def execute(self):
23+
pass
24+
25+
26+
class TypeHintedComponent:
27+
@magicbot.feedback
28+
def get_rotation(self) -> geometry.Rotation2d:
29+
return geometry.Rotation2d()
30+
31+
@magicbot.feedback
32+
def get_rotation_array(self) -> Sequence[geometry.Rotation2d]:
33+
return [geometry.Rotation2d()]
34+
35+
@magicbot.feedback
36+
def get_rotation_2_tuple(self) -> Tuple[geometry.Rotation2d, geometry.Rotation2d]:
37+
return (geometry.Rotation2d(), geometry.Rotation2d())
38+
39+
@magicbot.feedback
40+
def get_int(self) -> int:
41+
return 0
42+
43+
@magicbot.feedback
44+
def get_float(self) -> float:
45+
return 0.5
46+
47+
@magicbot.feedback
48+
def get_ints(self) -> Sequence[int]:
49+
return (0,)
50+
51+
@magicbot.feedback
52+
def get_empty_strings(self) -> Sequence[str]:
53+
return ()
54+
55+
def execute(self):
56+
pass
57+
58+
59+
class Robot(magicbot.MagicRobot):
60+
basic: BasicComponent
61+
type_hinted: TypeHintedComponent
62+
63+
def createObjects(self):
64+
pass
65+
66+
67+
def test_feedbacks_with_type_hints():
68+
robot = Robot()
69+
robot.robotInit()
70+
nt = ntcore.NetworkTableInstance.getDefault().getTable("components")
71+
72+
robot._do_periodics()
73+
74+
for name, type_str, value in (
75+
("basic/number", "double", 0.0),
76+
("basic/ints", "int[]", [0]),
77+
("basic/floats", "double[]", [0.0, 0.0]),
78+
("type_hinted/int", "int", 0),
79+
("type_hinted/float", "double", 0.5),
80+
("type_hinted/ints", "int[]", [0]),
81+
("type_hinted/empty_strings", "string[]", []),
82+
):
83+
topic = nt.getTopic(name)
84+
assert topic.getTypeString() == type_str
85+
assert topic.genericSubscribe().get().value() == value
86+
87+
for name, value in [
88+
("type_hinted/rotation", geometry.Rotation2d()),
89+
]:
90+
struct_type = type(value)
91+
assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}"
92+
topic = nt.getStructTopic(name, struct_type)
93+
assert topic.subscribe(None).get() == value
94+
95+
for name, struct_type, value in (
96+
("type_hinted/rotation_array", geometry.Rotation2d, [geometry.Rotation2d()]),
97+
(
98+
"type_hinted/rotation_2_tuple",
99+
geometry.Rotation2d,
100+
[geometry.Rotation2d(), geometry.Rotation2d()],
101+
),
102+
):
103+
assert nt.getTopic(name).getTypeString() == f"struct:{struct_type.__name__}[]"
104+
topic = nt.getStructArrayTopic(name, struct_type)
105+
assert topic.subscribe([]).get() == value

0 commit comments

Comments
 (0)