forked from 05bit/peewee-async
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpeewee_async.py
1539 lines (1218 loc) · 44.7 KB
/
peewee_async.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
peewee-async
============
Asynchronous interface for `peewee`_ ORM powered by `asyncio`_:
https://github.com/05bit/peewee-async
.. _peewee: https://github.com/coleifer/peewee
.. _asyncio: https://docs.python.org/3/library/asyncio.html
Licensed under The MIT License (MIT)
Copyright (c) 2014, Alexey Kinëv <[email protected]>
"""
import asyncio
import contextlib
import functools
import logging
import uuid
import warnings
import peewee
from playhouse.db_url import register_database
IntegrityErrors = (peewee.IntegrityError,)
try:
import aiopg
import psycopg2
IntegrityErrors += (psycopg2.IntegrityError,)
except ImportError:
aiopg = None
psycopg2 = None
try:
import aiomysql
import pymysql
except ImportError:
aiomysql = None
pymysql = None
try:
asyncio_current_task = asyncio.current_task
except AttributeError:
asyncio_current_task = asyncio.Task.current_task
__version__ = '0.8.0'
__all__ = [
# High level API ###
'Manager',
'PostgresqlDatabase',
'PooledPostgresqlDatabase',
'MySQLDatabase',
'PooledMySQLDatabase',
# Low level API ###
'execute',
'count',
'scalar',
'atomic',
'transaction',
'savepoint',
# Deprecated ###
'get_object',
'create_object',
'delete_object',
'update_object',
'sync_unwanted',
'UnwantedSyncQueryError',
]
__log__ = logging.getLogger('peewee.async')
__log__.addHandler(logging.NullHandler())
#################
# Async manager #
#################
class Manager:
"""Async peewee models manager.
:param loop: (optional) asyncio event loop
:param database: (optional) async database driver
Example::
class User(peewee.Model):
username = peewee.CharField(max_length=40, unique=True)
objects = Manager(PostgresqlDatabase('test'))
async def my_async_func():
user0 = await objects.create(User, username='test')
user1 = await objects.get(User, id=user0.id)
user2 = await objects.get(User, username='test')
# All should be the same
print(user1.id, user2.id, user3.id)
If you don't pass database to constructor, you should define
``database`` as a class member like that::
database = PostgresqlDatabase('test')
class MyManager(Manager):
database = database
objects = MyManager()
"""
#: Async database driver for manager. Must be provided
#: in constructor or as a class member.
database = None
def __init__(self, database=None, *, loop=None):
assert database or self.database, \
("Error, database must be provided via "
"argument or class member.")
self.database = database or self.database
self._loop = loop
self._timeout = getattr(database, 'timeout', None)
attach_callback = getattr(self.database, 'attach_callback', None)
if attach_callback:
attach_callback(lambda db: setattr(db, '_loop', loop))
else:
self.database._loop = loop
@property
def loop(self):
"""Get the event loop.
If no event loop is provided explicitly on creating
the instance, just return the current event loop.
"""
return self._loop or asyncio.get_event_loop()
@property
def is_connected(self):
"""Check if database is connected.
"""
return self.database._async_conn is not None
async def get(self, source_, *args, **kwargs):
"""Get the model instance.
:param source_: model or base query for lookup
Example::
async def my_async_func():
obj1 = await objects.get(MyModel, id=1)
obj2 = await objects.get(MyModel, MyModel.id==1)
obj3 = await objects.get(MyModel.select().where(MyModel.id==1))
All will return `MyModel` instance with `id = 1`
"""
await self.connect()
if isinstance(source_, peewee.Query):
query = source_
model = query.model
else:
query = source_.select()
model = source_
conditions = list(args) + [(getattr(model, k) == v)
for k, v in kwargs.items()]
if conditions:
query = query.where(*conditions)
try:
result = await self.execute(query)
return list(result)[0]
except IndexError:
raise model.DoesNotExist
async def create(self, model_, **data):
"""Create a new object saved to database.
"""
inst = model_(**data)
query = model_.insert(**dict(inst.__data__))
pk = await self.execute(query)
if inst._pk is None:
inst._pk = pk
return inst
async def get_or_create(self, model_, defaults=None, **kwargs):
"""Try to get an object or create it with the specified defaults.
Return 2-tuple containing the model instance and a boolean
indicating whether the instance was created.
"""
try:
return (await self.get(model_, **kwargs)), False
except model_.DoesNotExist:
data = defaults or {}
data.update({k: v for k, v in kwargs.items() if '__' not in k})
return (await self.create(model_, **data)), True
async def update(self, obj, only=None):
"""Update the object in the database. Optionally, update only
the specified fields. For creating a new object use :meth:`.create()`
:param only: (optional) the list/tuple of fields or
field names to update
"""
field_dict = dict(obj.__data__)
pk_field = obj._meta.primary_key
if only:
self._prune_fields(field_dict, only)
if obj._meta.only_save_dirty:
self._prune_fields(field_dict, obj.dirty_fields)
if obj._meta.composite_key:
for pk_part_name in pk_field.field_names:
field_dict.pop(pk_part_name, None)
else:
field_dict.pop(pk_field.name, None)
query = obj.update(**field_dict).where(obj._pk_expr())
result = await self.execute(query)
obj._dirty.clear()
return result
async def delete(self, obj, recursive=False, delete_nullable=False):
"""Delete object from database.
"""
if recursive:
dependencies = obj.dependencies(delete_nullable)
for cond, fk in reversed(list(dependencies)):
model = fk.model
if fk.null and not delete_nullable:
sq = model.update(**{fk.name: None}).where(cond)
else:
sq = model.delete().where(cond)
await self.execute(sq)
query = obj.delete().where(obj._pk_expr())
return (await self.execute(query))
async def create_or_get(self, model_, **kwargs):
"""Try to create new object with specified data. If object already
exists, then try to get it by unique fields.
"""
try:
return (await self.create(model_, **kwargs)), True
except IntegrityErrors:
query = []
for field_name, value in kwargs.items():
field = getattr(model_, field_name)
if field.unique or field.primary_key:
query.append(field == value)
return (await self.get(model_, *query)), False
async def execute(self, query):
"""Execute query asyncronously.
"""
query = self._swap_database(query)
return (await execute(query))
async def prefetch(self, query, *subqueries):
"""Asynchronous version of the `prefetch()` from peewee.
:return: Query that has already cached data for subqueries
"""
query = self._swap_database(query)
subqueries = map(self._swap_database, subqueries)
return (await prefetch(query, *subqueries))
async def count(self, query, clear_limit=False):
"""Perform *COUNT* aggregated query asynchronously.
:return: number of objects in ``select()`` query
"""
query = self._swap_database(query)
return (await count(query, clear_limit=clear_limit))
async def scalar(self, query, as_tuple=False):
"""Get single value from ``select()`` query, i.e. for aggregation.
:return: result is the same as after sync ``query.scalar()`` call
"""
query = self._swap_database(query)
return (await scalar(query, as_tuple=as_tuple))
async def connect(self):
"""Open database async connection if not connected.
"""
await self.database.connect_async(loop=self.loop, timeout=self._timeout)
async def close(self):
"""Close database async connection if connected.
"""
await self.database.close_async()
def atomic(self):
"""Similar to `peewee.Database.atomic()` method, but returns
**asynchronous** context manager.
Example::
async with objects.atomic():
await objects.create(
PageBlock, key='intro',
text="There are more things in heaven and earth, "
"Horatio, than are dreamt of in your philosophy.")
await objects.create(
PageBlock, key='signature', text="William Shakespeare")
"""
return atomic(self.database)
def transaction(self):
"""Similar to `peewee.Database.transaction()` method, but returns
**asynchronous** context manager.
"""
return transaction(self.database)
def savepoint(self, sid=None):
"""Similar to `peewee.Database.savepoint()` method, but returns
**asynchronous** context manager.
"""
return savepoint(self.database, sid=sid)
def allow_sync(self):
"""Allow sync queries within context. Close the sync
database connection on exit if connected.
Example::
with objects.allow_sync():
PageBlock.create_table(True)
"""
return self.database.allow_sync()
def _swap_database(self, query):
"""Swap database for query if swappable. Return **new query**
with swapped database.
This is experimental feature which allows us to have multiple
managers configured against different databases for single model
definition.
The essential limitation though is that database backend have
to be **the same type** for model and manager!
"""
database = _query_db(query)
if database == self.database:
return query
if self._subclassed(peewee.PostgresqlDatabase, database,
self.database):
can_swap = True
elif self._subclassed(peewee.MySQLDatabase, database,
self.database):
can_swap = True
else:
can_swap = False
if can_swap:
# **Experimental** database swapping!
query = query.clone()
query._database = self.database
return query
assert False, (
"Error, query's database and manager's database are "
"different. Query: %s Manager: %s" % (database, self.database)
)
return None
@staticmethod
def _subclassed(base, *classes):
"""Check if all classes are subclassed from base.
"""
return all(map(lambda obj: isinstance(obj, base), classes))
@staticmethod
def _prune_fields(field_dict, only):
"""Filter fields data **in place** with `only` list.
Example::
self._prune_fields(field_dict, ['slug', 'text'])
self._prune_fields(field_dict, [MyModel.slug])
"""
fields = [(isinstance(f, str) and f or f.name) for f in only]
for f in list(field_dict.keys()):
if f not in fields:
field_dict.pop(f)
return field_dict
#################
# Async queries #
#################
async def execute(query):
"""Execute *SELECT*, *INSERT*, *UPDATE* or *DELETE* query asyncronously.
:param query: peewee query instance created with ``Model.select()``,
``Model.update()`` etc.
:return: result depends on query type, it's the same as for sync
``query.execute()``
"""
if isinstance(query, (peewee.Select, peewee.ModelCompoundSelectQuery)):
coroutine = select
elif isinstance(query, peewee.Update):
coroutine = update
elif isinstance(query, peewee.Insert):
coroutine = insert
elif isinstance(query, peewee.Delete):
coroutine = delete
else:
coroutine = raw_query
return (await coroutine(query))
async def create_object(model, **data):
"""Create object asynchronously.
:param model: mode class
:param data: data for initializing object
:return: new object saved to database
"""
# NOTE! Here are internals involved:
#
# - obj._data
# - obj._get_pk_value()
# - obj._set_pk_value()
# - obj._prepare_instance()
#
warnings.warn("create_object() is deprecated, Manager.create() "
"should be used instead",
DeprecationWarning)
obj = model(**data)
pk = await insert(model.insert(**dict(obj.__data__)))
if obj._pk is None:
obj._pk = pk
return obj
async def get_object(source, *args):
"""Get object asynchronously.
:param source: mode class or query to get object from
:param args: lookup parameters
:return: model instance or raises ``peewee.DoesNotExist`` if object not
found
"""
warnings.warn("get_object() is deprecated, Manager.get() "
"should be used instead",
DeprecationWarning)
if isinstance(source, peewee.Query):
query = source
model = query.model
else:
query = source.select()
model = source
# Return first object from query
for obj in (await select(query.where(*args))):
return obj
# No objects found
raise model.DoesNotExist
async def delete_object(obj, recursive=False, delete_nullable=False):
"""Delete object asynchronously.
:param obj: object to delete
:param recursive: if ``True`` also delete all other objects depends on
object
:param delete_nullable: if `True` and delete is recursive then delete even
'nullable' dependencies
For details please check out `Model.delete_instance()`_ in peewee docs.
.. _Model.delete_instance(): http://peewee.readthedocs.io/en/latest/peewee/
api.html#Model.delete_instance
"""
warnings.warn("delete_object() is deprecated, Manager.delete() "
"should be used instead",
DeprecationWarning)
# Here are private calls involved:
# - obj._pk_expr()
if recursive:
dependencies = obj.dependencies(delete_nullable)
for query, fk in reversed(list(dependencies)):
model = fk.model
if fk.null and not delete_nullable:
await update(model.update(**{fk.name: None}).where(query))
else:
await delete(model.delete().where(query))
result = await delete(obj.delete().where(obj._pk_expr()))
return result
async def update_object(obj, only=None):
"""Update object asynchronously.
:param obj: object to update
:param only: list or tuple of fields to updata, is `None` then all fields
updated
This function does the same as `Model.save()`_ for already saved object,
but it doesn't invoke ``save()`` method on model class. That is
important to know if you overrided save method for your model.
.. _Model.save(): http://peewee.readthedocs.io/en/latest/peewee/
api.html#Model.save
"""
# Here are private calls involved:
#
# - obj._data
# - obj._meta
# - obj._prune_fields()
# - obj._pk_expr()
# - obj._dirty.clear()
#
warnings.warn("update_object() is deprecated, Manager.update() "
"should be used instead",
DeprecationWarning)
field_dict = dict(obj.__data__)
pk_field = obj._meta.primary_key
if only:
field_dict = obj._prune_fields(field_dict, only)
if not isinstance(pk_field, peewee.CompositeKey):
field_dict.pop(pk_field.name, None)
else:
field_dict = obj._prune_fields(field_dict, obj.dirty_fields)
rows = await update(obj.update(**field_dict).where(obj._pk_expr()))
obj._dirty.clear()
return rows
async def _execute_with_returning(query):
cursor = await _execute_query_async(query)
result = AsyncQueryWrapper(cursor=cursor, query=query)
try:
await result.fetchall()
finally:
await cursor.release()
return result
async def select(query):
"""Perform SELECT query asynchronously.
"""
assert isinstance(query, peewee.SelectQuery),\
("Error, trying to run select coroutine"
"with wrong query class %s" % str(query))
return await _execute_with_returning(query)
async def insert(query):
"""Perform INSERT query asynchronously. Returns last insert ID.
This function is called by object.create for single objects only.
"""
assert isinstance(query, peewee.Insert),\
("Error, trying to run insert coroutine"
"with wrong query class %s" % str(query))
if query._returning is not None and len(query._returning) > 1:
return await _execute_with_returning(query)
cursor = await _execute_query_async(query)
try:
if query._returning:
row = await cursor.fetchone()
if row is not None:
result = row[0]
else:
result = None
else:
database = _query_db(query)
last_id = await database.last_insert_id_async(cursor)
result = last_id
finally:
await cursor.release()
return result
async def update(query):
"""Perform UPDATE query asynchronously. Returns number of rows updated.
"""
assert isinstance(query, peewee.Update),\
("Error, trying to run update coroutine"
"with wrong query class %s" % str(query))
if query._returning:
return await _execute_with_returning(query)
cursor = await _execute_query_async(query)
rowcount = cursor.rowcount
await cursor.release()
return rowcount
async def delete(query):
"""Perform DELETE query asynchronously. Returns number of rows deleted.
"""
assert isinstance(query, peewee.Delete),\
("Error, trying to run delete coroutine"
"with wrong query class %s" % str(query))
if query._returning:
return await _execute_with_returning(query)
cursor = await _execute_query_async(query)
rowcount = cursor.rowcount
await cursor.release()
return rowcount
async def count(query, clear_limit=False):
"""Perform *COUNT* aggregated query asynchronously.
:return: number of objects in ``select()`` query
"""
clone = query.clone()
if query._distinct or query._group_by or query._limit or query._offset:
if clear_limit:
clone._limit = clone._offset = None
sql, params = clone.sql()
wrapped = 'SELECT COUNT(1) FROM (%s) AS wrapped_select' % sql
raw = query.model.raw(wrapped, *params)
return (await scalar(raw)) or 0
else:
clone._returning = [peewee.fn.Count(peewee.SQL('*'))]
clone._order_by = None
return (await scalar(clone)) or 0
async def scalar(query, as_tuple=False):
"""Get single value from ``select()`` query, i.e. for aggregation.
:return: result is the same as after sync ``query.scalar()`` call
"""
cursor = await _execute_query_async(query)
try:
row = await cursor.fetchone()
finally:
await cursor.release()
if row and not as_tuple:
return row[0]
else:
return row
async def raw_query(query):
assert isinstance(query, peewee.RawQuery),\
("Error, trying to run raw_query coroutine"
"with wrong query class %s" % str(query))
cursor = await _execute_query_async(query)
result = AsyncQueryWrapper(cursor=cursor, query=query)
try:
while True:
await result.fetchone()
except GeneratorExit:
pass
finally:
await cursor.release()
return result
async def prefetch(sq, *subqueries):
"""Asynchronous version of the `prefetch()` from peewee.
"""
if not subqueries:
result = await execute(sq)
return result
fixed_queries = peewee.prefetch_add_subquery(sq, subqueries)
deps = {}
rel_map = {}
for pq in reversed(fixed_queries):
query_model = pq.model
if pq.fields:
for rel_model in pq.rel_models:
rel_map.setdefault(rel_model, [])
rel_map[rel_model].append(pq)
deps[query_model] = {}
id_map = deps[query_model]
has_relations = bool(rel_map.get(query_model))
result = await execute(pq.query)
for instance in result:
if pq.fields:
pq.store_instance(instance, id_map)
if has_relations:
for rel in rel_map[query_model]:
rel.populate_instance(instance, deps[rel.model])
return result
###################
# Result wrappers #
###################
class RowsCursor(object):
def __init__(self, rows, description):
self._rows = rows
self.description = description
self._idx = 0
def fetchone(self):
if self._idx >= len(self._rows):
return None
row = self._rows[self._idx]
self._idx += 1
return row
def close(self):
pass
class AsyncQueryWrapper:
"""Async query results wrapper for async `select()`. Internally uses
results wrapper produced by sync peewee select query.
Arguments:
result_wrapper -- empty results wrapper produced by sync `execute()`
call cursor -- async cursor just executed query
To retrieve results after async fetching just iterate over this class
instance, like you generally iterate over sync results wrapper.
"""
def __init__(self, *, cursor=None, query=None):
self._cursor = cursor
self._rows = []
self._result_cache = None
self._result_wrapper = self._get_result_wrapper(query)
def __iter__(self):
return iter(self._result_wrapper)
def __len__(self):
return len(self._rows)
def __getitem__(self, idx):
# NOTE: side effects will appear when both
# iterating and accessing by index!
if self._result_cache is None:
self._result_cache = list(self)
return self._result_cache[idx]
def _get_result_wrapper(self, query):
"""Get result wrapper class.
"""
cursor = RowsCursor(self._rows, self._cursor.description)
return query._get_cursor_wrapper(cursor)
async def fetchone(self):
"""Fetch single row from the cursor.
"""
row = await self._cursor.fetchone()
if not row:
raise GeneratorExit
self._rows.append(row)
async def fetchall(self):
try:
while True:
await self.fetchone()
except GeneratorExit:
pass
############
# Database #
############
class AsyncDatabase:
_loop = None # asyncio event loop
_timeout = None # connection timeout
_allow_sync = True # whether sync queries are allowed
_async_conn = None # async connection
_async_wait = None # connection waiter
_task_data = None # asyncio per-task data
def __setattr__(self, name, value):
if name == 'allow_sync':
warnings.warn(
"`.allow_sync` setter is deprecated, use either the "
"`.allow_sync()` context manager or `.set_allow_sync()` "
"method.", DeprecationWarning)
self._allow_sync = value
else:
super().__setattr__(name, value)
@property
def loop(self):
"""Get the event loop.
If no event loop is provided explicitly on creating
the instance, just return the current event loop.
"""
return self._loop or asyncio.get_event_loop()
async def connect_async(self, loop=None, timeout=None):
"""Set up async connection on specified event loop or
on default event loop.
"""
if self.deferred:
raise Exception("Error, database not properly initialized "
"before opening connection")
if self._async_conn:
return
elif self._async_wait:
await self._async_wait
else:
self._loop = loop
self._async_wait = asyncio.Future(loop=self._loop)
if not timeout and self._timeout:
timeout = self._timeout
conn = self._async_conn_cls(
database=self.database,
loop=self._loop,
timeout=timeout,
**self.connect_params_async)
try:
await conn.connect()
except Exception as e:
if not self._async_wait.done():
self._async_wait.set_exception(e)
self._async_wait = None
raise
else:
self._task_data = TaskLocals(loop=self._loop)
self._async_conn = conn
self._async_wait.set_result(True)
async def cursor_async(self):
"""Acquire async cursor.
"""
await self.connect_async(loop=self._loop)
if self.transaction_depth_async() > 0:
conn = self.transaction_conn_async()
else:
conn = None
try:
return (await self._async_conn.cursor(conn=conn))
except:
await self.close_async()
raise
async def close_async(self):
"""Close async connection.
"""
if self._async_wait:
await self._async_wait
if self._async_conn:
conn = self._async_conn
self._async_conn = None
self._async_wait = None
self._task_data = None
await conn.close()
async def push_transaction_async(self):
"""Increment async transaction depth.
"""
await self.connect_async(loop=self.loop)
depth = self.transaction_depth_async()
if not depth:
conn = await self._async_conn.acquire()
self._task_data.set('conn', conn)
self._task_data.set('depth', depth + 1)
async def pop_transaction_async(self):
"""Decrement async transaction depth.
"""
depth = self.transaction_depth_async()
if depth > 0:
depth -= 1
self._task_data.set('depth', depth)
if depth == 0:
conn = self._task_data.get('conn')
self._async_conn.release(conn)
else:
raise ValueError("Invalid async transaction depth value")
def transaction_depth_async(self):
"""Get async transaction depth.
"""
return self._task_data.get('depth', 0) if self._task_data else 0
def transaction_conn_async(self):
"""Get async transaction connection.
"""
return self._task_data.get('conn', None) if self._task_data else None
def transaction_async(self):
"""Similar to peewee `Database.transaction()` method, but returns
asynchronous context manager.
"""
return transaction(self)
def atomic_async(self):
"""Similar to peewee `Database.atomic()` method, but returns
asynchronous context manager.
"""
return atomic(self)
def savepoint_async(self, sid=None):
"""Similar to peewee `Database.savepoint()` method, but returns
asynchronous context manager.
"""
return savepoint(self, sid=sid)
def set_allow_sync(self, value):
"""Allow or forbid sync queries for the database. See also
the :meth:`.allow_sync()` context manager.
"""
self._allow_sync = value
@contextlib.contextmanager
def allow_sync(self):
"""Allow sync queries within context. Close sync
connection on exit if connected.
Example::
with database.allow_sync():
PageBlock.create_table(True)
"""
old_allow_sync = self._allow_sync
self._allow_sync = True
try:
yield
except:
raise
finally:
self._allow_sync = old_allow_sync
try:
self.close()
except self.Error:
pass # already closed
def execute_sql(self, *args, **kwargs):
"""Sync execute SQL query, `allow_sync` must be set to True.
"""
assert self._allow_sync, (
"Error, sync query is not allowed! Call the `.set_allow_sync()` "
"or use the `.allow_sync()` context manager.")
if self._allow_sync in (logging.ERROR, logging.WARNING):
logging.log(self._allow_sync,
"Error, sync query is not allowed: %s %s" %
(str(args), str(kwargs)))
return super().execute_sql(*args, **kwargs)