13
13
# limitations under the License.
14
14
15
15
import logging
16
- from typing import Optional , Tuple
16
+ from typing import List , Optional , Tuple , Union
17
17
18
18
import attr
19
19
20
20
from synapse .api .constants import RelationTypes
21
21
from synapse .events import EventBase
22
22
from synapse .storage ._base import SQLBaseStore
23
+ from synapse .storage .database import LoggingTransaction
23
24
from synapse .storage .databases .main .stream import generate_pagination_where_clause
24
25
from synapse .storage .relations import (
25
26
AggregationPaginationToken ,
@@ -63,7 +64,7 @@ async def get_relations_for_event(
63
64
"""
64
65
65
66
where_clause = ["relates_to_id = ?" ]
66
- where_args = [event_id ]
67
+ where_args : List [ Union [ str , int ]] = [event_id ]
67
68
68
69
if relation_type is not None :
69
70
where_clause .append ("relation_type = ?" )
@@ -80,8 +81,8 @@ async def get_relations_for_event(
80
81
pagination_clause = generate_pagination_where_clause (
81
82
direction = direction ,
82
83
column_names = ("topological_ordering" , "stream_ordering" ),
83
- from_token = attr .astuple (from_token ) if from_token else None ,
84
- to_token = attr .astuple (to_token ) if to_token else None ,
84
+ from_token = attr .astuple (from_token ) if from_token else None , # type: ignore[arg-type]
85
+ to_token = attr .astuple (to_token ) if to_token else None , # type: ignore[arg-type]
85
86
engine = self .database_engine ,
86
87
)
87
88
@@ -106,7 +107,9 @@ async def get_relations_for_event(
106
107
order ,
107
108
)
108
109
109
- def _get_recent_references_for_event_txn (txn ):
110
+ def _get_recent_references_for_event_txn (
111
+ txn : LoggingTransaction ,
112
+ ) -> PaginationChunk :
110
113
txn .execute (sql , where_args + [limit + 1 ])
111
114
112
115
last_topo_id = None
@@ -160,7 +163,7 @@ async def get_aggregation_groups_for_event(
160
163
"""
161
164
162
165
where_clause = ["relates_to_id = ?" , "relation_type = ?" ]
163
- where_args = [event_id , RelationTypes .ANNOTATION ]
166
+ where_args : List [ Union [ str , int ]] = [event_id , RelationTypes .ANNOTATION ]
164
167
165
168
if event_type :
166
169
where_clause .append ("type = ?" )
@@ -169,8 +172,8 @@ async def get_aggregation_groups_for_event(
169
172
having_clause = generate_pagination_where_clause (
170
173
direction = direction ,
171
174
column_names = ("COUNT(*)" , "MAX(stream_ordering)" ),
172
- from_token = attr .astuple (from_token ) if from_token else None ,
173
- to_token = attr .astuple (to_token ) if to_token else None ,
175
+ from_token = attr .astuple (from_token ) if from_token else None , # type: ignore[arg-type]
176
+ to_token = attr .astuple (to_token ) if to_token else None , # type: ignore[arg-type]
174
177
engine = self .database_engine ,
175
178
)
176
179
@@ -199,7 +202,9 @@ async def get_aggregation_groups_for_event(
199
202
having_clause = having_clause ,
200
203
)
201
204
202
- def _get_aggregation_groups_for_event_txn (txn ):
205
+ def _get_aggregation_groups_for_event_txn (
206
+ txn : LoggingTransaction ,
207
+ ) -> PaginationChunk :
203
208
txn .execute (sql , where_args + [limit + 1 ])
204
209
205
210
next_batch = None
@@ -254,11 +259,12 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
254
259
LIMIT 1
255
260
"""
256
261
257
- def _get_applicable_edit_txn (txn ) :
262
+ def _get_applicable_edit_txn (txn : LoggingTransaction ) -> Optional [ str ] :
258
263
txn .execute (sql , (event_id , RelationTypes .REPLACE ))
259
264
row = txn .fetchone ()
260
265
if row :
261
266
return row [0 ]
267
+ return None
262
268
263
269
edit_id = await self .db_pool .runInteraction (
264
270
"get_applicable_edit" , _get_applicable_edit_txn
@@ -267,7 +273,7 @@ def _get_applicable_edit_txn(txn):
267
273
if not edit_id :
268
274
return None
269
275
270
- return await self .get_event (edit_id , allow_none = True )
276
+ return await self .get_event (edit_id , allow_none = True ) # type: ignore[attr-defined]
271
277
272
278
@cached ()
273
279
async def get_thread_summary (
@@ -283,7 +289,9 @@ async def get_thread_summary(
283
289
The number of items in the thread and the most recent response, if any.
284
290
"""
285
291
286
- def _get_thread_summary_txn (txn ) -> Tuple [int , Optional [str ]]:
292
+ def _get_thread_summary_txn (
293
+ txn : LoggingTransaction ,
294
+ ) -> Tuple [int , Optional [str ]]:
287
295
# Fetch the count of threaded events and the latest event ID.
288
296
# TODO Should this only allow m.room.message events.
289
297
sql = """
@@ -312,7 +320,7 @@ def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
312
320
AND relation_type = ?
313
321
"""
314
322
txn .execute (sql , (event_id , RelationTypes .THREAD ))
315
- count = txn .fetchone ()[0 ]
323
+ count = txn .fetchone ()[0 ] # type: ignore[index]
316
324
317
325
return count , latest_event_id
318
326
@@ -322,7 +330,7 @@ def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
322
330
323
331
latest_event = None
324
332
if latest_event_id :
325
- latest_event = await self .get_event (latest_event_id , allow_none = True )
333
+ latest_event = await self .get_event (latest_event_id , allow_none = True ) # type: ignore[attr-defined]
326
334
327
335
return count , latest_event
328
336
@@ -354,7 +362,7 @@ async def has_user_annotated_event(
354
362
LIMIT 1;
355
363
"""
356
364
357
- def _get_if_user_has_annotated_event (txn ) :
365
+ def _get_if_user_has_annotated_event (txn : LoggingTransaction ) -> bool :
358
366
txn .execute (
359
367
sql ,
360
368
(
0 commit comments