Skip to content

Commit f70f808

Browse files
committed
Make overloads respect keyword-only args
This commit resolves #1907. Specifically, it turned out that support for non-positional args in overload was never implemented to begin with. Thankfully, it also turned out the bulk of the logic we wanted was already implemented within `mypy.subtypes.is_callable_subtype`. Rather then re-implementing that code, this commit refactors that method to support any kind of check, instead of specifically subtype checks. This, as a side-effect, ended up making some partial progress towards #4159 -- this is because unlike the existing checks, `mypy.subtypes.is_callable_subtype` *doesn't* erase types and has better support for typevars in general. The reason this commit does not fully remove type erasure from overload checks is because the new implementation still calls `mypy.meet.is_overlapping_types` which *does* perform erasure. But fixing that seemed out-of-scope for this commit, so I stopped here.
1 parent a2fd962 commit f70f808

File tree

5 files changed

+139
-90
lines changed

5 files changed

+139
-90
lines changed

mypy/checker.py

Lines changed: 37 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
Context, Decorator, PrintStmt, BreakStmt, PassStmt, ContinueStmt,
2121
ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, PromoteExpr,
2222
Import, ImportFrom, ImportAll, ImportBase,
23-
ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF,
23+
ARG_POS, ARG_STAR, ARG_NAMED, ARG_NAMED_OPT, LITERAL_TYPE, MDEF, GDEF,
2424
CONTRAVARIANT, COVARIANT, INVARIANT,
2525
)
2626
from mypy import nodes
@@ -39,7 +39,7 @@
3939
from mypy import messages
4040
from mypy.subtypes import (
4141
is_subtype, is_equivalent, is_proper_subtype, is_more_precise,
42-
restrict_subtype_away, is_subtype_ignoring_tvars, is_callable_subtype,
42+
restrict_subtype_away, is_subtype_ignoring_tvars, is_callable_compatible,
4343
unify_generic_callable, find_member
4444
)
4545
from mypy.maptype import map_instance_to_supertype
@@ -437,7 +437,8 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
437437

438438
assert isinstance(impl_type, CallableType)
439439
assert isinstance(sig1, CallableType)
440-
if not is_callable_subtype(impl_type, sig1, ignore_return=True):
440+
if not is_callable_compatible(impl_type, sig1,
441+
is_compat=is_subtype, ignore_return=True):
441442
self.msg.overloaded_signatures_arg_specific(i + 1, defn.impl)
442443
impl_type_subst = impl_type
443444
if impl_type.variables:
@@ -3524,37 +3525,45 @@ def is_unsafe_overlapping_signatures(signature: Type, other: Type) -> bool:
35243525
"""
35253526
if isinstance(signature, CallableType):
35263527
if isinstance(other, CallableType):
3527-
# TODO varargs
3528-
# TODO keyword args
3529-
# TODO erasure
35303528
# TODO allow to vary covariantly
3529+
35313530
# Check if the argument counts are overlapping.
35323531
min_args = max(signature.min_args, other.min_args)
3533-
max_args = min(len(signature.arg_types), len(other.arg_types))
3532+
max_args = min(signature.max_positional_args(), other.max_positional_args())
35343533
if min_args > max_args:
35353534
# Argument counts are not overlapping.
35363535
return False
3537-
# Signatures are overlapping iff if they are overlapping for the
3538-
# smallest common argument count.
3539-
for i in range(min_args):
3540-
t1 = signature.arg_types[i]
3541-
t2 = other.arg_types[i]
3542-
if not is_overlapping_types(t1, t2):
3543-
return False
3536+
3537+
# If one of the corresponding argument do NOT overlap,
3538+
# then the signatures are not overlapping.
3539+
if not is_callable_compatible(signature, other,
3540+
is_compat=is_overlapping_types,
3541+
ignore_return=True,
3542+
check_args_covariantly=True):
3543+
# TODO: this check (unlike the others) will erase types due to
3544+
# how is_overlapping_type is implemented. This should be
3545+
# fixed to make this check consistent with the others.
3546+
return False
3547+
35443548
# All arguments types for the smallest common argument count are
35453549
# overlapping => the signature is overlapping. The overlapping is
35463550
# safe if the return types are identical.
35473551
if is_same_type(signature.ret_type, other.ret_type):
35483552
return False
3553+
35493554
# If the first signature has more general argument types, the
35503555
# latter will never be called
35513556
if is_more_general_arg_prefix(signature, other):
35523557
return False
3558+
35533559
# Special case: all args are subtypes, and returns are subtypes
3554-
if (all(is_proper_subtype(s, o)
3555-
for (s, o) in zip(signature.arg_types, other.arg_types)) and
3556-
is_proper_subtype(signature.ret_type, other.ret_type)):
3560+
if is_callable_compatible(signature, other,
3561+
is_compat=is_proper_subtype,
3562+
check_args_covariantly=True):
35573563
return False
3564+
3565+
# If the first signature is NOT more precise then the second,
3566+
# then the overlap is unsafe.
35583567
return not is_more_precise_signature(signature, other)
35593568
return True
35603569

@@ -3563,12 +3572,11 @@ def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool:
35633572
"""Does t have wider arguments than s?"""
35643573
# TODO should an overload with additional items be allowed to be more
35653574
# general than one with fewer items (or just one item)?
3566-
# TODO check argument kinds and otherwise make more general
35673575
if isinstance(t, CallableType):
35683576
if isinstance(s, CallableType):
3569-
t, s = unify_generic_callables(t, s)
3570-
return all(is_proper_subtype(args, argt)
3571-
for argt, args in zip(t.arg_types, s.arg_types))
3577+
return is_callable_compatible(t, s,
3578+
is_compat=is_proper_subtype,
3579+
ignore_return=True)
35723580
elif isinstance(t, FunctionLike):
35733581
if isinstance(s, FunctionLike):
35743582
if len(t.items()) == len(s.items()):
@@ -3577,29 +3585,6 @@ def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool:
35773585
return False
35783586

35793587

3580-
def unify_generic_callables(t: CallableType,
3581-
s: CallableType) -> Tuple[CallableType,
3582-
CallableType]:
3583-
"""Make type variables in generic callables the same if possible.
3584-
3585-
Return updated callables. If we can't unify the type variables,
3586-
return the unmodified arguments.
3587-
"""
3588-
# TODO: Use this elsewhere when comparing generic callables.
3589-
if t.is_generic() and s.is_generic():
3590-
t_substitutions = {}
3591-
s_substitutions = {}
3592-
for tv1, tv2 in zip(t.variables, s.variables):
3593-
# Are these something we can unify?
3594-
if tv1.id != tv2.id and is_equivalent_type_var_def(tv1, tv2):
3595-
newdef = TypeVarDef.new_unification_variable(tv2)
3596-
t_substitutions[tv1.id] = TypeVarType(newdef)
3597-
s_substitutions[tv2.id] = TypeVarType(newdef)
3598-
return (cast(CallableType, expand_type(t, t_substitutions)),
3599-
cast(CallableType, expand_type(s, s_substitutions)))
3600-
return t, s
3601-
3602-
36033588
def is_equivalent_type_var_def(tv1: TypeVarDef, tv2: TypeVarDef) -> bool:
36043589
"""Are type variable definitions equivalent?
36053590
@@ -3615,26 +3600,22 @@ def is_equivalent_type_var_def(tv1: TypeVarDef, tv2: TypeVarDef) -> bool:
36153600

36163601

36173602
def is_same_arg_prefix(t: CallableType, s: CallableType) -> bool:
3618-
# TODO check argument kinds
3619-
return all(is_same_type(argt, args)
3620-
for argt, args in zip(t.arg_types, s.arg_types))
3603+
return is_callable_compatible(t, s,
3604+
is_compat=is_same_type,
3605+
ignore_return=True,
3606+
check_args_covariantly=True,
3607+
ignore_pos_arg_names=True)
36213608

36223609

36233610
def is_more_precise_signature(t: CallableType, s: CallableType) -> bool:
36243611
"""Is t more precise than s?
36253612
36263613
A signature t is more precise than s if all argument types and the return
36273614
type of t are more precise than the corresponding types in s.
3628-
3629-
Assume that the argument kinds and names are compatible, and that the
3630-
argument counts are overlapping.
36313615
"""
3632-
# TODO generic function types
3633-
# Only consider the common prefix of argument types.
3634-
for argt, args in zip(t.arg_types, s.arg_types):
3635-
if not is_more_precise(argt, args):
3636-
return False
3637-
return is_more_precise(t.ret_type, s.ret_type)
3616+
return is_callable_compatible(t, s,
3617+
is_compat=is_more_precise,
3618+
check_args_covariantly=True)
36383619

36393620

36403621
def infer_operator_assignment_method(typ: Type, operator: str) -> Tuple[bool, str]:

mypy/constraints.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,9 @@ def find_matching_overload_item(overloaded: Overloaded, template: CallableType)
525525
for item in items:
526526
# Return type may be indeterminate in the template, so ignore it when performing a
527527
# subtype check.
528-
if mypy.subtypes.is_callable_subtype(item, template, ignore_return=True):
528+
if mypy.subtypes.is_callable_compatible(item, template,
529+
is_compat=mypy.subtypes.is_subtype,
530+
ignore_return=True):
529531
return item
530532
# Fall back to the first item if we can't find a match. This is totally arbitrary --
531533
# maybe we should just bail out at this point.

mypy/subtypes.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,9 @@ def visit_type_var(self, left: TypeVarType) -> bool:
203203
def visit_callable_type(self, left: CallableType) -> bool:
204204
right = self.right
205205
if isinstance(right, CallableType):
206-
return is_callable_subtype(
206+
return is_callable_compatible(
207207
left, right,
208+
is_compat=is_subtype,
208209
ignore_pos_arg_names=self.ignore_pos_arg_names)
209210
elif isinstance(right, Overloaded):
210211
return all(is_subtype(left, item, self.check_type_parameter,
@@ -310,10 +311,12 @@ def visit_overloaded(self, left: Overloaded) -> bool:
310311
else:
311312
# If this one overlaps with the supertype in any way, but it wasn't
312313
# an exact match, then it's a potential error.
313-
if (is_callable_subtype(left_item, right_item, ignore_return=True,
314-
ignore_pos_arg_names=self.ignore_pos_arg_names) or
315-
is_callable_subtype(right_item, left_item, ignore_return=True,
316-
ignore_pos_arg_names=self.ignore_pos_arg_names)):
314+
if (is_callable_compatible(left_item, right_item,
315+
is_compat=is_subtype, ignore_return=True,
316+
ignore_pos_arg_names=self.ignore_pos_arg_names) or
317+
is_callable_compatible(right_item, left_item,
318+
is_compat=is_subtype, ignore_return=True,
319+
ignore_pos_arg_names=self.ignore_pos_arg_names)):
317320
# If this is an overload that's already been matched, there's no
318321
# problem.
319322
if left_item not in matched_overloads:
@@ -568,16 +571,22 @@ def non_method_protocol_members(tp: TypeInfo) -> List[str]:
568571
return result
569572

570573

571-
def is_callable_subtype(left: CallableType, right: CallableType,
572-
ignore_return: bool = False,
573-
ignore_pos_arg_names: bool = False,
574-
use_proper_subtype: bool = False) -> bool:
575-
"""Is left a subtype of right?"""
574+
def is_callable_compatible(left: CallableType, right: CallableType,
575+
*,
576+
is_compat: Callable[[Type, Type], bool],
577+
ignore_return: bool = False,
578+
ignore_pos_arg_names: bool = False,
579+
check_args_covariantly: bool = False) -> bool:
580+
"""Is the left compatible with the right, using the provided compatibility check?
576581
577-
if use_proper_subtype:
578-
is_compat = is_proper_subtype
579-
else:
580-
is_compat = is_subtype
582+
If 'check_args_covariantly' is set to True, check if the left's args is
583+
compatible with the right's instead of the other way around (contravariantly).
584+
585+
This function is mostly used to check if the left is a subtype of the right which
586+
is why the default is to check the args covariantly. However, it's occasionally
587+
useful to check the args using some other check, so we leave the variance
588+
configurable.
589+
"""
581590

582591
# If either function is implicitly typed, ignore positional arg names too
583592
if left.implicit or right.implicit:
@@ -610,6 +619,9 @@ def is_callable_subtype(left: CallableType, right: CallableType,
610619
if not ignore_return and not is_compat(left.ret_type, right.ret_type):
611620
return False
612621

622+
if check_args_covariantly:
623+
is_compat = flip_compat_check(is_compat)
624+
613625
if right.is_ellipsis_args:
614626
return True
615627

@@ -664,7 +676,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
664676
right_by_position = right.argument_by_position(j)
665677
assert right_by_position is not None
666678
if not are_args_compatible(left_by_position, right_by_position,
667-
ignore_pos_arg_names, use_proper_subtype):
679+
ignore_pos_arg_names, is_compat):
668680
return False
669681
j += 1
670682
continue
@@ -687,7 +699,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
687699
right_by_name = right.argument_by_name(name)
688700
assert right_by_name is not None
689701
if not are_args_compatible(left_by_name, right_by_name,
690-
ignore_pos_arg_names, use_proper_subtype):
702+
ignore_pos_arg_names, is_compat):
691703
return False
692704
continue
693705

@@ -696,7 +708,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
696708
if left_arg is None:
697709
return False
698710

699-
if not are_args_compatible(left_arg, right_arg, ignore_pos_arg_names, use_proper_subtype):
711+
if not are_args_compatible(left_arg, right_arg, ignore_pos_arg_names, is_compat):
700712
return False
701713

702714
done_with_positional = False
@@ -748,7 +760,7 @@ def are_args_compatible(
748760
left: FormalArgument,
749761
right: FormalArgument,
750762
ignore_pos_arg_names: bool,
751-
use_proper_subtype: bool) -> bool:
763+
is_compat: Callable[[Type, Type], bool]) -> bool:
752764
# If right has a specific name it wants this argument to be, left must
753765
# have the same.
754766
if right.name is not None and left.name != right.name:
@@ -759,18 +771,20 @@ def are_args_compatible(
759771
if right.pos is not None and left.pos != right.pos:
760772
return False
761773
# Left must have a more general type
762-
if use_proper_subtype:
763-
if not is_proper_subtype(right.typ, left.typ):
764-
return False
765-
else:
766-
if not is_subtype(right.typ, left.typ):
767-
return False
774+
if not is_compat(right.typ, left.typ):
775+
return False
768776
# If right's argument is optional, left's must also be.
769777
if not right.required and left.required:
770778
return False
771779
return True
772780

773781

782+
def flip_compat_check(is_compat: Callable[[Type, Type], bool]) -> Callable[[Type, Type], bool]:
783+
def new_is_compat(left: Type, right: Type) -> bool:
784+
return is_compat(right, left)
785+
return new_is_compat
786+
787+
774788
def unify_generic_callable(type: CallableType, target: CallableType,
775789
ignore_return: bool) -> Optional[CallableType]:
776790
"""Try to unify a generic callable type with another callable type.
@@ -913,10 +927,7 @@ def visit_type_var(self, left: TypeVarType) -> bool:
913927
def visit_callable_type(self, left: CallableType) -> bool:
914928
right = self.right
915929
if isinstance(right, CallableType):
916-
return is_callable_subtype(
917-
left, right,
918-
ignore_pos_arg_names=False,
919-
use_proper_subtype=True)
930+
return is_callable_compatible(left, right, is_compat=is_proper_subtype)
920931
elif isinstance(right, Overloaded):
921932
return all(is_proper_subtype(left, item)
922933
for item in right.items())

mypy/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,13 @@ def max_fixed_args(self) -> int:
789789
n -= 1
790790
return n
791791

792+
def max_positional_args(self) -> int:
793+
"""Returns the number of positional args.
794+
795+
This includes *arg and **kwargs but excludes keyword-only args."""
796+
blacklist = (ARG_NAMED, ARG_NAMED_OPT)
797+
return len([kind not in blacklist for kind in self.arg_kinds])
798+
792799
def corresponding_argument(self, model: FormalArgument) -> Optional[FormalArgument]:
793800
"""Return the argument in this function that corresponds to `model`"""
794801

0 commit comments

Comments
 (0)