Skip to content

stubtest: Fix crash with numpy array default values #18353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def _verify_arg_default_value(
stub_arg: nodes.Argument, runtime_arg: inspect.Parameter
) -> Iterator[str]:
"""Checks whether argument default values are compatible."""
if runtime_arg.default != inspect.Parameter.empty:
if runtime_arg.default is not inspect.Parameter.empty:
if stub_arg.kind.is_required():
yield (
f'runtime argument "{runtime_arg.name}" '
Expand Down Expand Up @@ -705,18 +705,26 @@ def _verify_arg_default_value(
stub_default is not UNKNOWN
and stub_default is not ...
and runtime_arg.default is not UNREPRESENTABLE
and (
stub_default != runtime_arg.default
# We want the types to match exactly, e.g. in case the stub has
# True and the runtime has 1 (or vice versa).
or type(stub_default) is not type(runtime_arg.default)
)
):
yield (
f'runtime argument "{runtime_arg.name}" '
f"has a default value of {runtime_arg.default!r}, "
f"which is different from stub argument default {stub_default!r}"
)
defaults_match = True
# We want the types to match exactly, e.g. in case the stub has
# True and the runtime has 1 (or vice versa).
if type(stub_default) is not type(runtime_arg.default):
defaults_match = False
else:
try:
defaults_match = bool(stub_default == runtime_arg.default)
except Exception:
# Exception can be raised in eq/ne dunder methods (e.g. numpy arrays)
# At this point, consider the default to be different, it is probably
# too complex to put in a stub anyway.
defaults_match = False
if not defaults_match:
yield (
f'runtime argument "{runtime_arg.name}" '
f"has a default value of {runtime_arg.default!r}, "
f"which is different from stub argument default {stub_default!r}"
)
else:
if stub_arg.kind.is_optional():
yield (
Expand Down Expand Up @@ -758,7 +766,7 @@ def get_type(arg: Any) -> str | None:

def has_default(arg: Any) -> bool:
if isinstance(arg, inspect.Parameter):
return bool(arg.default != inspect.Parameter.empty)
return arg.default is not inspect.Parameter.empty
if isinstance(arg, nodes.Argument):
return arg.kind.is_optional()
raise AssertionError
Expand Down Expand Up @@ -1628,13 +1636,13 @@ def anytype() -> mypy.types.AnyType:
arg_names.append(
None if arg.kind == inspect.Parameter.POSITIONAL_ONLY else arg.name
)
has_default = arg.default == inspect.Parameter.empty
no_default = arg.default is inspect.Parameter.empty
if arg.kind == inspect.Parameter.POSITIONAL_ONLY:
arg_kinds.append(nodes.ARG_POS if has_default else nodes.ARG_OPT)
arg_kinds.append(nodes.ARG_POS if no_default else nodes.ARG_OPT)
elif arg.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
arg_kinds.append(nodes.ARG_POS if has_default else nodes.ARG_OPT)
arg_kinds.append(nodes.ARG_POS if no_default else nodes.ARG_OPT)
elif arg.kind == inspect.Parameter.KEYWORD_ONLY:
arg_kinds.append(nodes.ARG_NAMED if has_default else nodes.ARG_NAMED_OPT)
arg_kinds.append(nodes.ARG_NAMED if no_default else nodes.ARG_NAMED_OPT)
elif arg.kind == inspect.Parameter.VAR_POSITIONAL:
arg_kinds.append(nodes.ARG_STAR)
elif arg.kind == inspect.Parameter.VAR_KEYWORD:
Expand Down
12 changes: 12 additions & 0 deletions mypy/test/teststubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,18 @@ def f11(text=None) -> None: pass
error="f11",
)

# Simulate numpy ndarray.__bool__ that raises an error
yield Case(
stub="def f12(x=1): ...",
runtime="""
class _ndarray:
def __eq__(self, obj): return self
def __bool__(self): raise ValueError
def f12(x=_ndarray()) -> None: pass
""",
error="f12",
)

@collect_cases
def test_static_class_method(self) -> Iterator[Case]:
yield Case(
Expand Down
Loading