Skip to content

Commit 7ab8ae8

Browse files
committed
Refactor: unit test to use AsyncDb instead of Sync
1 parent 9e90632 commit 7ab8ae8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1121
-984
lines changed

tests/api/conftest.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
from aleph.jobs.process_pending_messages import PendingMessageProcessor
2525
from aleph.storage import StorageService
2626
from aleph.toolkit.timestamp import timestamp_to_datetime, utc_now
27-
from aleph.types.db_session import DbSessionFactory
27+
from aleph.types.db_session import AsyncDbSessionFactory
2828
from aleph.types.message_status import MessageStatus
2929

3030

3131
# TODO: remove the raw parameter, it's just to avoid larger refactorings
3232
async def _load_fixtures(
33-
session_factory: DbSessionFactory, filename: str, raw: bool = True
33+
session_factory: AsyncDbSessionFactory, filename: str, raw: bool = True
3434
) -> Sequence[Dict[str, Any]]:
3535
fixtures_dir = Path(__file__).parent / "fixtures"
3636
fixtures_file = fixtures_dir / filename
@@ -41,7 +41,7 @@ async def _load_fixtures(
4141
messages = []
4242
tx_hashes = set()
4343

44-
with session_factory() as session:
44+
async with session_factory() as session:
4545
for message_dict in messages_json:
4646
message_db = MessageDb.from_message_dict(message_dict)
4747
messages.append(message_db)
@@ -53,8 +53,8 @@ async def _load_fixtures(
5353
tx_hashes.add(tx_hash)
5454
session.add(chain_tx_db)
5555

56-
session.flush()
57-
session.execute(
56+
await session.flush()
57+
await session.execute(
5858
insert(message_confirmations).values(
5959
item_hash=message_db.item_hash, tx_hash=tx_hash
6060
)
@@ -67,14 +67,14 @@ async def _load_fixtures(
6767
)
6868
session.add(message_status)
6969

70-
session.commit()
70+
await session.commit()
7171

7272
return messages_json if raw else messages
7373

7474

7575
@pytest_asyncio.fixture
7676
async def fixture_messages(
77-
session_factory: DbSessionFactory,
77+
session_factory: AsyncDbSessionFactory,
7878
) -> Sequence[Dict[str, Any]]:
7979
return await _load_fixtures(session_factory, "fixture_messages.json")
8080

@@ -94,23 +94,23 @@ def make_aggregate_element(message: MessageDb) -> AggregateElementDb:
9494

9595
@pytest_asyncio.fixture
9696
async def fixture_aggregate_messages(
97-
session_factory: DbSessionFactory,
97+
session_factory: AsyncDbSessionFactory,
9898
) -> Sequence[MessageDb]:
9999
messages = await _load_fixtures(
100100
session_factory, "fixture_aggregates.json", raw=False
101101
)
102102
aggregate_keys = set()
103-
with session_factory() as session:
103+
async with session_factory() as session:
104104
for message in messages:
105105
aggregate_element = make_aggregate_element(message) # type: ignore
106106
session.add(aggregate_element)
107107
aggregate_keys.add((aggregate_element.owner, aggregate_element.key))
108-
session.commit()
108+
await session.commit()
109109

110110
for owner, key in aggregate_keys:
111-
refresh_aggregate(session=session, owner=owner, key=key)
111+
await refresh_aggregate(session=session, owner=owner, key=key)
112112

113-
session.commit()
113+
await session.commit()
114114

115115
return messages # type: ignore
116116

@@ -131,14 +131,14 @@ def make_post_db(message: MessageDb) -> PostDb:
131131

132132
@pytest_asyncio.fixture
133133
async def fixture_posts(
134-
session_factory: DbSessionFactory,
134+
session_factory: AsyncDbSessionFactory,
135135
) -> Sequence[PostDb]:
136136
messages = await _load_fixtures(session_factory, "fixture_posts.json", raw=False)
137137
posts = [make_post_db(message) for message in messages] # type: ignore
138138

139-
with session_factory() as session:
139+
async with session_factory() as session:
140140
session.add_all(posts)
141-
session.commit()
141+
await session.commit()
142142

143143
return posts
144144

@@ -222,7 +222,9 @@ def amended_post_with_refs_and_tags(
222222

223223

224224
@pytest.fixture
225-
def message_processor(mocker, mock_config: Config, session_factory: DbSessionFactory):
225+
def message_processor(
226+
mocker, mock_config: Config, session_factory: AsyncDbSessionFactory
227+
):
226228
storage_engine = InMemoryStorageEngine(files={})
227229
storage_service = StorageService(
228230
storage_engine=storage_engine,

tests/api/test_get_message.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Mapping, Sequence
33

44
import pytest
5+
import pytest_asyncio
56
import pytz
67
from aleph_message.models import Chain, ItemType, MessageType
78

@@ -20,7 +21,7 @@
2021
)
2122
from aleph.toolkit.timestamp import timestamp_to_datetime
2223
from aleph.types.channel import Channel
23-
from aleph.types.db_session import DbSessionFactory
24+
from aleph.types.db_session import AsyncDbSessionFactory
2425
from aleph.types.message_status import ErrorCode, MessageStatus
2526

2627
MESSAGE_URI = "/api/v0/messages/{}"
@@ -30,9 +31,9 @@
3031
RECEPTION_DATETIME = pytz.utc.localize(dt.datetime(2023, 1, 1))
3132

3233

33-
@pytest.fixture
34-
def fixture_messages_with_status(
35-
session_factory: DbSessionFactory,
34+
@pytest_asyncio.fixture
35+
async def fixture_messages_with_status(
36+
session_factory: AsyncDbSessionFactory,
3637
) -> Mapping[MessageStatus, Sequence[Any]]:
3738

3839
pending_messages = [
@@ -171,7 +172,7 @@ def fixture_messages_with_status(
171172
MessageStatus.REJECTED: rejected_messages,
172173
}
173174

174-
with session_factory() as session:
175+
async with session_factory() as session:
175176
for status, messages in messages_dict.items():
176177
for message in messages:
177178
session.add(message)
@@ -182,7 +183,7 @@ def fixture_messages_with_status(
182183
reception_time=RECEPTION_DATETIME,
183184
)
184185
)
185-
session.commit()
186+
await session.commit()
186187

187188
return messages_dict
188189

tests/api/test_list_messages.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from aleph.db.models.messages import MessageStatusDb
2323
from aleph.toolkit.timestamp import timestamp_to_datetime, utc_now
2424
from aleph.types.channel import Channel
25-
from aleph.types.db_session import DbSessionFactory
25+
from aleph.types.db_session import AsyncDbSessionFactory
2626
from aleph.types.message_status import MessageStatus
2727

2828
from .utils import get_messages_by_keys
@@ -172,7 +172,7 @@ async def test_get_messages_multiple_hashes(fixture_messages, ccn_api_client):
172172
async def test_get_messages_filter_by_tags(
173173
fixture_messages,
174174
ccn_api_client,
175-
session_factory: DbSessionFactory,
175+
session_factory: AsyncDbSessionFactory,
176176
post_with_refs_and_tags: Tuple[MessageDb, PostDb, MessageStatusDb],
177177
amended_post_with_refs_and_tags: Tuple[MessageDb, PostDb, MessageStatusDb],
178178
):
@@ -184,11 +184,11 @@ async def test_get_messages_filter_by_tags(
184184
message_db, _, message_status_db = post_with_refs_and_tags
185185
amend_message_db, _, amend_message_status_db = amended_post_with_refs_and_tags
186186

187-
with session_factory() as session:
187+
async with session_factory() as session:
188188
session.add_all(
189189
[message_db, message_status_db, amend_message_db, amend_message_status_db]
190190
)
191-
session.commit()
191+
await session.commit()
192192

193193
# Matching tag for both messages
194194
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "mainnet"})
@@ -564,13 +564,13 @@ def instance_message_fixture() -> Tuple[MessageDb, MessageStatusDb]:
564564
async def test_get_instance(
565565
ccn_api_client,
566566
instance_message_fixture: Tuple[MessageDb, MessageStatusDb],
567-
session_factory: DbSessionFactory,
567+
session_factory: AsyncDbSessionFactory,
568568
):
569569

570570
message_db, status_db = instance_message_fixture
571-
with session_factory() as session:
571+
async with session_factory() as session:
572572
session.add_all([message_db, status_db])
573-
session.commit()
573+
await session.commit()
574574

575575
response = await ccn_api_client.get(
576576
MESSAGES_URI, params={"hashes": message_db.item_hash}

tests/api/test_new_metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import pytest_asyncio
66

7-
from aleph.types.db_session import DbSessionFactory
7+
from aleph.types.db_session import AsyncDbSessionFactory
88

99
from .conftest import _load_fixtures
1010

@@ -15,7 +15,7 @@ def _generate_uri(node_type: str, node_id: str) -> str:
1515

1616
@pytest_asyncio.fixture
1717
async def fixture_metrics_messages(
18-
session_factory: DbSessionFactory,
18+
session_factory: AsyncDbSessionFactory,
1919
) -> Sequence[Dict[str, Any]]:
2020
return await _load_fixtures(session_factory, "test-metric.json")
2121

tests/api/test_posts.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from aleph.db.models import MessageDb
77
from aleph.db.models.messages import MessageStatusDb
88
from aleph.db.models.posts import PostDb
9-
from aleph.types.db_session import DbSessionFactory
9+
from aleph.types.db_session import AsyncDbSessionFactory
1010

1111
POSTS_URI = "/api/v1/posts.json"
1212

@@ -57,16 +57,16 @@ async def test_get_posts(ccn_api_client, fixture_posts: Sequence[PostDb]):
5757
@pytest.mark.asyncio
5858
async def test_get_posts_refs(
5959
ccn_api_client,
60-
session_factory: DbSessionFactory,
60+
session_factory: AsyncDbSessionFactory,
6161
fixture_posts: Sequence[PostDb],
6262
post_with_refs_and_tags: Tuple[MessageDb, PostDb, MessageStatusDb],
6363
):
6464
message_db, post_db, message_status_db = post_with_refs_and_tags
6565

66-
with session_factory() as session:
66+
async with session_factory() as session:
6767
session.add_all(fixture_posts)
6868
session.add_all([message_db, post_db, message_status_db])
69-
session.commit()
69+
await session.commit()
7070

7171
# Match the ref
7272
response = await ccn_api_client.get(
@@ -111,7 +111,7 @@ async def test_get_posts_refs(
111111
@pytest.mark.asyncio
112112
async def test_get_amended_posts_refs(
113113
ccn_api_client,
114-
session_factory: DbSessionFactory,
114+
session_factory: AsyncDbSessionFactory,
115115
fixture_posts: Sequence[PostDb],
116116
post_with_refs_and_tags: Tuple[MessageDb, PostDb, MessageStatusDb],
117117
amended_post_with_refs_and_tags: Tuple[MessageDb, PostDb, MessageStatusDb],
@@ -125,13 +125,13 @@ async def test_get_amended_posts_refs(
125125

126126
original_post_db.latest_amend = amend_post_db.item_hash
127127

128-
with session_factory() as session:
128+
async with session_factory() as session:
129129
session.add_all(fixture_posts)
130130
session.add_all(
131131
[original_message_db, original_post_db, original_message_status_db]
132132
)
133133
session.add_all([amend_message_db, amend_post_db, amend_message_status_db])
134-
session.commit()
134+
await session.commit()
135135

136136
# Match the ref
137137
response = await ccn_api_client.get(
@@ -176,16 +176,16 @@ async def test_get_amended_posts_refs(
176176
@pytest.mark.asyncio
177177
async def test_get_posts_tags(
178178
ccn_api_client,
179-
session_factory: DbSessionFactory,
179+
session_factory: AsyncDbSessionFactory,
180180
fixture_posts: Sequence[PostDb],
181181
post_with_refs_and_tags: Tuple[MessageDb, PostDb, MessageStatusDb],
182182
):
183183
message_db, post_db, message_status_db = post_with_refs_and_tags
184184

185-
with session_factory() as session:
185+
async with session_factory() as session:
186186
session.add_all(fixture_posts)
187187
session.add_all([message_db, post_db, message_status_db])
188-
session.commit()
188+
await session.commit()
189189

190190
# Match one tag
191191
response = await ccn_api_client.get(
@@ -245,7 +245,7 @@ async def test_get_posts_tags(
245245
@pytest.mark.asyncio
246246
async def test_get_amended_posts_tags(
247247
ccn_api_client,
248-
session_factory: DbSessionFactory,
248+
session_factory: AsyncDbSessionFactory,
249249
fixture_posts: Sequence[PostDb],
250250
post_with_refs_and_tags: Tuple[MessageDb, PostDb, MessageStatusDb],
251251
amended_post_with_refs_and_tags: Tuple[MessageDb, PostDb, MessageStatusDb],
@@ -259,13 +259,13 @@ async def test_get_amended_posts_tags(
259259

260260
original_post_db.latest_amend = amend_post_db.item_hash
261261

262-
with session_factory() as session:
262+
async with session_factory() as session:
263263
session.add_all(fixture_posts)
264264
session.add_all(
265265
[original_message_db, original_post_db, original_message_status_db]
266266
)
267267
session.add_all([amend_message_db, amend_post_db, amend_message_status_db])
268-
session.commit()
268+
await session.commit()
269269

270270
# Match one tag
271271
response = await ccn_api_client.get("/api/v0/posts.json", params={"tags": "amend"})

0 commit comments

Comments
 (0)