Skip to content

Commit

Permalink
Merge branch 'main' into refactor-exec-transport-command-method
Browse files Browse the repository at this point in the history
  • Loading branch information
beeankha authored Feb 17, 2025
2 parents f0da9b8 + f2bd35b commit b46d9dc
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 55 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
minor_changes:
- aws_ssm - Refactor ``_prepare_terminal()`` Method for Improved Clarity and Efficiency (https://github.com/ansible-collections/community.aws/pull/).
111 changes: 56 additions & 55 deletions plugins/connection/aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,8 @@
from dataclasses import dataclass
from typing import Dict
from typing import List
from typing import Optional
from typing import NoReturn
from typing import Optional
from typing import Tuple

try:
Expand Down Expand Up @@ -626,7 +626,7 @@ def instance_id(self) -> str:
return self._instance_id

@instance_id.setter
def instance_id(self, instance_id: str) -> NoReturn:
def instance_id(self, instance_id: str) -> None:
self._instance_id = instance_id

def start_session(self):
Expand Down Expand Up @@ -663,7 +663,7 @@ def start_session(self):
self._vvvv(f"SSM COMMAND: {to_text(cmd)}")

stdout_r, stdout_w = pty.openpty()
session = subprocess.Popen(
self._session = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=stdout_w,
Expand All @@ -674,14 +674,13 @@ def start_session(self):

os.close(stdout_w)
self._stdout = os.fdopen(stdout_r, "rb", 0)
self._session = session

# Disable command echo and prompt.
# For non-windows Hosts: Ensure the session has started, and disable command echo and prompt.
self._prepare_terminal()

self._vvvv(f"SSM CONNECTION ID: {self._session_id}")
self._vvvv(f"SSM CONNECTION ID: {self._session_id}") # pylint: disable=unreachable

return session
return self._session

def poll_stdout(self, timeout: int = 1000) -> bool:
"""Polls the stdout file descriptor.
Expand Down Expand Up @@ -784,72 +783,74 @@ def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) ->

return self.exec_communicate(cmd, mark_start, mark_begin, mark_end)

def _prepare_terminal(self):
"""perform any one-time terminal settings"""
# No windows setup for now
if self.is_windows:
return

# *_complete variables are 3 valued:
# - None: not started
# - False: started
# - True: complete

startup_complete = False
disable_echo_complete = None
disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict")
def _ensure_ssm_session_has_started(self) -> None:
"""Ensure the SSM session has started on the host. We poll stdout
until we match the following string 'Starting session with SessionId'
"""
stdout = ""
for poll_result in self.poll("START SSM SESSION", "start_session"):
if poll_result:
stdout += to_text(self._stdout.read(1024))
self._vvvv(f"START SSM SESSION stdout line: \n{to_bytes(stdout)}")
match = str(stdout).find("Starting session with SessionId")
if match != -1:
self._vvvv("START SSM SESSION startup output received")
break

disable_prompt_complete = None
end_mark = self.generate_mark()
def _disable_prompt_command(self) -> None:
"""Disable prompt command from the host"""
end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])
disable_prompt_cmd = to_bytes(
"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n",
errors="surrogate_or_strict",
)
disable_prompt_reply = re.compile(r"\r\r\n" + re.escape(end_mark) + r"\r\r\n", re.MULTILINE)

# Send command
self._vvvv(f"DISABLE PROMPT Disabling Prompt: \n{disable_prompt_cmd}")
self._session.stdin.write(disable_prompt_cmd)

stdout = ""
# Custom command execution for when we're waiting for startup
for poll_result in self.poll("PRE", "start_session"):
if disable_prompt_complete:
break
for poll_result in self.poll("DISABLE PROMPT", disable_prompt_cmd):
if poll_result:
stdout += to_text(self._stdout.read(1024))
self._vvvv(f"PRE stdout line: \n{to_bytes(stdout)}")
self._vvvv(f"DISABLE PROMPT stdout line: \n{to_bytes(stdout)}")
if disable_prompt_reply.search(stdout):
break

# wait til prompt is ready
if startup_complete is False:
match = str(stdout).find("Starting session with SessionId")
if match != -1:
self._vvvv("PRE startup output received")
startup_complete = True
def _disable_echo_command(self) -> None:
"""Disable echo command from the host"""
disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict")

# disable echo
if startup_complete and (disable_echo_complete is None):
self._vvvv(f"PRE Disabling Echo: {disable_echo_cmd}")
self._session.stdin.write(disable_echo_cmd)
disable_echo_complete = False
# Send command
self._vvvv(f"DISABLE ECHO Disabling Prompt: \n{disable_echo_cmd}")
self._session.stdin.write(disable_echo_cmd)

if disable_echo_complete is False:
stdout = ""
for poll_result in self.poll("DISABLE ECHO", disable_echo_cmd):
if poll_result:
stdout += to_text(self._stdout.read(1024))
self._vvvv(f"DISABLE ECHO stdout line: \n{to_bytes(stdout)}")
match = str(stdout).find("stty -echo")
if match != -1:
disable_echo_complete = True
break

# disable prompt
if disable_echo_complete and disable_prompt_complete is None:
self._vvvv(f"PRE Disabling Prompt: \n{disable_prompt_cmd}")
self._session.stdin.write(disable_prompt_cmd)
disable_prompt_complete = False
def _prepare_terminal(self) -> None:
"""perform any one-time terminal settings"""
# No Windows setup for now
if self.is_windows:
return

if disable_prompt_complete is False:
match = disable_prompt_reply.search(stdout)
if match:
stdout = stdout[match.end():] # fmt: skip
disable_prompt_complete = True
# Ensure SSM Session has started
self._ensure_ssm_session_has_started()

# see https://github.com/pylint-dev/pylint/issues/8909)
if not disable_prompt_complete: # pylint: disable=unreachable
raise AnsibleConnectionFailure(f"SSM process closed during _prepare_terminal on host: {self.instance_id}")
self._vvvv("PRE Terminal configured")
# Disable echo command
self._disable_echo_command() # pylint: disable=unreachable

# Disable prompt command
self._disable_prompt_command() # pylint: disable=unreachable

self._vvvv("PRE Terminal configured") # pylint: disable=unreachable

def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
"""Wrap command so stdout and status can be extracted"""
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/plugins/connection/aws_ssm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def connection_init(*args, **kwargs):
connection._session = MagicMock()
connection._session.poll = MagicMock()
connection._session.poll.side_effect = lambda: None
connection._session.stdin = MagicMock()
connection._session.stdin.write = MagicMock()
connection._stdout = MagicMock()
connection._flush_stderr = MagicMock()

Expand Down
122 changes: 122 additions & 0 deletions tests/unit/plugins/connection/aws_ssm/test_prepare_terminal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-

# This file is part of Ansible
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

# While it may seem appropriate to import our custom fixtures here, the pytest_ansible pytest plugin
# isn't as agressive as the ansible_test._util.target.pytest.plugins.ansible_pytest_collections plugin
# when it comes to rewriting the import paths and as such we can't import fixtures via their
# absolute import path or across collections.


from unittest.mock import MagicMock
from unittest.mock import patch

import pytest

from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3

if not HAS_BOTO3:
pytestmark = pytest.mark.skip("test_poll.py requires the python modules 'boto3' and 'botocore'")


def poll_mock(x, y):
while poll_mock.results:
yield poll_mock.results.pop(0)
raise TimeoutError("-- poll_stdout_mock() --- Process has timeout...")


@pytest.mark.parametrize(
"stdout_lines,timeout_failure",
[
(["Starting ", "session ", "with SessionId"], False),
(["Starting session", " with SessionId"], False),
(["Init - Starting", " session", " with SessionId"], False),
(["Starting", " session", " with SessionId "], False),
(["Starting ", "session"], True),
(["Starting ", "session with Session"], True),
(["session ", "with SessionId"], True),
],
)
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text")
def test_ensure_ssm_session_has_started(m_to_text, m_to_bytes, connection_aws_ssm, stdout_lines, timeout_failure):
m_to_text.side_effect = str
m_to_bytes.side_effect = str
connection_aws_ssm._stdout.read = MagicMock()

connection_aws_ssm._stdout.read.side_effect = stdout_lines

poll_mock.results = [True for i in range(len(stdout_lines))]
connection_aws_ssm.poll = MagicMock()
connection_aws_ssm.poll.side_effect = poll_mock

if timeout_failure:
with pytest.raises(TimeoutError):
connection_aws_ssm._ensure_ssm_session_has_started()
else:
connection_aws_ssm._ensure_ssm_session_has_started()


@pytest.mark.parametrize(
"stdout_lines,timeout_failure",
[
(["stty -echo"], False),
(["stty ", "-echo"], False),
(["stty"], True),
(["stty ", "-ech"], True),
],
)
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text")
def test_disable_echo_command(m_to_text, m_to_bytes, connection_aws_ssm, stdout_lines, timeout_failure):
m_to_text.side_effect = str
m_to_bytes.side_effect = lambda x, **kw: str(x)
connection_aws_ssm._stdout.read = MagicMock()

connection_aws_ssm._stdout.read.side_effect = stdout_lines

poll_mock.results = [True for i in range(len(stdout_lines))]
connection_aws_ssm.poll = MagicMock()
connection_aws_ssm.poll.side_effect = poll_mock

if timeout_failure:
with pytest.raises(TimeoutError):
connection_aws_ssm._disable_echo_command()
else:
connection_aws_ssm._disable_echo_command()

connection_aws_ssm._session.stdin.write.assert_called_once_with("stty -echo\n")


@pytest.mark.parametrize("timeout_failure", [True, False])
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.random")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text")
def test_disable_prompt_command(m_to_text, m_to_bytes, m_random, connection_aws_ssm, timeout_failure):
m_to_text.side_effect = str
m_to_bytes.side_effect = lambda x, **kw: str(x)
connection_aws_ssm._stdout.read = MagicMock()

connection_aws_ssm.poll = MagicMock()
connection_aws_ssm.poll.side_effect = poll_mock

m_random.choice = MagicMock()
m_random.choice.side_effect = lambda x: "a"

end_mark = "".join(["a" for i in range(connection_aws_ssm.MARK_LENGTH)])

connection_aws_ssm._stdout.read.return_value = (
f"\r\r\n{end_mark}\r\r\n" if not timeout_failure else "unmatching value"
)
poll_mock.results = [True]

prompt_cmd = f"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '{end_mark}'\n"

if timeout_failure:
with pytest.raises(TimeoutError):
connection_aws_ssm._disable_prompt_command()
else:
connection_aws_ssm._disable_prompt_command()

connection_aws_ssm._session.stdin.write.assert_called_once_with(prompt_cmd)

0 comments on commit b46d9dc

Please sign in to comment.