Skip to content

Commit

Permalink
Port Tunneling interface fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
miroberts committed Dec 11, 2023
1 parent c6305c7 commit 51bccea
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 159 deletions.
174 changes: 74 additions & 100 deletions keepercommander/commands/discoveryrotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
import os.path
import queue
import socket
import sys
import threading
from datetime import datetime
Expand Down Expand Up @@ -1562,18 +1563,13 @@ def execute(self, params, **kwargs):
############################################## TUNNELING ###############################################################
class PAMTunnelListCommand(Command):
pam_cmd_parser = argparse.ArgumentParser(prog='dr-tunnel-list-command')
pam_cmd_parser.add_argument('--uid', '-u', required=False, dest='record_uid', action='store',
help='Filter list with UID of the PAM record that was used to create the tunnel')
pam_cmd_parser.add_argument('--conversation-id', '-c', required=False, dest='convo_id', action='store',
help='The connection ID of the Tunnel to list')

def get_parser(self):
return PAMTunnelListCommand.pam_cmd_parser

def execute(self, params, **kwargs):
def gather_tabel_row_data(thread):
# {"thread": t, "host": host, "port": port, "name": listener_name, "started": datetime.now(),
# "record_uid": record_uid}
# {"thread": t, "host": host, "port": port, "started": datetime.now(),
row = []
run_time = None
hours = 0
Expand Down Expand Up @@ -1606,93 +1602,58 @@ def gather_tabel_row_data(thread):
row.append(text_line)
return row

convo_id = kwargs.get('convo_id', None)

if not params.tunnel_threads:
logging.warning(f"{bcolors.OKBLUE}No Tunnels running{bcolors.ENDC}")
return

table = []
headers = ['Tunnel ID', 'Host', 'Port', 'Record UID', 'Up Time']

if convo_id:
if convo_id in params.tunnel_threads:
table.append(gather_tabel_row_data(params.tunnel_threads[convo_id]))
else:
print(f"{bcolors.FAIL}Tunnel {convo_id} not found{bcolors.ENDC}")
return
else:
for i, convo_id in enumerate(params.tunnel_threads):
row = gather_tabel_row_data(params.tunnel_threads[convo_id])
if row:
table.append(row)
for i, convo_id in enumerate(params.tunnel_threads):
row = gather_tabel_row_data(params.tunnel_threads[convo_id])
if row:
table.append(row)

dump_report_data(table, headers, fmt='table', filename="", row_number=False, column_width=None)


class PAMTunnelStopCommand(Command):
pam_cmd_parser = argparse.ArgumentParser(prog='dr-tunnel-stop-command')
pam_cmd_parser.add_argument('--conversation-id', '-c', required=False, dest='convo_id', action='store',
help='The connection ID of the Tunnel to stop')

def tunnel_cleanup(self, params, convo_id):
tunnel_data = params.tunnel_threads.get(convo_id, None)
if not tunnel_data:
return

for task_name in ["print", "connect"]:
task = tunnel_data.get(task_name)
if task:
task.cancel()
logging.debug(f"Cancelled {task_name} for {convo_id}")

del params.tunnel_threads[convo_id]
logging.debug(f"Cleaned up data for {convo_id}")

if convo_id in params.tunnel_threads_queue:
del params.tunnel_threads_queue[convo_id]
logging.debug(f"{bcolors.OKBLUE}{convo_id} Queue cleaned up{bcolors.ENDC}")
pam_cmd_parser.add_argument('uid', type= str, action='store', help='The Tunnel UID')

def get_parser(self):
return PAMTunnelStopCommand.pam_cmd_parser

def execute(self, params, **kwargs):
convo_id = kwargs.get('convo_id')
tunnel_data = params.tunnel_threads.get(convo_id, None)
convo_id = kwargs.get('uid')
if not convo_id:
raise CommandError('tunnel stop', '"uid" argument is required')

if tunnel_data is None:
print(f"{bcolors.WARNING}No data found for conversation ID {convo_id}{bcolors.ENDC}")
tunnel_data = params.tunnel_threads.get(convo_id, None)
if not tunnel_data:
logging.debug(f"{bcolors.WARNING}No tunnel data found for {convo_id}{bcolors.ENDC}")
return

# Stop the asyncio event loop
loop = tunnel_data.get("loop", None)

entrance = tunnel_data.get("entrance", None)
if loop and entrance and not loop._closed:
if loop._closed:
print(f"{bcolors.WARNING}Event loop is closed for conversation ID {convo_id}{bcolors.ENDC}")
else:
# Run the disconnect method in the event loop
loop.create_task(entrance.stop_server())
print(f"Disconnected entrance for {convo_id}")

loop.call_soon_threadsafe(self.tunnel_cleanup, params, convo_id)
connect_task = tunnel_data.get("connect_task", None)
if connect_task:
connect_task.cancel()
return


class PAMTunnelTailCommand(Command):
pam_cmd_parser = argparse.ArgumentParser(prog='dr-tunnel-tail-command')
pam_cmd_parser.add_argument('--conversation-id', '-c', required=False, dest='convo_id', action='store',
help='The connection ID of the Tunnel to tail logs')
pam_cmd_parser.add_argument('uid', type= str, action='store', help='The Tunnel UID')

def get_parser(self):
return PAMTunnelTailCommand.pam_cmd_parser

def execute(self, params, **kwargs):
convo_id = kwargs.get('convo_id')
convo_id = kwargs.get('uid')
if not convo_id:
raise CommandError('tunnel tail', '"uid" argument is required')

log_queue = params.tunnel_threads_queue.get(convo_id)

# TODO make this run in a new thread?
logger_level = logging.getLogger().getEffectiveLevel()

logging.getLogger('aiortc').setLevel(logging.DEBUG)
Expand Down Expand Up @@ -1761,8 +1722,6 @@ class PAMTunnelStartCommand(Command):
type=int, default=0,
help='The port number on which the server will be listening for incoming connections. '
'If not set, random open port on the machine will be used.')
pam_cmd_parser.add_argument('--listener-name', '-l', required=False, dest='listener_name',
action='store', default="Keeper PAM Tunnel", help='The name of the listener.')

def get_parser(self):
return PAMTunnelStartCommand.pam_cmd_parser
Expand Down Expand Up @@ -1792,25 +1751,7 @@ def setup_logging(self, convo_id, log_queue, logging_level):
logger.debug("Logging setup complete.")
return logger

def tunnel_cleanup(self, params, convo_id):
tunnel_data = params.tunnel_threads.get(convo_id, None)
if not tunnel_data:
return

for task_name in ["print", "entrance"]:
task = tunnel_data.get(task_name)
if task:
task.cancel()
print(f"Cancelled {task_name} for {convo_id}")

del params.tunnel_threads[convo_id]
print(f"Cleaned up data for {convo_id}")

if convo_id in params.tunnel_threads_queue:
del params.tunnel_threads_queue[convo_id]
print(f"{bcolors.OKBLUE}{convo_id} Queue cleaned up{bcolors.ENDC}")

async def connect(self, params, record_uid, convo_id, gateway_uid, host, port, listener_name,
async def connect(self, params, record_uid, convo_id, gateway_uid, host, port,
log_queue, gateway_public_key_bytes, client_private_key):

# Setup custom logging to put logs into log_queue
Expand Down Expand Up @@ -1847,11 +1788,17 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port, l

# Set up the pc
print_ready_event = asyncio.Event()
pc = WebRTCConnection(endpoint_name=listener_name, print_ready_event=print_ready_event,
pc = WebRTCConnection(endpoint_name=convo_id, print_ready_event=print_ready_event,
username=response.username, password=response.password, logger=logger)

# make webRTC sdp offer
offer = await pc.make_offer()
try:
offer = await pc.make_offer()
except socket.gaierror:
print(f"{bcolors.WARNING}Please upgrade Commander to the latest version to use this feature...{bcolors.ENDC}")
return
except Exception as e:
raise CommandError('tunnel start', f'Error making WebRTC offer: {e}')
encrypted_offer = tunnel_encrypt(symmetric_key, offer)
logger.debug("-->. SEND START MESSAGE OVER REST TO GATEWAY")

Expand All @@ -1871,7 +1818,7 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port, l
# TODO create objects for WebRTC inputs
router_response = router_send_action_to_gateway(
params=params,
gateway_action=GatewayActionWebRTCSession(inputs={'listenerName': listener_name, "recordUid": record_uid,
gateway_action=GatewayActionWebRTCSession(inputs={'listenerName': convo_id, "recordUid": record_uid,
"offer": encrypted_offer, 'kind': 'start',
'conversationType': 'tunnel'}),
message_type=pam_pb2.CMT_GENERAL,
Expand Down Expand Up @@ -1900,18 +1847,16 @@ async def connect(self, params, record_uid, convo_id, gateway_uid, host, port, l

logger.debug("starting private tunnel")

private_tunnel = TunnelEntrance(host=host, port=port, endpoint_name=listener_name, pc=pc,
private_tunnel = TunnelEntrance(host=host, port=port, endpoint_name=convo_id, pc=pc,
print_ready_event=print_ready_event, logger=logger)

t1 = asyncio.create_task(private_tunnel.start_server())
params.tunnel_threads[convo_id].update({"print": t1, "entrance": private_tunnel})
params.tunnel_threads[convo_id].update({"server": t1, "entrance": private_tunnel})

logger.debug("--> START LISTENING FOR MESSAGES FROM GATEWAY --------")
await asyncio.gather(t1, private_tunnel.reader_task)

self.tunnel_cleanup(params, convo_id)

def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port, listener_name,
def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port,
gateway_public_key_bytes, client_private_key):
loop = None
try:
Expand All @@ -1920,30 +1865,59 @@ def pre_connect(self, params, record_uid, convo_id, gateway_uid, host, port, lis
asyncio.set_event_loop(loop)
output_queue = queue.Queue(maxsize=500)
params.tunnel_threads_queue[convo_id] = output_queue
loop.run_until_complete(
# Create a Task from the coroutine
connect_task = loop.create_task(
self.connect(
params=params,
record_uid=record_uid,
convo_id=convo_id,
gateway_uid=gateway_uid,
host=host,
port=port,
listener_name=listener_name,
log_queue=output_queue,
gateway_public_key_bytes=gateway_public_key_bytes,
client_private_key=client_private_key
)
)
except asyncio.CancelledError:
print(f"{bcolors.WARNING}Tasks for connection {convo_id} were cancelled.{bcolors.ENDC}")
params.tunnel_threads[convo_id].update({"connect_task": connect_task})
try:
# Run the task until it is complete
loop.run_until_complete(connect_task)
except asyncio.CancelledError:
pass
except SocketNotConnectedException as es:
print(f"{bcolors.FAIL}An exception occurred in pre_connect for connection {convo_id}: {es}{bcolors.ENDC}")
except Exception as e:
print(f"{bcolors.FAIL}An exception occurred in pre_connect for connection {convo_id}: {e}{bcolors.ENDC}")
finally:
if loop:
loop.call_soon_threadsafe(self.tunnel_cleanup, params, convo_id)
print(f"{bcolors.OKBLUE}Cleanup called for connection {convo_id}.{bcolors.ENDC}")
try:
tunnel_data = params.tunnel_threads.get(convo_id, None)
if not tunnel_data:
logging.debug(f"{bcolors.WARNING}No tunnel data found for {convo_id}{bcolors.ENDC}")
return

if convo_id in params.tunnel_threads_queue:
del params.tunnel_threads_queue[convo_id]

entrance = tunnel_data.get("entrance", None)
if entrance:
loop.run_until_complete(entrance.stop_server())

del params.tunnel_threads[convo_id]
logging.debug(f"Cleaned up data for {convo_id}")

try:
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
loop.close()
logging.debug(f"{convo_id} Loop cleaned up")
except Exception as e:
logging.debug(f"{bcolors.WARNING}Exception while stopping event loop: {e}{bcolors.ENDC}")
except Exception as e:
print(f"{bcolors.FAIL}An exception occurred in pre_connect for connection {convo_id}: {e}{bcolors.ENDC}")
print(f"{bcolors.OKBLUE}Tunnel {convo_id} closed.{bcolors.ENDC}")

def execute(self, params, **kwargs):
version = [3, 11, 0]
Expand All @@ -1963,7 +1937,6 @@ def execute(self, params, **kwargs):
gateway_uid = kwargs.get('gateway')
host = kwargs.get('host')
port = kwargs.get('port')
listener_name = kwargs.get('listener_name')

gateway_public_key_bytes = retrieve_gateway_public_key(gateway_uid, params, api, utils)

Expand Down Expand Up @@ -1999,10 +1972,11 @@ def execute(self, params, **kwargs):
client_private_key_value = client_private_key.get_default_value(str)

t = threading.Thread(target=self.pre_connect, args=(params, record_uid, convo_id, gateway_uid, host, port,
listener_name, gateway_public_key_bytes,
client_private_key_value)
gateway_public_key_bytes, client_private_key_value)
)

# Setting the thread as a daemon thread
t.daemon = True
t.start()
params.tunnel_threads[convo_id].update({"convo_id": convo_id, "thread": t, "host": host, "port": port,
"name": listener_name, "started": datetime.now(),
"record_uid": record_uid})
"started": datetime.now(), "record_uid": record_uid})
Loading

0 comments on commit 51bccea

Please sign in to comment.