Skip to content

Commit 6b8f27d

Browse files
committed
Add type annotations to mobject/graphing/functions.py
1 parent 398cd26 commit 6b8f27d

File tree

2 files changed

+28
-12
lines changed

2 files changed

+28
-12
lines changed

manim/mobject/graphing/functions.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
from manim.mobject.types.vectorized_mobject import VMobject
1818

1919
if TYPE_CHECKING:
20+
from typing import Any
21+
2022
from typing_extensions import Self
2123

2224
from manim.typing import Point3D, Point3DLike
25+
from manim.utils.color import ParsableManimColor
2326

2427
from manim.utils.color import YELLOW
2528

@@ -111,7 +114,7 @@ def __init__(
111114
discontinuities: Iterable[float] | None = None,
112115
use_smoothing: bool = True,
113116
use_vectorized: bool = False,
114-
**kwargs,
117+
**kwargs: Any,
115118
):
116119
def internal_parametric_function(t: float) -> Point3D:
117120
"""Wrap ``function``'s output inside a NumPy array."""
@@ -143,13 +146,13 @@ def generate_points(self) -> Self:
143146
lambda t: self.t_min <= t <= self.t_max,
144147
self.discontinuities,
145148
)
146-
discontinuities = np.array(list(discontinuities))
149+
discontinuities_array = np.array(list(discontinuities))
147150
boundary_times = np.array(
148151
[
149152
self.t_min,
150153
self.t_max,
151-
*(discontinuities - self.dt),
152-
*(discontinuities + self.dt),
154+
*(discontinuities_array - self.dt),
155+
*(discontinuities_array + self.dt),
153156
],
154157
)
155158
boundary_times.sort()
@@ -211,19 +214,29 @@ def construct(self):
211214
self.add(cos_func, sin_func_1, sin_func_2)
212215
"""
213216

214-
def __init__(self, function, x_range=None, color=YELLOW, **kwargs):
217+
def __init__(
218+
self,
219+
function: Callable[[float], Any],
220+
x_range: np.array | None = None,
221+
color: ParsableManimColor = YELLOW,
222+
**kwargs: Any,
223+
) -> None:
215224
if x_range is None:
216225
x_range = np.array([-config["frame_x_radius"], config["frame_x_radius"]])
217226

218227
self.x_range = x_range
219-
self.parametric_function = lambda t: np.array([t, function(t), 0])
220-
self.function = function
228+
self.parametric_function: Callable[[float], np.array] = lambda t: np.array(
229+
[t, function(t), 0]
230+
)
231+
# TODO:
232+
# error: Incompatible types in assignment (expression has type "Callable[[float], Any]", variable has type "Callable[[Arg(float, 't')], Any]") [assignment]
233+
self.function = function # type: ignore[assignment]
221234
super().__init__(self.parametric_function, self.x_range, color=color, **kwargs)
222235

223-
def get_function(self):
236+
def get_function(self) -> Callable[[float], Any]:
224237
return self.function
225238

226-
def get_point_from_function(self, x):
239+
def get_point_from_function(self, x: float) -> np.array:
227240
return self.parametric_function(x)
228241

229242

@@ -236,8 +249,8 @@ def __init__(
236249
min_depth: int = 5,
237250
max_quads: int = 1500,
238251
use_smoothing: bool = True,
239-
**kwargs,
240-
):
252+
**kwargs: Any,
253+
) -> None:
241254
"""An implicit function.
242255
243256
Parameters
@@ -295,7 +308,7 @@ def construct(self):
295308

296309
super().__init__(**kwargs)
297310

298-
def generate_points(self):
311+
def generate_points(self) -> Self:
299312
p_min, p_max = (
300313
np.array([self.x_range[0], self.y_range[0]]),
301314
np.array([self.x_range[1], self.y_range[1]]),

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ ignore_errors = False
7373
[mypy-manim.mobject.graphing.scale.*]
7474
ignore_errors = False
7575

76+
[mypy-manim.mobject.graphing.functions.*]
77+
ignore_errors = False
78+
7679
[mypy-manim.mobject.graphing.number_line.*]
7780
ignore_errors = False
7881

0 commit comments

Comments
 (0)