Skip to content

Commit 31c2d43

Browse files
author
Andrew Jackson
committed
Implement connection service file functionality
1 parent d0797f1 commit 31c2d43

File tree

3 files changed

+199
-5
lines changed

3 files changed

+199
-5
lines changed

asyncpg/connect_utils.py

+125-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88

99
import asyncio
10+
import configparser
1011
import collections
1112
from collections.abc import Callable
1213
import enum
@@ -87,6 +88,9 @@ class SSLNegotiation(compat.StrEnum):
8788
PGPASSFILE = '.pgpass'
8889

8990

91+
PG_SERVICEFILE = '.pg_service.conf'
92+
93+
9094
def _read_password_file(passfile: pathlib.Path) \
9195
-> typing.List[typing.Tuple[str, ...]]:
9296

@@ -268,7 +272,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
268272

269273

270274
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
271-
password, passfile, database, ssl,
275+
password, passfile, database, ssl, service,
272276
direct_tls, server_settings,
273277
target_session_attrs, krbsrvname, gsslib):
274278
# `auth_hosts` is the version of host information for the purposes
@@ -278,6 +282,120 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
278282
ssl_min_protocol_version = ssl_max_protocol_version = None
279283
sslnegotiation = None
280284

285+
if dsn:
286+
parsed = urllib.parse.urlparse(dsn)
287+
if parsed.query:
288+
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
289+
for key, val in query.items():
290+
if isinstance(val, list):
291+
query[key] = val[-1]
292+
293+
if 'service' in query:
294+
val = query.pop('service')
295+
if not service and val:
296+
service = val
297+
298+
connection_service_file = os.getenv('PGSERVICEFILE')
299+
if connection_service_file is None:
300+
homedir = compat.get_pg_home_directory()
301+
if homedir:
302+
connection_service_file = homedir / PG_SERVICEFILE
303+
else:
304+
connection_service_file = None
305+
else:
306+
connection_service_file = pathlib.Path(connection_service_file)
307+
308+
if connection_service_file is not None and service is not None:
309+
# TODO Open and parse connection service file
310+
pg_service = configparser.ConfigParser()
311+
pg_service.read(connection_service_file)
312+
if service in pg_service.sections():
313+
service_params = pg_service[service]
314+
if 'port' in service_params:
315+
val = service_params.pop('port')
316+
if not port and val:
317+
port = [int(p) for p in val.split(',')]
318+
319+
if 'host' in service_params:
320+
val = service_params.pop('host')
321+
if not host and val:
322+
host, port = _parse_hostlist(val, port)
323+
324+
if 'dbname' in service_params:
325+
val = service_params.pop('dbname')
326+
if database is None:
327+
database = val
328+
329+
if 'database' in service_params:
330+
val = service_params.pop('database')
331+
if database is None:
332+
database = val
333+
334+
if 'user' in service_params:
335+
val = service_params.pop('user')
336+
if user is None:
337+
user = val
338+
339+
if 'password' in service_params:
340+
val = service_params.pop('password')
341+
if password is None:
342+
password = val
343+
344+
if 'passfile' in service_params:
345+
val = service_params.pop('passfile')
346+
if passfile is None:
347+
passfile = val
348+
349+
if 'sslmode' in service_params:
350+
val = service_params.pop('sslmode')
351+
if ssl is None:
352+
ssl = val
353+
354+
if 'sslcert' in service_params:
355+
sslcert = service_params.pop('sslcert')
356+
357+
if 'sslkey' in service_params:
358+
sslkey = service_params.pop('sslkey')
359+
360+
if 'sslrootcert' in service_params:
361+
sslrootcert = service_params.pop('sslrootcert')
362+
363+
if 'sslnegotiation' in service_params:
364+
sslnegotiation = service_params.pop('sslnegotiation')
365+
366+
if 'sslcrl' in service_params:
367+
sslcrl = service_params.pop('sslcrl')
368+
369+
if 'sslpassword' in service_params:
370+
sslpassword = service_params.pop('sslpassword')
371+
372+
if 'ssl_min_protocol_version' in service_params:
373+
ssl_min_protocol_version = service_params.pop(
374+
'ssl_min_protocol_version'
375+
)
376+
377+
if 'ssl_max_protocol_version' in service_params:
378+
ssl_max_protocol_version = service_params.pop(
379+
'ssl_max_protocol_version'
380+
)
381+
382+
if 'target_session_attrs' in service_params:
383+
dsn_target_session_attrs = service_params.pop(
384+
'target_session_attrs'
385+
)
386+
if target_session_attrs is None:
387+
target_session_attrs = dsn_target_session_attrs
388+
389+
if 'krbsrvname' in service_params:
390+
val = service_params.pop('krbsrvname')
391+
if krbsrvname is None:
392+
krbsrvname = val
393+
394+
if 'gsslib' in service_params:
395+
val = service_params.pop('gsslib')
396+
if gsslib is None:
397+
gsslib = val
398+
281399
if dsn:
282400
parsed = urllib.parse.urlparse(dsn)
283401

@@ -406,6 +524,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
406524
if gsslib is None:
407525
gsslib = val
408526

527+
if 'service' in query:
528+
val = query.pop('service')
529+
409530
if query:
410531
if server_settings is None:
411532
server_settings = query
@@ -491,6 +612,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
491612
database=database, user=user,
492613
passfile=passfile)
493614

615+
494616
addrs = []
495617
have_tcp_addrs = False
496618
for h, p in zip(host, port):
@@ -724,7 +846,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
724846
max_cached_statement_lifetime,
725847
max_cacheable_statement_size,
726848
ssl, direct_tls, server_settings,
727-
target_session_attrs, krbsrvname, gsslib):
849+
target_session_attrs, krbsrvname, gsslib, service):
728850
local_vars = locals()
729851
for var_name in {'max_cacheable_statement_size',
730852
'max_cached_statement_lifetime',
@@ -754,7 +876,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
754876
direct_tls=direct_tls, database=database,
755877
server_settings=server_settings,
756878
target_session_attrs=target_session_attrs,
757-
krbsrvname=krbsrvname, gsslib=gsslib)
879+
krbsrvname=krbsrvname, gsslib=gsslib, service=service)
758880

759881
config = _ClientConfiguration(
760882
command_timeout=command_timeout,

asyncpg/connection.py

+6
Original file line numberDiff line numberDiff line change
@@ -2074,6 +2074,7 @@ async def _do_execute(
20742074
async def connect(dsn=None, *,
20752075
host=None, port=None,
20762076
user=None, password=None, passfile=None,
2077+
service=None,
20772078
database=None,
20782079
loop=None,
20792080
timeout=60,
@@ -2183,6 +2184,10 @@ async def connect(dsn=None, *,
21832184
(defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf``
21842185
on Windows).
21852186
2187+
:param service:
2188+
The name of the postgres connection service stored in the postgres
2189+
connection service file.
2190+
21862191
:param loop:
21872192
An asyncio event loop instance. If ``None``, the default
21882193
event loop will be used.
@@ -2428,6 +2433,7 @@ async def connect(dsn=None, *,
24282433
user=user,
24292434
password=password,
24302435
passfile=passfile,
2436+
service=service,
24312437
ssl=ssl,
24322438
direct_tls=direct_tls,
24332439
database=database,

tests/test_connect.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -1116,7 +1116,8 @@ def run_testcase(self, testcase):
11161116
env = testcase.get('env', {})
11171117
test_env = {'PGHOST': None, 'PGPORT': None,
11181118
'PGUSER': None, 'PGPASSWORD': None,
1119-
'PGDATABASE': None, 'PGSSLMODE': None}
1119+
'PGDATABASE': None, 'PGSSLMODE': None,
1120+
'PGSERVICE': None, }
11201121
test_env.update(env)
11211122

11221123
dsn = testcase.get('dsn')
@@ -1132,6 +1133,7 @@ def run_testcase(self, testcase):
11321133
target_session_attrs = testcase.get('target_session_attrs')
11331134
krbsrvname = testcase.get('krbsrvname')
11341135
gsslib = testcase.get('gsslib')
1136+
service = testcase.get('service')
11351137

11361138
expected = testcase.get('result')
11371139
expected_error = testcase.get('error')
@@ -1157,7 +1159,7 @@ def run_testcase(self, testcase):
11571159
direct_tls=direct_tls,
11581160
server_settings=server_settings,
11591161
target_session_attrs=target_session_attrs,
1160-
krbsrvname=krbsrvname, gsslib=gsslib)
1162+
krbsrvname=krbsrvname, gsslib=gsslib, service=service)
11611163

11621164
params = {
11631165
k: v for k, v in params._asdict().items()
@@ -1236,6 +1238,70 @@ def test_connect_params(self):
12361238
for testcase in self.TESTS:
12371239
self.run_testcase(testcase)
12381240

1241+
def test_connect_connection_service_file(self):
1242+
connection_service_file = tempfile.NamedTemporaryFile('w+t', delete=False)
1243+
connection_service_file.write(textwrap.dedent(f'''
1244+
[test_service_dbname]
1245+
port=5433
1246+
host=somehost
1247+
dbname=test_dbname
1248+
user=admin
1249+
password=test_password
1250+
target_session_attrs=primary
1251+
krbsrvname=fakekrbsrvname
1252+
gsslib=sspi
1253+
1254+
[test_service_database]
1255+
port=5433
1256+
host=somehost
1257+
database=test_dbname
1258+
user=admin
1259+
password=test_password
1260+
target_session_attrs=primary
1261+
krbsrvname=fakekrbsrvname
1262+
gsslib=sspi
1263+
'''))
1264+
connection_service_file.close()
1265+
os.chmod(connection_service_file.name, stat.S_IWUSR | stat.S_IRUSR)
1266+
try:
1267+
# passfile path in env
1268+
self.run_testcase({
1269+
'dsn': 'postgresql://?service=test_service_dbname',
1270+
'env': {
1271+
'PGSERVICEFILE': connection_service_file.name
1272+
},
1273+
'result': (
1274+
[('somehost', 5433)],
1275+
{
1276+
'user': 'admin',
1277+
'password': 'test_password',
1278+
'database': 'test_dbname',
1279+
'target_session_attrs': 'primary',
1280+
'krbsrvname': 'fakekrbsrvname',
1281+
'gsslib': 'sspi',
1282+
}
1283+
)
1284+
})
1285+
self.run_testcase({
1286+
'dsn': 'postgresql://?service=test_service_database',
1287+
'env': {
1288+
'PGSERVICEFILE': connection_service_file.name
1289+
},
1290+
'result': (
1291+
[('somehost', 5433)],
1292+
{
1293+
'user': 'admin',
1294+
'password': 'test_password',
1295+
'database': 'test_dbname',
1296+
'target_session_attrs': 'primary',
1297+
'krbsrvname': 'fakekrbsrvname',
1298+
'gsslib': 'sspi',
1299+
}
1300+
)
1301+
})
1302+
finally:
1303+
os.unlink(connection_service_file.name)
1304+
12391305
def test_connect_pgpass_regular(self):
12401306
passfile = tempfile.NamedTemporaryFile('w+t', delete=False)
12411307
passfile.write(textwrap.dedent(R'''

0 commit comments

Comments
 (0)