diff --git a/.gitignore b/.gitignore index 4392703fd..b21811d7b 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ keeper.txt *.csv Makefile *.db +dr-logs \ No newline at end of file diff --git a/keepercommander/cli.py b/keepercommander/cli.py index c3f2ed60d..e0d573ef7 100644 --- a/keepercommander/cli.py +++ b/keepercommander/cli.py @@ -207,6 +207,8 @@ def is_msp(params_local): elif command_line == 'debug': is_debug = logging.getLogger().level <= logging.DEBUG logging.getLogger().setLevel((logging.WARNING if params.batch_mode else logging.INFO) if is_debug else logging.DEBUG) + logging.getLogger('aiortc').setLevel(logging.WARNING if is_debug or params.batch_mode else logging.DEBUG) + logging.getLogger('aioice').setLevel(logging.WARNING if is_debug or params.batch_mode else logging.DEBUG) logging.info('Debug %s', 'OFF' if is_debug else 'ON') else: diff --git a/keepercommander/commands/discoveryrotation.py b/keepercommander/commands/discoveryrotation.py index 597e4589d..34329086a 100644 --- a/keepercommander/commands/discoveryrotation.py +++ b/keepercommander/commands/discoveryrotation.py @@ -12,16 +12,14 @@ import asyncio import json import logging +import os.path import queue -import ssl import sys import threading -import os.path from datetime import datetime from typing import Dict, Optional, Any import requests -import websockets from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ec @@ -38,16 +36,16 @@ from .pam.pam_dto import GatewayActionGatewayInfo, GatewayActionDiscoverInputs, GatewayActionDiscover, \ GatewayActionRotate, \ GatewayActionRotateInputs, GatewayAction, GatewayActionJobInfoInputs, \ - GatewayActionJobInfo, GatewayActionJobCancel + GatewayActionJobInfo, GatewayActionJobCancel, GatewayActionWebRTCSession from .pam.router_helper import router_send_action_to_gateway, print_router_response, \ router_get_connected_gateways, router_set_record_rotation_information, router_get_rotation_schedules, \ - get_router_url, router_send_message_to_gateway, get_router_ws_url, get_controller_cookie, request_cookie_jar_to_str + get_router_url, router_get_relay_access_creds from .record_edit import RecordEditMixin -from .tunnel.port_forward import tunnel_connected, endpoint -from .. import api, utils, vault_extensions, vault, record_management, attachment, record_facades, rest_api, crypto +from .tunnel.port_forward.endpoint import establish_symmetric_key, tunnel_encrypt, WebRTCConnection, tunnel_decrypt, \ + TunnelEntrance +from .. import api, utils, vault_extensions, vault, record_management, attachment, record_facades from ..display import bcolors from ..error import CommandError -from ..loginv3 import CommonHelperMethods from ..params import KeeperParams, LAST_RECORD_UID from ..proto import pam_pb2, router_pb2, record_pb2 from ..proto.APIRequest_pb2 import GetKsmPublicKeysRequest, GetKsmPublicKeysResponse @@ -1569,16 +1567,31 @@ class PAMTunnelListCommand(Command): pam_cmd_parser.add_argument('--conversation-id', '-c', required=False, dest='convo_id', action='store', help='The connection ID of the Tunnel to list') - pam_cmd_parser.add_argument('--verbose', '-v', required=False, dest='verbose', action='store_true', - help='Print out more details about the tunnel') - def get_parser(self): return PAMTunnelListCommand.pam_cmd_parser def execute(self, params, **kwargs): - def print_thread(thread, verbose): + def gather_tabel_row_data(thread): # {"thread": t, "host": host, "port": port, "name": listener_name, "started": datetime.now(), # "record_uid": record_uid} + if thread.get('entrance') is None: + tunnel_data = thread.get(convo_id, None) + if not tunnel_data: + return + for task_name in ["print", "connect"]: + task = tunnel_data.get(task_name) + if task: + task.cancel() + logging.debug(f"Cancelled {task_name} for {convo_id}") + + del params.tunnel_threads[convo_id] + logging.debug(f"Cleaned up data for {convo_id}") + + if convo_id in params.tunnel_threads_queue: + del params.tunnel_threads_queue[convo_id] + logging.debug(f"{bcolors.OKBLUE}{convo_id} Queue cleaned up{bcolors.ENDC}") + return [] + row = [] run_time = None hours = 0 minutes = 0 @@ -1587,51 +1600,44 @@ def print_thread(thread, verbose): run_time = datetime.now() - thread.get('started') hours, remainder = divmod(run_time.seconds, 3600) minutes, seconds = divmod(remainder, 60) - - text_line = f"{bcolors.OKGREEN}Tunnel {thread.get('name', '')} '{thread.get('convo_id', '')}'" - text_line += f", Host: {thread.get('host')}" if thread.get('host') else '' - text_line += f", Port: {thread.get('port')}" if thread.get('port') else '' - text_line += f", Record UID: {thread.get('record_uid')}" if thread.get('record_uid') else '' - text_line += f", Up time:" + # + # row.append(f"{thread.get('name', '')}") + row.append(f"{bcolors.OKBLUE}{thread.get('convo_id', '')}{bcolors.ENDC}") + row.append(f"{thread.get('host')}" if thread.get('host') else '') + row.append(f"{thread.get('entrance')._port}" if thread.get('entrance')._port else '') + row.append(f"{thread.get('record_uid')}" if thread.get('record_uid') else '') + text_line = "" if run_time: - text_line += f" days {run_time.days}" if run_time.days > 0 else '' - text_line += f" hours {hours}" if hours > 0 or run_time.days > 0 else '' - text_line += f" minutes {minutes}" - text_line += f" seconds {seconds}" - if verbose: - text_line += f", Public Key: {thread.get('entrance').gateway_public_key_bytes}" - text_line += f"{bcolors.ENDC}" - print(text_line) + if run_time.days == 1: + text_line += f"{run_time.days} day " + elif run_time.days > 1: + text_line += f"{run_time.days} days " + text_line += f"{hours} hr " if hours > 0 or run_time.days > 0 else '' + text_line += f"{minutes} min " + text_line += f"{seconds} sec" + row.append(text_line) + return row convo_id = kwargs.get('convo_id', None) - if kwargs.get('verbose'): - verbose = True - else: - verbose = False - record_uid = kwargs.get('record_uid', None) if not params.tunnel_threads: logging.warning(f"{bcolors.OKBLUE}No Tunnels running{bcolors.ENDC}") return + table = [] + headers = ['Tunnel ID', 'Host', 'Port', 'Record UID', 'Up Time'] + if convo_id: if convo_id in params.tunnel_threads: - print_thread(params.tunnel_threads[convo_id], verbose) + table.append(gather_tabel_row_data(params.tunnel_threads[convo_id])) else: print(f"{bcolors.FAIL}Tunnel {convo_id} not found{bcolors.ENDC}") return + else: + for i, convo_id in enumerate(params.tunnel_threads): + table.append(gather_tabel_row_data(params.tunnel_threads[convo_id])) - if record_uid: - # Print out all tunnels for record uid - for convo_id in params.tunnel_threads: - if params.tunnel_threads[convo_id].get("record_uid", "") == record_uid: - print_thread(params.tunnel_threads[convo_id], verbose) - return - print(f"{bcolors.FAIL}Tunnel for record {record_uid} not found{bcolors.ENDC}") - return - - for i, convo_id in enumerate(params.tunnel_threads): - print_thread(params.tunnel_threads[convo_id], verbose) + dump_report_data(table, headers, fmt='table', filename="", row_number=False, column_width=None) class PAMTunnelStopCommand(Command): @@ -1644,7 +1650,7 @@ def tunnel_cleanup(self, params, convo_id): if not tunnel_data: return - for task_name in ["ws_reader", "ws_writer", "connect"]: + for task_name in ["print", "connect"]: task = tunnel_data.get(task_name) if task: task.cancel() @@ -1677,7 +1683,7 @@ def execute(self, params, **kwargs): print(f"{bcolors.WARNING}Event loop is closed for conversation ID {convo_id}{bcolors.ENDC}") else: # Run the disconnect method in the event loop - loop.create_task(entrance.disconnect()) + loop.create_task(entrance.stop_server()) print(f"Disconnected entrance for {convo_id}") loop.call_soon_threadsafe(self.tunnel_cleanup, params, convo_id) @@ -1695,6 +1701,14 @@ def execute(self, params, **kwargs): convo_id = kwargs.get('convo_id') log_queue = params.tunnel_threads_queue.get(convo_id) + + # TODO make this run in a new thread? + logger_level = logging.getLogger().getEffectiveLevel() + + logging.getLogger('aiortc').setLevel(logging.DEBUG) + logging.getLogger('aioice').setLevel(logging.DEBUG) + logging.getLogger(convo_id).setLevel(logging.DEBUG) + if log_queue: try: while True: @@ -1707,6 +1721,10 @@ def execute(self, params, **kwargs): except Exception as e: print(f' {bcolors.WARNING}Exiting due to exception: {e}{bcolors.ENDC}') return + finally: + logging.getLogger('aiortc').setLevel(logger_level) + logging.getLogger('aioice').setLevel(logger_level) + logging.getLogger(convo_id).setLevel(logger_level) else: print(f' {bcolors.FAIL}Invalid conversation ID{bcolors.ENDC}') return @@ -1743,12 +1761,12 @@ class PAMTunnelStartCommand(Command): help='Used to list all tunnels for the given Gateway UID') pam_cmd_parser.add_argument('--uid', '-u', required=True, dest='record_uid', action='store', help='Filter list with UID of the PAM record that was used to create the tunnel') - pam_cmd_parser.add_argument('--host', '-o', required=False, dest='host', action='store', default=None, + pam_cmd_parser.add_argument('--host', '-o', required=False, dest='host', action='store', + default="127.0.0.1", help='The address on which the server will be accepting connections. It could be an ' 'IP address or a hostname. ' - 'Ex. if set to 127.0.0.1 then only connections from the same machine will be ' - 'accepted. By default, if nothing is set, which means that the server will accept ' - 'connections from any IP address.') + 'Ex. set to 127.0.0.1 as default so only connections from the same machine will be' + ' accepted.') pam_cmd_parser.add_argument('--port', '-p', required=False, dest='port', action='store', type=int, default=0, help='The port number on which the server will be listening for incoming connections. ' @@ -1789,7 +1807,7 @@ def tunnel_cleanup(self, params, convo_id): if not tunnel_data: return - for task_name in ["ws_reader", "ws_writer", "connect"]: + for task_name in ["print", "entrance"]: task = tunnel_data.get(task_name) if task: task.cancel() @@ -1808,72 +1826,98 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port, l # Setup custom logging to put logs into log_queue logger = self.setup_logging(convo_id, log_queue, logging.getLogger().getEffectiveLevel()) - transmission_key = utils.generate_aes_key() - server_public_key = rest_api.SERVER_PUBLIC_KEYS[params.rest_context.server_key_id] - - if params.rest_context.server_key_id < 7: - encrypted_transmission_key = crypto.encrypt_rsa(transmission_key, server_public_key) - else: - encrypted_transmission_key = crypto.encrypt_ec(transmission_key, server_public_key) - encrypted_session_token = crypto.encrypt_aes_v2(utils.base64_url_decode(params.session_token), transmission_key) - router_url = get_router_ws_url(params) - connection_url = (f'{router_url}/api/user/tunnel/{convo_id}' - f'?Authorization=KeeperUser%20' - f'{CommonHelperMethods.bytes_to_url_safe_str(encrypted_session_token)}' - f'&TransmissionKey={CommonHelperMethods.bytes_to_url_safe_str(encrypted_transmission_key)}') - - print("--> 1. CONNECT TO WS --------") - cookies = get_controller_cookie(params, gateway_uid) - cookie_str = request_cookie_jar_to_str(cookies) - - extra_headers = { - 'Cookie': cookie_str, - } - - entrance_ws = await websockets.connect(connection_url, ping_interval=10, extra_headers=extra_headers) - - print("--> 2. SEND START MESSAGE OVER REST TO GATEWAY") - - payload_dict = { - 'kind': 'start', - 'conversationType': 'tunnel', - 'value': {'listenerName': listener_name, "recordUid": record_uid} - } + print(f"{bcolors.HIGHINTENSITYWHITE}Establishing tunnel between Commander and Gateway. Please wait...{bcolors.ENDC}") + # get the keys + gateway_public_key = ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256R1(), gateway_public_key_bytes) + + """ +# Generate an EC private key +private_key = ec.generate_private_key( + ec.SECP256R1(), # Using P-256 curve + backend=default_backend() +) +# Serialize to PEM format +private_key_str = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() +).decode('utf-8') + """ + + client_private_key_pem = serialization.load_pem_private_key( + client_private_key.encode(), + password=None, + backend=default_backend() + ) - payload_json = json.dumps(payload_dict, default=lambda o: o.__dict__, sort_keys=True, indent=4) - payload_bytes = payload_json.encode('utf-8') + # Get symmetric key + symmetric_key = establish_symmetric_key(client_private_key_pem, gateway_public_key) + + response = router_get_relay_access_creds(params=params) + + # Set up the pc + print_ready_event = asyncio.Event() + pc = WebRTCConnection(endpoint_name=listener_name, print_ready_event=print_ready_event, + username=response.username, password=response.password, logger=logger) + + # make webRTC sdp offer + offer = await pc.make_offer() + encrypted_offer = tunnel_encrypt(symmetric_key, offer) + logger.debug("-->. SEND START MESSAGE OVER REST TO GATEWAY") + + ''' + 'inputs': { + 'conversationType': ['tunnel', 'guacd'] + 'kind': ['start', 'disconnect'], + 'recordUid': record_uid, <-- this is the record UID of the PAM resource record + with Network information + 'listenerName': NAME OF LISTENER, <-- Used in logging (not required) + 'offer': encrypted_WebRTC_sdp_offer, <-- WebRTC SDP offer encrypted with symmetric key + 'allow_control': True, <-- only for guacd, False = readonly session (default True) + 'guacamole_client_id: guacamole_client_id, <-- only for guacd, Connect to an existing guacd session + 'userRecordUid': userRecordUid, <-- only for guacd, User record UID to connect for session + } + ''' + # TODO create objects for WebRTC inputs + router_response = router_send_action_to_gateway( + params=params, + gateway_action=GatewayActionWebRTCSession(inputs={'listenerName': listener_name, "recordUid": record_uid, + "offer": encrypted_offer, 'kind': 'start', + 'conversationType': 'tunnel'}), + message_type=pam_pb2.CMT_GENERAL, + is_streaming=False, + destination_gateway_uid_str=gateway_uid + ) + if not router_response: + return + gateway_response = router_response.get('response', {}) + if not gateway_response: + raise Exception(f"Error getting response from the Gateway: {router_response}") + try: + payload = json.loads(gateway_response.get('payload', None)) + if not payload: + raise Exception(f"Error getting payload from the Gateway response: {gateway_response}") + except Exception as e: + raise Exception(f"Error getting payload from the Gateway response: {e}") - rq_proto = router_pb2.RouterControllerMessage() - rq_proto.messageUid = url_safe_str_to_bytes(convo_id) - rq_proto.controllerUid = url_safe_str_to_bytes(gateway_uid) - rq_proto.messageType = pam_pb2.CMT_STREAM - rq_proto.streamResponse = False - rq_proto.payload = payload_bytes - rq_proto.timeout = 1500000 # Default time out how long the response from the Gateway should be + encrypted_answer = payload.get('data', None) + if not encrypted_answer: + raise Exception(f"Error getting data from the Gateway response payload: {payload}") - rs = router_send_message_to_gateway( - params, - transmission_key, - rq_proto, - gateway_uid, - cookies) + # decrypt the sdp answer + answer = tunnel_decrypt(symmetric_key, encrypted_answer) + await pc.accept_answer(answer) - if b'No socket connection exist to start streaming.' in rs.content: - raise SocketNotConnectedException(f"Commander didn't connect to right instance of the router to connect to " - f'the Controller: {rs.content}') + logger.debug("starting private tunnel") - tunnel = tunnel_connected.ConnectedTunnel(entrance_ws) - entrance = endpoint.TunnelProtocol(tunnel, endpoint_name=listener_name, logger=logger, - gateway_public_key_bytes=gateway_public_key_bytes, - client_private_key=client_private_key, host=host, port=port) + private_tunnel = TunnelEntrance(host=host, port=port, endpoint_name=listener_name, pc=pc, + print_ready_event=print_ready_event, logger=logger) - t1 = asyncio.create_task(tunnel.ws_reader()) - t2 = asyncio.create_task(tunnel.ws_writer()) - t3 = asyncio.create_task(entrance.connect()) - params.tunnel_threads[convo_id].update({"ws_reader": t1, "ws_writer": t2, "connect": t3, "entrance": entrance}) + t1 = asyncio.create_task(private_tunnel.start_server()) + params.tunnel_threads[convo_id].update({"print": t1, "entrance": private_tunnel}) - print("--> 3. START LISTENING FOR MESSAGES FROM GATEWAY --------") - await asyncio.gather(t1, t2, t3) + logger.debug("--> START LISTENING FOR MESSAGES FROM GATEWAY --------") + await asyncio.gather(t1, private_tunnel.reader_task) self.tunnel_cleanup(params, convo_id) @@ -1927,8 +1971,8 @@ def execute(self, params, **kwargs): convo_id = GatewayAction.generate_conversation_id() params.tunnel_threads[convo_id] = {} gateway_uid = kwargs.get('gateway') - host = kwargs.get('host', "127.0.0.1") - port = kwargs.get('port', 0) + host = kwargs.get('host') + port = kwargs.get('port') listener_name = kwargs.get('listener_name') gateway_public_key_bytes = retrieve_gateway_public_key(gateway_uid, params, api, utils) @@ -1957,7 +2001,7 @@ def execute(self, params, **kwargs): encryption_algorithm=serialization.NoEncryption() ).decode('utf-8') client_private_key = vault.TypedField.new_field('secret', - client_private_key_value,"Client Private Key") + client_private_key_value, "Client Private Key") record.custom.append(client_private_key) record_management.update_record(params, record) api.sync_down(params) diff --git a/keepercommander/commands/pam/pam_dto.py b/keepercommander/commands/pam/pam_dto.py index 13ec5f272..1814f4e3e 100644 --- a/keepercommander/commands/pam/pam_dto.py +++ b/keepercommander/commands/pam/pam_dto.py @@ -147,3 +147,12 @@ def __init__(self, conversation_id=None): def toJSON(self): return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionWebRTCSession(GatewayAction): + + def __init__(self, inputs: dict,conversation_id=None): + super().__init__('webrtc-session', inputs=inputs, conversation_id=conversation_id, is_scheduled=False) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) diff --git a/keepercommander/commands/pam/router_helper.py b/keepercommander/commands/pam/router_helper.py index 5708f58d9..95b408ddc 100644 --- a/keepercommander/commands/pam/router_helper.py +++ b/keepercommander/commands/pam/router_helper.py @@ -88,8 +88,25 @@ def router_get_rotation_schedules(params, proto_request): return None +def router_get_relay_access_creds(params, expire_sec=None): + query_params = { + 'expire-sec': expire_sec + } -def _post_request_to_router(params, path, rq_proto=None, method='post', raw_without_status_check_response=False): + rs = _post_request_to_router(params, 'relay_access_creds', query_params=query_params) + + if type(rs) == bytes: + rac = pam_pb2.RelayAccessCreds() + rac.ParseFromString(rs) + if logging.getLogger().level <= logging.DEBUG: + js = google.protobuf.json_format.MessageToJson(rac) + logging.debug('>>> [GW RS] %s: %s', 'get_relay_access_creds', js) + + return rac + + return None + +def _post_request_to_router(params, path, rq_proto=None, method='post', raw_without_status_check_response=False, query_params=None): path = 'api/user/' + path krouter_host = get_router_url(params) @@ -113,6 +130,7 @@ def _post_request_to_router(params, path, rq_proto=None, method='post', raw_with try: rs = requests.request(method, krouter_host + "/" + path, + params=query_params, verify=VERIFY_SSL, headers={ 'TransmissionKey': bytes_to_base64(encrypted_transmission_key), diff --git a/keepercommander/commands/tunnel/port_forward/endpoint.py b/keepercommander/commands/tunnel/port_forward/endpoint.py index b54813498..2d20fc873 100644 --- a/keepercommander/commands/tunnel/port_forward/endpoint.py +++ b/keepercommander/commands/tunnel/port_forward/endpoint.py @@ -1,4 +1,3 @@ -import abc import asyncio import enum import logging @@ -9,30 +8,27 @@ import time from typing import Optional, Dict, Tuple, Any, List, Union, Sequence -from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, RTCIceServer, RTCDataChannel -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization, hashes +from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, RTCIceServer +from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.utils import int_to_bytes -from keeper_secrets_manager_core.utils import bytes_to_base64 +from keeper_secrets_manager_core.utils import bytes_to_base64, base64_to_bytes, bytes_to_string from keepercommander.display import bcolors -from .tunnel import ITunnel -PRIVATE_BUFFER_TRUNCATION_THRESHOLD = 1400 +logging.getLogger('aiortc').setLevel(logging.WARNING) +logging.getLogger('aioice').setLevel(logging.WARNING) + +BUFFER_TRUNCATION_THRESHOLD = 1400 READ_TIMEOUT = 10 -PUBLIC_READ_TIMEOUT = 60 -NON_PARED_READ_TIMEOUT = 5 -CONTROL_MESSAGE_NO_LENGTH = HMAC_MESSAGE_LENGTH = 2 +CONTROL_MESSAGE_NO_LENGTH = 2 CONNECTION_NO_LENGTH = DATA_LENGTH = 4 LATENCY_COUNT = 5 NONCE_LENGTH = 12 SYMMETRIC_KEY_LENGTH = RANDOM_LENGTH = 32 TERMINATOR = b';' -FORWARDER_BUFFER_TRUNCATION_THRESHOLD = (CONNECTION_NO_LENGTH + DATA_LENGTH + PRIVATE_BUFFER_TRUNCATION_THRESHOLD - + len(TERMINATOR)) class ConnectionNotFoundException(Exception): @@ -42,25 +38,9 @@ class ConnectionNotFoundException(Exception): class ControlMessage(enum.IntEnum): Ping = 1 Pong = 2 - ShareWebRTCDescription = 11 OpenConnection = 101 CloseConnection = 102 - - -def track_round_trip_latency(round_trip_latency, ping_time): # type: (List[Any], float) -> List[float] - time_now = time.perf_counter() - if len(round_trip_latency) >= LATENCY_COUNT: - round_trip_latency.pop(0) - # from the time the ping was sent to the time the pong was received - latency = time_now - ping_time - # Store in milliseconds - round_trip_latency.append(latency * 1000) - return round_trip_latency - - -def calc_round_trip_latency_average(round_trip_latency): # type: (Sequence[Union[float, int]]) -> float - return sum(round_trip_latency) / len(round_trip_latency) - # self.logger.debug(f'Endpoint {self.endpoint_name}: Private round trip latency average: {average_latency}') + ConnectionOpened = 103 def generate_random_bytes(pass_length=RANDOM_LENGTH): # type: (int) -> bytes @@ -119,457 +99,247 @@ def is_port_open(host: str, port: int) -> bool: return False -def find_server_public_key(raw_public_key: bytes): - try: - # Gateway public keys use the P-256 curve, so we need to use the same curve - return ec.EllipticCurvePublicKey.from_encoded_point(ec.SECP256R1(), raw_public_key) - except Exception as e: - logging.error(f"Error while loading public key: {e}") - return None - +def establish_symmetric_key(private_key, client_public_key): + # Perform ECDH key agreement + shared_secret = private_key.exchange(ec.ECDH(), client_public_key) -class TunnelProtocol(abc.ABC): - """ - This class is used to set up the public tunnel entrance. This is used for the signaling phase and control messages. + # Derive a symmetric key using HKDF + symmetric_key = HKDF( + algorithm=hashes.SHA256(), + length=SYMMETRIC_KEY_LENGTH, + salt=None, + info=b'encrypt network traffic', + ).derive(shared_secret) + return AESGCM(symmetric_key) - The public tunnel is encrypted using a Private key out of the record and the gateway's public key. - This tunnel is used to send control messages to the gateway: Ping, Pong, CloseConnection - and ShareWebRTCDescription. - There isn't a need for open connection because we send a start command in the discoveryrotation.py file. - There is one connection or channel, 0 for control messages. We have the ability to add more channels if needed +def tunnel_encrypt(symmetric_key: AESGCM, data: bytes): + """ Encrypts data using the symmetric key """ + nonce = os.urandom(NONCE_LENGTH) # 12-byte nonce for AES-GCM + d = nonce + symmetric_key.encrypt(nonce, data, None) + return bytes_to_base64(d) - The private tunnel uses WebRTC to connect to a peer on the gateway. - The flow is as follows: - The public tunnel Part I - 0. User enters a command to start a tunnel - 1. Commander sends a start command to the gateway through krouter - 2. Commander starts the public tunnel entrance and listens for messages from krouter - 2.5. The Gateway: starts the public tunnel exit, listens for messages from krouter - 3. There are ping and pong messages to keep the connection alive, and CloseConnection will close everything. - These are all encrypted using the private key and the gateway's public key. - - Signaling Phase - 4. Commander: makes a WebRTC peer, makes and offer, and sets its setLocalDescription. This offer gets send to the - gateway through the public tunnel using a ShareWebRTCDescription message. - 5. The Gateway gets that sets its setRemoteDescription and makes an answer. This answer gets sent back to - Commander in a ShareWebRTCDescription message. - 6. Commander sets its setRemoteDescription to the answer it got from the gateway and the two peers connect using - STUN and TURN servers. - - Setting up the private tunnel - 7. Commander sets up a local server that listens for connections to a - local port that the user has provided or a random port if none is provided. - 8. Commander sends a private ping message through the private tunnel entrance to the private tunnel exit - 9. The Gateway: receives the private ping message and sends a private pong message back establishing the - connection - 10. Commander waits for a client to connect to the local server. +def tunnel_decrypt(symmetric_key: AESGCM, encrypted_data: str): + """ Decrypts data using the symmetric key """ + data_bytes = base64_to_bytes(encrypted_data) + if len(data_bytes) <= NONCE_LENGTH: + return None + nonce = data_bytes[:NONCE_LENGTH] + data = data_bytes[NONCE_LENGTH:] + try: + return symmetric_key.decrypt(nonce, data, None) + except: + return None - User connects to the target host and port - 11. Client connects to the private tunnel's local server. - 12. Private Tunnel Entrance (In Commander) sends an open connection message to the WebRTC connection and listens - to the client forwarding on any data - 13. Private Tunnel Exit (On The Gateway): receives the open connection message and connects to the target - host and port sending any data back to the WebRTC connection - 14. The session goes on until the CloseConnection message is sent, or the outer tunnel is closed. - 15. The User can repeat steps 10-14 as many times as they want - User closes the public tunnel - 16. The User closes the public tunnel and everything is cleaned up, and we can start back at step 1 - """ - def __init__(self, tunnel, # type: ITunnel - endpoint_name = None, # type: Optional[str] - logger = None, # type: logging.Logger - gateway_public_key_bytes = None, # type: bytes - client_private_key = "", # type: str - host = "127.0.0.1", # type: str - port = None, # type: int - ): # type: (...) -> None - self.symmetric_key_aesgcm = None - self.hmac_message = generate_random_bytes() - self._round_trip_latency = [] - self.ping_time = None - self.tunnel = tunnel - self.endpoint_name = endpoint_name - self.logger = logger - self.target_port = port - self.target_host = host - self.private_tunnel = None - self._ping_attempt = 0 - self.private_tunnel_server = None - self.kill_server_event = asyncio.Event() - self.gateway_public_key_bytes = gateway_public_key_bytes - """ -# Generate an EC private key -private_key = ec.generate_private_key( - ec.SECP256R1(), # Using P-256 curve - backend=default_backend() -) -# Serialize to PEM format -private_key_str = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() -).decode('utf-8') - """ - - self.client_private_key_pem = serialization.load_pem_private_key( - client_private_key.encode(), - password=None, - backend=default_backend() - ) - self.server_public_key = None +class WebRTCConnection: + def __init__(self, endpoint_name: Optional[str] = "Keeper PAM Tunnel", + print_ready_event: Optional[asyncio.Event] = None, username: Optional[str] = None, + password: Optional[str] = None, logger: Optional[logging.Logger] = None): self.web_rtc_queue = asyncio.Queue() - self.tunnel_ready = asyncio.Event() + self.closed = False + self.data_channel = None + self.print_ready_event = print_ready_event # Define the STUN server URL + # krouter_server_url = 'https://connect.' + params.server # https://connect.dev.keepersecurity.com # stun_url = "stun:stun.l.google.com:19302" - relay_url = 'relaybeta.keeperpamlab.com' + relay_url = 'relay.keeperpamlab.com' stun_url = f"stun:{relay_url}:3478" # Create an RTCIceServer instance for the STUN server stun_server = RTCIceServer(urls=stun_url) # Define the TURN server URL and credentials - # The transport parameter in the TURN server configuration (such as UDP, TCP, TLS, or DTLS) specifies the - # protocol used for the connection between the peers and the TURN server. - # Peer-to-peer communication through the TURN server is still DTLS encrypted end to end. - turn_url = f"turn:{relay_url}:3478?transport=udp" + turn_url = f"turn:{relay_url}?transport=udp" + ''' # Define TURN server credentials - username = "your_username" - password = "your_password" - - # Define the TURN server URL with credentials - + username = username + password = password + # Create an RTCIceServer instance for the TURN server with credentials - turn_server = RTCIceServer(urls=turn_url, username=username, credential=password) ''' + turn_server = RTCIceServer(urls=turn_url, username=username, credential=password) # Create an RTCIceServer instance for the TURN server - turn_server = RTCIceServer(urls=turn_url) + # turn_server = RTCIceServer(urls=turn_url) # Create a new RTCConfiguration with both STUN and TURN servers config = RTCConfiguration(iceServers=[stun_server, turn_server]) + # config = RTCConfiguration(iceServers=[stun_server]) - self.pc = RTCPeerConnection(config) - self.data_channel = self.pc.createDataChannel("chat", ordered=True, maxPacketLifeTime=None, maxRetransmits=None) - - def on_data_channel_open(): - self.logger.debug("Data channel opened") - data = b'' - buffer = int.to_bytes(0, CONNECTION_NO_LENGTH, byteorder='big') - length = CONTROL_MESSAGE_NO_LENGTH + len(data) - buffer += int.to_bytes(length, DATA_LENGTH, byteorder='big') - buffer += int.to_bytes(ControlMessage.Ping, CONTROL_MESSAGE_NO_LENGTH, byteorder='big') - buffer += data + TERMINATOR - self.data_channel.send(buffer) - - def on_data_channel_message(message): - self.web_rtc_queue.put_nowait(message) - - def on_data_channel(channel): - channel.on("open", on_data_channel_open) - channel.on("message", on_data_channel_message) - self.tunnel_ready.set() - - def on_connection_state_change(): - self.logger.debug(f"Connection State has changed: {self.pc.connectionState}") - - if self.pc.connectionState == "connected": - # Connection is established, you can now send/receive data - pass - elif self.pc.connectionState == "connecting": - pass - elif self.pc.connectionState in ["disconnected", "failed", "closed"]: - # Handle disconnection or failure here - pass - - self.pc.on("datachannel", on_data_channel) - - self.pc.on("connectionstatechange", on_connection_state_change) - - def establish_symmetric_key(self): - # Step 3: Perform ECDH key agreement - shared_secret = self.client_private_key_pem.exchange(ec.ECDH(), self.server_public_key) - - # Step 4: Derive a symmetric key using HKDF - symmetric_key = HKDF( - algorithm=hashes.SHA256(), - length=SYMMETRIC_KEY_LENGTH, - salt=None, - info=b'encrypt network traffic', - ).derive(shared_secret) - self.symmetric_key_aesgcm = AESGCM(symmetric_key) - - def tunnel_encrypt(self, data: bytes): - """ Encrypts data using the symmetric key """ - nonce = os.urandom(NONCE_LENGTH) # 12-byte nonce for AES-GCM - d = nonce + self.symmetric_key_aesgcm.encrypt(nonce, data, None) - return bytes_to_base64(d) - - def tunnel_decrypt(self, encrypted_data: bytes): - """ Decrypts data using the symmetric key """ - if len(encrypted_data) <= NONCE_LENGTH: - self.logger.error(f'Endpoint {self.endpoint_name}: Invalid encrypted data') - return None - nonce = encrypted_data[:NONCE_LENGTH] - data = encrypted_data[NONCE_LENGTH:] - try: - return self.symmetric_key_aesgcm.decrypt(nonce, data, None) - except Exception as e: - self.logger.error(f'Endpoint {self.endpoint_name}: Failed to decrypt data: {e}') - return None - - async def connect(self): - if not self.tunnel.is_connected: - await self.tunnel.connect() - - self.server_public_key = find_server_public_key(self.gateway_public_key_bytes) + self._pc = RTCPeerConnection(config) + self.setup_data_channel() + self.setup_event_handlers() + self.logger = logger + self.endpoint_name = endpoint_name - if self.server_public_key is None: - self.logger.debug(f'Endpoint {self.endpoint_name}: Invalid public key') - await self.disconnect() - raise Exception('Invalid public key') + async def make_offer(self): + offer = await self._pc.createOffer() + await self._pc.setLocalDescription(offer) + return self._pc.localDescription.sdp.encode('utf-8') - self.establish_symmetric_key() + async def accept_answer(self, answer): + if isinstance(answer, bytes): + answer = bytes_to_string(answer) + await self._pc.setRemoteDescription(RTCSessionDescription(answer, "answer")) - t1 = asyncio.create_task(self.start_tunnel_reader()) - tasks = [t1] - self.logger.debug(f'Endpoint {self.endpoint_name}: Private tunnel started, sending HMAC to gateway') + def setup_data_channel(self): + self.data_channel = self._pc.createDataChannel("control", ordered=True) - # Create an offer - offer = await self.pc.createOffer() - await self.pc.setLocalDescription(offer) + def setup_event_handlers(self): + self.data_channel.on("open", self.on_data_channel_open) + self.data_channel.on("message", self.on_data_channel_message) + self._pc.on("datachannel", self.on_data_channel) + self._pc.on("connectionstatechange", self.on_connection_state_change) - # Send the offer to the server - await self.send_control_message(ControlMessage.ShareWebRTCDescription, - self.pc.localDescription.sdp.encode('utf-8')) + def on_data_channel_open(self): + self.print_ready_event.set() - await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + self.logger.debug("Data channel opened") + data = b'' + buffer = int.to_bytes(0, CONNECTION_NO_LENGTH, byteorder='big') + length = CONTROL_MESSAGE_NO_LENGTH + len(data) + buffer += int.to_bytes(length, DATA_LENGTH, byteorder='big') + buffer += int.to_bytes(ControlMessage.Ping, CONTROL_MESSAGE_NO_LENGTH, byteorder='big') + buffer += data + TERMINATOR + self.data_channel.send(buffer) + self.logger.error(f'Endpoint {self.endpoint_name}: Data channel opened') - await self.disconnect() + def on_data_channel_message(self, message): + self.web_rtc_queue.put_nowait(message) - async def disconnect(self): - try: - await self.send_control_message(ControlMessage.CloseConnection) - finally: - tasks = [] - try: - self.tunnel.disconnect() - except Exception as ex: - self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception closing tunnel {ex}') - try: - if self.private_tunnel: - tasks.append(self.private_tunnel.stop_server()) - except Exception as ex: - self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception closing private tunnel {ex}') - try: - if len(tasks) > 0: - await asyncio.gather(*tasks) - except Exception as ex: - self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception gathering tasks {ex}') + def on_data_channel(self, channel): + channel.on("open", self.on_data_channel_open) + channel.on("message", self.on_data_channel_message) - self.kill_server_event.set() + def on_connection_state_change(self): + self.logger.debug(f'Endpoint {self.endpoint_name}: Connection State has changed: {self._pc.connectionState}') + if self._pc.connectionState == "connected": + # Connection is established, you can now send/receive data + pass + elif self._pc.connectionState in ["disconnected", "failed", "closed"]: + # Handle disconnection or failure here + pass - async def start_tunnel_reader(self) -> None: - if not self.tunnel.is_connected: - self.logger.warning(f'Endpoint {self.endpoint_name}: Tunnel reader: not connected') - return + def is_data_channel_open(self): + return (self.data_channel is not None and self.data_channel.readyState == "open" + and self._pc.connectionState == "connected") - self._ping_attempt = 0 - buff = b'' - ''' - Data structure of a packet - +----------------------+----------------+---------------------------------+-------------+ - | Connection Number | Data Length | Data | Terminator | - | (4 bytes) | (4 bytes) | (variable length) | (variable) | - +----------------------+----------------+---------------------------------+-------------+ - | | | | | - | 0 (for | | +-------------------+---------+ | | - | control message) | | | Control Message | Control | | | - | | | | Number (2 bytes) | Data | | | - | | | +-------------------+---------+ | | - | | | | | - +----------------------+----------------+---------------------------------+-------------+ - - Breakdown of Each Part: - Connection Number (4 bytes): - This is the first part of the message. - It's used to identify which type of message is being sent. - In your code, a connection number of 0 signifies a control message. - Data Length (4 bytes): - This follows the connection number. - It specifies the length of the actual data in bytes that follows this field. - Data (variable length): - The content of the message. - Its length is determined by the "Data Length" field. - For control messages, it further contains a control message number and the actual control data. - Terminator (variable): - Marks the end of the message. - Check for this terminator to validate the end of a message. - If the terminator is not found or is incorrect, it indicates a message boundary error. - ''' - while self.tunnel.is_connected and not self.kill_server_event.is_set(): - if len(buff) >= CONNECTION_NO_LENGTH + DATA_LENGTH: - # At this stage we have two connections. 0 is for control messages and 1 is for data - connection_no = int.from_bytes(buff[:CONNECTION_NO_LENGTH], byteorder='big') - length = int.from_bytes(buff[CONNECTION_NO_LENGTH:CONNECTION_NO_LENGTH + DATA_LENGTH], - byteorder='big') - # Wait for the rest of the data if it hasn't arrived yet - if len(buff) >= CONNECTION_NO_LENGTH + DATA_LENGTH + length + len(TERMINATOR): - if buff[CONNECTION_NO_LENGTH + DATA_LENGTH + length: - CONNECTION_NO_LENGTH + DATA_LENGTH + length + len(TERMINATOR)] != TERMINATOR: - self.logger.warning(f'Endpoint {self.endpoint_name}: Invalid terminator') - # if we don't have a valid terminator then we don't know where the message ends or begins - break - s_data = buff[CONNECTION_NO_LENGTH + DATA_LENGTH: CONNECTION_NO_LENGTH + DATA_LENGTH + length] - buff = buff[CONNECTION_NO_LENGTH + DATA_LENGTH + length + len(TERMINATOR):] - if connection_no == 0: - # This is a control message - if len(s_data) >= CONTROL_MESSAGE_NO_LENGTH: - message_no = int.from_bytes(s_data[:CONTROL_MESSAGE_NO_LENGTH], byteorder='big') - s_data = s_data[CONTROL_MESSAGE_NO_LENGTH:] - await self.process_control_message(ControlMessage(message_no), s_data) - else: - # We have the ability to add more channels in the future if needed - self.logger.error(f"Endpoint {self.endpoint_name}: Invalid Public channel {connection_no}") - break - else: - self.logger.debug(f"Endpoint {self.endpoint_name}: Buffer is too short {len(buff)} need " - f"{CONNECTION_NO_LENGTH + DATA_LENGTH + length + len(TERMINATOR)}") - # Yield control back to the event loop for other tasks to execute - await asyncio.sleep(0) - else: - # Yield control back to the event loop for other tasks to execute - await asyncio.sleep(0) + # Example usage of state check in a method + def send_message(self, message): + if self.is_data_channel_open(): + self.data_channel.send(message) + else: + self.logger.error(f'Endpoint {self.endpoint_name}: Data channel is not open.') + + async def close_connection(self): + # Close the data channel if it's open + if self.data_channel and self.data_channel.readyState == "open": + self.data_channel.close() + self.logger.error(f'Endpoint {self.endpoint_name}: Data channel closed') + + # Close the peer connection + if self._pc: + await self._pc.close() + print("Peer connection closed") + + # Clear the asyncio queue + while not self.web_rtc_queue.empty(): + self.web_rtc_queue.get_nowait() + self.web_rtc_queue = None + + # Reset instance variables + self.data_channel = None + self.pc = None + + # Set the closed flag + self.closed = True + + # Close the asyncio event loop if necessary + loop = asyncio.get_event_loop() + if loop.is_running(): try: - buffer = await self.tunnel.read(PUBLIC_READ_TIMEOUT) - self.logger.debug(f"Endpoint {self.endpoint_name}: Received data from tunnel: {len(buffer)}") - if isinstance(buffer, bytes): - decrypt_data = self.tunnel_decrypt(buffer) - buff += decrypt_data - else: - # Yield control back to the event loop for other tasks to execute - await asyncio.sleep(0) - except asyncio.TimeoutError as e: - if self._ping_attempt > 3: - if self.tunnel.is_connected: - self.tunnel.disconnect() - raise e - self.logger.debug(f'Endpoint {self.endpoint_name}: Tunnel reader timed out') - self.logger.debug(f'Endpoint {self.endpoint_name}: Send ping request') - self.ping_time = time.perf_counter() - await self.send_control_message(ControlMessage.Ping) - self._ping_attempt += 1 - continue + for task in asyncio.all_tasks(loop): + task.cancel() + await loop.create_task(asyncio.sleep(.1)) + loop.stop() except Exception as e: - self.logger.warning(f'Endpoint {self.endpoint_name}: Failed to read from tunnel: {e}') - break - - if not self.tunnel.is_connected: - self.logger.info(f'Endpoint {self.endpoint_name}: Exiting public tunnel reader.') - break - - if not isinstance(buffer, bytes): - continue - self.logger.info(f'Endpoint {self.endpoint_name}: Exiting public tunnel reader.') - await self.disconnect() - - async def _send_to_tunnel(self, connection_no, data): # type: (int, bytes) -> None - buffer = int.to_bytes(connection_no, CONNECTION_NO_LENGTH, byteorder='big') - buffer += int.to_bytes(len(data), DATA_LENGTH, byteorder='big') - buffer += data + TERMINATOR - buffer = self.tunnel_encrypt(buffer) - self.logger.debug(f"Sending data to tunnel: {len(buffer)}") - await self.tunnel.write(buffer) - # Yield control back to the event loop for other tasks to execute - await asyncio.sleep(0) - - async def send_control_message(self, message_no: ControlMessage, data: Optional[bytes] = None) -> None: - buffer = int.to_bytes(message_no, CONTROL_MESSAGE_NO_LENGTH, byteorder='big') - buffer += data if data is not None else b'' - # Control messages are sent through connection 0 - await self._send_to_tunnel(0, buffer) - - async def process_control_message(self, message_no, data): # type: (ControlMessage, bytes) -> None - if message_no == ControlMessage.Ping: - self.logger.debug(f'Endpoint {self.endpoint_name}: Received ping request') - self.logger.debug(f'Endpoint {self.endpoint_name}: Send pong request') - await self.send_control_message(ControlMessage.Pong) - self._ping_attempt = 0 - elif message_no == ControlMessage.Pong: - self.logger.debug(f'Endpoint {self.endpoint_name}: Received pong request') - self._ping_attempt = 0 - if self.ping_time is not None: - self._round_trip_latency = track_round_trip_latency(self._round_trip_latency, self.ping_time) - self.logger.debug(f'Endpoint {self.endpoint_name}: Public round trip latency: ' - f'{self._round_trip_latency[-1]} ms') - self.ping_time = None - elif message_no == ControlMessage.ShareWebRTCDescription: - if len(data[CONNECTION_NO_LENGTH:]) > 0: - try: - await self.pc.setRemoteDescription(RTCSessionDescription(data[CONNECTION_NO_LENGTH:].decode(), - "answer")) + self.logger.error(f'Endpoint {self.endpoint_name}: Error stopping loop: {e}') - self.logger.debug("starting private tunnel") - - private_tunnel_event = asyncio.Event() + """ + This class is used to set up the tunnel entrance. This is used for the signaling phase and control messages. - self.private_tunnel = PrivateTunnelEntrance(private_tunnel_event=private_tunnel_event, - host=self.target_host, port=self.target_port, - endpoint_name=self.endpoint_name, - kill_server_event=self.kill_server_event, - data_channel=self.data_channel, - incoming_queue=self.web_rtc_queue, - logger=self.logger) + API calls offer/answer are encrypted using a shared secret derived from a key out of the record and the gateway's + own key. + This tunnel is used to send control messages to the gateway: Ping, Pong, CloseConnection + and ShareWebRTCDescription. + There isn't a need for open connection because we send a start command in the discoveryrotation.py file. + There is one connection or channel, 0 for control messages. We have the ability to add more channels if needed - private_tunnel_started = asyncio.Event() - self.private_tunnel_server = asyncio.create_task(self.private_tunnel.start_server( - private_tunnel_event, private_tunnel_started, self.tunnel_ready)) - await private_tunnel_started.wait() - serving = self.private_tunnel.server.is_serving() if self.private_tunnel.server else False + The tunnel uses WebRTC to connect to a peer on the gateway. - if not serving: - self.logger.debug(f'Endpoint {self.endpoint_name}: Private tunnel failed to start') - await self.disconnect() - raise Exception('Private tunnel failed to start') + The flow is as follows: + The pre tunnel (KRouter API calls/Signaling Phase) + 0. User enters a command to start a tunnel + 1. Commander sends a request to the KRouter API to get TURN server credentials + 2. Commander: makes a WebRTC peer, makes and offer, and sets its setLocalDescription. + 3. Commander sends tunnel start action to the gateway through krouter with offer encrypted using the shared + secret + 4. The Gateway gets that offer decrypts it with the shared secret and sets its setRemoteDescription, makes an + API call to krouter for TURN server credentials and makes an answer and sets its setLocalDescription. + 5. The gateway returns the answer encrypted using the shared secret in the response to the tunnel start action + 5.5 The Gateway sets up a tunnel exit + 6. Commander decrypts the answer and sets its setRemoteDescription to the answer it got from the gateway + 6.5 The Commander sets up a tunnel entrance + 7 The two peers connect using STUN and TURN servers. If a direct connection can be made then the TURN server + is not used. + + Setting up the tunnel + 8. Commander sets up a local server that listens for connections to a local port that the user has provided or a + random port if none is provided. + 9. Commander sends a ping message through the tunnel entrance to the tunnel exit + 10. The Gateway: receives the ping message and sends a pong message back establishing the + connection + 11. Commander waits for a client to connect to the local server. Both sides wait for a set timeout to receive + data if the timeout is reached then a ping is sent. After 3 pings with no pongs the tunnel is closed. - except Exception as e: - self.logger.error(f"Error setting remote description: {e}") + User connects to the target + 12. Client connects to the tunnel's local server (localhost:[PORT]. + 13. Tunnel Entrance (In Commander) sends an open connection message to the WebRTC connection and listens + to the client forwarding on any data + 14. Tunnel Exit (On The Gateway): receives the open connection message and connects to the target + host and port sending any data back to the WebRTC connection + 15. The session goes on until the CloseConnection message is sent, or the outer tunnel is closed. + 16. The User can repeat steps 12-16 as many times as they want - elif message_no == ControlMessage.CloseConnection: - await self.disconnect() - else: - self.logger.info(f'Endpoint {self.endpoint_name}: Unknown control message {message_no}') + User closes the tunnel + 17. The User closes the tunnel and everything is cleaned up, and we can start back at step 1 + """ -class PrivateTunnelEntrance: +class TunnelEntrance: """ This class is used to forward data between a WebRTC connection and a connection to a target. - The Private Tunnel isn't connected to the public tunnel except that the public tunnel can close it. Connection 0 is reserved for control messages. All other connections are for when a client connects - This private tunnel uses four control messages: Ping, Pong, OpenConnection and CloseConnection + This tunnel uses four control messages: Ping, Pong, OpenConnection and CloseConnection Data is broken into three parts: connection number, [message number], and data message number is only used in control messages. (if the connection number is 0 then there is a message number) """ def __init__(self, - private_tunnel_event, # type: asyncio.Event host, # type: str port, # type: int endpoint_name, # type: str - kill_server_event, # type: asyncio.Event - data_channel, # type: RTCDataChannel - incoming_queue, # type: asyncio.Queue + pc, # type: WebRTCConnection + print_ready_event, # type: asyncio.Event logger = None, # type: logging.Logger ): # type: (...) -> None - self._round_trip_latency = [] self.ping_time = None self.to_local_task = None - self.private_tunnel_event = private_tunnel_event self._ping_attempt = 0 self.host = host self.server = None @@ -581,19 +351,20 @@ def __init__(self, self.is_connected = True self.reader_task = asyncio.create_task(self.start_reader()) self.to_tunnel_tasks = {} - self.kill_server_event = kill_server_event - self.incoming_queue = incoming_queue - self.data_channel = data_channel + self.kill_server_event = asyncio.Event() + self.pc = pc + self.print_ready_event = print_ready_event async def send_to_web_rtc(self, data): - if self.data_channel.readyState == "open": - self.data_channel.send(data) + if self.pc.is_data_channel_open(): + self.pc.send_message(data) # Yield control back to the event loop for other tasks to execute await asyncio.sleep(0) else: - print("Data channel is not open. Data not sent.") + if self.print_ready_event.is_set(): + self.logger.error(f'Endpoint {self.endpoint_name}: Data channel is not open. Data not sent.') - async def send_control_message(self, message_no, data = None): # type: (ControlMessage, Optional[bytes]) -> None + async def send_control_message(self, message_no, data=None): # type: (ControlMessage, Optional[bytes]) -> None """ Packet structure Control Message Packets [CONNECTION_NO_LENGTH + DATA_LENGTH + CONTROL_MESSAGE_NO_LENGTH + DATA] @@ -609,53 +380,54 @@ async def send_control_message(self, message_no, data = None): # type: (Control f' to tunnel.') await self.send_to_web_rtc(buffer) except Exception as e: - self.logger.error(f"Endpoint {self.endpoint_name}: Error while sending private control message: {e}") + self.logger.error(f"Endpoint {self.endpoint_name}: Error while sending control message: {e}") async def process_control_message(self, message_no, data): # type: (ControlMessage, Optional[bytes]) -> None if message_no == ControlMessage.CloseConnection: - self.logger.debug(f'Endpoint {self.endpoint_name}: Received private close connection request') + self.logger.debug(f'Endpoint {self.endpoint_name}: Received close connection request') if data and len(data) > 0: target_connection_no = int.from_bytes(data, byteorder='big') if target_connection_no == 0: for c in list(self.connections): await self.close_connection(c) else: - self.logger.debug(f'Endpoint {self.endpoint_name}: Closing private connection ' + self.logger.debug(f'Endpoint {self.endpoint_name}: Closing connection ' f'{target_connection_no}') await self.close_connection(target_connection_no) elif message_no == ControlMessage.Pong: - self.logger.debug(f'Endpoint {self.endpoint_name}: Received private pong request') + self.logger.debug(f'Endpoint {self.endpoint_name}: Received pong request') self._ping_attempt = 0 self.is_connected = True if self.ping_time is not None: - self._round_trip_latency = track_round_trip_latency(self._round_trip_latency, self.ping_time) - self.logger.debug(f'Endpoint {self.endpoint_name}: Private round trip latency: ' - f'{self._round_trip_latency[-1]} ms') + time_now = time.perf_counter() + # from the time the ping was sent to the time the pong was received + latency = time_now - self.ping_time + self.logger.debug(f'Endpoint {self.endpoint_name}: Round trip latency: {latency} ms') self.ping_time = None elif message_no == ControlMessage.Ping: - self.logger.debug(f'Endpoint {self.endpoint_name}: Received private ping request') + self.logger.debug(f'Endpoint {self.endpoint_name}: Received ping request') await self.send_control_message(ControlMessage.Pong) - elif message_no == ControlMessage.OpenConnection: + elif message_no == ControlMessage.ConnectionOpened: if len(data) >= CONNECTION_NO_LENGTH: if len(data) > CONNECTION_NO_LENGTH: - self.logger.debug(f"Endpoint {self.endpoint_name}: Received invalid private open connection message" + self.logger.debug(f"Endpoint {self.endpoint_name}: Received invalid open connection message" f" ({len(data)} bytes)") connection_no = int.from_bytes(data[:CONNECTION_NO_LENGTH], byteorder='big') - self.logger.debug(f"Endpoint {self.endpoint_name}: Starting private reader for connection " + self.logger.debug(f"Endpoint {self.endpoint_name}: Starting reader for connection " f"{connection_no}") try: self.to_tunnel_tasks[connection_no] = asyncio.create_task( self.forward_data_to_tunnel(connection_no)) # From current connection to WebRTC connection self.logger.debug( - f"Endpoint {self.endpoint_name}: Started private reader for connection {connection_no}") + f"Endpoint {self.endpoint_name}: Started reader for connection {connection_no}") except ConnectionNotFoundException as e: self.logger.error(f"Endpoint {self.endpoint_name}: Connection {connection_no} not found: {e}") except Exception as e: - self.logger.error(f"Endpoint {self.endpoint_name}: Error while forwarding private data: {e}") + self.logger.error(f"Endpoint {self.endpoint_name}: Error while forwarding data: {e}") else: self.logger.error(f"Endpoint {self.endpoint_name}: Invalid open connection message") else: - self.logger.warning(f'Endpoint {self.endpoint_name} Unknown private tunnel control message: {message_no}') + self.logger.warning(f'Endpoint {self.endpoint_name} Unknown tunnel control message: {message_no}') async def forward_data_to_local(self): """ @@ -665,8 +437,7 @@ async def forward_data_to_local(self): Data Packets [CONNECTION_NO_LENGTH + DATA_LENGTH + DATA] """ try: - self.private_tunnel_event.set() - self.logger.debug(f"Endpoint {self.endpoint_name}: Forwarding private data to local...") + self.logger.debug(f"Endpoint {self.endpoint_name}: Forwarding data to local...") buff = b'' should_exit = False while not self.kill_server_event.is_set() and not should_exit: @@ -677,11 +448,11 @@ async def forward_data_to_local(self): if len(buff) >= CONNECTION_NO_LENGTH + DATA_LENGTH + length + len(TERMINATOR): if buff[CONNECTION_NO_LENGTH + DATA_LENGTH + length: CONNECTION_NO_LENGTH + DATA_LENGTH + length + len(TERMINATOR)] != TERMINATOR: - self.logger.warning(f'Endpoint {self.endpoint_name}: Private Invalid terminator') + self.logger.warning(f'Endpoint {self.endpoint_name}: Invalid terminator') # if we don't have a valid terminator then we don't know where the message ends or begins should_exit = True break - self.logger.debug(f'Endpoint {self.endpoint_name}: Private buffer data received data') + self.logger.debug(f'Endpoint {self.endpoint_name}: Buffer data received data') send_data = buff[CONNECTION_NO_LENGTH + DATA_LENGTH:CONNECTION_NO_LENGTH + DATA_LENGTH + length] buff = buff[CONNECTION_NO_LENGTH + DATA_LENGTH + length + len(TERMINATOR):] if connection_no == 0: @@ -694,13 +465,13 @@ async def forward_data_to_local(self): await self.process_control_message(control_m, send_data) else: if connection_no not in self.connections: - self.logger.error(f"Endpoint {self.endpoint_name}: Private connection not found: " + self.logger.error(f"Endpoint {self.endpoint_name}: Connection not found: " f"{connection_no}") continue _, con_writer = self.connections[connection_no] try: - self.logger.debug(f"Endpoint {self.endpoint_name}: Forwarding private data to " + self.logger.debug(f"Endpoint {self.endpoint_name}: Forwarding data to " f"local for connection {connection_no} ({len(send_data)})") con_writer.write(send_data) await con_writer.drain() @@ -708,55 +479,55 @@ async def forward_data_to_local(self): await asyncio.sleep(0) except Exception as ex: self.logger.error(f"Endpoint {self.endpoint_name}: Error while forwarding " - f"private data to local: {ex}") + f"data to local: {ex}") # Yield control back to the event loop for other tasks to execute await asyncio.sleep(0) else: self.logger.debug( - f"Endpoint {self.endpoint_name}: Private buffer is too short {len(buff)} need " + f"Endpoint {self.endpoint_name}: Buffer is too short {len(buff)} need " f"{CONNECTION_NO_LENGTH + DATA_LENGTH + length + len(TERMINATOR)}") # Yield control back to the event loop for other tasks to execute await asyncio.sleep(0) break try: - data = await asyncio.wait_for(self.incoming_queue.get(), READ_TIMEOUT) + data = await asyncio.wait_for(self.pc.web_rtc_queue.get(), READ_TIMEOUT) except asyncio.TimeoutError as et: if self._ping_attempt > 3: if self.is_connected: await self.stop_server() raise et - self.logger.debug(f'Endpoint {self.endpoint_name}: Private Tunnel reader timed out') - self.logger.debug(f'Endpoint {self.endpoint_name}: Send Private ping request') + self.logger.debug(f'Endpoint {self.endpoint_name}: Tunnel reader timed out') + self.logger.debug(f'Endpoint {self.endpoint_name}: Send ping request') self.ping_time = time.perf_counter() await self.send_control_message(ControlMessage.Ping) self._ping_attempt += 1 continue - self.incoming_queue.task_done() + self.pc.web_rtc_queue.task_done() if not data or not self.is_connected: - self.logger.info(f"Endpoint {self.endpoint_name}: Exiting forward private data to local") + self.logger.info(f"Endpoint {self.endpoint_name}: Exiting forward data to local") break elif len(data) == 0: # Yield control back to the event loop for other tasks to execute await asyncio.sleep(0) continue elif isinstance(data, bytes): - self.logger.debug(f"Endpoint {self.endpoint_name}: Got private data from WebRTC connection " + self.logger.debug(f"Endpoint {self.endpoint_name}: Got data from WebRTC connection " f"{len(data)} bytes)") buff += data else: # Yield control back to the event loop for other tasks to execute await asyncio.sleep(0) - self.logger.debug(f"Endpoint {self.endpoint_name}: Exiting forward private data successfully.") + self.logger.debug(f"Endpoint {self.endpoint_name}: Exiting forward data successfully.") except asyncio.CancelledError: pass except Exception as ex: - self.logger.error(f"Endpoint {self.endpoint_name}: Error while forwarding private data: {ex}") + self.logger.error(f"Endpoint {self.endpoint_name}: Error while forwarding data: {ex}") finally: - self.logger.debug(f"Endpoint {self.endpoint_name}: Closing private tunnel") + self.logger.debug(f"Endpoint {self.endpoint_name}: Closing tunnel") await self.stop_server() async def start_reader(self): # type: () -> None @@ -771,7 +542,7 @@ async def start_reader(self): # type: () -> None # Send hello world open connection message self.ping_time = time.perf_counter() await self.send_control_message(ControlMessage.Ping) - self.logger.debug(f"Endpoint {self.endpoint_name}: Sent private ping message to WebRTC connection") + self.logger.debug(f"Endpoint {self.endpoint_name}: Sent ping message to WebRTC connection") except Exception as e: self.logger.error(f"Endpoint {self.endpoint_name}: Error while establishing WebRTC connection: {e}") failed = True @@ -794,8 +565,8 @@ async def forward_data_to_tunnel(self, con_no): break reader, _ = c try: - data = await reader.read(PRIVATE_BUFFER_TRUNCATION_THRESHOLD) - self.logger.debug(f"Endpoint {self.endpoint_name}: Forwarding private {len(data)} " + data = await reader.read(BUFFER_TRUNCATION_THRESHOLD) + self.logger.debug(f"Endpoint {self.endpoint_name}: Forwarding {len(data)} " f"bytes to tunnel for connection {con_no}") if not data: self.logger.debug(f"Endpoint {self.endpoint_name}: Connection {con_no} no data") @@ -813,10 +584,10 @@ async def forward_data_to_tunnel(self, con_no): # Yield control back to the event loop for other tasks to execute await asyncio.sleep(0) except Exception as e: - self.logger.debug(f"Endpoint {self.endpoint_name}: Private connection '{con_no}' read failed: {e}") + self.logger.debug(f"Endpoint {self.endpoint_name}: Connection '{con_no}' read failed: {e}") break except Exception as e: - self.logger.error(f"Endpoint {self.endpoint_name}: Error while forwarding private data in connection " + self.logger.error(f"Endpoint {self.endpoint_name}: Error while forwarding data in connection " f"{con_no}: {e}") if con_no not in self.connections: @@ -835,17 +606,13 @@ async def handle_connection(self, reader, writer): # type: (asyncio.StreamReade self.connection_no += 1 self.connections[connection_no] = (reader, writer) - self.logger.debug(f"Endpoint {self.endpoint_name}: Created private local connection {connection_no}") + self.logger.debug(f"Endpoint {self.endpoint_name}: Created local connection {connection_no}") # Send open connection message with con_no. this is required to be sent to start the connection await self.send_control_message(ControlMessage.OpenConnection, int.to_bytes(connection_no, CONNECTION_NO_LENGTH, byteorder='big')) - async def start_server(self, - private_tunnel_event, # type: asyncio.Event - private_tunnel_started, # type: asyncio.Event - tunnel_ready # type: asyncio.Event - ): # type: (...) -> None + async def start_server(self): # type: (...) -> None """ This server is used to listen for client connections to the local port. """ @@ -855,12 +622,17 @@ async def start_server(self, self.logger.error(f"Endpoint {self.endpoint_name}: Error while finding open port: {e}") await self.print_not_ready() return + + if not self._port: + self.logger.error(f"Endpoint {self.endpoint_name}: No open ports found for local server") + await self.print_not_ready() + return + try: self.server = await asyncio.start_server(self.handle_connection, family=socket.AF_INET, host=self.host, port=self._port) async with self.server: - private_tunnel_started.set() - asyncio.create_task(self.print_ready(self.host, self._port, private_tunnel_event, tunnel_ready)) + await asyncio.create_task(self.print_ready(self.host, self._port, self.print_ready_event)) await self.server.serve_forever() except ConnectionRefusedError as er: self.logger.error(f"Endpoint {self.endpoint_name}: Connection Refused while starting server: {er}") @@ -876,6 +648,7 @@ async def start_server(self, return async def print_not_ready(self): + print(f'{bcolors.FAIL}+---------------------------------------------------------{bcolors.ENDC}') print(f'{bcolors.FAIL}| Endpoint {self.endpoint_name}{bcolors.ENDC} failed to start') print(f'{bcolors.FAIL}+---------------------------------------------------------{bcolors.ENDC}') @@ -886,16 +659,15 @@ async def print_not_ready(self): async def print_ready(self, host, # type: str port, # type: int - private_tunnel_event, # type: asyncio.Event - tunnel_ready # type: asyncio.Event + print_ready_event, # type: asyncio.Event ): # type: (...) -> None """ pretty prints the endpoint name and host:port after the tunnels are set up """ + wait_for_server = READ_TIMEOUT * 6 try: - await asyncio.wait_for(private_tunnel_event.wait(), timeout=60) - except asyncio.TimeoutError: - self.logger.debug(f"Endpoint {self.endpoint_name}: Timed out waiting for private tunnel to start") + await asyncio.wait_for(print_ready_event.wait(), wait_for_server) + except TimeoutError: await self.print_not_ready() return @@ -903,22 +675,19 @@ async def print_ready(self, host, # type: str await self.print_not_ready() return - try: - await asyncio.wait_for(tunnel_ready.wait(), timeout=60) - except asyncio.TimeoutError: - self.logger.debug(f"Endpoint {self.endpoint_name}: Timed out waiting for private tunnel to start") - await self.print_not_ready() - return - - # Just sleep a little bit to print out last + # Sleep a little bit to print out last await asyncio.sleep(.5) host = host + ":" if host else '' - print(f'{bcolors.OKGREEN}+---------------------------------------------------------{bcolors.ENDC}') + print(f'{bcolors.OKGREEN}+---------------------------------------------------------------{bcolors.ENDC}') print( f'{bcolors.OKGREEN}| Endpoint {bcolors.ENDC}{bcolors.OKBLUE}{self.endpoint_name}{bcolors.ENDC}' f'{bcolors.OKGREEN}: Listening on port: {bcolors.ENDC}' f'{bcolors.BOLD}{bcolors.OKBLUE}{host}{port}{bcolors.ENDC}') - print(f'{bcolors.OKGREEN}+---------------------------------------------------------{bcolors.ENDC}') + print(f'{bcolors.OKGREEN}+---------------------------------------------------------------{bcolors.ENDC}') + print(f'{bcolors.OKGREEN}View all open tunnels : {bcolors.ENDC}{bcolors.OKBLUE}pam tunnel list{bcolors.ENDC}') + print(f'{bcolors.OKGREEN}Tail logs on open tunnel: {bcolors.ENDC}{bcolors.OKBLUE}pam tunnel tail -c="[TUNNELID]"{bcolors.ENDC}') + print(f'{bcolors.OKGREEN}Stop a tunnel : {bcolors.ENDC}{bcolors.OKBLUE}pam tunnel stop -c="[TUNNELID]"{bcolors.ENDC}') + async def stop_server(self): try: @@ -952,19 +721,19 @@ async def close_connection(self, connection_no): await asyncio.wait_for(writer.wait_closed(), timeout=5.0) except asyncio.TimeoutError: self.logger.warning( - f"Endpoint {self.endpoint_name}: Timed out while trying to close Private connection " + f"Endpoint {self.endpoint_name}: Timed out while trying to close connection " f"{connection_no}") del self.connections[connection_no] - self.logger.info(f"Endpoint {self.endpoint_name}: Closed Private connection {connection_no}") + self.logger.info(f"Endpoint {self.endpoint_name}: Closed connection {connection_no}") else: - self.logger.info(f"Endpoint {self.endpoint_name}: Private Connection {connection_no} not found") + self.logger.info(f"Endpoint {self.endpoint_name}: Connection {connection_no} not found") if connection_no in self.to_tunnel_tasks: try: self.to_tunnel_tasks[connection_no].cancel() except Exception as ex: - self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception canceling private tasks {ex}') + self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception canceling tasks {ex}') del self.to_tunnel_tasks[connection_no] - self.logger.info(f"Endpoint {self.endpoint_name}: Tasks closed for Private connection {connection_no}") + self.logger.info(f"Endpoint {self.endpoint_name}: Tasks closed for connection {connection_no}") else: - self.logger.info(f"Endpoint {self.endpoint_name}: Private tasks for {connection_no} not found") + self.logger.info(f"Endpoint {self.endpoint_name}: Tasks for {connection_no} not found") diff --git a/keepercommander/commands/tunnel/port_forward/tunnel.py b/keepercommander/commands/tunnel/port_forward/tunnel.py deleted file mode 100644 index 558ebc66b..000000000 --- a/keepercommander/commands/tunnel/port_forward/tunnel.py +++ /dev/null @@ -1,53 +0,0 @@ -import abc -import asyncio - - -class ITunnel(abc.ABC): - @abc.abstractmethod - async def connect(self): # type: () -> None - pass - - @abc.abstractmethod - def disconnect(self): # type: () -> None - pass - - @property - @abc.abstractmethod - def is_connected(self): # type: () -> bool - pass - - @abc.abstractmethod - async def read(self, timeout = -1): # type: (int) -> bytes - pass - - @abc.abstractmethod - async def write(self, data): # type: (bytes) -> None - pass - - -class LocalTunnel(ITunnel): - def __init__(self, own, other): # type: (asyncio.Queue, asyncio.Queue) -> None - self._own = own - self._other = other - self._connected = False - - @property - def is_connected(self): - return self._connected - - async def connect(self): - self._connected = True - - def disconnect(self): - self._connected = False - - async def write(self, data): - await self._other.put(data) - - async def read(self, timeout = -1): - if timeout > 0: - buffer = await asyncio.wait_for(self._own.get(), timeout) - else: - buffer = await self._own.get() - self._own.task_done() - return buffer diff --git a/keepercommander/commands/tunnel/port_forward/tunnel_connected.py b/keepercommander/commands/tunnel/port_forward/tunnel_connected.py deleted file mode 100644 index f33429f51..000000000 --- a/keepercommander/commands/tunnel/port_forward/tunnel_connected.py +++ /dev/null @@ -1,59 +0,0 @@ -import asyncio -import json - -from websockets.client import WebSocketClientProtocol - -from keepercommander import utils -from keepercommander.utils import is_json -from .tunnel import ITunnel - - -class ConnectedTunnel(ITunnel): - def __init__(self, ws): - self.ws = ws # type: WebSocketClientProtocol - self.input_queue = asyncio.Queue() - self.output_queue = asyncio.Queue() - self._disconnect_requested = False - - def connect(self): # type: () -> None - pass - - def is_connected(self): - return True - - def disconnect(self): # type: () -> None - self._disconnect_requested = True - - async def ws_reader(self): - ws = self.ws - async for frame in ws: - if isinstance(frame, str): - if is_json(frame): - frame = json.loads(frame) - frame_data = frame.get('data') - else: - data = utils.base64_url_decode(frame) - - await self.output_queue.put(data) - - async def ws_writer(self): - while not self._disconnect_requested: - frame = await self.input_queue.get() - if frame: - if isinstance(frame, bytes): - frame = utils.base64_url_encode(frame) - await self.ws.send(frame) - await self.ws.close() - - async def read(self, timeout = -1): - if timeout > 0: - buffer = await asyncio.wait_for(self.output_queue.get(), timeout) - else: - buffer = await self.output_queue.get() - self.output_queue.task_done() - return buffer - - async def write(self, data): - if self.is_connected: - if len(data) > 0: - await self.input_queue.put(data) diff --git a/keepercommander/proto/pam_pb2.py b/keepercommander/proto/pam_pb2.py index 849c60152..45e03bcb7 100644 --- a/keepercommander/proto/pam_pb2.py +++ b/keepercommander/proto/pam_pb2.py @@ -288,4 +288,4 @@ _CONFIGURATIONADDREQUEST._serialized_end=2235 _RELAYACCESSCREDS._serialized_start=2237 _RELAYACCESSCREDS._serialized_end=2291 -# @@protoc_insertion_point(module_scope) +# @@protoc_insertion_point(module_scope) \ No newline at end of file diff --git a/unit-tests/pam-tunnel/test_pam_tunnel.py b/unit-tests/pam-tunnel/test_pam_tunnel.py index 03097c7cf..d4bffaac6 100644 --- a/unit-tests/pam-tunnel/test_pam_tunnel.py +++ b/unit-tests/pam-tunnel/test_pam_tunnel.py @@ -2,7 +2,7 @@ import unittest from unittest import mock -if sys.version_info >= (3, 15): +if sys.version_info >= (3, 11): import datetime import socket import string diff --git a/unit-tests/pam-tunnel/test_private_tunnel.py b/unit-tests/pam-tunnel/test_private_tunnel.py index 7a89b4841..7df78c382 100644 --- a/unit-tests/pam-tunnel/test_private_tunnel.py +++ b/unit-tests/pam-tunnel/test_private_tunnel.py @@ -10,16 +10,15 @@ from aiortc import RTCDataChannel from cryptography.utils import int_to_bytes from keepercommander import utils - from keepercommander.commands.tunnel.port_forward.endpoint import (PrivateTunnelEntrance, ControlMessage, + from keepercommander.commands.tunnel.port_forward.endpoint import (TunnelEntrance, ControlMessage, CONTROL_MESSAGE_NO_LENGTH, CONNECTION_NO_LENGTH, ConnectionNotFoundException, - TERMINATOR, DATA_LENGTH) + TERMINATOR, DATA_LENGTH, WebRTCConnection) from test_pam_tunnel import new_private_key # Only define the class if Python version is 3.8 or higher class TestPrivateTunnelEntrance(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): - self.event = asyncio.Event() self.host = 'localhost' self.port = 8080 self.endpoint_name = 'TestEndpoint' @@ -28,13 +27,12 @@ async def asyncSetUp(self): self.logger = mock.MagicMock(spec=logging) self.kill_server_event = asyncio.Event() self.tunnel_symmetric_key = utils.generate_aes_key() - self.data_channel = mock.MagicMock(sepc=RTCDataChannel) - self.data_channel.readyState = 'open' + self.pc = mock.MagicMock(sepc=WebRTCConnection) + self.pc.data_channel.readyState = 'open' self.incoming_queue = mock.MagicMock(sepc=asyncio.Queue()) - self.pte = PrivateTunnelEntrance( - self.event, self.host, self.port, self.endpoint_name, self.kill_server_event, self.data_channel, - self.incoming_queue, self.logger - ) + self.print_ready_event = asyncio.Event() + self.pte = TunnelEntrance(self.host, self.port, self.endpoint_name, self.pc, + self.print_ready_event, self.logger) async def set_queue_side_effect(self): data = b'some_data' @@ -46,7 +44,7 @@ async def mock_incoming_queue_get(): yield None # Now use an iterator of this coroutine function as the side effect - self.pte.incoming_queue.get.side_effect = mock_incoming_queue_get().__anext__ + self.pte.pc.web_rtc_queue.get.side_effect = mock_incoming_queue_get().__anext__ async def asyncTearDown(self): await self.pte.stop_server() # ensure the server is stopped after test @@ -56,7 +54,7 @@ async def test_send_control_message(self): self.pte.tls_writer = mock.MagicMock(spec=asyncio.StreamWriter) # Mock write and drain methods - with mock.patch.object(self.pte.data_channel, 'send', new_callable=mock.AsyncMock) as mock_send: + with mock.patch.object(self.pte.pc.data_channel, 'send', new_callable=mock.AsyncMock) as mock_send: # Define the control message and optional data control_message = ControlMessage.Ping @@ -81,7 +79,7 @@ async def test_send_control_message_with_error(self): self.pte.logger = mock.MagicMock() # Set side effect to raise an exception - self.pte.data_channel.send.side_effect = Exception("Mocked Exception") + self.pte.pc.data_channel.send.side_effect = Exception("Mocked Exception") # Define the control message and optional data control_message = ControlMessage.Ping @@ -151,8 +149,7 @@ async def test_process_ping_message(self): async def test_start_server(self): with mock.patch('asyncio.start_server', new_callable=mock.AsyncMock) as mock_open_connection, \ mock.patch.object(self.pte, 'handle_connection', new_callable=mock.AsyncMock) as mock_handle_connection: - await self.pte.start_server(mock.AsyncMock(spec=asyncio.Event), mock.AsyncMock(spec=asyncio.Event), - mock.AsyncMock(spec=asyncio.Event)) + await self.pte.start_server() mock_open_connection.assert_called_with(mock_handle_connection, family=socket.AF_INET, host='localhost', port=self.port) @@ -163,8 +160,7 @@ async def test_start_server_normal(self): self.pte.logger = mock.MagicMock() - await self.pte.start_server(mock.AsyncMock(spec=asyncio.Event), mock.AsyncMock(spec=asyncio.Event), - mock.AsyncMock(spec=asyncio.Event)) + await self.pte.start_server() print_ready.assert_called_once() @@ -174,7 +170,7 @@ async def test_start_server_connection_refused_error(self): mock_start_server.side_effect = ConnectionRefusedError self.pte.logger = mock.MagicMock() - await self.pte.start_server(mock.AsyncMock(), mock.AsyncMock(), mock.AsyncMock()) + await self.pte.start_server() self.pte.logger.error.assert_called_with('Endpoint TestEndpoint: Connection Refused while starting ' 'server: ') @@ -187,7 +183,7 @@ async def test_start_server_timeout_error(self): mock_start_server.side_effect = TimeoutError self.pte.logger = mock.MagicMock() - await self.pte.start_server(mock.AsyncMock(), mock.AsyncMock(), mock.AsyncMock()) + await self.pte.start_server() self.pte.logger.error.assert_called_with('Endpoint TestEndpoint: OS Error while starting server: ') mock_stop.assert_called() @@ -199,7 +195,7 @@ async def test_start_server_os_error(self): mock_start_server.side_effect = OSError("Some OS Error") self.pte.logger = mock.MagicMock() - await self.pte.start_server(mock.AsyncMock(), mock.AsyncMock(), mock.AsyncMock()) + await self.pte.start_server() self.pte.logger.error.assert_called_with('Endpoint TestEndpoint: OS Error while starting server: ' 'Some OS Error') @@ -212,7 +208,7 @@ async def test_start_server_generic_exception(self): mock_start_server.side_effect = Exception("Some generic exception") self.pte.logger = mock.MagicMock() - await self.pte.start_server(mock.AsyncMock(), mock.AsyncMock(), mock.AsyncMock()) + await self.pte.start_server() self.pte.logger.error.assert_called_with('Endpoint TestEndpoint: Error while starting server: ' 'Some generic exception') @@ -242,13 +238,16 @@ async def read_side_effect(*args, **kwargs): self.pte.kill_server_event = mock.MagicMock(spec=asyncio.Event) self.pte.kill_server_event.is_set.side_effect = [False, False, True] + self.pte.pc = mock.MagicMock(spec=WebRTCConnection) + self.pte.pc.data_channel = mock.MagicMock(spec=RTCDataChannel) + self.pte.pc.data_channel.readyState = 'open' # Run the task and wait for it to complete task = asyncio.create_task(self.pte.forward_data_to_tunnel(1)) await asyncio.sleep(.01) # Give some time for the task to run task.cancel() # Cancel the task to stop it from running indefinitely - self.pte.data_channel.send.assert_called_with(b'\x00\x00\x00\x01\x00\x00\x00\x0bhello world;') + self.pte.pc.data_channel.send.assert_called_with(b'\x00\x00\x00\x01\x00\x00\x00\x0bhello world;') # Test Connection Not Found async def test_forward_data_to_tunnel_no_connection(self): @@ -334,31 +333,17 @@ async def test_print_not_ready(self): # Test print_ready async def test_print_ready(self): with mock.patch('builtins.print') as mock_print: - await self.pte.print_ready('localhost', 8080, mock.AsyncMock(), mock.AsyncMock()) + await self.pte.print_ready('localhost', 8080, mock.AsyncMock()) # Check if print was called (optional) mock_print.assert_called() # Test print_ready with TimeoutError async def test_print_ready_timeout_error_forwarder(self): - forwarder_event = mock.AsyncMock(spec=asyncio.Event) - forwarder_event.wait.side_effect = asyncio.TimeoutError() - private_tunnel_event = mock.AsyncMock(spec=asyncio.Event) + print_event = mock.AsyncMock(spec=asyncio.Event) + print_event.wait.side_effect = asyncio.TimeoutError() with mock.patch.object(self.pte, 'print_not_ready', new_callable=mock.AsyncMock) as mock_print_not_ready: - await self.pte.print_ready('localhost', 8080, forwarder_event, private_tunnel_event) - - # Check if logger.debug was called - self.pte.logger.debug.assert_called_with("Endpoint TestEndpoint: Timed out waiting for private tunnel to start") - # Check if print was called (optional) - mock_print_not_ready.assert_called() - - # Test print_ready with TimeoutError - async def test_print_ready_timeout_error_private_tunnel(self): - forwarder_event = mock.AsyncMock(spec=asyncio.Event) - private_tunnel_event = mock.AsyncMock(spec=asyncio.Event) - private_tunnel_event.wait.side_effect = asyncio.TimeoutError() - with mock.patch.object(self.pte, 'print_not_ready', new_callable=mock.AsyncMock) as mock_print_not_ready: - await self.pte.print_ready('localhost', 8080, forwarder_event, private_tunnel_event) + await self.pte.print_ready('localhost', 8080, print_event) # Check if logger.debug was called self.pte.logger.debug.assert_called_with("Endpoint TestEndpoint: Timed out waiting for private tunnel to start") diff --git a/unit-tests/pam-tunnel/test_public_tunnel.py b/unit-tests/pam-tunnel/test_public_tunnel.py deleted file mode 100644 index 06f09d132..000000000 --- a/unit-tests/pam-tunnel/test_public_tunnel.py +++ /dev/null @@ -1,102 +0,0 @@ -import sys -import unittest -from unittest import mock - -if sys.version_info >= (3, 15): - from cryptography.hazmat.primitives.asymmetric import ec - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import serialization - from keeper_secrets_manager_core.utils import base64_to_bytes - from keepercommander.commands.tunnel.port_forward.tunnel import ITunnel - from keepercommander.commands.tunnel.port_forward.endpoint import (ControlMessage, CONTROL_MESSAGE_NO_LENGTH, - DATA_LENGTH, CONNECTION_NO_LENGTH, TunnelProtocol, - TERMINATOR, find_server_public_key) - - - # Only define the class if Python version is 3.8 or higher - def make_private_key(): - private_key = ec.generate_private_key( - ec.SECP256R1(), # Using P-256 curve - backend=default_backend() - ) - return private_key - - class TestPublicTunnel(unittest.IsolatedAsyncioTestCase): - - async def asyncSetUp(self): - # Initialize mock objects and test setup - self.mock_tunnel = mock.AsyncMock(spec=ITunnel) - self.mock_logger = mock.Mock() - - self.mock_tunnel.is_connected = True - self.client_private_key = make_private_key() - self.client_private_key_str = self.client_private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() - ).decode('utf-8') - self.gateway_public_key = make_private_key().public_key().public_bytes(encoding=serialization.Encoding.X962, - format=serialization.PublicFormat. - UncompressedPoint) - - self.tunnel_protocol = TunnelProtocol(self.mock_tunnel, endpoint_name="Test Public", - logger=self.mock_logger, - gateway_public_key_bytes=self.gateway_public_key, - client_private_key=self.client_private_key_str, host="localhost", - port=8080) - - self.tunnel_protocol.private_tunnel_server = mock.AsyncMock() - self.tunnel_protocol.read_connection_task = mock.AsyncMock() - - self.tunnel_protocol.server_public_key = find_server_public_key( - self.tunnel_protocol.gateway_public_key_bytes) - self.tunnel_protocol.establish_symmetric_key() - - async def asyncTearDown(self): - await self.tunnel_protocol.disconnect() - - async def test_connect(self): - self.mock_tunnel.is_connected = False - with mock.patch.object(self.tunnel_protocol, 'start_tunnel_reader', return_value=None) as mock_start_reader, \ - mock.patch.object(self.tunnel_protocol, 'disconnect', return_value=None) as mock_disconnect: - - await self.tunnel_protocol.connect() - - self.mock_tunnel.connect.assert_called_once() - mock_start_reader.assert_called_once() - mock_disconnect.assert_called_once() - - async def test_disconnect(self): - with mock.patch.object(self.tunnel_protocol, 'send_control_message', return_value=None) as mock_send_control: - await self.tunnel_protocol.disconnect() - mock_send_control.assert_called_once_with(ControlMessage.CloseConnection) - self.assertTrue(self.tunnel_protocol.kill_server_event.is_set()) - - async def test_start_tunnel_reader_control(self): - # build data for a ping control message - data = b'' - data1 = int.to_bytes(ControlMessage.Ping, CONTROL_MESSAGE_NO_LENGTH, byteorder='big') + data - buffer = int.to_bytes(0, CONNECTION_NO_LENGTH, byteorder='big') - buffer += int.to_bytes(len(data1), DATA_LENGTH, byteorder='big') - buffer += data1 + TERMINATOR - - self.tunnel_protocol.tunnel.read = mock.AsyncMock() - self.tunnel_protocol.tunnel.read.side_effect = [base64_to_bytes(self.tunnel_protocol.tunnel_encrypt(buffer)), None] - with mock.patch.object(self.tunnel_protocol, 'process_control_message', return_value=None) as mock_process: - await self.tunnel_protocol.start_tunnel_reader() - self.mock_tunnel.read.assert_called() - mock_process.assert_called_with(ControlMessage.Ping, data) - - async def test_send_to_tunnel(self): - await self.tunnel_protocol._send_to_tunnel(1, b'data') - self.mock_tunnel.write.assert_called_once() - - async def test_send_control_message(self): - with mock.patch.object(self.tunnel_protocol, '_send_to_tunnel', return_value=None) as mock_send_to_tunnel: - await self.tunnel_protocol.send_control_message(ControlMessage.Ping) - mock_send_to_tunnel.assert_called_once() - - async def test_process_control_message(self): - with mock.patch.object(self.tunnel_protocol, 'send_control_message', return_value=None) as mock_send_control: - await self.tunnel_protocol.process_control_message(ControlMessage.Ping, b'') - mock_send_control.assert_called_once_with(ControlMessage.Pong)