Skip to content

Commit 178161e

Browse files
authored
Implement insert many functions (#22)
Here, finish the implementations on `#insert_many` and `#insert_many_tx`, which were previously marked as not yet finished. The Python plugin for sqlc doesn't implement `copyfrom`, so we instead fall back on an alternative approach from an experimental `database/sql` branch I have, which implements an insert many operation using a combination of Postgres arrays and `unnest`. We remove some `Optional` annotations on `JobInsertParams` for properties that are always set by the client. (Necessary to make MyPy's checks pass properly.)
1 parent 536051c commit 178161e

11 files changed

+330
-14
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
12+
- Implement `insert_many` and `insert_many_tx`. [PR #22](https://github.com/riverqueue/river/pull/22).
13+
1014
## [0.2.0] - 2024-07-04
1115

1216
### Changed

examples/all.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,21 @@
1111
import asyncio
1212

1313
from examples import async_client_insert_example
14+
from examples import async_client_insert_many_example
1415
from examples import async_client_insert_tx_example
1516
from examples import client_insert_example
17+
from examples import client_insert_many_example
18+
from examples import client_insert_many_insert_opts_example
1619
from examples import client_insert_tx_example
1720

1821
if __name__ == "__main__":
1922
asyncio.set_event_loop(asyncio.new_event_loop())
2023

2124
asyncio.run(async_client_insert_example.example())
25+
asyncio.run(async_client_insert_many_example.example())
2226
asyncio.run(async_client_insert_tx_example.example())
2327

2428
client_insert_example.example()
29+
client_insert_many_example.example()
30+
client_insert_many_insert_opts_example.example()
2531
client_insert_tx_example.example()
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#
2+
# Run with:
3+
#
4+
# rye run python3 -m examples.client_insert_many_example
5+
#
6+
7+
import asyncio
8+
from dataclasses import dataclass
9+
import json
10+
import riverqueue
11+
import sqlalchemy
12+
13+
from examples.helpers import dev_database_url
14+
from riverqueue.driver import riversqlalchemy
15+
16+
17+
@dataclass
18+
class CountArgs:
19+
count: int
20+
21+
kind: str = "sort"
22+
23+
def to_json(self) -> str:
24+
return json.dumps({"count": self.count})
25+
26+
27+
async def example():
28+
engine = sqlalchemy.ext.asyncio.create_async_engine(dev_database_url(is_async=True))
29+
client = riverqueue.AsyncClient(riversqlalchemy.AsyncDriver(engine))
30+
31+
num_inserted = await client.insert_many(
32+
[
33+
CountArgs(count=1),
34+
CountArgs(count=2),
35+
]
36+
)
37+
print(num_inserted)
38+
39+
40+
if __name__ == "__main__":
41+
asyncio.set_event_loop(asyncio.new_event_loop())
42+
asyncio.run(example())
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#
2+
# Run with:
3+
#
4+
# rye run python3 -m examples.client_insert_many_example
5+
#
6+
7+
from dataclasses import dataclass
8+
import json
9+
import riverqueue
10+
import sqlalchemy
11+
12+
from examples.helpers import dev_database_url
13+
from riverqueue.driver import riversqlalchemy
14+
15+
16+
@dataclass
17+
class CountArgs:
18+
count: int
19+
20+
kind: str = "sort"
21+
22+
def to_json(self) -> str:
23+
return json.dumps({"count": self.count})
24+
25+
26+
def example():
27+
engine = sqlalchemy.create_engine(dev_database_url())
28+
client = riverqueue.Client(riversqlalchemy.Driver(engine))
29+
30+
num_inserted = client.insert_many(
31+
[
32+
CountArgs(count=1),
33+
CountArgs(count=2),
34+
]
35+
)
36+
print(num_inserted)
37+
38+
39+
if __name__ == "__main__":
40+
example()
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#
2+
# Run with:
3+
#
4+
# rye run python3 -m examples.client_insert_many_example
5+
#
6+
7+
from dataclasses import dataclass
8+
import json
9+
import riverqueue
10+
import sqlalchemy
11+
12+
from examples.helpers import dev_database_url
13+
from riverqueue.driver import riversqlalchemy
14+
15+
16+
@dataclass
17+
class CountArgs:
18+
count: int
19+
20+
kind: str = "sort"
21+
22+
def to_json(self) -> str:
23+
return json.dumps({"count": self.count})
24+
25+
26+
def example():
27+
engine = sqlalchemy.create_engine(dev_database_url())
28+
client = riverqueue.Client(riversqlalchemy.Driver(engine))
29+
30+
num_inserted = client.insert_many(
31+
[
32+
riverqueue.InsertManyParams(
33+
CountArgs(count=1),
34+
insert_opts=riverqueue.InsertOpts(max_attempts=5),
35+
),
36+
riverqueue.InsertManyParams(
37+
CountArgs(count=2),
38+
insert_opts=riverqueue.InsertOpts(queue="alternate_queue"),
39+
),
40+
]
41+
)
42+
print(num_inserted)
43+
44+
45+
if __name__ == "__main__":
46+
example()

src/riverqueue/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def _make_insert_params(
320320
queue=insert_opts.queue or args_insert_opts.queue or QUEUE_DEFAULT,
321321
scheduled_at=scheduled_at and scheduled_at.astimezone(timezone.utc),
322322
state="scheduled" if scheduled_at else "available",
323-
tags=insert_opts.tags or args_insert_opts.tags,
323+
tags=insert_opts.tags or args_insert_opts.tags or [],
324324
)
325325

326326
return insert_params, unique_opts

src/riverqueue/driver/driver_protocol.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@ class GetParams:
2727
@dataclass
2828
class JobInsertParams:
2929
kind: str
30-
args: Optional[Any] = None
30+
args: Any = None
3131
created_at: Optional[datetime] = None
3232
finalized_at: Optional[datetime] = None
3333
metadata: Optional[Any] = None
34-
max_attempts: Optional[int] = field(default=25)
35-
priority: Optional[int] = field(default=1)
36-
queue: Optional[str] = field(default="default")
34+
max_attempts: int = field(default=25)
35+
priority: int = field(default=1)
36+
queue: str = field(default="default")
3737
scheduled_at: Optional[datetime] = None
38-
state: Optional[str] = field(default="available")
39-
tags: Optional[List[str]] = field(default_factory=list)
38+
state: str = field(default="available")
39+
tags: list[str] = field(default_factory=list)
4040

4141

4242
class AsyncExecutorProtocol(Protocol):

src/riverqueue/driver/riversqlalchemy/dbsqlc/river_job.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,46 @@ class JobInsertFastParams:
8181
tags: List[str]
8282

8383

84+
JOB_INSERT_FAST_MANY = """-- name: job_insert_fast_many \\:execrows
85+
INSERT INTO river_job(
86+
args,
87+
kind,
88+
max_attempts,
89+
metadata,
90+
priority,
91+
queue,
92+
scheduled_at,
93+
state,
94+
tags
95+
) SELECT
96+
unnest(:p1\\:\\:jsonb[]),
97+
unnest(:p2\\:\\:text[]),
98+
unnest(:p3\\:\\:smallint[]),
99+
unnest(:p4\\:\\:jsonb[]),
100+
unnest(:p5\\:\\:smallint[]),
101+
unnest(:p6\\:\\:text[]),
102+
unnest(:p7\\:\\:timestamptz[]),
103+
unnest(:p8\\:\\:river_job_state[]),
104+
105+
-- Had trouble getting multi-dimensional arrays to play nicely with sqlc,
106+
-- but it might be possible. For now, join tags into a single string.
107+
string_to_array(unnest(:p9\\:\\:text[]), ',')
108+
"""
109+
110+
111+
@dataclasses.dataclass()
112+
class JobInsertFastManyParams:
113+
args: List[Any]
114+
kind: List[str]
115+
max_attempts: List[int]
116+
metadata: List[Any]
117+
priority: List[int]
118+
queue: List[str]
119+
scheduled_at: List[datetime.datetime]
120+
state: List[models.RiverJobState]
121+
tags: List[str]
122+
123+
84124
class Querier:
85125
def __init__(self, conn: sqlalchemy.engine.Connection):
86126
self._conn = conn
@@ -154,6 +194,20 @@ def job_insert_fast(self, arg: JobInsertFastParams) -> Optional[models.RiverJob]
154194
tags=row[15],
155195
)
156196

197+
def job_insert_fast_many(self, arg: JobInsertFastManyParams) -> int:
198+
result = self._conn.execute(sqlalchemy.text(JOB_INSERT_FAST_MANY), {
199+
"p1": arg.args,
200+
"p2": arg.kind,
201+
"p3": arg.max_attempts,
202+
"p4": arg.metadata,
203+
"p5": arg.priority,
204+
"p6": arg.queue,
205+
"p7": arg.scheduled_at,
206+
"p8": arg.state,
207+
"p9": arg.tags,
208+
})
209+
return result.rowcount
210+
157211

158212
class AsyncQuerier:
159213
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
@@ -227,3 +281,17 @@ async def job_insert_fast(self, arg: JobInsertFastParams) -> Optional[models.Riv
227281
scheduled_at=row[14],
228282
tags=row[15],
229283
)
284+
285+
async def job_insert_fast_many(self, arg: JobInsertFastManyParams) -> int:
286+
result = await self._conn.execute(sqlalchemy.text(JOB_INSERT_FAST_MANY), {
287+
"p1": arg.args,
288+
"p2": arg.kind,
289+
"p3": arg.max_attempts,
290+
"p4": arg.metadata,
291+
"p5": arg.priority,
292+
"p6": arg.queue,
293+
"p7": arg.scheduled_at,
294+
"p8": arg.state,
295+
"p9": arg.tags,
296+
})
297+
return result.rowcount

src/riverqueue/driver/riversqlalchemy/dbsqlc/river_job.sql

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,29 @@ INSERT INTO river_job(
6969
coalesce(sqlc.narg('scheduled_at')::timestamptz, now()),
7070
@state::river_job_state,
7171
coalesce(@tags::varchar(255)[], '{}')
72-
) RETURNING *;
72+
) RETURNING *;
73+
74+
-- name: JobInsertFastMany :execrows
75+
INSERT INTO river_job(
76+
args,
77+
kind,
78+
max_attempts,
79+
metadata,
80+
priority,
81+
queue,
82+
scheduled_at,
83+
state,
84+
tags
85+
) SELECT
86+
unnest(@args::jsonb[]),
87+
unnest(@kind::text[]),
88+
unnest(@max_attempts::smallint[]),
89+
unnest(@metadata::jsonb[]),
90+
unnest(@priority::smallint[]),
91+
unnest(@queue::text[]),
92+
unnest(@scheduled_at::timestamptz[]),
93+
unnest(@state::river_job_state[]),
94+
95+
-- Had trouble getting multi-dimensional arrays to play nicely with sqlc,
96+
-- but it might be possible. For now, join tags into a single string.
97+
string_to_array(unnest(@tags::text[]), ',');

src/riverqueue/driver/riversqlalchemy/sql_alchemy_driver.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
asynccontextmanager,
33
contextmanager,
44
)
5+
from datetime import datetime, timezone
56
from riverqueue.driver.driver_protocol import AsyncDriverProtocol, AsyncExecutorProtocol
67
from sqlalchemy import Engine
78
from sqlalchemy.engine import Connection
@@ -16,7 +17,7 @@
1617

1718
from ...driver import DriverProtocol, ExecutorProtocol, GetParams, JobInsertParams
1819
from ...model import Job
19-
from .dbsqlc import river_job, pg_misc
20+
from .dbsqlc import models, river_job, pg_misc
2021

2122

2223
class AsyncExecutor(AsyncExecutorProtocol):
@@ -36,8 +37,11 @@ async def job_insert(self, insert_params: JobInsertParams) -> Job:
3637
),
3738
)
3839

39-
async def job_insert_many(self, all_params) -> int:
40-
raise NotImplementedError("sqlc doesn't implement copy in python yet")
40+
async def job_insert_many(self, all_params: list[JobInsertParams]) -> int:
41+
await self.job_querier.job_insert_fast_many(
42+
_build_insert_many_params(all_params)
43+
)
44+
return len(all_params)
4145

4246
async def job_get_by_kind_and_unique_properties(
4347
self, get_params: GetParams
@@ -94,8 +98,9 @@ def job_insert(self, insert_params: JobInsertParams) -> Job:
9498
),
9599
)
96100

97-
def job_insert_many(self, all_params) -> int:
98-
raise NotImplementedError("sqlc doesn't implement copy in python yet")
101+
def job_insert_many(self, all_params: list[JobInsertParams]) -> int:
102+
self.job_querier.job_insert_fast_many(_build_insert_many_params(all_params))
103+
return len(all_params)
99104

100105
def job_get_by_kind_and_unique_properties(
101106
self, get_params: GetParams
@@ -133,3 +138,34 @@ def executor(self) -> Iterator[ExecutorProtocol]:
133138

134139
def unwrap_executor(self, tx) -> ExecutorProtocol:
135140
return Executor(tx)
141+
142+
143+
def _build_insert_many_params(
144+
all_params: list[JobInsertParams],
145+
) -> river_job.JobInsertFastManyParams:
146+
insert_many_params = river_job.JobInsertFastManyParams(
147+
args=[],
148+
kind=[],
149+
max_attempts=[],
150+
metadata=[],
151+
priority=[],
152+
queue=[],
153+
scheduled_at=[],
154+
state=[],
155+
tags=[],
156+
)
157+
158+
for insert_params in all_params:
159+
insert_many_params.args.append(insert_params.args)
160+
insert_many_params.kind.append(insert_params.kind)
161+
insert_many_params.max_attempts.append(insert_params.max_attempts)
162+
insert_many_params.metadata.append(insert_params.metadata or "{}")
163+
insert_many_params.priority.append(insert_params.priority)
164+
insert_many_params.queue.append(insert_params.queue)
165+
insert_many_params.scheduled_at.append(
166+
insert_params.scheduled_at or datetime.now(timezone.utc)
167+
)
168+
insert_many_params.state.append(cast(models.RiverJobState, insert_params.state))
169+
insert_many_params.tags.append(",".join(insert_params.tags))
170+
171+
return insert_many_params

0 commit comments

Comments
 (0)