Skip to content

Add some minimal support for PEP 612 in type stubs. #879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions pytype/load_pytd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,12 @@ def f() -> List[int]: ...
loader = load_pytd.Loader(None, self.python_version, pythonpath=[d.path])
foo = loader.import_name("foo")
bar = loader.import_name("bar")
self.assertEqual(pytd_utils.Print(foo), "foo.List = list")
self.assertEqual(pytd_utils.Print(foo),
"from builtins import list as List")
self.assertEqual(pytd_utils.Print(bar), textwrap.dedent("""
from typing import List

bar.List = list
from builtins import list as List

def bar.f() -> List[int]: ...
""").strip())
Expand Down Expand Up @@ -569,11 +570,7 @@ def test_star_import(self):
self._pickle_modules(loader, d, foo, bar)
loaded_ast = self._load_pickled_module(d, bar)
loaded_ast.Visit(visitors.VerifyLookup())
self.assertMultiLineEqual(pytd_utils.Print(loaded_ast),
textwrap.dedent("""
import foo

bar.A = foo.A""").lstrip())
self.assertEqual(pytd_utils.Print(loaded_ast), "from foo import A")

def test_function_alias(self):
with file_utils.Tempdir() as d:
Expand Down
24 changes: 23 additions & 1 deletion pytype/pyi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def __init__(self, module_info):
self.constants = []
self.aliases = collections.OrderedDict()
self.type_params = []
self.param_specs = []
self.generated_classes = collections.defaultdict(list)
self.module_path_map = {}

Expand Down Expand Up @@ -342,6 +343,16 @@ def add_type_var(self, name, typevar):
self.type_params.append(pytd.TypeParameter(
name=name, constraints=constraints, bound=bound))

def add_param_spec(self, name, paramspec):
if name != paramspec.name:
raise ParseError("ParamSpec name needs to be %r (not %r)" % (
paramspec.name, name))
# ParamSpec should probably be represented with its own pytd class, like
# TypeVar. This is just a quick, hacky way for us to keep track of which
# names refer to ParamSpecs so we can replace them with Any in
# _parameterized_type().
self.param_specs.append(pytd.NamedType(name))

def add_import(self, from_package, import_list):
"""Add an import.

Expand Down Expand Up @@ -419,7 +430,18 @@ def _parameterized_type(self, base_type, parameters):
if self._is_tuple_base_type(base_type):
return pytdgen.heterogeneous_tuple(base_type, parameters)
elif self._is_callable_base_type(base_type):
return pytdgen.pytd_callable(base_type, parameters)
callable_parameters = []
for p in parameters:
# We do not yet support PEP 612, Parameter Specification Variables.
# To avoid blocking typeshed from adopting this PEP, we convert new
# features to Any.
if p in self.param_specs or (
isinstance(p, pytd.GenericType) and self._matches_full_name(
p, ("typing.Concatenate", "typing_extensions.Concatenate"))):
callable_parameters.append(pytd.AnythingType())
else:
callable_parameters.append(p)
return pytdgen.pytd_callable(base_type, tuple(callable_parameters))
else:
assert parameters
return pytd.GenericType(base_type=base_type, parameters=parameters)
Expand Down
64 changes: 47 additions & 17 deletions pytype/pyi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@
ParseError = types.ParseError


_TYPEVAR_IDS = ("TypeVar", "typing.TypeVar")
_PARAMSPEC_IDS = (
"ParamSpec", "typing.ParamSpec", "typing_extensions.ParamSpec")
_TYPING_NAMEDTUPLE_IDS = ("NamedTuple", "typing.NamedTuple")
_COLL_NAMEDTUPLE_IDS = ("namedtuple", "collections.namedtuple")
_TYPEDDICT_IDS = (
"TypedDict", "typing.TypedDict", "typing_extensions.TypedDict")
_NEWTYPE_IDS = ("NewType", "typing.NewType")


#------------------------------------------------------
# imports

Expand Down Expand Up @@ -86,6 +96,18 @@ def from_call(cls, node: ast3.AST) -> "_TypeVar":
return cls(name, bound, constraints)


@dataclasses.dataclass
class _ParamSpec:
"""Internal representation of ParamSpecs."""

name: str

@classmethod
def from_call(cls, node: ast3.AST) -> "_ParamSpec":
name, = node.args
return cls(name)


#------------------------------------------------------
# pytd utils

Expand Down Expand Up @@ -340,10 +362,13 @@ def visit_Assign(self, node):
target = targets[0]
name = target.id

# Record and erase typevar definitions.
# Record and erase TypeVar and ParamSpec definitions.
if isinstance(node.value, _TypeVar):
self.defs.add_type_var(name, node.value)
return Splice([])
elif isinstance(node.value, _ParamSpec):
self.defs.add_param_spec(name, node.value)
return Splice([])

if node.type_comment:
# TODO(mdemello): can pyi files have aliases with typecomments?
Expand Down Expand Up @@ -412,15 +437,15 @@ def visit_ImportFrom(self, node):
self.defs.add_import(module, imports)
return Splice([])

def _convert_newtype_args(self, node):
def _convert_newtype_args(self, node: ast3.AST):
if len(node.args) != 2:
msg = "Wrong args: expected NewType(name, [(field, type), ...])"
raise ParseError(msg)
name, typ = node.args
typ = self.convert_node(typ)
node.args = [name.s, typ]

def _convert_typing_namedtuple_args(self, node):
def _convert_typing_namedtuple_args(self, node: ast3.AST):
# TODO(mdemello): handle NamedTuple("X", a=int, b=str, ...)
if len(node.args) != 2:
msg = "Wrong args: expected NamedTuple(name, [(field, type), ...])"
Expand All @@ -430,7 +455,7 @@ def _convert_typing_namedtuple_args(self, node):
fields = [(types.string_value(n), t) for (n, t) in fields]
node.args = [name.s, fields]

def _convert_collections_namedtuple_args(self, node):
def _convert_collections_namedtuple_args(self, node: ast3.AST):
if len(node.args) != 2:
msg = "Wrong args: expected namedtuple(name, [field, ...])"
raise ParseError(msg)
Expand All @@ -454,7 +479,11 @@ def _convert_typevar_args(self, node):
val = types.string_value(kw.value, context="TypeVar bound")
kw.value = self.annotation_visitor.convert_late_annotation(val)

def _convert_typed_dict_args(self, node):
def _convert_paramspec_args(self, node):
name, = node.args
node.args = [name.s]

def _convert_typed_dict_args(self, node: ast3.AST):
# TODO(b/157603915): new_typed_dict currently doesn't do anything with the
# args, so we don't bother converting them fully.
msg = "Wrong args: expected TypedDict(name, {field: type, ...})"
Expand All @@ -473,30 +502,31 @@ def enter_Call(self, node):
# passing them to internal functions directly in visit_Call.
if isinstance(node.func, ast3.Attribute):
node.func = _attribute_to_name(node.func)
if node.func.id in ("TypeVar", "typing.TypeVar"):
if node.func.id in _TYPEVAR_IDS:
self._convert_typevar_args(node)
elif node.func.id in ("NamedTuple", "typing.NamedTuple"):
elif node.func.id in _PARAMSPEC_IDS:
self._convert_paramspec_args(node)
elif node.func.id in _TYPING_NAMEDTUPLE_IDS:
self._convert_typing_namedtuple_args(node)
elif node.func.id in ("namedtuple", "collections.namedtuple"):
elif node.func.id in _COLL_NAMEDTUPLE_IDS:
self._convert_collections_namedtuple_args(node)
elif node.func.id in ("TypedDict", "typing.TypedDict",
"typing_extensions.TypedDict"):
elif node.func.id in _TYPEDDICT_IDS:
self._convert_typed_dict_args(node)
elif node.func.id in ("NewType", "typing.NewType"):
elif node.func.id in _NEWTYPE_IDS:
return self._convert_newtype_args(node)

def visit_Call(self, node):
if node.func.id in ("TypeVar", "typing.TypeVar"):
if node.func.id in _TYPEVAR_IDS:
if self.level > 0:
raise ParseError("TypeVars need to be defined at module level")
return _TypeVar.from_call(node)
elif node.func.id in ("NamedTuple", "typing.NamedTuple",
"namedtuple", "collections.namedtuple"):
elif node.func.id in _PARAMSPEC_IDS:
return _ParamSpec.from_call(node)
elif node.func.id in _TYPING_NAMEDTUPLE_IDS + _COLL_NAMEDTUPLE_IDS:
return self.defs.new_named_tuple(*node.args)
elif node.func.id in ("TypedDict", "typing.TypedDict",
"typing_extensions.TypedDict"):
elif node.func.id in _TYPEDDICT_IDS:
return self.defs.new_typed_dict(*node.args, total=False)
elif node.func.id in ("NewType", "typing.NewType"):
elif node.func.id in _NEWTYPE_IDS:
return self.defs.new_new_type(*node.args)
# Convert all other calls to NamedTypes; for example:
# * typing.pyi uses things like
Expand Down
124 changes: 124 additions & 0 deletions pytype/pyi/parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2715,5 +2715,129 @@ def test_feature_version(self):
self.assertEqual(actual, expected)


class ParamSpecTest(_ParserTestBase):

def test_from_typing(self):
self.check("""
from typing import Awaitable, Callable, ParamSpec, TypeVar

P = ParamSpec('P')
R = TypeVar('R')

def f(x: Callable[P, R]) -> Callable[P, Awaitable[R]]: ...
""", """
from typing import Awaitable, Callable, TypeVar

R = TypeVar('R')

def f(x: Callable[..., R]) -> Callable[..., Awaitable[R]]: ...
""")

def test_from_typing_extensions(self):
self.check("""
from typing import Awaitable, Callable, TypeVar

from typing_extensions import ParamSpec

P = ParamSpec('P')
R = TypeVar('R')

def f(x: Callable[P, R]) -> Callable[P, Awaitable[R]]: ...
""", """
from typing import Awaitable, Callable, TypeVar

from typing_extensions import ParamSpec

R = TypeVar('R')

def f(x: Callable[..., R]) -> Callable[..., Awaitable[R]]: ...
""")

@test_base.skip("ParamSpec in custom generic classes not supported yet")
def test_custom_generic(self):
self.check("""
from typing import Callable, Generic, ParamSpec, TypeVar

P = ParamSpec('P')
T = TypeVar('T')

class X(Generic[T, P]):
f: Callable[P, int]
x: T
""")

@test_base.skip("ParamSpec in custom generic classes not supported yet")
def test_double_brackets(self):
# Double brackets can be omitted when instantiating a class parameterized
# with only a single ParamSpec.
self.check("""
from typing import Generic, ParamSpec

P = ParamSpec('P')

class X(Generic[P]): ...

def f1(x: X[int, str]) -> None: ...
def f2(x: X[[int, str]]) -> None: ...
""", """
from typing import Generic, ParamSpec

P = ParamSpec('P')

class X(Generic[P]): ...

def f1(x: X[int, str]) -> None: ...
def f2(x: X[int, str]) -> None: ...
""")


class ConcatenateTest(_ParserTestBase):

def test_from_typing(self):
self.check("""
from typing import Callable, Concatenate, ParamSpec, TypeVar

P = ParamSpec('P')
R = TypeVar('R')

class X: ...

def f(x: Callable[Concatenate[X, P], R]) -> Callable[P, R]: ...
""", """
from typing import Callable, TypeVar

R = TypeVar('R')

class X: ...

def f(x: Callable[..., R]) -> Callable[..., R]: ...
""")

def test_from_typing_extensions(self):
self.check("""
from typing import Callable, TypeVar

from typing_extensions import Concatenate, ParamSpec

P = ParamSpec('P')
R = TypeVar('R')

class X: ...

def f(x: Callable[Concatenate[X, P], R]) -> Callable[P, R]: ...
""", """
from typing import Callable, TypeVar

from typing_extensions import Concatenate
from typing_extensions import ParamSpec

R = TypeVar('R')

class X: ...

def f(x: Callable[..., R]) -> Callable[..., R]: ...
""")


if __name__ == "__main__":
unittest.main()
9 changes: 6 additions & 3 deletions pytype/pytd/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,16 @@ def EnterAlias(self, _):

def VisitAlias(self, node):
"""Convert an import or alias to a string."""
if isinstance(self.old_node.type, pytd.NamedType):
if isinstance(self.old_node.type, (pytd.NamedType, pytd.ClassType)):
full_name = self.old_node.type.name
suffix = ""
module, _, name = full_name.rpartition(".")
if module:
if name not in ("*", self.old_node.name):
suffix += " as " + self.old_node.name
alias_name = self.old_node.name
if alias_name.startswith(f"{self._unit_name}."):
alias_name = alias_name[len(self._unit_name)+1:]
if name not in ("*", alias_name):
suffix += " as " + alias_name
self.imports = self.old_imports # undo unnecessary imports change
return "from " + module + " import " + name + suffix
elif isinstance(self.old_node.type, (pytd.Constant, pytd.Function)):
Expand Down
6 changes: 1 addition & 5 deletions pytype/pytd/visitors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,7 @@ class A: ...
ast2 = ast2.Visit(visitors.LookupExternalTypes(
{"foo": ast1}, self_name=None))
self.assertEqual(name, ast2.name)
self.assertMultiLineEqual(pytd_utils.Print(ast2), textwrap.dedent("""
import foo

A = foo.A
""").strip())
self.assertEqual(pytd_utils.Print(ast2), "from foo import A")

def test_lookup_two_star_aliases(self):
src1 = "class A: ..."
Expand Down