44import asyncpg
55from sqlalchemy .engine .interfaces import Dialect
66from sqlalchemy .sql import ClauseElement
7+ from sqlalchemy .sql .ddl import DDLElement
78
89from databases .backends .common .records import Record , create_column_maps
9- from databases .backends .dialects .psycopg import compile_query , get_dialect
10- from databases .core import DatabaseURL
10+ from databases .backends .dialects .psycopg import dialect as psycopg_dialect
11+ from databases .core import LOG_EXTRA , DatabaseURL
1112from databases .interfaces import (
1213 ConnectionBackend ,
1314 DatabaseBackend ,
@@ -24,9 +25,20 @@ def __init__(
2425 ) -> None :
2526 self ._database_url = DatabaseURL (database_url )
2627 self ._options = options
27- self ._dialect = get_dialect ()
28+ self ._dialect = self . _get_dialect ()
2829 self ._pool = None
2930
31+ def _get_dialect (self ) -> Dialect :
32+ dialect = psycopg_dialect (paramstyle = "pyformat" )
33+ dialect .implicit_returning = True
34+ dialect .supports_native_enum = True
35+ dialect .supports_smallserial = True # 9.2+
36+ dialect ._backslash_escapes = False
37+ dialect .supports_sane_multi_rowcount = True # psycopg 2.0.9+
38+ dialect ._has_native_hstore = True
39+ dialect .supports_native_decimal = True
40+ return dialect
41+
3042 def _get_connection_kwargs (self ) -> dict :
3143 url_options = self ._database_url .options
3244
@@ -87,15 +99,15 @@ async def release(self) -> None:
8799
88100 async def fetch_all (self , query : ClauseElement ) -> typing .List [RecordInterface ]:
89101 assert self ._connection is not None , "Connection is not acquired"
90- query_str , args , result_columns = compile_query ( query , self ._dialect )
102+ query_str , args , result_columns = self ._compile ( query )
91103 rows = await self ._connection .fetch (query_str , * args )
92104 dialect = self ._dialect
93105 column_maps = create_column_maps (result_columns )
94106 return [Record (row , result_columns , dialect , column_maps ) for row in rows ]
95107
96108 async def fetch_one (self , query : ClauseElement ) -> typing .Optional [RecordInterface ]:
97109 assert self ._connection is not None , "Connection is not acquired"
98- query_str , args , result_columns = compile_query ( query , self ._dialect )
110+ query_str , args , result_columns = self ._compile ( query )
99111 row = await self ._connection .fetchrow (query_str , * args )
100112 if row is None :
101113 return None
@@ -123,7 +135,7 @@ async def fetch_val(
123135
124136 async def execute (self , query : ClauseElement ) -> typing .Any :
125137 assert self ._connection is not None , "Connection is not acquired"
126- query_str , args , _ = compile_query ( query , self ._dialect )
138+ query_str , args , _ = self ._compile ( query )
127139 return await self ._connection .fetchval (query_str , * args )
128140
129141 async def execute_many (self , queries : typing .List [ClauseElement ]) -> None :
@@ -132,25 +144,55 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
132144 # loop through multiple executes here, which should all end up
133145 # using the same prepared statement.
134146 for single_query in queries :
135- single_query , args , _ = compile_query ( single_query , self ._dialect )
147+ single_query , args , _ = self ._compile ( single_query )
136148 await self ._connection .execute (single_query , * args )
137149
138150 async def iterate (
139151 self , query : ClauseElement
140152 ) -> typing .AsyncGenerator [typing .Any , None ]:
141153 assert self ._connection is not None , "Connection is not acquired"
142- query_str , args , result_columns = compile_query ( query , self ._dialect )
154+ query_str , args , result_columns = self ._compile ( query )
143155 column_maps = create_column_maps (result_columns )
144156 async for row in self ._connection .cursor (query_str , * args ):
145157 yield Record (row , result_columns , self ._dialect , column_maps )
146158
147159 def transaction (self ) -> TransactionBackend :
148160 return AsyncpgTransaction (connection = self )
149161
150- @property
151- def raw_connection (self ) -> asyncpg .connection .Connection :
152- assert self ._connection is not None , "Connection is not acquired"
153- return self ._connection
162+ def _compile (self , query : ClauseElement ) -> typing .Tuple [str , list , tuple ]:
163+ compiled = query .compile (
164+ dialect = self ._dialect , compile_kwargs = {"render_postcompile" : True }
165+ )
166+
167+ if not isinstance (query , DDLElement ):
168+ compiled_params = sorted (compiled .params .items ())
169+
170+ mapping = {
171+ key : "$" + str (i ) for i , (key , _ ) in enumerate (compiled_params , start = 1 )
172+ }
173+ compiled_query = compiled .string % mapping
174+
175+ processors = compiled ._bind_processors
176+ args = [
177+ processors [key ](val ) if key in processors else val
178+ for key , val in compiled_params
179+ ]
180+ result_map = compiled ._result_columns
181+ else :
182+ compiled_query = compiled .string
183+ args = []
184+ result_map = None
185+
186+ query_message = compiled_query .replace (" \n " , " " ).replace ("\n " , " " )
187+ logger .debug (
188+ "Query: %s Args: %s" , query_message , repr (tuple (args )), extra = LOG_EXTRA
189+ )
190+ return compiled_query , args , result_map
191+
192+ @property
193+ def raw_connection (self ) -> asyncpg .connection .Connection :
194+ assert self ._connection is not None , "Connection is not acquired"
195+ return self ._connection
154196
155197
156198class AsyncpgTransaction (TransactionBackend ):
0 commit comments