Skip to content

Commit 22d3251

Browse files
author
vshepard
committed
Fix ssh command in remote_ops.py
1 parent 6115461 commit 22d3251

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

Diff for: testgres/node.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,8 @@ def get_auth_method(t):
529529
u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host),
530530
u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host),
531531
u"host\treplication\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host),
532-
u"host\tall\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host)
532+
u"host\tall\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host),
533+
u"host\tall\tall\tall\t{}\n".format(auth_host)
533534
] # yapf: disable
534535

535536
# write missing lines

Diff for: testgres/operations/remote_ops.py

+34-13
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import logging
21
import os
2+
import socket
33
import subprocess
44
import tempfile
55
import platform
6+
import time
67

78
# we support both pg8000 and psycopg2
89
try:
@@ -48,10 +49,10 @@ def __init__(self, conn_params: ConnectionParams):
4849
self.ssh_key = conn_params.ssh_key
4950
self.port = conn_params.port
5051
self.ssh_cmd = ["-o StrictHostKeyChecking=no"]
51-
if self.port:
52-
self.ssh_cmd += ["-p", self.port]
5352
if self.ssh_key:
5453
self.ssh_cmd += ["-i", self.ssh_key]
54+
if self.port:
55+
self.ssh_cmd += ["-p", self.port]
5556
self.remote = True
5657
self.username = conn_params.username or self.get_user()
5758
self.tunnel_process = None
@@ -62,17 +63,36 @@ def __enter__(self):
6263
def __exit__(self, exc_type, exc_val, exc_tb):
6364
self.close_ssh_tunnel()
6465

66+
@staticmethod
67+
def is_port_open(host, port):
68+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
69+
sock.settimeout(1) # Таймаут для попытки соединения
70+
try:
71+
sock.connect((host, port))
72+
return True
73+
except socket.error:
74+
return False
75+
6576
def establish_ssh_tunnel(self, local_port, remote_port):
6677
"""
6778
Establish an SSH tunnel from a local port to a remote PostgreSQL port.
6879
"""
6980
ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"]
7081
self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300)
82+
timeout = 10
83+
start_time = time.time()
84+
while time.time() - start_time < timeout:
85+
if self.is_port_open('localhost', local_port):
86+
print("SSH tunnel established.")
87+
return
88+
time.sleep(0.5)
89+
raise Exception("Failed to establish SSH tunnel within the timeout period.")
7190

7291
def close_ssh_tunnel(self):
73-
if hasattr(self, 'tunnel_process'):
92+
if self.tunnel_process:
7493
self.tunnel_process.terminate()
7594
self.tunnel_process.wait()
95+
print("SSH tunnel closed.")
7696
del self.tunnel_process
7797
else:
7898
print("No active tunnel to close.")
@@ -238,9 +258,9 @@ def mkdtemp(self, prefix=None):
238258
- prefix (str): The prefix of the temporary directory name.
239259
"""
240260
if prefix:
241-
command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"]
261+
command = ["ssh" + f"{self.username}@{self.host}"] + self.ssh_cmd + [f"mktemp -d {prefix}XXXXX"]
242262
else:
243-
command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", "mktemp -d"]
263+
command = ["ssh", f"{self.username}@{self.host}"] + self.ssh_cmd + ["mktemp -d"]
244264

245265
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
246266

@@ -283,7 +303,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
283303
mode = "r+b" if binary else "r+"
284304

285305
with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file:
286-
# Because in scp we set up port using -P option instead -p
306+
# Because in scp we set up port using -P option
287307
scp_ssh_cmd = ['-P' if x == '-p' else x for x in self.ssh_cmd]
288308

289309
if not truncate:
@@ -305,9 +325,9 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
305325
tmp_file.flush()
306326
scp_cmd = ['scp'] + scp_ssh_cmd + [tmp_file.name, f"{self.username}@{self.host}:{filename}"]
307327
subprocess.run(scp_cmd, check=True)
308-
remote_directory = os.path.dirname(filename)
309328

310-
mkdir_cmd = ['ssh'] + self.ssh_cmd + [f"{self.username}@{self.host}", f'mkdir -p {remote_directory}']
329+
remote_directory = os.path.dirname(filename)
330+
mkdir_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + [f"mkdir -p {remote_directory}"]
311331
subprocess.run(mkdir_cmd, check=True)
312332

313333
os.remove(tmp_file.name)
@@ -372,7 +392,7 @@ def get_pid(self):
372392
return int(self.exec_command("echo $$", encoding=get_default_encoding()))
373393

374394
def get_process_children(self, pid):
375-
command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"pgrep -P {pid}"]
395+
command = ["ssh", f"{self.username}@{self.host}"] + self.ssh_cmd + [f"pgrep -P {pid}"]
376396

377397
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
378398

@@ -387,15 +407,16 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
387407
"""
388408
Establish SSH tunnel and connect to a PostgreSQL database.
389409
"""
390-
self.establish_ssh_tunnel(local_port=port, remote_port=self.conn_params.port)
391-
410+
local_port = reserve_port()
411+
self.establish_ssh_tunnel(local_port=local_port, remote_port=port)
392412
try:
393413
conn = pglib.connect(
394414
host=host,
395-
port=port,
415+
port=local_port,
396416
database=dbname,
397417
user=user,
398418
password=password,
419+
timeout=10
399420
)
400421
print("Database connection established successfully.")
401422
return conn

0 commit comments

Comments
 (0)