Skip to content

Commit 72a0414

Browse files
committed
Fix: Type Errors in mobject/graphing/functions.py
1 parent a6990bd commit 72a0414

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

manim/mobject/graphing/functions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def construct(self):
104104

105105
def __init__(
106106
self,
107-
function: Callable[[float], Point3DLike],
107+
function: Callable[[float | np.ndarray], Point3DLike],
108108
t_range: tuple[float, float] | tuple[float, float, float] = (0, 1),
109109
scaling: _ScaleBase = LinearBase(),
110110
dt: float = 1e-8,
@@ -113,7 +113,7 @@ def __init__(
113113
use_vectorized: bool = False,
114114
**kwargs,
115115
):
116-
def internal_parametric_function(t: float) -> Point3D:
116+
def internal_parametric_function(t: float | np.ndarray) -> Point3D:
117117
"""Wrap ``function``'s output inside a NumPy array."""
118118
return np.asarray(function(t))
119119

@@ -143,23 +143,23 @@ def generate_points(self) -> Self:
143143
lambda t: self.t_min <= t <= self.t_max,
144144
self.discontinuities,
145145
)
146-
discontinuities = np.array(list(discontinuities))
146+
discontinuities_array = np.array(list(discontinuities))
147147
boundary_times = np.array(
148148
[
149149
self.t_min,
150150
self.t_max,
151-
*(discontinuities - self.dt),
152-
*(discontinuities + self.dt),
151+
*(discontinuities_array + self.dt),
152+
*(discontinuities_array - self.dt),
153153
],
154154
)
155155
boundary_times.sort()
156156
else:
157-
boundary_times = [self.t_min, self.t_max]
157+
boundary_times = np.array([self.t_min, self.t_max])
158158

159159
for t1, t2 in zip(boundary_times[0::2], boundary_times[1::2]):
160160
t_range = np.array(
161161
[
162-
*self.scaling.function(np.arange(t1, t2, self.t_step)),
162+
*(self.scaling.function(t) for t in np.arange(t1, t2, self.t_step)),
163163
self.scaling.function(t2),
164164
],
165165
)

manim/mobject/graphing/scale.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class _ScaleBase:
2626
def __init__(self, custom_labels: bool = False):
2727
self.custom_labels = custom_labels
2828

29-
def function(self, value: float) -> float:
29+
def function(self, value: float | np.ndarray) -> float:
3030
"""The function that will be used to scale the values.
3131
3232
Parameters

0 commit comments

Comments
 (0)