Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Some Command-Related Methods in aws_ssm.py #2248

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
minor_changes:
- aws_ssm - Refactor ``_exec_transport_commands``, ``_generate_commands``, and ``_exec_transport_commands`` methods for improved clarity (https://github.com/ansible-collections/community.aws/pull/2248).
153 changes: 116 additions & 37 deletions plugins/connection/aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@
import string
import subprocess
import time
from dataclasses import dataclass
from typing import Dict
from typing import List
from typing import NoReturn
from typing import Optional
from typing import Tuple
Expand Down Expand Up @@ -433,6 +436,17 @@ def filter_ansi(line: str, is_windows: bool) -> str:
return line


@dataclass
class CommandResult:
"""
Custom dataclass for the executed command results.
"""

returncode: int
stdout_combined: str
stderr_combined: str


class Connection(ConnectionBase):
"""AWS SSM based connections"""

Expand Down Expand Up @@ -954,15 +968,46 @@ def _generate_encryption_settings(self):
put_headers["x-amz-server-side-encryption-aws-kms-key-id"] = self.get_option("bucket_sse_kms_key_id")
return put_args, put_headers

def _generate_commands(self, bucket_name, s3_path, in_path, out_path):
def _generate_commands(
self,
bucket_name: str,
s3_path: str,
in_path: str,
out_path: str,
) -> Tuple[List[Dict], dict]:
"""
Generate commands for the specified bucket, S3 path, input path, and output path.

:param bucket_name: The name of the S3 bucket used for file transfers.
:param s3_path: The S3 path to the file to be sent.
:param in_path: Input path
:param out_path: Output path
:param method: The request method to use for the command (can be "get" or "put").

:returns: A tuple containing a list of command dictionaries along with any ``put_args`` dictionaries.
"""

put_args, put_headers = self._generate_encryption_settings()
commands = []

put_url = self._get_url("put_object", bucket_name, s3_path, "PUT", extra_args=put_args)
get_url = self._get_url("get_object", bucket_name, s3_path, "GET")

if self.is_windows:
put_command_headers = "; ".join([f"'{h}' = '{v}'" for h, v in put_headers.items()])
put_commands = [
commands.append({
"command":
(
"Invoke-WebRequest "
f"'{get_url}' "
f"-OutFile '{out_path}'"
),
# The "method" key indicates to _file_transport_command which commands are get_commands
"method": "get",
"headers": {},
}) # fmt: skip
commands.append({
"command":
(
"Invoke-WebRequest -Method PUT "
# @{'key' = 'value'; 'key2' = 'value2'}
Expand All @@ -971,47 +1016,66 @@ def _generate_commands(self, bucket_name, s3_path, in_path, out_path):
f"-Uri '{put_url}' "
f"-UseBasicParsing"
),
] # fmt: skip
get_commands = [
(
"Invoke-WebRequest "
f"'{get_url}' "
f"-OutFile '{out_path}'"
),
] # fmt: skip
# The "method" key indicates to _file_transport_command which commands are put_commands
"method": "put",
"headers": put_headers,
}) # fmt: skip
else:
put_command_headers = " ".join([f"-H '{h}: {v}'" for h, v in put_headers.items()])
put_commands = [
(
"curl --request PUT "
f"{put_command_headers} "
f"--upload-file '{in_path}' "
f"'{put_url}'"
),
] # fmt: skip
get_commands = [
commands.append({
"command":
(
"curl "
f"-o '{out_path}' "
f"'{get_url}'"
),
# Due to https://github.com/curl/curl/issues/183 earlier
# versions of curl did not create the output file, when the
# response was empty. Although this issue was fixed in 2015,
# some actively maintained operating systems still use older
# versions of it (e.g. CentOS 7)
# The "method" key indicates to _file_transport_command which commands are get_commands
"method": "get",
"headers": {},
}) # fmt: skip
# Due to https://github.com/curl/curl/issues/183 earlier
# versions of curl did not create the output file, when the
# response was empty. Although this issue was fixed in 2015,
# some actively maintained operating systems still use older
# versions of it (e.g. CentOS 7)
commands.append({
"command":
(
"touch "
f"'{out_path}'"
)
] # fmt: skip
),
"method": "get",
"headers": {},
}) # fmt: skip
commands.append({
"command":
(
"curl --request PUT "
f"{put_command_headers} "
f"--upload-file '{in_path}' "
f"'{put_url}'"
),
# The "method" key indicates to _file_transport_command which commands are put_commands
"method": "put",
"headers": put_headers,
}) # fmt: skip

return commands, put_args

return get_commands, put_commands, put_args
def _exec_transport_commands(self, in_path: str, out_path: str, commands: List[dict]) -> CommandResult:
"""
Execute the provided transport commands.

:param in_path: The input path.
:param out_path: The output path.
:param commands: A list of command dictionaries containing the command string and metadata.

:returns: A tuple containing the return code, stdout, and stderr.
"""

def _exec_transport_commands(self, in_path, out_path, commands):
stdout_combined, stderr_combined = "", ""
for command in commands:
(returncode, stdout, stderr) = self.exec_command(command, in_data=None, sudoable=False)
(returncode, stdout, stderr) = self.exec_command(command["command"], in_data=None, sudoable=False)

# Check the return code
if returncode != 0:
Expand All @@ -1023,31 +1087,46 @@ def _exec_transport_commands(self, in_path, out_path, commands):
return (returncode, stdout_combined, stderr_combined)

@_ssm_retry
def _file_transport_command(self, in_path, out_path, ssm_action):
"""transfer a file to/from host using an intermediate S3 bucket"""
def _file_transport_command(
self,
in_path: str,
out_path: str,
ssm_action: str,
) -> CommandResult:
"""
Transfer file(s) to/from host using an intermediate S3 bucket and then delete the file(s).

:param in_path: The input path.
:param out_path: The output path.
:param ssm_action: The SSM action to perform ("get" or "put").

:returns: The command's return code, stdout, and stderr in a tuple.
"""

bucket_name = self.get_option("bucket_name")
s3_path = self._escape_path(f"{self.instance_id}/{out_path}")

get_commands, put_commands, put_args = self._generate_commands(
client = self._s3_client

commands, put_args = self._generate_commands(
bucket_name,
s3_path,
in_path,
out_path,
)

client = self._s3_client

try:
if ssm_action == "get":
(returncode, stdout, stderr) = self._exec_transport_commands(in_path, out_path, put_commands)
put_commands = [cmd for cmd in commands if cmd.get("method") == "put"]
result = self._exec_transport_commands(in_path, out_path, put_commands)
with open(to_bytes(out_path, errors="surrogate_or_strict"), "wb") as data:
client.download_fileobj(bucket_name, s3_path, data)
else:
get_commands = [cmd for cmd in commands if cmd.get("method") == "get"]
with open(to_bytes(in_path, errors="surrogate_or_strict"), "rb") as data:
client.upload_fileobj(data, bucket_name, s3_path, ExtraArgs=put_args)
(returncode, stdout, stderr) = self._exec_transport_commands(in_path, out_path, get_commands)
return (returncode, stdout, stderr)
result = self._exec_transport_commands(in_path, out_path, get_commands)
return result
finally:
# Remove the files from the bucket after they've been transferred
client.delete_object(Bucket=bucket_name, Key=s3_path)
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

import sys
from io import StringIO
from unittest.mock import MagicMock
from unittest.mock import patch
Expand Down Expand Up @@ -268,3 +269,72 @@ def test_generate_mark(self):
assert test_a != test_b
assert len(test_a) == Connection.MARK_LENGTH
assert len(test_b) == Connection.MARK_LENGTH

@pytest.mark.skipif(sys.platform == "win", reason="This test is only for non-Windows systems")
def test_generate_commands_non_windows(self):
"""Testing command generation on non-Windows systems"""
pc = PlayContext()
new_stdin = StringIO()
conn = connection_loader.get("community.aws.aws_ssm", pc, new_stdin)
conn.get_option = MagicMock()

mock_s3_client = MagicMock()
mock_s3_client.generate_presigned_url.return_value = "https://test-url"
conn._s3_client = mock_s3_client

test_command_generation = conn._generate_commands(
"test_bucket",
"test/s3/path",
"test/in/path",
"test/out/path",
)

# Ensure data types of command object are as expected
assert isinstance(test_command_generation, tuple)
assert isinstance(test_command_generation[0], list)
assert isinstance(test_command_generation[0][0], dict)

# Three command dictionaries are generated for non-Windows systems
assert len(test_command_generation[0]) == 3

# Check contents of command dictionaries
assert "command" in test_command_generation[0][0]
assert "method" in test_command_generation[0][2]
assert "headers" in test_command_generation[0][2]
assert "curl --request PUT -H" in test_command_generation[0][2]["command"]
assert test_command_generation[0][2]["method"] == "put"

@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.Connection.is_windows")
def test_generate_commands_windows(self, mock_is_windows):
"""Testing command generation on Windows systems"""
pc = PlayContext()
new_stdin = StringIO()
conn = connection_loader.get("community.aws.aws_ssm", pc, new_stdin)

mock_is_windows.return_value = True

mock_s3_client = MagicMock()
mock_s3_client.generate_presigned_url.return_value = "https://test-url"
conn._s3_client = mock_s3_client

test_command_generation = conn._generate_commands(
"test_bucket",
"test/s3/path",
"test/in/path",
"test/out/path",
)

# Ensure data types of command object are as expected
assert isinstance(test_command_generation, tuple)
assert isinstance(test_command_generation[0], list)
assert isinstance(test_command_generation[0][0], dict)

# Two command dictionaries are generated for Windows
assert len(test_command_generation[0]) == 2

# Check contents of command dictionaries
assert "command" in test_command_generation[0][0]
assert "method" in test_command_generation[0][1]
assert "headers" in test_command_generation[0][1]
assert "Invoke-WebRequest" in test_command_generation[0][1]["command"]
assert test_command_generation[0][1]["method"] == "put"
Loading