Skip to content

Commit 8f06d2c

Browse files
lmazuelkashifkhan
andauthored
Types in poller (#29228)
* Types in poller * Black it right * Saving this week's work * Move typing * Split base polling in two * Typing fixes * Typing update * Black * Unecessary Generic * Stringify types * Fix import * Spellcheck * Weird typo... * PyLint * More types * Update sdk/core/azure-core/azure/core/polling/async_base_polling.py Co-authored-by: Kashif Khan <[email protected]> * Missing type * Typing of the day * Re-enable verifytypes * Simplify the expectations async pipeline has on the response * Async Cxt Manager * Final Typing? * More covariant * Upside down * Fix tests * Messed up merge * Pylint * Better Typing * Final typing? * Pylint * Simplify translation typing for now * Fix backcompat with azure-mgmt-core * Revert renaming private methods * Black * Feedback from @kristapratico * Docstrings part 1 * Polling pylint part 2 * Black * All LRO impl should use TypeVar * Feedback * Convert some Anyu after feedback * Spellcheck * Black * Update sdk/core/azure-core/azure/core/polling/_async_poller.py * Update sdk/core/azure-core/azure/core/polling/_async_poller.py * Update sdk/core/azure-core/azure/core/polling/_poller.py --------- Co-authored-by: Kashif Khan <[email protected]>
1 parent 715008b commit 8f06d2c

File tree

15 files changed

+642
-279
lines changed

15 files changed

+642
-279
lines changed

.vscode/cspell.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
"ints",
216216
"iohttp",
217217
"IOHTTP",
218+
"IOLRO",
218219
"inprogress",
219220
"ipconfiguration",
220221
"ipconfigurations",

sdk/core/azure-core/azure/core/_pipeline_client_async.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@
3535
Generic,
3636
Optional,
3737
cast,
38-
TYPE_CHECKING,
3938
)
40-
from typing_extensions import Protocol
4139
from .configuration import Configuration
4240
from .pipeline import AsyncPipeline
4341
from .pipeline.transport._base import PipelineClientBase
@@ -51,17 +49,8 @@
5149
)
5250

5351

54-
if TYPE_CHECKING: # Protocol and non-Protocol can't mix in Python 3.7
55-
56-
class _AsyncContextManagerCloseable(AsyncContextManager, Protocol):
57-
"""Defines a context manager that is closeable at the same time."""
58-
59-
async def close(self):
60-
...
61-
62-
6352
HTTPRequestType = TypeVar("HTTPRequestType")
64-
AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", bound="_AsyncContextManagerCloseable")
53+
AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType", bound="AsyncContextManager")
6554

6655
_LOGGER = logging.getLogger(__name__)
6756

@@ -80,11 +69,9 @@ class _Coroutine(Awaitable[AsyncHTTPResponseType]):
8069
This allows the dev to either use the "async with" syntax, or simply the object directly.
8170
It's also why "send_request" is not declared as async, since it couldn't be both easily.
8271
83-
"wrapped" must be an awaitable that returns an object that:
84-
- has an async "close()"
85-
- has an "__aexit__" method (IOW, is an async context manager)
72+
"wrapped" must be an awaitable object that returns an object implements the async context manager protocol.
8673
87-
This permits this code to work for both requests.
74+
This permits this code to work for both following requests.
8875
8976
```python
9077
from azure.core import AsyncPipelineClient
@@ -124,9 +111,6 @@ async def __aenter__(self) -> AsyncHTTPResponseType:
124111
async def __aexit__(self, *args) -> None:
125112
await self._response.__aexit__(*args)
126113

127-
async def close(self) -> None:
128-
await self._response.close()
129-
130114

131115
class AsyncPipelineClient(
132116
PipelineClientBase,

sdk/core/azure-core/azure/core/pipeline/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727
from typing import TypeVar, Generic, Dict, Any
2828

29-
HTTPResponseType = TypeVar("HTTPResponseType")
30-
HTTPRequestType = TypeVar("HTTPRequestType")
29+
HTTPResponseType = TypeVar("HTTPResponseType", covariant=True)
30+
HTTPRequestType = TypeVar("HTTPRequestType", covariant=True)
3131

3232

3333
class PipelineContext(Dict[str, Any]):

sdk/core/azure-core/azure/core/pipeline/policies/_universal.py

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,46 +35,25 @@
3535
import types
3636
import re
3737
import uuid
38-
from typing import IO, cast, Union, Optional, AnyStr, Dict, MutableMapping, Any, Set, Mapping
38+
from typing import IO, cast, Union, Optional, AnyStr, Dict, Any, Set, Mapping
3939
import urllib.parse
40-
from typing_extensions import Protocol
4140

4241
from azure.core import __version__ as azcore_version
4342
from azure.core.exceptions import DecodeError
4443

4544
from azure.core.pipeline import PipelineRequest, PipelineResponse
4645
from ._base import SansIOHTTPPolicy
4746

47+
from ..transport import HttpRequest as LegacyHttpRequest
48+
from ..transport._base import _HttpResponseBase as LegacySansIOHttpResponse
49+
from ...rest import HttpRequest
50+
from ...rest._rest_py3 import _HttpResponseBase as SansIOHttpResponse
4851

4952
_LOGGER = logging.getLogger(__name__)
5053

51-
52-
class HTTPRequestType(Protocol):
53-
"""Protocol compatible with new rest request and legacy transport request"""
54-
55-
headers: MutableMapping[str, str]
56-
url: str
57-
method: str
58-
body: bytes
59-
60-
61-
class HTTPResponseType(Protocol):
62-
"""Protocol compatible with new rest response and legacy transport response"""
63-
64-
@property
65-
def headers(self) -> MutableMapping[str, str]:
66-
...
67-
68-
@property
69-
def status_code(self) -> int:
70-
...
71-
72-
@property
73-
def content_type(self) -> Optional[str]:
74-
...
75-
76-
def text(self, encoding: Optional[str] = None) -> str:
77-
...
54+
HTTPRequestType = Union[LegacyHttpRequest, HttpRequest]
55+
HTTPResponseType = Union[LegacySansIOHttpResponse, SansIOHttpResponse]
56+
PipelineResponseType = PipelineResponse[HTTPRequestType, HTTPResponseType]
7857

7958

8059
class HeadersPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):

sdk/core/azure-core/azure/core/pipeline/policies/_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,28 +25,34 @@
2525
# --------------------------------------------------------------------------
2626
import datetime
2727
import email.utils
28+
from typing import Optional, cast
29+
2830
from urllib.parse import urlparse
2931
from ...utils._utils import _FixedOffset, case_insensitive_dict
3032

3133

32-
def _parse_http_date(text):
34+
def _parse_http_date(text: str) -> datetime.datetime:
3335
"""Parse a HTTP date format into datetime.
3436
3537
:param str text: Text containing a date in HTTP format
3638
:rtype: datetime.datetime
3739
:return: The parsed datetime
3840
"""
3941
parsed_date = email.utils.parsedate_tz(text)
40-
return datetime.datetime(*parsed_date[:6], tzinfo=_FixedOffset(parsed_date[9] / 60))
42+
if not parsed_date:
43+
raise ValueError("Invalid HTTP date")
44+
tz_offset = cast(int, parsed_date[9]) # Look at the code, tz_offset is always an int, at worst 0
45+
return datetime.datetime(*parsed_date[:6], tzinfo=_FixedOffset(tz_offset / 60))
4146

4247

43-
def parse_retry_after(retry_after: str):
48+
def parse_retry_after(retry_after: str) -> float:
4449
"""Helper to parse Retry-After and get value in seconds.
4550
4651
:param str retry_after: Retry-After header
4752
:rtype: float
4853
:return: Value of Retry-After in seconds.
4954
"""
55+
delay: float # Using the Mypy recommendation to use float for "int or float"
5056
try:
5157
delay = int(retry_after)
5258
except ValueError:
@@ -56,7 +62,7 @@ def parse_retry_after(retry_after: str):
5662
return max(0, delay)
5763

5864

59-
def get_retry_after(response):
65+
def get_retry_after(response) -> Optional[float]:
6066
"""Get the value of Retry-After in seconds.
6167
6268
:param response: The PipelineResponse object

sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _iterate_response_content(iterator):
6060
raise _ResponseStopIteration() # pylint: disable=raise-missing-from
6161

6262

63-
class AsyncHttpResponse(_HttpResponseBase): # pylint: disable=abstract-method
63+
class AsyncHttpResponse(_HttpResponseBase, AbstractAsyncContextManager): # pylint: disable=abstract-method
6464
"""An AsyncHttpResponse ABC.
6565
6666
Allows for the asynchronous streaming of data from the response.
@@ -93,6 +93,9 @@ def parts(self) -> AsyncIterator:
9393

9494
return _PartGenerator(self, default_http_response_type=AsyncHttpClientTransportResponse)
9595

96+
async def __aexit__(self, exc_type, exc_value, traceback):
97+
return None
98+
9699

97100
class AsyncHttpClientTransportResponse( # pylint: disable=abstract-method
98101
_HttpClientTransportResponse, AsyncHttpResponse

sdk/core/azure-core/azure/core/polling/_async_poller.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,12 @@ async def run(self): # pylint:disable=invalid-overridden-method
7171
"""
7272

7373

74-
async def async_poller(client, initial_response, deserialization_callback, polling_method):
74+
async def async_poller(
75+
client: Any,
76+
initial_response: Any,
77+
deserialization_callback: Callable[[Any], PollingReturnType_co],
78+
polling_method: AsyncPollingMethod[PollingReturnType_co],
79+
) -> PollingReturnType_co:
7580
"""Async Poller for long running operations.
7681
7782
.. deprecated:: 1.5.0
@@ -86,6 +91,8 @@ async def async_poller(client, initial_response, deserialization_callback, polli
8691
:type deserialization_callback: callable or msrest.serialization.Model
8792
:param polling_method: The polling strategy to adopt
8893
:type polling_method: ~azure.core.polling.PollingMethod
94+
:return: The final resource at the end of the polling.
95+
:rtype: any or None
8996
"""
9097
poller = AsyncLROPoller(client, initial_response, deserialization_callback, polling_method)
9198
return await poller
@@ -109,7 +116,7 @@ def __init__(
109116
self,
110117
client: Any,
111118
initial_response: Any,
112-
deserialization_callback: Callable,
119+
deserialization_callback: Callable[[Any], PollingReturnType_co],
113120
polling_method: AsyncPollingMethod[PollingReturnType_co],
114121
):
115122
self._polling_method = polling_method
@@ -124,7 +131,11 @@ def __init__(
124131
self._polling_method.initialize(client, initial_response, deserialization_callback)
125132

126133
def polling_method(self) -> AsyncPollingMethod[PollingReturnType_co]:
127-
"""Return the polling method associated to this poller."""
134+
"""Return the polling method associated to this poller.
135+
136+
:return: The polling method associated to this poller.
137+
:rtype: ~azure.core.polling.AsyncPollingMethod
138+
"""
128139
return self._polling_method
129140

130141
def continuation_token(self) -> str:
@@ -158,6 +169,7 @@ async def result(self) -> PollingReturnType_co:
158169
"""Return the result of the long running operation.
159170
160171
:returns: The deserialized resource of the long running operation, if one is available.
172+
:rtype: any or None
161173
:raises ~azure.core.exceptions.HttpResponseError: Server problem with the query.
162174
"""
163175
await self.wait()

sdk/core/azure-core/azure/core/polling/_poller.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,35 +64,44 @@ def from_continuation_token(cls, continuation_token: str, **kwargs) -> Tuple[Any
6464
raise TypeError("Polling method '{}' doesn't support from_continuation_token".format(cls.__name__))
6565

6666

67-
class NoPolling(PollingMethod):
67+
class NoPolling(PollingMethod[PollingReturnType_co]):
6868
"""An empty poller that returns the deserialized initial response."""
6969

70+
_deserialization_callback: Callable[[Any], PollingReturnType_co]
71+
"""Deserialization callback passed during initialization"""
72+
7073
def __init__(self):
7174
self._initial_response = None
72-
self._deserialization_callback = None
7375

74-
def initialize(self, _: Any, initial_response: Any, deserialization_callback: Callable) -> None:
76+
def initialize(
77+
self,
78+
_: Any,
79+
initial_response: Any,
80+
deserialization_callback: Callable[[Any], PollingReturnType_co],
81+
) -> None:
7582
self._initial_response = initial_response
7683
self._deserialization_callback = deserialization_callback
7784

7885
def run(self) -> None:
7986
"""Empty run, no polling."""
8087

8188
def status(self) -> str:
82-
"""Return the current status as a string.
89+
"""Return the current status.
8390
8491
:rtype: str
92+
:return: The current status
8593
"""
8694
return "succeeded"
8795

8896
def finished(self) -> bool:
8997
"""Is this polling finished?
9098
9199
:rtype: bool
100+
:return: Whether this polling is finished
92101
"""
93102
return True
94103

95-
def resource(self) -> Any:
104+
def resource(self) -> PollingReturnType_co:
96105
return self._deserialization_callback(self._initial_response)
97106

98107
def get_continuation_token(self) -> str:
@@ -105,7 +114,7 @@ def from_continuation_token(cls, continuation_token: str, **kwargs) -> Tuple[Any
105114
try:
106115
deserialization_callback = kwargs["deserialization_callback"]
107116
except KeyError:
108-
raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token")
117+
raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") from None
109118
import pickle
110119

111120
initial_response = pickle.loads(base64.b64decode(continuation_token)) # nosec
@@ -130,7 +139,7 @@ def __init__(
130139
self,
131140
client: Any,
132141
initial_response: Any,
133-
deserialization_callback: Callable,
142+
deserialization_callback: Callable[[Any], PollingReturnType_co],
134143
polling_method: PollingMethod[PollingReturnType_co],
135144
) -> None:
136145
self._callbacks: List[Callable] = []
@@ -147,10 +156,11 @@ def __init__(
147156

148157
# Prepare thread execution
149158
self._thread = None
150-
self._done = None
159+
self._done = threading.Event()
151160
self._exception = None
152-
if not self._polling_method.finished():
153-
self._done = threading.Event()
161+
if self._polling_method.finished():
162+
self._done.set()
163+
else:
154164
self._thread = threading.Thread(
155165
target=with_current_context(self._start),
156166
name="LROPoller({})".format(uuid.uuid4()),
@@ -161,9 +171,6 @@ def __init__(
161171
def _start(self):
162172
"""Start the long running operation.
163173
On completion, runs any callbacks.
164-
165-
:param callable update_cmd: The API request to check the status of
166-
the operation.
167174
"""
168175
try:
169176
self._polling_method.run()
@@ -189,7 +196,11 @@ def _start(self):
189196
callbacks, self._callbacks = self._callbacks, []
190197

191198
def polling_method(self) -> PollingMethod[PollingReturnType_co]:
192-
"""Return the polling method associated to this poller."""
199+
"""Return the polling method associated to this poller.
200+
201+
:return: The polling method
202+
:rtype: ~azure.core.polling.PollingMethod
203+
"""
193204
return self._polling_method
194205

195206
def continuation_token(self) -> str:
@@ -223,8 +234,9 @@ def result(self, timeout: Optional[float] = None) -> PollingReturnType_co:
223234
"""Return the result of the long running operation, or
224235
the result available after the specified timeout.
225236
226-
:returns: The deserialized resource of the long running operation,
227-
if one is available.
237+
:param float timeout: Period of time to wait before getting back control.
238+
:returns: The deserialized resource of the long running operation, if one is available.
239+
:rtype: any or None
228240
:raises ~azure.core.exceptions.HttpResponseError: Server problem with the query.
229241
"""
230242
self.wait(timeout)
@@ -266,7 +278,7 @@ def add_done_callback(self, func: Callable) -> None:
266278
argument, a completed LongRunningOperation.
267279
"""
268280
# Still use "_done" and not "done", since CBs are executed inside the thread.
269-
if self._done is None or self._done.is_set():
281+
if self._done.is_set():
270282
func(self._polling_method)
271283
# Let's add them still, for consistency (if you wish to access to it for some reasons)
272284
self._callbacks.append(func)

0 commit comments

Comments
 (0)