Skip to content

perf: try to cache inner contexts of overloads #19408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
88 changes: 65 additions & 23 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@
"builtins.memoryview",
}

POISON_KEY: Final = (-1,)


class TooManyUnions(Exception):
"""Indicates that we need to stop splitting unions in an attempt
Expand Down Expand Up @@ -356,7 +358,12 @@ def __init__(

self._arg_infer_context_cache = None

self.overload_stack_depth = 0
self._args_cache: dict[tuple[int, ...], list[Type]] = {}

def reset(self) -> None:
assert self.overload_stack_depth == 0
assert not self._args_cache
self.resolved_type = {}

def visit_name_expr(self, e: NameExpr) -> Type:
Expand Down Expand Up @@ -1613,9 +1620,10 @@ def check_call(
object_type,
)
elif isinstance(callee, Overloaded):
return self.check_overload_call(
callee, args, arg_kinds, arg_names, callable_name, object_type, context
)
with self.overload_context():
return self.check_overload_call(
callee, args, arg_kinds, arg_names, callable_name, object_type, context
)
elif isinstance(callee, AnyType) or not self.chk.in_checked_function():
return self.check_any_type_call(args, callee)
elif isinstance(callee, UnionType):
Expand Down Expand Up @@ -1678,6 +1686,14 @@ def check_call(
else:
return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)

@contextmanager
def overload_context(self) -> Iterator[None]:
self.overload_stack_depth += 1
yield
self.overload_stack_depth -= 1
if self.overload_stack_depth == 0:
self._args_cache.clear()

def check_callable_call(
self,
callee: CallableType,
Expand Down Expand Up @@ -1935,20 +1951,40 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type:
self.msg.unsupported_type_type(item, context)
return AnyType(TypeOfAny.from_error)

def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]:
def infer_arg_types_in_empty_context(
self, args: list[Expression], *, allow_cache: bool
) -> list[Type]:
"""Infer argument expression types in an empty context.

In short, we basically recurse on each argument without considering
in what context the argument was called.
"""
# We can only use this hack locally while checking a single nested overloaded
# call. This saves a lot of rechecking, but is not generally safe. Cache is
# pruned upon leaving the outermost overload.
can_cache = (
allow_cache
and POISON_KEY not in self._args_cache
and not any(isinstance(t, TempNode) for t in args)
)
key = tuple(map(id, args))
if can_cache and key in self._args_cache:
return self._args_cache[key]
res: list[Type] = []

for arg in args:
arg_type = self.accept(arg)
if has_erased_component(arg_type):
res.append(NoneType())
else:
res.append(arg_type)
with self.msg.filter_errors(filter_errors=True, save_filtered_errors=True) as w:
for arg in args:
arg_type = self.accept(arg)
if has_erased_component(arg_type):
res.append(NoneType())
else:
res.append(arg_type)

if w.has_new_errors():
self.msg.add_errors(w.filtered_errors())
elif can_cache:
# Do not cache if new diagnostics were emitted: they may impact parent overload
self._args_cache[key] = res
return res

def infer_more_unions_for_recursive_type(self, type_context: Type) -> bool:
Expand Down Expand Up @@ -2712,7 +2748,7 @@ def check_overload_call(
"""Checks a call to an overloaded function."""
# Normalize unpacked kwargs before checking the call.
callee = callee.with_unpacked_kwargs()
arg_types = self.infer_arg_types_in_empty_context(args)
arg_types = self.infer_arg_types_in_empty_context(args, allow_cache=True)
# Step 1: Filter call targets to remove ones where the argument counts don't match
plausible_targets = self.plausible_overload_call_targets(
arg_types, arg_kinds, arg_names, callee
Expand Down Expand Up @@ -2921,17 +2957,16 @@ def infer_overload_return_type(

for typ in plausible_targets:
assert self.msg is self.chk.msg
with self.msg.filter_errors() as w:
with self.chk.local_type_map() as m:
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
with self.msg.filter_errors() as w, self.chk.local_type_map() as m:
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
is_match = not w.has_new_errors()
if is_match:
# Return early if possible; otherwise record info, so we can
Expand Down Expand Up @@ -3307,7 +3342,7 @@ def apply_generic_arguments(
)

def check_any_type_call(self, args: list[Expression], callee: Type) -> tuple[Type, Type]:
self.infer_arg_types_in_empty_context(args)
self.infer_arg_types_in_empty_context(args, allow_cache=False)
callee = get_proper_type(callee)
if isinstance(callee, AnyType):
return (
Expand Down Expand Up @@ -3478,6 +3513,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
return self.strfrm_checker.check_str_interpolation(e.left, e.right)
if isinstance(e.left, StrExpr):
return self.strfrm_checker.check_str_interpolation(e.left, e.right)

left_type = self.accept(e.left)

proper_left_type = get_proper_type(left_type)
Expand Down Expand Up @@ -4350,6 +4386,9 @@ def check_list_multiply(self, e: OpExpr) -> Type:
return result

def visit_assignment_expr(self, e: AssignmentExpr) -> Type:
if self.overload_stack_depth > 0:
# Poison cache when we encounter assignments in overloads - they affect the binder.
self._args_cache[POISON_KEY] = []
value = self.accept(e.value)
self.chk.check_assignment(e.target, e.value)
self.chk.check_final(e)
Expand Down Expand Up @@ -5405,6 +5444,9 @@ def find_typeddict_context(

def visit_lambda_expr(self, e: LambdaExpr) -> Type:
"""Type check lambda expression."""
if self.overload_stack_depth > 0:
# Poison cache when we encounter lambdas - it isn't safe to cache their types.
self._args_cache[POISON_KEY] = []
self.chk.check_default_args(e, body_is_trivial=False)
inferred_type, type_override = self.infer_lambda_type_using_context(e)
if not inferred_type:
Expand Down