Skip to content

Commit 9eadb4f

Browse files
committed
working implementation
1 parent e270fe8 commit 9eadb4f

File tree

4 files changed

+114
-9
lines changed

4 files changed

+114
-9
lines changed

mypy/checker.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import itertools
44
import fnmatch
55
from contextlib import contextmanager
6+
import sys
67

78
from typing import (
89
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator
@@ -380,7 +381,8 @@ def get_generator_yield_type(self, return_type: Type, is_coroutine: bool) -> Typ
380381
"""Given the declared return type of a generator (t), return the type it yields (ty)."""
381382
if isinstance(return_type, AnyType):
382383
return AnyType()
383-
elif not self.is_generator_return_type(return_type, is_coroutine):
384+
elif (not self.is_generator_return_type(return_type, is_coroutine)
385+
and not self.is_async_generator_return_type(return_type)):
384386
# If the function doesn't have a proper Generator (or
385387
# Awaitable) return type, anything is permissible.
386388
return AnyType()
@@ -411,7 +413,8 @@ def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> T
411413
"""Given a declared generator return type (t), return the type its yield receives (tc)."""
412414
if isinstance(return_type, AnyType):
413415
return AnyType()
414-
elif not self.is_generator_return_type(return_type, is_coroutine):
416+
elif (not self.is_generator_return_type(return_type, is_coroutine)
417+
and not self.is_async_generator_return_type(return_type)):
415418
# If the function doesn't have a proper Generator (or
416419
# Awaitable) return type, anything is permissible.
417420
return AnyType()
@@ -425,6 +428,8 @@ def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> T
425428
and len(return_type.args) >= 3):
426429
# Generator: tc is args[1].
427430
return return_type.args[1]
431+
elif return_type.type.fullname() == 'typing.AsyncGenerator' and len(return_type.args) >= 2:
432+
return return_type.args[1]
428433
else:
429434
# `return_type` is a supertype of Generator, so callers won't be able to send it
430435
# values. IOW, tc is None.

mypy/semanal.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,9 +1068,7 @@ def normalize_type_alias(self, node: SymbolTableNode,
10681068
ctx: Context) -> SymbolTableNode:
10691069
if node.fullname in type_aliases:
10701070
# Node refers to an aliased type such as typing.List; normalize.
1071-
old_node = node
10721071
node = self.lookup_qualified(type_aliases[node.fullname], ctx)
1073-
assert node is not None, (type_aliases, old_node, old_node.fullname, ctx)
10741072
if node.fullname == 'typing.DefaultDict':
10751073
self.add_module_symbol('collections', '__mypy_collections__', False, ctx)
10761074
node = self.lookup_qualified('__mypy_collections__.defaultdict', ctx)

test-data/unit/check-async-await.test

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -303,25 +303,125 @@ def f() -> Generator[int, str, int]:
303303
[builtins fixtures/async_await.pyi]
304304
[out]
305305

306+
-- Async generators (PEP 525), some test cases adapted from the PEP text
307+
-- ---------------------------------------------------------------------
308+
306309
[case testAsyncGenerator]
307310
# flags: --fast-parser --python-version 3.6
308311
from mypy_extensions import AsyncGenerator
309312

310313
async def f() -> int:
311314
return 42
312315

313-
async def g() -> AsyncGenerator[int, str]:
316+
async def g() -> AsyncGenerator[int, None]:
314317
value = await f()
315-
reveal_type(value) # E: Revealed type is 'builtins.int'
316-
x = yield value
317-
reveal_type(x) # E: Revealed type is 'builtins.str'
318+
reveal_type(value) # E: Revealed type is 'builtins.int*'
319+
yield value
320+
# return without a value is fine
321+
return
322+
reveal_type(g) # E: Revealed type is 'def () -> typing.AsyncGenerator[builtins.int, void]'
323+
reveal_type(g()) # E: Revealed type is 'typing.AsyncGenerator[builtins.int, void]'
318324

319325
async def h() -> None:
320326
async for item in g():
321-
reveal_type(item) # E: Revealed type is 'builtins.int'
327+
reveal_type(item) # E: Revealed type is 'builtins.int*'
328+
329+
[builtins fixtures/dict.pyi]
330+
331+
[case testAsyncGeneratorManualIter]
332+
# flags: --fast-parser --python-version 3.6
333+
from mypy_extensions import AsyncGenerator
334+
335+
async def genfunc() -> AsyncGenerator[int, None]:
336+
yield 1
337+
yield 2
338+
339+
async def user() -> None:
340+
gen = genfunc()
341+
342+
reveal_type(gen.__aiter__()) # E: Revealed type is 'typing.AsyncGenerator[builtins.int*, void]'
343+
344+
reveal_type(await gen.__anext__()) # E: Revealed type is 'builtins.int*'
345+
346+
[builtins fixtures/dict.pyi]
347+
348+
[case testAsyncGeneratorAsend]
349+
# flags: --fast-parser --python-version 3.6
350+
from mypy_extensions import AsyncGenerator
351+
352+
async def f() -> None:
353+
pass
354+
355+
async def gen() -> AsyncGenerator[int, str]:
356+
await f()
357+
v = yield 42
358+
reveal_type(v) # E: Revealed type is 'builtins.str'
359+
await f()
360+
361+
async def h() -> None:
362+
g = gen()
363+
await g.asend(()) # E: Argument 1 to "asend" of "AsyncGenerator" has incompatible type "Tuple[]"; expected "str"
364+
reveal_type(await g.asend('hello')) # E: Revealed type is 'builtins.int*'
322365

323366
[builtins fixtures/dict.pyi]
324367

368+
[case testAsyncGeneratorAthrow]
369+
# flags: --fast-parser --python-version 3.6
370+
from mypy_extensions import AsyncGenerator
371+
372+
async def gen() -> AsyncGenerator[str, int]:
373+
try:
374+
yield 'hello'
375+
except BaseException:
376+
yield 'world'
377+
378+
async def h() -> None:
379+
g = gen()
380+
v = await g.asend(1)
381+
reveal_type(v) # E: Revealed type is 'builtins.str*'
382+
await g.athrow(BaseException)
383+
384+
[builtins fixtures/dict.pyi]
385+
386+
[case testAsyncGeneratorNoSyncIteration]
387+
# flags: --fast-parser --python-version 3.6
388+
from mypy_extensions import AsyncGenerator
389+
390+
async def gen() -> AsyncGenerator[int, None]:
391+
for i in (1, 2, 3):
392+
yield i
393+
394+
def h() -> None:
395+
for i in gen():
396+
pass
397+
398+
[builtins fixtures/dict.pyi]
399+
400+
[out]
401+
main:9: error: Iterable expected
402+
main:9: error: AsyncGenerator[int, None] has no attribute "__iter__"; maybe "__aiter__"?
403+
404+
[case testAsyncGeneratorNoYieldFrom]
405+
# flags: --fast-parser --python-version 3.6
406+
from mypy_extensions import AsyncGenerator
407+
408+
async def f() -> AsyncGenerator[int, None]:
409+
pass
410+
411+
async def gen() -> AsyncGenerator[int, None]:
412+
yield from f() # E: 'yield from' in async function
413+
414+
[builtins fixtures/dict.pyi]
415+
416+
[case testAsyncGeneratorNoReturnWithValue]
417+
# flags: --fast-parser --python-version 3.6
418+
from mypy_extensions import AsyncGenerator
419+
420+
async def gen() -> AsyncGenerator[int, None]:
421+
yield 1
422+
return 42
423+
424+
[builtins fixtures/dict.pyi]
325425

326426
-- The full matrix of coroutine compatibility
327427
-- ------------------------------------------

test-data/unit/fixtures/dict.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,5 @@ class list(Iterable[T], Generic[T]): # needed by some test cases
3333
class tuple: pass
3434
class function: pass
3535
class float: pass
36+
37+
class BaseException: pass

0 commit comments

Comments
 (0)