Skip to content

Treat NewTypes like normal subclasses #1301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -16,6 +16,11 @@ Release date: TBA

Refs PyCQA/pylint#2567

* Treat ``typing.NewType()`` values as normal subclasses.

Closes PyCQA/pylint#2296
Closes PyCQA/pylint#3162

What's New in astroid 2.12.13?
==============================
Release date: 2022-11-19
176 changes: 162 additions & 14 deletions astroid/brain/brain_typing.py
Original file line number Diff line number Diff line change
@@ -10,10 +10,11 @@
from collections.abc import Iterator
from functools import partial

from astroid import context, extract_node, inference_tip
from astroid import context, extract_node, inference_tip, nodes
from astroid.builder import _extract_single_node
from astroid.const import PY38_PLUS, PY39_PLUS
from astroid.exceptions import (
AstroidImportError,
AttributeInferenceError,
InferenceError,
UseInferenceDefault,
@@ -35,8 +36,6 @@
from astroid.util import Uninferable

TYPING_NAMEDTUPLE_BASENAMES = {"NamedTuple", "typing.NamedTuple"}
TYPING_TYPEVARS = {"TypeVar", "NewType"}
TYPING_TYPEVARS_QUALIFIED = {"typing.TypeVar", "typing.NewType"}
TYPING_TYPE_TEMPLATE = """
class Meta(type):
def __getitem__(self, item):
@@ -49,6 +48,13 @@ def __args__(self):
class {0}(metaclass=Meta):
pass
"""
# PEP484 suggests NewType is equivalent to this for typing purposes
# https://www.python.org/dev/peps/pep-0484/#newtype-helper-function
TYPING_NEWTYPE_TEMPLATE = """
class {derived}({base}):
def __init__(self, val: {base}) -> None:
...
"""
TYPING_MEMBERS = set(getattr(typing, "__all__", []))

TYPING_ALIAS = frozenset(
@@ -103,24 +109,33 @@ def __class_getitem__(cls, item):
"""


def looks_like_typing_typevar_or_newtype(node):
def looks_like_typing_typevar(node: nodes.Call) -> bool:
func = node.func
if isinstance(func, Attribute):
return func.attrname in TYPING_TYPEVARS
return func.attrname == "TypeVar"
if isinstance(func, Name):
return func.name in TYPING_TYPEVARS
return func.name == "TypeVar"
return False


def infer_typing_typevar_or_newtype(node, context_itton=None):
"""Infer a typing.TypeVar(...) or typing.NewType(...) call"""
def looks_like_typing_newtype(node: nodes.Call) -> bool:
func = node.func
if isinstance(func, Attribute):
return func.attrname == "NewType"
if isinstance(func, Name):
return func.name == "NewType"
return False


def infer_typing_typevar(
node: nodes.Call, ctx: context.InferenceContext | None = None
) -> Iterator[nodes.ClassDef]:
"""Infer a typing.TypeVar(...) call"""
try:
func = next(node.func.infer(context=context_itton))
next(node.func.infer(context=ctx))
except (InferenceError, StopIteration) as exc:
raise UseInferenceDefault from exc

if func.qname() not in TYPING_TYPEVARS_QUALIFIED:
raise UseInferenceDefault
if not node.args:
raise UseInferenceDefault
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L127-128 has a drop in coverage. Could you re-create a test for it?

# Cannot infer from a dynamic class name (f-string)
@@ -129,7 +144,135 @@ def infer_typing_typevar_or_newtype(node, context_itton=None):

typename = node.args[0].as_string().strip("'")
node = extract_node(TYPING_TYPE_TEMPLATE.format(typename))
return node.infer(context=context_itton)
return node.infer(context=ctx)


def infer_typing_newtype(
node: nodes.Call, ctx: context.InferenceContext | None = None
) -> Iterator[nodes.ClassDef]:
"""Infer a typing.NewType(...) call"""
try:
next(node.func.infer(context=ctx))
except (InferenceError, StopIteration) as exc:
raise UseInferenceDefault from exc

if len(node.args) != 2:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you create a test for this? It is currently uncovered.

raise UseInferenceDefault

# Cannot infer from a dynamic class name (f-string)
if isinstance(node.args[0], JoinedStr) or isinstance(node.args[1], JoinedStr):
raise UseInferenceDefault

derived, base = node.args
derived_name = derived.as_string().strip("'")
base_name = base.as_string().strip("'")

new_node: ClassDef = extract_node(
TYPING_NEWTYPE_TEMPLATE.format(derived=derived_name, base=base_name)
)
new_node.parent = node.parent

new_bases: list[NodeNG] = []

if not isinstance(base, nodes.Const):
# Base type arg is a normal reference, so no need to do special lookups
new_bases = [base]
elif isinstance(base, nodes.Const) and isinstance(base.value, str):
# If the base type is given as a string (e.g. for a forward reference),
# make a naive attempt to find the corresponding node.
_, resolved_base = node.frame().lookup(base_name)
if resolved_base:
base_node = resolved_base[0]

# If the value is from an "import from" statement, follow the import chain
if isinstance(base_node, nodes.ImportFrom):
ctx = ctx.clone() if ctx else context.InferenceContext()
ctx.lookupname = base_name
base_node = next(base_node.infer(context=ctx))

new_bases = [base_node]
elif "." in base.value:
possible_base = _try_find_imported_object_from_str(node, base.value, ctx)
if possible_base:
new_bases = [possible_base]

if new_bases:
new_node.postinit(
bases=new_bases, body=new_node.body, decorators=new_node.decorators
)
Comment on lines +200 to +202
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new_node is fully constructed already. It's enough to set bases manually.

Suggested change
new_node.postinit(
bases=new_bases, body=new_node.body, decorators=new_node.decorators
)
new_node.bases = new_bases


return new_node.infer(context=ctx)


def _try_find_imported_object_from_str(
node: nodes.Call,
name: str,
ctx: context.InferenceContext | None,
) -> nodes.NodeNG | None:
for statement_mod_name, _ in _possible_module_object_splits(name):
# Find import statements that may pull in the appropriate modules
# The name used to find this statement may not correspond to the name of the module actually being imported
# For example, "import email.charset" is found by lookup("email")
_, resolved_bases = node.frame().lookup(statement_mod_name)
if not resolved_bases:
continue

resolved_base = resolved_bases[0]
if isinstance(resolved_base, nodes.Import):
# Extract the names of the module as they are accessed from actual code
scope_names = {(alias or name) for (name, alias) in resolved_base.names}
aliases = {alias: name for (name, alias) in resolved_base.names if alias}

# Find potential mod_name, obj_name splits that work with the available names
# for the module in this scope
import_targets = [
(mod_name, obj_name)
for (mod_name, obj_name) in _possible_module_object_splits(name)
if mod_name in scope_names
]
if not import_targets:
continue

import_target, name_in_mod = import_targets[0]
import_target = aliases.get(import_target, import_target)

# Try to import the module and find the object in it
try:
resolved_mod: nodes.Module = resolved_base.do_import_module(
import_target
)
except AstroidImportError:
# If the module doesn't actually exist, try the next option
continue

# Try to find the appropriate ClassDef or other such node in the target module
_, object_results_in_mod = resolved_mod.lookup(name_in_mod)
if not object_results_in_mod:
continue

base_node = object_results_in_mod[0]

# If the value is from an "import from" statement, follow the import chain
if isinstance(base_node, nodes.ImportFrom):
ctx = ctx.clone() if ctx else context.InferenceContext()
ctx.lookupname = name_in_mod
base_node = next(base_node.infer(context=ctx))

return base_node

return None


def _possible_module_object_splits(
dot_str: str,
) -> Iterator[tuple[str, str]]:
components = dot_str.split(".")
popped = []

while components:
popped.append(components.pop())

yield ".".join(components), ".".join(reversed(popped))


def _looks_like_typing_subscript(node):
@@ -404,8 +547,13 @@ def infer_typing_cast(

AstroidManager().register_transform(
Call,
inference_tip(infer_typing_typevar_or_newtype),
looks_like_typing_typevar_or_newtype,
inference_tip(infer_typing_typevar),
looks_like_typing_typevar,
)
AstroidManager().register_transform(
Call,
inference_tip(infer_typing_newtype),
looks_like_typing_newtype,
)
AstroidManager().register_transform(
Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript
352 changes: 351 additions & 1 deletion tests/unittest_brain.py
Original file line number Diff line number Diff line change
@@ -1718,6 +1718,26 @@ def test_typing_types(self) -> None:
inferred = next(node.infer())
self.assertIsInstance(inferred, nodes.ClassDef, node.as_string())

def test_typing_typevar_bad_args(self) -> None:
ast_nodes = builder.extract_node(
"""
from typing import TypeVar
T = TypeVar()
T #@
U = TypeVar(f"U")
U #@
"""
)
assert isinstance(ast_nodes, list)

no_args_node = ast_nodes[0]
assert list(no_args_node.infer()) == [util.Uninferable]

fstr_node = ast_nodes[1]
assert list(fstr_node.infer()) == [util.Uninferable]

def test_typing_type_without_tip(self):
"""Regression test for https://github.com/PyCQA/pylint/issues/5770"""
node = builder.extract_node(
@@ -1729,7 +1749,337 @@ def make_new_type(t):
"""
)
with self.assertRaises(UseInferenceDefault):
astroid.brain.brain_typing.infer_typing_typevar_or_newtype(node.value)
astroid.brain.brain_typing.infer_typing_newtype(node.value)

def test_typing_newtype_attrs(self) -> None:
ast_nodes = builder.extract_node(
"""
from typing import NewType
import decimal
from decimal import Decimal
NewType("Foo", str) #@
NewType("Bar", "int") #@
NewType("Baz", Decimal) #@
NewType("Qux", decimal.Decimal) #@
"""
)
assert isinstance(ast_nodes, list)

# Base type given by reference
foo_node = ast_nodes[0]

# Should be unambiguous
foo_inferred_all = list(foo_node.infer())
assert len(foo_inferred_all) == 1

foo_inferred = foo_inferred_all[0]
assert isinstance(foo_inferred, astroid.ClassDef)

# Check base type method is inferred by accessing one of its methods
foo_base_class_method = foo_inferred.getattr("endswith")[0]
assert isinstance(foo_base_class_method, astroid.FunctionDef)
assert foo_base_class_method.qname() == "builtins.str.endswith"

# Base type given by string (i.e. "int")
bar_node = ast_nodes[1]
bar_inferred_all = list(bar_node.infer())
assert len(bar_inferred_all) == 1
bar_inferred = bar_inferred_all[0]
assert isinstance(bar_inferred, astroid.ClassDef)

bar_base_class_method = bar_inferred.getattr("bit_length")[0]
assert isinstance(bar_base_class_method, astroid.FunctionDef)
assert bar_base_class_method.qname() == "builtins.int.bit_length"

# Decimal may be reexported from an implementation-defined module. For
# example, in CPython 3.10 this is _decimal, but in PyPy 7.3 it's
# _pydecimal. So the expected qname needs to be grabbed dynamically.
decimal_quant_node = builder.extract_node(
"""
from decimal import Decimal
Decimal.quantize #@
"""
)
assert isinstance(decimal_quant_node, nodes.NodeNG)

# Just grab the first result, since infer() may return values for both
# _decimal and _pydecimal
decimal_quant_qname = next(decimal_quant_node.infer()).qname()

# Base type is from an "import from"
baz_node = ast_nodes[2]
baz_inferred_all = list(baz_node.infer())
assert len(baz_inferred_all) == 1
baz_inferred = baz_inferred_all[0]
assert isinstance(baz_inferred, astroid.ClassDef)

baz_base_class_method = baz_inferred.getattr("quantize")[0]
assert isinstance(baz_base_class_method, astroid.FunctionDef)
assert decimal_quant_qname == baz_base_class_method.qname()

# Base type is from an import
qux_node = ast_nodes[3]
qux_inferred_all = list(qux_node.infer())
qux_inferred = qux_inferred_all[0]
assert isinstance(qux_inferred, astroid.ClassDef)

qux_base_class_method = qux_inferred.getattr("quantize")[0]
assert isinstance(qux_base_class_method, astroid.FunctionDef)
assert decimal_quant_qname == qux_base_class_method.qname()

def test_typing_newtype_bad_args(self) -> None:
ast_nodes = builder.extract_node(
"""
from typing import NewType
NoArgs = NewType()
NoArgs #@
OneArg = NewType("OneArg")
OneArg #@
ThreeArgs = NewType("ThreeArgs", int, str)
ThreeArgs #@
DynamicArg = NewType(f"DynamicArg", int)
DynamicArg #@
DynamicBase = NewType("DynamicBase", f"int")
DynamicBase #@
"""
)
assert isinstance(ast_nodes, list)

node: nodes.NodeNG
for node in ast_nodes:
assert list(node.infer()) == [util.Uninferable]

def test_typing_newtype_user_defined(self) -> None:
ast_nodes = builder.extract_node(
"""
from typing import NewType
class A:
def __init__(self, value: int):
self.value = value
a = A(5)
a #@
B = NewType("B", A)
b = B(5)
b #@
Comment on lines +1871 to +1872
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
b = B(5)
b #@
b = B(A(5))
b #@

What is the inference result of b? At runtime it should be an instance of A.

"""
)
assert isinstance(ast_nodes, list)

for node in ast_nodes:
self._verify_node_has_expected_attr(node)

def test_typing_newtype_forward_reference(self) -> None:
# Similar to the test above, but using a forward reference for "A"
ast_nodes = builder.extract_node(
"""
from typing import NewType
B = NewType("B", "A")
class A:
def __init__(self, value: int):
self.value = value
a = A(5)
a #@
b = B(5)
b #@
"""
)
assert isinstance(ast_nodes, list)

for node in ast_nodes:
self._verify_node_has_expected_attr(node)

def _verify_node_has_expected_attr(self, node: nodes.NodeNG) -> None:
inferred_all = list(node.infer())
assert len(inferred_all) == 1
inferred = inferred_all[0]
assert isinstance(inferred, astroid.Instance)

# Should be able to infer that the "value" attr is present on both types
val = inferred.getattr("value")[0]
assert isinstance(val, astroid.AssignAttr)

# Sanity check: nonexistent attr is not inferred
with self.assertRaises(AttributeInferenceError):
inferred.getattr("bad_attr")

def test_typing_newtype_forward_reference_imported(self) -> None:
all_ast_nodes = builder.extract_node(
"""
from typing import NewType
A = NewType("A", "decimal.Decimal")
B = NewType("B", "decimal_mod_alias.Decimal")
C = NewType("C", "Decimal")
D = NewType("D", "DecimalAlias")
import decimal
import decimal as decimal_mod_alias
from decimal import Decimal
from decimal import Decimal as DecimalAlias
Decimal #@
a = A(decimal.Decimal(2))
a #@
b = B(decimal_mod_alias.Decimal(2))
b #@
c = C(Decimal(2))
c #@
d = D(DecimalAlias(2))
d #@
"""
)
assert isinstance(all_ast_nodes, list)

real_dec, *ast_nodes = all_ast_nodes

real_quantize = next(real_dec.infer()).getattr("quantize")

for node in ast_nodes:
all_inferred = list(node.infer())
assert len(all_inferred) == 1
inferred = all_inferred[0]
assert isinstance(inferred, astroid.Instance)

assert inferred.getattr("quantize") == real_quantize

def test_typing_newtype_forward_ref_bad_base(self) -> None:
ast_nodes = builder.extract_node(
"""
from typing import NewType
A = NewType("A", "DoesntExist")
a = A()
a #@
# Valid name, but not actually imported
B = NewType("B", "decimal.Decimal")
b = B()
b #@
# AST works out, but can't import the module
import not_a_real_module
C = NewType("C", "not_a_real_module.SomeClass")
c = C()
c #@
# Real module, fake base class name
import email.charset
D = NewType("D", "email.charset.BadClassRef")
d = D()
d #@
# Real module, but aliased differently than used
import email.header as header_mod
E = NewType("E", "email.header.Header")
e = E(header_mod.Header())
e #@
"""
)
assert isinstance(ast_nodes, list)

for ast_node in ast_nodes:
inferred = next(ast_node.infer())

with self.assertRaises(astroid.AttributeInferenceError):
inferred.getattr("value")

def test_typing_newtype_forward_ref_nested_module(self) -> None:
ast_nodes = builder.extract_node(
"""
from typing import NewType
A = NewType("A", "email.charset.Charset")
B = NewType("B", "charset.Charset")
# header is unused in both cases, but verifies that module name is properly checked
import email.header, email.charset
from email import header, charset
real = charset.Charset()
real #@
a = A(email.charset.Charset())
a #@
b = B(charset.Charset())
"""
)
assert isinstance(ast_nodes, list)

real, *newtypes = ast_nodes

real_inferred_all = list(real.infer())
assert len(real_inferred_all) == 1
real_inferred = real_inferred_all[0]

real_method = real_inferred.getattr("get_body_encoding")

for newtype_node in newtypes:
newtype_inferred_all = list(newtype_node.infer())
assert len(newtype_inferred_all) == 1
newtype_inferred = newtype_inferred_all[0]

newtype_method = newtype_inferred.getattr("get_body_encoding")

assert real_method == newtype_method

def test_typing_newtype_forward_ref_nested_class(self) -> None:
ast_nodes = builder.extract_node(
"""
from typing import NewType
A = NewType("A", "SomeClass.Nested")
class SomeClass:
class Nested:
def method(self) -> None:
pass
real = SomeClass.Nested()
real #@
a = A(SomeClass.Nested())
a #@
"""
)
assert isinstance(ast_nodes, list)

real, newtype = ast_nodes

real_all_inferred = list(real.infer())
assert len(real_all_inferred) == 1
real_inferred = real_all_inferred[0]
real_method = real_inferred.getattr("method")

newtype_all_inferred = list(newtype.infer())
assert len(newtype_all_inferred) == 1
newtype_inferred = newtype_all_inferred[0]

# This could theoretically work, but for now just here to check that
# the "forward-declared module" inference doesn't totally break things
with self.assertRaises(astroid.AttributeInferenceError):
newtype_method = newtype_inferred.getattr("method")

assert real_method == newtype_method

def test_namedtuple_nested_class(self):
result = builder.extract_node(