From f2bd35b1e5d9c333ae7f1d63f52ad97b6ad12466 Mon Sep 17 00:00:00 2001 From: Bikouo Aubin <79859644+abikouo@users.noreply.github.com> Date: Thu, 13 Feb 2025 17:33:06 +0100 Subject: [PATCH] aws_ssm - refactor _prepare_terminal() method (#2229) SUMMARY Refer to https://issues.redhat.com/browse/ACA-2094 Refactor _prepare_terminal() and add unit tests ISSUE TYPE Feature Pull Request COMPONENT NAME connection/aws_ssm Reviewed-by: Helen Bailey Reviewed-by: Alina Buzachis --- ...0204-aws_ssm-refactor-prepare_terminal.yml | 3 + plugins/connection/aws_ssm.py | 111 ++++++++-------- .../plugins/connection/aws_ssm/conftest.py | 2 + .../aws_ssm/test_prepare_terminal.py | 122 ++++++++++++++++++ 4 files changed, 183 insertions(+), 55 deletions(-) create mode 100644 changelogs/fragments/20250204-aws_ssm-refactor-prepare_terminal.yml create mode 100644 tests/unit/plugins/connection/aws_ssm/test_prepare_terminal.py diff --git a/changelogs/fragments/20250204-aws_ssm-refactor-prepare_terminal.yml b/changelogs/fragments/20250204-aws_ssm-refactor-prepare_terminal.yml new file mode 100644 index 00000000000..2da558ac08d --- /dev/null +++ b/changelogs/fragments/20250204-aws_ssm-refactor-prepare_terminal.yml @@ -0,0 +1,3 @@ +--- +minor_changes: + - aws_ssm - Refactor ``_prepare_terminal()`` Method for Improved Clarity and Efficiency (https://github.com/ansible-collections/community.aws/pull/). diff --git a/plugins/connection/aws_ssm.py b/plugins/connection/aws_ssm.py index 8300a8f8084..a743bc83553 100644 --- a/plugins/connection/aws_ssm.py +++ b/plugins/connection/aws_ssm.py @@ -327,8 +327,8 @@ import string import subprocess import time -from typing import Optional from typing import NoReturn +from typing import Optional from typing import Tuple try: @@ -609,7 +609,7 @@ def instance_id(self) -> str: return self._instance_id @instance_id.setter - def instance_id(self, instance_id: str) -> NoReturn: + def instance_id(self, instance_id: str) -> None: self._instance_id = instance_id def start_session(self): @@ -646,7 +646,7 @@ def start_session(self): self._vvvv(f"SSM COMMAND: {to_text(cmd)}") stdout_r, stdout_w = pty.openpty() - session = subprocess.Popen( + self._session = subprocess.Popen( cmd, stdin=subprocess.PIPE, stdout=stdout_w, @@ -657,14 +657,13 @@ def start_session(self): os.close(stdout_w) self._stdout = os.fdopen(stdout_r, "rb", 0) - self._session = session - # Disable command echo and prompt. + # For non-windows Hosts: Ensure the session has started, and disable command echo and prompt. self._prepare_terminal() - self._vvvv(f"SSM CONNECTION ID: {self._session_id}") + self._vvvv(f"SSM CONNECTION ID: {self._session_id}") # pylint: disable=unreachable - return session + return self._session def poll_stdout(self, timeout: int = 1000) -> bool: """Polls the stdout file descriptor. @@ -767,72 +766,74 @@ def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) -> return self.exec_communicate(cmd, mark_start, mark_begin, mark_end) - def _prepare_terminal(self): - """perform any one-time terminal settings""" - # No windows setup for now - if self.is_windows: - return - - # *_complete variables are 3 valued: - # - None: not started - # - False: started - # - True: complete - - startup_complete = False - disable_echo_complete = None - disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict") + def _ensure_ssm_session_has_started(self) -> None: + """Ensure the SSM session has started on the host. We poll stdout + until we match the following string 'Starting session with SessionId' + """ + stdout = "" + for poll_result in self.poll("START SSM SESSION", "start_session"): + if poll_result: + stdout += to_text(self._stdout.read(1024)) + self._vvvv(f"START SSM SESSION stdout line: \n{to_bytes(stdout)}") + match = str(stdout).find("Starting session with SessionId") + if match != -1: + self._vvvv("START SSM SESSION startup output received") + break - disable_prompt_complete = None - end_mark = self.generate_mark() + def _disable_prompt_command(self) -> None: + """Disable prompt command from the host""" + end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)]) disable_prompt_cmd = to_bytes( "PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n", errors="surrogate_or_strict", ) disable_prompt_reply = re.compile(r"\r\r\n" + re.escape(end_mark) + r"\r\r\n", re.MULTILINE) + # Send command + self._vvvv(f"DISABLE PROMPT Disabling Prompt: \n{disable_prompt_cmd}") + self._session.stdin.write(disable_prompt_cmd) + stdout = "" - # Custom command execution for when we're waiting for startup - for poll_result in self.poll("PRE", "start_session"): - if disable_prompt_complete: - break + for poll_result in self.poll("DISABLE PROMPT", disable_prompt_cmd): if poll_result: stdout += to_text(self._stdout.read(1024)) - self._vvvv(f"PRE stdout line: \n{to_bytes(stdout)}") + self._vvvv(f"DISABLE PROMPT stdout line: \n{to_bytes(stdout)}") + if disable_prompt_reply.search(stdout): + break - # wait til prompt is ready - if startup_complete is False: - match = str(stdout).find("Starting session with SessionId") - if match != -1: - self._vvvv("PRE startup output received") - startup_complete = True + def _disable_echo_command(self) -> None: + """Disable echo command from the host""" + disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict") - # disable echo - if startup_complete and (disable_echo_complete is None): - self._vvvv(f"PRE Disabling Echo: {disable_echo_cmd}") - self._session.stdin.write(disable_echo_cmd) - disable_echo_complete = False + # Send command + self._vvvv(f"DISABLE ECHO Disabling Prompt: \n{disable_echo_cmd}") + self._session.stdin.write(disable_echo_cmd) - if disable_echo_complete is False: + stdout = "" + for poll_result in self.poll("DISABLE ECHO", disable_echo_cmd): + if poll_result: + stdout += to_text(self._stdout.read(1024)) + self._vvvv(f"DISABLE ECHO stdout line: \n{to_bytes(stdout)}") match = str(stdout).find("stty -echo") if match != -1: - disable_echo_complete = True + break - # disable prompt - if disable_echo_complete and disable_prompt_complete is None: - self._vvvv(f"PRE Disabling Prompt: \n{disable_prompt_cmd}") - self._session.stdin.write(disable_prompt_cmd) - disable_prompt_complete = False + def _prepare_terminal(self) -> None: + """perform any one-time terminal settings""" + # No Windows setup for now + if self.is_windows: + return - if disable_prompt_complete is False: - match = disable_prompt_reply.search(stdout) - if match: - stdout = stdout[match.end():] # fmt: skip - disable_prompt_complete = True + # Ensure SSM Session has started + self._ensure_ssm_session_has_started() - # see https://github.com/pylint-dev/pylint/issues/8909) - if not disable_prompt_complete: # pylint: disable=unreachable - raise AnsibleConnectionFailure(f"SSM process closed during _prepare_terminal on host: {self.instance_id}") - self._vvvv("PRE Terminal configured") + # Disable echo command + self._disable_echo_command() # pylint: disable=unreachable + + # Disable prompt command + self._disable_prompt_command() # pylint: disable=unreachable + + self._vvvv("PRE Terminal configured") # pylint: disable=unreachable def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str: """Wrap command so stdout and status can be extracted""" diff --git a/tests/unit/plugins/connection/aws_ssm/conftest.py b/tests/unit/plugins/connection/aws_ssm/conftest.py index 3d1b2c874fa..a35feb06e2f 100644 --- a/tests/unit/plugins/connection/aws_ssm/conftest.py +++ b/tests/unit/plugins/connection/aws_ssm/conftest.py @@ -32,6 +32,8 @@ def connection_init(*args, **kwargs): connection._session = MagicMock() connection._session.poll = MagicMock() connection._session.poll.side_effect = lambda: None + connection._session.stdin = MagicMock() + connection._session.stdin.write = MagicMock() connection._stdout = MagicMock() connection._flush_stderr = MagicMock() diff --git a/tests/unit/plugins/connection/aws_ssm/test_prepare_terminal.py b/tests/unit/plugins/connection/aws_ssm/test_prepare_terminal.py new file mode 100644 index 00000000000..fe6f2361402 --- /dev/null +++ b/tests/unit/plugins/connection/aws_ssm/test_prepare_terminal.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + +# This file is part of Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# While it may seem appropriate to import our custom fixtures here, the pytest_ansible pytest plugin +# isn't as agressive as the ansible_test._util.target.pytest.plugins.ansible_pytest_collections plugin +# when it comes to rewriting the import paths and as such we can't import fixtures via their +# absolute import path or across collections. + + +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3 + +if not HAS_BOTO3: + pytestmark = pytest.mark.skip("test_poll.py requires the python modules 'boto3' and 'botocore'") + + +def poll_mock(x, y): + while poll_mock.results: + yield poll_mock.results.pop(0) + raise TimeoutError("-- poll_stdout_mock() --- Process has timeout...") + + +@pytest.mark.parametrize( + "stdout_lines,timeout_failure", + [ + (["Starting ", "session ", "with SessionId"], False), + (["Starting session", " with SessionId"], False), + (["Init - Starting", " session", " with SessionId"], False), + (["Starting", " session", " with SessionId "], False), + (["Starting ", "session"], True), + (["Starting ", "session with Session"], True), + (["session ", "with SessionId"], True), + ], +) +@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes") +@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text") +def test_ensure_ssm_session_has_started(m_to_text, m_to_bytes, connection_aws_ssm, stdout_lines, timeout_failure): + m_to_text.side_effect = str + m_to_bytes.side_effect = str + connection_aws_ssm._stdout.read = MagicMock() + + connection_aws_ssm._stdout.read.side_effect = stdout_lines + + poll_mock.results = [True for i in range(len(stdout_lines))] + connection_aws_ssm.poll = MagicMock() + connection_aws_ssm.poll.side_effect = poll_mock + + if timeout_failure: + with pytest.raises(TimeoutError): + connection_aws_ssm._ensure_ssm_session_has_started() + else: + connection_aws_ssm._ensure_ssm_session_has_started() + + +@pytest.mark.parametrize( + "stdout_lines,timeout_failure", + [ + (["stty -echo"], False), + (["stty ", "-echo"], False), + (["stty"], True), + (["stty ", "-ech"], True), + ], +) +@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes") +@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text") +def test_disable_echo_command(m_to_text, m_to_bytes, connection_aws_ssm, stdout_lines, timeout_failure): + m_to_text.side_effect = str + m_to_bytes.side_effect = lambda x, **kw: str(x) + connection_aws_ssm._stdout.read = MagicMock() + + connection_aws_ssm._stdout.read.side_effect = stdout_lines + + poll_mock.results = [True for i in range(len(stdout_lines))] + connection_aws_ssm.poll = MagicMock() + connection_aws_ssm.poll.side_effect = poll_mock + + if timeout_failure: + with pytest.raises(TimeoutError): + connection_aws_ssm._disable_echo_command() + else: + connection_aws_ssm._disable_echo_command() + + connection_aws_ssm._session.stdin.write.assert_called_once_with("stty -echo\n") + + +@pytest.mark.parametrize("timeout_failure", [True, False]) +@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.random") +@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes") +@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text") +def test_disable_prompt_command(m_to_text, m_to_bytes, m_random, connection_aws_ssm, timeout_failure): + m_to_text.side_effect = str + m_to_bytes.side_effect = lambda x, **kw: str(x) + connection_aws_ssm._stdout.read = MagicMock() + + connection_aws_ssm.poll = MagicMock() + connection_aws_ssm.poll.side_effect = poll_mock + + m_random.choice = MagicMock() + m_random.choice.side_effect = lambda x: "a" + + end_mark = "".join(["a" for i in range(connection_aws_ssm.MARK_LENGTH)]) + + connection_aws_ssm._stdout.read.return_value = ( + f"\r\r\n{end_mark}\r\r\n" if not timeout_failure else "unmatching value" + ) + poll_mock.results = [True] + + prompt_cmd = f"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '{end_mark}'\n" + + if timeout_failure: + with pytest.raises(TimeoutError): + connection_aws_ssm._disable_prompt_command() + else: + connection_aws_ssm._disable_prompt_command() + + connection_aws_ssm._session.stdin.write.assert_called_once_with(prompt_cmd)