Skip to content

Add VectorNDLike type aliases #4068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions docs/source/contributing/docs/types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,8 @@ typed as a :class:`~.Point3D`, because it represents a direction along
which to shift a :class:`~.Mobject`, not a position in space.

As a general rule, if a parameter is called ``direction`` or ``axis``,
it should be type hinted as some form of :class:`~.VectorND`.

.. warning::

This is not always true. For example, as of Manim 0.18.0, the direction
parameter of the :class:`.Vector` Mobject should be
``Point2DLike | Point3DLike``, as it can also accept ``tuple[float, float]``
and ``tuple[float, float, float]``.
it should be type hinted as some form of :class:`~.VectorND` or
:class:`~.VectorNDLike`.

Colors
------
Expand Down
5 changes: 3 additions & 2 deletions manim/mobject/geometry/arc.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def construct(self):
Point3DLike,
QuadraticSpline,
Vector3D,
Vector3DLike,
)


Expand All @@ -99,12 +100,12 @@ class TipableVMobject(VMobject, metaclass=ConvertToOpenGL):
def __init__(
self,
tip_length: float = DEFAULT_ARROW_TIP_LENGTH,
normal_vector: Vector3D = OUT,
normal_vector: Vector3DLike = OUT,
tip_style: dict = {},
**kwargs: Any,
) -> None:
self.tip_length: float = tip_length
self.normal_vector: Vector3D = normal_vector
self.normal_vector: Vector3D = np.asarray(normal_vector)
self.tip_style: dict = tip_style
super().__init__(**kwargs)

Expand Down
6 changes: 3 additions & 3 deletions manim/mobject/geometry/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from typing_extensions import Literal, Self, TypeAlias

from manim.typing import Point2DLike, Point3D, Point3DLike, Vector3D
from manim.typing import Point3D, Point3DLike, Vector2DLike, Vector3D, Vector3DLike
from manim.utils.color import ParsableManimColor

from ..matrix import Matrix # Avoid circular import
Expand Down Expand Up @@ -153,7 +153,7 @@ def _set_start_and_end_attrs(
def _pointify(
self,
mob_or_point: Mobject | Point3DLike,
direction: Vector3D | None = None,
direction: Vector3DLike | None = None,
) -> Point3D:
"""Transforms a mobject into its corresponding point. Does nothing if a point is passed.

Expand Down Expand Up @@ -716,7 +716,7 @@ def construct(self):

def __init__(
self,
direction: Point2DLike | Point3DLike = RIGHT,
direction: Vector2DLike | Vector3DLike = RIGHT,
buff: float = 0,
**kwargs: Any,
) -> None:
Expand Down
21 changes: 11 additions & 10 deletions manim/mobject/graphing/coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
Point3D,
Point3DLike,
Vector3D,
Vector3DLike,
)

LineType = TypeVar("LineType", bound=Line)
Expand Down Expand Up @@ -349,8 +350,8 @@ def _get_axis_label(
self,
label: float | str | Mobject,
axis: Mobject,
edge: Sequence[float],
direction: Sequence[float],
edge: Vector3DLike,
direction: Vector3DLike,
buff: float = SMALL_BUFF,
) -> Mobject:
"""Gets the label for an axis.
Expand Down Expand Up @@ -2414,7 +2415,7 @@ def __init__(
y_length: float | None = config.frame_height + 2.5,
z_length: float | None = config.frame_height - 1.5,
z_axis_config: dict[str, Any] | None = None,
z_normal: Vector3D = DOWN,
z_normal: Vector3DLike = DOWN,
num_axis_pieces: int = 20,
light_source: Sequence[float] = 9 * DOWN + 7 * LEFT + 10 * OUT,
# opengl stuff (?)
Expand All @@ -2440,7 +2441,7 @@ def __init__(
self.z_axis_config,
)

self.z_normal = z_normal
self.z_normal: Vector3D = np.asarray(z_normal)
self.num_axis_pieces = num_axis_pieces

self.light_source = light_source
Expand Down Expand Up @@ -2501,11 +2502,11 @@ def make_func(axis):
def get_y_axis_label(
self,
label: float | str | Mobject,
edge: Sequence[float] = UR,
direction: Sequence[float] = UR,
edge: Vector3DLike = UR,
direction: Vector3DLike = UR,
buff: float = SMALL_BUFF,
rotation: float = PI / 2,
rotation_axis: Vector3D = OUT,
rotation_axis: Vector3DLike = OUT,
**kwargs,
) -> Mobject:
"""Generate a y-axis label.
Expand Down Expand Up @@ -2551,11 +2552,11 @@ def construct(self):
def get_z_axis_label(
self,
label: float | str | Mobject,
edge: Vector3D = OUT,
direction: Vector3D = RIGHT,
edge: Vector3DLike = OUT,
direction: Vector3DLike = RIGHT,
buff: float = SMALL_BUFF,
rotation: float = PI / 2,
rotation_axis: Vector3D = RIGHT,
rotation_axis: Vector3DLike = RIGHT,
**kwargs: Any,
) -> Mobject:
"""Generate a z-axis label.
Expand Down
69 changes: 37 additions & 32 deletions manim/mobject/mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
Point3D,
Point3DLike,
Point3DLike_Array,
Vector3D,
Vector3DLike,
)

from ..animation.animation import Animation
Expand Down Expand Up @@ -1198,7 +1198,7 @@ def apply_to_family(self, func: Callable[[Mobject], None]) -> None:
for mob in self.family_members_with_points():
func(mob)

def shift(self, *vectors: Vector3D) -> Self:
def shift(self, *vectors: Vector3DLike) -> Self:
"""Shift by the given vectors.

Parameters
Expand Down Expand Up @@ -1269,14 +1269,16 @@ def construct(self):
)
return self

def rotate_about_origin(self, angle: float, axis: Vector3D = OUT, axes=[]) -> Self:
def rotate_about_origin(
self, angle: float, axis: Vector3DLike = OUT, axes=[]
) -> Self:
"""Rotates the :class:`~.Mobject` about the ORIGIN, which is at [0,0,0]."""
return self.rotate(angle, axis, about_point=ORIGIN)

def rotate(
self,
angle: float,
axis: Vector3D = OUT,
axis: Vector3DLike = OUT,
about_point: Point3DLike | None = None,
**kwargs,
) -> Self:
Expand All @@ -1287,7 +1289,7 @@ def rotate(
)
return self

def flip(self, axis: Vector3D = UP, **kwargs) -> Self:
def flip(self, axis: Vector3DLike = UP, **kwargs) -> Self:
"""Flips/Mirrors an mobject about its center.

Examples
Expand Down Expand Up @@ -1407,7 +1409,7 @@ def apply_points_function_about_point(
self,
func: MultiMappingFunction,
about_point: Point3DLike | None = None,
about_edge: Vector3D | None = None,
about_edge: Vector3DLike | None = None,
) -> Self:
if about_point is None:
if about_edge is None:
Expand Down Expand Up @@ -1437,7 +1439,7 @@ def center(self) -> Self:
return self

def align_on_border(
self, direction: Vector3D, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER
self, direction: Vector3DLike, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER
) -> Self:
"""Direction just needs to be a vector pointing towards side or
corner in the 2d plane.
Expand All @@ -1454,7 +1456,7 @@ def align_on_border(
return self

def to_corner(
self, corner: Vector3D = DL, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER
self, corner: Vector3DLike = DL, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER
) -> Self:
"""Moves this :class:`~.Mobject` to the given corner of the screen.

Expand Down Expand Up @@ -1482,7 +1484,7 @@ def construct(self):
return self.align_on_border(corner, buff)

def to_edge(
self, edge: Vector3D = LEFT, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER
self, edge: Vector3DLike = LEFT, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER
) -> Self:
"""Moves this :class:`~.Mobject` to the given edge of the screen,
without affecting its position in the other dimension.
Expand Down Expand Up @@ -1514,12 +1516,12 @@ def construct(self):
def next_to(
self,
mobject_or_point: Mobject | Point3DLike,
direction: Vector3D = RIGHT,
direction: Vector3DLike = RIGHT,
buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER,
aligned_edge: Vector3D = ORIGIN,
aligned_edge: Vector3DLike = ORIGIN,
submobject_to_align: Mobject | None = None,
index_of_submobject_to_align: int | None = None,
coor_mask: Vector3D = np.array([1, 1, 1]),
coor_mask: Vector3DLike = np.array([1, 1, 1]),
) -> Self:
"""Move this :class:`~.Mobject` next to another's :class:`~.Mobject` or Point3D.

Expand All @@ -1541,6 +1543,9 @@ def construct(self):
self.add(d, c, s, t)

"""
direction = np.asarray(direction)
aligned_edge = np.asarray(aligned_edge)

if isinstance(mobject_or_point, Mobject):
mob = mobject_or_point
if index_of_submobject_to_align is not None:
Expand Down Expand Up @@ -1703,22 +1708,22 @@ def stretch_to_fit_depth(self, depth: float, **kwargs) -> Self:
"""Stretches the :class:`~.Mobject` to fit a depth, not keeping width/height proportional."""
return self.rescale_to_fit(depth, 2, stretch=True, **kwargs)

def set_coord(self, value, dim: int, direction: Vector3D = ORIGIN) -> Self:
def set_coord(self, value, dim: int, direction: Vector3DLike = ORIGIN) -> Self:
curr = self.get_coord(dim, direction)
shift_vect = np.zeros(self.dim)
shift_vect[dim] = value - curr
self.shift(shift_vect)
return self

def set_x(self, x: float, direction: Vector3D = ORIGIN) -> Self:
def set_x(self, x: float, direction: Vector3DLike = ORIGIN) -> Self:
"""Set x value of the center of the :class:`~.Mobject` (``int`` or ``float``)"""
return self.set_coord(x, 0, direction)

def set_y(self, y: float, direction: Vector3D = ORIGIN) -> Self:
def set_y(self, y: float, direction: Vector3DLike = ORIGIN) -> Self:
"""Set y value of the center of the :class:`~.Mobject` (``int`` or ``float``)"""
return self.set_coord(y, 1, direction)

def set_z(self, z: float, direction: Vector3D = ORIGIN) -> Self:
def set_z(self, z: float, direction: Vector3DLike = ORIGIN) -> Self:
"""Set z value of the center of the :class:`~.Mobject` (``int`` or ``float``)"""
return self.set_coord(z, 2, direction)

Expand All @@ -1731,16 +1736,16 @@ def space_out_submobjects(self, factor: float = 1.5, **kwargs) -> Self:
def move_to(
self,
point_or_mobject: Point3DLike | Mobject,
aligned_edge: Vector3D = ORIGIN,
coor_mask: Vector3D = np.array([1, 1, 1]),
aligned_edge: Vector3DLike = ORIGIN,
coor_mask: Vector3DLike = np.array([1, 1, 1]),
) -> Self:
"""Move center of the :class:`~.Mobject` to certain Point3D."""
if isinstance(point_or_mobject, Mobject):
target = point_or_mobject.get_critical_point(aligned_edge)
else:
target = point_or_mobject
point_to_align = self.get_critical_point(aligned_edge)
self.shift((target - point_to_align) * coor_mask)
self.shift((target - point_to_align) * np.asarray(coor_mask))
return self

def replace(
Expand Down Expand Up @@ -2051,7 +2056,7 @@ def get_extremum_along_dim(
else:
return np.max(values)

def get_critical_point(self, direction: Vector3D) -> Point3D:
def get_critical_point(self, direction: Vector3DLike) -> Point3D:
"""Picture a box bounding the :class:`~.Mobject`. Such a box has
9 'critical points': 4 corners, 4 edge center, the
center. This returns one of them, along the given direction.
Expand Down Expand Up @@ -2080,11 +2085,11 @@ def get_critical_point(self, direction: Vector3D) -> Point3D:

# Pseudonyms for more general get_critical_point method

def get_edge_center(self, direction: Vector3D) -> Point3D:
def get_edge_center(self, direction: Vector3DLike) -> Point3D:
"""Get edge Point3Ds for certain direction."""
return self.get_critical_point(direction)

def get_corner(self, direction: Vector3D) -> Point3D:
def get_corner(self, direction: Vector3DLike) -> Point3D:
"""Get corner Point3Ds for certain direction."""
return self.get_critical_point(direction)

Expand All @@ -2095,9 +2100,9 @@ def get_center(self) -> Point3D:
def get_center_of_mass(self) -> Point3D:
return np.apply_along_axis(np.mean, 0, self.get_all_points())

def get_boundary_point(self, direction: Vector3D) -> Point3D:
def get_boundary_point(self, direction: Vector3DLike) -> Point3D:
all_points = self.get_points_defining_boundary()
index = np.argmax(np.dot(all_points, np.array(direction).T))
index = np.argmax(np.dot(all_points, np.asarray(direction)))
return all_points[index]

def get_midpoint(self) -> Point3D:
Expand Down Expand Up @@ -2154,19 +2159,19 @@ def length_over_dim(self, dim: int) -> float:
dim,
) - self.reduce_across_dimension(min, dim)

def get_coord(self, dim: int, direction: Vector3D = ORIGIN):
def get_coord(self, dim: int, direction: Vector3DLike = ORIGIN) -> float:
"""Meant to generalize ``get_x``, ``get_y`` and ``get_z``"""
return self.get_extremum_along_dim(dim=dim, key=direction[dim])

def get_x(self, direction: Vector3D = ORIGIN) -> float:
def get_x(self, direction: Vector3DLike = ORIGIN) -> float:
"""Returns x Point3D of the center of the :class:`~.Mobject` as ``float``"""
return self.get_coord(0, direction)

def get_y(self, direction: Vector3D = ORIGIN) -> float:
def get_y(self, direction: Vector3DLike = ORIGIN) -> float:
"""Returns y Point3D of the center of the :class:`~.Mobject` as ``float``"""
return self.get_coord(1, direction)

def get_z(self, direction: Vector3D = ORIGIN) -> float:
def get_z(self, direction: Vector3DLike = ORIGIN) -> float:
"""Returns z Point3D of the center of the :class:`~.Mobject` as ``float``"""
return self.get_coord(2, direction)

Expand Down Expand Up @@ -2237,7 +2242,7 @@ def match_depth(self, mobject: Mobject, **kwargs) -> Self:
return self.match_dim_size(mobject, 2, **kwargs)

def match_coord(
self, mobject: Mobject, dim: int, direction: Vector3D = ORIGIN
self, mobject: Mobject, dim: int, direction: Vector3DLike = ORIGIN
) -> Self:
"""Match the Point3Ds with the Point3Ds of another :class:`~.Mobject`."""
return self.set_coord(
Expand All @@ -2261,7 +2266,7 @@ def match_z(self, mobject: Mobject, direction=ORIGIN) -> Self:
def align_to(
self,
mobject_or_point: Mobject | Point3DLike,
direction: Vector3D = ORIGIN,
direction: Vector3DLike = ORIGIN,
) -> Self:
"""Aligns mobject to another :class:`~.Mobject` in a certain direction.

Expand Down Expand Up @@ -2316,7 +2321,7 @@ def family_members_with_points(self) -> list[Self]:

def arrange(
self,
direction: Vector3D = RIGHT,
direction: Vector3DLike = RIGHT,
buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER,
center: bool = True,
**kwargs,
Expand Down Expand Up @@ -2349,7 +2354,7 @@ def arrange_in_grid(
rows: int | None = None,
cols: int | None = None,
buff: float | tuple[float, float] = MED_SMALL_BUFF,
cell_alignment: Vector3D = ORIGIN,
cell_alignment: Vector3DLike = ORIGIN,
row_alignments: str | None = None, # "ucd"
col_alignments: str | None = None, # "lcr"
row_heights: Iterable[float | None] | None = None,
Expand Down
Loading
Loading