Skip to content

Commit f6ec750

Browse files
committed
Added basic test code for the profiled PID subsystem.
1 parent 2921008 commit f6ec750

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed

tests/test_profiledpidsubsystem.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from types import MethodType
2+
from typing import Any
3+
4+
import pytest
5+
from wpimath.controller import ProfiledPIDController, ProfiledPIDControllerRadians
6+
from wpimath.trajectory import TrapezoidProfile, TrapezoidProfileRadians
7+
8+
from commands2 import ProfiledPIDSubsystem
9+
10+
MAX_VELOCITY = 30 # Radians per second
11+
MAX_ACCELERATION = 500 # Radians per sec squared
12+
PID_KP = 50
13+
14+
15+
class EvalSubsystem(ProfiledPIDSubsystem):
16+
def __init__(self, controller, state_factory):
17+
self._state_factory = state_factory
18+
super().__init__(controller, 0)
19+
20+
21+
def simple_use_output(self, output: float, setpoint: Any):
22+
"""A simple _useOutput method that saves the current state of the controller."""
23+
self._output = output
24+
self._setpoint = setpoint
25+
26+
27+
def simple_get_measurement(self) -> float:
28+
"""A simple _getMeasurement method that returns zero (frozen or stuck plant)."""
29+
return 0.0
30+
31+
32+
controller_types = [
33+
(
34+
ProfiledPIDControllerRadians,
35+
TrapezoidProfileRadians.Constraints,
36+
TrapezoidProfileRadians.State,
37+
),
38+
(ProfiledPIDController, TrapezoidProfile.Constraints, TrapezoidProfile.State),
39+
]
40+
controller_ids = ["radians", "dimensionless"]
41+
42+
43+
@pytest.fixture(params=controller_types, ids=controller_ids)
44+
def subsystem(request):
45+
"""
46+
Fixture that returns an EvalSubsystem object for each type of controller.
47+
"""
48+
controller, profile_factory, state_factory = request.param
49+
profile = profile_factory(MAX_VELOCITY, MAX_ACCELERATION)
50+
pid = controller(PID_KP, 0, 0, profile)
51+
return EvalSubsystem(pid, state_factory)
52+
53+
54+
def test_profiled_pid_subsystem_init(subsystem):
55+
"""
56+
Verify that the ProfiledPIDSubsystem can be initialized using
57+
all supported profiled PID controller / trapezoid profile types.
58+
"""
59+
assert isinstance(subsystem, EvalSubsystem)
60+
61+
62+
def test_profiled_pid_subsystem_not_implemented_get_measurement(subsystem):
63+
"""
64+
Verify that the ProfiledPIDSubsystem._getMeasurement method
65+
raises NotImplementedError.
66+
"""
67+
with pytest.raises(NotImplementedError):
68+
subsystem._getMeasurement()
69+
70+
71+
def test_profiled_pid_subsystem_not_implemented_use_output(subsystem):
72+
"""
73+
Verify that the ProfiledPIDSubsystem._useOutput method raises
74+
NotImplementedError.
75+
"""
76+
with pytest.raises(NotImplementedError):
77+
subsystem._useOutput(0, subsystem._state_factory())
78+
79+
80+
@pytest.mark.parametrize("use_float", [True, False])
81+
def test_profiled_pid_subsystem_set_goal(subsystem, use_float):
82+
"""
83+
Verify that the ProfiledPIDSubsystem.setGoal method sets the goal.
84+
"""
85+
if use_float:
86+
subsystem.setGoal(1.0)
87+
assert subsystem.getController().getGoal().position == 1.0
88+
assert subsystem.getController().getGoal().velocity == 0.0
89+
else:
90+
subsystem.setGoal(subsystem._state_factory(1.0, 2.0))
91+
assert subsystem.getController().getGoal().position == 1.0
92+
assert subsystem.getController().getGoal().velocity == 2.0
93+
94+
95+
def test_profiled_pid_subsystem_enable_subsystem(subsystem):
96+
"""
97+
Verify the subsystem can be enabled.
98+
"""
99+
# Dynamically add _useOutput and _getMeasurement methods so the
100+
# system can be enabled
101+
setattr(subsystem, "_useOutput", MethodType(simple_use_output, subsystem))
102+
setattr(subsystem, "_getMeasurement", MethodType(simple_get_measurement, subsystem))
103+
# Enable the subsystem
104+
subsystem.enable()
105+
assert subsystem.isEnabled()
106+
107+
108+
def test_profiled_pid_subsystem_disable_subsystem(subsystem):
109+
"""
110+
Verify the subsystem can be disabled.
111+
"""
112+
# Dynamically add _useOutput and _getMeasurement methods so the
113+
# system can be enabled
114+
setattr(subsystem, "_useOutput", MethodType(simple_use_output, subsystem))
115+
setattr(subsystem, "_getMeasurement", MethodType(simple_get_measurement, subsystem))
116+
# Enable and then disable the subsystem
117+
subsystem.enable()
118+
subsystem.disable()
119+
assert not subsystem.isEnabled()

0 commit comments

Comments
 (0)