diff --git a/keepercommander/commands/discoveryrotation.py b/keepercommander/commands/discoveryrotation.py index a7cb1866b..d108e414c 100644 --- a/keepercommander/commands/discoveryrotation.py +++ b/keepercommander/commands/discoveryrotation.py @@ -1733,7 +1733,7 @@ def retrieve_gateway_public_key(gateway_uid, params, api, utils) -> bytes: class PAMTunnelStartCommand(Command): pam_cmd_parser = argparse.ArgumentParser(prog='dr-port-forward-command') - pam_cmd_parser.add_argument('--gateway', '-g', required=False, dest='gateway', action='store', + pam_cmd_parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', help='Used to list all tunnels for the given Gateway UID') pam_cmd_parser.add_argument('--record', '-r', required=True, dest='record_uid', action='store', help='The Record UID of the PAM resource record with network information to use for ' diff --git a/keepercommander/commands/tunnel/port_forward/endpoint.py b/keepercommander/commands/tunnel/port_forward/endpoint.py index 8d2448f99..9c8028465 100644 --- a/keepercommander/commands/tunnel/port_forward/endpoint.py +++ b/keepercommander/commands/tunnel/port_forward/endpoint.py @@ -54,6 +54,7 @@ class ControlMessage(enum.IntEnum): OpenConnection = 101 CloseConnection = 102 ConnectionOpened = 103 + SendEOF = 104 def generate_random_bytes(pass_length=RANDOM_LENGTH): # type: (int) -> bytes @@ -463,6 +464,7 @@ def __init__(self, self.pc = pc self.print_ready_event = print_ready_event self.connect_task = connect_task + self.eof_sent = False @property def port(self): @@ -579,6 +581,14 @@ async def process_control_message(self, message_no, data): # type: (ControlMess self.logger.error(f"Endpoint {self.endpoint_name}: Error in forwarding data task: {e}") else: self.logger.error(f"Endpoint {self.endpoint_name}: Invalid open connection message") + elif message_no == ControlMessage.SendEOF: + if len(data) >= CONNECTION_NO_LENGTH: + con_no = int.from_bytes(data[:CONNECTION_NO_LENGTH], byteorder='big') + if con_no in self.connections: + self.logger.debug(f'Endpoint {self.endpoint_name}: Sending EOF to {con_no}') + self.connections[con_no].writer.write_eof() + else: + self.logger.error(f'Endpoint {self.endpoint_name}: Connection for EOF {con_no} not found') else: self.logger.warning(f'Endpoint {self.endpoint_name} Unknown tunnel control message: {message_no}') @@ -726,10 +736,15 @@ async def forward_data_to_tunnel(self, con_no): break if isinstance(data, bytes): if c.reader.at_eof() and len(data) == 0: + if not self.eof_sent: + await self.send_control_message(ControlMessage.SendEOF, + int_to_bytes(con_no, CONNECTION_NO_LENGTH)) + self.eof_sent = True # Yield control back to the event loop for other tasks to execute await asyncio.sleep(0) continue else: + self.eof_sent = False buffer = int.to_bytes(con_no, CONNECTION_NO_LENGTH, byteorder='big') buffer += int.to_bytes(len(data), DATA_LENGTH, byteorder='big') + data + TERMINATOR await self.send_to_web_rtc(buffer) @@ -849,7 +864,8 @@ async def stop_server(self): self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception closing data channel {ex}') try: - self.connect_task.cancel() + if self.connect_task is not None: + self.connect_task.cancel() finally: self.closing = True self.logger.info(f"Endpoint {self.endpoint_name}: Tunnel stopped") @@ -873,7 +889,8 @@ async def close_connection(self, connection_no): if connection_no in self.connections: try: - self.connections[connection_no].to_tunnel_task.cancel() + if self.connections[connection_no].to_tunnel_task is not None: + self.connections[connection_no].to_tunnel_task.cancel() except Exception as ex: self.logger.warning(f'Endpoint {self.endpoint_name}: hit exception canceling tasks {ex}') del self.connections[connection_no]