11# This file is part of parallel-ssh.
22#
3- # Copyright (C) 2014-2020 Panos Kittenis.
3+ # Copyright (C) 2014-2022 Panos Kittenis and contributors .
44#
55# This library is free software; you can redistribute it and/or
66# modify it under the terms of the GNU Lesser General Public
2020import logging
2121
2222import gevent .pool
23-
2423from gevent import joinall , spawn , Timeout as GTimeout
2524from gevent .hub import Hub
2625
26+ from ..common import _validate_pkey_path
27+ from ...config import HostConfig
2728from ...constants import DEFAULT_RETRIES , RETRY_DELAY
28- from ...exceptions import HostArgumentError , Timeout , ShellError
29+ from ...exceptions import HostArgumentError , Timeout , ShellError , HostConfigError
2930from ...output import HostOutput
3031
31-
3232Hub .NOT_ERROR = (Exception ,)
3333logger = logging .getLogger (__name__ )
3434
@@ -43,6 +43,19 @@ def __init__(self, hosts, user=None, password=None, port=None, pkey=None,
4343 host_config = None , retry_delay = RETRY_DELAY ,
4444 identity_auth = True ,
4545 ipv6_only = False ,
46+ proxy_host = None ,
47+ proxy_port = None ,
48+ proxy_user = None ,
49+ proxy_password = None ,
50+ proxy_pkey = None ,
51+ keepalive_seconds = None ,
52+ cert_file = None ,
53+ gssapi_auth = False ,
54+ gssapi_server_identity = None ,
55+ gssapi_client_identity = None ,
56+ gssapi_delegate_credentials = False ,
57+ forward_ssh_agent = False ,
58+ _auth_thread_pool = True ,
4659 ):
4760 self .allow_agent = allow_agent
4861 self .pool_size = pool_size
@@ -60,6 +73,19 @@ def __init__(self, hosts, user=None, password=None, port=None, pkey=None,
6073 self .cmds = None
6174 self .identity_auth = identity_auth
6275 self .ipv6_only = ipv6_only
76+ self .proxy_host = proxy_host
77+ self .proxy_port = proxy_port
78+ self .proxy_user = proxy_user
79+ self .proxy_password = proxy_password
80+ self .proxy_pkey = proxy_pkey
81+ self .keepalive_seconds = keepalive_seconds
82+ self .cert_file = cert_file
83+ self .forward_ssh_agent = forward_ssh_agent
84+ self .gssapi_auth = gssapi_auth
85+ self .gssapi_server_identity = gssapi_server_identity
86+ self .gssapi_client_identity = gssapi_client_identity
87+ self .gssapi_delegate_credentials = gssapi_delegate_credentials
88+ self ._auth_thread_pool = _auth_thread_pool
6389 self ._check_host_config ()
6490
6591 def _validate_hosts (self , _hosts ):
@@ -100,7 +126,7 @@ def _check_host_config(self):
100126 def _open_shell (self , host_i , host ,
101127 encoding = 'utf-8' , read_timeout = None ):
102128 try :
103- _client = self ._make_ssh_client (host_i , host )
129+ _client = self ._get_ssh_client (host_i , host )
104130 shell = _client .open_shell (
105131 encoding = encoding , read_timeout = read_timeout )
106132 return shell
@@ -230,36 +256,37 @@ def get_last_output(self, cmds=None):
230256 return self ._get_output_from_cmds (
231257 cmds , raise_error = False )
232258
233- def _get_host_config_values (self , host_i , host ):
259+ def _get_host_config (self , host_i , host ):
234260 if self .host_config is None :
235- return self .user , self .port , self .password , self .pkey , \
236- getattr (self , 'proxy_host' , None ), \
237- getattr (self , 'proxy_port' , None ), getattr (self , 'proxy_user' , None ), \
238- getattr (self , 'proxy_password' , None ), getattr (self , 'proxy_pkey' , None )
239- elif isinstance (self .host_config , list ):
240- config = self .host_config [host_i ]
241- return config .user or self .user , config .port or self .port , \
242- config .password or self .password , config .private_key or self .pkey , \
243- config .proxy_host or getattr (self , 'proxy_host' , None ), \
244- config .proxy_port or getattr (self , 'proxy_port' , None ), \
245- config .proxy_user or getattr (self , 'proxy_user' , None ), \
246- config .proxy_password or getattr (self , 'proxy_password' , None ), \
247- config .proxy_pkey or getattr (self , 'proxy_pkey' , None )
248- elif isinstance (self .host_config , dict ):
249- _user = self .host_config .get (host , {}).get ('user' , self .user )
250- _port = self .host_config .get (host , {}).get ('port' , self .port )
251- _password = self .host_config .get (host , {}).get (
252- 'password' , self .password )
253- _pkey = self .host_config .get (host , {}).get ('private_key' , self .pkey )
254- return _user , _port , _password , _pkey , None , None , None , None , None
261+ config = HostConfig (
262+ user = self .user , port = self .port , password = self .password , private_key = self .pkey ,
263+ allow_agent = self .allow_agent , num_retries = self .num_retries , retry_delay = self .retry_delay ,
264+ timeout = self .timeout , identity_auth = self .identity_auth , proxy_host = self .proxy_host ,
265+ proxy_port = self .proxy_port , proxy_user = self .proxy_user , proxy_password = self .proxy_password ,
266+ proxy_pkey = self .proxy_pkey ,
267+ keepalive_seconds = self .keepalive_seconds ,
268+ ipv6_only = self .ipv6_only ,
269+ cert_file = self .cert_file ,
270+ forward_ssh_agent = self .forward_ssh_agent ,
271+ gssapi_auth = self .gssapi_auth ,
272+ gssapi_server_identity = self .gssapi_server_identity ,
273+ gssapi_client_identity = self .gssapi_client_identity ,
274+ gssapi_delegate_credentials = self .gssapi_delegate_credentials ,
275+ )
276+ return config
277+ elif not isinstance (self .host_config , list ):
278+ raise HostConfigError ("Host configuration of type %s is invalid - valid types are list[HostConfig]" ,
279+ type (self .host_config ))
280+ config = self .host_config [host_i ]
281+ return config
255282
256283 def _run_command (self , host_i , host , command , sudo = False , user = None ,
257284 shell = None , use_pty = False ,
258285 encoding = 'utf-8' , read_timeout = None ):
259286 """Make SSHClient if needed, run command on host"""
260287 logger .debug ("_run_command with read timeout %s" , read_timeout )
261288 try :
262- _client = self ._make_ssh_client (host_i , host )
289+ _client = self ._get_ssh_client (host_i , host )
263290 host_out = _client .run_command (
264291 command , sudo = sudo , user = user , shell = shell ,
265292 use_pty = use_pty , encoding = encoding , read_timeout = read_timeout )
@@ -283,7 +310,7 @@ def connect_auth(self):
283310 :returns: list of greenlets to ``joinall`` with.
284311 :rtype: list(:py:mod:`gevent.greenlet.Greenlet`)
285312 """
286- cmds = [spawn (self ._make_ssh_client , i , host ) for i , host in enumerate (self .hosts )]
313+ cmds = [spawn (self ._get_ssh_client , i , host ) for i , host in enumerate (self .hosts )]
287314 return cmds
288315
289316 def _consume_output (self , stdout , stderr ):
@@ -429,7 +456,7 @@ def copy_file(self, local_file, remote_file, recurse=False, copy_args=None):
429456
430457 def _copy_file (self , host_i , host , local_file , remote_file , recurse = False ):
431458 """Make sftp client, copy file"""
432- client = self ._make_ssh_client (host_i , host )
459+ client = self ._get_ssh_client (host_i , host )
433460 return client .copy_file (
434461 local_file , remote_file , recurse = recurse )
435462
@@ -512,7 +539,7 @@ def copy_remote_file(self, remote_file, local_file, recurse=False,
512539 def _copy_remote_file (self , host_i , host , remote_file , local_file , recurse ,
513540 ** kwargs ):
514541 """Make sftp client, copy file to local"""
515- client = self ._make_ssh_client (host_i , host )
542+ client = self ._get_ssh_client (host_i , host )
516543 return client .copy_remote_file (
517544 remote_file , local_file , recurse = recurse , ** kwargs )
518545
@@ -522,5 +549,26 @@ def _handle_greenlet_exc(self, func, host, *args, **kwargs):
522549 except Exception as ex :
523550 raise ex
524551
525- def _make_ssh_client (self , host_i , host ):
552+ def _get_ssh_client (self , host_i , host ):
553+ logger .debug ("Make client request for host %s, (host_i, host) in clients: %s" ,
554+ host , (host_i , host ) in self ._host_clients )
555+ _client = self ._host_clients .get ((host_i , host ))
556+ if _client is not None :
557+ return _client
558+ cfg = self ._get_host_config (host_i , host )
559+ _pkey = self .pkey if cfg .private_key is None else cfg .private_key
560+ _pkey_data = self ._load_pkey_data (_pkey )
561+ _client = self ._make_ssh_client (host , cfg , _pkey_data )
562+ self ._host_clients [(host_i , host )] = _client
563+ return _client
564+
565+ def _load_pkey_data (self , _pkey ):
566+ if isinstance (_pkey , str ):
567+ _validate_pkey_path (_pkey )
568+ with open (_pkey , 'rb' ) as fh :
569+ _pkey_data = fh .read ()
570+ return _pkey_data
571+ return _pkey
572+
573+ def _make_ssh_client (self , host , cfg , _pkey_data ):
526574 raise NotImplementedError
0 commit comments