Skip to content

Add type hints to graphing #4118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 26 additions & 19 deletions manim/mobject/graphing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable

import numpy as np
from isosurfaces import plot_isoline
Expand All @@ -21,7 +21,7 @@

from manim.typing import Point3D, Point3DLike

from manim.utils.color import YELLOW
from manim.utils.color import YELLOW, ManimColor


class ParametricFunction(VMobject, metaclass=ConvertToOpenGL):
Expand Down Expand Up @@ -111,13 +111,13 @@
discontinuities: Iterable[float] | None = None,
use_smoothing: bool = True,
use_vectorized: bool = False,
**kwargs,
):
**kwargs: Any,
) -> None:
def internal_parametric_function(t: float) -> Point3D:
"""Wrap ``function``'s output inside a NumPy array."""
return np.asarray(function(t))

self.function = internal_parametric_function

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute function, which was previously defined in subclass
FunctionGraph
.
if len(t_range) == 2:
t_range = (*t_range, 0.01)

Expand All @@ -139,11 +139,11 @@

def generate_points(self) -> Self:
if self.discontinuities is not None:
discontinuities = filter(
discontinuities_filter = filter(
lambda t: self.t_min <= t <= self.t_max,
self.discontinuities,
)
discontinuities = np.array(list(discontinuities))
discontinuities = np.array(list(discontinuities_filter))
boundary_times = np.array(
[
self.t_min,
Expand All @@ -154,15 +154,14 @@
)
boundary_times.sort()
else:
boundary_times = [self.t_min, self.t_max]
boundary_times = np.array([self.t_min, self.t_max])

for t1, t2 in zip(boundary_times[0::2], boundary_times[1::2]):
t_range = np.array(
[
*self.scaling.function(np.arange(t1, t2, self.t_step)),
self.scaling.function(t2),
],
[self.scaling.function(t) for t in np.arange(t1, t2, self.t_step)]
)
if t_range[-1] != self.scaling.function(t2):
t_range = np.append(t_range, self.scaling.function(t2))

if self.use_vectorized:
x, y, z = self.function(t_range)
Expand Down Expand Up @@ -211,19 +210,27 @@
self.add(cos_func, sin_func_1, sin_func_2)
"""

def __init__(self, function, x_range=None, color=YELLOW, **kwargs):
def __init__(
self,
function: Callable[[float], Point3D],
x_range: tuple[float, float] | None = None,
color: ManimColor = YELLOW,
**kwargs: Any,
) -> None:
if x_range is None:
x_range = np.array([-config["frame_x_radius"], config["frame_x_radius"]])
x_range = (-config["frame_x_radius"], config["frame_x_radius"])

self.x_range = x_range
self.parametric_function = lambda t: np.array([t, function(t), 0])
self.function = function
self.parametric_function: Callable[[float], Point3D] = lambda t: np.array(
[t, function(t), 0]
)
self.function = function # type: ignore[assignment]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know how to resolve this.

super().__init__(self.parametric_function, self.x_range, color=color, **kwargs)

def get_function(self):
def get_function(self) -> Callable[[float], Point3D]:
return self.function

def get_point_from_function(self, x):
def get_point_from_function(self, x: float) -> Point3D:
return self.parametric_function(x)


Expand All @@ -236,7 +243,7 @@
min_depth: int = 5,
max_quads: int = 1500,
use_smoothing: bool = True,
**kwargs,
**kwargs: Any,
):
"""An implicit function.

Expand Down Expand Up @@ -295,7 +302,7 @@

super().__init__(**kwargs)

def generate_points(self):
def generate_points(self) -> Self:
p_min, p_max = (
np.array([self.x_range[0], self.y_range[0]]),
np.array([self.x_range[1], self.y_range[1]]),
Expand Down
121 changes: 79 additions & 42 deletions manim/mobject/graphing/probability.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


from collections.abc import Iterable, MutableSequence, Sequence
from typing import TYPE_CHECKING, Any

import numpy as np

Expand All @@ -16,20 +17,29 @@
from manim.mobject.mobject import Mobject
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
from manim.mobject.svg.brace import Brace
from manim.mobject.text.tex_mobject import MathTex, Tex
from manim.mobject.types.vectorized_mobject import VGroup, VMobject
from manim.mobject.text.tex_mobject import MathTex, SingleStringMathTex, Tex
from manim.mobject.text.text_mobject import MarkupText, Text
from manim.mobject.types.vectorized_mobject import VGroup
from manim.utils.color import (
BLUE_E,
DARK_GREY,
GREEN_E,
LIGHT_GREY,
MAROON_B,
YELLOW,
ManimColor,
ParsableManimColor,
color_gradient,
)
from manim.utils.iterables import tuplify

if TYPE_CHECKING:
from typing_extensions import TypeAlias

from manim.typing import Vector3D

TextLike: TypeAlias = SingleStringMathTex | Text | MathTex | MarkupText

EPSILON = 0.0001


Expand All @@ -54,13 +64,13 @@ def construct(self):

def __init__(
self,
height=3,
width=3,
fill_color=DARK_GREY,
fill_opacity=1,
stroke_width=0.5,
stroke_color=LIGHT_GREY,
default_label_scale_val=1,
height: float = 3,
width: float = 3,
fill_color: ParsableManimColor = DARK_GREY,
fill_opacity: float = 1,
stroke_width: float = 0.5,
stroke_color: ParsableManimColor = LIGHT_GREY,
default_label_scale_val: float = 1,
):
super().__init__(
height=height,
Expand All @@ -72,7 +82,9 @@ def __init__(
)
self.default_label_scale_val = default_label_scale_val

def add_title(self, title="Sample space", buff=MED_SMALL_BUFF):
def add_title(
self, title: str = "Sample space", buff: float = MED_SMALL_BUFF
) -> None:
# TODO, should this really exist in SampleSpaceScene
title_mob = Tex(title)
if title_mob.width > self.width:
Expand All @@ -81,23 +93,31 @@ def add_title(self, title="Sample space", buff=MED_SMALL_BUFF):
self.title = title_mob
self.add(title_mob)

def add_label(self, label):
def add_label(self, label: str) -> None:
self.label = label

def complete_p_list(self, p_list):
def complete_p_list(self, p_list: list) -> list:
new_p_list = list(tuplify(p_list))
remainder = 1.0 - sum(new_p_list)
if abs(remainder) > EPSILON:
new_p_list.append(remainder)
return new_p_list

def get_division_along_dimension(self, p_list, dim, colors, vect):
def get_division_along_dimension(
self,
p_list: list,
dim: int,
reference_colors: Sequence[ParsableManimColor],
vect: Vector3D,
) -> VGroup:
p_list = self.complete_p_list(p_list)
colors = color_gradient(colors, len(p_list))
colors = color_gradient(reference_colors, len(p_list))

last_point = self.get_edge_center(-vect)
parts = VGroup()
for factor, color in zip(p_list, colors):
for factor, color in zip(
p_list, colors if isinstance(colors, list) else [colors]
):
part = SampleSpace()
part.set_fill(color, 1)
part.replace(self, stretch=True)
Expand All @@ -107,28 +127,38 @@ def get_division_along_dimension(self, p_list, dim, colors, vect):
parts.add(part)
return parts

def get_horizontal_division(self, p_list, colors=[GREEN_E, BLUE_E], vect=DOWN):
def get_horizontal_division(
self,
p_list: list,
colors: list[ManimColor] = [GREEN_E, BLUE_E],
vect: Vector3D = DOWN,
) -> VGroup:
return self.get_division_along_dimension(p_list, 1, colors, vect)

def get_vertical_division(self, p_list, colors=[MAROON_B, YELLOW], vect=RIGHT):
def get_vertical_division(
self,
p_list: list,
colors: list[ManimColor] = [MAROON_B, YELLOW],
vect: Vector3D = RIGHT,
) -> VGroup:
return self.get_division_along_dimension(p_list, 0, colors, vect)

def divide_horizontally(self, *args, **kwargs):
def divide_horizontally(self, *args: Any, **kwargs: Any) -> None:
self.horizontal_parts = self.get_horizontal_division(*args, **kwargs)
self.add(self.horizontal_parts)

def divide_vertically(self, *args, **kwargs):
def divide_vertically(self, *args: Any, **kwargs: Any) -> None:
self.vertical_parts = self.get_vertical_division(*args, **kwargs)
self.add(self.vertical_parts)

def get_subdivision_braces_and_labels(
self,
parts,
labels,
direction,
buff=SMALL_BUFF,
min_num_quads=1,
):
parts: VGroup,
labels: list[str],
direction: Vector3D,
buff: float = SMALL_BUFF,
min_num_quads: int = 1,
) -> VGroup:
label_mobs = VGroup()
braces = VGroup()
for label, part in zip(labels, parts):
Expand All @@ -151,24 +181,26 @@ def get_subdivision_braces_and_labels(
}
return VGroup(parts.braces, parts.labels)

def get_side_braces_and_labels(self, labels, direction=LEFT, **kwargs):
def get_side_braces_and_labels(
self, labels: list[str], direction: Vector3D = LEFT, **kwargs: Any
) -> VGroup:
assert hasattr(self, "horizontal_parts")
parts = self.horizontal_parts
return self.get_subdivision_braces_and_labels(
parts, labels, direction, **kwargs
)

def get_top_braces_and_labels(self, labels, **kwargs):
def get_top_braces_and_labels(self, labels: list[str], **kwargs: Any) -> VGroup:
assert hasattr(self, "vertical_parts")
parts = self.vertical_parts
return self.get_subdivision_braces_and_labels(parts, labels, UP, **kwargs)

def get_bottom_braces_and_labels(self, labels, **kwargs):
def get_bottom_braces_and_labels(self, labels: list[str], **kwargs: Any) -> VGroup:
assert hasattr(self, "vertical_parts")
parts = self.vertical_parts
return self.get_subdivision_braces_and_labels(parts, labels, DOWN, **kwargs)

def add_braces_and_labels(self):
def add_braces_and_labels(self) -> None:
for attr in "horizontal_parts", "vertical_parts":
if not hasattr(self, attr):
continue
Expand All @@ -177,11 +209,11 @@ def add_braces_and_labels(self):
if hasattr(parts, subattr):
self.add(getattr(parts, subattr))

def __getitem__(self, index):
def __getitem__(self, index: int) -> Mobject:
if hasattr(self, "horizontal_parts"):
return self.horizontal_parts[index]
return self.horizontal_parts[index] # type: ignore [no-any-return]
elif hasattr(self, "vertical_parts"):
return self.vertical_parts[index]
return self.vertical_parts[index] # type: ignore [no-any-return]
return self.split()[index]


Expand Down Expand Up @@ -243,7 +275,7 @@ def __init__(
y_range: Sequence[float] | None = None,
x_length: float | None = None,
y_length: float | None = None,
bar_colors: Iterable[str] = [
bar_colors: Iterable[ParsableManimColor] = [
"#003f5c",
"#58508d",
"#bc5090",
Expand All @@ -253,8 +285,8 @@ def __init__(
bar_width: float = 0.6,
bar_fill_opacity: float = 0.7,
bar_stroke_width: float = 3,
**kwargs,
):
**kwargs: Any,
) -> None:
if isinstance(bar_colors, str):
logger.warning(
"Passing a string to `bar_colors` has been deprecated since v0.15.2 and will be removed after v0.17.0, the parameter must be a list. "
Expand Down Expand Up @@ -311,7 +343,7 @@ def __init__(

self.y_axis.add_numbers()

def _update_colors(self):
def _update_colors(self) -> None:
"""Initialize the colors of the bars of the chart.

Sets the color of ``self.bars`` via ``self.bar_colors``.
Expand All @@ -321,13 +353,16 @@ def _update_colors(self):
"""
self.bars.set_color_by_gradient(*self.bar_colors)

def _add_x_axis_labels(self):
def _add_x_axis_labels(self) -> None:
"""Essentially :meth`:~.NumberLine.add_labels`, but differs in that
the direction of the label with respect to the x_axis changes to UP or DOWN
depending on the value.

UP for negative values and DOWN for positive values.
"""
if self.bar_names is None:
return

val_range = np.arange(
0.5, len(self.bar_names), 1
) # 0.5 shifted so that labels are centered, not on ticks
Expand Down Expand Up @@ -398,8 +433,8 @@ def get_bar_labels(
color: ParsableManimColor | None = None,
font_size: float = 24,
buff: float = MED_SMALL_BUFF,
label_constructor: type[VMobject] = Tex,
):
label_constructor: type[TextLike] = Tex,
) -> VGroup:
"""Annotates each bar with its corresponding value. Use ``self.bar_labels`` to access the
labels after creation.

Expand All @@ -408,9 +443,9 @@ def get_bar_labels(
color
The color of each label. By default ``None`` and is based on the parent's bar color.
font_size
The font size of each label.
The font size of each label. By default 24.
buff
The distance from each label to its bar. By default 0.4.
The distance from each label to its bar. By default 0.25.
label_constructor
The Mobject class to construct the labels, by default :class:`~.Tex`.

Expand Down Expand Up @@ -446,7 +481,9 @@ def construct(self):

return bar_labels

def change_bar_values(self, values: Iterable[float], update_colors: bool = True):
def change_bar_values(
self, values: MutableSequence[float], update_colors: bool = True
) -> None:
"""Updates the height of the bars of the chart.

Parameters
Expand Down
Loading
Loading