diff --git a/keepercommander/commands/discoveryrotation.py b/keepercommander/commands/discoveryrotation.py index c91574f75..132c3b6d4 100644 --- a/keepercommander/commands/discoveryrotation.py +++ b/keepercommander/commands/discoveryrotation.py @@ -14,6 +14,7 @@ import logging import os.path import queue +import socket import sys import threading from datetime import datetime @@ -1562,18 +1563,13 @@ def execute(self, params, **kwargs): ############################################## TUNNELING ############################################################### class PAMTunnelListCommand(Command): pam_cmd_parser = argparse.ArgumentParser(prog='dr-tunnel-list-command') - pam_cmd_parser.add_argument('--uid', '-u', required=False, dest='record_uid', action='store', - help='Filter list with UID of the PAM record that was used to create the tunnel') - pam_cmd_parser.add_argument('--conversation-id', '-c', required=False, dest='convo_id', action='store', - help='The connection ID of the Tunnel to list') def get_parser(self): return PAMTunnelListCommand.pam_cmd_parser def execute(self, params, **kwargs): def gather_tabel_row_data(thread): - # {"thread": t, "host": host, "port": port, "name": listener_name, "started": datetime.now(), - # "record_uid": record_uid} + # {"thread": t, "host": host, "port": port, "started": datetime.now(), row = [] run_time = None hours = 0 @@ -1606,8 +1602,6 @@ def gather_tabel_row_data(thread): row.append(text_line) return row - convo_id = kwargs.get('convo_id', None) - if not params.tunnel_threads: logging.warning(f"{bcolors.OKBLUE}No Tunnels running{bcolors.ENDC}") return @@ -1615,84 +1609,51 @@ def gather_tabel_row_data(thread): table = [] headers = ['Tunnel ID', 'Host', 'Port', 'Record UID', 'Up Time'] - if convo_id: - if convo_id in params.tunnel_threads: - table.append(gather_tabel_row_data(params.tunnel_threads[convo_id])) - else: - print(f"{bcolors.FAIL}Tunnel {convo_id} not found{bcolors.ENDC}") - return - else: - for i, convo_id in enumerate(params.tunnel_threads): - row = gather_tabel_row_data(params.tunnel_threads[convo_id]) - if row: - table.append(row) + for i, convo_id in enumerate(params.tunnel_threads): + row = gather_tabel_row_data(params.tunnel_threads[convo_id]) + if row: + table.append(row) dump_report_data(table, headers, fmt='table', filename="", row_number=False, column_width=None) class PAMTunnelStopCommand(Command): pam_cmd_parser = argparse.ArgumentParser(prog='dr-tunnel-stop-command') - pam_cmd_parser.add_argument('--conversation-id', '-c', required=False, dest='convo_id', action='store', - help='The connection ID of the Tunnel to stop') - - def tunnel_cleanup(self, params, convo_id): - tunnel_data = params.tunnel_threads.get(convo_id, None) - if not tunnel_data: - return - - for task_name in ["print", "connect"]: - task = tunnel_data.get(task_name) - if task: - task.cancel() - logging.debug(f"Cancelled {task_name} for {convo_id}") - - del params.tunnel_threads[convo_id] - logging.debug(f"Cleaned up data for {convo_id}") - - if convo_id in params.tunnel_threads_queue: - del params.tunnel_threads_queue[convo_id] - logging.debug(f"{bcolors.OKBLUE}{convo_id} Queue cleaned up{bcolors.ENDC}") + pam_cmd_parser.add_argument('uid', type= str, action='store', help='The Tunnel UID') def get_parser(self): return PAMTunnelStopCommand.pam_cmd_parser def execute(self, params, **kwargs): - convo_id = kwargs.get('convo_id') - tunnel_data = params.tunnel_threads.get(convo_id, None) + convo_id = kwargs.get('uid') + if not convo_id: + raise CommandError('tunnel stop', '"uid" argument is required') - if tunnel_data is None: - print(f"{bcolors.WARNING}No data found for conversation ID {convo_id}{bcolors.ENDC}") + tunnel_data = params.tunnel_threads.get(convo_id, None) + if not tunnel_data: + logging.debug(f"{bcolors.WARNING}No tunnel data found for {convo_id}{bcolors.ENDC}") return - # Stop the asyncio event loop - loop = tunnel_data.get("loop", None) - - entrance = tunnel_data.get("entrance", None) - if loop and entrance and not loop._closed: - if loop._closed: - print(f"{bcolors.WARNING}Event loop is closed for conversation ID {convo_id}{bcolors.ENDC}") - else: - # Run the disconnect method in the event loop - loop.create_task(entrance.stop_server()) - print(f"Disconnected entrance for {convo_id}") - - loop.call_soon_threadsafe(self.tunnel_cleanup, params, convo_id) + connect_task = tunnel_data.get("connect_task", None) + if connect_task: + connect_task.cancel() + return class PAMTunnelTailCommand(Command): pam_cmd_parser = argparse.ArgumentParser(prog='dr-tunnel-tail-command') - pam_cmd_parser.add_argument('--conversation-id', '-c', required=False, dest='convo_id', action='store', - help='The connection ID of the Tunnel to tail logs') + pam_cmd_parser.add_argument('uid', type= str, action='store', help='The Tunnel UID') def get_parser(self): return PAMTunnelTailCommand.pam_cmd_parser def execute(self, params, **kwargs): - convo_id = kwargs.get('convo_id') + convo_id = kwargs.get('uid') + if not convo_id: + raise CommandError('tunnel tail', '"uid" argument is required') log_queue = params.tunnel_threads_queue.get(convo_id) - # TODO make this run in a new thread? logger_level = logging.getLogger().getEffectiveLevel() logging.getLogger('aiortc').setLevel(logging.DEBUG) @@ -1761,8 +1722,6 @@ class PAMTunnelStartCommand(Command): type=int, default=0, help='The port number on which the server will be listening for incoming connections. ' 'If not set, random open port on the machine will be used.') - pam_cmd_parser.add_argument('--listener-name', '-l', required=False, dest='listener_name', - action='store', default="Keeper PAM Tunnel", help='The name of the listener.') def get_parser(self): return PAMTunnelStartCommand.pam_cmd_parser @@ -1792,25 +1751,7 @@ def setup_logging(self, convo_id, log_queue, logging_level): logger.debug("Logging setup complete.") return logger - def tunnel_cleanup(self, params, convo_id): - tunnel_data = params.tunnel_threads.get(convo_id, None) - if not tunnel_data: - return - - for task_name in ["print", "entrance"]: - task = tunnel_data.get(task_name) - if task: - task.cancel() - print(f"Cancelled {task_name} for {convo_id}") - - del params.tunnel_threads[convo_id] - print(f"Cleaned up data for {convo_id}") - - if convo_id in params.tunnel_threads_queue: - del params.tunnel_threads_queue[convo_id] - print(f"{bcolors.OKBLUE}{convo_id} Queue cleaned up{bcolors.ENDC}") - - async def connect(self, params, record_uid, convo_id, gateway_uid, host, port, listener_name, + async def connect(self, params, record_uid, convo_id, gateway_uid, host, port, log_queue, gateway_public_key_bytes, client_private_key): # Setup custom logging to put logs into log_queue @@ -1847,11 +1788,17 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port, l # Set up the pc print_ready_event = asyncio.Event() - pc = WebRTCConnection(endpoint_name=listener_name, print_ready_event=print_ready_event, + pc = WebRTCConnection(endpoint_name=convo_id, print_ready_event=print_ready_event, username=response.username, password=response.password, logger=logger) # make webRTC sdp offer - offer = await pc.make_offer() + try: + offer = await pc.make_offer() + except socket.gaierror: + print(f"{bcolors.WARNING}Please upgrade Commander to the latest version to use this feature...{bcolors.ENDC}") + return + except Exception as e: + raise CommandError('tunnel start', f'Error making WebRTC offer: {e}') encrypted_offer = tunnel_encrypt(symmetric_key, offer) logger.debug("-->. SEND START MESSAGE OVER REST TO GATEWAY") @@ -1871,7 +1818,7 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port, l # TODO create objects for WebRTC inputs router_response = router_send_action_to_gateway( params=params, - gateway_action=GatewayActionWebRTCSession(inputs={'listenerName': listener_name, "recordUid": record_uid, + gateway_action=GatewayActionWebRTCSession(inputs={'listenerName': convo_id, "recordUid": record_uid, "offer": encrypted_offer, 'kind': 'start', 'conversationType': 'tunnel'}), message_type=pam_pb2.CMT_GENERAL, @@ -1900,18 +1847,16 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port, l logger.debug("starting private tunnel") - private_tunnel = TunnelEntrance(host=host, port=port, endpoint_name=listener_name, pc=pc, + private_tunnel = TunnelEntrance(host=host, port=port, endpoint_name=convo_id, pc=pc, print_ready_event=print_ready_event, logger=logger) t1 = asyncio.create_task(private_tunnel.start_server()) - params.tunnel_threads[convo_id].update({"print": t1, "entrance": private_tunnel}) + params.tunnel_threads[convo_id].update({"server": t1, "entrance": private_tunnel}) logger.debug("--> START LISTENING FOR MESSAGES FROM GATEWAY --------") await asyncio.gather(t1, private_tunnel.reader_task) - self.tunnel_cleanup(params, convo_id) - - def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port, listener_name, + def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port, gateway_public_key_bytes, client_private_key): loop = None try: @@ -1920,7 +1865,8 @@ def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port, lis asyncio.set_event_loop(loop) output_queue = queue.Queue(maxsize=500) params.tunnel_threads_queue[convo_id] = output_queue - loop.run_until_complete( + # Create a Task from the coroutine + connect_task = loop.create_task( self.connect( params=params, record_uid=record_uid, @@ -1928,22 +1874,50 @@ def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port, lis gateway_uid=gateway_uid, host=host, port=port, - listener_name=listener_name, log_queue=output_queue, gateway_public_key_bytes=gateway_public_key_bytes, client_private_key=client_private_key ) ) - except asyncio.CancelledError: - print(f"{bcolors.WARNING}Tasks for connection {convo_id} were cancelled.{bcolors.ENDC}") + params.tunnel_threads[convo_id].update({"connect_task": connect_task}) + try: + # Run the task until it is complete + loop.run_until_complete(connect_task) + except asyncio.CancelledError: + pass except SocketNotConnectedException as es: print(f"{bcolors.FAIL}An exception occurred in pre_connect for connection {convo_id}: {es}{bcolors.ENDC}") except Exception as e: print(f"{bcolors.FAIL}An exception occurred in pre_connect for connection {convo_id}: {e}{bcolors.ENDC}") finally: if loop: - loop.call_soon_threadsafe(self.tunnel_cleanup, params, convo_id) - print(f"{bcolors.OKBLUE}Cleanup called for connection {convo_id}.{bcolors.ENDC}") + try: + tunnel_data = params.tunnel_threads.get(convo_id, None) + if not tunnel_data: + logging.debug(f"{bcolors.WARNING}No tunnel data found for {convo_id}{bcolors.ENDC}") + return + + if convo_id in params.tunnel_threads_queue: + del params.tunnel_threads_queue[convo_id] + + entrance = tunnel_data.get("entrance", None) + if entrance: + loop.run_until_complete(entrance.stop_server()) + + del params.tunnel_threads[convo_id] + logging.debug(f"Cleaned up data for {convo_id}") + + try: + for task in asyncio.all_tasks(loop): + task.cancel() + loop.stop() + loop.close() + logging.debug(f"{convo_id} Loop cleaned up") + except Exception as e: + logging.debug(f"{bcolors.WARNING}Exception while stopping event loop: {e}{bcolors.ENDC}") + except Exception as e: + print(f"{bcolors.FAIL}An exception occurred in pre_connect for connection {convo_id}: {e}{bcolors.ENDC}") + print(f"{bcolors.OKBLUE}Tunnel {convo_id} closed.{bcolors.ENDC}") def execute(self, params, **kwargs): version = [3, 11, 0] @@ -1963,7 +1937,6 @@ def execute(self, params, **kwargs): gateway_uid = kwargs.get('gateway') host = kwargs.get('host') port = kwargs.get('port') - listener_name = kwargs.get('listener_name') gateway_public_key_bytes = retrieve_gateway_public_key(gateway_uid, params, api, utils) @@ -1999,10 +1972,11 @@ def execute(self, params, **kwargs): client_private_key_value = client_private_key.get_default_value(str) t = threading.Thread(target=self.pre_connect, args=(params, record_uid, convo_id, gateway_uid, host, port, - listener_name, gateway_public_key_bytes, - client_private_key_value) + gateway_public_key_bytes, client_private_key_value) ) + + # Setting the thread as a daemon thread + t.daemon = True t.start() params.tunnel_threads[convo_id].update({"convo_id": convo_id, "thread": t, "host": host, "port": port, - "name": listener_name, "started": datetime.now(), - "record_uid": record_uid}) + "started": datetime.now(), "record_uid": record_uid}) diff --git a/keepercommander/commands/tunnel/port_forward/endpoint.py b/keepercommander/commands/tunnel/port_forward/endpoint.py index 2d20fc873..6be216c24 100644 --- a/keepercommander/commands/tunnel/port_forward/endpoint.py +++ b/keepercommander/commands/tunnel/port_forward/endpoint.py @@ -6,7 +6,7 @@ import socket import string import time -from typing import Optional, Dict, Tuple, Any, List, Union, Sequence +from typing import Optional, Dict, Tuple from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, RTCIceServer from cryptography.hazmat.primitives import hashes @@ -129,7 +129,8 @@ def tunnel_decrypt(symmetric_key: AESGCM, encrypted_data: str): data = data_bytes[NONCE_LENGTH:] try: return symmetric_key.decrypt(nonce, data, None) - except: + except Exception as e: + logging.error(f'Error decrypting data: {e}') return None @@ -143,33 +144,26 @@ def __init__(self, endpoint_name: Optional[str] = "Keeper PAM Tunnel", self.print_ready_event = print_ready_event # Define the STUN server URL - # krouter_server_url = 'https://connect.' + params.server # https://connect.dev.keepersecurity.com - # stun_url = "stun:stun.l.google.com:19302" + # To use Google's STUN server + ''' + stun_url = "stun:stun.l.google.com:19302" + # Create an RTCIceServer instance for the TURN server + turn_server = RTCIceServer(urls=turn_url) + config = RTCConfiguration(iceServers=[stun_server]) + ''' + # Using Keeper's STUN and TURN servers + # relay_url = 'relay.' + params.server + '3478' # relay.dev.keepersecurity.com:3478 relay_url = 'relay.keeperpamlab.com' stun_url = f"stun:{relay_url}:3478" - # Create an RTCIceServer instance for the STUN server stun_server = RTCIceServer(urls=stun_url) - # Define the TURN server URL and credentials turn_url = f"turn:{relay_url}?transport=udp" - - ''' - # Define TURN server credentials - username = username - password = password - # Create an RTCIceServer instance for the TURN server with credentials - ''' turn_server = RTCIceServer(urls=turn_url, username=username, credential=password) - - # Create an RTCIceServer instance for the TURN server - # turn_server = RTCIceServer(urls=turn_url) - # Create a new RTCConfiguration with both STUN and TURN servers config = RTCConfiguration(iceServers=[stun_server, turn_server]) - # config = RTCConfiguration(iceServers=[stun_server]) self._pc = RTCPeerConnection(config) self.setup_data_channel() @@ -237,39 +231,35 @@ def send_message(self, message): self.logger.error(f'Endpoint {self.endpoint_name}: Data channel is not open.') async def close_connection(self): + if self.closed: + return # Close the data channel if it's open if self.data_channel and self.data_channel.readyState == "open": - self.data_channel.close() - self.logger.error(f'Endpoint {self.endpoint_name}: Data channel closed') + try: + self.data_channel.close() + self.data_channel = None + self.logger.error(f'Endpoint {self.endpoint_name}: Data channel closed') + except Exception as e: + self.logger.error(f'Endpoint {self.endpoint_name}: Error closing data channel: {e}') # Close the peer connection if self._pc: await self._pc.close() - print("Peer connection closed") + self.logger.error(f'Endpoint {self.endpoint_name}: "Peer connection closed') # Clear the asyncio queue - while not self.web_rtc_queue.empty(): - self.web_rtc_queue.get_nowait() - self.web_rtc_queue = None + if self.web_rtc_queue: + while not self.web_rtc_queue.empty(): + self.web_rtc_queue.get_nowait() + self.web_rtc_queue = None # Reset instance variables self.data_channel = None - self.pc = None + self._pc = None # Set the closed flag self.closed = True - # Close the asyncio event loop if necessary - loop = asyncio.get_event_loop() - if loop.is_running(): - try: - for task in asyncio.all_tasks(loop): - task.cancel() - await loop.create_task(asyncio.sleep(.1)) - loop.stop() - except Exception as e: - self.logger.error(f'Endpoint {self.endpoint_name}: Error stopping loop: {e}') - """ This class is used to set up the tunnel entrance. This is used for the signaling phase and control messages. @@ -338,6 +328,7 @@ def __init__(self, print_ready_event, # type: asyncio.Event logger = None, # type: logging.Logger ): # type: (...) -> None + self.closing = False self.ping_time = None self.to_local_task = None self._ping_attempt = 0 @@ -354,6 +345,7 @@ def __init__(self, self.kill_server_event = asyncio.Event() self.pc = pc self.print_ready_event = print_ready_event + self.server_task = None async def send_to_web_rtc(self, data): if self.pc.is_data_channel_open(): @@ -495,7 +487,7 @@ async def forward_data_to_local(self): except asyncio.TimeoutError as et: if self._ping_attempt > 3: if self.is_connected: - await self.stop_server() + self.kill_server_event.set() raise et self.logger.debug(f'Endpoint {self.endpoint_name}: Tunnel reader timed out') self.logger.debug(f'Endpoint {self.endpoint_name}: Send ping request') @@ -528,7 +520,7 @@ async def forward_data_to_local(self): finally: self.logger.debug(f"Endpoint {self.endpoint_name}: Closing tunnel") - await self.stop_server() + self.kill_server_event.set() async def start_reader(self): # type: () -> None """ @@ -543,6 +535,8 @@ async def start_reader(self): # type: () -> None self.ping_time = time.perf_counter() await self.send_control_message(ControlMessage.Ping) self.logger.debug(f"Endpoint {self.endpoint_name}: Sent ping message to WebRTC connection") + except asyncio.CancelledError: + pass except Exception as e: self.logger.error(f"Endpoint {self.endpoint_name}: Error while establishing WebRTC connection: {e}") failed = True @@ -550,7 +544,7 @@ async def start_reader(self): # type: () -> None if failed: for connection_no in list(self.connections): await self.close_connection(connection_no) - await self.stop_server() + self.kill_server_event.set() self.is_connected = False return @@ -616,8 +610,17 @@ async def start_server(self): # type: (...) -> None """ This server is used to listen for client connections to the local port. """ + if self.server: + return try: self._port = find_open_port(tried_ports=[], preferred_port=self._port, host=self.host) + except asyncio.CancelledError: + self.logger.info(f"Endpoint {self.endpoint_name}: Server has been cancelled. Cleaning up...") + # Perform necessary cleanup here + self.server.close() # Close the server + await self.server.wait_closed() # Wait until the server is closed + return + except Exception as e: self.logger.error(f"Endpoint {self.endpoint_name}: Error while finding open port: {e}") await self.print_not_ready() @@ -632,7 +635,7 @@ async def start_server(self): # type: (...) -> None self.server = await asyncio.start_server(self.handle_connection, family=socket.AF_INET, host=self.host, port=self._port) async with self.server: - await asyncio.create_task(self.print_ready(self.host, self._port, self.print_ready_event)) + self.server_task = await asyncio.create_task(self.print_ready(self.host, self._port, self.print_ready_event)) await self.server.serve_forever() except ConnectionRefusedError as er: self.logger.error(f"Endpoint {self.endpoint_name}: Connection Refused while starting server: {er}") @@ -648,14 +651,13 @@ async def start_server(self): # type: (...) -> None return async def print_not_ready(self): - - print(f'{bcolors.FAIL}+---------------------------------------------------------{bcolors.ENDC}') + print(f'\n{bcolors.FAIL}+---------------------------------------------------------{bcolors.ENDC}') print(f'{bcolors.FAIL}| Endpoint {self.endpoint_name}{bcolors.ENDC} failed to start') print(f'{bcolors.FAIL}+---------------------------------------------------------{bcolors.ENDC}') await self.send_control_message(ControlMessage.CloseConnection, int_to_bytes(0)) for c in list(self.connections): await self.close_connection(c) - await self.stop_server() + self.kill_server_event.set() async def print_ready(self, host, # type: str port, # type: int @@ -678,33 +680,40 @@ async def print_ready(self, host, # type: str # Sleep a little bit to print out last await asyncio.sleep(.5) host = host + ":" if host else '' - print(f'{bcolors.OKGREEN}+---------------------------------------------------------------{bcolors.ENDC}') + print(f'\n{bcolors.OKGREEN}+---------------------------------------------------------------{bcolors.ENDC}') print( f'{bcolors.OKGREEN}| Endpoint {bcolors.ENDC}{bcolors.OKBLUE}{self.endpoint_name}{bcolors.ENDC}' f'{bcolors.OKGREEN}: Listening on port: {bcolors.ENDC}' f'{bcolors.BOLD}{bcolors.OKBLUE}{host}{port}{bcolors.ENDC}') print(f'{bcolors.OKGREEN}+---------------------------------------------------------------{bcolors.ENDC}') print(f'{bcolors.OKGREEN}View all open tunnels : {bcolors.ENDC}{bcolors.OKBLUE}pam tunnel list{bcolors.ENDC}') - print(f'{bcolors.OKGREEN}Tail logs on open tunnel: {bcolors.ENDC}{bcolors.OKBLUE}pam tunnel tail -c="[TUNNELID]"{bcolors.ENDC}') - print(f'{bcolors.OKGREEN}Stop a tunnel : {bcolors.ENDC}{bcolors.OKBLUE}pam tunnel stop -c="[TUNNELID]"{bcolors.ENDC}') - + print(f'{bcolors.OKGREEN}Tail logs on open tunnel: {bcolors.ENDC}' + f'{bcolors.OKBLUE}pam tunnel tail ' + + (f'--' if self.endpoint_name[0] == '-' else '') + + f'{self.endpoint_name}{bcolors.ENDC}') + print(f'{bcolors.OKGREEN}Stop a tunnel : {bcolors.ENDC}' + f'{bcolors.OKBLUE}pam tunnel stop ' + + (f'--' if self.endpoint_name[0] == '-' else '') + + f'{self.endpoint_name}{bcolors.ENDC}') async def stop_server(self): + if self.closing: + return try: await self.send_control_message(ControlMessage.CloseConnection, int_to_bytes(0)) - if self.server: - self.server.close() - await self.server.wait_closed() - self.logger.debug(f"Endpoint {self.endpoint_name}: Local server stopped") - self.server = None - if self.reader_task: - self.reader_task.cancel() - for t in list(self.to_tunnel_tasks): - self.to_tunnel_tasks[t].cancel() - if self.to_local_task: - self.to_local_task.cancel() + except Exception as ex: + self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception sending Close connection {ex}') + + self.kill_server_event.set() + try: + # close aiortc data channel + await self.pc.close_connection() + except Exception as ex: + self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception closing data channel {ex}') + finally: - self.kill_server_event.set() + self.closing = True + self.logger.debug(f"Endpoint {self.endpoint_name}: Tunnel stopped") async def close_connection(self, connection_no): try: