1
- import logging
2
1
import os
2
+ import socket
3
3
import subprocess
4
4
import tempfile
5
5
import platform
6
+ import time
6
7
7
8
# we support both pg8000 and psycopg2
8
9
try :
@@ -48,10 +49,10 @@ def __init__(self, conn_params: ConnectionParams):
48
49
self .ssh_key = conn_params .ssh_key
49
50
self .port = conn_params .port
50
51
self .ssh_cmd = ["-o StrictHostKeyChecking=no" ]
51
- if self .port :
52
- self .ssh_cmd += ["-p" , self .port ]
53
52
if self .ssh_key :
54
53
self .ssh_cmd += ["-i" , self .ssh_key ]
54
+ if self .port :
55
+ self .ssh_cmd += ["-p" , self .port ]
55
56
self .remote = True
56
57
self .username = conn_params .username or self .get_user ()
57
58
self .tunnel_process = None
@@ -62,17 +63,36 @@ def __enter__(self):
62
63
def __exit__ (self , exc_type , exc_val , exc_tb ):
63
64
self .close_ssh_tunnel ()
64
65
66
+ @staticmethod
67
+ def is_port_open (host , port ):
68
+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as sock :
69
+ sock .settimeout (1 ) # Таймаут для попытки соединения
70
+ try :
71
+ sock .connect ((host , port ))
72
+ return True
73
+ except socket .error :
74
+ return False
75
+
65
76
def establish_ssh_tunnel (self , local_port , remote_port ):
66
77
"""
67
78
Establish an SSH tunnel from a local port to a remote PostgreSQL port.
68
79
"""
69
80
ssh_cmd = ['-N' , '-L' , f"{ local_port } :localhost:{ remote_port } " ]
70
81
self .tunnel_process = self .exec_command (ssh_cmd , get_process = True , timeout = 300 )
82
+ timeout = 10
83
+ start_time = time .time ()
84
+ while time .time () - start_time < timeout :
85
+ if self .is_port_open ('localhost' , local_port ):
86
+ print ("SSH tunnel established." )
87
+ return
88
+ time .sleep (0.5 )
89
+ raise Exception ("Failed to establish SSH tunnel within the timeout period." )
71
90
72
91
def close_ssh_tunnel (self ):
73
- if hasattr ( self , ' tunnel_process' ) :
92
+ if self . tunnel_process :
74
93
self .tunnel_process .terminate ()
75
94
self .tunnel_process .wait ()
95
+ print ("SSH tunnel closed." )
76
96
del self .tunnel_process
77
97
else :
78
98
print ("No active tunnel to close." )
@@ -238,9 +258,9 @@ def mkdtemp(self, prefix=None):
238
258
- prefix (str): The prefix of the temporary directory name.
239
259
"""
240
260
if prefix :
241
- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , f"mktemp -d { prefix } XXXXX" ]
261
+ command = ["ssh" + f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ f"mktemp -d { prefix } XXXXX" ]
242
262
else :
243
- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , "mktemp -d" ]
263
+ command = ["ssh" , f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ "mktemp -d" ]
244
264
245
265
result = subprocess .run (command , stdout = subprocess .PIPE , stderr = subprocess .PIPE , text = True )
246
266
@@ -283,7 +303,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
283
303
mode = "r+b" if binary else "r+"
284
304
285
305
with tempfile .NamedTemporaryFile (mode = mode , delete = False ) as tmp_file :
286
- # Because in scp we set up port using -P option instead -p
306
+ # Because in scp we set up port using -P option
287
307
scp_ssh_cmd = ['-P' if x == '-p' else x for x in self .ssh_cmd ]
288
308
289
309
if not truncate :
@@ -305,9 +325,9 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
305
325
tmp_file .flush ()
306
326
scp_cmd = ['scp' ] + scp_ssh_cmd + [tmp_file .name , f"{ self .username } @{ self .host } :{ filename } " ]
307
327
subprocess .run (scp_cmd , check = True )
308
- remote_directory = os .path .dirname (filename )
309
328
310
- mkdir_cmd = ['ssh' ] + self .ssh_cmd + [f"{ self .username } @{ self .host } " , f'mkdir -p { remote_directory } ' ]
329
+ remote_directory = os .path .dirname (filename )
330
+ mkdir_cmd = ['ssh' , f"{ self .username } @{ self .host } " ] + self .ssh_cmd + [f"mkdir -p { remote_directory } " ]
311
331
subprocess .run (mkdir_cmd , check = True )
312
332
313
333
os .remove (tmp_file .name )
@@ -372,7 +392,7 @@ def get_pid(self):
372
392
return int (self .exec_command ("echo $$" , encoding = get_default_encoding ()))
373
393
374
394
def get_process_children (self , pid ):
375
- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , f"pgrep -P { pid } " ]
395
+ command = ["ssh" , f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ f"pgrep -P { pid } " ]
376
396
377
397
result = subprocess .run (command , stdout = subprocess .PIPE , stderr = subprocess .PIPE , text = True )
378
398
@@ -387,15 +407,16 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
387
407
"""
388
408
Establish SSH tunnel and connect to a PostgreSQL database.
389
409
"""
390
- self . establish_ssh_tunnel ( local_port = port , remote_port = self . conn_params . port )
391
-
410
+ local_port = reserve_port ( )
411
+ self . establish_ssh_tunnel ( local_port = local_port , remote_port = port )
392
412
try :
393
413
conn = pglib .connect (
394
414
host = host ,
395
- port = port ,
415
+ port = local_port ,
396
416
database = dbname ,
397
417
user = user ,
398
418
password = password ,
419
+ timeout = 10
399
420
)
400
421
print ("Database connection established successfully." )
401
422
return conn
0 commit comments