Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions ec2instanceconnectcli/EC2InstanceConnectCLI.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,18 @@ def handle_keys(self):
key_publisher.push_public_key(session, bundle['instance_id'], bundle['username'], self.pub_key, bundle['zone'])
self.logger.debug('Successfully pushed the public key to {0}'.format(bundle['instance_id']))

def run_command(self, command=None):
def run_command(self, args=None):
"""
Runs the given command in a sub-shell
:param command: Command to invoke
:type command: basestring
Runs the given command
:param args: Arguments to invoke
:type args: list of strings
:return: Return code for remote command
:rtype: int
"""
if not command:
if not args:
raise ValueError('Must provide a command')

invocation_proc = Popen(command, shell=True)
invocation_proc = Popen(args)
while invocation_proc.poll() is None: #sub-process not terminated
time.sleep(0.1)
return invocation_proc.returncode
Expand Down
11 changes: 5 additions & 6 deletions ec2instanceconnectcli/EC2InstanceConnectCommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,20 @@ def get_command(self):
Generates and returns the generated command
"""
# Start with protocol & identity file
command = '{0} -o "IdentitiesOnly=yes" -i {1}'.format(self.program, self.key_file)
command = [self.program, '-o', 'IdentitiesOnly=yes', '-i', self.key_file]

# Next add command flags if present
if len(self.flags) > 0:
command = "{0} {1}".format(command, self.flags)
command.extend(self.flags)

# Target
command = "{0} {1}".format(command, self._get_target(self.instance_bundles[0]))
command.append(self._get_target(self.instance_bundles[0]))

#program specific command
if len(self.program_command) > 0:
command = "{0} {1}".format(command, self.program_command)
command.append(self.program_command)

if len(self.instance_bundles) > 1:
command = "{0} {1}".format(command, self._get_target(self.instance_bundles[1]))
command.append(self._get_target(self.instance_bundles[1]))

self.logger.debug('Generated command: {0}'.format(command))

Expand Down
6 changes: 2 additions & 4 deletions ec2instanceconnectcli/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _parse_command_flags(raw_command, instance_bundles, is_ssh=False):
:return: tuple of flags and final comamnd or file list
:rtype: tuple
"""
flags = ''
flags = []
is_user = False
is_flagged = False
command_index = 0
Expand All @@ -133,7 +133,7 @@ def _parse_command_flags(raw_command, instance_bundles, is_ssh=False):
used += 1

# This is either a flag or a flag value
flags = '{0} {1}'.format(flags, raw_command[command_index])
flags.append(raw_command[command_index])

if raw_command[command_index][0] == '-':
# Flag
Expand All @@ -152,8 +152,6 @@ def _parse_command_flags(raw_command, instance_bundles, is_ssh=False):

command_index += 1

flags = flags.strip()

"""
Target host and command or file list
"""
Expand Down
49 changes: 25 additions & 24 deletions tests/test_EC2ConnectCLI.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_mssh_no_target(self,
mock_push_key,
mock_run):
mock_file = 'identity'
flag = '-f flag'
flags = ['-f', 'flag']
command = 'command arg'
logger = EC2InstanceConnectLogger()
instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id,
Expand All @@ -41,12 +41,12 @@ def test_mssh_no_target(self,
mock_instance_data.return_value = self.instance_info
mock_push_key.return_value = None

cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger())
cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flags, command, logger.get_logger())
cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger())
cli.invoke_command()
expected_command = 'ssh -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3} {4}'.format(mock_file, flag, self.default_user,
self.public_ip, command)

expected_command = ['ssh', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags,
'{}@{}'.format(self.default_user, self.public_ip), command]

# Check that we successfully get to the run
self.assertTrue(mock_instance_data.called)
Expand All @@ -62,7 +62,7 @@ def test_mssh_no_target_no_public_ip(self,
mock_push_key,
mock_run):
mock_file = "identity"
flag = '-f flag'
flags = ['-f', 'flag']
command = 'command arg'
logger = EC2InstanceConnectLogger()
instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id,
Expand All @@ -72,12 +72,12 @@ def test_mssh_no_target_no_public_ip(self,
mock_instance_data.return_value = self.private_instance_info
mock_push_key.return_value = None

cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger())
cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flags, command, logger.get_logger())
cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger())
cli.invoke_command()

expected_command = 'ssh -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3} {4}'.format(mock_file, flag, self.default_user,
self.private_ip, command)
expected_command = ['ssh', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags,
'{}@{}'.format(self.default_user, self.private_ip), command]

# Check that we successfully get to the run
self.assertTrue(mock_instance_data.called)
Expand All @@ -92,7 +92,7 @@ def test_mssh_with_target(self,
mock_push_key,
mock_run):
mock_file = 'identity'
flag = '-f flag'
flags = ['-f', 'flag']
command = 'command arg'
host = '0.0.0.0'
logger = EC2InstanceConnectLogger()
Expand All @@ -103,12 +103,12 @@ def test_mssh_with_target(self,
mock_instance_data.return_value = self.instance_info
mock_push_key.return_value = None

cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger())
cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flags, command, logger.get_logger())
cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger())
cli.invoke_command()

expected_command = 'ssh -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3} {4}'.format(mock_file, flag, self.default_user,
host, command)
expected_command = ['ssh', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags,
'{}@{}'.format(self.default_user, host), command]
# Check that we successfully get to the run
# Since both target and availability_zone are provided, mock_instance_data should not be called
self.assertFalse(mock_instance_data.called)
Expand All @@ -123,7 +123,7 @@ def test_msftp(self,
mock_push_key,
mock_run):
mock_file = 'identity'
flag = '-f flag'
flags = ['-f', 'flag']
command = 'file2 file3'
logger = EC2InstanceConnectLogger()
instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id,
Expand All @@ -133,10 +133,11 @@ def test_msftp(self,
mock_instance_data.return_value = self.instance_info
mock_push_key.return_value = None

expected_command = 'sftp -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3}:{4} {5}'.format(mock_file, flag, self.default_user,
self.public_ip, 'file1', command)
expected_command = ['sftp', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags,
'{}@{}:{}'.format(self.default_user, self.public_ip, 'file1'),
command]

cli_command = EC2InstanceConnectCommand("sftp", instance_bundles, mock_file, flag, command, logger.get_logger())
cli_command = EC2InstanceConnectCommand("sftp", instance_bundles, mock_file, flags, command, logger.get_logger())
cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger())
cli.invoke_command()

Expand All @@ -153,7 +154,7 @@ def test_mscp(self,
mock_push_key,
mock_run):
mock_file = 'identity'
flag = '-f flag'
flags = ['-f', 'flag']
command = 'file2 file3'
logger = EC2InstanceConnectLogger()
instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id,
Expand All @@ -166,12 +167,12 @@ def test_mscp(self,
mock_instance_data.return_value = self.instance_info
mock_push_key.return_value = None

expected_command = 'scp -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3}:{4} {5} {6}@{7}:{8}'.format(mock_file, flag, self.default_user,
self.public_ip, 'file1', command,
self.default_user,
self.public_ip, 'file4')
expected_command = ['scp', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags,
'{}@{}:{}'.format(self.default_user, self.public_ip, 'file1'),
command,
'{}@{}:{}'.format(self.default_user, self.public_ip, 'file4')]

cli_command = EC2InstanceConnectCommand("scp", instance_bundles, mock_file, flag, command, logger.get_logger())
cli_command = EC2InstanceConnectCommand("scp", instance_bundles, mock_file, flags, command, logger.get_logger())
cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger())
cli.invoke_command()

Expand All @@ -183,5 +184,5 @@ def test_mscp(self,
def test_status_code(self):
#TODO: Refine test for checking run_command status code
cli = EC2InstanceConnectCLI(None, None, None, None)
code = cli.run_command("echo ok; exit -1;")
code = cli.run_command(["sh", "-c", "echo ok; exit -1;"])
self.assertEqual(code, 255)
12 changes: 6 additions & 6 deletions tests/test_input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_basic_target(self):

self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id,
'target': None, 'zone': None, 'region': None, 'profile': self.profile}])
self.assertEqual(flags, '')
self.assertEqual(flags, [])
self.assertEqual(command, '')

def test_username(self):
Expand All @@ -51,7 +51,7 @@ def test_username(self):

self.assertEqual(bundles, [{'username': 'myuser', 'instance_id': self.instance_id,
'target': None, 'zone': None, 'region': None, 'profile': self.profile}])
self.assertEqual(flags, '')
self.assertEqual(flags, [])
self.assertEqual(command, '')

def test_dns_name(self):
Expand All @@ -63,7 +63,7 @@ def test_dns_name(self):
self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id,
'target': self.dns_name, 'zone': self.availability_zone,
'region': self.region, 'profile': self.profile}])
self.assertEqual(flags, '')
self.assertEqual(flags, [])
self.assertEqual(command, '')

def test_flags(self):
Expand All @@ -73,7 +73,7 @@ def test_flags(self):

self.assertEqual(bundles, [{'username': 'login', 'instance_id': self.instance_id,
'target': None, 'zone': None, 'region': None, 'profile': self.profile}])
self.assertEqual(flags, '-1 -l login')
self.assertEqual(flags, ['-1', '-l', 'login'])
self.assertEqual(command, '')

def test_command(self):
Expand All @@ -83,7 +83,7 @@ def test_command(self):

self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id,
'target': None, 'zone': None, 'region': None, 'profile': self.profile}])
self.assertEqual(flags, '')
self.assertEqual(flags, [])
self.assertEqual(command, 'uname -a')

def test_sftp(self):
Expand All @@ -95,7 +95,7 @@ def test_sftp(self):
self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id,
'target': None, 'zone': None, 'region': None, 'profile': self.profile,
'file': 'first_file'}])
self.assertEqual(flags, '')
self.assertEqual(flags, [])
self.assertEqual(command, 'second_file')

def test_invalid_username(self):
Expand Down