Skip to content

Commit 32802c7

Browse files
committed
Make overloads respect keyword-only args
This commit resolves python#1907. Specifically, it turned out that support for non-positional args in overload was never implemented to begin with. Thankfully, it also turned out the bulk of the logic we wanted was already implemented within `mypy.subtypes.is_callable_subtype`. Rather then re-implementing that code, this commit refactors that method to support any kind of check, instead of specifically subtype checks. This, as a side-effect, ended up making some partial progress towards python#4159 -- this is because unlike the existing checks, `mypy.subtypes.is_callable_subtype` *doesn't* erase types and has better support for typevars in general. The reason this commit does not fully remove type erasure from overload checks is because the new implementation still calls `mypy.meet.is_overlapping_types` which *does* perform erasure. But fixing that seemed out-of-scope for this commit, so I stopped here.
1 parent af7e834 commit 32802c7

File tree

5 files changed

+139
-90
lines changed

5 files changed

+139
-90
lines changed

mypy/checker.py

Lines changed: 37 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
Context, Decorator, PrintStmt, BreakStmt, PassStmt, ContinueStmt,
2121
ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, PromoteExpr,
2222
Import, ImportFrom, ImportAll, ImportBase,
23-
ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF,
23+
ARG_POS, ARG_STAR, ARG_NAMED, ARG_NAMED_OPT, LITERAL_TYPE, MDEF, GDEF,
2424
CONTRAVARIANT, COVARIANT, INVARIANT,
2525
)
2626
from mypy import nodes
@@ -39,7 +39,7 @@
3939
from mypy import messages
4040
from mypy.subtypes import (
4141
is_subtype, is_equivalent, is_proper_subtype, is_more_precise,
42-
restrict_subtype_away, is_subtype_ignoring_tvars, is_callable_subtype,
42+
restrict_subtype_away, is_subtype_ignoring_tvars, is_callable_compatible,
4343
unify_generic_callable, find_member
4444
)
4545
from mypy.maptype import map_instance_to_supertype
@@ -437,7 +437,8 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
437437

438438
assert isinstance(impl_type, CallableType)
439439
assert isinstance(sig1, CallableType)
440-
if not is_callable_subtype(impl_type, sig1, ignore_return=True):
440+
if not is_callable_compatible(impl_type, sig1,
441+
is_compat=is_subtype, ignore_return=True):
441442
self.msg.overloaded_signatures_arg_specific(i + 1, defn.impl)
442443
impl_type_subst = impl_type
443444
if impl_type.variables:
@@ -3530,37 +3531,45 @@ def is_unsafe_overlapping_signatures(signature: Type, other: Type) -> bool:
35303531
"""
35313532
if isinstance(signature, CallableType):
35323533
if isinstance(other, CallableType):
3533-
# TODO varargs
3534-
# TODO keyword args
3535-
# TODO erasure
35363534
# TODO allow to vary covariantly
3535+
35373536
# Check if the argument counts are overlapping.
35383537
min_args = max(signature.min_args, other.min_args)
3539-
max_args = min(len(signature.arg_types), len(other.arg_types))
3538+
max_args = min(signature.max_positional_args(), other.max_positional_args())
35403539
if min_args > max_args:
35413540
# Argument counts are not overlapping.
35423541
return False
3543-
# Signatures are overlapping iff if they are overlapping for the
3544-
# smallest common argument count.
3545-
for i in range(min_args):
3546-
t1 = signature.arg_types[i]
3547-
t2 = other.arg_types[i]
3548-
if not is_overlapping_types(t1, t2):
3549-
return False
3542+
3543+
# If one of the corresponding argument do NOT overlap,
3544+
# then the signatures are not overlapping.
3545+
if not is_callable_compatible(signature, other,
3546+
is_compat=is_overlapping_types,
3547+
ignore_return=True,
3548+
check_args_covariantly=True):
3549+
# TODO: this check (unlike the others) will erase types due to
3550+
# how is_overlapping_type is implemented. This should be
3551+
# fixed to make this check consistent with the others.
3552+
return False
3553+
35503554
# All arguments types for the smallest common argument count are
35513555
# overlapping => the signature is overlapping. The overlapping is
35523556
# safe if the return types are identical.
35533557
if is_same_type(signature.ret_type, other.ret_type):
35543558
return False
3559+
35553560
# If the first signature has more general argument types, the
35563561
# latter will never be called
35573562
if is_more_general_arg_prefix(signature, other):
35583563
return False
3564+
35593565
# Special case: all args are subtypes, and returns are subtypes
3560-
if (all(is_proper_subtype(s, o)
3561-
for (s, o) in zip(signature.arg_types, other.arg_types)) and
3562-
is_proper_subtype(signature.ret_type, other.ret_type)):
3566+
if is_callable_compatible(signature, other,
3567+
is_compat=is_proper_subtype,
3568+
check_args_covariantly=True):
35633569
return False
3570+
3571+
# If the first signature is NOT more precise then the second,
3572+
# then the overlap is unsafe.
35643573
return not is_more_precise_signature(signature, other)
35653574
return True
35663575

@@ -3569,12 +3578,11 @@ def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool:
35693578
"""Does t have wider arguments than s?"""
35703579
# TODO should an overload with additional items be allowed to be more
35713580
# general than one with fewer items (or just one item)?
3572-
# TODO check argument kinds and otherwise make more general
35733581
if isinstance(t, CallableType):
35743582
if isinstance(s, CallableType):
3575-
t, s = unify_generic_callables(t, s)
3576-
return all(is_proper_subtype(args, argt)
3577-
for argt, args in zip(t.arg_types, s.arg_types))
3583+
return is_callable_compatible(t, s,
3584+
is_compat=is_proper_subtype,
3585+
ignore_return=True)
35783586
elif isinstance(t, FunctionLike):
35793587
if isinstance(s, FunctionLike):
35803588
if len(t.items()) == len(s.items()):
@@ -3583,29 +3591,6 @@ def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool:
35833591
return False
35843592

35853593

3586-
def unify_generic_callables(t: CallableType,
3587-
s: CallableType) -> Tuple[CallableType,
3588-
CallableType]:
3589-
"""Make type variables in generic callables the same if possible.
3590-
3591-
Return updated callables. If we can't unify the type variables,
3592-
return the unmodified arguments.
3593-
"""
3594-
# TODO: Use this elsewhere when comparing generic callables.
3595-
if t.is_generic() and s.is_generic():
3596-
t_substitutions = {}
3597-
s_substitutions = {}
3598-
for tv1, tv2 in zip(t.variables, s.variables):
3599-
# Are these something we can unify?
3600-
if tv1.id != tv2.id and is_equivalent_type_var_def(tv1, tv2):
3601-
newdef = TypeVarDef.new_unification_variable(tv2)
3602-
t_substitutions[tv1.id] = TypeVarType(newdef)
3603-
s_substitutions[tv2.id] = TypeVarType(newdef)
3604-
return (cast(CallableType, expand_type(t, t_substitutions)),
3605-
cast(CallableType, expand_type(s, s_substitutions)))
3606-
return t, s
3607-
3608-
36093594
def is_equivalent_type_var_def(tv1: TypeVarDef, tv2: TypeVarDef) -> bool:
36103595
"""Are type variable definitions equivalent?
36113596
@@ -3621,26 +3606,22 @@ def is_equivalent_type_var_def(tv1: TypeVarDef, tv2: TypeVarDef) -> bool:
36213606

36223607

36233608
def is_same_arg_prefix(t: CallableType, s: CallableType) -> bool:
3624-
# TODO check argument kinds
3625-
return all(is_same_type(argt, args)
3626-
for argt, args in zip(t.arg_types, s.arg_types))
3609+
return is_callable_compatible(t, s,
3610+
is_compat=is_same_type,
3611+
ignore_return=True,
3612+
check_args_covariantly=True,
3613+
ignore_pos_arg_names=True)
36273614

36283615

36293616
def is_more_precise_signature(t: CallableType, s: CallableType) -> bool:
36303617
"""Is t more precise than s?
36313618
36323619
A signature t is more precise than s if all argument types and the return
36333620
type of t are more precise than the corresponding types in s.
3634-
3635-
Assume that the argument kinds and names are compatible, and that the
3636-
argument counts are overlapping.
36373621
"""
3638-
# TODO generic function types
3639-
# Only consider the common prefix of argument types.
3640-
for argt, args in zip(t.arg_types, s.arg_types):
3641-
if not is_more_precise(argt, args):
3642-
return False
3643-
return is_more_precise(t.ret_type, s.ret_type)
3622+
return is_callable_compatible(t, s,
3623+
is_compat=is_more_precise,
3624+
check_args_covariantly=True)
36443625

36453626

36463627
def infer_operator_assignment_method(typ: Type, operator: str) -> Tuple[bool, str]:

mypy/constraints.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,9 @@ def find_matching_overload_item(overloaded: Overloaded, template: CallableType)
525525
for item in items:
526526
# Return type may be indeterminate in the template, so ignore it when performing a
527527
# subtype check.
528-
if mypy.subtypes.is_callable_subtype(item, template, ignore_return=True):
528+
if mypy.subtypes.is_callable_compatible(item, template,
529+
is_compat=mypy.subtypes.is_subtype,
530+
ignore_return=True):
529531
return item
530532
# Fall back to the first item if we can't find a match. This is totally arbitrary --
531533
# maybe we should just bail out at this point.

mypy/subtypes.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,9 @@ def visit_type_var(self, left: TypeVarType) -> bool:
202202
def visit_callable_type(self, left: CallableType) -> bool:
203203
right = self.right
204204
if isinstance(right, CallableType):
205-
return is_callable_subtype(
205+
return is_callable_compatible(
206206
left, right,
207+
is_compat=is_subtype,
207208
ignore_pos_arg_names=self.ignore_pos_arg_names)
208209
elif isinstance(right, Overloaded):
209210
return all(is_subtype(left, item, self.check_type_parameter,
@@ -309,10 +310,12 @@ def visit_overloaded(self, left: Overloaded) -> bool:
309310
else:
310311
# If this one overlaps with the supertype in any way, but it wasn't
311312
# an exact match, then it's a potential error.
312-
if (is_callable_subtype(left_item, right_item, ignore_return=True,
313-
ignore_pos_arg_names=self.ignore_pos_arg_names) or
314-
is_callable_subtype(right_item, left_item, ignore_return=True,
315-
ignore_pos_arg_names=self.ignore_pos_arg_names)):
313+
if (is_callable_compatible(left_item, right_item,
314+
is_compat=is_subtype, ignore_return=True,
315+
ignore_pos_arg_names=self.ignore_pos_arg_names) or
316+
is_callable_compatible(right_item, left_item,
317+
is_compat=is_subtype, ignore_return=True,
318+
ignore_pos_arg_names=self.ignore_pos_arg_names)):
316319
# If this is an overload that's already been matched, there's no
317320
# problem.
318321
if left_item not in matched_overloads:
@@ -562,16 +565,22 @@ def non_method_protocol_members(tp: TypeInfo) -> List[str]:
562565
return result
563566

564567

565-
def is_callable_subtype(left: CallableType, right: CallableType,
566-
ignore_return: bool = False,
567-
ignore_pos_arg_names: bool = False,
568-
use_proper_subtype: bool = False) -> bool:
569-
"""Is left a subtype of right?"""
568+
def is_callable_compatible(left: CallableType, right: CallableType,
569+
*,
570+
is_compat: Callable[[Type, Type], bool],
571+
ignore_return: bool = False,
572+
ignore_pos_arg_names: bool = False,
573+
check_args_covariantly: bool = False) -> bool:
574+
"""Is the left compatible with the right, using the provided compatibility check?
570575
571-
if use_proper_subtype:
572-
is_compat = is_proper_subtype
573-
else:
574-
is_compat = is_subtype
576+
If 'check_args_covariantly' is set to True, check if the left's args is
577+
compatible with the right's instead of the other way around (contravariantly).
578+
579+
This function is mostly used to check if the left is a subtype of the right which
580+
is why the default is to check the args covariantly. However, it's occasionally
581+
useful to check the args using some other check, so we leave the variance
582+
configurable.
583+
"""
575584

576585
# If either function is implicitly typed, ignore positional arg names too
577586
if left.implicit or right.implicit:
@@ -604,6 +613,9 @@ def is_callable_subtype(left: CallableType, right: CallableType,
604613
if not ignore_return and not is_compat(left.ret_type, right.ret_type):
605614
return False
606615

616+
if check_args_covariantly:
617+
is_compat = flip_compat_check(is_compat)
618+
607619
if right.is_ellipsis_args:
608620
return True
609621

@@ -658,7 +670,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
658670
right_by_position = right.argument_by_position(j)
659671
assert right_by_position is not None
660672
if not are_args_compatible(left_by_position, right_by_position,
661-
ignore_pos_arg_names, use_proper_subtype):
673+
ignore_pos_arg_names, is_compat):
662674
return False
663675
j += 1
664676
continue
@@ -681,7 +693,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
681693
right_by_name = right.argument_by_name(name)
682694
assert right_by_name is not None
683695
if not are_args_compatible(left_by_name, right_by_name,
684-
ignore_pos_arg_names, use_proper_subtype):
696+
ignore_pos_arg_names, is_compat):
685697
return False
686698
continue
687699

@@ -690,7 +702,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
690702
if left_arg is None:
691703
return False
692704

693-
if not are_args_compatible(left_arg, right_arg, ignore_pos_arg_names, use_proper_subtype):
705+
if not are_args_compatible(left_arg, right_arg, ignore_pos_arg_names, is_compat):
694706
return False
695707

696708
done_with_positional = False
@@ -742,7 +754,7 @@ def are_args_compatible(
742754
left: FormalArgument,
743755
right: FormalArgument,
744756
ignore_pos_arg_names: bool,
745-
use_proper_subtype: bool) -> bool:
757+
is_compat: Callable[[Type, Type], bool]) -> bool:
746758
# If right has a specific name it wants this argument to be, left must
747759
# have the same.
748760
if right.name is not None and left.name != right.name:
@@ -753,18 +765,20 @@ def are_args_compatible(
753765
if right.pos is not None and left.pos != right.pos:
754766
return False
755767
# Left must have a more general type
756-
if use_proper_subtype:
757-
if not is_proper_subtype(right.typ, left.typ):
758-
return False
759-
else:
760-
if not is_subtype(right.typ, left.typ):
761-
return False
768+
if not is_compat(right.typ, left.typ):
769+
return False
762770
# If right's argument is optional, left's must also be.
763771
if not right.required and left.required:
764772
return False
765773
return True
766774

767775

776+
def flip_compat_check(is_compat: Callable[[Type, Type], bool]) -> Callable[[Type, Type], bool]:
777+
def new_is_compat(left: Type, right: Type) -> bool:
778+
return is_compat(right, left)
779+
return new_is_compat
780+
781+
768782
def unify_generic_callable(type: CallableType, target: CallableType,
769783
ignore_return: bool) -> Optional[CallableType]:
770784
"""Try to unify a generic callable type with another callable type.
@@ -907,10 +921,7 @@ def visit_type_var(self, left: TypeVarType) -> bool:
907921
def visit_callable_type(self, left: CallableType) -> bool:
908922
right = self.right
909923
if isinstance(right, CallableType):
910-
return is_callable_subtype(
911-
left, right,
912-
ignore_pos_arg_names=False,
913-
use_proper_subtype=True)
924+
return is_callable_compatible(left, right, is_compat=is_proper_subtype)
914925
elif isinstance(right, Overloaded):
915926
return all(is_proper_subtype(left, item)
916927
for item in right.items())

mypy/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,13 @@ def max_fixed_args(self) -> int:
773773
n -= 1
774774
return n
775775

776+
def max_positional_args(self) -> int:
777+
"""Returns the number of positional args.
778+
779+
This includes *arg and **kwargs but excludes keyword-only args."""
780+
blacklist = (ARG_NAMED, ARG_NAMED_OPT)
781+
return len([kind not in blacklist for kind in self.arg_kinds])
782+
776783
def corresponding_argument(self, model: FormalArgument) -> Optional[FormalArgument]:
777784
"""Return the argument in this function that corresponds to `model`"""
778785

0 commit comments

Comments
 (0)