Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 56e281b

Browse files
authored
Additional type hints for relations database class. (#11205)
1 parent 0e16b41 commit 56e281b

File tree

3 files changed

+25
-15
lines changed

3 files changed

+25
-15
lines changed

changelog.d/11205.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve type hints for the relations datastore.

mypy.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ files =
5353
synapse/storage/databases/main/keys.py,
5454
synapse/storage/databases/main/pusher.py,
5555
synapse/storage/databases/main/registration.py,
56+
synapse/storage/databases/main/relations.py,
5657
synapse/storage/databases/main/session.py,
5758
synapse/storage/databases/main/stream.py,
5859
synapse/storage/databases/main/ui_auth.py,

synapse/storage/databases/main/relations.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Optional, Tuple
16+
from typing import List, Optional, Tuple, Union
1717

1818
import attr
1919

2020
from synapse.api.constants import RelationTypes
2121
from synapse.events import EventBase
2222
from synapse.storage._base import SQLBaseStore
23+
from synapse.storage.database import LoggingTransaction
2324
from synapse.storage.databases.main.stream import generate_pagination_where_clause
2425
from synapse.storage.relations import (
2526
AggregationPaginationToken,
@@ -63,7 +64,7 @@ async def get_relations_for_event(
6364
"""
6465

6566
where_clause = ["relates_to_id = ?"]
66-
where_args = [event_id]
67+
where_args: List[Union[str, int]] = [event_id]
6768

6869
if relation_type is not None:
6970
where_clause.append("relation_type = ?")
@@ -80,8 +81,8 @@ async def get_relations_for_event(
8081
pagination_clause = generate_pagination_where_clause(
8182
direction=direction,
8283
column_names=("topological_ordering", "stream_ordering"),
83-
from_token=attr.astuple(from_token) if from_token else None,
84-
to_token=attr.astuple(to_token) if to_token else None,
84+
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
85+
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
8586
engine=self.database_engine,
8687
)
8788

@@ -106,7 +107,9 @@ async def get_relations_for_event(
106107
order,
107108
)
108109

109-
def _get_recent_references_for_event_txn(txn):
110+
def _get_recent_references_for_event_txn(
111+
txn: LoggingTransaction,
112+
) -> PaginationChunk:
110113
txn.execute(sql, where_args + [limit + 1])
111114

112115
last_topo_id = None
@@ -160,7 +163,7 @@ async def get_aggregation_groups_for_event(
160163
"""
161164

162165
where_clause = ["relates_to_id = ?", "relation_type = ?"]
163-
where_args = [event_id, RelationTypes.ANNOTATION]
166+
where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
164167

165168
if event_type:
166169
where_clause.append("type = ?")
@@ -169,8 +172,8 @@ async def get_aggregation_groups_for_event(
169172
having_clause = generate_pagination_where_clause(
170173
direction=direction,
171174
column_names=("COUNT(*)", "MAX(stream_ordering)"),
172-
from_token=attr.astuple(from_token) if from_token else None,
173-
to_token=attr.astuple(to_token) if to_token else None,
175+
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
176+
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
174177
engine=self.database_engine,
175178
)
176179

@@ -199,7 +202,9 @@ async def get_aggregation_groups_for_event(
199202
having_clause=having_clause,
200203
)
201204

202-
def _get_aggregation_groups_for_event_txn(txn):
205+
def _get_aggregation_groups_for_event_txn(
206+
txn: LoggingTransaction,
207+
) -> PaginationChunk:
203208
txn.execute(sql, where_args + [limit + 1])
204209

205210
next_batch = None
@@ -254,11 +259,12 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
254259
LIMIT 1
255260
"""
256261

257-
def _get_applicable_edit_txn(txn):
262+
def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
258263
txn.execute(sql, (event_id, RelationTypes.REPLACE))
259264
row = txn.fetchone()
260265
if row:
261266
return row[0]
267+
return None
262268

263269
edit_id = await self.db_pool.runInteraction(
264270
"get_applicable_edit", _get_applicable_edit_txn
@@ -267,7 +273,7 @@ def _get_applicable_edit_txn(txn):
267273
if not edit_id:
268274
return None
269275

270-
return await self.get_event(edit_id, allow_none=True)
276+
return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined]
271277

272278
@cached()
273279
async def get_thread_summary(
@@ -283,7 +289,9 @@ async def get_thread_summary(
283289
The number of items in the thread and the most recent response, if any.
284290
"""
285291

286-
def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
292+
def _get_thread_summary_txn(
293+
txn: LoggingTransaction,
294+
) -> Tuple[int, Optional[str]]:
287295
# Fetch the count of threaded events and the latest event ID.
288296
# TODO Should this only allow m.room.message events.
289297
sql = """
@@ -312,7 +320,7 @@ def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
312320
AND relation_type = ?
313321
"""
314322
txn.execute(sql, (event_id, RelationTypes.THREAD))
315-
count = txn.fetchone()[0]
323+
count = txn.fetchone()[0] # type: ignore[index]
316324

317325
return count, latest_event_id
318326

@@ -322,7 +330,7 @@ def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
322330

323331
latest_event = None
324332
if latest_event_id:
325-
latest_event = await self.get_event(latest_event_id, allow_none=True)
333+
latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]
326334

327335
return count, latest_event
328336

@@ -354,7 +362,7 @@ async def has_user_annotated_event(
354362
LIMIT 1;
355363
"""
356364

357-
def _get_if_user_has_annotated_event(txn):
365+
def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
358366
txn.execute(
359367
sql,
360368
(

0 commit comments

Comments
 (0)