Skip to content

Commit

Permalink
Add type annotations to mobject/graphing/functions.py
Browse files Browse the repository at this point in the history
  • Loading branch information
henrikmidtiby committed Jan 20, 2025
1 parent 398cd26 commit 6b8f27d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
37 changes: 25 additions & 12 deletions manim/mobject/graphing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
from manim.mobject.types.vectorized_mobject import VMobject

if TYPE_CHECKING:
from typing import Any

from typing_extensions import Self

from manim.typing import Point3D, Point3DLike
from manim.utils.color import ParsableManimColor

from manim.utils.color import YELLOW

Expand Down Expand Up @@ -111,7 +114,7 @@ def __init__(
discontinuities: Iterable[float] | None = None,
use_smoothing: bool = True,
use_vectorized: bool = False,
**kwargs,
**kwargs: Any,
):
def internal_parametric_function(t: float) -> Point3D:
"""Wrap ``function``'s output inside a NumPy array."""

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute function, which was previously defined in subclass
FunctionGraph
.
Expand Down Expand Up @@ -143,13 +146,13 @@ def generate_points(self) -> Self:
lambda t: self.t_min <= t <= self.t_max,
self.discontinuities,
)
discontinuities = np.array(list(discontinuities))
discontinuities_array = np.array(list(discontinuities))
boundary_times = np.array(
[
self.t_min,
self.t_max,
*(discontinuities - self.dt),
*(discontinuities + self.dt),
*(discontinuities_array - self.dt),
*(discontinuities_array + self.dt),
],
)
boundary_times.sort()
Expand Down Expand Up @@ -211,19 +214,29 @@ def construct(self):
self.add(cos_func, sin_func_1, sin_func_2)
"""

def __init__(self, function, x_range=None, color=YELLOW, **kwargs):
def __init__(
self,
function: Callable[[float], Any],
x_range: np.array | None = None,
color: ParsableManimColor = YELLOW,
**kwargs: Any,
) -> None:
if x_range is None:
x_range = np.array([-config["frame_x_radius"], config["frame_x_radius"]])

self.x_range = x_range
self.parametric_function = lambda t: np.array([t, function(t), 0])
self.function = function
self.parametric_function: Callable[[float], np.array] = lambda t: np.array(
[t, function(t), 0]
)
# TODO:
# error: Incompatible types in assignment (expression has type "Callable[[float], Any]", variable has type "Callable[[Arg(float, 't')], Any]") [assignment]
self.function = function # type: ignore[assignment]
super().__init__(self.parametric_function, self.x_range, color=color, **kwargs)

def get_function(self):
def get_function(self) -> Callable[[float], Any]:
return self.function

def get_point_from_function(self, x):
def get_point_from_function(self, x: float) -> np.array:
return self.parametric_function(x)


Expand All @@ -236,8 +249,8 @@ def __init__(
min_depth: int = 5,
max_quads: int = 1500,
use_smoothing: bool = True,
**kwargs,
):
**kwargs: Any,
) -> None:
"""An implicit function.
Parameters
Expand Down Expand Up @@ -295,7 +308,7 @@ def construct(self):

super().__init__(**kwargs)

def generate_points(self):
def generate_points(self) -> Self:
p_min, p_max = (
np.array([self.x_range[0], self.y_range[0]]),
np.array([self.x_range[1], self.y_range[1]]),
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ ignore_errors = False
[mypy-manim.mobject.graphing.scale.*]
ignore_errors = False

[mypy-manim.mobject.graphing.functions.*]
ignore_errors = False

[mypy-manim.mobject.graphing.number_line.*]
ignore_errors = False

Expand Down

0 comments on commit 6b8f27d

Please sign in to comment.