Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 33a02f0

Browse files
authored
Fix additional type hints from Twisted upgrade. (#9518)
1 parent 4db07f9 commit 33a02f0

File tree

12 files changed

+96
-61
lines changed

12 files changed

+96
-61
lines changed

changelog.d/9518.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix incorrect type hints.

synapse/http/federation/matrix_federation_agent.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
import logging
1616
import urllib.parse
17-
from typing import List, Optional
17+
from typing import Any, Generator, List, Optional
1818

1919
from netaddr import AddrFormatError, IPAddress, IPSet
2020
from zope.interface import implementer
@@ -116,7 +116,7 @@ def request(
116116
uri: bytes,
117117
headers: Optional[Headers] = None,
118118
bodyProducer: Optional[IBodyProducer] = None,
119-
) -> defer.Deferred:
119+
) -> Generator[defer.Deferred, Any, defer.Deferred]:
120120
"""
121121
Args:
122122
method: HTTP method: GET/POST/etc
@@ -177,17 +177,17 @@ def request(
177177
# We need to make sure the host header is set to the netloc of the
178178
# server and that a user-agent is provided.
179179
if headers is None:
180-
headers = Headers()
180+
request_headers = Headers()
181181
else:
182-
headers = headers.copy()
182+
request_headers = headers.copy()
183183

184-
if not headers.hasHeader(b"host"):
185-
headers.addRawHeader(b"host", parsed_uri.netloc)
186-
if not headers.hasHeader(b"user-agent"):
187-
headers.addRawHeader(b"user-agent", self.user_agent)
184+
if not request_headers.hasHeader(b"host"):
185+
request_headers.addRawHeader(b"host", parsed_uri.netloc)
186+
if not request_headers.hasHeader(b"user-agent"):
187+
request_headers.addRawHeader(b"user-agent", self.user_agent)
188188

189189
res = yield make_deferred_yieldable(
190-
self._agent.request(method, uri, headers, bodyProducer)
190+
self._agent.request(method, uri, request_headers, bodyProducer)
191191
)
192192

193193
return res

synapse/http/matrixfederationclient.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None:
10491049
RequestSendFailed: if the Content-Type header is missing or isn't JSON
10501050
10511051
"""
1052-
c_type = headers.getRawHeaders(b"Content-Type")
1053-
if c_type is None:
1052+
content_type_headers = headers.getRawHeaders(b"Content-Type")
1053+
if content_type_headers is None:
10541054
raise RequestSendFailed(
10551055
RuntimeError("No Content-Type header received from remote server"),
10561056
can_retry=False,
10571057
)
10581058

1059-
c_type = c_type[0].decode("ascii") # only the first header
1059+
c_type = content_type_headers[0].decode("ascii") # only the first header
10601060
val, options = cgi.parse_header(c_type)
10611061
if val != "application/json":
10621062
raise RequestSendFailed(

synapse/http/server.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import types
2222
import urllib
2323
from http import HTTPStatus
24+
from inspect import isawaitable
2425
from io import BytesIO
2526
from typing import (
2627
Any,
@@ -30,6 +31,7 @@
3031
Iterable,
3132
Iterator,
3233
List,
34+
Optional,
3335
Pattern,
3436
Tuple,
3537
Union,
@@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
7981
"""Sends a JSON error response to clients."""
8082

8183
if f.check(SynapseError):
82-
error_code = f.value.code
83-
error_dict = f.value.error_dict()
84+
# mypy doesn't understand that f.check asserts the type.
85+
exc = f.value # type: SynapseError # type: ignore
86+
error_code = exc.code
87+
error_dict = exc.error_dict()
8488

85-
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
89+
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
8690
else:
8791
error_code = 500
8892
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
@@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
9195
"Failed handle request via %r: %r",
9296
request.request_metrics.name,
9397
request,
94-
exc_info=(f.type, f.value, f.getTracebackObject()),
98+
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
9599
)
96100

97101
# Only respond with an error response if we haven't already started writing,
@@ -128,7 +132,8 @@ def return_html_error(
128132
`{msg}` placeholders), or a jinja2 template
129133
"""
130134
if f.check(CodeMessageException):
131-
cme = f.value
135+
# mypy doesn't understand that f.check asserts the type.
136+
cme = f.value # type: CodeMessageException # type: ignore
132137
code = cme.code
133138
msg = cme.msg
134139

@@ -142,7 +147,7 @@ def return_html_error(
142147
logger.error(
143148
"Failed handle request %r",
144149
request,
145-
exc_info=(f.type, f.value, f.getTracebackObject()),
150+
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
146151
)
147152
else:
148153
code = HTTPStatus.INTERNAL_SERVER_ERROR
@@ -151,7 +156,7 @@ def return_html_error(
151156
logger.error(
152157
"Failed handle request %r",
153158
request,
154-
exc_info=(f.type, f.value, f.getTracebackObject()),
159+
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
155160
)
156161

157162
if isinstance(error_template, str):
@@ -278,7 +283,7 @@ async def _async_render(self, request: Request):
278283
raw_callback_return = method_handler(request)
279284

280285
# Is it synchronous? We'll allow this for now.
281-
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
286+
if isawaitable(raw_callback_return):
282287
callback_return = await raw_callback_return
283288
else:
284289
callback_return = raw_callback_return # type: ignore
@@ -399,8 +404,10 @@ def _get_handler_for_request(
399404
A tuple of the callback to use, the name of the servlet, and the
400405
key word arguments to pass to the callback
401406
"""
407+
# At this point the path must be bytes.
408+
request_path_bytes = request.path # type: bytes # type: ignore
409+
request_path = request_path_bytes.decode("ascii")
402410
# Treat HEAD requests as GET requests.
403-
request_path = request.path.decode("ascii")
404411
request_method = request.method
405412
if request_method == b"HEAD":
406413
request_method = b"GET"
@@ -551,7 +558,7 @@ def __init__(
551558
request: Request,
552559
iterator: Iterator[bytes],
553560
):
554-
self._request = request
561+
self._request = request # type: Optional[Request]
555562
self._iterator = iterator
556563
self._paused = False
557564

@@ -563,7 +570,7 @@ def _send_data(self, data: List[bytes]) -> None:
563570
"""
564571
Send a list of bytes as a chunk of a response.
565572
"""
566-
if not data:
573+
if not data or not self._request:
567574
return
568575
self._request.write(b"".join(data))
569576

synapse/http/site.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import contextlib
1515
import logging
1616
import time
17-
from typing import Optional, Union
17+
from typing import Optional, Type, Union
1818

1919
import attr
2020
from zope.interface import implementer
@@ -57,7 +57,7 @@ class SynapseRequest(Request):
5757

5858
def __init__(self, channel, *args, **kw):
5959
Request.__init__(self, channel, *args, **kw)
60-
self.site = channel.site
60+
self.site = channel.site # type: SynapseSite
6161
self._channel = channel # this is used by the tests
6262
self.start_time = 0.0
6363

@@ -96,25 +96,34 @@ def __repr__(self):
9696
def get_request_id(self):
9797
return "%s-%i" % (self.get_method(), self.request_seq)
9898

99-
def get_redacted_uri(self):
100-
uri = self.uri
99+
def get_redacted_uri(self) -> str:
100+
"""Gets the redacted URI associated with the request (or placeholder if the URI
101+
has not yet been received).
102+
103+
Note: This is necessary as the placeholder value in twisted is str
104+
rather than bytes, so we need to sanitise `self.uri`.
105+
106+
Returns:
107+
The redacted URI as a string.
108+
"""
109+
uri = self.uri # type: Union[bytes, str]
101110
if isinstance(uri, bytes):
102-
uri = self.uri.decode("ascii", errors="replace")
111+
uri = uri.decode("ascii", errors="replace")
103112
return redact_uri(uri)
104113

105-
def get_method(self):
106-
"""Gets the method associated with the request (or placeholder if not
107-
method has yet been received).
114+
def get_method(self) -> str:
115+
"""Gets the method associated with the request (or placeholder if method
116+
has not yet been received).
108117
109118
Note: This is necessary as the placeholder value in twisted is str
110119
rather than bytes, so we need to sanitise `self.method`.
111120
112121
Returns:
113-
str
122+
The request method as a string.
114123
"""
115-
method = self.method
124+
method = self.method # type: Union[bytes, str]
116125
if isinstance(method, bytes):
117-
method = self.method.decode("ascii")
126+
return self.method.decode("ascii")
118127
return method
119128

120129
def render(self, resrc):
@@ -432,7 +441,9 @@ def __init__(
432441

433442
assert config.http_options is not None
434443
proxied = config.http_options.x_forwarded
435-
self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
444+
self.requestFactory = (
445+
XForwardedForRequest if proxied else SynapseRequest
446+
) # type: Type[Request]
436447
self.access_logger = logging.getLogger(logger_name)
437448
self.server_version_string = server_version_string.encode("ascii")
438449

synapse/logging/_remote.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
TCP4ClientEndpoint,
3333
TCP6ClientEndpoint,
3434
)
35-
from twisted.internet.interfaces import IPushProducer, ITransport
35+
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
3636
from twisted.internet.protocol import Factory, Protocol
3737
from twisted.python.failure import Failure
3838

@@ -121,7 +121,9 @@ def __init__(
121121
try:
122122
ip = ip_address(self.host)
123123
if isinstance(ip, IPv4Address):
124-
endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
124+
endpoint = TCP4ClientEndpoint(
125+
_reactor, self.host, self.port
126+
) # type: IStreamClientEndpoint
125127
elif isinstance(ip, IPv6Address):
126128
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
127129
else:

synapse/metrics/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def collect(self):
527527
REGISTRY.register(ReactorLastSeenMetric())
528528

529529

530-
def runUntilCurrentTimer(func):
530+
def runUntilCurrentTimer(reactor, func):
531531
@functools.wraps(func)
532532
def f(*args, **kwargs):
533533
now = reactor.seconds()
@@ -590,13 +590,14 @@ def f(*args, **kwargs):
590590

591591
try:
592592
# Ensure the reactor has all the attributes we expect
593-
reactor.runUntilCurrent
594-
reactor._newTimedCalls
595-
reactor.threadCallQueue
593+
reactor.seconds # type: ignore
594+
reactor.runUntilCurrent # type: ignore
595+
reactor._newTimedCalls # type: ignore
596+
reactor.threadCallQueue # type: ignore
596597

597598
# runUntilCurrent is called when we have pending calls. It is called once
598599
# per iteratation after fd polling.
599-
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
600+
reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent) # type: ignore
600601

601602
# We manually run the GC each reactor tick so that we can get some metrics
602603
# about time spent doing GC,

synapse/module_api/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
import logging
17-
from typing import TYPE_CHECKING, Iterable, Optional, Tuple
17+
from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple
1818

1919
from twisted.internet import defer
2020

@@ -307,7 +307,7 @@ async def complete_sso_login_async(
307307
@defer.inlineCallbacks
308308
def get_state_events_in_room(
309309
self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
310-
) -> defer.Deferred:
310+
) -> Generator[defer.Deferred, Any, defer.Deferred]:
311311
"""Gets current state events for the given room.
312312
313313
(This is exposed for compatibility with the old SpamCheckerApi. We should

synapse/push/httppusher.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
# limitations under the License.
1616
import logging
1717
import urllib.parse
18-
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
18+
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
1919

2020
from prometheus_client import Counter
2121

2222
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
23+
from twisted.internet.interfaces import IDelayedCall
2324

2425
from synapse.api.constants import EventTypes
2526
from synapse.events import EventBase
@@ -71,7 +72,7 @@ def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
7172
self.data = pusher_config.data
7273
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
7374
self.failing_since = pusher_config.failing_since
74-
self.timed_call = None
75+
self.timed_call = None # type: Optional[IDelayedCall]
7576
self._is_processing = False
7677
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
7778
self._pusherpool = hs.get_pusherpool()

synapse/replication/tcp/client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ def __init__(self, hs: "HomeServer"):
108108

109109
# Map from stream to list of deferreds waiting for the stream to
110110
# arrive at a particular position. The lists are sorted by stream position.
111-
self._streams_to_waiters = (
112-
{}
113-
) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
111+
self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]]
114112

115113
async def on_rdata(
116114
self, stream_name: str, instance_name: str, token: int, rows: list

synapse/server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
import twisted.internet.base
4040
import twisted.internet.tcp
41+
from twisted.internet import defer
4142
from twisted.mail.smtp import sendmail
4243
from twisted.web.iweb import IPolicyForHTTPS
4344

@@ -403,7 +404,7 @@ def get_room_shutdown_handler(self) -> RoomShutdownHandler:
403404
return RoomShutdownHandler(self)
404405

405406
@cache_in_self
406-
def get_sendmail(self) -> sendmail:
407+
def get_sendmail(self) -> Callable[..., defer.Deferred]:
407408
return sendmail
408409

409410
@cache_in_self

0 commit comments

Comments
 (0)