Skip to content

Commit

Permalink
Remove custom Command dataclass, misc. edits to eliminate TypeErrors
Browse files Browse the repository at this point in the history
  • Loading branch information
beeankha committed Feb 25, 2025
1 parent f0a88c4 commit 02cf388
Showing 1 changed file with 25 additions and 36 deletions.
61 changes: 25 additions & 36 deletions plugins/connection/aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,17 +436,6 @@ def filter_ansi(line: str, is_windows: bool) -> str:
return line


@dataclass
class Command:
"""
Custom dataclass for the generated command dictionaries.
"""

command: str
method: str # 'get' or 'put'
headers: Dict[str, str]


@dataclass
class CommandResult:
"""
Expand Down Expand Up @@ -985,7 +974,7 @@ def _generate_commands(
s3_path: str,
in_path: str,
out_path: str,
) -> Tuple[List[Command], Optional[Dict]]:
) -> Tuple[List[Dict], dict]:
"""
Generate commands for the specified bucket, S3 path, input path, and output path.
Expand All @@ -995,7 +984,7 @@ def _generate_commands(
:param out_path: Output path
:param method: The request method to use for the command (can be "get" or "put").
:returns: List of Command dictionaries containing the command string and metadata.
:returns: A tuple containing a list of command dictionaries along with any ``put_args`` dictionaries.
"""

put_args, put_headers = self._generate_encryption_settings()
Expand Down Expand Up @@ -1029,7 +1018,7 @@ def _generate_commands(
),
# The "method" key indicates to _file_transport_command which commands are put_commands
"method": "put",
"headers": put_headers
"headers": put_headers,
}) # fmt: skip
else:
put_command_headers = " ".join([f"-H '{h}: {v}'" for h, v in put_headers.items()])
Expand All @@ -1055,6 +1044,8 @@ def _generate_commands(
"touch "
f"'{out_path}'"
),
"method": "get",
"headers": {},
}) # fmt: skip
commands.append({
"command":
Expand All @@ -1066,12 +1057,12 @@ def _generate_commands(
),
# The "method" key indicates to _file_transport_command which commands are put_commands
"method": "put",
"headers": put_headers
"headers": put_headers,
}) # fmt: skip

return commands, put_args

def _exec_transport_commands(self, in_path: str, out_path: str, commands: List[Command]) -> CommandResult:
def _exec_transport_commands(self, in_path: str, out_path: str, commands: List[dict]) -> Tuple[int, str, str]:
"""
Execute the provided transport commands.
Expand All @@ -1084,10 +1075,13 @@ def _exec_transport_commands(self, in_path: str, out_path: str, commands: List[C

stdout_combined, stderr_combined = "", ""
for command in commands:
(returncode, stdout, stderr) = self.exec_command(command["command"], in_data=None, sudoable=False)
result = self.exec_command(command["command"], in_data=None, sudoable=False)

returncode = result[0]
stdout, stderr = result[1], result[2]

# Check the return code
if returncode != 0:
if result[0] != 0:
raise AnsibleError(f"failed to transfer file to {in_path} {out_path}:\n{stdout}\n{stderr}")

stdout_combined += stdout
Expand All @@ -1101,7 +1095,7 @@ def _file_transport_command(
in_path: str,
out_path: str,
ssm_action: str,
) -> CommandResult:
) -> Tuple[int, str, str]:
"""
Transfer file(s) to/from host using an intermediate S3 bucket and then delete the file(s).
Expand All @@ -1117,30 +1111,25 @@ def _file_transport_command(

client = self._s3_client

commands, put_args = self._generate_commands(
bucket_name,
s3_path,
in_path,
out_path,
)

try:
if ssm_action == "get":
put_commands, put_args = self._generate_commands(
bucket_name,
s3_path,
in_path,
out_path,
)
put_commands = [cmd["command"] for cmd in put_commands if cmd.get("method") == "put"]
(returncode, stdout, stderr) = self._exec_transport_commands(in_path, out_path, put_commands)
put_commands = [cmd for cmd in commands if cmd.get("method") == "put"]
result = self._exec_transport_commands(in_path, out_path, put_commands)
with open(to_bytes(out_path, errors="surrogate_or_strict"), "wb") as data:
client.download_fileobj(bucket_name, s3_path, data)
else:
get_commands, put_args = self._generate_commands(
bucket_name,
s3_path,
in_path,
out_path,
)
get_commands = [cmd["command"] for cmd in get_commands if cmd.get("method") == "get"]
get_commands = [cmd for cmd in commands if cmd.get("method") == "get"]
with open(to_bytes(in_path, errors="surrogate_or_strict"), "rb") as data:
client.upload_fileobj(data, bucket_name, s3_path, ExtraArgs=put_args)
(returncode, stdout, stderr) = self._exec_transport_commands(in_path, out_path, get_commands)
return CommandResult(returncode, stdout, stderr)
result = self._exec_transport_commands(in_path, out_path, get_commands)
return CommandResult(result)
finally:
# Remove the files from the bucket after they've been transferred
client.delete_object(Bucket=bucket_name, Key=s3_path)
Expand Down

0 comments on commit 02cf388

Please sign in to comment.