From 83ac041181aad2b803de027730360534e4e2e9d5 Mon Sep 17 00:00:00 2001 From: Andrey Kislyuk Date: Sun, 16 Jun 2024 15:08:00 -0700 Subject: [PATCH] Fix mypy errors, ruff format --- Makefile | 6 +- aegea/__init__.py | 87 +++-- aegea/aegea_config.py | 11 + aegea/alarms.py | 35 +- aegea/audit.py | 138 +++++--- aegea/batch.py | 325 ++++++++++++------ aegea/billing.py | 53 ++- aegea/build_ami.py | 48 ++- aegea/build_docker_image.py | 79 +++-- aegea/cloudtrail.py | 14 +- aegea/cost.py | 121 +++++-- aegea/ebs.py | 75 +++- aegea/ecr.py | 17 +- aegea/ecs.py | 163 +++++---- aegea/efs.py | 26 +- aegea/elb.py | 110 ++++-- aegea/flow_logs.py | 25 +- aegea/iam.py | 29 +- aegea/instance_ctl.py | 11 +- aegea/lambda.py | 34 +- aegea/launch.py | 235 +++++++++---- aegea/logs.py | 43 ++- aegea/ls.py | 73 +++- aegea/missions/arvados-worker/Makefile | 11 - aegea/missions/arvados-worker/config.yml | 16 - aegea/missions/arvados-worker/environment | 1 - .../rootfs.skel.in/etc/default/munge | 1 - .../rootfs.skel.in/etc/munge/munge.key | 1 - .../rootfs.skel.in/etc/slurm-llnl/nodes.conf | 1 - .../rootfs.skel.in/etc/slurm-llnl/slurm.conf | 1 - .../usr/bin/aegea-set-slurm-nodes | 1 - aegea/pricing.py | 32 +- aegea/rds.py | 70 ++-- aegea/rm.py | 37 +- aegea/rootfs.skel.build_ami/root/.aws/config | 7 +- aegea/secrets.py | 61 +++- aegea/sfn.py | 58 +++- aegea/ssh.py | 156 ++++++--- aegea/top.py | 4 +- aegea/util/__init__.py | 22 +- aegea/util/aws/__init__.py | 88 ++++- aegea/util/aws/_boto3_loader.py | 8 +- aegea/util/aws/batch.py | 127 ++++--- aegea/util/aws/batch_events_lambda/app.py | 2 + aegea/util/aws/dns.py | 32 +- aegea/util/aws/iam.py | 35 +- aegea/util/aws/logs.py | 47 ++- aegea/util/aws/spot.py | 42 ++- aegea/util/cloudinit.py | 51 ++- aegea/util/constants.py | 3 + aegea/util/crypto.py | 18 +- aegea/util/exceptions.py | 1 + aegea/util/printing.py | 57 ++- aegea/zones.py | 7 + mypy.ini | 20 -- pyproject.toml | 20 +- scripts/aegea | 7 +- scripts/aegea-build-image-for-mission | 9 +- scripts/aegea-rebuild-public-elb-sg | 19 +- scripts/aegea-ssh | 7 +- scripts/pypi-apt-freeze | 11 +- 61 files changed, 1972 insertions(+), 877 deletions(-) delete mode 100644 aegea/missions/arvados-worker/Makefile delete mode 100644 aegea/missions/arvados-worker/config.yml delete mode 120000 aegea/missions/arvados-worker/environment delete mode 120000 aegea/missions/arvados-worker/rootfs.skel.in/etc/default/munge delete mode 120000 aegea/missions/arvados-worker/rootfs.skel.in/etc/munge/munge.key delete mode 120000 aegea/missions/arvados-worker/rootfs.skel.in/etc/slurm-llnl/nodes.conf delete mode 120000 aegea/missions/arvados-worker/rootfs.skel.in/etc/slurm-llnl/slurm.conf delete mode 120000 aegea/missions/arvados-worker/rootfs.skel.in/usr/bin/aegea-set-slurm-nodes mode change 120000 => 100644 aegea/rootfs.skel.build_ami/root/.aws/config delete mode 100644 mypy.ini diff --git a/Makefile b/Makefile index ba470f55..ffae856d 100644 --- a/Makefile +++ b/Makefile @@ -6,9 +6,9 @@ aegea/constants.json: python3 -c "import aegea; aegea.initialize(); from aegea.util.constants import write; write()" lint: - for dir in $$(dirname */__init__.py); do ruff $$dir; done - for script in $$(grep -r -l '/usr/bin/env python3' aegea/missions aegea/rootfs.skel scripts); do ruff $$script; done - mypy --check-untyped-defs --no-strict-optional $$(python3 setup.py --name) + for dir in $$(dirname */__init__.py); do ruff check $$dir; done + for script in $$(grep -r -l '/usr/bin/env python3' aegea/missions aegea/rootfs.skel scripts); do ruff check $$script; done + mypy --install-types --non-interactive test: coverage run --source=$$(python3 setup.py --name) -m unittest discover --start-directory test --top-level-directory . --verbose diff --git a/aegea/__init__.py b/aegea/__init__.py index 335b4967..fbab6547 100644 --- a/aegea/__init__.py +++ b/aegea/__init__.py @@ -19,7 +19,7 @@ import warnings from io import open from textwrap import fill -from typing import Any, Dict +from typing import Any, Dict, Optional import boto3 import botocore @@ -33,8 +33,6 @@ logger = logging.getLogger(__name__) -config, parser = None, None # type: AegeaConfig, argparse.ArgumentParser -_subparsers = {} # type: Dict[Any, Any] class AegeaConfig(tweak.Config): base_config_file = os.path.join(os.path.dirname(__file__), "base_config.yml") @@ -59,6 +57,17 @@ def __doc__(self): doc += f"\n- {config_file} ({sources.get(i, 'set by AEGEA_CONFIG_FILE')})" return doc + +class _PlaceholderAegeaConfig(AegeaConfig): + def __init__(self, *args, **kwargs): + pass + + +config: AegeaConfig = _PlaceholderAegeaConfig() +parser: argparse.ArgumentParser = argparse.ArgumentParser() +_subparsers: Dict[Any, Any] = {} + + class AegeaHelpFormatter(argparse.RawTextHelpFormatter): def _get_help_string(self, action): default = _get_config_for_prog(self._prog).get(action.dest) @@ -67,9 +76,11 @@ def _get_help_string(self, action): return action.help + f" (default: {default})" return action.help + def initialize(): global config, parser from .util.printing import BOLD, ENDC, RED + config = AegeaConfig(__name__, use_yaml=True, save_on_exit=False) if not os.path.exists(config.user_config_file): config_dir = os.path.dirname(os.path.abspath(config.user_config_file)) @@ -83,26 +94,32 @@ def initialize(): parser = argparse.ArgumentParser( description=f"{BOLD() + RED() + __name__.capitalize() + ENDC()}: {fill(__doc__.strip())}", - formatter_class=AegeaHelpFormatter + formatter_class=AegeaHelpFormatter, + ) + parser.add_argument( + "--version", + action="version", + version="%(prog)s {}\n{}\n{}\n{} {}\n{}\n{}".format( + __version__, + "boto3 " + boto3.__version__, + "botocore " + botocore.__version__, + platform.python_implementation(), + platform.python_version(), + platform.platform(), + config.__doc__, + ), ) - parser.add_argument("--version", action="version", version="%(prog)s {}\n{}\n{}\n{} {}\n{}\n{}".format( - __version__, - "boto3 " + boto3.__version__, - "botocore " + botocore.__version__, - platform.python_implementation(), - platform.python_version(), - platform.platform(), - config.__doc__, - )) def help(args): parser.print_help() + register_parser(help) + def main(args=None): parsed_args = parser.parse_args(args=args) logger.setLevel(parsed_args.log_level) - has_attrs = (getattr(parsed_args, "sort_by", None) and getattr(parsed_args, "columns", None)) + has_attrs = getattr(parsed_args, "sort_by", None) and getattr(parsed_args, "columns", None) if has_attrs and parsed_args.sort_by not in parsed_args.columns: parsed_args.columns.append(parsed_args.sort_by) try: @@ -133,13 +150,16 @@ def main(args=None): del result["ResponseMetadata"] print(json.dumps(result, indent=2, default=str)) + def _get_config_for_prog(prog): command = prog.split(" ", 1)[-1].replace("-", "_").replace(" ", "_") return config.get(command, {}) + def register_parser(function, parent=None, name=None, **add_parser_args): def get_aws_profiles(**kwargs): from botocore.session import Session + return list(Session().full_config["profiles"]) def set_aws_profile(profile_name): @@ -148,6 +168,7 @@ def set_aws_profile(profile_name): def get_region_names(**kwargs): from botocore.loaders import create_loader + for partition_data in create_loader().load_data("endpoints")["partitions"]: if partition_data["partition"] == config.partition: return partition_data["regions"].keys() @@ -157,13 +178,15 @@ def set_aws_region(region_name): def set_endpoint_url(endpoint_url): from .util.aws._boto3_loader import Loader + Loader.client_kwargs["default"].update(endpoint_url=endpoint_url) def set_client_kwargs(client_kwargs): from .util.aws._boto3_loader import Loader + Loader.client_kwargs.update(json.loads(client_kwargs)) - if config is None: + if isinstance(config, _PlaceholderAegeaConfig): initialize() if parent is None: parent = parser @@ -177,17 +200,29 @@ def set_client_kwargs(client_kwargs): add_parser_args["help"] = add_parser_args["description"].strip().splitlines()[0].rstrip(".") add_parser_args.setdefault("formatter_class", AegeaHelpFormatter) subparser = _subparsers[parent.prog].add_parser(parser_name.replace("_", "-"), **add_parser_args) - subparser.add_argument("--max-col-width", "-w", type=int, default=32, - help="When printing tables, truncate column contents to this width. Set to 0 for auto fit.") - subparser.add_argument("--json", action="store_true", - help="Output tabular data as a JSON-formatted list of objects") - subparser.add_argument("--log-level", default=config.get("log_level"), type=str.upper, - help=str([logging.getLevelName(i) for i in range(10, 60, 10)]), - choices={logging.getLevelName(i) for i in range(10, 60, 10)}) - subparser.add_argument("--profile", help="Profile to use from the AWS CLI configuration file", - type=set_aws_profile).completer = get_aws_profiles - subparser.add_argument("--region", help="Region to use (overrides environment variable)", - type=set_aws_region).completer = get_region_names + subparser.add_argument( + "--max-col-width", + "-w", + type=int, + default=32, + help="When printing tables, truncate column contents to this width. Set to 0 for auto fit.", + ) + subparser.add_argument( + "--json", action="store_true", help="Output tabular data as a JSON-formatted list of objects" + ) + subparser.add_argument( + "--log-level", + default=config.get("log_level"), + type=str.upper, + help=str([logging.getLevelName(i) for i in range(10, 60, 10)]), + choices={logging.getLevelName(i) for i in range(10, 60, 10)}, + ) + subparser.add_argument( + "--profile", help="Profile to use from the AWS CLI configuration file", type=set_aws_profile + ).completer = get_aws_profiles + subparser.add_argument( + "--region", help="Region to use (overrides environment variable)", type=set_aws_region + ).completer = get_region_names subparser.add_argument("--endpoint-url", metavar="URL", help="Service endpoint URL to use", type=set_endpoint_url) subparser.add_argument("--client-kwargs", help=argparse.SUPPRESS, type=set_client_kwargs) subparser.set_defaults(entry_point=function) diff --git a/aegea/aegea_config.py b/aegea/aegea_config.py index 0453ca54..5229e432 100644 --- a/aegea/aegea_config.py +++ b/aegea/aegea_config.py @@ -25,8 +25,10 @@ def configure(args): configure_parser.print_help() + configure_parser = register_parser(configure) + def ls(args): from . import config, tweak @@ -36,22 +38,28 @@ def collect_kv(d, path, collector): collect_kv(d[k], path + "." + k, collector) else: collector.append([path.lstrip(".") + "." + k, repr(v)]) + collector = [] # type: List[List] collect_kv(config, "", collector) page_output(format_table(collector)) + ls_parser = register_listing_parser(ls, parent=configure_parser) + def get(args): """Get an Aegea configuration parameter by name""" from . import config + for key in args.key.split("."): config = getattr(config, key) print(json.dumps(config)) + get_parser = register_parser(get, parent=configure_parser) get_parser.add_argument("key") + def set(args): """Set an Aegea configuration parameter to a given value""" from . import config, tweak @@ -72,12 +80,15 @@ def config_files(self): c[args.key.split(".")[-1]] = json.loads(args.value) if args.json else args.value config_saver.save() + set_parser = register_parser(set, parent=configure_parser) set_parser.add_argument("key") set_parser.add_argument("value") + def sync(args): """Save Aegea configuration to your AWS IAM account, or retrieve a previously saved configuration""" raise NotImplementedError() + sync_parser = register_listing_parser(sync, parent=configure_parser) diff --git a/aegea/alarms.py b/aegea/alarms.py index 54de74cf..c263d35d 100644 --- a/aegea/alarms.py +++ b/aegea/alarms.py @@ -10,29 +10,34 @@ def alarms(args): page_output(tabulate(resources.cloudwatch.alarms.all(), args)) + parser = register_listing_parser(alarms, help="List CloudWatch alarms") + def put_alarm(args): sns = resources.sns logs = clients.logs cloudwatch = clients.cloudwatch topic = sns.create_topic(Name=args.alarm_name) topic.subscribe(Protocol="email", Endpoint=args.email) - logs.put_metric_filter(logGroupName=args.log_group_name, - filterName=args.alarm_name, - filterPattern=args.pattern, - metricTransformations=[dict(metricName=args.alarm_name, - metricNamespace=__name__, - metricValue="1")]) - cloudwatch.put_metric_alarm(AlarmName=args.alarm_name, - MetricName=args.alarm_name, - Namespace=__name__, - Statistic="Sum", - Period=300, - Threshold=1, - ComparisonOperator="GreaterThanOrEqualToThreshold", - EvaluationPeriods=1, - AlarmActions=[topic.arn]) + logs.put_metric_filter( + logGroupName=args.log_group_name, + filterName=args.alarm_name, + filterPattern=args.pattern, + metricTransformations=[dict(metricName=args.alarm_name, metricNamespace=__name__, metricValue="1")], + ) + cloudwatch.put_metric_alarm( + AlarmName=args.alarm_name, + MetricName=args.alarm_name, + Namespace=__name__, + Statistic="Sum", + Period=300, + Threshold=1, + ComparisonOperator="GreaterThanOrEqualToThreshold", + EvaluationPeriods=1, + AlarmActions=[topic.arn], + ) + parser = register_parser(put_alarm, help="Configure a CloudWatch alarm") parser.add_argument("--log-group-name", required=True) diff --git a/aegea/audit.py b/aegea/audit.py index 863ff758..65d5e39c 100644 --- a/aegea/audit.py +++ b/aegea/audit.py @@ -149,13 +149,16 @@ def audit_2_1(self): def audit_2_2(self): """2.2 Ensure CloudTrail log file validation is enabled (Scored)""" self.assertGreater(len(self.trails), 0, "No CloudTrail trails configured") - self.assertTrue(all(trail["LogFileValidationEnabled"] for trail in self.trails), - "Some CloudTrail trails don't have log file validation enabled") + self.assertTrue( + all(trail["LogFileValidationEnabled"] for trail in self.trails), + "Some CloudTrail trails don't have log file validation enabled", + ) def audit_2_3(self): """2.3 Ensure the S3 bucket CloudTrail logs to is not publicly accessible (Scored)""" raise NotImplementedError() import boto3 + s3 = boto3.session.Session(region_name="us-east-1").resource("s3") # s3 = boto3.resource("s3") # for trail in self.trails: @@ -178,12 +181,14 @@ def audit_2_4(self): for trail in self.trails: self.assertIn("CloudWatchLogsLogGroupArn", trail) trail_status = clients.cloudtrail.get_trail_status(Name=trail["TrailARN"]) - self.assertGreater(trail_status["LatestCloudWatchLogsDeliveryTime"], - datetime.now(tzutc()) - timedelta(days=1)) + self.assertGreater( + trail_status["LatestCloudWatchLogsDeliveryTime"], datetime.now(tzutc()) - timedelta(days=1) + ) def audit_2_5(self): """2.5 Ensure AWS Config is enabled in all regions (Scored)""" import boto3 + for region in boto3.Session().get_available_regions("config"): aws_config = boto3.session.Session(region_name=region).client("config") res = aws_config.describe_configuration_recorder_status() @@ -207,22 +212,24 @@ def ensure_alarm(self, name, pattern, log_group_name): logs = clients.logs cloudwatch = clients.cloudwatch topic = sns.create_topic(Name=name) - topic.subscribe(Protocol='email', Endpoint=self.email) - logs.put_metric_filter(logGroupName=log_group_name, - filterName=name, - filterPattern=pattern, - metricTransformations=[dict(metricName=name, - metricNamespace=__name__, - metricValue="1")]) - cloudwatch.put_metric_alarm(AlarmName=name, - MetricName=name, - Namespace=__name__, - Statistic="Sum", - Period=300, - Threshold=1, - ComparisonOperator="GreaterThanOrEqualToThreshold", - EvaluationPeriods=1, - AlarmActions=[topic.arn]) + topic.subscribe(Protocol="email", Endpoint=self.email) + logs.put_metric_filter( + logGroupName=log_group_name, + filterName=name, + filterPattern=pattern, + metricTransformations=[dict(metricName=name, metricNamespace=__name__, metricValue="1")], + ) + cloudwatch.put_metric_alarm( + AlarmName=name, + MetricName=name, + Namespace=__name__, + Statistic="Sum", + Period=300, + Threshold=1, + ComparisonOperator="GreaterThanOrEqualToThreshold", + EvaluationPeriods=1, + AlarmActions=[topic.arn], + ) def assert_alarm(self, name, pattern, remediate=False): logs = clients.logs @@ -240,81 +247,104 @@ def assert_alarm(self, name, pattern, remediate=False): except Exception: pass if remediate and not alarm_ok: - self.ensure_alarm(name=name, - pattern=pattern, - log_group_name=log_group_name) + self.ensure_alarm(name=name, pattern=pattern, log_group_name=log_group_name) alarm_ok = True self.assertTrue(alarm_ok) def audit_3_1(self): """3.1 Ensure a log metric filter and alarm exist for unauthorized API calls (Scored)""" - self.assert_alarm("UnauthorizedAPICalls", - '{ ($.errorCode = "*UnauthorizedOperation") || ($.errorCode = "AccessDenied*") }') + self.assert_alarm( + "UnauthorizedAPICalls", '{ ($.errorCode = "*UnauthorizedOperation") || ($.errorCode = "AccessDenied*") }' + ) def audit_3_2(self): """3.2 Ensure a log metric filter and alarm exist for Management Console sign-in without MFA (Scored)""" - self.assert_alarm("ConsoleUseWithoutMFA", - '{ $.userIdentity.sessionContext.attributes.mfaAuthenticated != "true" }') + self.assert_alarm( + "ConsoleUseWithoutMFA", '{ $.userIdentity.sessionContext.attributes.mfaAuthenticated != "true" }' + ) def audit_3_3(self): """3.3 Ensure a log metric filter and alarm exist for usage of "root" account (Scored)""" - self.assert_alarm("RootAccountUsed", - '{ $.userIdentity.type = \"Root\" && $.userIdentity.invokedBy NOT EXISTS && $.eventType != \"AwsServiceEvent\" }') # noqa + self.assert_alarm( + "RootAccountUsed", + '{ $.userIdentity.type = "Root" && $.userIdentity.invokedBy NOT EXISTS && $.eventType != "AwsServiceEvent" }', + ) # noqa def audit_3_4(self): """3.4 Ensure a log metric filter and alarm exist for IAM policy changes (Scored)""" - self.assert_alarm("IAMPolicyChanged", - '{($.eventName=DeleteGroupPolicy)||($.eventName=DeleteRolePolicy)||($.eventName=DeleteUserPolicy)||($.eventName=PutGroupPolicy)||($.eventName=PutRolePolicy)||($.eventName=PutUserPolicy)||($.eventName=CreatePolicy)||($.eventName=DeletePolicy)||($.eventName=CreatePolicyVersion)||($.eventName=DeletePolicyVersion)||($.eventName=AttachRolePolicy)||($.eventName=DetachRolePolicy)||($.eventName=AttachUserPolicy)||($.eventName=DetachUserPolicy)||($.eventName=AttachGroupPolicy)||($.eventName=DetachGroupPolicy)}') # noqa + self.assert_alarm( + "IAMPolicyChanged", + "{($.eventName=DeleteGroupPolicy)||($.eventName=DeleteRolePolicy)||($.eventName=DeleteUserPolicy)||($.eventName=PutGroupPolicy)||($.eventName=PutRolePolicy)||($.eventName=PutUserPolicy)||($.eventName=CreatePolicy)||($.eventName=DeletePolicy)||($.eventName=CreatePolicyVersion)||($.eventName=DeletePolicyVersion)||($.eventName=AttachRolePolicy)||($.eventName=DetachRolePolicy)||($.eventName=AttachUserPolicy)||($.eventName=DetachUserPolicy)||($.eventName=AttachGroupPolicy)||($.eventName=DetachGroupPolicy)}", + ) # noqa def audit_3_5(self): """3.5 Ensure a log metric filter and alarm exist for CloudTrail configuration changes (Scored)""" - self.assert_alarm("CloudTrailConfigChanged", - '{ ($.eventName = CreateTrail) || ($.eventName = UpdateTrail) || ($.eventName = DeleteTrail) || ($.eventName = StartLogging) || ($.eventName = StopLogging) }') # noqa + self.assert_alarm( + "CloudTrailConfigChanged", + "{ ($.eventName = CreateTrail) || ($.eventName = UpdateTrail) || ($.eventName = DeleteTrail) || ($.eventName = StartLogging) || ($.eventName = StopLogging) }", + ) # noqa def audit_3_6(self): """3.6 Ensure a log metric filter and alarm exist for AWS Management Console authentication failures (Scored)""" - self.assert_alarm("ConsoleLoginFailed", - '{ ($.eventName = ConsoleLogin) && ($.errorMessage = \"Failed authentication\") }') + self.assert_alarm( + "ConsoleLoginFailed", '{ ($.eventName = ConsoleLogin) && ($.errorMessage = "Failed authentication") }' + ) def audit_3_7(self): """3.7 Ensure a log metric filter and alarm exist for disabling or scheduled deletion of customer created CMKs (Scored)""" # noqa - self.assert_alarm("KMSCMKDisabled", - '{($.eventSource = kms.amazonaws.com) && (($.eventName=DisableKey)||($.eventName=ScheduleKeyDeletion))}') # noqa + self.assert_alarm( + "KMSCMKDisabled", + "{($.eventSource = kms.amazonaws.com) && (($.eventName=DisableKey)||($.eventName=ScheduleKeyDeletion))}", + ) # noqa def audit_3_8(self): """3.8 Ensure a log metric filter and alarm exist for S3 bucket policy changes (Scored)""" - self.assert_alarm("S3BucketPolicyChanged", - '{ ($.eventSource = s3.amazonaws.com) && (($.eventName = PutBucketAcl) || ($.eventName = PutBucketPolicy) || ($.eventName = PutBucketCors) || ($.eventName = PutBucketLifecycle) || ($.eventName = PutBucketReplication) || ($.eventName = DeleteBucketPolicy) || ($.eventName = DeleteBucketCors) || ($.eventName = DeleteBucketLifecycle) || ($.eventName = DeleteBucketReplication)) }') # noqa + self.assert_alarm( + "S3BucketPolicyChanged", + "{ ($.eventSource = s3.amazonaws.com) && (($.eventName = PutBucketAcl) || ($.eventName = PutBucketPolicy) || ($.eventName = PutBucketCors) || ($.eventName = PutBucketLifecycle) || ($.eventName = PutBucketReplication) || ($.eventName = DeleteBucketPolicy) || ($.eventName = DeleteBucketCors) || ($.eventName = DeleteBucketLifecycle) || ($.eventName = DeleteBucketReplication)) }", + ) # noqa def audit_3_9(self): """3.9 Ensure a log metric filter and alarm exist for AWS Config configuration changes (Scored)""" - self.assert_alarm("AWSConfigServiceChanged", - '{($.eventSource = config.amazonaws.com) && (($.eventName=StopConfigurationRecorder)||($.eventName=DeleteDeliveryChannel)||($.eventName=PutDeliveryChannel)||($.eventName=PutConfigurationRecorder))}') # noqa + self.assert_alarm( + "AWSConfigServiceChanged", + "{($.eventSource = config.amazonaws.com) && (($.eventName=StopConfigurationRecorder)||($.eventName=DeleteDeliveryChannel)||($.eventName=PutDeliveryChannel)||($.eventName=PutConfigurationRecorder))}", + ) # noqa def audit_3_10(self): """3.10 Ensure a log metric filter and alarm exist for security group changes (Scored)""" - self.assert_alarm("EC2SecurityGroupChanged", - '{ ($.eventName = AuthorizeSecurityGroupIngress) || ($.eventName = AuthorizeSecurityGroupEgress) || ($.eventName = RevokeSecurityGroupIngress) || ($.eventName = RevokeSecurityGroupEgress) || ($.eventName = CreateSecurityGroup) || ($.eventName = DeleteSecurityGroup)}') # noqa + self.assert_alarm( + "EC2SecurityGroupChanged", + "{ ($.eventName = AuthorizeSecurityGroupIngress) || ($.eventName = AuthorizeSecurityGroupEgress) || ($.eventName = RevokeSecurityGroupIngress) || ($.eventName = RevokeSecurityGroupEgress) || ($.eventName = CreateSecurityGroup) || ($.eventName = DeleteSecurityGroup)}", + ) # noqa def audit_3_11(self): """3.11 Ensure a log metric filter and alarm exist for changes to Network Access Control Lists (NACL) (Scored)""" # noqa - self.assert_alarm("EC2NACLChanged", - '{ ($.eventName = CreateNetworkAcl) || ($.eventName = CreateNetworkAclEntry) || ($.eventName = DeleteNetworkAcl) || ($.eventName = DeleteNetworkAclEntry) || ($.eventName = ReplaceNetworkAclEntry) || ($.eventName = ReplaceNetworkAclAssociation) }') # noqa + self.assert_alarm( + "EC2NACLChanged", + "{ ($.eventName = CreateNetworkAcl) || ($.eventName = CreateNetworkAclEntry) || ($.eventName = DeleteNetworkAcl) || ($.eventName = DeleteNetworkAclEntry) || ($.eventName = ReplaceNetworkAclEntry) || ($.eventName = ReplaceNetworkAclAssociation) }", + ) # noqa def audit_3_12(self): """3.12 Ensure a log metric filter and alarm exist for changes to network gateways (Scored)""" - self.assert_alarm("EC2NetworkGatewayChanged", - '{ ($.eventName = CreateCustomerGateway) || ($.eventName = DeleteCustomerGateway) || ($.eventName = AttachInternetGateway) || ($.eventName = CreateInternetGateway) || ($.eventName = DeleteInternetGateway) || ($.eventName = DetachInternetGateway) }') # noqa + self.assert_alarm( + "EC2NetworkGatewayChanged", + "{ ($.eventName = CreateCustomerGateway) || ($.eventName = DeleteCustomerGateway) || ($.eventName = AttachInternetGateway) || ($.eventName = CreateInternetGateway) || ($.eventName = DeleteInternetGateway) || ($.eventName = DetachInternetGateway) }", + ) # noqa def audit_3_13(self): """3.13 Ensure a log metric filter and alarm exist for route table changes (Scored)""" - self.assert_alarm("EC2RouteTableChanged", - '{ ($.eventName = CreateRoute) || ($.eventName = CreateRouteTable) || ($.eventName = ReplaceRoute) || ($.eventName = ReplaceRouteTableAssociation) || ($.eventName = DeleteRouteTable) || ($.eventName = DeleteRoute) || ($.eventName = DisassociateRouteTable) }') # noqa + self.assert_alarm( + "EC2RouteTableChanged", + "{ ($.eventName = CreateRoute) || ($.eventName = CreateRouteTable) || ($.eventName = ReplaceRoute) || ($.eventName = ReplaceRouteTableAssociation) || ($.eventName = DeleteRouteTable) || ($.eventName = DeleteRoute) || ($.eventName = DisassociateRouteTable) }", + ) # noqa def audit_3_14(self): """3.14 Ensure a log metric filter and alarm exist for VPC changes (Scored)""" - self.assert_alarm("EC2VPCChanged", - '{ ($.eventName = CreateVpc) || ($.eventName = DeleteVpc) || ($.eventName = ModifyVpcAttribute) || ($.eventName = AcceptVpcPeeringConnection) || ($.eventName = CreateVpcPeeringConnection) || ($.eventName = DeleteVpcPeeringConnection) || ($.eventName = RejectVpcPeeringConnection) || ($.eventName = AttachClassicLinkVpc) || ($.eventName = DetachClassicLinkVpc) || ($.eventName = DisableVpcClassicLink) || ($.eventName = EnableVpcClassicLink) }') # noqa + self.assert_alarm( + "EC2VPCChanged", + "{ ($.eventName = CreateVpc) || ($.eventName = DeleteVpc) || ($.eventName = ModifyVpcAttribute) || ($.eventName = AcceptVpcPeeringConnection) || ($.eventName = CreateVpcPeeringConnection) || ($.eventName = DeleteVpcPeeringConnection) || ($.eventName = RejectVpcPeeringConnection) || ($.eventName = AttachClassicLinkVpc) || ($.eventName = DetachClassicLinkVpc) || ($.eventName = DisableVpcClassicLink) || ($.eventName = EnableVpcClassicLink) }", + ) # noqa def audit_3_15(self): """3.15 Ensure security contact information is registered (Scored)""" @@ -340,6 +370,7 @@ def audit_4_4(self): """4.4 Ensure the default security group restricts all traffic (Scored)""" raise NotImplementedError() + def audit(args): auditor = Auditor() auditor.__dict__.update(vars(args)) @@ -356,5 +387,6 @@ def audit(args): # TODO: WHITE("NO TEST") page_output(format_table(table, column_names=["Result", "Test"], max_col_width=120)) -parser = register_parser(audit, help='Generate a security report using the CIS AWS Foundations Benchmark') -parser.add_argument('--email', help="Administrative contact email") + +parser = register_parser(audit, help="Generate a security report using the CIS AWS Foundations Benchmark") +parser.add_argument("--email", help="Administrative contact email") diff --git a/aegea/batch.py b/aegea/batch.py index 10f40499..47bb9837 100644 --- a/aegea/batch.py +++ b/aegea/batch.py @@ -50,19 +50,25 @@ def complete_queue_name(**kwargs): return [q["jobQueueName"] for q in paginate(clients.batch.get_paginator("describe_job_queues"))] + def complete_ce_name(**kwargs): return [c["computeEnvironmentName"] for c in paginate(clients.batch.get_paginator("describe_compute_environments"))] + def batch(args): batch_parser.print_help() + batch_parser = register_parser(batch, help="Manage AWS Batch resources", description=__doc__) + def queues(args): page_output(tabulate(paginate(clients.batch.get_paginator("describe_job_queues")), args)) + queues_parser = register_listing_parser(queues, parent=batch_parser, help="List Batch queues") + def create_queue(args): ces = [dict(computeEnvironment=e, order=i) for i, e in enumerate(args.compute_environments)] logger.info("Creating queue %s in %s", args.name, ces) @@ -70,24 +76,30 @@ def create_queue(args): make_waiter(clients.batch.describe_job_queues, "jobQueues[].status", "VALID", "pathAny").wait(jobQueues=[args.name]) return queue + create_queue_parser = register_parser(create_queue, parent=batch_parser, help="Create a Batch queue") create_queue_parser.add_argument("name") create_queue_parser.add_argument("--priority", type=int, default=5) create_queue_parser.add_argument("--compute-environments", nargs="+", required=True) + def delete_queue(args): clients.batch.update_job_queue(jobQueue=args.name, state="DISABLED") make_waiter(clients.batch.describe_job_queues, "jobQueues[].status", "VALID", "pathAny").wait(jobQueues=[args.name]) clients.batch.delete_job_queue(jobQueue=args.name) + delete_queue_parser = register_parser(delete_queue, parent=batch_parser, help="Delete a Batch queue") delete_queue_parser.add_argument("name").completer = complete_queue_name + def compute_environments(args): page_output(tabulate(paginate(clients.batch.get_paginator("describe_compute_environments")), args)) + ce_parser = register_listing_parser(compute_environments, parent=batch_parser, help="List Batch compute environments") + def ensure_launch_template(prefix=__name__.replace(".", "_"), **kwargs): name = prefix + "_" + hashlib.sha256(json.dumps(kwargs, sort_keys=True).encode()).hexdigest()[:32] try: @@ -96,19 +108,26 @@ def ensure_launch_template(prefix=__name__.replace(".", "_"), **kwargs): expect_error_codes(e, "InvalidLaunchTemplateName.AlreadyExistsException") return name + def create_compute_environment(args): batch_iam_role = ensure_iam_role(args.service_role, trust=["batch"], policies=["service-role/AWSBatchServiceRole"]) vpc = ensure_vpc() ssh_key_name = ensure_ssh_key(args.ssh_key_name, base_name=__name__) - instance_profile = ensure_instance_profile(args.instance_role, - policies={"service-role/AmazonAPIGatewayPushToCloudWatchLogs", - "service-role/AmazonEC2ContainerServiceforEC2Role", - "AmazonSSMManagedInstanceCore", - IAMPolicyBuilder(action="sts:AssumeRole", resource="*")}) - compute_resources = dict(type=args.compute_type, - maxvCpus=args.max_vcpus, - subnets=[subnet.id for subnet in vpc.subnets.all()], - securityGroupIds=[ensure_security_group("aegea.launch", vpc).id]) + instance_profile = ensure_instance_profile( + args.instance_role, + policies={ + "service-role/AmazonAPIGatewayPushToCloudWatchLogs", + "service-role/AmazonEC2ContainerServiceforEC2Role", + "AmazonSSMManagedInstanceCore", + IAMPolicyBuilder(action="sts:AssumeRole", resource="*"), + }, + ) + compute_resources = dict( + type=args.compute_type, + maxvCpus=args.max_vcpus, + subnets=[subnet.id for subnet in vpc.subnets.all()], + securityGroupIds=[ensure_security_group("aegea.launch", vpc).id], + ) if not args.compute_type.startswith("FARGATE"): commands = instance_storage_shellcode.strip().format(mountpoint="/mnt", mkfs=get_mkfs_command()).split("\n") user_data = get_user_data(commands=commands, mime_multipart_archive=True) @@ -118,27 +137,40 @@ def create_compute_environment(args): ecs_ami_id = resolve_ami(tags=args.ecs_container_instance_ami_tags).id else: ecs_ami_id = get_ssm_parameter("/aws/service/ecs/optimized-ami/amazon-linux-2/recommended/image_id") - launch_template = ensure_launch_template(ImageId=ecs_ami_id, - # TODO: add configurable BDM for Docker image cache space - UserData=base64.b64encode(user_data).decode()) - compute_resources.update(minvCpus=args.min_vcpus, - desiredvCpus=args.desired_vcpus, - ec2KeyPair=ssh_key_name, - instanceRole=instance_profile.name, - instanceTypes=args.instance_types, - launchTemplate=dict(launchTemplateName=launch_template), - spotIamFleetRole=SpotFleetBuilder.get_iam_fleet_role().name, - tags=dict(Name=__name__)) + launch_template = ensure_launch_template( + ImageId=ecs_ami_id, + # TODO: add configurable BDM for Docker image cache space + UserData=base64.b64encode(user_data).decode(), + ) + compute_resources.update( + minvCpus=args.min_vcpus, + desiredvCpus=args.desired_vcpus, + ec2KeyPair=ssh_key_name, + instanceRole=instance_profile.name, + instanceTypes=args.instance_types, + launchTemplate=dict(launchTemplateName=launch_template), + spotIamFleetRole=SpotFleetBuilder.get_iam_fleet_role().name, + tags=dict(Name=__name__), + ) logger.info("Creating compute environment %s in %s", args.name, vpc) - compute_environment = clients.batch.create_compute_environment(computeEnvironmentName=args.name, - type=args.type, - computeResources=compute_resources, - serviceRole=batch_iam_role.name) - wtr = make_waiter(clients.batch.describe_compute_environments, "computeEnvironments[].status", "VALID", "pathAny", - delay=2, max_attempts=300) + compute_environment = clients.batch.create_compute_environment( + computeEnvironmentName=args.name, + type=args.type, + computeResources=compute_resources, + serviceRole=batch_iam_role.name, + ) + wtr = make_waiter( + clients.batch.describe_compute_environments, + "computeEnvironments[].status", + "VALID", + "pathAny", + delay=2, + max_attempts=300, + ) wtr.wait(computeEnvironments=[args.name]) return compute_environment + cce_parser = register_parser(create_compute_environment, parent=batch_parser, help="Create a Batch compute environment") cce_parser.add_argument("name") cce_parser.add_argument("--type", choices={"MANAGED", "UNMANAGED"}) @@ -153,6 +185,7 @@ def create_compute_environment(args): cce_parser.add_argument("--ecs-container-instance-ami") cce_parser.add_argument("--ecs-container-instance-ami-tags") + def update_compute_environment(args): update_compute_environment_args = dict(computeEnvironment=args.name, computeResources={}) if args.min_vcpus is not None: @@ -163,21 +196,25 @@ def update_compute_environment(args): update_compute_environment_args["computeResources"].update(maxvCpus=args.max_vcpus) return clients.batch.update_compute_environment(**update_compute_environment_args) + uce_parser = register_parser(update_compute_environment, parent=batch_parser, help="Update a Batch compute environment") uce_parser.add_argument("name").completer = complete_ce_name uce_parser.add_argument("--min-vcpus", type=int) uce_parser.add_argument("--desired-vcpus", type=int) uce_parser.add_argument("--max-vcpus", type=int) + def delete_compute_environment(args): clients.batch.update_compute_environment(computeEnvironment=args.name, state="DISABLED") wtr = make_waiter(clients.batch.describe_compute_environments, "computeEnvironments[].status", "VALID", "pathAny") wtr.wait(computeEnvironments=[args.name]) clients.batch.delete_compute_environment(computeEnvironment=args.name) + dce_parser = register_parser(delete_compute_environment, parent=batch_parser, help="Delete a Batch compute environment") dce_parser.add_argument("name").completer = complete_ce_name + def ensure_queue(name): cq_args = argparse.Namespace(name=name, priority=5, compute_environments=[name]) try: @@ -186,6 +223,7 @@ def ensure_queue(name): create_compute_environment(cce_parser.parse_args(args=[name])) return create_queue(cq_args) + def submit(args): try: ensure_lambda_helper() @@ -209,10 +247,12 @@ def submit(args): container_overrides = dict(command=command, environment=environment) if args.job_role == config.batch_submit.job_role: - args.default_job_role_iam_policies.append(IAMPolicyBuilder( - action=["s3:List*", "s3:HeadObject*", "s3:GetObject*", "s3:PutObject*"], - resource=["arn:aws:s3:::aegea-*", "arn:aws:s3:::aegea-*/*"] - )) + args.default_job_role_iam_policies.append( + IAMPolicyBuilder( + action=["s3:List*", "s3:HeadObject*", "s3:GetObject*", "s3:PutObject*"], + resource=["arn:aws:s3:::aegea-*", "arn:aws:s3:::aegea-*/*"], + ) + ) else: args.default_job_role_iam_policies = [] job_definition_arn, job_name = ensure_job_definition(args) @@ -230,12 +270,14 @@ def submit(args): args.memory = int(args.default_memory_mb) container_overrides["memory"] = args.memory while True: - submit_args = dict(jobName=job_name, - jobQueue=args.queue, - dependsOn=[dict(jobId=dep) for dep in args.depends_on], - jobDefinition=job_definition_arn, - parameters={k: v for k, v in args.parameters}, - containerOverrides=container_overrides) + submit_args = dict( + jobName=job_name, + jobQueue=args.queue, + dependsOn=[dict(jobId=dep) for dep in args.depends_on], + jobDefinition=job_definition_arn, + parameters={k: v for k, v in args.parameters}, + containerOverrides=container_overrides, + ) try: if args.dry_run: logger.info("The following command would be run:") @@ -253,12 +295,14 @@ def submit(args): logger.debug("This job must be dispatched to Fargate, switching to Fargate job definition") args.platform_capabilities = ["FARGATE"] job_definition_arn, job_name = ensure_job_definition(args) - container_overrides["resourceRequirements"] = [dict(type="VCPU", value=str(args.vcpus)), - dict(type="MEMORY", value=str(args.memory))] + container_overrides["resourceRequirements"] = [ + dict(type="VCPU", value=str(args.vcpus)), + dict(type="MEMORY", value=str(args.memory)), + ] del container_overrides["memory"] - submit_args.update(jobName=job_name, - jobDefinition=job_definition_arn, - containerOverrides=container_overrides) + submit_args.update( + jobName=job_name, jobDefinition=job_definition_arn, containerOverrides=container_overrides + ) else: raise @@ -278,74 +322,131 @@ def submit(args): raise NotImplementedError() return job + submit_parser = register_parser(submit, parent=batch_parser, help="Submit a job to a Batch queue") submit_parser.add_argument("--name") submit_parser.add_argument("--queue", default=__name__.replace(".", "_")).completer = complete_queue_name submit_parser.add_argument("--depends-on", nargs="+", metavar="JOB_ID", default=[]) submit_parser.add_argument("--job-definition-arn") + def add_command_args(parser): group = parser.add_mutually_exclusive_group() group.add_argument("--watch", action="store_true", help="Monitor submitted job, stream log until job completes") - group.add_argument("--wait", action="store_true", - help="Block on job. Exit with code 0 if job succeeded, 1 if failed") + group.add_argument( + "--wait", action="store_true", help="Block on job. Exit with code 0 if job succeeded, 1 if failed" + ) group = parser.add_mutually_exclusive_group() group.add_argument("--command", nargs="+", help="Run these commands as the job (using " + BOLD("bash -c") + ")") - group.add_argument("--execute", type=argparse.FileType("rb"), metavar="EXECUTABLE", - help="Read this executable file and run it as the job") - group.add_argument("--wdl", type=argparse.FileType("rb"), metavar="WDL_WORKFLOW", - help="Read this WDL workflow file and run it as the job") - parser.add_argument("--wdl-input", type=argparse.FileType("r"), metavar="WDL_INPUT_JSON", default=sys.stdin, - help="With --wdl, use this JSON file as the WDL job input (default: stdin)") - parser.add_argument("--environment", nargs="+", metavar="NAME=VALUE", - type=lambda x: dict(zip(["name", "value"], x.split("=", 1))), default=[]) + group.add_argument( + "--execute", + type=argparse.FileType("rb"), + metavar="EXECUTABLE", + help="Read this executable file and run it as the job", + ) + group.add_argument( + "--wdl", + type=argparse.FileType("rb"), + metavar="WDL_WORKFLOW", + help="Read this WDL workflow file and run it as the job", + ) + parser.add_argument( + "--wdl-input", + type=argparse.FileType("r"), + metavar="WDL_INPUT_JSON", + default=sys.stdin, + help="With --wdl, use this JSON file as the WDL job input (default: stdin)", + ) + parser.add_argument( + "--environment", + nargs="+", + metavar="NAME=VALUE", + type=lambda x: dict(zip(["name", "value"], x.split("=", 1))), + default=[], + ) parser.add_argument("--staging-s3-bucket", help=argparse.SUPPRESS) + def add_job_defn_args(parser): - parser.add_argument("--ulimits", nargs="*", - help="Separate ulimit name and value with colon, for example: --ulimits nofile:20000", - default=["nofile:100000"]) + parser.add_argument( + "--ulimits", + nargs="*", + help="Separate ulimit name and value with colon, for example: --ulimits nofile:20000", + default=["nofile:100000"], + ) img_group = parser.add_mutually_exclusive_group() - img_group.add_argument("--image", default="ubuntu", metavar="DOCKER_IMAGE", - help="Docker image URL to use for running job/task") + img_group.add_argument( + "--image", default="ubuntu", metavar="DOCKER_IMAGE", help="Docker image URL to use for running job/task" + ) ecs_img_help = "Name of Docker image residing in this account's Elastic Container Registry" ecs_img_arg = img_group.add_argument("--ecs-image", "--ecr-image", "-i", metavar="REPO[:TAG]", help=ecs_img_help) ecs_img_arg.completer = ecr_image_name_completer - parser.add_argument("--volumes", nargs="+", metavar="HOST_PATH=GUEST_PATH", type=lambda x: x.split("=", 1), - default=[]) + parser.add_argument( + "--volumes", nargs="+", metavar="HOST_PATH=GUEST_PATH", type=lambda x: x.split("=", 1), default=[] + ) # Note: ECS (but not Batch) also supports memoryReservation, a way to specify a soft memory limit for packing, # but that is only useful when co-locating multiple ECS containers together (which we don't do here) - parser.add_argument("--memory-mb", dest="memory", type=int, - help="Memory to allocate to the Docker container (this is both a soft and a hard limit)") + parser.add_argument( + "--memory-mb", + dest="memory", + type=int, + help="Memory to allocate to the Docker container (this is both a soft and a hard limit)", + ) parser.add_argument("--user", help="Name or ID of user to use in the Docker container") parser.add_argument("--log-driver", help="Custom log driver") - parser.add_argument("--log-options", help="Custom log configuration options", nargs="+", metavar="NAME=VALUE", - type=lambda x: x.split("=", 1), default=[]) + parser.add_argument( + "--log-options", + help="Custom log configuration options", + nargs="+", + metavar="NAME=VALUE", + type=lambda x: x.split("=", 1), + default=[], + ) + add_command_args(submit_parser) -group = submit_parser.add_argument_group(title="job definition parameters", description=""" -See http://docs.aws.amazon.com/batch/latest/userguide/job_definitions.html""") +group = submit_parser.add_argument_group( + title="job definition parameters", + description=""" +See http://docs.aws.amazon.com/batch/latest/userguide/job_definitions.html""", +) add_job_defn_args(group) group.add_argument("--vcpus", type=int, default=1) group.add_argument("--gpus", type=int, default=0) group.add_argument("--privileged", action="store_true", default=False) -group.add_argument("--volume-type", choices={"standard", "io1", "gp2", "sc1", "st1"}, - help="io1, PIOPS SSD; gp2, general purpose SSD; sc1, cold HDD; st1, throughput optimized HDD") +group.add_argument( + "--volume-type", + choices={"standard", "io1", "gp2", "sc1", "st1"}, + help="io1, PIOPS SSD; gp2, general purpose SSD; sc1, cold HDD; st1, throughput optimized HDD", +) group.add_argument("--parameters", nargs="+", metavar="NAME=VALUE", type=lambda x: x.split("=", 1), default=[]) group.add_argument("--job-role", metavar="IAM_ROLE", help="Name of IAM role to grant to the job") -group.add_argument("--storage", nargs="+", metavar="MOUNTPOINT=SIZE_GB", - type=lambda x: x.rstrip("GBgb").split("=", 1), default=[]) -group.add_argument("--efs-storage", action="store", dest="efs_storage", default=False, - help="Mount EFS network filesystem to the mount point specified. Example: --efs-storage /mnt") -group.add_argument("--mount-instance-storage", nargs="?", const="/mnt", - help="Assemble (MD RAID0), format and mount ephemeral instance storage on this mount point") -submit_parser.add_argument("--timeout", - help="Terminate (and possibly restart) the job after this time (use suffix s, m, h, d, w)") -submit_parser.add_argument("--retry-attempts", type=int, default=1, - help="Number of times to restart the job upon failure") +group.add_argument( + "--storage", nargs="+", metavar="MOUNTPOINT=SIZE_GB", type=lambda x: x.rstrip("GBgb").split("=", 1), default=[] +) +group.add_argument( + "--efs-storage", + action="store", + dest="efs_storage", + default=False, + help="Mount EFS network filesystem to the mount point specified. Example: --efs-storage /mnt", +) +group.add_argument( + "--mount-instance-storage", + nargs="?", + const="/mnt", + help="Assemble (MD RAID0), format and mount ephemeral instance storage on this mount point", +) +submit_parser.add_argument( + "--timeout", help="Terminate (and possibly restart) the job after this time (use suffix s, m, h, d, w)" +) +submit_parser.add_argument( + "--retry-attempts", type=int, default=1, help="Number of times to restart the job upon failure" +) submit_parser.add_argument("--dry-run", action="store_true", help="Gather arguments and stop short of submitting job") + def terminate(args): def terminate_one(job_id): return clients.batch.terminate_job(jobId=job_id, reason=args.reason) @@ -357,10 +458,12 @@ def terminate_one(job_id): result += list(executor.map(terminate_one, args.job_id[1:])) logger.info("Sent termination requests for %d jobs", len(result)) + terminate_parser = register_parser(terminate, parent=batch_parser, help="Terminate Batch jobs") terminate_parser.add_argument("job_id", nargs="+") terminate_parser.add_argument("--reason", help="A message to attach to the job conveying the reason for canceling it") + def ls(args, page_size=100): queues = args.queues or [q["jobQueueName"] for q in clients.batch.describe_job_queues()["jobQueues"]] @@ -372,19 +475,27 @@ def list_jobs_worker(list_jobs_worker_args): job_ids = sum(executor.map(list_jobs_worker, itertools.product(queues, args.status)), []) # type: List def describe_jobs_worker(start_index): - return clients.batch.describe_jobs(jobs=job_ids[start_index:start_index + page_size])["jobs"] + return clients.batch.describe_jobs(jobs=job_ids[start_index : start_index + page_size])["jobs"] table = sum(executor.map(describe_jobs_worker, range(0, len(job_ids), page_size)), []) # type: List page_output(tabulate(table, args, cell_transforms={"createdAt": Timestamp})) -job_status_colors = dict(SUBMITTED=YELLOW(), PENDING=YELLOW(), RUNNABLE=BOLD() + YELLOW(), - STARTING=GREEN(), RUNNING=GREEN(), - SUCCEEDED=BOLD() + GREEN(), FAILED=BOLD() + RED()) + +job_status_colors = dict( + SUBMITTED=YELLOW(), + PENDING=YELLOW(), + RUNNABLE=BOLD() + YELLOW(), + STARTING=GREEN(), + RUNNING=GREEN(), + SUCCEEDED=BOLD() + GREEN(), + FAILED=BOLD() + RED(), +) job_states = job_status_colors.keys() ls_parser = register_listing_parser(ls, parent=batch_parser, help="List Batch jobs") ls_parser.add_argument("--queues", nargs="+").completer = complete_queue_name ls_parser.add_argument("--status", nargs="+", default=job_states, choices=job_states) + def get_job_desc(job_id): try: return clients.batch.describe_jobs(jobs=[job_id])["jobs"][0] @@ -392,25 +503,30 @@ def get_job_desc(job_id): bucket = resources.s3.Bucket(f"aegea-batch-jobs-{ARN.get_account_id()}") return json.loads(bucket.Object(f"job_descriptions/{job_id}").get()["Body"].read()) + def describe(args): return get_job_desc(args.job_id) + describe_parser = register_parser(describe, parent=batch_parser, help="Describe a Batch job") describe_parser.add_argument("job_id") + def format_job_status(status): return job_status_colors[status] + status + ENDC() + def print_event(event): print(str(Timestamp(event["timestamp"])) + " " + event["message"]) + def get_logs(args, print_event_fn=print_event): - for event in CloudwatchLogReader(log_group_name=args.log_group_name, - log_stream_name=args.log_stream_name, - head=args.head, - tail=args.tail): + for event in CloudwatchLogReader( + log_group_name=args.log_group_name, log_stream_name=args.log_stream_name, head=args.head, tail=args.tail + ): print_event_fn(event) + def watch(args, print_event_fn=print_event): job_desc = get_job_desc(args.job_id) args.job_name = job_desc["jobName"] @@ -432,10 +548,12 @@ def watch(args, print_event_fn=print_event): if "logStreamName" in job_desc.get("container", {}): args.log_stream_name = job_desc["container"]["logStreamName"] if log_reader is None: - log_reader = CloudwatchLogReader(log_group_name=log_group_name, - log_stream_name=args.log_stream_name, - head=args.head, - tail=args.tail) + log_reader = CloudwatchLogReader( + log_group_name=log_group_name, + log_stream_name=args.log_stream_name, + head=args.head, + tail=args.tail, + ) for event in log_reader: print_event_fn(event) if "statusReason" in job_desc: @@ -451,6 +569,7 @@ def watch(args, print_event_fn=print_event): job_done = True time.sleep(1) + get_logs_parser = register_parser(get_logs, parent=batch_parser, help="Retrieve logs for a Batch job") get_logs_parser.add_argument("log_group_name", default="/aws/batch/job") get_logs_parser.add_argument("log_stream_name") @@ -458,10 +577,21 @@ def watch(args, print_event_fn=print_event): watch_parser.add_argument("job_id") for parser in get_logs_parser, watch_parser: lines_group = parser.add_mutually_exclusive_group() - lines_group.add_argument("--head", type=int, nargs="?", const=10, - help="Retrieve this number of lines from the beginning of the log (default 10)") - lines_group.add_argument("--tail", type=int, nargs="?", const=10, - help="Retrieve this number of lines from the end of the log (default 10)") + lines_group.add_argument( + "--head", + type=int, + nargs="?", + const=10, + help="Retrieve this number of lines from the beginning of the log (default 10)", + ) + lines_group.add_argument( + "--tail", + type=int, + nargs="?", + const=10, + help="Retrieve this number of lines from the end of the log (default 10)", + ) + def ssh(args): if not args.ssh_args: @@ -474,8 +604,9 @@ def ssh(args): if "containerInstanceArn" not in job_desc["container"]: raise AegeaException(f"Job {args.job_id} has not been dispatched to a container instance") ecs_ci_arn = job_desc["container"]["containerInstanceArn"] - ecs_ci_desc = clients.ecs.describe_container_instances(cluster=ce_desc["ecsClusterArn"], - containerInstances=[ecs_ci_arn])["containerInstances"][0] + ecs_ci_desc = clients.ecs.describe_container_instances( + cluster=ce_desc["ecsClusterArn"], containerInstances=[ecs_ci_arn] + )["containerInstances"][0] ecs_ci_ec2_id = ecs_ci_desc["ec2InstanceId"] logger.info(f"Job {args.job_id} is on EC2 instance {ecs_ci_ec2_id}") ecs_task_arn = job_desc["container"]["taskArn"] @@ -484,8 +615,10 @@ def ssh(args): raise AegeaException(f"No ECS task found for job {args.job_id}") container_id = res["tasks"][0]["containers"][0]["runtimeId"] logger.info(f"Job {args.job_id} is in container {container_id}") - ssh_to_ecs_container(instance_id=ecs_ci_ec2_id, container_id=container_id, ssh_args=args.ssh_args, - use_ssm=args.use_ssm) + ssh_to_ecs_container( + instance_id=ecs_ci_ec2_id, container_id=container_id, ssh_args=args.ssh_args, use_ssm=args.use_ssm + ) + ssh_parser = register_parser(ssh, parent=batch_parser, help="Log in to a running Batch job via SSH") ssh_parser.add_argument("job_id") diff --git a/aegea/billing.py b/aegea/billing.py index 6fcc3570..8ee29d5b 100644 --- a/aegea/billing.py +++ b/aegea/billing.py @@ -31,32 +31,46 @@ def billing(args): billing_parser.print_help() + billing_parser = register_parser(billing, help="Configure and view AWS cost and usage reports", description=__doc__) + def configure(args): bucket_name = args.billing_reports_bucket.format(account_id=ARN.get_account_id()) - bucket_policy = IAMPolicyBuilder(principal="arn:aws:iam::386209384616:root", - action=["s3:GetBucketAcl", "s3:GetBucketPolicy"], - resource="arn:aws:s3:::{}".format(bucket_name)) - bucket_policy.add_statement(principal="arn:aws:iam::386209384616:root", - action=["s3:PutObject"], - resource="arn:aws:s3:::{}/*".format(bucket_name)) + bucket_policy = IAMPolicyBuilder( + principal="arn:aws:iam::386209384616:root", + action=["s3:GetBucketAcl", "s3:GetBucketPolicy"], + resource="arn:aws:s3:::{}".format(bucket_name), + ) + bucket_policy.add_statement( + principal="arn:aws:iam::386209384616:root", + action=["s3:PutObject"], + resource="arn:aws:s3:::{}/*".format(bucket_name), + ) bucket = ensure_s3_bucket(bucket_name, policy=bucket_policy) try: - clients.cur.put_report_definition(ReportDefinition=dict(ReportName=__name__, - TimeUnit="HOURLY", - Format="textORcsv", - Compression="GZIP", - S3Bucket=bucket.name, - S3Prefix="aegea", - S3Region=clients.cur.meta.region_name, - AdditionalSchemaElements=["RESOURCES"])) + clients.cur.put_report_definition( + ReportDefinition=dict( + ReportName=__name__, + TimeUnit="HOURLY", + Format="textORcsv", + Compression="GZIP", + S3Bucket=bucket.name, + S3Prefix="aegea", + S3Region=clients.cur.meta.region_name, + AdditionalSchemaElements=["RESOURCES"], + ) + ) except clients.cur.exceptions.DuplicateReportNameException: pass - print("Configured cost and usage reports. Enable cost allocation tags: http://docs.aws.amazon.com/awsaccountbilling/latest/aboutv2/activate-built-in-tags.html.") # noqa + print( + "Configured cost and usage reports. Enable cost allocation tags: http://docs.aws.amazon.com/awsaccountbilling/latest/aboutv2/activate-built-in-tags.html." + ) # noqa + parser = register_parser(configure, parent=billing_parser) + def filter_line_items(items, args): for item in items: if args.min_cost and float(item["lineItem/BlendedCost"]) < args.min_cost: @@ -67,6 +81,7 @@ def filter_line_items(items, args): continue yield item + def ls(args): bucket = resources.s3.Bucket(args.billing_reports_bucket.format(account_id=ARN.get_account_id())) now = datetime.utcnow() @@ -88,12 +103,16 @@ def ls(args): msg = 'Unable to get report {} from {}: {}. Run "aegea billing configure" to enable reports.' raise AegeaException(msg.format(manifest_name, bucket, e)) + parser = register_parser(ls, parent=billing_parser, help="List contents of AWS cost and usage reports") parser.add_argument("--columns", nargs="+") parser.add_argument("--year", type=int, help="Year to get billing reports for. Defaults to current year") parser.add_argument("--month", type=int, help="Month (numeral) to get billing reports for. Defaults to current month") -parser.add_argument("--billing-reports-bucket", help="Name of S3 bucket to retrieve billing reports from", - default=config.billing_configure.billing_reports_bucket) # type: ignore +parser.add_argument( + "--billing-reports-bucket", + help="Name of S3 bucket to retrieve billing reports from", + default=config.billing_configure.billing_reports_bucket, +) # type: ignore parser.add_argument("--min-cost", type=float, help="Omit billing line items below this cost") parser.add_argument("--days", type=float, help="Only look at line items from this many past days") parser.add_argument("--by-user", action="store_true") diff --git a/aegea/build_ami.py b/aegea/build_ami.py index b007eca6..50a3fa8a 100644 --- a/aegea/build_ami.py +++ b/aegea/build_ami.py @@ -47,8 +47,9 @@ def wait(): for i in range(args.cloud_init_timeout_seconds // args.cloud_init_poll_interval_seconds): try: - run_command("sudo jq --exit-status .v1.errors==[] /var/lib/cloud/data/result.json", - instance_ids=[instance.id]) + run_command( + "sudo jq --exit-status .v1.errors==[] /var/lib/cloud/data/result.json", instance_ids=[instance.id] + ) break except clients.ssm.exceptions.InvalidInstanceId: wait() @@ -67,8 +68,13 @@ def wait(): image = instance.create_image(Name=args.name, Description=description, BlockDeviceMappings=get_bdm()) tags = dict(args.tags) base_ami = resources.ec2.Image(args.ami) - tags.update(Owner=ARN.get_iam_username(), AegeaVersion=__version__, - Base=base_ami.id, BaseName=base_ami.name, BaseDescription=base_ami.description or "") + tags.update( + Owner=ARN.get_iam_username(), + AegeaVersion=__version__, + Base=base_ami.id, + BaseName=base_ami.name, + BaseDescription=base_ami.description or "", + ) add_tags(image, **tags) logger.info("Waiting for %s to become available...", image.id) clients.ec2.get_waiter("image_available").wait(ImageIds=[image.id], WaiterConfig=dict(Delay=10, MaxAttempts=120)) @@ -79,24 +85,38 @@ def wait(): instance.terminate() return dict(ImageID=image.id, **tags) + parser = register_parser(build_ami, help="Build an EC2 AMI") parser.add_argument("name", help="Default: aegea-ARCH-YYYY-MM-DD-HH-MM", nargs="?") parser.add_argument("--snapshot-existing-host", type=str, metavar="HOST") parser.add_argument("--wait-for-ami", action="store_true") parser.add_argument("--ssh-key-name") parser.add_argument("--no-verify-ssh-key-pem-file", dest="verify_ssh_key_pem_file", action="store_false") -parser.add_argument("--instance-type", default=None, - help="Instance type to use for building AMI (default: c5.xlarge for x86_64, c6gd.xlarge for arm64)") -parser.add_argument("--architecture", default="x86_64", choices={"x86_64", "arm64"}, - help="CPU architecture for building the AMI") +parser.add_argument( + "--instance-type", + default=None, + help="Instance type to use for building AMI (default: c5.xlarge for x86_64, c6gd.xlarge for arm64)", +) +parser.add_argument( + "--architecture", default="x86_64", choices={"x86_64", "arm64"}, help="CPU architecture for building the AMI" +) parser.add_argument("--security-groups", nargs="+") parser.add_argument("--base-ami") -parser.add_argument("--base-ami-distribution", - help="Use AMI for this distribution (examples: Ubuntu:20.04, Amazon Linux:2") +parser.add_argument( + "--base-ami-distribution", help="Use AMI for this distribution (examples: Ubuntu:20.04, Amazon Linux:2" +) parser.add_argument("--dry-run", "--dryrun", action="store_true") -parser.add_argument("--tags", nargs="+", metavar="NAME=VALUE", type=lambda x: x.split("=", 1), - help="Tag the resulting AMI with these tags") +parser.add_argument( + "--tags", + nargs="+", + metavar="NAME=VALUE", + type=lambda x: x.split("=", 1), + help="Tag the resulting AMI with these tags", +) parser.add_argument("--cloud-config-data", type=json.loads) -parser.add_argument("--cloud-init-timeout-seconds", type=int, - help="Approximate time in seconds to wait for cloud-init to finish before aborting.") +parser.add_argument( + "--cloud-init-timeout-seconds", + type=int, + help="Approximate time in seconds to wait for cloud-init to finish before aborting.", +) parser.add_argument("--iam-role", default=__name__) diff --git a/aegea/build_docker_image.py b/aegea/build_docker_image.py index 72910e05..9f311b14 100644 --- a/aegea/build_docker_image.py +++ b/aegea/build_docker_image.py @@ -26,6 +26,7 @@ RUN {run} """ + def get_dockerfile(args): if args.dockerfile: return io.open(args.dockerfile, "rb").read() @@ -36,13 +37,16 @@ def get_dockerfile(args): "echo $CLOUD_CONFIG_B64 | base64 --decode > /etc/cloud/cloud.cfg.d/99_aegea.cfg", "cloud-init init", "cloud-init modules --mode=config", - "cloud-init modules --mode=final" + "cloud-init modules --mode=final", ] - return dockerfile.format(base_image=args.base_image, - maintainer=ARN.get_iam_username(), - label=" ".join(args.tags), - cloud_config_b64=base64.b64encode(get_cloud_config(args)).decode(), - run=json.dumps(cmd)).encode() + return dockerfile.format( + base_image=args.base_image, + maintainer=ARN.get_iam_username(), + label=" ".join(args.tags), + cloud_config_b64=base64.b64encode(get_cloud_config(args)).decode(), + run=json.dumps(cmd), + ).encode() + def encode_dockerfile(args): with io.BytesIO() as buf: @@ -51,33 +55,32 @@ def encode_dockerfile(args): gz.close() return base64.b64encode(buf.getvalue()).decode() + def get_cloud_config(args): - cloud_config_data = OrderedDict(packages=args.packages, - write_files=get_bootstrap_files(get_rootfs_skel_dirs(args)), - runcmd=args.commands) + cloud_config_data = OrderedDict( + packages=args.packages, write_files=get_bootstrap_files(get_rootfs_skel_dirs(args)), runcmd=args.commands + ) cloud_config_data.update(dict(args.cloud_config_data)) cloud_cfg_d = { "datasource_list": ["None"], - "datasource": { - "None": { - "userdata_raw": encode_cloud_config_payload(cloud_config_data, gzip=False) - } - } + "datasource": {"None": {"userdata_raw": encode_cloud_config_payload(cloud_config_data, gzip=False)}}, } return json.dumps(cloud_cfg_d).encode() + def ensure_ecr_repo(name, read_access=None): try: clients.ecr.create_repository(repositoryName=name) except clients.ecr.exceptions.RepositoryAlreadyExistsException: pass - policy = IAMPolicyBuilder(principal=dict(AWS=read_access), - action=["ecr:GetDownloadUrlForLayer", - "ecr:BatchGetImage", - "ecr:BatchCheckLayerAvailability"]) + policy = IAMPolicyBuilder( + principal=dict(AWS=read_access), + action=["ecr:GetDownloadUrlForLayer", "ecr:BatchGetImage", "ecr:BatchCheckLayerAvailability"], + ) if read_access: clients.ecr.set_repository_policy(repositoryName=name, policyText=str(policy)) + build_docker_image_shellcode = """#!/bin/bash set -euo pipefail apt-get update -qq @@ -94,15 +97,22 @@ def ensure_ecr_repo(name, read_access=None): docker build $CACHE_FROM -t "$TAG" . docker push "$TAG" """ # noqa + + def build_docker_image(args): for key, value in config.build_image.items(): getattr(args, key).extend(value) - args.tags += ["AegeaVersion={}".format(__version__), - 'description="Built by {} for {}"'.format(__name__, ARN.get_iam_username())] + args.tags += [ + "AegeaVersion={}".format(__version__), + 'description="Built by {} for {}"'.format(__name__, ARN.get_iam_username()), + ] ensure_ecr_repo(args.name, read_access=args.read_access) with tempfile.NamedTemporaryFile(mode="wt") as exec_fh: - exec_fh.write(build_docker_image_shellcode.format(dockerfile=encode_dockerfile(args), - use_cache=json.dumps(args.use_cache))) + exec_fh.write( + build_docker_image_shellcode.format( + dockerfile=encode_dockerfile(args), use_cache=json.dumps(args.use_cache) + ) + ) exec_fh.flush() submit_args = submit_parser.parse_args(["--execute", exec_fh.name]) submit_args.volumes = [["/var/run/docker.sock", "/var/run/docker.sock"]] @@ -114,24 +124,35 @@ def build_docker_image(args): dict(name="TAG", value=args.tag), dict(name="REPO", value=args.name), dict(name="AWS_DEFAULT_REGION", value=ARN.get_region()), - dict(name="AWS_ACCOUNT_ID", value=ARN.get_account_id()) + dict(name="AWS_ACCOUNT_ID", value=ARN.get_account_id()), ] builder_iam_role = ensure_iam_role(__name__, trust=["ecs-tasks"], policies=args.builder_iam_policies) submit_args.job_role = builder_iam_role.name job = submit(submit_args) return dict(job=job) + parser = register_parser(build_docker_image, help="Build an Elastic Container Registry Docker image") parser.add_argument("name") parser.add_argument("--tag", help="Docker image tag", default="latest") -parser.add_argument("--read-access", nargs="*", - help="AWS account IDs or IAM principal ARNs to grant read access. Use '*' to grant to all.") +parser.add_argument( + "--read-access", + nargs="*", + help="AWS account IDs or IAM principal ARNs to grant read access. Use '*' to grant to all.", +) parser.add_argument("--builder-image", default="ubuntu:18.04", help=argparse.SUPPRESS) -parser.add_argument("--builder-iam-policies", nargs="+", - default=["AmazonEC2FullAccess", "AmazonS3FullAccess", "AmazonEC2ContainerRegistryPowerUser"]) +parser.add_argument( + "--builder-iam-policies", + nargs="+", + default=["AmazonEC2FullAccess", "AmazonS3FullAccess", "AmazonEC2ContainerRegistryPowerUser"], +) parser.add_argument("--tags", nargs="+", default=[], metavar="NAME=VALUE", help="Tag resulting image with these tags") parser.add_argument("--cloud-config-data", type=json.loads) parser.add_argument("--dockerfile") -parser.add_argument("--no-cache", dest="use_cache", action="store_false", - help="Build image from scratch without re-using layers from build steps of prior versions") +parser.add_argument( + "--no-cache", + dest="use_cache", + action="store_false", + help="Build image from scratch without re-using layers from build steps of prior versions", +) parser.add_argument("--dry-run", action="store_true", help="Gather arguments and stop short of building the image") diff --git a/aegea/cloudtrail.py b/aegea/cloudtrail.py index d28ac91c..4f603c95 100644 --- a/aegea/cloudtrail.py +++ b/aegea/cloudtrail.py @@ -1,6 +1,7 @@ """ List CloudTrail trails. Query, filter, and print trail events. """ + import json from datetime import datetime @@ -13,14 +14,19 @@ def cloudtrail(args): cloudtrail_parser.print_help() -cloudtrail_parser = register_parser(cloudtrail, help="List CloudTrail trails and print trail events", - description=__doc__) + +cloudtrail_parser = register_parser( + cloudtrail, help="List CloudTrail trails and print trail events", description=__doc__ +) + def ls(args): page_output(tabulate(clients.cloudtrail.describe_trails()["trailList"], args)) + parser = register_parser(ls, parent=cloudtrail_parser, help="List CloudTrail trails") + def print_cloudtrail_event(event): log_record = json.loads(event["CloudTrailEvent"]) user_identity = log_record["userIdentity"] @@ -31,6 +37,7 @@ def print_cloudtrail_event(event): request_params = json.dumps(log_record.get("requestParameters")) print(event["EventTime"], user_identity, log_record["eventType"], log_record["eventName"], request_params) + def lookup(args): lookup_args = dict(LookupAttributes=[{"AttributeKey": k, "AttributeValue": v} for k, v in args.attributes]) if args.start_time: @@ -39,9 +46,10 @@ def lookup(args): lookup_args.update(EndTime=args.end_time) if args.category: lookup_args.update(EventCategory=args.category) - for event in paginate(clients.cloudtrail.get_paginator('lookup_events'), **lookup_args): + for event in paginate(clients.cloudtrail.get_paginator("lookup_events"), **lookup_args): print_cloudtrail_event(event) + parser = register_parser(lookup, parent=cloudtrail_parser, help="Query and print CloudTrail events") parser.add_argument("--attributes", nargs="+", metavar="NAME=VALUE", type=lambda x: x.split("=", 1), default=[]) parser.add_argument("--category") diff --git a/aegea/cost.py b/aegea/cost.py index b5d17779..59fc6cb5 100644 --- a/aegea/cost.py +++ b/aegea/cost.py @@ -19,10 +19,13 @@ def format_float(f): except Exception: return f + def get_common_method_args(args): - return dict(Granularity=args.granularity, - TimePeriod=dict(Start=args.time_period_start.date().isoformat(), - End=args.time_period_end.date().isoformat())) + return dict( + Granularity=args.granularity, + TimePeriod=dict(Start=args.time_period_start.date().isoformat(), End=args.time_period_end.date().isoformat()), + ) + def cost(args): if not (args.group_by or args.group_by_tag): @@ -53,40 +56,106 @@ def cost(args): rows = sorted(rows, key=lambda row: -row["TOTAL"]) # type: ignore page_output(tabulate(rows, args, cell_transforms=cell_transforms)) + parser_cost = register_parser(cost, help="List AWS costs") -parser_cost.add_argument("--time-period-start", type=Timestamp, default=Timestamp("-7d"), - help="Time to start cost history." + Timestamp.__doc__) # type: ignore -parser_cost.add_argument("--time-period-end", type=Timestamp, default=Timestamp("-0d"), - help="Time to end cost history." + Timestamp.__doc__) # type: ignore +parser_cost.add_argument( + "--time-period-start", + type=Timestamp, + default=Timestamp("-7d"), + help="Time to start cost history." + Timestamp.__doc__, # type: ignore +) +parser_cost.add_argument( + "--time-period-end", + type=Timestamp, + default=Timestamp("-0d"), + help="Time to end cost history." + Timestamp.__doc__, # type: ignore +) parser_cost.add_argument("--granularity", choices={"HOURLY", "DAILY", "MONTHLY"}, help="AWS cost granularity") -parser_cost.add_argument("--metrics", nargs="+", default=["AmortizedCost"], - choices={"AmortizedCost", "BlendedCost", "NetAmortizedCost", "NetUnblendedCost", - "NormalizedUsageAmount", "UnblendedCost", "UsageQuantity"}) -parser_cost.add_argument("--group-by", nargs="+", default=[], - choices={"AZ", "INSTANCE_TYPE", "LINKED_ACCOUNT", "OPERATION", "PURCHASE_TYPE", "SERVICE", - "REGION", "USAGE_TYPE", "PLATFORM", "TENANCY", "RECORD_TYPE", "LEGAL_ENTITY_NAME", - "DEPLOYMENT_OPTION", "DATABASE_ENGINE", "CACHE_ENGINE", "INSTANCE_TYPE_FAMILY", - "BILLING_ENTITY", "RESERVATION_ID", "SAVINGS_PLANS_TYPE", "SAVINGS_PLAN_ARN"}) +parser_cost.add_argument( + "--metrics", + nargs="+", + default=["AmortizedCost"], + choices={ + "AmortizedCost", + "BlendedCost", + "NetAmortizedCost", + "NetUnblendedCost", + "NormalizedUsageAmount", + "UnblendedCost", + "UsageQuantity", + }, +) +parser_cost.add_argument( + "--group-by", + nargs="+", + default=[], + choices={ + "AZ", + "INSTANCE_TYPE", + "LINKED_ACCOUNT", + "OPERATION", + "PURCHASE_TYPE", + "SERVICE", + "REGION", + "USAGE_TYPE", + "PLATFORM", + "TENANCY", + "RECORD_TYPE", + "LEGAL_ENTITY_NAME", + "DEPLOYMENT_OPTION", + "DATABASE_ENGINE", + "CACHE_ENGINE", + "INSTANCE_TYPE_FAMILY", + "BILLING_ENTITY", + "RESERVATION_ID", + "SAVINGS_PLANS_TYPE", + "SAVINGS_PLAN_ARN", + }, +) parser_cost.add_argument("--group-by-tag", nargs="+", default=[]) parser_cost.add_argument("--min-total", type=int, help="Omit rows that total below this number") + def cost_forecast(args): get_cost_forecast_args = dict(get_common_method_args(args), Metric=args.metric, PredictionIntervalLevel=75) res = clients.ce.get_cost_forecast(**get_cost_forecast_args) args.columns = ["TimePeriod.Start", "MeanValue", "PredictionIntervalLowerBound", "PredictionIntervalUpperBound"] - cell_transforms = {col: format_float - for col in ["MeanValue", "PredictionIntervalLowerBound", "PredictionIntervalUpperBound"]} + cell_transforms = { + col: format_float for col in ["MeanValue", "PredictionIntervalLowerBound", "PredictionIntervalUpperBound"] + } title = "TOTAL ({})".format(boto3.session.Session().profile_name) table = res["ForecastResultsByTime"] + [{"TimePeriod": {"Start": title}, "MeanValue": res["Total"]["Amount"]}] page_output(tabulate(table, args, cell_transforms=cell_transforms)) + parser_cost_forecast = register_parser(cost_forecast, help="List AWS cost forecasts") -parser_cost_forecast.add_argument("--time-period-start", type=Timestamp, default=Timestamp("1d"), - help="Time to start cost forecast." + Timestamp.__doc__) # type: ignore -parser_cost_forecast.add_argument("--time-period-end", type=Timestamp, default=Timestamp("7d"), - help="Time to end cost forecast." + Timestamp.__doc__) # type: ignore -parser_cost_forecast.add_argument("--granularity", choices={"HOURLY", "DAILY", "MONTHLY"}, - help="Up to 3 months of DAILY forecasts or 12 months of MONTHLY forecasts") -parser_cost_forecast.add_argument("--metric", help="Which metric Cost Explorer uses to create your forecast", - choices={"USAGE_QUANTITY", "UNBLENDED_COST", "NET_UNBLENDED_COST", "AMORTIZED_COST", - "NET_AMORTIZED_COST", "BLENDED_COST", "NORMALIZED_USAGE_AMOUNT"}) +parser_cost_forecast.add_argument( + "--time-period-start", + type=Timestamp, + default=Timestamp("1d"), + help="Time to start cost forecast." + Timestamp.__doc__, # type: ignore +) +parser_cost_forecast.add_argument( + "--time-period-end", + type=Timestamp, + default=Timestamp("7d"), + help="Time to end cost forecast." + Timestamp.__doc__, # type: ignore +) +parser_cost_forecast.add_argument( + "--granularity", + choices={"HOURLY", "DAILY", "MONTHLY"}, + help="Up to 3 months of DAILY forecasts or 12 months of MONTHLY forecasts", +) +parser_cost_forecast.add_argument( + "--metric", + help="Which metric Cost Explorer uses to create your forecast", + choices={ + "USAGE_QUANTITY", + "UNBLENDED_COST", + "NET_UNBLENDED_COST", + "AMORTIZED_COST", + "NET_AMORTIZED_COST", + "BLENDED_COST", + "NORMALIZED_USAGE_AMOUNT", + }, +) diff --git a/aegea/ebs.py b/aegea/ebs.py index 9d7e1070..68f25230 100644 --- a/aegea/ebs.py +++ b/aegea/ebs.py @@ -24,28 +24,36 @@ def complete_volume_id(**kwargs): return [i["VolumeId"] for i in clients.ec2.describe_volumes()["Volumes"]] + def ebs(args): ebs_parser.print_help() + ebs_parser = register_parser(ebs, help="Manage Elastic Block Store resources", description=__doc__) + def ls(args): @lru_cache() def instance_id_to_name(i): return add_name(resources.ec2.Instance(i)).name + table = [{f: get_cell(i, f) for f in args.columns} for i in filter_collection(resources.ec2.volumes, args)] if "attachments" in args.columns: for row in table: row["attachments"] = ", ".join(instance_id_to_name(a["InstanceId"]) for a in row["attachments"]) page_output(tabulate(table, args)) + parser = register_filtering_parser(ls, parent=ebs_parser, help="List EC2 EBS volumes") + def snapshots(args): page_output(filter_and_tabulate(resources.ec2.snapshots.filter(OwnerIds=[ARN.get_account_id()]), args)) + parser = register_filtering_parser(snapshots, parent=ebs_parser, help="List EC2 EBS snapshots") + def create(args): if (args.format or args.mount) and not args.attach: raise SystemExit("Arguments --format and --mount require --attach") @@ -72,25 +80,31 @@ def create(args): raise return res + parser_create = register_parser(create, parent=ebs_parser, help="Create an EBS volume") parser_create.add_argument("--dry-run", action="store_true") parser_create.add_argument("--snapshot-id") parser_create.add_argument("--availability-zone") parser_create.add_argument("--kms-key-id") parser_create.add_argument("--tags", nargs="+", metavar="TAG_NAME=VALUE") -parser_create.add_argument("--attach", action="store_true", - help="Attach volume to this instance (only valid when running on EC2)") +parser_create.add_argument( + "--attach", action="store_true", help="Attach volume to this instance (only valid when running on EC2)" +) + def snapshot(args): return clients.ec2.create_snapshot(DryRun=args.dry_run, VolumeId=args.volume_id) + + parser_snapshot = register_parser(snapshot, parent=ebs_parser, help="Create an EBS snapshot") parser_snapshot.add_argument("volume_id").completer = complete_volume_id + def attach_volume(args): - return clients.ec2.attach_volume(DryRun=args.dry_run, - VolumeId=args.volume_id, - InstanceId=args.instance, - Device=args.device) + return clients.ec2.attach_volume( + DryRun=args.dry_run, VolumeId=args.volume_id, InstanceId=args.instance, Device=args.device + ) + def find_volume_id(mountpoint): with open("/proc/mounts") as fh: @@ -105,7 +119,10 @@ def find_volume_id(mountpoint): break else: raise Exception(f"EBS volume ID not found for mountpoint {mountpoint} (devnode {devnode})") - return re.search(r"Elastic_Block_Store_(vol[\w]+)", devnode_link).group(1).replace("vol", "vol-") + ebs_vol_id_match = re.search(r"Elastic_Block_Store_(vol[\w]+)", devnode_link) + assert ebs_vol_id_match is not None + return ebs_vol_id_match.group(1).replace("vol", "vol-") + def find_devnode(volume_id): if os.path.exists("/dev/disk/by-id"): @@ -119,9 +136,11 @@ def find_devnode(volume_id): return "/dev/" + attachment["Device"] raise Exception(f"Could not find devnode for {volume_id}") + def get_fs_label(volume_id): return "aegv" + volume_id[4:12] + def attach(args): if args.instance is None: args.instance = get_metadata("instance-id") @@ -158,16 +177,26 @@ def attach(args): logger.info("Mounting %s at %s", args.volume_id, args.mount) subprocess.check_call(["mount", find_devnode(args.volume_id), args.mount], stdout=sys.stderr.buffer) return res + + parser_attach = register_parser(attach, parent=ebs_parser, help="Attach an EBS volume to an EC2 instance") parser_attach.add_argument("volume_id").completer = complete_volume_id parser_attach.add_argument("instance", type=resolve_instance_id, nargs="?") -parser_attach.add_argument("--device", choices=["xvd" + chr(i + 1) for i in range(ord("a"), ord("z"))], - help="Device node to attach volume to. Default: auto-select the first available node") +parser_attach.add_argument( + "--device", + choices=["xvd" + chr(i + 1) for i in range(ord("a"), ord("z"))], + help="Device node to attach volume to. Default: auto-select the first available node", +) for parser in parser_create, parser_attach: - parser.add_argument("--format", nargs="?", const="xfs", - help="Use this command and arguments to format volume after attaching (only valid on EC2)") + parser.add_argument( + "--format", + nargs="?", + const="xfs", + help="Use this command and arguments to format volume after attaching (only valid on EC2)", + ) parser.add_argument("--mount", nargs="?", const="/mnt", help="Mount volume on given mountpoint (only valid on EC2)") + def detach(args): """ Detach an EBS volume from an EC2 instance. @@ -184,22 +213,27 @@ def detach(args): cmd = "umount {devnode} || (kill -9 $(lsof -t +f -- $(readlink -f {devnode}) | sort | uniq); umount {devnode} || umount -l {devnode})" # noqa subprocess.call(cmd.format(devnode=find_devnode(volume_id)), shell=True) attachment = resources.ec2.Volume(volume_id).attachments[0] - res = clients.ec2.detach_volume(DryRun=args.dry_run, - VolumeId=volume_id, - InstanceId=attachment["InstanceId"], - Device=attachment["Device"], - Force=args.force) + res = clients.ec2.detach_volume( + DryRun=args.dry_run, + VolumeId=volume_id, + InstanceId=attachment["InstanceId"], + Device=attachment["Device"], + Force=args.force, + ) clients.ec2.get_waiter("volume_available").wait(VolumeIds=[volume_id]) if args.delete: logger.info("Deleting EBS volume %s", volume_id) clients.ec2.delete_volume(VolumeId=volume_id, DryRun=args.dry_run) return res + + parser_detach = register_parser(detach, parent=ebs_parser) parser_detach.add_argument("volume_id", help="EBS volume ID or mountpoint").completer = complete_volume_id parser_detach.add_argument("--unmount", action="store_true", help="Unmount the volume before detaching") parser_detach.add_argument("--delete", action="store_true", help="Delete the volume after detaching") parser_detach.add_argument("--force", action="store_true") + def modify(args): modify_args = dict(VolumeId=args.volume_id, DryRun=args.dry_run) if args.size: @@ -214,13 +248,18 @@ def modify(args): # "optimizing", "pathAny") # waiter.wait(VolumeIds=[args.volume_id]) return res + + parser_modify = register_parser(modify, parent=ebs_parser, help="Change the size, type, or IOPS of an EBS volume") parser_modify.add_argument("volume_id").completer = complete_volume_id for parser in parser_create, parser_modify: parser.add_argument("--size-gb", dest="size", type=int, help="Volume size in gigabytes") - parser.add_argument("--volume-type", choices={"standard", "io1", "gp2", "sc1", "st1"}, - help="io1, PIOPS SSD; gp2, general purpose SSD; sc1, cold HDD; st1, throughput optimized HDD") + parser.add_argument( + "--volume-type", + choices={"standard", "io1", "gp2", "sc1", "st1"}, + help="io1, PIOPS SSD; gp2, general purpose SSD; sc1, cold HDD; st1, throughput optimized HDD", + ) parser.add_argument("--iops", type=int) for parser in parser_snapshot, parser_attach, parser_detach, parser_modify: diff --git a/aegea/ecr.py b/aegea/ecr.py index 3bef869a..c4864353 100644 --- a/aegea/ecr.py +++ b/aegea/ecr.py @@ -16,8 +16,10 @@ def ecr(args): ecr_parser.print_help() + ecr_parser = register_parser(ecr, help="Manage Elastic Container Registry resources", description=__doc__) + def ls(args): table = [] # type: List[Dict] describe_repositories_args = dict(repositoryNames=args.repositories) if args.repositories else {} @@ -35,16 +37,20 @@ def ls(args): table = sorted(table, key=lambda r: r["repositoryName"] + str(r.get("imagePushedAt"))) page_output(tabulate(table, args)) + ls_parser = register_listing_parser(ls, parent=ecr_parser, help="List ECR repos and images") ls_parser.add_argument("repositories", nargs="*") + def ecr_image_name_completer(**kwargs): return (r["repositoryName"] for r in paginate(clients.ecr.get_paginator("describe_repositories"))) + def retag(args): if "dkr.ecr" in args.repository and "amazonaws.com" in args.repository: - if not args.repository.startswith("{}.dkr.ecr.{}.amazonaws.com/".format(ARN.get_account_id(), - clients.ecr.meta.region_name)): + if not args.repository.startswith( + "{}.dkr.ecr.{}.amazonaws.com/".format(ARN.get_account_id(), clients.ecr.meta.region_name) + ): raise AegeaException("Unexpected repository ID {}".format(args.repository)) args.repository = args.repository.split("/", 1)[1] image_id_key = "imageDigest" if len(args.existing_tag_or_digest) == 64 else "imageTag" @@ -54,9 +60,10 @@ def retag(args): break else: raise AegeaException("No image found for tag or digest {}".format(args.existing_tag_or_digest)) - return clients.ecr.put_image(repositoryName=args.repository, - imageManifest=image["imageManifest"], - imageTag=args.new_tag) + return clients.ecr.put_image( + repositoryName=args.repository, imageManifest=image["imageManifest"], imageTag=args.new_tag + ) + retag_parser = register_parser(retag, parent=ecr_parser, help="Add a new tag to an existing image") retag_parser.add_argument("repository").completer = ecr_image_name_completer diff --git a/aegea/ecs.py b/aegea/ecs.py index f1bc30ff..0d4871fe 100644 --- a/aegea/ecs.py +++ b/aegea/ecs.py @@ -39,20 +39,25 @@ def complete_cluster_name(**kwargs): return [ARN(c).resource.partition("/")[2] for c in paginate(clients.ecs.get_paginator("list_clusters"))] + def ecs(args): ecs_parser.print_help() + ecs_parser = register_parser(ecs, help="Manage Elastic Container Service resources", description=__doc__) + def clusters(args): if not args.clusters: args.clusters = list(paginate(clients.ecs.get_paginator("list_clusters"))) cluster_desc = clients.ecs.describe_clusters(clusters=args.clusters)["clusters"] page_output(tabulate(cluster_desc, args)) + clusters_parser = register_listing_parser(clusters, parent=ecs_parser, help="List ECS clusters") clusters_parser.add_argument("clusters", nargs="*").completer = complete_cluster_name + def get_task_descs(cluster_names, task_names=None, desired_status=frozenset(["RUNNING", "STOPPED"])): list_tasks = clients.ecs.get_paginator("list_tasks") @@ -70,10 +75,11 @@ def describe_tasks_worker(t, cluster=None): with ThreadPoolExecutor() as executor: for cluster, status, tasks in executor.map(list_tasks_worker, product(cluster_names, desired_status)): worker = partial(describe_tasks_worker, cluster=cluster) - descs = executor.map(worker, (tasks[pos:pos + 100] for pos in range(0, len(tasks), 100))) + descs = executor.map(worker, (tasks[pos : pos + 100] for pos in range(0, len(tasks), 100))) task_descs += sum(descs, []) return task_descs + def tasks(args): list_clusters = clients.ecs.get_paginator("list_clusters") if args.clusters is None: @@ -81,12 +87,14 @@ def tasks(args): task_descs = get_task_descs(cluster_names=args.clusters, task_names=args.tasks, desired_status=args.desired_status) page_output(tabulate(task_descs, args)) + tasks_parser = register_listing_parser(tasks, parent=ecs_parser, help="List ECS tasks") tasks_parser.add_argument("tasks", nargs="*") tasks_parser.add_argument("--clusters", nargs="*").completer = complete_cluster_name tasks_parser.add_argument("--desired-status", nargs=1, choices={"RUNNING", "STOPPED"}, default=["RUNNING", "STOPPED"]) tasks_parser.add_argument("--launch-type", nargs=1, choices={"EC2", "FARGATE"}, default=["EC2", "FARGATE"]) + def run(args): args.storage = args.efs_storage = args.mount_instance_storage = None command, environment = get_command_and_env(args) @@ -97,8 +105,8 @@ def run(args): "options": { "awslogs-region": clients.ecs.meta.region_name, "awslogs-group": args.task_name, - "awslogs-stream-prefix": args.task_name - } + "awslogs-stream-prefix": args.task_name, + }, } ensure_log_group(log_config["options"]["awslogs-group"]) # type: ignore @@ -108,34 +116,38 @@ def run(args): volumes, mount_points = get_volumes_and_mountpoints(args) if args.memory is None: if args.fargate_memory.endswith("GB"): - args.memory = int(args.fargate_memory[:-len("GB")]) * 1024 + args.memory = int(args.fargate_memory[: -len("GB")]) * 1024 else: args.memory = int(args.fargate_memory) - container_defn = dict(name=args.task_name, - image=args.image, - cpu=0, - memory=args.memory, - user=args.user, - command=[], - environment=[], - portMappings=[], - essential=True, - logConfiguration=log_config, - mountPoints=[dict(sourceVolume="scratch", containerPath="/mnt")] + mount_points, - volumesFrom=[]) + container_defn = dict( + name=args.task_name, + image=args.image, + cpu=0, + memory=args.memory, + user=args.user, + command=[], + environment=[], + portMappings=[], + essential=True, + logConfiguration=log_config, + mountPoints=[dict(sourceVolume="scratch", containerPath="/mnt")] + mount_points, + volumesFrom=[], + ) set_ulimits(args, container_defn) exec_role = ensure_fargate_execution_role(args.execution_role) task_role = ensure_iam_role(args.task_role, trust=["ecs-tasks"]) - expect_task_defn = dict(containerDefinitions=[container_defn], - requiresCompatibilities=["FARGATE"], - taskRoleArn=task_role.arn, - executionRoleArn=exec_role.arn, - networkMode="awsvpc", - cpu=args.fargate_cpu, - memory=args.fargate_memory, - volumes=[dict(name="scratch", host={})] + volumes) + expect_task_defn = dict( + containerDefinitions=[container_defn], + requiresCompatibilities=["FARGATE"], + taskRoleArn=task_role.arn, + executionRoleArn=exec_role.arn, + networkMode="awsvpc", + cpu=args.fargate_cpu, + memory=args.fargate_memory, + volumes=[dict(name="scratch", host={})] + volumes, + ) task_hash = hashlib.sha256(json.dumps(expect_task_defn, sort_keys=True).encode()).hexdigest()[:8] task_defn_name = __name__.replace(".", "_") + "_" + task_hash @@ -144,13 +156,20 @@ def run(args): task_defn = clients.ecs.describe_task_definition(taskDefinition=task_defn_name)["taskDefinition"] assert task_defn["status"] == "ACTIVE" assert "FARGATE" in task_defn["compatibilities"] - desc_keys = ["family", "revision", "taskDefinitionArn", "status", "compatibilities", "placementConstraints", - "requiresAttributes"] + desc_keys = [ + "family", + "revision", + "taskDefinitionArn", + "status", + "compatibilities", + "placementConstraints", + "requiresAttributes", + ] task_desc = {key: task_defn.pop(key) for key in desc_keys} if expect_task_defn["cpu"].endswith(" vCPU"): - expect_task_defn["cpu"] = str(int(expect_task_defn["cpu"][:-len(" vCPU")]) * 1024) + expect_task_defn["cpu"] = str(int(expect_task_defn["cpu"][: -len(" vCPU")]) * 1024) if expect_task_defn["memory"].endswith("GB"): - expect_task_defn["memory"] = str(int(expect_task_defn["memory"][:-len("GB")]) * 1024) + expect_task_defn["memory"] = str(int(expect_task_defn["memory"][: -len("GB")]) * 1024) for key in expect_task_defn: assert expect_task_defn[key] == task_defn[key] logger.info("Reusing task definition %s", task_desc["taskDefinitionArn"]) @@ -160,20 +179,20 @@ def run(args): network_config = { "awsvpcConfiguration": { - "subnets": [ - subnet.id for subnet in vpc.subnets.all() - ], + "subnets": [subnet.id for subnet in vpc.subnets.all()], "securityGroups": [ensure_security_group(args.security_group, vpc).id], - "assignPublicIp": "ENABLED" + "assignPublicIp": "ENABLED", } } container_overrides = [dict(name=args.task_name, command=command, environment=environment)] - run_args = dict(cluster=args.cluster, - taskDefinition=task_desc["taskDefinitionArn"], - launchType="FARGATE", - platformVersion=args.fargate_platform_version, - networkConfiguration=network_config, - overrides=dict(containerOverrides=container_overrides)) + run_args = dict( + cluster=args.cluster, + taskDefinition=task_desc["taskDefinitionArn"], + launchType="FARGATE", + platformVersion=args.fargate_platform_version, + networkConfiguration=network_config, + overrides=dict(containerOverrides=container_overrides), + ) if args.tags: run_args["tags"] = encode_tags(args.tags, case="lower") if args.dry_run: @@ -193,6 +212,7 @@ def run(args): else: return res["tasks"][0] + register_parser_args = dict(parent=ecs_parser, help="Run a Fargate task") register_parser_args["aliases"] = ["launch"] @@ -208,32 +228,44 @@ def run(args): run_parser.add_argument("--dry-run", action="store_true", help="Gather arguments and stop short of running task") fargate_group = run_parser.add_argument_group( - description=("Resource allocation for the Fargate task VM, which runs the task Docker container(s): " - "(See also https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-cpu-memory-error.html)") + description=( + "Resource allocation for the Fargate task VM, which runs the task Docker container(s): " + "(See also https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-cpu-memory-error.html)" + ) ) fargate_group.add_argument( - "--fargate-cpu", help="vCPUs to allocate to the Fargate task", - choices=[".25 vCPU", ".5 vCPU", "1 vCPU", "2 vCPU", "4 vCPU", "256", "512", "1024", "2048", "4096"] + "--fargate-cpu", + help="vCPUs to allocate to the Fargate task", + choices=[".25 vCPU", ".5 vCPU", "1 vCPU", "2 vCPU", "4 vCPU", "256", "512", "1024", "2048", "4096"], ) fargate_group.add_argument( - "--fargate-memory", help="Memory to allocate to the Fargate task", - choices=["0.5GB"] + ["{}GB".format(i) for i in range(1, 31)] + ["512"] + list(map(str, range(1024, 30721, 1024))) + "--fargate-memory", + help="Memory to allocate to the Fargate task", + choices=["0.5GB"] + ["{}GB".format(i) for i in range(1, 31)] + ["512"] + list(map(str, range(1024, 30721, 1024))), +) + +task_status_colors = dict( + PROVISIONING=YELLOW(), + PENDING=BOLD() + YELLOW(), + ACTIVATING=BOLD() + YELLOW(), + RUNNING=GREEN(), + DEACTIVATING=BOLD() + GREEN(), + STOPPING=BOLD() + GREEN(), + DEPROVISIONING=BOLD() + GREEN(), + STOPPED=BOLD() + GREEN(), ) -task_status_colors = dict(PROVISIONING=YELLOW(), PENDING=BOLD() + YELLOW(), ACTIVATING=BOLD() + YELLOW(), - RUNNING=GREEN(), - DEACTIVATING=BOLD() + GREEN(), STOPPING=BOLD() + GREEN(), DEPROVISIONING=BOLD() + GREEN(), - STOPPED=BOLD() + GREEN()) def format_task_status(status): return task_status_colors[status] + status + ENDC() + def watch(args): logger.info("Watching task %s (%s)", args.task_id, args.cluster) last_status, events_received = None, 0 log_reader = CloudwatchLogReader( log_group_name=args.task_name, - log_stream_name="/".join([args.task_name, args.task_name, os.path.basename(args.task_id)]) + log_stream_name="/".join([args.task_name, args.task_name, os.path.basename(args.task_id)]), ) while last_status != "STOPPED": res = clients.ecs.describe_tasks(cluster=args.cluster, tasks=[args.task_id]) @@ -252,25 +284,33 @@ def watch(args): break # Logs retrieved successfully but task record is no longer in ECS time.sleep(1) + watch_parser = register_parser(watch, parent=ecs_parser, help="Monitor a running ECS Fargate task and stream its logs") watch_parser.add_argument("task_id") watch_parser.add_argument("--cluster", default=__name__.replace(".", "_")) watch_parser.add_argument("--task-name", default=__name__.replace(".", "_")) lines_group = watch_parser.add_mutually_exclusive_group() -lines_group.add_argument("--head", type=int, nargs="?", const=10, - help="Retrieve this number of lines from the beginning of the log (default 10)") -lines_group.add_argument("--tail", type=int, nargs="?", const=10, - help="Retrieve this number of lines from the end of the log (default 10)") +lines_group.add_argument( + "--head", + type=int, + nargs="?", + const=10, + help="Retrieve this number of lines from the beginning of the log (default 10)", +) +lines_group.add_argument( + "--tail", type=int, nargs="?", const=10, help="Retrieve this number of lines from the end of the log (default 10)" +) + def stop(args): - return clients.ecs.stop_task(cluster=args.cluster, - task=args.task_id, - reason="Stopped by {}".format(__name__)) + return clients.ecs.stop_task(cluster=args.cluster, task=args.task_id, reason="Stopped by {}".format(__name__)) + stop_parser = register_parser(stop, parent=ecs_parser, help="Stop a running ECS Fargate task") stop_parser.add_argument("task_id") stop_parser.add_argument("--cluster", default=__name__.replace(".", "_")) + def ssh(args): if not args.ssh_args: args.ssh_args = ["/bin/bash", "-l"] @@ -284,14 +324,17 @@ def ssh(args): raise AegeaException('No task found with name "{}" in cluster "{}"'.format(args.task_name, args.cluster_name)) ecs_ci_arn = task_desc["containerInstanceArn"] - ecs_ci_desc = clients.ecs.describe_container_instances(cluster=task_desc["clusterArn"], - containerInstances=[ecs_ci_arn])["containerInstances"][0] + ecs_ci_desc = clients.ecs.describe_container_instances( + cluster=task_desc["clusterArn"], containerInstances=[ecs_ci_arn] + )["containerInstances"][0] ecs_ci_ec2_id = ecs_ci_desc["ec2InstanceId"] logger.info("Task {} is on EC2 instance {}".format(args.task_name, ecs_ci_ec2_id)) container_id = task_desc["containers"][0]["runtimeId"] logger.info("Task {} is in container {}".format(args.task_name, container_id)) - ssh_to_ecs_container(instance_id=ecs_ci_ec2_id, container_id=container_id, ssh_args=args.ssh_args, - use_ssm=args.use_ssm) + ssh_to_ecs_container( + instance_id=ecs_ci_ec2_id, container_id=container_id, ssh_args=args.ssh_args, use_ssm=args.use_ssm + ) + ssh_parser = register_parser(ssh, parent=ecs_parser, help="Log in to a running ECS container via SSH") ssh_parser.add_argument("cluster_name") diff --git a/aegea/efs.py b/aegea/efs.py index e819f363..4aa22d90 100644 --- a/aegea/efs.py +++ b/aegea/efs.py @@ -27,8 +27,10 @@ def efs(args): efs_parser.print_help() + efs_parser = register_parser(efs, help="Manage Elastic Filesystem resources", description=__doc__) + def ls(args): table = [] for filesystem in clients.efs.describe_file_systems()["FileSystems"]: @@ -38,21 +40,26 @@ def ls(args): args.columns += args.mount_target_columns page_output(tabulate(table, args, cell_transforms={"SizeInBytes": lambda x, r: x.get("Value") if x else None})) + parser = register_listing_parser(ls, parent=efs_parser, help="List EFS filesystems") parser.add_argument("--mount-target-columns", nargs="+") + def create(args): vpc = resources.ec2.Vpc(args.vpc) if args.vpc else ensure_vpc() if args.security_groups is None: args.security_groups = [__name__] - ensure_security_group(__name__, vpc, tcp_ingress=[dict(port=socket.getservbyname("nfs"), - source_security_group_name=__name__)]) + ensure_security_group( + __name__, vpc, tcp_ingress=[dict(port=socket.getservbyname("nfs"), source_security_group_name=__name__)] + ) creation_token = base64.b64encode(bytearray(os.urandom(24))).decode() args.tags.append("Name=" + args.name) - create_file_system_args = dict(CreationToken=creation_token, - PerformanceMode=args.performance_mode, - ThroughputMode=args.throughput_mode, - Tags=encode_tags(args.tags)) + create_file_system_args = dict( + CreationToken=creation_token, + PerformanceMode=args.performance_mode, + ThroughputMode=args.throughput_mode, + Tags=encode_tags(args.tags), + ) if args.throughput_mode == "provisioned": create_file_system_args.update(ProvisionedThroughputInMibps=args.provisioned_throughput_in_mibps) fs = clients.efs.create_file_system(**create_file_system_args) @@ -61,12 +68,13 @@ def create(args): waiter.wait(FileSystemId=fs["FileSystemId"]) security_groups = [resolve_security_group(g, vpc).id for g in args.security_groups] for subnet in vpc.subnets.all(): - mount_target = clients.efs.create_mount_target(FileSystemId=fs["FileSystemId"], - SubnetId=subnet.id, - SecurityGroups=security_groups) + mount_target = clients.efs.create_mount_target( + FileSystemId=fs["FileSystemId"], SubnetId=subnet.id, SecurityGroups=security_groups + ) logger.info("Created EFS mount target %s in %s", mount_target["MountTargetId"], mount_target["SubnetId"]) return fs + parser_create = register_parser(create, parent=efs_parser, help="Create an EFS filesystem") parser_create.add_argument("name") parser_create.add_argument("--performance-mode", choices={"generalPurpose", "maxIO"}, default="generalPurpose") diff --git a/aegea/elb.py b/aegea/elb.py index 15084c71..422f692a 100644 --- a/aegea/elb.py +++ b/aegea/elb.py @@ -30,12 +30,15 @@ def elb(args): elb_parser.print_help() + elb_parser = register_parser(elb, help="Manage Elastic Load Balancers", description=__doc__) + def ls(args): @lru_cache() def sgid_to_name(i): return resources.ec2.SecurityGroup(i).group_name + table = [] dns_aliases = get_elb_dns_aliases() for row in paginate(clients.elb.get_paginator("describe_load_balancers")): @@ -54,9 +57,11 @@ def sgid_to_name(i): table.extend([dict(row, **target) for target in targets] if targets else [row]) page_output(tabulate(table, args, cell_transforms={"SecurityGroups": lambda x, r: ", ".join(map(sgid_to_name, x))})) + parser = register_listing_parser(ls, parent=elb_parser, help="List ELBs") parser.add_argument("elbs", nargs="*") + def get_target_group(alb_name, target_group_name): alb = clients.elbv2.describe_load_balancers(Names=[alb_name])["LoadBalancers"][0] target_groups = clients.elbv2.describe_target_groups(LoadBalancerArn=alb["LoadBalancerArn"])["TargetGroups"] @@ -66,10 +71,12 @@ def get_target_group(alb_name, target_group_name): m = "Target group {} not found in {} (target groups found: {})" raise AegeaException(m.format(target_group_name, alb_name, ", ".join(t["TargetGroupName"] for t in target_groups))) + def get_targets(target_group): res = clients.elbv2.describe_target_health(TargetGroupArn=target_group["TargetGroupArn"]) return res["TargetHealthDescriptions"] + def register(args): if args.type == "ELB": instances = [dict(InstanceId=i) for i in args.instances] @@ -81,8 +88,10 @@ def register(args): clients.elbv2.register_targets(TargetGroupArn=target_group["TargetGroupArn"], Targets=instances) return dict(registered=instances, current=[t["Target"] for t in get_targets(target_group)]) + parser_register = register_parser(register, parent=elb_parser, help="Add EC2 instances to an ELB") + def deregister(args): if args.type == "ELB": instances = [dict(InstanceId=i) for i in args.instances] @@ -94,8 +103,10 @@ def deregister(args): clients.elbv2.deregister_targets(TargetGroupArn=target_group["TargetGroupArn"], Targets=instances) return dict(deregistered=instances, current=[t["Target"] for t in get_targets(target_group)]) + parser_deregister = register_parser(deregister, parent=elb_parser, help="Remove EC2 instances from an ELB") + def replace(args): result = register(args) old_instances = set(hashabledict(d) for d in result["current"]) - set(hashabledict(d) for d in result["registered"]) @@ -104,8 +115,11 @@ def replace(args): result.update(deregister(args)) return result -parser_replace = register_parser(replace, parent=elb_parser, - help="Replace all EC2 instances in an ELB with the ones given") + +parser_replace = register_parser( + replace, parent=elb_parser, help="Replace all EC2 instances in an ELB with the ones given" +) + def find_acm_cert(dns_name): for cert in paginate(clients.acm.get_paginator("list_certificates")): @@ -115,6 +129,7 @@ def find_acm_cert(dns_name): return cert raise AegeaException("Unable to find ACM certificate for {}".format(dns_name)) + def ensure_target_group(name, **kwargs): # TODO: delete and re-create action and TG if settings don't match try: @@ -125,6 +140,7 @@ def ensure_target_group(name, **kwargs): res = clients.elbv2.create_target_group(Name=name, **kwargs) return res["TargetGroups"][0] + def create(args): for zone in paginate(clients.route53.get_paginator("list_hosted_zones")): if args.dns_alias.endswith("." + zone["Name"].rstrip(".")): @@ -133,33 +149,43 @@ def create(args): raise AegeaException("Unable to find Route53 DNS zone for {}".format(args.dns_alias)) cert = find_acm_cert(args.dns_alias) if args.type == "ELB": - listener = dict(Protocol="https", - LoadBalancerPort=443, - SSLCertificateId=cert["CertificateArn"], - InstanceProtocol="http", - InstancePort=args.instance_port or 80) - elb = clients.elb.create_load_balancer(LoadBalancerName=args.elb_name, - Listeners=[listener], - AvailabilityZones=list(availability_zones()), - SecurityGroups=[sg.id for sg in args.security_groups]) + listener = dict( + Protocol="https", + LoadBalancerPort=443, + SSLCertificateId=cert["CertificateArn"], + InstanceProtocol="http", + InstancePort=args.instance_port or 80, + ) + elb = clients.elb.create_load_balancer( + LoadBalancerName=args.elb_name, + Listeners=[listener], + AvailabilityZones=list(availability_zones()), + SecurityGroups=[sg.id for sg in args.security_groups], + ) elif args.type == "ALB": vpc = ensure_vpc() - res = clients.elbv2.create_load_balancer(Name=args.elb_name, - Subnets=[subnet.id for subnet in vpc.subnets.all()], - SecurityGroups=[sg.id for sg in args.security_groups]) + res = clients.elbv2.create_load_balancer( + Name=args.elb_name, + Subnets=[subnet.id for subnet in vpc.subnets.all()], + SecurityGroups=[sg.id for sg in args.security_groups], + ) elb = res["LoadBalancers"][0] - target_group = ensure_target_group(args.target_group.format(elb_name=args.elb_name), - Protocol="HTTP", - Port=args.instance_port, - VpcId=vpc.id, - HealthCheckProtocol=args.health_check_protocol, - HealthCheckPort=args.health_check_port, - HealthCheckPath=args.health_check_path, - Matcher=dict(HttpCode=args.ok_http_codes)) - listener_params = dict(Protocol="HTTPS", - Port=443, - Certificates=[dict(CertificateArn=cert["CertificateArn"])], - DefaultActions=[dict(Type="forward", TargetGroupArn=target_group["TargetGroupArn"])]) + target_group = ensure_target_group( + args.target_group.format(elb_name=args.elb_name), + Protocol="HTTP", + Port=args.instance_port, + VpcId=vpc.id, + HealthCheckProtocol=args.health_check_protocol, + HealthCheckPort=args.health_check_port, + HealthCheckPath=args.health_check_path, + Matcher=dict(HttpCode=args.ok_http_codes), + ) + listener_params = dict( + Protocol="HTTPS", + Port=443, + Certificates=[dict(CertificateArn=cert["CertificateArn"])], + DefaultActions=[dict(Type="forward", TargetGroupArn=target_group["TargetGroupArn"])], + ) res = clients.elbv2.describe_listeners(LoadBalancerArn=elb["LoadBalancerArn"]) if res["Listeners"]: res = clients.elbv2.modify_listener(ListenerArn=res["Listeners"][0]["ListenerArn"], **listener_params) @@ -168,25 +194,38 @@ def create(args): listener = res["Listeners"][0] if args.path_pattern: rules = clients.elbv2.describe_rules(ListenerArn=listener["ListenerArn"])["Rules"] - clients.elbv2.create_rule(ListenerArn=listener["ListenerArn"], - Conditions=[dict(Field="path-pattern", Values=[args.path_pattern])], - Actions=[dict(Type="forward", TargetGroupArn=target_group["TargetGroupArn"])], - Priority=len(rules)) + clients.elbv2.create_rule( + ListenerArn=listener["ListenerArn"], + Conditions=[dict(Field="path-pattern", Values=[args.path_pattern])], + Actions=[dict(Type="forward", TargetGroupArn=target_group["TargetGroupArn"])], + Priority=len(rules), + ) replace(args) DNSZone(zone["Name"]).update(args.dns_alias.replace("." + zone["Name"].rstrip("."), ""), elb["DNSName"]) return dict(elb_name=args.elb_name, dns_name=elb["DNSName"], dns_alias=args.dns_alias) + parser_create = register_parser(create, parent=elb_parser, help="Create a new ELB") -parser_create.add_argument("--security-groups", nargs="+", type=resolve_security_group, required=True, help=""" +parser_create.add_argument( + "--security-groups", + nargs="+", + type=resolve_security_group, + required=True, + help=""" Security groups to assign the ELB. You must allow TCP traffic to flow between clients and the ELB on ports 80/443 -and allow TCP traffic to flow between the ELB and the instances on INSTANCE_PORT.""") +and allow TCP traffic to flow between the ELB and the instances on INSTANCE_PORT.""", +) parser_create.add_argument("--dns-alias", required=True, help="Fully qualified DNS name that will point to the ELB") parser_create.add_argument("--path-pattern") parser_create.add_argument("--health-check-protocol", default="HTTP", choices={"HTTP", "HTTPS"}) parser_create.add_argument("--health-check-port", default="traffic-port", help="Port to be queried by ELB health check") parser_create.add_argument("--health-check-path", default="/", help="Path to be queried by ELB health check") -parser_create.add_argument("--ok-http-codes", default="200-399", - help="Comma or dash-separated HTTP response codes considered healthy by ELB health check") +parser_create.add_argument( + "--ok-http-codes", + default="200-399", + help="Comma or dash-separated HTTP response codes considered healthy by ELB health check", +) + def delete(args): if args.type == "ELB": @@ -196,13 +235,16 @@ def delete(args): assert len(elbs) == 1 clients.elbv2.delete_load_balancer(LoadBalancerArn=elbs[0]["LoadBalancerArn"]) + parser_delete = register_parser(delete, parent=elb_parser, help="Delete an ELB") + def list_load_balancers(): elbs = paginate(clients.elb.get_paginator("describe_load_balancers")) albs = paginate(clients.elbv2.get_paginator("describe_load_balancers")) return list(elbs) + list(albs) + for parser in parser_register, parser_deregister, parser_replace, parser_create, parser_delete: parser.add_argument("elb_name").completer = lambda **kw: [i["LoadBalancerName"] for i in list_load_balancers()] parser.add_argument("--type", choices={"ELB", "ALB"}, default="ALB") diff --git a/aegea/flow_logs.py b/aegea/flow_logs.py index 311a0a9b..ea9d9a7b 100644 --- a/aegea/flow_logs.py +++ b/aegea/flow_logs.py @@ -15,8 +15,10 @@ def flow_logs(args): flow_logs_parser.print_help() + flow_logs_parser = register_parser(flow_logs, help="Manage EC2 VPC flow logs", description=__doc__) + def create(args): if args.resource and args.resource.startswith("vpc-"): resource_type = "VPC" @@ -29,35 +31,42 @@ def create(args): else: args.resource = ensure_vpc().id resource_type = "VPC" - flow_logs_iam_role = ensure_iam_role(__name__, - policies=["service-role/AmazonAPIGatewayPushToCloudWatchLogs"], - trust=["vpc-flow-logs"]) + flow_logs_iam_role = ensure_iam_role( + __name__, policies=["service-role/AmazonAPIGatewayPushToCloudWatchLogs"], trust=["vpc-flow-logs"] + ) try: - return clients.ec2.create_flow_logs(ResourceIds=[args.resource], - ResourceType=resource_type, - TrafficType=args.traffic_type, - LogGroupName=__name__, - DeliverLogsPermissionArn=flow_logs_iam_role.arn) + return clients.ec2.create_flow_logs( + ResourceIds=[args.resource], + ResourceType=resource_type, + TrafficType=args.traffic_type, + LogGroupName=__name__, + DeliverLogsPermissionArn=flow_logs_iam_role.arn, + ) except ClientError as e: expect_error_codes(e, "FlowLogAlreadyExists") return dict(FlowLogAlreadyExists=True) + parser = register_parser(create, parent=flow_logs_parser, help="Create VPC flow logs") parser.add_argument("--resource") parser.add_argument("--traffic_type", choices=["ACCEPT", "REJECT", "ALL"], default="ALL") + def ls(args): describe_flow_logs_args = dict(Filters=[dict(Name="resource-id", Values=[args.resource])]) if args.resource else {} page_output(tabulate(clients.ec2.describe_flow_logs(**describe_flow_logs_args)["FlowLogs"], args)) + parser = register_listing_parser(ls, parent=flow_logs_parser, help="List VPC flow logs") parser.add_argument("--resource") + def get(args): args.log_group, args.pattern = __name__, None args.log_stream = "-".join([args.network_interface, args.traffic_type]) if args.network_interface else None grep(args) + parser = register_parser(get, parent=flow_logs_parser, help="Get VPC flow logs") parser.add_argument("--network-interface") parser.add_argument("--traffic_type", choices=["ACCEPT", "REJECT", "ALL"], default="ALL") diff --git a/aegea/iam.py b/aegea/iam.py index 8df99001..44aef547 100644 --- a/aegea/iam.py +++ b/aegea/iam.py @@ -21,18 +21,22 @@ def iam(args): iam_parser.print_help() + iam_parser = register_parser(iam) + def configure(args): for group, policies in config.managed_iam_groups.items(): print("Creating group", group) formatted_policies = [(IAMPolicyBuilder(**p) if isinstance(p, Mapping) else p) for p in policies] ensure_iam_group(group, policies=formatted_policies) - msg = 'Created group {g}. Use the AWS console or "aws iam add-user-to-group --user-name USER --group-name {g}" to add users to it.' # noqa + msg = 'Created group {g}. Use the AWS console or "aws iam add-user-to-group --user-name USER --group-name {g}" to add users to it.' # noqa print(BOLD(msg.format(g=group))) + parser_configure = register_parser(configure, parent=iam_parser, help="Set up aegea-specific IAM groups and policies") + def get_policies_for_principal(cell, row): try: policies = [p.policy_name for p in row.policies.all()] + [p.policy_name for p in row.attached_policies.all()] @@ -42,6 +46,7 @@ def get_policies_for_principal(cell, row): return "[Access denied]" raise + def users(args): try: current_user_id = resources.iam.CurrentUser().user_id @@ -67,27 +72,35 @@ def describe_access_keys(cell): "cur": mark_cur_user, "policies": get_policies_for_principal, "mfa": describe_mfa, - "access_keys": describe_access_keys + "access_keys": describe_access_keys, } page_output(tabulate(users, args, cell_transforms=cell_transforms)) + parser = register_listing_parser(users, parent=iam_parser, help="List IAM users") + def groups(args): page_output(tabulate(resources.iam.groups.all(), args, cell_transforms={"policies": get_policies_for_principal})) + parser = register_listing_parser(groups, parent=iam_parser, help="List IAM groups") + def roles(args): page_output(tabulate(resources.iam.roles.all(), args, cell_transforms={"policies": get_policies_for_principal})) + parser = register_listing_parser(roles, parent=iam_parser, help="List IAM roles") + def policies(args): page_output(tabulate(resources.iam.policies.all(), args)) + parser = register_listing_parser(policies, parent=iam_parser, help="List IAM policies") + def generate_password(length=16): while True: password = [random.SystemRandom().choice(string.ascii_letters + string.digits) for _ in range(length)] @@ -98,17 +111,19 @@ def generate_password(length=16): continue if not any(c in string.digits for c in password): continue - return ''.join(password) + return "".join(password) + def create_user(args): if args.prompt_for_password: from getpass import getpass + args.password = getpass(prompt=f"Password for IAM user {args.username}:") else: args.password = generate_password() try: user = resources.iam.create_user(UserName=args.username) - clients.iam.get_waiter('user_exists').wait(UserName=args.username) + clients.iam.get_waiter("user_exists").wait(UserName=args.username) logger.info("Created new IAM user %s", user) print(BOLD(f"Generated new password for IAM user {args.username}: {args.password}")) except resources.iam.meta.client.exceptions.EntityAlreadyExistsException: @@ -129,9 +144,11 @@ def create_user(args): user.add_group(GroupName=group.name) logger.info("Added %s to %s", user, group) + parser = register_listing_parser(create_user, parent=iam_parser, help="Create a new IAM user") parser.add_argument("username") parser.add_argument("--reset-password", action="store_true") -parser.add_argument("--prompt-for-password", - help="Display an interactive prompt for new user password instead of autogenerating") +parser.add_argument( + "--prompt-for-password", help="Display an interactive prompt for new user password instead of autogenerating" +) parser.add_argument("--groups", nargs="*", default=[], help="IAM groups to add the user to") diff --git a/aegea/instance_ctl.py b/aegea/instance_ctl.py index 31502059..596773b6 100644 --- a/aegea/instance_ctl.py +++ b/aegea/instance_ctl.py @@ -18,18 +18,22 @@ def resolve_instance_ids(input_names): raise Exception("Unable to resolve one or more of the instance names") return ids, names + def start(args): ids, names = resolve_instance_ids(args.names) clients.ec2.start_instances(InstanceIds=ids, DryRun=args.dry_run) + def stop(args): ids, names = resolve_instance_ids(args.names) clients.ec2.stop_instances(InstanceIds=ids, DryRun=args.dry_run) + def reboot(args): ids, names = resolve_instance_ids(args.names) clients.ec2.reboot_instances(InstanceIds=ids, DryRun=args.dry_run) + def terminate(args): ids, names = resolve_instance_ids(args.names) clients.ec2.terminate_instances(InstanceIds=ids, DryRun=args.dry_run) @@ -38,13 +42,16 @@ def terminate(args): if not args.dry_run: DNSZone().delete(name) + def rename(args): """Supply two names: Existing instance name or ID, and new name to assign to the instance.""" old_name, new_name = args.names add_tags(resources.ec2.Instance(resolve_instance_id(old_name)), Name=new_name, dry_run=args.dry_run) + for action in (start, stop, reboot, terminate, rename): - parser = register_parser(action, help="{} EC2 instances".format(action.__name__.capitalize()), - description=action.__doc__) + parser = register_parser( + action, help="{} EC2 instances".format(action.__name__.capitalize()), description=action.__doc__ + ) parser.add_argument("--dry-run", "--dryrun", action="store_true") parser.add_argument("names", nargs="+") diff --git a/aegea/lambda.py b/aegea/lambda.py index cf08caa5..0732ce65 100644 --- a/aegea/lambda.py +++ b/aegea/lambda.py @@ -17,20 +17,27 @@ def _lambda(args): lambda_parser.print_help() + lambda_parser = register_parser(_lambda, name="lambda") + def ls(args): paginator = getattr(clients, "lambda").get_paginator("list_functions") page_output(tabulate(paginate(paginator), args, cell_transforms={"LastModified": Timestamp})) + parser_ls = register_listing_parser(ls, parent=lambda_parser, help="List AWS Lambda functions") + def event_source_mappings(args): paginator = getattr(clients, "lambda").get_paginator("list_event_source_mappings") page_output(tabulate(paginate(paginator), args)) -parser_event_source_mappings = register_listing_parser(event_source_mappings, parent=lambda_parser, - help="List event source mappings") + +parser_event_source_mappings = register_listing_parser( + event_source_mappings, parent=lambda_parser, help="List event source mappings" +) + def update_code(args): with open(args.zip_file, "rb") as fh: @@ -40,10 +47,12 @@ def update_code(args): assert base64.b64decode(res["CodeSha256"]) == payload_sha return res + update_code_parser = register_parser(update_code, parent=lambda_parser, help="Update function code") update_code_parser.add_argument("function_name") update_code_parser.add_argument("zip_file") + def update_config(args): update_args = dict(FunctionName=args.function_name) if args.role: @@ -58,15 +67,24 @@ def update_config(args): update_args.update(Environment=cfg["Environment"]) return getattr(clients, "lambda").update_function_configuration(**update_args) + def role_name_completer(**kwargs): return [r.name for r in resources.iam.roles.all()] + update_config_parser = register_parser(update_config, parent=lambda_parser, help="Update function configuration") update_config_parser.add_argument("function_name") update_config_parser.add_argument("--role", help="IAM role for the function").completer = role_name_completer -update_config_parser.add_argument("--timeout", type=int, - help="The amount of time that Lambda allows a function to run before stopping it") -update_config_parser.add_argument("--memory-size", type=int, - help="The amount of memory that your function has access to") -update_config_parser.add_argument("--environment", nargs="+", metavar="NAME=VALUE", type=lambda x: x.split("=", 1), - help="Read environment variables for function, update with given values, write back") +update_config_parser.add_argument( + "--timeout", type=int, help="The amount of time that Lambda allows a function to run before stopping it" +) +update_config_parser.add_argument( + "--memory-size", type=int, help="The amount of memory that your function has access to" +) +update_config_parser.add_argument( + "--environment", + nargs="+", + metavar="NAME=VALUE", + type=lambda x: x.split("=", 1), + help="Read environment variables for function, update with given values, write back", +) diff --git a/aegea/launch.py b/aegea/launch.py index e4974705..9f295d80 100644 --- a/aegea/launch.py +++ b/aegea/launch.py @@ -72,6 +72,7 @@ def get_spot_bid_price(instance_type, ondemand_multiplier=1.2): ondemand_price = get_ondemand_price_usd(clients.ec2.meta.region_name, instance_type) return float(ondemand_price) * ondemand_multiplier + def get_startup_commands(args): hostname = ".".join([args.hostname, config.dns.private_zone.rstrip(".")]) if args.use_dns else args.hostname return [ @@ -80,24 +81,29 @@ def get_startup_commands(args): "echo tsc > /sys/devices/system/clocksource/clocksource0/current_clocksource", ] + args.commands + def get_ssh_ca_keys(bless_config): for lambda_regional_config in bless_config["lambda_config"]["regions"]: if lambda_regional_config["aws_region"] == clients.ec2.meta.region_name: break - ca_keys_secret_arn = ARN(service="secretsmanager", - region=lambda_regional_config["aws_region"], - account_id=ARN(bless_config["lambda_config"]["role_arn"]).account_id, - resource="secret:" + bless_config["lambda_config"]["function_name"]) + ca_keys_secret_arn = ARN( + service="secretsmanager", + region=lambda_regional_config["aws_region"], + account_id=ARN(bless_config["lambda_config"]["role_arn"]).account_id, + resource="secret:" + bless_config["lambda_config"]["function_name"], + ) ca_keys_secret = clients.secretsmanager.get_secret_value(SecretId=str(ca_keys_secret_arn)) ca_keys = json.loads(ca_keys_secret["SecretString"])["ssh_ca_keys"] return "\n".join(ca_keys) + def infer_architecture(instance_type): instance_family = instance_type.split(".")[0] if "g" in instance_family or instance_family == "a1": return "arm64" return "x86_64" + def ensure_efs_home(subnet): for fs in clients.efs.describe_file_systems()["FileSystems"]: if {"Key": "mountpoint", "Value": "/home"} in fs["Tags"]: @@ -108,17 +114,22 @@ def ensure_efs_home(subnet): create_efs_args = ["aegea_home", "--tags", "mountpoint=/home", "managedBy=aegea", "--vpc", subnet.vpc_id] return create_efs(parser_create_efs.parse_args(create_efs_args)) + def launch(args): + # FIXME: run `systemctl mask mnt.mount` may be needed to disable systemctl "management" of mounts + # Test by rebooting an instance with an ebs volume attached and confirming the mount comes back up + # See https://unix.stackexchange.com/questions/563300/how-to-stop-systemd-from-immediately-unmounting-degraded-btrfs-volume args.storage = dict(args.storage) # Allow storage to be specified as either a list (argparse) or mapping (YAML) args.storage = {k: str(v).rstrip("GBgb") for k, v in args.storage.items()} - logger.debug('Using %s for storage', ', '.join('='.join(_) for _ in args.storage.items())) + logger.debug("Using %s for storage", ", ".join("=".join(_) for _ in args.storage.items())) if args.spot_price or args.duration_hours or args.cores or args.min_mem_per_core_gb: args.spot = True if args.use_dns: dns_zone = DNSZone() - ssh_key_name = ensure_ssh_key(name=args.ssh_key_name, base_name=__name__, - verify_pem_file=args.verify_ssh_key_pem_file) + ssh_key_name = ensure_ssh_key( + name=args.ssh_key_name, base_name=__name__, verify_pem_file=args.verify_ssh_key_pem_file + ) user_info = get_user_info() # TODO: move all account init checks into init helper with region-specific semaphore on s3 try: @@ -146,10 +157,12 @@ def launch(args): logger.info("Using %s (%s)", args.ami, args.ami.name) except AegeaException as e: if args.ami is None and len(ami_tags) == 0 and "Could not resolve AMI" in str(e): - raise AegeaException("No AMI was given, and no " + arch + " AMIs were found in this account. " - "To build an aegea AMI, use aegea build-ami --architecture " + arch + ". " - "To use the default Ubuntu Linux LTS AMI, use --ubuntu-linux-ami. " - "To use the default Amazon Linux 2 AMI, use --amazon-linux-ami. ") + raise AegeaException( + "No AMI was given, and no " + arch + " AMIs were found in this account. " + "To build an aegea AMI, use aegea build-ami --architecture " + arch + ". " + "To use the default Ubuntu Linux LTS AMI, use --ubuntu-linux-ami. " + "To use the default Amazon Linux 2 AMI, use --amazon-linux-ami. " + ) raise if args.subnet: subnet = resources.ec2.Subnet(args.subnet) @@ -159,10 +172,13 @@ def launch(args): if args.spot and not args.availability_zone: # Select the optimal availability zone by scanning the price history for the given spot instance type best_spot_price_desc = dict(SpotPrice=sys.maxsize, AvailabilityZone=None) - for spot_price_desc in paginate(clients.ec2.get_paginator("describe_spot_price_history"), - InstanceTypes=[args.instance_type], - ProductDescriptions=["Linux/UNIX (Amazon VPC)", "Linux/Unix"], - StartTime=datetime.datetime.utcnow() - datetime.timedelta(hours=1)): + for spot_price_desc in paginate( + clients.ec2.get_paginator("describe_spot_price_history"), + InstanceTypes=[args.instance_type], + ProductDescriptions=["Linux/UNIX (Amazon VPC)", "Linux/Unix"], + StartTime=datetime.datetime.utcnow() - datetime.timedelta(hours=1), + ): + assert isinstance(best_spot_price_desc["SpotPrice"], (str, int, float)) if float(spot_price_desc["SpotPrice"]) < float(best_spot_price_desc["SpotPrice"]): best_spot_price_desc = spot_price_desc args.availability_zone = best_spot_price_desc["AvailabilityZone"] @@ -178,17 +194,23 @@ def launch(args): security_groups.append(resolve_security_group(efs_security_group_name, vpc)) ssh_host_key = new_ssh_key() - user_data_args = dict(host_key=ssh_host_key, - commands=get_startup_commands(args), - packages=args.packages, - storage=args.storage, - rootfs_skel_dirs=get_rootfs_skel_dirs(args)) + user_data_args = dict( + host_key=ssh_host_key, + commands=get_startup_commands(args), + packages=args.packages, + storage=args.storage, + rootfs_skel_dirs=get_rootfs_skel_dirs(args), + ) if args.provision_user: - user_data_args["provision_users"] = [dict(name=user_info["linux_username"], - uid=user_info["linux_user_id"], - sudo="ALL=(ALL) NOPASSWD:ALL", - groups="docker", - shell="/bin/bash")] + user_data_args["provision_users"] = [ + dict( + name=user_info["linux_username"], + uid=user_info["linux_user_id"], + sudo="ALL=(ALL) NOPASSWD:ALL", + groups="docker", + shell="/bin/bash", + ) + ] elif args.bless_config: with open(args.bless_config) as fh: bless_config = yaml.safe_load(fh) @@ -196,17 +218,24 @@ def launch(args): user_data_args["provision_users"] = bless_config["client_config"]["remote_users"] hkl = hostkey_line(hostnames=[], key=ssh_host_key).strip() - instance_tags = dict(Name=args.hostname, Owner=user_info["iam_username"], - SSHHostPublicKeyPart1=hkl[:255], SSHHostPublicKeyPart2=hkl[255:], - OwnerSSHKeyName=ssh_key_name, **dict(args.tags)) + instance_tags = dict( + Name=args.hostname, + Owner=user_info["iam_username"], + SSHHostPublicKeyPart1=hkl[:255], + SSHHostPublicKeyPart2=hkl[255:], + OwnerSSHKeyName=ssh_key_name, + **dict(args.tags), + ) user_data_args.update(dict(args.cloud_config_data)) - launch_spec = dict(ImageId=args.ami.id, - KeyName=ssh_key_name, - SubnetId=subnet.id, - SecurityGroupIds=[sg.id for sg in security_groups], - InstanceType=args.instance_type, - BlockDeviceMappings=get_bdm(ami=args.ami.id, ebs_storage=args.storage), - UserData=get_user_data(**user_data_args)) + launch_spec = dict( + ImageId=args.ami.id, + KeyName=ssh_key_name, + SubnetId=subnet.id, + SecurityGroupIds=[sg.id for sg in security_groups], + InstanceType=args.instance_type, + BlockDeviceMappings=get_bdm(ami=args.ami.id, ebs_storage=args.storage), + UserData=get_user_data(**user_data_args), + ) tag_spec = dict(ResourceType="instance", Tags=encode_tags(instance_tags)) logger.info("Launch spec user data is %i bytes long", len(launch_spec["UserData"])) if args.iam_role: @@ -217,8 +246,10 @@ def launch(args): instance_profile = ensure_instance_profile(args.iam_role, policies=[umbrella_policy]) except ClientError as e: expect_error_codes(e, "AccessDenied") - raise AegeaException('Unable to validate IAM permissions for launch. If you have only iam:PassRole ' - 'access, try --no-manage-iam. If you have no IAM access, try --iam-role="".') + raise AegeaException( + "Unable to validate IAM permissions for launch. If you have only iam:PassRole " + 'access, try --no-manage-iam. If you have no IAM access, try --iam-role="".' + ) else: instance_profile = resources.iam.InstanceProfile(args.iam_role) launch_spec["IamInstanceProfile"] = dict(Arn=instance_profile.arn) @@ -231,21 +262,25 @@ def launch(args): if args.spot: launch_spec["UserData"] = base64.b64encode(launch_spec["UserData"]).decode() if args.duration_hours or args.cores or args.min_mem_per_core_gb: - spot_fleet_args = dict(launch_spec=dict(launch_spec, TagSpecifications=[tag_spec]), - client_token=args.client_token) + spot_fleet_args = dict( + launch_spec=dict(launch_spec, TagSpecifications=[tag_spec]), client_token=args.client_token + ) for arg in "cores", "min_mem_per_core_gb", "spot_price", "duration_hours", "dry_run": if getattr(args, arg, None): spot_fleet_args[arg] = getattr(args, arg) if "cores" in spot_fleet_args: spot_fleet_args["min_cores_per_instance"] = spot_fleet_args["cores"] if args.instance_type != parser.get_default("instance_type"): - msg = ("Using --instance-type with spot fleet may unnecessarily constrain available instances. " - "Consider using --cores and --min-mem-per-core-gb instead") + msg = ( + "Using --instance-type with spot fleet may unnecessarily constrain available instances. " + "Consider using --cores and --min-mem-per-core-gb instead" + ) logger.warn(msg) class InstanceSpotFleetBuilder(SpotFleetBuilder): def instance_types(self, **kwargs): yield args.instance_type, 1 + spot_fleet_builder = InstanceSpotFleetBuilder(**spot_fleet_args) # type: SpotFleetBuilder else: spot_fleet_builder = SpotFleetBuilder(**spot_fleet_args) @@ -268,7 +303,7 @@ def instance_types(self, **kwargs): ValidUntil=datetime.datetime.utcnow() + datetime.timedelta(hours=1), LaunchSpecification=launch_spec, ClientToken=args.client_token, - DryRun=args.dry_run + DryRun=args.dry_run, ) sir_id = res["SpotInstanceRequests"][0]["SpotInstanceRequestId"] clients.ec2.get_waiter("spot_instance_request_fulfilled").wait(SpotInstanceRequestIds=[sir_id]) @@ -277,8 +312,9 @@ def instance_types(self, **kwargs): add_tags(instance, **instance_tags) else: launch_spec = dict(launch_spec, TagSpecifications=[tag_spec]) - instances = resources.ec2.create_instances(MinCount=1, MaxCount=1, ClientToken=args.client_token, - DryRun=args.dry_run, **launch_spec) + instances = resources.ec2.create_instances( + MinCount=1, MaxCount=1, ClientToken=args.client_token, DryRun=args.dry_run, **launch_spec + ) instance = instances[0] except (KeyboardInterrupt, WaiterError): if sir_id is not None and instance is None: @@ -300,55 +336,106 @@ def instance_types(self, **kwargs): logger.info("Launched %s %s in %s using %s (%s)", instance.instance_type, instance, subnet, args.ami, args.ami.name) return dict(instance_id=instance.id) + parser = register_parser(launch) parser.add_argument("hostname") -parser.add_argument("--storage", nargs="+", metavar="MOUNTPOINT=SIZE_GB", type=lambda x: x.split("=", 1), - help="At launch time, attach EBS volume(s) of this size, format and mount them.") -parser.add_argument("--efs-home", action="store_true", - help="Create and manage an EFS filesystem that the instance will use for user home directories") +parser.add_argument( + "--storage", + nargs="+", + metavar="MOUNTPOINT=SIZE_GB", + type=lambda x: x.split("=", 1), + help="At launch time, attach EBS volume(s) of this size, format and mount them.", +) +parser.add_argument( + "--efs-home", + action="store_true", + help="Create and manage an EFS filesystem that the instance will use for user home directories", +) parser.add_argument("--commands", nargs="+", metavar="COMMAND", help="Commands to run on host upon startup") parser.add_argument("--packages", nargs="+", metavar="PACKAGE", help="APT packages to install on host upon startup") parser.add_argument("--ssh-key-name") parser.add_argument("--no-verify-ssh-key-pem-file", dest="verify_ssh_key_pem_file", action="store_false") parser.add_argument("--no-provision-user", dest="provision_user", action="store_false") -parser.add_argument("--bless-config", default=os.environ.get("BLESS_CONFIG"), - help="Path to a Bless configuration file (or pass via the BLESS_CONFIG environment variable)") +parser.add_argument( + "--bless-config", + default=os.environ.get("BLESS_CONFIG"), + help="Path to a Bless configuration file (or pass via the BLESS_CONFIG environment variable)", +) parser.add_argument("--ami", help="AMI to use for the instance. Default: " + resolve_ami.__doc__) # type: ignore parser.add_argument("--ami-tags", nargs="+", metavar="NAME=VALUE", help="Use the most recent AMI with these tags") -parser.add_argument("--ami-tag-keys", nargs="+", default=[], metavar="TAG_NAME", - help="Use the most recent AMI with these tag names") +parser.add_argument( + "--ami-tag-keys", nargs="+", default=[], metavar="TAG_NAME", help="Use the most recent AMI with these tag names" +) parser.add_argument("--ubuntu-linux-ami", action="store_true", help="Use the most recent Ubuntu Linux LTS AMI") parser.add_argument("--amazon-linux-ami", action="store_true", help="Use the most recent Amazon Linux AMI") parser.add_argument("--amazon-linux-release", help="Use a specific Amazon Linux release", choices={"2", "2022", "2023"}) -parser.add_argument("--spot", action="store_true", - help="Launch a preemptible spot instance, which is cheaper but could be forced to shut down") +parser.add_argument( + "--spot", + action="store_true", + help="Launch a preemptible spot instance, which is cheaper but could be forced to shut down", +) parser.add_argument("--duration-hours", type=float, help="Terminate the spot instance after this number of hours") parser.add_argument("--cores", type=int, help="Minimum number of cores to request (spot fleet API)") parser.add_argument("--min-mem-per-core-gb", type=float) parser.add_argument("--instance-type", "-t", help="See https://ec2instances.info/").completer = instance_type_completer -parser.add_argument("--spot-price", type=float, - help="Maximum bid price for spot instances. Defaults to 1.2x the ondemand price.") -parser.add_argument("--no-dns", dest="use_dns", action="store_false", help=""" -Skip registering instance name in private DNS (if you don't want launching principal to have Route53 write access)""") +parser.add_argument( + "--spot-price", type=float, help="Maximum bid price for spot instances. Defaults to 1.2x the ondemand price." +) +parser.add_argument( + "--no-dns", + dest="use_dns", + action="store_false", + help=""" +Skip registering instance name in private DNS (if you don't want launching principal to have Route53 write access)""", +) parser.add_argument("--client-token", help="Token used to identify your instance, SIR or SFR") parser.add_argument("--subnet") parser.add_argument("--availability-zone", "--az") parser.add_argument("--security-groups", nargs="+", metavar="SECURITY_GROUP") -parser.add_argument("--tags", nargs="+", metavar="NAME=VALUE", type=lambda x: x.split("=", 1), - help="Tags to apply to launched instances.") -parser.add_argument("--wait-for-ssh", action="store_true", - help=("Wait for launched instance to begin accepting SSH connections. " - "Security groups and NACLs must permit SSH from launching host.")) -parser.add_argument("--iam-role", help=("Pass this IAM role to the launched instance through an instance profile. " - "Role credentials will become available in the instance metadata. " - "To launch an instance without a profile/role, use an empty string here.")) -parser.add_argument("--iam-policies", nargs="+", metavar="IAM_POLICY_NAME", - help="Ensure the default or specified IAM role has the listed IAM managed policies attached") -parser.add_argument("--use-imdsv2", "--metadata-options-http-tokens-required", action="store_true", - help="Configure the instance to use Instance Metadata Service Version 2") -parser.add_argument("--no-manage-iam", action="store_false", dest="manage_iam", - help=("Prevents aegea from creating or managing the IAM role or policies for the instance. The " - "given or default IAM role and instance profile will still be used, raising an error if they " - "are not found.")) +parser.add_argument( + "--tags", + nargs="+", + metavar="NAME=VALUE", + type=lambda x: x.split("=", 1), + help="Tags to apply to launched instances.", +) +parser.add_argument( + "--wait-for-ssh", + action="store_true", + help=( + "Wait for launched instance to begin accepting SSH connections. " + "Security groups and NACLs must permit SSH from launching host." + ), +) +parser.add_argument( + "--iam-role", + help=( + "Pass this IAM role to the launched instance through an instance profile. " + "Role credentials will become available in the instance metadata. " + "To launch an instance without a profile/role, use an empty string here." + ), +) +parser.add_argument( + "--iam-policies", + nargs="+", + metavar="IAM_POLICY_NAME", + help="Ensure the default or specified IAM role has the listed IAM managed policies attached", +) +parser.add_argument( + "--use-imdsv2", + "--metadata-options-http-tokens-required", + action="store_true", + help="Configure the instance to use Instance Metadata Service Version 2", +) +parser.add_argument( + "--no-manage-iam", + action="store_false", + dest="manage_iam", + help=( + "Prevents aegea from creating or managing the IAM role or policies for the instance. The " + "given or default IAM role and instance profile will still be used, raising an error if they " + "are not found." + ), +) parser.add_argument("--cloud-config-data", type=json.loads) parser.add_argument("--dry-run", "--dryrun", action="store_true") diff --git a/aegea/logs.py b/aegea/logs.py index 7d0fa289..0c33ed28 100644 --- a/aegea/logs.py +++ b/aegea/logs.py @@ -50,6 +50,7 @@ def log_group_completer(prefix, **kwargs): for group in paginate(clients.logs.get_paginator("describe_log_groups"), **describe_log_groups_args): yield group["logGroupName"] + def logs(args): if args.log_group and (args.log_stream or args.start_time or args.end_time): if args.export and args.print_s3_urls: @@ -66,8 +67,12 @@ def logs(args): if args.log_group and group["logGroupName"] != args.log_group: continue n = 0 - for stream in paginate(clients.logs.get_paginator("describe_log_streams"), - logGroupName=group["logGroupName"], orderBy="LastEventTime", descending=True): + for stream in paginate( + clients.logs.get_paginator("describe_log_streams"), + logGroupName=group["logGroupName"], + orderBy="LastEventTime", + descending=True, + ): now = datetime.utcnow().replace(microsecond=0) stream["lastIngestionTime"] = now - datetime.utcfromtimestamp(stream.get("lastIngestionTime", 0) // 1000) table.append(dict(group, **stream)) @@ -76,6 +81,7 @@ def logs(args): break page_output(tabulate(table, args)) + logs_parser = register_parser(logs) logs_parser.add_argument("--max-streams-per-group", "-n", type=int, default=8) logs_parser.add_argument("--sort-by", default="lastIngestionTime:reverse") @@ -85,6 +91,7 @@ def logs(args): logs_parser.add_argument("log_stream", nargs="?", help="CloudWatch log stream") add_time_bound_args(logs_parser, snap=2, start="-24h") + def filter(args): filter_args = dict(logGroupName=args.log_group) if args.log_stream: @@ -107,24 +114,32 @@ def filter(args): else: return SystemExit(os.EX_OK if num_results > 0 else os.EX_DATAERR) + filter_parser = register_parser(filter, help="Filter and print events in a CloudWatch Logs stream or group of streams") -filter_parser.add_argument("pattern", help="""CloudWatch filter pattern to use. Case-sensitive. See -http://docs.aws.amazon.com/AmazonCloudWatch/latest/DeveloperGuide/FilterAndPatternSyntax.html""") +filter_parser.add_argument( + "pattern", + help="""CloudWatch filter pattern to use. Case-sensitive. See +http://docs.aws.amazon.com/AmazonCloudWatch/latest/DeveloperGuide/FilterAndPatternSyntax.html""", +) filter_parser.add_argument("log_group", help="CloudWatch log group").completer = log_group_completer filter_parser.add_argument("log_stream", nargs="?", help="CloudWatch log stream") -filter_parser.add_argument("--follow", "-f", help="Repeat search continuously instead of running once", - action="store_true") +filter_parser.add_argument( + "--follow", "-f", help="Repeat search continuously instead of running once", action="store_true" +) add_time_bound_args(filter_parser) + def grep(args): if args.context: args.before_context = args.after_context = args.context if not args.end_time: args.end_time = Timestamp("-0s") - query = clients.logs.start_query(logGroupName=args.log_group, - startTime=int(datetime.timestamp(args.start_time) * 1000), - endTime=int(datetime.timestamp(args.end_time) * 1000), - queryString=args.query) + query = clients.logs.start_query( + logGroupName=args.log_group, + startTime=int(datetime.timestamp(args.start_time) * 1000), + endTime=int(datetime.timestamp(args.end_time) * 1000), + queryString=args.query, + ) seen_results = {} # type: Dict[str, Dict] print_with_context = partial(print_log_event_with_context, before=args.before_context, after=args.after_context) try: @@ -157,9 +172,13 @@ def grep(args): logger.debug("Query %s: %s", query["queryId"], res["statistics"]) return SystemExit(os.EX_OK if seen_results else os.EX_DATAERR) + grep_parser = register_parser(grep, help="Run a CloudWatch Logs Insights query (similar to filter, but faster)") -grep_parser.add_argument("query", help="""CloudWatch Logs Insights query to use. See -https://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/AnalyzingLogData.html""") +grep_parser.add_argument( + "query", + help="""CloudWatch Logs Insights query to use. See +https://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/AnalyzingLogData.html""", +) grep_parser.add_argument("log_group", help="CloudWatch log group").completer = log_group_completer grep_parser.add_argument("--before-context", "-B", type=int, default=0) grep_parser.add_argument("--after-context", "-A", type=int, default=0) diff --git a/aegea/ls.py b/aegea/ls.py index aca9d4bb..d5f3803a 100644 --- a/aegea/ls.py +++ b/aegea/ls.py @@ -14,6 +14,7 @@ def column_completer(parser, **kwargs): subresource = getattr(resource, parser.get_default("subresource")) return [attr for attr in dir(subresource("")) if not attr.startswith("_")] + def register_listing_parser(function, **kwargs): col_def = dict(default=kwargs.pop("column_defaults")) if "column_defaults" in kwargs else {} parser = register_parser(function, **kwargs) @@ -22,14 +23,23 @@ def register_listing_parser(function, **kwargs): col_arg.completer = column_completer return parser + def register_filtering_parser(function, **kwargs): parser = register_listing_parser(function, **kwargs) - parser.add_argument("-f", "--filter", nargs="+", default=[], metavar="FILTER_NAME=VALUE", - help="Filter(s) to apply to output, e.g. --filter state=available") - parser.add_argument("-t", "--tag", nargs="+", default=[], metavar="TAG_NAME=VALUE", - help="Tag(s) to filter output by") + parser.add_argument( + "-f", + "--filter", + nargs="+", + default=[], + metavar="FILTER_NAME=VALUE", + help="Filter(s) to apply to output, e.g. --filter state=available", + ) + parser.add_argument( + "-t", "--tag", nargs="+", default=[], metavar="TAG_NAME=VALUE", help="Tag(s) to filter output by" + ) return parser + def filter_collection(collection, args): filters = [] # TODO: shlex? @@ -45,9 +55,11 @@ def filter_collection(collection, args): filters.append(dict(Name="tag:" + name, Values=[value])) return collection.filter(Filters=filters) + def filter_and_tabulate(collection, args, **kwargs): return tabulate(filter_collection(collection, args), args, **kwargs) + def add_name(instance): instance.name = instance.id for tag in instance.tags or []: @@ -55,6 +67,7 @@ def add_name(instance): instance.name = tag["Value"] return instance + def ls(args): for col in "tags", "launch_time": if col not in args.columns: @@ -65,26 +78,33 @@ def ls(args): "state": lambda x, r: x["Name"], "security_groups": lambda x, r: ", ".join(sg["GroupName"] for sg in x), "iam_instance_profile": lambda x, r: x.get("Arn", "").split("/")[-1] if x else None, - "instance_lifecycle": lambda x, r: "" if x is None else x + "instance_lifecycle": lambda x, r: "" if x is None else x, } page_output(tabulate(instances, args, cell_transforms=cell_transforms)) + parser = register_filtering_parser(ls, help="List EC2 instances") + def console(args): instance_id = resolve_instance_id(args.instance) err = "[No console output received for {}. Console output may lag by several minutes.]".format(instance_id) page_output(resources.ec2.Instance(instance_id).console_output().get("Output", err)) + parser = register_parser(console, help="Get console output for an EC2 instance") parser.add_argument("instance") + def images(args): page_output(filter_and_tabulate(resources.ec2.images.filter(Owners=["self"]), args)) + parser = register_filtering_parser(images, help="List EC2 AMIs") peer_desc_cache = {} # type: Dict[str, Any] + + def describe_peer(peer): if "CidrIp" in peer: if peer["CidrIp"] not in peer_desc_cache: @@ -95,6 +115,7 @@ def describe_peer(peer): peer_desc_cache[peer["GroupId"]] = resources.ec2.SecurityGroup(peer["GroupId"]) return peer_desc_cache[peer["GroupId"]].group_name, peer_desc_cache[peer["GroupId"]].description + def security_groups(args): def format_rule(row, perm, peer, egress=False): peer_desc, row.peer_description = describe_peer(peer) @@ -103,6 +124,7 @@ def format_rule(row, perm, peer, egress=False): row.rule += GREEN("▶") if egress else GREEN("◀") row.rule += peer_desc + ":" + (port_range if egress else "*") row.proto = "*" if perm["IpProtocol"] == "-1" else perm["IpProtocol"] + table = [] for sg in resources.ec2.security_groups.all(): for i, perm in enumerate(sg.ip_permissions + sg.ip_permissions_egress): @@ -111,21 +133,27 @@ def format_rule(row, perm, peer, egress=False): format_rule(table[-1], perm, peer, egress=True if i > len(sg.ip_permissions) - 1 else False) page_output(tabulate(table, args)) + parser = register_filtering_parser(security_groups, help="List EC2 security groups") + def acls(args): page_output(filter_and_tabulate(resources.ec2.network_acls, args)) + parser = register_filtering_parser(acls, help="List EC2 network ACLs") + def clusters(args): - cluster_arns = sum([p["clusterArns"] for p in clients.ecs.get_paginator("list_clusters").paginate()], []) # type: List[Dict] # noqa + cluster_arns = sum([p["clusterArns"] for p in clients.ecs.get_paginator("list_clusters").paginate()], []) # type: List[Dict] # noqa page_output(tabulate(clients.ecs.describe_clusters(clusters=cluster_arns)["clusters"], args)) + parser = register_listing_parser(clusters, help="List ECS clusters") + def tasks(args): - cluster_arns = sum([p["clusterArns"] for p in clients.ecs.get_paginator("list_clusters").paginate()], []) # type: List[Dict] # noqa + cluster_arns = sum([p["clusterArns"] for p in clients.ecs.get_paginator("list_clusters").paginate()], []) # type: List[Dict] # noqa table = [] for cluster_arn in cluster_arns: list_tasks_args = dict(cluster=cluster_arn, desiredStatus=args.desired_status) @@ -136,49 +164,67 @@ def tasks(args): table.append(task) page_output(tabulate(table, args)) + parser = register_listing_parser(tasks, help="List ECS tasks") parser.add_argument("--desired-status", choices={"RUNNING", "PENDING", "STOPPED"}, default="RUNNING") + def taskdefs(args): table = [] for taskdef_arn in clients.ecs.list_task_definitions()["taskDefinitionArns"]: table.append(clients.ecs.describe_task_definition(taskDefinition=taskdef_arn)["taskDefinition"]) page_output(tabulate(table, args)) -parser = register_listing_parser(taskdefs, help="List ECS task definitions", - column_defaults=["family", "revision", "containerDefinitions"]) + +parser = register_listing_parser( + taskdefs, help="List ECS task definitions", column_defaults=["family", "revision", "containerDefinitions"] +) + def sirs(args): page_output(tabulate(clients.ec2.describe_spot_instance_requests()["SpotInstanceRequests"], args)) + parser = register_listing_parser(sirs, help="List EC2 spot instance requests") + def sfrs(args): page_output(tabulate(paginate(clients.ec2.get_paginator("describe_spot_fleet_requests")), args)) + parser = register_listing_parser(sfrs, help="List EC2 spot fleet requests") parser.add_argument("--trim-col-names", nargs="+", default=["SpotFleetRequestConfig.", "SpotFleetRequest"]) + def key_pairs(args): page_output(tabulate(resources.ec2.key_pairs.all(), args)) + parser = register_listing_parser(key_pairs, help="List EC2 SSH key pairs", column_defaults=["name", "key_fingerprint"]) + def subnets(args): page_output(filter_and_tabulate(resources.ec2.subnets, args)) + parser = register_filtering_parser(subnets, help="List EC2 VPCs and subnets") + def tables(args): page_output(tabulate(resources.dynamodb.tables.all(), args)) + parser = register_listing_parser(tables, help="List DynamoDB tables") + def subscriptions(args): page_output(tabulate(paginate(clients.sns.get_paginator("list_subscriptions")), args)) -parser = register_listing_parser(subscriptions, help="List SNS subscriptions", - column_defaults=["SubscriptionArn", "Protocol", "Endpoint"]) + +parser = register_listing_parser( + subscriptions, help="List SNS subscriptions", column_defaults=["SubscriptionArn", "Protocol", "Endpoint"] +) + def limits(args): """ @@ -193,8 +239,10 @@ def limits(args): table = clients.ec2.describe_account_attributes(AttributeNames=attrs)["AccountAttributes"] page_output(tabulate(table, args)) + parser = register_parser(limits) + def cmks(args): aliases = {alias.get("TargetKeyId"): alias for alias in paginate(clients.kms.get_paginator("list_aliases"))} table = [] @@ -203,9 +251,12 @@ def cmks(args): table.append(key) page_output(tabulate(table, args)) + parser = register_parser(cmks, help="List KMS Customer Master Keys") + def certificates(args): page_output(tabulate(paginate(clients.acm.get_paginator("list_certificates")), args)) + parser = register_parser(certificates, help="List Amazon Certificate Manager SSL certificates") diff --git a/aegea/missions/arvados-worker/Makefile b/aegea/missions/arvados-worker/Makefile deleted file mode 100644 index 442b6c79..00000000 --- a/aegea/missions/arvados-worker/Makefile +++ /dev/null @@ -1,11 +0,0 @@ -SHELL=/bin/bash - -rootfs.skel: %: %.in environment - -rm -rf $@ - cp -R $@.in $@ - source environment; vars=" $$(compgen -A variable ARVADOS)"; for i in $$(find $@ -type f); do cat $$i | envsubst "$${vars//[[:space:]]/ $$}" | sponge $$i; done - -clean: - -rm -rf rootfs.skel - -.PHONY: rootfs.skel diff --git a/aegea/missions/arvados-worker/config.yml b/aegea/missions/arvados-worker/config.yml deleted file mode 100644 index 536d0dc2..00000000 --- a/aegea/missions/arvados-worker/config.yml +++ /dev/null @@ -1,16 +0,0 @@ -build_ami: - base_ami_product: com.ubuntu.cloud:server:14.04:amd64 - rootfs_skel_dirs: - $append: rootfs.skel - commands: - $extend: - - "apt-add-repository --yes ppa:brightbox/ruby-ng" - - "apt-key adv --keyserver pool.sks-keyservers.net --recv 571659111078ECD7 AC40B2F7 58118E89F3A912897C070ADBF76221572C52609D" - - "echo deb http://apt.arvados.org/ trusty main > /etc/apt/sources.list.d/arvados.list" - - "echo deb https://apt.dockerproject.org/repo ubuntu-trusty main > /etc/apt/sources.list.d/docker.list" - - "apt-get update" - - "apt-get install --yes ruby2.1 ruby2.1-dev libgmp3-dev" - - "apt-get install --yes python-arvados-python-client crunch-run arvados-docker-cleaner python-arvados-fuse slurm-llnl munge" - - "gem install arvados-cli" - - "chown munge:munge /etc/munge/munge.key" - - "chmod 0400 /etc/munge/munge.key" diff --git a/aegea/missions/arvados-worker/environment b/aegea/missions/arvados-worker/environment deleted file mode 120000 index 0859d055..00000000 --- a/aegea/missions/arvados-worker/environment +++ /dev/null @@ -1 +0,0 @@ -../arvados/environment \ No newline at end of file diff --git a/aegea/missions/arvados-worker/rootfs.skel.in/etc/default/munge b/aegea/missions/arvados-worker/rootfs.skel.in/etc/default/munge deleted file mode 120000 index df185fbc..00000000 --- a/aegea/missions/arvados-worker/rootfs.skel.in/etc/default/munge +++ /dev/null @@ -1 +0,0 @@ -../../../../arvados/rootfs.skel.in/etc/default/munge \ No newline at end of file diff --git a/aegea/missions/arvados-worker/rootfs.skel.in/etc/munge/munge.key b/aegea/missions/arvados-worker/rootfs.skel.in/etc/munge/munge.key deleted file mode 120000 index 1ad6f437..00000000 --- a/aegea/missions/arvados-worker/rootfs.skel.in/etc/munge/munge.key +++ /dev/null @@ -1 +0,0 @@ -../../../../arvados/rootfs.skel.in/etc/munge/munge.key \ No newline at end of file diff --git a/aegea/missions/arvados-worker/rootfs.skel.in/etc/slurm-llnl/nodes.conf b/aegea/missions/arvados-worker/rootfs.skel.in/etc/slurm-llnl/nodes.conf deleted file mode 120000 index 92be3120..00000000 --- a/aegea/missions/arvados-worker/rootfs.skel.in/etc/slurm-llnl/nodes.conf +++ /dev/null @@ -1 +0,0 @@ -../../../../arvados/rootfs.skel.in/etc/slurm-llnl/nodes.conf \ No newline at end of file diff --git a/aegea/missions/arvados-worker/rootfs.skel.in/etc/slurm-llnl/slurm.conf b/aegea/missions/arvados-worker/rootfs.skel.in/etc/slurm-llnl/slurm.conf deleted file mode 120000 index 1e8c636b..00000000 --- a/aegea/missions/arvados-worker/rootfs.skel.in/etc/slurm-llnl/slurm.conf +++ /dev/null @@ -1 +0,0 @@ -../../../../arvados/rootfs.skel.in/etc/slurm-llnl/slurm.conf \ No newline at end of file diff --git a/aegea/missions/arvados-worker/rootfs.skel.in/usr/bin/aegea-set-slurm-nodes b/aegea/missions/arvados-worker/rootfs.skel.in/usr/bin/aegea-set-slurm-nodes deleted file mode 120000 index 7f272624..00000000 --- a/aegea/missions/arvados-worker/rootfs.skel.in/usr/bin/aegea-set-slurm-nodes +++ /dev/null @@ -1 +0,0 @@ -../../../../arvados/rootfs.skel.in/usr/bin/aegea-set-slurm-nodes \ No newline at end of file diff --git a/aegea/pricing.py b/aegea/pricing.py index 031f54ec..b7d7ea20 100644 --- a/aegea/pricing.py +++ b/aegea/pricing.py @@ -18,6 +18,7 @@ def describe_services(): client = boto3.client("pricing", region_name="us-east-1") return paginate(client.get_paginator("describe_services")) + def get_instance_type_sort_key(args): def instance_type_sort_key(row): instance_type = re.match(r"(.+)\.(\d*)(.+)", row[args.columns.index("instanceType")]) @@ -25,9 +26,11 @@ def instance_type_sort_key(row): if size == "metal": # type: ignore mx = sys.maxsize size_order = ["nano", "micro", "small", "medium", "large", "xlarge"] - return family, int(mx) if mx else 1, size_order.index(size) if size in size_order else sys.maxsize # type: ignore # noqa + return family, int(mx) if mx else 1, size_order.index(size) if size in size_order else sys.maxsize # type: ignore # noqa + return instance_type_sort_key + def pricing(args): if args.spot: args.columns = args.columns_spot @@ -43,22 +46,37 @@ def pricing(args): if args.sort_by == "instanceType": args.sort_by = get_instance_type_sort_key(args) filters = [("location", region_name(args.region))] + args.filters - table = get_products(args.service_code, region=args.region, filters=filters, terms=args.terms, - max_cache_age_days=args.max_cache_age_days) + table = get_products( + args.service_code, + region=args.region, + filters=filters, + terms=args.terms, + max_cache_age_days=args.max_cache_age_days, + ) page_output(tabulate(table, args)) else: args.columns = ["ServiceCode", "AttributeNames"] page_output(tabulate(describe_services(), args)) + parser = register_parser(pricing, help="List AWS prices") -pricing_arg = parser.add_argument("service_code", nargs="?", help=""" -AWS product offer to list prices for. Run without this argument to see the list of available product service codes.""") +pricing_arg = parser.add_argument( + "service_code", + nargs="?", + help=""" +AWS product offer to list prices for. Run without this argument to see the list of available product service codes.""", +) pricing_arg.completer = lambda **kwargs: [service["ServiceCode"] for service in describe_services()] parser.add_argument("--columns", nargs="+") parser.add_argument("--filters", nargs="+", metavar="NAME=VALUE", type=lambda x: x.split("=", 1), default=[]) parser.add_argument("--terms", nargs="+", default=["OnDemand"]) parser.add_argument("--sort-by") parser.add_argument("--spot", action="store_true", help="Display AWS EC2 Spot Instance pricing history") -parser.add_argument("--spot-start-time", type=Timestamp, default=Timestamp("-1h"), metavar="START", - help="Time to start spot price history." + Timestamp.__doc__) # type: ignore +parser.add_argument( + "--spot-start-time", + type=Timestamp, + default=Timestamp("-1h"), + metavar="START", + help="Time to start spot price history." + Timestamp.__doc__, # type: ignore +) parser.add_argument("--columns-spot") diff --git a/aegea/rds.py b/aegea/rds.py index ca75256e..8289eb31 100644 --- a/aegea/rds.py +++ b/aegea/rds.py @@ -17,53 +17,66 @@ def rds(args): rds_parser.print_help() + rds_parser = register_parser(rds, help="Manage RDS resources", description=__doc__) + def add_tags(resource, prefix, key): resource_id = ":".join([prefix, resource[key]]) arn = ARN(service="rds", resource=resource_id) resource["tags"] = clients.rds.list_tags_for_resource(ResourceName=str(arn))["TagList"] return resource + def list_rds_clusters(): paginator = clients.rds.get_paginator("describe_db_clusters") return [add_tags(i, "cluster", "DBClusterIdentifier") for i in paginate(paginator)] + def list_rds_instances(): paginator = clients.rds.get_paginator("describe_db_instances") return [add_tags(i, "db", "DBInstanceIdentifier") for i in paginate(paginator)] + def ls(args): page_output(tabulate(list_rds_clusters(), args)) + parser = register_parser(ls, parent=rds_parser, help="List RDS clusters") + def instances(args): page_output(tabulate(list_rds_instances(), args)) + parser = register_parser(instances, parent=rds_parser, help="List RDS instances") + def snapshots(args): paginator = clients.rds.get_paginator("describe_db_snapshots") table = [add_tags(i, "snapshot", "DBSnapshotIdentifier") for i in paginate(paginator)] page_output(tabulate(table, args)) + parser = register_parser(snapshots, parent=rds_parser, help="List RDS snapshots") + def create(args): - create_args = dict(DBInstanceIdentifier=args.name, - AllocatedStorage=args.storage, - Engine=args.engine, - StorageType=args.storage_type, - StorageEncrypted=True, - AutoMinorVersionUpgrade=True, - MultiAZ=False, - MasterUsername=args.master_username or getpass.getuser(), - MasterUserPassword=args.master_user_password, - VpcSecurityGroupIds=args.security_groups, - DBInstanceClass=args.db_instance_class, - Tags=encode_tags(args.tags), - CopyTagsToSnapshot=True) + create_args = dict( + DBInstanceIdentifier=args.name, + AllocatedStorage=args.storage, + Engine=args.engine, + StorageType=args.storage_type, + StorageEncrypted=True, + AutoMinorVersionUpgrade=True, + MultiAZ=False, + MasterUsername=args.master_username or getpass.getuser(), + MasterUserPassword=args.master_user_password, + VpcSecurityGroupIds=args.security_groups, + DBInstanceClass=args.db_instance_class, + Tags=encode_tags(args.tags), + CopyTagsToSnapshot=True, + ) if args.db_name: create_args.update(DBName=args.db_name) if args.engine_version: @@ -73,6 +86,7 @@ def create(args): instance = clients.rds.describe_db_instances(DBInstanceIdentifier=args.name)["DBInstances"][0] return {k: instance[k] for k in ("Endpoint", "DbiResourceId", "DBInstanceStatus")} + parser = register_parser(create, parent=rds_parser, help="Create an RDS instance") parser.add_argument("name") parser.add_argument("--db-name") @@ -86,36 +100,44 @@ def create(args): parser.add_argument("--tags", nargs="+", default=[], metavar="TAG_NAME=VALUE") parser.add_argument("--security-groups", nargs="+", default=[]) + def delete(args): clients.rds.delete_db_instance(DBInstanceIdentifier=args.name, SkipFinalSnapshot=True) + parser = register_parser(delete, parent=rds_parser, help="Delete an RDS instance") parser.add_argument("name").completer = lambda **kw: [i["DBInstanceIdentifier"] for i in list_rds_instances()] + def snapshot(args): - return clients.rds.create_db_snapshot(DBInstanceIdentifier=args.instance_name, - DBSnapshotIdentifier=args.snapshot_name, - Tags=encode_tags(args.tags)) + return clients.rds.create_db_snapshot( + DBInstanceIdentifier=args.instance_name, DBSnapshotIdentifier=args.snapshot_name, Tags=encode_tags(args.tags) + ) + parser = register_parser(snapshot, parent=rds_parser, help="Create an RDS snapshot") parser.add_argument("instance_name") parser.add_argument("snapshot_name") parser.add_argument("--tags", nargs="+", default=[]) + def restore(args): tags = dict(tag.split("=", 1) for tag in args.tags) - clients.rds.restore_db_instance_from_db_snapshot(DBInstanceIdentifier=args.instance_name, - DBSnapshotIdentifier=args.snapshot_name, - StorageType=args.storage_type, - AutoMinorVersionUpgrade=True, - MultiAZ=False, - DBInstanceClass=args.db_instance_class, - Tags=[dict(Key=k, Value=v) for k, v in tags.items()], - CopyTagsToSnapshot=True) + clients.rds.restore_db_instance_from_db_snapshot( + DBInstanceIdentifier=args.instance_name, + DBSnapshotIdentifier=args.snapshot_name, + StorageType=args.storage_type, + AutoMinorVersionUpgrade=True, + MultiAZ=False, + DBInstanceClass=args.db_instance_class, + Tags=[dict(Key=k, Value=v) for k, v in tags.items()], + CopyTagsToSnapshot=True, + ) clients.rds.get_waiter("db_instance_available").wait(DBInstanceIdentifier=args.instance_name) instance = clients.rds.describe_db_instances(DBInstanceIdentifier=args.instance_name)["DBInstances"][0] return {k: instance[k] for k in ("Endpoint", "DbiResourceId")} + parser = register_parser(restore, parent=rds_parser, help="Restore an RDS instance from a snapshot") parser.add_argument("snapshot_name") parser.add_argument("instance_name") diff --git a/aegea/rm.py b/aegea/rm.py index 73e2a27b..31d05939 100644 --- a/aegea/rm.py +++ b/aegea/rm.py @@ -22,13 +22,12 @@ def delete_vpc(name, args): vpc = resources.ec2.Vpc(name) - for eigw in paginate(clients.ec2.get_paginator('describe_egress_only_internet_gateways')): + for eigw in paginate(clients.ec2.get_paginator("describe_egress_only_internet_gateways")): for attachment in eigw["Attachments"]: if attachment.get("VpcId") == vpc.id: logger.info("Will delete %s", eigw["EgressOnlyInternetGatewayId"]) clients.ec2.delete_egress_only_internet_gateway( - EgressOnlyInternetGatewayId=eigw["EgressOnlyInternetGatewayId"], - DryRun=not args.force + EgressOnlyInternetGatewayId=eigw["EgressOnlyInternetGatewayId"], DryRun=not args.force ) for igw in vpc.internet_gateways.all(): logger.info("Will delete %s", igw) @@ -59,6 +58,7 @@ def delete_vpc(name, args): logger.info("Will delete %s", vpc) vpc.delete(DryRun=not args.force) + def rm(args): for name in args.names: try: @@ -105,9 +105,9 @@ def rm(args): elif name.startswith("sir-"): clients.ec2.cancel_spot_instance_requests(SpotInstanceRequestIds=[name], DryRun=not args.force) elif name.startswith("sfr-"): - clients.ec2.cancel_spot_fleet_requests(SpotFleetRequestIds=[name], - TerminateInstances=False, - DryRun=not args.force) + clients.ec2.cancel_spot_fleet_requests( + SpotFleetRequestIds=[name], TerminateInstances=False, DryRun=not args.force + ) elif name.startswith("fs-"): efs = clients.efs for mount_target in efs.describe_mount_targets(FileSystemId=name)["MountTargets"]: @@ -143,12 +143,25 @@ def rm(args): if not args.force: logger.info("Dry run succeeded on %s. Run %s again with --force (-f) to actually remove.", args.names, __name__) + parser = register_parser(rm, help="Remove or deprovision resources", description=__doc__) parser.add_argument("names", nargs="+") parser.add_argument("-f", "--force", action="store_true") -parser.add_argument("--key-pair", action="store_true", help=""" -Assume input names are EC2 SSH key pair names (required when deleting key pairs, since they have no ID or ARN)""") -parser.add_argument("--elb", action="store_true", help=""" -Assume input names are Elastic Load Balancer names (required when deleting ELBs, since they have no ID or ARN)""") -parser.add_argument("--lambda", action="store_true", help=""" -Assume input names are Lambda function names (required when deleting Lambdas, since they have no ID or ARN)""") +parser.add_argument( + "--key-pair", + action="store_true", + help=""" +Assume input names are EC2 SSH key pair names (required when deleting key pairs, since they have no ID or ARN)""", +) +parser.add_argument( + "--elb", + action="store_true", + help=""" +Assume input names are Elastic Load Balancer names (required when deleting ELBs, since they have no ID or ARN)""", +) +parser.add_argument( + "--lambda", + action="store_true", + help=""" +Assume input names are Lambda function names (required when deleting Lambdas, since they have no ID or ARN)""", +) diff --git a/aegea/rootfs.skel.build_ami/root/.aws/config b/aegea/rootfs.skel.build_ami/root/.aws/config deleted file mode 120000 index e9d9dfeb..00000000 --- a/aegea/rootfs.skel.build_ami/root/.aws/config +++ /dev/null @@ -1 +0,0 @@ -../../etc/aws.conf \ No newline at end of file diff --git a/aegea/rootfs.skel.build_ami/root/.aws/config b/aegea/rootfs.skel.build_ami/root/.aws/config new file mode 100644 index 00000000..b8575368 --- /dev/null +++ b/aegea/rootfs.skel.build_ami/root/.aws/config @@ -0,0 +1,6 @@ +[plugins] +cwlogs = cwlogs +[preview] +efs=true +cloudfront=true +[default] diff --git a/aegea/secrets.py b/aegea/secrets.py index bf5394f2..91777547 100644 --- a/aegea/secrets.py +++ b/aegea/secrets.py @@ -66,33 +66,44 @@ def parse_principal(args): elif args.iam_user: return resources.iam.User(args.iam_user) else: - logger.warn('You did not specify anyone to grant access to this secret. ' - 'You can specify a principal with "--instance-profile" or "--iam-{role,user,group}".') + logger.warn( + "You did not specify anyone to grant access to this secret. " + 'You can specify a principal with "--instance-profile" or "--iam-{role,user,group}".' + ) + def ensure_policy(principal, secret_arn): - policy_name = "{}.{}.{}".format(__name__, - ARN(principal.arn).resource.replace("/", "."), - ARN(secret_arn).resource.split(":")[1].replace("/", ".")) + policy_name = "{}.{}.{}".format( + __name__, + ARN(principal.arn).resource.replace("/", "."), + ARN(secret_arn).resource.split(":")[1].replace("/", "."), + ) policy_doc = IAMPolicyBuilder(action="secretsmanager:GetSecretValue", resource=secret_arn) policy = ensure_iam_policy(policy_name, policy_doc) principal.attach_policy(PolicyArn=policy.arn) + def secrets(args): secrets_parser.print_help() + secrets_parser = register_parser(secrets, help="Manage application credentials (secrets)", description=__doc__) + def ls(args): - list_secrets_paginator = Paginator(method=clients.secretsmanager.list_secrets, - pagination_config=dict(result_key="SecretList", - input_token="NextToken", - output_token="NextToken", - limit_key="MaxResults"), - model=None) + list_secrets_paginator = Paginator( + method=clients.secretsmanager.list_secrets, + pagination_config=dict( + result_key="SecretList", input_token="NextToken", output_token="NextToken", limit_key="MaxResults" + ), + model=None, + ) page_output(tabulate(paginate(list_secrets_paginator), args)) + ls_parser = register_listing_parser(ls, parent=secrets_parser) + def put(args): if args.generate_ssh_key: ssh_key = new_ssh_key() @@ -110,28 +121,40 @@ def put(args): if parse_principal(args): ensure_policy(parse_principal(args), res["ARN"]) if args.generate_ssh_key: - return dict(ssh_public_key=hostkey_line(hostnames=[], key=ssh_key).strip(), - ssh_key_fingerprint=key_fingerprint(ssh_key)) + return dict( + ssh_public_key=hostkey_line(hostnames=[], key=ssh_key).strip(), ssh_key_fingerprint=key_fingerprint(ssh_key) + ) + put_parser = register_parser(put, parent=secrets_parser) -put_parser.add_argument("--generate-ssh-key", action="store_true", - help="Generate a new SSH key pair and write the private key as the secret value; write the public key to stdout") # noqa +put_parser.add_argument( + "--generate-ssh-key", + action="store_true", + help="Generate a new SSH key pair and write the private key as the secret value; write the public key to stdout", +) # noqa + def get(args): sys.stdout.write(clients.secretsmanager.get_secret_value(SecretId=args.secret_name)["SecretString"]) + get_parser = register_parser(get, parent=secrets_parser) + def delete(args): return clients.secretsmanager.delete_secret(SecretId=args.secret_name) + delete_parser = register_parser(delete, parent=secrets_parser) for parser in put_parser, get_parser, delete_parser: - parser.add_argument("secret_name", - help="List the secret name. For put, pass the secret value on stdin, or via an environment variable with the same name as the secret.") # noqa + parser.add_argument( + "secret_name", + help="List the secret name. For put, pass the secret value on stdin, or via an environment variable with the same name as the secret.", + ) # noqa parser.add_argument("--instance-profile") parser.add_argument("--iam-role") parser.add_argument("--iam-group") - parser.add_argument("--iam-user", - help="Name of IAM instance profile, role, group, or user who will be granted access to secret") + parser.add_argument( + "--iam-user", help="Name of IAM instance profile, role, group, or user who will be granted access to secret" + ) diff --git a/aegea/sfn.py b/aegea/sfn.py index d4b7405e..c1b8a375 100644 --- a/aegea/sfn.py +++ b/aegea/sfn.py @@ -18,22 +18,29 @@ from .util.aws import ARN, clients from .util.printing import BOLD, ENDC, GREEN, RED, YELLOW, page_output, tabulate -sfn_status_colors = dict(RUNNING=GREEN(), SUCCEEDED=BOLD() + GREEN(), - FAILED=BOLD() + RED(), TIMED_OUT=BOLD() + RED(), ABORTED=BOLD() + RED()) +sfn_status_colors = dict( + RUNNING=GREEN(), SUCCEEDED=BOLD() + GREEN(), FAILED=BOLD() + RED(), TIMED_OUT=BOLD() + RED(), ABORTED=BOLD() + RED() +) + def complete_state_machine_name(**kwargs): return [c["name"] for c in paginate(clients.stepfunctions.get_paginator("list_state_machines"))] + def sfn(args): sfn_parser.print_help() + sfn_parser = register_parser(sfn, help="Manage AWS Step Functions", description=__doc__) + def state_machines(args): page_output(tabulate(paginate(clients.stepfunctions.get_paginator("list_state_machines")), args)) + state_machines_parser = register_listing_parser(state_machines, parent=sfn_parser, help="List state machines") + def ls(args): if args.state_machine: sm_arn = ARN(service="states", resource="stateMachine:" + args.state_machine) @@ -53,10 +60,12 @@ def list_executions(state_machine): page_output(tabulate(executions, args)) + ls_parser = register_listing_parser(ls, parent=sfn_parser, help="List executions for state machines in this account") ls_parser.add_argument("--state-machine").completer = complete_state_machine_name ls_parser.add_argument("--status", choices=list(sfn_status_colors)) + def describe(args): if ARN(args.resource_arn).resource.startswith("execution"): desc = clients.stepfunctions.describe_execution(executionArn=args.resource_arn) @@ -67,9 +76,11 @@ def describe(args): desc["definition"] = json.loads(desc.get("definition", "null")) return desc + describe_parser = register_parser(describe, parent=sfn_parser, help="Describe a state machine or execution") describe_parser.add_argument("resource_arn") + def watch(args, print_event_fn=batch.print_event): seen_events = set() # type: Set[str] previous_status = None @@ -79,8 +90,11 @@ def watch(args, print_event_fn=batch.print_event): sys.stderr.write(".") sys.stderr.flush() else: - logger.info("%s %s", exec_desc["executionArn"], - sfn_status_colors[exec_desc["status"]] + exec_desc["status"] + ENDC()) + logger.info( + "%s %s", + exec_desc["executionArn"], + sfn_status_colors[exec_desc["status"]] + exec_desc["status"] + ENDC(), + ) previous_status = exec_desc["status"] history = clients.stepfunctions.get_execution_history(executionArn=str(args.execution_arn)) for event in sorted(history["events"], key=lambda x: x["id"]): @@ -89,9 +103,15 @@ def watch(args, print_event_fn=batch.print_event): for key in event.keys(): if key.endswith("EventDetails") and event[key]: details = event[key] - logger.info("%s %s %s %s %s %s", event["timestamp"], event["type"], - details.get("resourceType", ""), details.get("resource", ""), details.get("name", ""), - json.loads(details.get("parameters", "{}")).get("FunctionName", "")) + logger.info( + "%s %s %s %s %s %s", + event["timestamp"], + event["type"], + details.get("resourceType", ""), + details.get("resource", ""), + details.get("name", ""), + json.loads(details.get("parameters", "{}")).get("FunctionName", ""), + ) if "taskSubmittedEventDetails" in event: if event.get("taskSubmittedEventDetails", {}).get("resourceType") == "batch": job_id = json.loads(event["taskSubmittedEventDetails"]["output"])["JobId"] @@ -111,12 +131,20 @@ def watch(args, print_event_fn=batch.print_event): return SystemExit(json.dumps(last_event, indent=4, default=str)) -watch_parser = register_parser(watch, parent=sfn_parser, - help="Monitor a state machine execution and stream its execution history") +watch_parser = register_parser( + watch, parent=sfn_parser, help="Monitor a state machine execution and stream its execution history" +) watch_parser.add_argument("execution_arn") -event_colors = dict(ExecutionStarted=GREEN(), ExecutionSucceeded=BOLD() + GREEN(), ExecutionFailed=BOLD() + RED(), - ExecutionAborted=BOLD() + RED(), TaskSucceeded=GREEN(), TaskFailed=RED()) +event_colors = dict( + ExecutionStarted=GREEN(), + ExecutionSucceeded=BOLD() + GREEN(), + ExecutionFailed=BOLD() + RED(), + ExecutionAborted=BOLD() + RED(), + TaskSucceeded=GREEN(), + TaskFailed=RED(), +) + def history(args): history = clients.stepfunctions.get_execution_history(executionArn=str(args.execution_arn)) @@ -134,8 +162,9 @@ def history(args): for key in list(event): if key.endswith("EventDetails") and event[key]: event["details"] = event[key] - event["name"] = event["details"].get("name", ":".join([event["details"].get(k, "") - for k in ["resourceType", "resource"]])) + event["name"] = event["details"].get( + "name", ":".join([event["details"].get(k, "") for k in ["resourceType", "resource"]]) + ) if event["name"] == ":": event["name"] = ARN(args.execution_arn).resource.split(":", 1)[-1] elif "FunctionName" in event["details"].get("parameters", ""): @@ -145,11 +174,14 @@ def history(args): events.append(event) page_output(tabulate(events, args)) + history_parser = register_listing_parser(history, parent=sfn_parser, help="List event history for a given execution") history_parser.add_argument("execution_arn") + def stop(args): return clients.stepfunctions.stop_execution(executionArn=args.execution_arn) + stop_parser = register_listing_parser(stop, parent=sfn_parser, help="Stop an execution") stop_parser.add_argument("execution_arn") diff --git a/aegea/ssh.py b/aegea/ssh.py index 4288d6e5..c7e540d5 100644 --- a/aegea/ssh.py +++ b/aegea/ssh.py @@ -53,18 +53,23 @@ opts_by_nargs = { "ssh": {0: "46AaCfGgKkMNnqsTtVvXxYy", 1: "BbcDEeFIiJLlmOopQRSW"}, - "scp": {0: "346BCpqrv", 1: "cFiloPS"} + "scp": {0: "346BCpqrv", 1: "cFiloPS"}, } + def add_bless_and_passthrough_opts(parser, program): - parser.add_argument("--bless-config", default=os.environ.get("BLESS_CONFIG"), - help="Path to a Bless configuration file (or pass via the BLESS_CONFIG environment variable)") + parser.add_argument( + "--bless-config", + default=os.environ.get("BLESS_CONFIG"), + help="Path to a Bless configuration file (or pass via the BLESS_CONFIG environment variable)", + ) parser.add_argument("--use-kms-auth", help=argparse.SUPPRESS) for opt in opts_by_nargs[program][0]: parser.add_argument("-" + opt, action="store_true", help=argparse.SUPPRESS) for opt in opts_by_nargs[program][1]: parser.add_argument("-" + opt, action="append", help=argparse.SUPPRESS) + def extract_passthrough_opts(args, program): opts = [] for opt in opts_by_nargs[program][0]: @@ -75,10 +80,12 @@ def extract_passthrough_opts(args, program): opts.extend(["-" + opt, value]) return opts + @lru_cache(8) def get_instance(name): return resources.ec2.Instance(resolve_instance_id(name)) + def save_instance_public_key(name, use_ssm=False): instance = get_instance(name) tags = {tag["Key"]: tag["Value"] for tag in instance.tags or []} @@ -89,6 +96,7 @@ def save_instance_public_key(name, use_ssm=False): hostname = instance.id if use_ssm else instance.public_dns_name add_ssh_host_key_to_known_hosts(hostname + " " + ssh_host_key + "\n") + def resolve_instance_public_dns(name): instance = get_instance(name) if not getattr(instance, "public_dns_name", None): @@ -96,6 +104,7 @@ def resolve_instance_public_dns(name): raise AegeaException(msg.format(instance, getattr(instance, "state", {}).get("Name"))) return instance.public_dns_name + def get_user_info(): iam_username = ARN.get_iam_username() linux_username, at, domain = iam_username.partition("@") @@ -103,29 +112,37 @@ def get_user_info(): linux_user_id = str(2000 + (int.from_bytes(user_id_bytes, byteorder=sys.byteorder) // 2)) return dict(iam_username=iam_username, linux_username=linux_username, linux_user_id=linux_user_id) + def get_kms_auth_token(session, bless_config, lambda_regional_config): logger.info("Requesting new KMS auth token in %s", lambda_regional_config["aws_region"]) token_not_before = datetime.datetime.utcnow() - datetime.timedelta(minutes=1) token_not_after = token_not_before + datetime.timedelta(hours=1) - token = dict(not_before=token_not_before.strftime("%Y%m%dT%H%M%SZ"), - not_after=token_not_after.strftime("%Y%m%dT%H%M%SZ")) + token = dict( + not_before=token_not_before.strftime("%Y%m%dT%H%M%SZ"), not_after=token_not_after.strftime("%Y%m%dT%H%M%SZ") + ) encryption_context = { "from": session.resource("iam").CurrentUser().user_name, "to": bless_config["lambda_config"]["function_name"], - "user_type": "user" + "user_type": "user", } - kms = session.client('kms', region_name=lambda_regional_config["aws_region"]) - res = kms.encrypt(KeyId=lambda_regional_config["kms_auth_key_id"], - Plaintext=json.dumps(token), - EncryptionContext=encryption_context) + kms = session.client("kms", region_name=lambda_regional_config["aws_region"]) + res = kms.encrypt( + KeyId=lambda_regional_config["kms_auth_key_id"], + Plaintext=json.dumps(token), + EncryptionContext=encryption_context, + ) return base64.b64encode(res["CiphertextBlob"]).decode() + def get_awslambda_client(region_name, credentials): - return boto3.client("lambda", - region_name=region_name, - aws_access_key_id=credentials['AccessKeyId'], - aws_secret_access_key=credentials['SecretAccessKey'], - aws_session_token=credentials['SessionToken']) + return boto3.client( + "lambda", + region_name=region_name, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + def ensure_bless_ssh_cert(ssh_key_name, bless_config, use_kms_auth, max_cert_age=1800): ssh_key = ensure_local_ssh_key(ssh_key_name) @@ -142,33 +159,44 @@ def ensure_bless_ssh_cert(ssh_key_name, bless_config, use_kms_auth, max_cert_age if "oidc_client_id" in bless_config["client_config"]: from cryptography.hazmat.primitives import serialization - aws_oidc_args = ["--client-id", bless_config["client_config"]["oidc_client_id"], - "--issuer-url", bless_config["client_config"]["oidc_issuer_url"]] + + aws_oidc_args = [ + "--client-id", + bless_config["client_config"]["oidc_client_id"], + "--issuer-url", + bless_config["client_config"]["oidc_issuer_url"], + ] aws_role_arn_arg = ["--aws-role-arn", bless_config["client_config"]["role_arn"]] token = json.loads(subprocess.check_output(["aws-oidc", "token"] + aws_oidc_args))["access_token"] creds = json.loads(subprocess.check_output(["aws-oidc", "creds-process"] + aws_oidc_args + aws_role_arn_arg)) awslambda = get_awslambda_client(region_name=lambda_regional_config["aws_region"], credentials=creds) - public_key = ssh_key.key.public_key().public_bytes(encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo) - bless_input = dict(public_key_to_sign=dict(publicKey="".join(public_key.decode().splitlines()[1:-1])), - identity=dict(okta_identity=dict(AccessToken=token))) + public_key = ssh_key.key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + bless_input = dict( + public_key_to_sign=dict(publicKey="".join(public_key.decode().splitlines()[1:-1])), + identity=dict(okta_identity=dict(AccessToken=token)), + ) else: session = boto3.Session(profile_name=bless_config["client_config"]["aws_user_profile"]) iam = session.resource("iam") sts = session.client("sts") assume_role_res = sts.assume_role(RoleArn=bless_config["lambda_config"]["role_arn"], RoleSessionName=__name__) - awslambda = get_awslambda_client(region_name=lambda_regional_config["aws_region"], - credentials=assume_role_res["Credentials"]) - bless_input = dict(bastion_user=iam.CurrentUser().user_name, - bastion_user_ip="0.0.0.0/0", - bastion_ips=",".join(bless_config["client_config"]["bastion_ips"]), - remote_usernames=",".join(bless_config["client_config"]["remote_users"]), - public_key_to_sign=get_public_key_from_pair(ssh_key), - command="*") + awslambda = get_awslambda_client( + region_name=lambda_regional_config["aws_region"], credentials=assume_role_res["Credentials"] + ) + bless_input = dict( + bastion_user=iam.CurrentUser().user_name, + bastion_user_ip="0.0.0.0/0", + bastion_ips=",".join(bless_config["client_config"]["bastion_ips"]), + remote_usernames=",".join(bless_config["client_config"]["remote_users"]), + public_key_to_sign=get_public_key_from_pair(ssh_key), + command="*", + ) if use_kms_auth: - bless_input["kmsauth_token"] = get_kms_auth_token(session=session, - bless_config=bless_config, - lambda_regional_config=lambda_regional_config) + bless_input["kmsauth_token"] = get_kms_auth_token( + session=session, bless_config=bless_config, lambda_regional_config=lambda_regional_config + ) res = awslambda.invoke(FunctionName=bless_config["lambda_config"]["function_name"], Payload=json.dumps(bless_input)) bless_output = json.loads(res["Payload"].read().decode()) @@ -181,6 +209,7 @@ def ensure_bless_ssh_cert(ssh_key_name, bless_config, use_kms_auth, max_cert_age fh.write(bless_output["certificate"]) return ssh_cert_filename + def match_instance_to_bastion(instance, bastions): for bastion_config in bastions: for ipv4_pattern in bastion_config["hosts"]: @@ -188,8 +217,16 @@ def match_instance_to_bastion(instance, bastions): logger.info("Using %s to connect to %s", bastion_config["pattern"], instance) return bastion_config -def prepare_ssh_host_opts(username, hostname, bless_config_filename=None, ssh_key_name=__name__, use_kms_auth=True, - use_ssm=True, use_ec2_instance_connect=True): + +def prepare_ssh_host_opts( + username, + hostname, + bless_config_filename=None, + ssh_key_name=__name__, + use_kms_auth=True, + use_ssm=True, + use_ec2_instance_connect=True, +): instance = get_instance(hostname) if not getattr(get_instance(hostname), "subnet", None): msg = "Unable to resolve subnet for {} (state: {})" @@ -197,9 +234,7 @@ def prepare_ssh_host_opts(username, hostname, bless_config_filename=None, ssh_ke if bless_config_filename: with open(bless_config_filename) as fh: bless_config = yaml.safe_load(fh) - ensure_bless_ssh_cert(ssh_key_name=ssh_key_name, - bless_config=bless_config, - use_kms_auth=use_kms_auth) + ensure_bless_ssh_cert(ssh_key_name=ssh_key_name, bless_config=bless_config, use_kms_auth=use_kms_auth) add_ssh_key_to_agent(ssh_key_name) bastion_config = match_instance_to_bastion(instance=instance, bastions=bless_config["ssh_config"]["bastions"]) if not username: @@ -227,15 +262,17 @@ def prepare_ssh_host_opts(username, hostname, bless_config_filename=None, ssh_ke InstanceId=instance.id, InstanceOSUser=username, SSHPublicKey=ssh_public_key, - AvailabilityZone=instance.placement["AvailabilityZone"] + AvailabilityZone=instance.placement["AvailabilityZone"], ) return [], username + "@" + (instance.id if use_ssm else resolve_instance_public_dns(hostname)) + def init_ssm(instance_id): ssm_plugin_path = ensure_session_manager_plugin() os.environ["PATH"] = os.environ["PATH"] + ":" + os.path.dirname(ssm_plugin_path) return ["-o", "ProxyCommand=aws ssm start-session --document-name AWS-StartSSHSession --target " + instance_id] + def ssh(args): ssh_opts = ["-o", f"ServerAliveInterval={args.server_alive_interval}"] ssh_opts += ["-o", f"ServerAliveCountMax={args.server_alive_count_max}"] @@ -245,23 +282,31 @@ def ssh(args): if args.use_ssm: ssh_opts += init_ssm(get_instance(name).id) - host_opts, hostname = prepare_ssh_host_opts(username=prefix, hostname=name, - bless_config_filename=args.bless_config, - use_kms_auth=args.use_kms_auth, - use_ssm=args.use_ssm, - use_ec2_instance_connect=args.use_ec2_instance_connect) + host_opts, hostname = prepare_ssh_host_opts( + username=prefix, + hostname=name, + bless_config_filename=args.bless_config, + use_kms_auth=args.use_kms_auth, + use_ssm=args.use_ssm, + use_ec2_instance_connect=args.use_ec2_instance_connect, + ) os.execvp("ssh", ["ssh"] + ssh_opts + host_opts + [hostname] + args.ssh_args) + ssh_parser = register_parser(ssh, help="Connect to an EC2 instance", description=__doc__) ssh_parser.add_argument("name") -ssh_parser.add_argument("ssh_args", nargs=argparse.REMAINDER, - help="Arguments to pass to ssh; please see " + BOLD("man ssh") + " for details") +ssh_parser.add_argument( + "ssh_args", + nargs=argparse.REMAINDER, + help="Arguments to pass to ssh; please see " + BOLD("man ssh") + " for details", +) ssh_parser.add_argument("--server-alive-interval", help=argparse.SUPPRESS) ssh_parser.add_argument("--server-alive-count-max", help=argparse.SUPPRESS) ssh_parser.add_argument("--no-ssm", action="store_false", dest="use_ssm") ssh_parser.add_argument("--no-ec2-instance-connect", action="store_false", dest="use_ec2_instance_connect") add_bless_and_passthrough_opts(ssh_parser, "ssh") + def scp(args): """ Transfer files to or from EC2 instance. @@ -276,25 +321,36 @@ def scp(args): if args.use_ssm and not ssm_init_complete: scp_opts += init_ssm(get_instance(hostname).id) ssm_init_complete = True - host_opts, hostname = prepare_ssh_host_opts(username=username, hostname=hostname, - bless_config_filename=args.bless_config, - use_kms_auth=args.use_kms_auth, use_ssm=args.use_ssm) + host_opts, hostname = prepare_ssh_host_opts( + username=username, + hostname=hostname, + bless_config_filename=args.bless_config, + use_kms_auth=args.use_kms_auth, + use_ssm=args.use_ssm, + ) args.scp_args[i] = hostname + colon + path os.execvp("scp", ["scp"] + scp_opts + host_opts + args.scp_args) + scp_parser = register_parser(scp, help="Transfer files to or from EC2 instance", description=scp.__doc__) -scp_parser.add_argument("scp_args", nargs=argparse.REMAINDER, - help="Arguments to pass to scp; please see " + BOLD("man scp") + " for details") +scp_parser.add_argument( + "scp_args", + nargs=argparse.REMAINDER, + help="Arguments to pass to scp; please see " + BOLD("man scp") + " for details", +) scp_parser.add_argument("--no-ssm", action="store_false", dest="use_ssm") add_bless_and_passthrough_opts(scp_parser, "scp") + def run(args): run_command(args.command, instance_ids=[get_instance(args.instance).id]) + run_parser = register_parser(run, help="Run a command on an EC2 instance", description=run_command.__doc__) run_parser.add_argument("instance") run_parser.add_argument("command") + def ssh_to_ecs_container(instance_id, container_id, ssh_args, use_ssm): ssh_args = ["-t", instance_id, "sudo", "docker", "exec", "--interactive", "--tty", container_id] + ssh_args if "BLESS_CONFIG" not in os.environ: # bless will provide the username, otherwise use Amazon Linux default diff --git a/aegea/top.py b/aegea/top.py index 5a84bf67..59f9f09b 100644 --- a/aegea/top.py +++ b/aegea/top.py @@ -23,6 +23,7 @@ def get_stats_for_region(region): num_instances, num_amis, num_vpcs, num_enis, num_volumes = ["Access denied"] * 5 # type: ignore return [region, num_instances, num_amis, num_vpcs, num_enis, num_volumes] + def top(args): table = [] # type: List[List] columns = ["Region", "Instances", "AMIs", "VPCs", "Network interfaces", "EBS volumes"] @@ -30,4 +31,5 @@ def top(args): table = list(executor.map(get_stats_for_region, boto3.Session().get_available_regions("ec2"))) page_output(format_table(table, column_names=columns, max_col_width=args.max_col_width)) -parser = register_parser(top, help='Show an overview of AWS resources per region') + +parser = register_parser(top, help="Show an overview of AWS resources per region") diff --git a/aegea/util/__init__.py b/aegea/util/__init__.py index 8f89b5c1..ef5f2b59 100644 --- a/aegea/util/__init__.py +++ b/aegea/util/__init__.py @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) + def wait_for_port(host, port, timeout=600, print_progress=True): if print_progress: sys.stderr.write("Waiting for {}:{}...".format(host, port)) @@ -38,6 +39,7 @@ def wait_for_port(host, port, timeout=600, print_progress=True): if time.time() - start_time > timeout: raise + def validate_hostname(hostname): if len(hostname) > 255: raise Exception("Hostname {} is longer than 255 characters".format(hostname)) @@ -47,29 +49,31 @@ def validate_hostname(hostname): if not all(allowed.match(x) for x in hostname.split(".")): raise Exception("Hostname {} is not RFC 1123 compliant".format(hostname)) + class VerboseRepr: def __repr__(self): return "<{module}.{classname} object at 0x{mem_loc:x}: {dict}>".format( - module=self.__module__, - classname=self.__class__.__name__, - mem_loc=id(self), - dict=Repr().repr(self.__dict__) + module=self.__module__, classname=self.__class__.__name__, mem_loc=id(self), dict=Repr().repr(self.__dict__) ) + def natural_sort(i): return sorted(i, key=lambda s: [int(t) if t.isdigit() else t.lower() for t in re.split(r"(\d+)", s)]) + def paginate(boto3_paginator, *args, **kwargs): for page in boto3_paginator.paginate(*args, **kwargs): for result_key in boto3_paginator.result_keys: yield from page.get(result_key.parsed.get("value"), []) + class Timestamp(datetime): """ Integer inputs are interpreted as milliseconds since the epoch. Sub-second precision is discarded. Suffixes (s, m, h, d, w) are supported. Negative inputs (e.g. -5m) are interpreted as relative to the current date. Other inputs (e.g. 2020-01-01, 15:20) are parsed using the dateutil parser. """ + _precision = {} # type: Dict[Any, Any] def __new__(cls, t, snap=0): @@ -77,13 +81,14 @@ def __new__(cls, t, snap=0): t = int(t) if not isinstance(t, (str, bytes)): from dateutil.tz import tzutc + return datetime.fromtimestamp(t // 1000, tz=tzutc()) try: units = ["weeks", "days", "hours", "minutes", "seconds"] diffs = {u: float(t[:-1]) for u in units if u.startswith(t[-1])} # type: ignore if len(diffs) == 1: # Snap > 0 governs the rounding of units (hours, minutes and seconds) to 0 to improve cache performance - snap_units = {u.rstrip("s"): 0 for u in units[units.index(list(diffs)[0]) + snap:]} if snap else {} + snap_units = {u.rstrip("s"): 0 for u in units[units.index(list(diffs)[0]) + snap :]} if snap else {} snap_units.pop("day", None) snap_units.update(microsecond=0) ts = datetime.now().replace(**snap_units) + relativedelta(**diffs) # type: ignore @@ -99,20 +104,24 @@ def match_precision(cls, timestamp, precision_source): logger.debug("Discarding timestamp %s %s precision", timestamp, ", ".join(cls._precision[precision_source])) return timestamp.replace(**cls._precision.get(precision_source, dict(microsecond=0))) + def add_time_bound_args(p, snap=0, start="-7d"): t = partial(Timestamp, snap=snap) p.add_argument("--start-time", type=t, default=Timestamp(start, snap=snap), help=Timestamp.__doc__, metavar="START") p.add_argument("--end-time", type=t, help=Timestamp.__doc__, metavar="END") + class hashabledict(dict): def __hash__(self): return hash(tuple(sorted(self.items()))) + def describe_cidr(cidr): import ipaddress import socket import ipwhois + address = ipaddress.ip_network(str(cidr)).network_address try: whois = ipwhois.IPWhois(address).lookup_rdap() @@ -125,12 +134,14 @@ def describe_cidr(cidr): whois_names = [cidr] return ", ".join(str(n) for n in whois_names) + def gzip_compress_bytes(payload): buf = io.BytesIO() with gzip.GzipFile(fileobj=buf, mode="w", mtime=0) as gzfh: gzfh.write(payload) return buf.getvalue() + def get_mkfs_command(fs_type="xfs", label="aegveph"): if fs_type == "xfs": return "mkfs.xfs -L {} -f ".format(label) @@ -139,6 +150,7 @@ def get_mkfs_command(fs_type="xfs", label="aegveph"): else: raise Exception("unknown fs_type: {}".format(fs_type)) + class ThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor): def __init__(self, **kwargs): max_workers = kwargs.pop("max_workers", min(8, (os.cpu_count() or 1) + 4)) diff --git a/aegea/util/aws/__init__.py b/aegea/util/aws/__init__.py index cd30caa8..8bd2cd5a 100644 --- a/aegea/util/aws/__init__.py +++ b/aegea/util/aws/__init__.py @@ -28,6 +28,7 @@ def get_ssm_parameter(name): return clients.ssm.get_parameter(Name=name)["Parameter"]["Value"] + def locate_ami(distribution, release, architecture): """ Examples:: @@ -54,6 +55,7 @@ def locate_ami(distribution, release, architecture): return resources.ec2.Image(ami_id) raise AegeaException(f"No AMI found for {distribution} {release} {architecture}") + def ensure_vpc(): """ If a default VPC exists in the current account/region, return it; otherwise, return the first VPC managed by aegea, @@ -66,6 +68,7 @@ def ensure_vpc(): break else: from ... import config + logger.info("Creating VPC with CIDR %s", config.vpc.cidr[ARN.get_region()]) tags = dict(Name="aegea-vpc", managedBy="aegea") tag_spec = dict(ResourceType="vpc", Tags=encode_tags(tags)) @@ -93,10 +96,12 @@ def ensure_vpc(): route_table.create_route(DestinationIpv6CidrBlock="::/0", EgressOnlyInternetGatewayId=eigw_id) return vpc + def availability_zones(): for az in clients.ec2.describe_availability_zones()["AvailabilityZones"]: yield az["ZoneName"] + def ensure_subnet(vpc, availability_zone=None, assign_ipv6_cidr_blocks=False): """ Find and return a subnet in the given VPC, or create one subnet per AZ and return one of them if no suitable subnet @@ -112,18 +117,21 @@ def ensure_subnet(vpc, availability_zone=None, assign_ipv6_cidr_blocks=False): break else: from ... import config + subnet_cidrs = ip_network(str(config.vpc.cidr[ARN.get_region()])).subnets(new_prefix=config.vpc.subnet_prefix) subnets = {} for az, subnet_cidr in zip(availability_zones(), subnet_cidrs): logger.info("Creating subnet with CIDR %s in %s, %s", subnet_cidr, vpc, az) tags = dict(Name="aegea-subnet", managedBy="aegea") tag_spec = dict(ResourceType="subnet", Tags=encode_tags(tags)) - subnets[az] = resources.ec2.create_subnet(VpcId=vpc.id, CidrBlock=str(subnet_cidr), AvailabilityZone=az, - TagSpecifications=[tag_spec]) + subnets[az] = resources.ec2.create_subnet( + VpcId=vpc.id, CidrBlock=str(subnet_cidr), AvailabilityZone=az, TagSpecifications=[tag_spec] + ) time.sleep(1) clients.ec2.get_waiter("subnet_available").wait(SubnetIds=[subnets[az].id]) - clients.ec2.modify_subnet_attribute(SubnetId=subnets[az].id, - MapPublicIpOnLaunch=dict(Value=config.vpc.map_public_ip_on_launch)) + clients.ec2.modify_subnet_attribute( + SubnetId=subnets[az].id, MapPublicIpOnLaunch=dict(Value=config.vpc.map_public_ip_on_launch) + ) if assign_ipv6_cidr_blocks: vpc_cidr_block = vpc.ipv6_cidr_block_association_set[0]["Ipv6CidrBlock"] subnet_cidr_blocks = list(ip_network(vpc_cidr_block).subnets(new_prefix=64)) @@ -132,11 +140,12 @@ def ensure_subnet(vpc, availability_zone=None, assign_ipv6_cidr_blocks=False): clients.ec2.associate_subnet_cidr_block(SubnetId=subnet.id, Ipv6CidrBlock=str(subnet_cidr_blocks[i])) clients.ec2.modify_subnet_attribute( SubnetId=subnet.id, - AssignIpv6AddressOnCreation=dict(Value=config.vpc.assign_ipv6_address_on_creation) + AssignIpv6AddressOnCreation=dict(Value=config.vpc.assign_ipv6_address_on_creation), ) subnet = subnets[availability_zone] if availability_zone is not None else list(subnets.values())[0] return subnet + def ensure_ingress_rule(security_group, **kwargs): cidr_ip, source_security_group_id = kwargs.pop("CidrIp"), kwargs.pop("SourceSecurityGroupId") for rule in security_group.ip_permissions: @@ -153,6 +162,7 @@ def ensure_ingress_rule(security_group, **kwargs): authorize_ingress_args["IpPermissions"][0]["UserIdGroupPairs"] = [dict(GroupId=source_security_group_id)] security_group.authorize_ingress(**authorize_ingress_args) + def resolve_security_group(name, vpc=None): if vpc is None: vpc = ensure_vpc() @@ -162,6 +172,7 @@ def resolve_security_group(name, vpc=None): return security_group raise KeyError(name) + def ensure_security_group(name, vpc, tcp_ingress=frozenset()): try: security_group = resolve_security_group(name, vpc) @@ -177,10 +188,17 @@ def ensure_security_group(name, vpc, tcp_ingress=frozenset()): source_security_group_id = None if "source_security_group_name" in rule: source_security_group_id = resolve_security_group(rule["source_security_group_name"], vpc).id - ensure_ingress_rule(security_group, IpProtocol="tcp", FromPort=rule["port"], ToPort=rule["port"], - CidrIp=rule.get("cidr"), SourceSecurityGroupId=source_security_group_id) + ensure_ingress_rule( + security_group, + IpProtocol="tcp", + FromPort=rule["port"], + ToPort=rule["port"], + CidrIp=rule.get("cidr"), + SourceSecurityGroupId=source_security_group_id, + ) return security_group + class S3BucketLifecycleBuilder: def __init__(self, **kwargs): self.rules = [] @@ -203,8 +221,10 @@ def add_rule(self, prefix="", tags=None, expiration=None, transitions=None, abor def __iter__(self): yield ("Rules", self.rules) + def ensure_s3_bucket(name=None, policy=None, lifecycle=None, encryption=None): from ... import config + if name is None: name = f"aegea-assets-{ARN.get_account_id()}" bucket = resources.s3.Bucket(name) @@ -233,6 +253,7 @@ def ensure_s3_bucket(name=None, policy=None, lifecycle=None, encryption=None): bucket.LifecycleConfiguration().put(LifecycleConfiguration=dict(lifecycle)) return bucket + class ARN: arn = partition = service = region = account_id = resource = "" fields = "arn partition service region account_id resource".split() @@ -275,6 +296,7 @@ def get_iam_username(cls): def __str__(self): return ":".join(getattr(self, field) for field in self.fields) + def encode_tags(tags, case="title"): if isinstance(tags, (list, tuple)): tags = dict(tag.split("=", 1) for tag in tags) @@ -284,18 +306,23 @@ def encode_tags(tags, case="title"): elif case == "lower": return [dict(key=k, value=v) for k, v in tags.items()] + def decode_tags(tags): return {tag["Key"]: tag["Value"] for tag in tags} + def add_tags(resource, dry_run=False, **tags): return resource.create_tags(Tags=encode_tags(tags), DryRun=dry_run) + def filter_by_tags(collection, **tags): return collection.filter(Filters=[dict(Name="tag:" + k, Values=[v]) for k, v in tags.items()]) + def filter_by_tag_keys(collection, *tag_keys): return collection.filter(Filters=[dict(Name="tag-key", Values=[k]) for k in tag_keys]) + def resolve_instance_id(name): filter_name = "dns-name" if name.startswith("ec2") and name.endswith("compute.amazonaws.com") else "tag:Name" if name.startswith("i-"): @@ -306,6 +333,7 @@ def resolve_instance_id(name): except IndexError: raise AegeaException(f'Could not resolve "{name}" to a known instance') + def get_bdm(ami=None, max_devices=12, ebs_storage=None): # Note: d2.8xl and hs1.8xl have 24 devices bdm = [dict(VirtualName="ephemeral" + str(i), DeviceName="xvd" + chr(ord("b") + i)) for i in range(max_devices)] @@ -324,20 +352,24 @@ def get_bdm(ami=None, max_devices=12, ebs_storage=None): bdm.extend(ebs_bdm) return bdm + def get_metadata(category): imds = IMDS() token = imds._fetch_metadata_token() return imds._get_request(url_path=f"latest/meta-data/{category}", retry_func=None, token=token).text + def get_ecs_task_metadata(path="/task"): res = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"] + path) res.raise_for_status() return res.content.decode() + def expect_error_codes(exception, *codes): if getattr(exception, "response", None) and exception.response.get("Error", {}).get("Code", {}) not in codes: raise + def resolve_ami(ami=None, arch="x86_64", tags=frozenset(), tag_keys=frozenset()): """ Find an AMI by ID, name, or tags. @@ -353,8 +385,10 @@ def resolve_ami(ami=None, arch="x86_64", tags=frozenset(), tag_keys=frozenset()) return resources.ec2.Image(ami) else: if ami is None: - filters = dict(Owners=["self"], - Filters=[dict(Name="state", Values=["available"]), dict(Name="architecture", Values=[arch])]) + filters = dict( + Owners=["self"], + Filters=[dict(Name="state", Values=["available"]), dict(Name="architecture", Values=[arch])], + ) else: filters = dict(Owners=["self"], Filters=[dict(Name="name", Values=[ami])]) all_amis = resources.ec2.images.filter(**filters) @@ -374,24 +408,29 @@ def resolve_ami(ami=None, arch="x86_64", tags=frozenset(), tag_keys=frozenset()) raise AegeaException("Could not resolve AMI {}".format(dict(tags, ami=ami))) return amis[-1] + offers_api = "https://pricing.us-east-1.amazonaws.com/offers/v1.0" + def region_name(region_id): region_names, region_ids = {}, {} from botocore import loaders + for partition_data in loaders.create_loader().load_data("endpoints")["partitions"]: region_names.update({k: v["description"] for k, v in partition_data["regions"].items()}) region_ids.update({v: k for k, v in region_names.items()}) return region_names[region_id] + def get_pricing_data(service_code, filters=None, max_cache_age_days=30): from ... import config if filters is None: filters = [("location", region_name(clients.ec2.meta.region_name))] - get_products_args = dict(ServiceCode=service_code, - Filters=[dict(Type="TERM_MATCH", Field=k, Value=v) for k, v in filters]) + get_products_args = dict( + ServiceCode=service_code, Filters=[dict(Type="TERM_MATCH", Field=k, Value=v) for k, v in filters] + ) cache_key = hashlib.sha256(json.dumps(get_products_args, sort_keys=True).encode()).hexdigest()[:32] service_code_filename = os.path.join(config.user_config_dir, f"pricing_cache_{cache_key}.json.gz") try: @@ -412,6 +451,7 @@ def get_pricing_data(service_code, filters=None, max_cache_age_days=30): print(e, file=sys.stderr) return pricing_data + def get_products(service_code, region=None, filters=None, terms=None, max_cache_age_days=30): from ... import config @@ -432,6 +472,7 @@ def get_products(service_code, region=None, filters=None, terms=None, max_cache_ for price_dimension in term["priceDimensions"].values(): yield dict(dict(product, **term["termAttributes"]), **price_dimension) + def get_ondemand_price_usd(region, instance_type, **kwargs): from ... import config @@ -441,12 +482,14 @@ def get_ondemand_price_usd(region, instance_type, **kwargs): continue return product["pricePerUnit"]["USD"] + def get_iam_role_for_instance(instance): instance = resources.ec2.Instance(resolve_instance_id(instance)) profile = resources.iam.InstanceProfile(ARN(instance.iam_instance_profile["Arn"]).resource.split("/")[1]) assert len(profile.roles) <= 1 return profile.roles[0] if profile.roles else None + def get_elb_dns_aliases(): dns_aliases = {} for zone in paginate(clients.route53.get_paginator("list_hosted_zones")): @@ -457,20 +500,25 @@ def get_elb_dns_aliases(): dns_aliases[value.rstrip(".").replace("dualstack.", "")] = rrs["Name"] return dns_aliases + ip_ranges_api = "https://ip-ranges.amazonaws.com/ip-ranges.json" + def get_public_ip_ranges(service="AMAZON", region=None): if region is None: region = ARN.get_region() ranges = requests.get(ip_ranges_api).json()["prefixes"] return [r for r in ranges if r["service"] == service and r["region"] == region] + def make_waiter(op, path, expected, matcher="path", delay=1, max_attempts=30): from botocore.waiter import SingleWaiterConfig, Waiter + acceptor = dict(matcher=matcher, argument=path, expected=expected, state="success") waiter_cfg = dict(operation=op.__name__, delay=delay, maxAttempts=max_attempts, acceptors=[acceptor]) return Waiter(op.__name__, SingleWaiterConfig(waiter_cfg), op) + def resolve_log_group(name): for log_group in paginate(clients.logs.get_paginator("describe_log_groups"), logGroupNamePrefix=name): if log_group["logGroupName"] == name: @@ -478,6 +526,7 @@ def resolve_log_group(name): else: raise AegeaException(f"Log group {name} not found") + def ensure_log_group(name): try: return resolve_log_group(name) @@ -488,6 +537,7 @@ def ensure_log_group(name): pass return resolve_log_group(name) + def ensure_ecs_cluster(name): res = clients.ecs.describe_clusters(clusters=[name]) if res.get("failures"): @@ -497,21 +547,29 @@ def ensure_ecs_cluster(name): raise AegeaException(res) return res["clusters"][0] -def get_cloudwatch_metric_stats(namespace, name, start_time=None, end_time=None, period=None, statistic="Average", - resource=None, **kwargs): + +def get_cloudwatch_metric_stats( + namespace, name, start_time=None, end_time=None, period=None, statistic="Average", resource=None, **kwargs +): start_time = datetime.utcnow() - period * 60 if start_time is None else start_time end_time = datetime.utcnow() if end_time is None else end_time cloudwatch = resources.cloudwatch if resource is None else resource metric = cloudwatch.Metric(namespace, name) - get_stats_args = dict(StartTime=start_time, EndTime=end_time, Statistics=[statistic], - Dimensions=[dict(Name=k, Value=v) for k, v in kwargs.items()]) + get_stats_args = dict( + StartTime=start_time, + EndTime=end_time, + Statistics=[statistic], + Dimensions=[dict(Name=k, Value=v) for k, v in kwargs.items()], + ) if period is not None: get_stats_args.update(Period=period) return metric.get_statistics(**get_stats_args) + def instance_type_completer(max_cache_age_days=30, **kwargs): return [p["instanceType"] for p in get_products("AmazonEC2")] + instance_storage_shellcode = """ aegea_bd=( $(shopt -s nullglob; readlink -f /dev/disk/by-id/nvme-Amazon_EC2_NVMe_Instance_Storage_AWS{{?????????????????,?????????????????-ns-?}} | sort | uniq) ) if [ ! -e /dev/md0 ]; then mdadm --create /dev/md0 --force --auto=yes --level=0 --chunk=256 --raid-devices=${{#aegea_bd[@]}} ${{aegea_bd[@]}}; {mkfs} /dev/md0; fi diff --git a/aegea/util/aws/_boto3_loader.py b/aegea/util/aws/_boto3_loader.py index f85f29d1..91866824 100644 --- a/aegea/util/aws/_boto3_loader.py +++ b/aegea/util/aws/_boto3_loader.py @@ -12,7 +12,7 @@ def __getattr__(self, attr): if attr == "__name__": return "Loader" if attr == "__bases__": - return (object, ) + return (object,) if attr == "__all__": return list(self.cache[self.factory]) if attr == "__file__": @@ -26,7 +26,9 @@ def __getattr__(self, attr): self.cache["client"][attr] = self.cache["resource"][attr].meta.client else: import boto3 + factory = getattr(boto3, self.factory) - self.cache[self.factory][attr] = factory(attr.replace("_", "-"), - **self.client_kwargs.get(attr, self.client_kwargs["default"])) + self.cache[self.factory][attr] = factory( + attr.replace("_", "-"), **self.client_kwargs.get(attr, self.client_kwargs["default"]) + ) return self.cache[self.factory][attr] diff --git a/aegea/util/aws/batch.py b/aegea/util/aws/batch.py index 210bb68e..38be2921 100644 --- a/aegea/util/aws/batch.py +++ b/aegea/util/aws/batch.py @@ -35,7 +35,9 @@ sed -i -e "s|/archive.ubuntu.com|/{region}.ec2.archive.ubuntu.com|g" /etc/apt/sources.list apt-get update -qq""" -ebs_vol_mgr_shellcode = apt_mgr_shellcode + """ +ebs_vol_mgr_shellcode = ( + apt_mgr_shellcode + + """ apt-get install -qqy --no-install-suggests --no-install-recommends httpie awscli jq lsof python3-virtualenv > /dev/null python3 -m virtualenv -q --python=python3 /opt/aegea-venv /opt/aegea-venv/bin/pip install -q argcomplete requests boto3 tweak pyyaml @@ -43,7 +45,8 @@ aegea_ebs_cleanup() {{ echo Detaching EBS volume $aegea_ebs_vol_id; cd /; /opt/aegea-venv/bin/aegea ebs detach --unmount --force --delete $aegea_ebs_vol_id; }} trap aegea_ebs_cleanup EXIT aegea_ebs_vol_id=$(/opt/aegea-venv/bin/aegea ebs create --size-gb {size_gb} --volume-type {volume_type} --tags managedBy=aegea batchJobId=$AWS_BATCH_JOB_ID --attach --format ext4 --mount {mountpoint} | jq -r .VolumeId) -""" # noqa +""" +) # noqa efs_vol_shellcode = """mkdir -p {efs_mountpoint} IMDS=http://169.254.169.254/latest @@ -53,17 +56,22 @@ NFS_ENDPOINT=$(echo "$AEGEA_EFS_DESC" | jq -r ".[] | select(.SubnetId == env.SUBNET_ID) | .IpAddress") mount -t nfs -o nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2 $NFS_ENDPOINT:/ {efs_mountpoint}""" # noqa -instance_storage_mgr_shellcode = apt_mgr_shellcode + """ -apt-get install -qqy --no-install-suggests --no-install-recommends mdadm""" + instance_storage_shellcode +instance_storage_mgr_shellcode = ( + apt_mgr_shellcode + + """ +apt-get install -qqy --no-install-suggests --no-install-recommends mdadm""" + + instance_storage_shellcode +) + def ensure_dynamodb_table(name, hash_key_name, read_capacity_units=5, write_capacity_units=5): try: - table = resources.dynamodb.create_table(TableName=name, - KeySchema=[dict(AttributeName=hash_key_name, KeyType="HASH")], - AttributeDefinitions=[dict(AttributeName=hash_key_name, - AttributeType="S")], - ProvisionedThroughput=dict(ReadCapacityUnits=read_capacity_units, - WriteCapacityUnits=write_capacity_units)) + table = resources.dynamodb.create_table( + TableName=name, + KeySchema=[dict(AttributeName=hash_key_name, KeyType="HASH")], + AttributeDefinitions=[dict(AttributeName=hash_key_name, AttributeType="S")], + ProvisionedThroughput=dict(ReadCapacityUnits=read_capacity_units, WriteCapacityUnits=write_capacity_units), + ) except ClientError as e: expect_error_codes(e, "ResourceInUseException") table = resources.dynamodb.Table(name) @@ -78,19 +86,29 @@ def get_command_and_env(args): args.privileged = True args.volumes.append(["/dev", "/dev"]) if args.mount_instance_storage: - shellcode += instance_storage_mgr_shellcode.strip().format(region=ARN.get_region(), - mountpoint=args.mount_instance_storage, - mkfs=get_mkfs_command(fs_type="ext4")).splitlines() + shellcode += ( + instance_storage_mgr_shellcode.strip() + .format( + region=ARN.get_region(), mountpoint=args.mount_instance_storage, mkfs=get_mkfs_command(fs_type="ext4") + ) + .splitlines() + ) if args.storage: for mountpoint, size_gb in args.storage: volume_type = "st1" if args.volume_type: volume_type = args.volume_type - shellcode += ebs_vol_mgr_shellcode.strip().format(region=ARN.get_region(), - aegea_version=__version__, - size_gb=size_gb, - volume_type=volume_type, - mountpoint=mountpoint).splitlines() + shellcode += ( + ebs_vol_mgr_shellcode.strip() + .format( + region=ARN.get_region(), + aegea_version=__version__, + size_gb=size_gb, + volume_type=volume_type, + mountpoint=mountpoint, + ) + .splitlines() + ) elif args.efs_storage: args.privileged = True if "=" in args.efs_storage: @@ -116,17 +134,17 @@ def get_command_and_env(args): args.execute.seek(0) bucket.upload_fileobj(args.execute, key_name) payload_url = clients.s3.generate_presigned_url( - ClientMethod='get_object', - Params=dict(Bucket=bucket.name, Key=key_name), - ExpiresIn=3600 * 24 * 7 + ClientMethod="get_object", Params=dict(Bucket=bucket.name, Key=key_name), ExpiresIn=3600 * 24 * 7 ) tmpdir_fmt = "${AWS_BATCH_CE_NAME:-$AWS_EXECUTION_ENV}.${AWS_BATCH_JQ_NAME:-}.${AWS_BATCH_JOB_ID:-}.XXXXX" - shellcode += ['BATCH_SCRIPT=$(mktemp --tmpdir "{tmpdir_fmt}")'.format(tmpdir_fmt=tmpdir_fmt), - "apt-get update -qq", - "apt-get install -qqy --no-install-suggests --no-install-recommends curl ca-certificates gnupg", - "curl -L '{payload_url}' > $BATCH_SCRIPT".format(payload_url=payload_url), - "chmod +x $BATCH_SCRIPT", - "$BATCH_SCRIPT"] + shellcode += [ + 'BATCH_SCRIPT=$(mktemp --tmpdir "{tmpdir_fmt}")'.format(tmpdir_fmt=tmpdir_fmt), + "apt-get update -qq", + "apt-get install -qqy --no-install-suggests --no-install-recommends curl ca-certificates gnupg", + "curl -L '{payload_url}' > $BATCH_SCRIPT".format(payload_url=payload_url), + "chmod +x $BATCH_SCRIPT", + "$BATCH_SCRIPT", + ] elif args.wdl: bucket = ensure_s3_bucket(args.staging_s3_bucket) wdl_key_name = "{}.wdl".format(hashlib.sha256(args.wdl.read()).hexdigest()) @@ -143,18 +161,23 @@ def get_command_and_env(args): "cd /mnt", "aws s3 cp s3://{bucket}/{key} .".format(bucket=bucket.name, key=wdl_key_name), "aws s3 cp s3://{bucket}/{key} wdl_input.json".format(bucket=bucket.name, key=wdl_input_key_name), - "miniwdl run --dir /mnt --verbose --error-json {} --input wdl_input.json > wdl_output.json".format(wdl_key_name), # noqa - "aws s3 cp wdl_output.json s3://{bucket}/wdl_output/${{AWS_BATCH_JOB_ID}}.json".format(bucket=bucket.name) + "miniwdl run --dir /mnt --verbose --error-json {} --input wdl_input.json > wdl_output.json".format( + wdl_key_name + ), # noqa + "aws s3 cp wdl_output.json s3://{bucket}/wdl_output/${{AWS_BATCH_JOB_ID}}.json".format(bucket=bucket.name), ] args.command = bash_cmd_preamble + shellcode + (args.command or []) return args.command, args.environment + def get_ecr_image_uri(tag): return f"{ARN.get_account_id()}.dkr.ecr.{ARN.get_region()}.amazonaws.com/{tag}" + def ensure_ecr_image(tag): pass + def set_ulimits(args, container_props): if args.ulimits: container_props.setdefault("ulimits", []) @@ -162,6 +185,7 @@ def set_ulimits(args, container_props): name, value = ulimit.split(":", 1) container_props["ulimits"].append(dict(name=name, hardLimit=int(value), softLimit=int(value))) + def get_volumes_and_mountpoints(args): volumes, mount_points = [], [] if args.wdl and ["/var/run/docker.sock", "/var/run/docker.sock"] not in args.volumes: @@ -181,6 +205,7 @@ def get_volumes_and_mountpoints(args): mount_points.append(mount_spec) return volumes, mount_points + def ensure_job_definition(args): def get_jd_arn_and_job_name(jd_res): job_name = args.name or f"{jd_res['jobDefinitionName']}_{jd_res['revision']}" @@ -189,8 +214,9 @@ def get_jd_arn_and_job_name(jd_res): if args.ecs_image: args.image = get_ecr_image_uri(args.ecs_image) container_props = dict(image=args.image, user=args.user, privileged=args.privileged) - container_props.update(volumes=[], mountPoints=[], environment=[], command=[], resourceRequirements=[], ulimits=[], - secrets=[]) + container_props.update( + volumes=[], mountPoints=[], environment=[], command=[], resourceRequirements=[], ulimits=[], secrets=[] + ) if args.platform_capabilities == ["FARGATE"]: container_props["resourceRequirements"].append(dict(type="VCPU", value="0.25")) container_props["resourceRequirements"].append(dict(type="MEMORY", value="512")) @@ -210,34 +236,44 @@ def get_jd_arn_and_job_name(jd_res): container_props["logConfiguration"]["options"] = {k: v for k, v in args.log_options} iam_role = ensure_iam_role(args.job_role, trust=["ecs-tasks"], policies=args.default_job_role_iam_policies) container_props.update(jobRoleArn=iam_role.arn) - expect_job_defn = dict(status="ACTIVE", type="container", parameters={}, tags={}, - retryStrategy=dict(attempts=args.retry_attempts, evaluateOnExit=[]), - containerProperties=container_props, platformCapabilities=args.platform_capabilities) + expect_job_defn = dict( + status="ACTIVE", + type="container", + parameters={}, + tags={}, + retryStrategy=dict(attempts=args.retry_attempts, evaluateOnExit=[]), + containerProperties=container_props, + platformCapabilities=args.platform_capabilities, + ) job_hash = hashlib.sha256(json.dumps(container_props, sort_keys=True).encode()).hexdigest()[:8] job_defn_name = __name__.replace(".", "_") + "_jd_" + job_hash if args.platform_capabilities == ["FARGATE"]: job_defn_name += "_FARGATE" container_props["fargatePlatformConfiguration"] = dict(platformVersion="LATEST") container_props["networkConfiguration"] = dict(assignPublicIp="ENABLED") - describe_job_definitions_paginator = Paginator(method=clients.batch.describe_job_definitions, - pagination_config=dict(result_key="jobDefinitions", - input_token="nextToken", - output_token="nextToken", - limit_key="maxResults"), - model=None) + describe_job_definitions_paginator = Paginator( + method=clients.batch.describe_job_definitions, + pagination_config=dict( + result_key="jobDefinitions", input_token="nextToken", output_token="nextToken", limit_key="maxResults" + ), + model=None, + ) for job_defn in paginate(describe_job_definitions_paginator, jobDefinitionName=job_defn_name): job_defn_desc = {k: job_defn.pop(k) for k in ("jobDefinitionName", "jobDefinitionArn", "revision")} if job_defn == expect_job_defn: logger.info("Found existing Batch job definition %s", job_defn_desc["jobDefinitionArn"]) return get_jd_arn_and_job_name(job_defn_desc) logger.info("Creating new Batch job definition %s", job_defn_name) - jd_res = clients.batch.register_job_definition(jobDefinitionName=job_defn_name, - type="container", - containerProperties=container_props, - retryStrategy=dict(attempts=args.retry_attempts), - platformCapabilities=args.platform_capabilities) + jd_res = clients.batch.register_job_definition( + jobDefinitionName=job_defn_name, + type="container", + containerProperties=container_props, + retryStrategy=dict(attempts=args.retry_attempts), + platformCapabilities=args.platform_capabilities, + ) return get_jd_arn_and_job_name(jd_res) + def ensure_lambda_helper(): awslambda = getattr(clients, "lambda") try: @@ -246,6 +282,7 @@ def ensure_lambda_helper(): except awslambda.exceptions.ResourceNotFoundException: logger.info("Batch helper Lambda not found, installing") import chalice.cli # type: ignore + orig_argv = sys.argv orig_wd = os.getcwd() try: diff --git a/aegea/util/aws/batch_events_lambda/app.py b/aegea/util/aws/batch_events_lambda/app.py index 02d3b85c..2f5866d4 100644 --- a/aegea/util/aws/batch_events_lambda/app.py +++ b/aegea/util/aws/batch_events_lambda/app.py @@ -4,6 +4,7 @@ Fields like "command" and "environment" are redacted to avoid storing potentially sensitive information. """ + import json import os @@ -14,6 +15,7 @@ app = Chalice(app_name="aegea-batch-events") + @app.on_cw_event({"source": ["aws.batch"]}) def process_batch_event(event): job_id = event.detail["jobId"] diff --git a/aegea/util/aws/dns.py b/aegea/util/aws/dns.py index ea7738e1..a66305bd 100644 --- a/aegea/util/aws/dns.py +++ b/aegea/util/aws/dns.py @@ -10,9 +10,11 @@ def get_client_token(iam_username, service): from getpass import getuser from socket import gethostname + tok = "{}.{}.{}:{}@{}".format(iam_username, service, int(time.time()), getuser(), gethostname().split(".")[0]) return tok[:64] + class DNSZone(VerboseRepr): def __init__(self, zone_name=None, create_default_private_zone=True): if zone_name is None: @@ -24,10 +26,12 @@ def __init__(self, zone_name=None, create_default_private_zone=True): vpc = ensure_vpc() vpc.modify_attribute(EnableDnsSupport=dict(Value=True)) vpc.modify_attribute(EnableDnsHostnames=dict(Value=True)) - res = clients.route53.create_hosted_zone(Name=config.dns.private_zone, - CallerReference=get_client_token(None, "route53"), - HostedZoneConfig=dict(PrivateZone=True), - VPC=dict(VPCRegion=ARN.get_region(), VPCId=vpc.vpc_id)) + res = clients.route53.create_hosted_zone( + Name=config.dns.private_zone, + CallerReference=get_client_token(None, "route53"), + HostedZoneConfig=dict(PrivateZone=True), + VPC=dict(VPCRegion=ARN.get_region(), VPCId=vpc.vpc_id), + ) self.zone = res["HostedZone"] else: raise @@ -42,21 +46,23 @@ def find(zone_name): def update(self, names, values, action="UPSERT", record_type="CNAME", ttl=60): def format_rrs(name, value): - return dict(Name=name + "." + self.zone["Name"], - Type=record_type, - TTL=ttl, - ResourceRecords=value if isinstance(value, (list, tuple)) else [{"Value": value}]) + return dict( + Name=name + "." + self.zone["Name"], + Type=record_type, + TTL=ttl, + ResourceRecords=value if isinstance(value, (list, tuple)) else [{"Value": value}], + ) + if not isinstance(names, (list, tuple)): names, values = [names], [values] updates = [dict(Action=action, ResourceRecordSet=format_rrs(k, v)) for k, v in zip(names, values)] - return clients.route53.change_resource_record_sets(HostedZoneId=self.zone_id, - ChangeBatch=dict(Changes=updates)) + return clients.route53.change_resource_record_sets(HostedZoneId=self.zone_id, ChangeBatch=dict(Changes=updates)) def delete(self, name, value=None, record_type="CNAME", missing_ok=True): if value is None: - res = clients.route53.list_resource_record_sets(HostedZoneId=self.zone_id, - StartRecordName=name + "." + self.zone["Name"], - StartRecordType=record_type) + res = clients.route53.list_resource_record_sets( + HostedZoneId=self.zone_id, StartRecordName=name + "." + self.zone["Name"], StartRecordType=record_type + ) for rrs in res["ResourceRecordSets"]: if rrs["Name"] == name + "." + self.zone["Name"] and rrs["Type"] == record_type: value = rrs["ResourceRecords"] diff --git a/aegea/util/aws/iam.py b/aegea/util/aws/iam.py index 1d349ace..98ae717f 100644 --- a/aegea/util/aws/iam.py +++ b/aegea/util/aws/iam.py @@ -53,10 +53,10 @@ def add_statement(self, principal=None, action=None, effect="Allow", resource=No statement["Principal"] = principal self.policy["Statement"].append(statement) if action: - for action in (action if isinstance(action, list) else [action]): + for action in action if isinstance(action, list) else [action]: self.add_action(action) if resource: - for resource in (resource if isinstance(resource, list) else [resource]): + for resource in resource if isinstance(resource, list) else [resource]: self.add_resource(resource) def add_action(self, action): @@ -79,12 +79,18 @@ def add_assume_role_principals(self, principals): def __str__(self): return json.dumps(self.policy) + def ensure_iam_role(name, policies=frozenset(), trust=frozenset()): assume_role_policy = IAMPolicyBuilder() assume_role_policy.add_assume_role_principals(trust) - role = ensure_iam_entity(name, policies=policies, collection=resources.iam.roles, - constructor=resources.iam.create_role, RoleName=name, - AssumeRolePolicyDocument=str(assume_role_policy)) + role = ensure_iam_entity( + name, + policies=policies, + collection=resources.iam.roles, + constructor=resources.iam.create_role, + RoleName=name, + AssumeRolePolicyDocument=str(assume_role_policy), + ) trust_policy = IAMPolicyBuilder(role.assume_role_policy_document) trust_policy.add_assume_role_principals(trust) if trust_policy.policy != role.assume_role_policy_document: @@ -92,9 +98,12 @@ def ensure_iam_role(name, policies=frozenset(), trust=frozenset()): role.AssumeRolePolicy().update(PolicyDocument=str(trust_policy)) return role + def ensure_iam_group(name, policies=frozenset()): - return ensure_iam_entity(name, policies=policies, collection=resources.iam.groups, - constructor=resources.iam.create_group, GroupName=name) + return ensure_iam_entity( + name, policies=policies, collection=resources.iam.groups, constructor=resources.iam.create_group, GroupName=name + ) + def ensure_iam_entity(iam_entity_name, policies, collection, constructor, **constructor_args): for entity in collection.all(): @@ -116,6 +125,7 @@ def ensure_iam_entity(iam_entity_name, policies, collection, constructor, **cons # TODO: accommodate IAM eventual consistency return entity + def ensure_instance_profile(iam_role_name, policies=frozenset()): for instance_profile in resources.iam.instance_profiles.all(): if instance_profile.name == iam_role_name: @@ -131,6 +141,7 @@ def ensure_instance_profile(iam_role_name, policies=frozenset()): instance_profile.add_role(RoleName=role.name) return instance_profile + def ensure_iam_policy(name, doc): try: return resources.iam.create_policy(PolicyName=name, PolicyDocument=str(doc)) @@ -143,6 +154,7 @@ def ensure_iam_policy(name, doc): version.delete() return policy + def compose_managed_policies(policy_names): policy = IAMPolicyBuilder() for policy_name in policy_names: @@ -152,7 +164,10 @@ def compose_managed_policies(policy_names): policy.policy["Statement"][-1]["Sid"] = policy_name + str(i) return policy + def ensure_fargate_execution_role(name): - return ensure_iam_role(name, trust=["ecs-tasks"], - policies=["service-role/AmazonEC2ContainerServiceforEC2Role", - "service-role/AWSBatchServiceRole"]) + return ensure_iam_role( + name, + trust=["ecs-tasks"], + policies=["service-role/AmazonEC2ContainerServiceforEC2Role", "service-role/AWSBatchServiceRole"], + ) diff --git a/aegea/util/aws/logs.py b/aegea/util/aws/logs.py index 6251f9fb..ddfeccaf 100644 --- a/aegea/util/aws/logs.py +++ b/aegea/util/aws/logs.py @@ -25,8 +25,11 @@ def __init__(self, log_stream_name, log_group_name="/aws/batch/job", head=None, def __iter__(self): page = None - get_args = dict(logGroupName=self.log_group_name, logStreamName=self.log_stream_name, - limit=min(self.head or 10000, self.tail or 10000)) + get_args = dict( + logGroupName=self.log_group_name, + logStreamName=self.log_stream_name, + limit=min(self.head or 10000, self.tail or 10000), + ) get_args["startFromHead"] = True if self.tail is None else False if self.next_page_token: get_args["nextToken"] = self.next_page_token @@ -41,6 +44,7 @@ def __iter__(self): if page: self.next_page_token = page[self.next_page_key] + def export_log_files(args): bucket_name = "aegea-cloudwatch-log-export-{}-{}".format(ARN.get_account_id(), clients.logs.meta.region_name) bucket_arn = ARN(service="s3", region="", account_id="", resource=bucket_name) @@ -52,10 +56,12 @@ def export_log_files(args): bucket = ensure_s3_bucket(bucket_name, policy=policy, lifecycle=lifecycle) if not args.end_time: args.end_time = Timestamp.match_precision(Timestamp("-0s"), args.start_time) - export_task_args = dict(logGroupName=args.log_group, - fromTime=int(datetime.timestamp(args.start_time) * 1000), - to=int(datetime.timestamp(args.end_time) * 1000), - destination=bucket.name) + export_task_args = dict( + logGroupName=args.log_group, + fromTime=int(datetime.timestamp(args.start_time) * 1000), + to=int(datetime.timestamp(args.end_time) * 1000), + destination=bucket.name, + ) if args.log_stream: export_task_args.update(logStreamNamePrefix=args.log_stream) cache_key = hashlib.sha256(json.dumps(export_task_args, sort_keys=True).encode()).hexdigest()[:32] @@ -85,6 +91,7 @@ def export_log_files(args): pass return bucket.objects.filter(Prefix=cache_key) + def get_lines_for_log_file(log_file): if not log_file.key.endswith(".gz"): return [] @@ -95,12 +102,14 @@ def get_lines_for_log_file(log_file): log_lines.append(line) return log_lines + def export_and_print_log_events(args): with ThreadPoolExecutor() as executor: for lines in executor.map(get_lines_for_log_file, export_log_files(args)): for line in lines: sys.stdout.write(line) + def print_log_event(event): if "@timestamp" in event: print(str(Timestamp(event["@timestamp"])), event["@message"]) @@ -109,6 +118,7 @@ def print_log_event(event): else: print(json.dumps(event, indent=4)) + def print_log_events(args): streams = [] if args.log_stream: @@ -140,22 +150,27 @@ def print_log_events(args): break get_log_events_args.update(nextToken=page["nextForwardToken"], limit=10000) + def print_log_event_with_context(log_record_pointer, before=10, after=10): res = clients.logs.get_log_record(logRecordPointer=log_record_pointer) log_record = res["logRecord"] account_id, log_group_name = log_record["@log"].split(":") - before_ctx = clients.logs.get_log_events(logGroupName=log_group_name, - logStreamName=log_record["@logStream"], - endTime=int(log_record["@timestamp"]), - limit=before, - startFromHead=False) + before_ctx = clients.logs.get_log_events( + logGroupName=log_group_name, + logStreamName=log_record["@logStream"], + endTime=int(log_record["@timestamp"]), + limit=before, + startFromHead=False, + ) for event in before_ctx["events"]: print_log_event(event) - after_ctx = clients.logs.get_log_events(logGroupName=log_group_name, - logStreamName=log_record["@logStream"], - startTime=int(log_record["@timestamp"]), - limit=after, - startFromHead=True) + after_ctx = clients.logs.get_log_events( + logGroupName=log_group_name, + logStreamName=log_record["@logStream"], + startTime=int(log_record["@timestamp"]), + limit=after, + startFromHead=True, + ) for event in after_ctx["events"]: print_log_event(event) print("---") diff --git a/aegea/util/aws/spot.py b/aegea/util/aws/spot.py index d375f023..569f9959 100644 --- a/aegea/util/aws/spot.py +++ b/aegea/util/aws/spot.py @@ -9,9 +9,20 @@ class SpotFleetBuilder(VerboseRepr): # TODO: vivify from toolspec; vivify from SFR ID; update with incremental cores/memory requirements - def __init__(self, launch_spec, cores=1, min_cores_per_instance=1, min_mem_per_core_gb=1.5, gpus_per_instance=0, - min_ephemeral_storage_gb=0, spot_price=None, duration_hours=None, client_token=None, - instance_type_prefixes=None, dry_run=False): + def __init__( + self, + launch_spec, + cores=1, + min_cores_per_instance=1, + min_mem_per_core_gb=1.5, + gpus_per_instance=0, + min_ephemeral_storage_gb=0, + spot_price=None, + duration_hours=None, + client_token=None, + instance_type_prefixes=None, + dry_run=False, + ): if spot_price is None: spot_price = 1 if "SecurityGroupIds" in launch_spec: @@ -28,21 +39,20 @@ def __init__(self, launch_spec, cores=1, min_cores_per_instance=1, min_mem_per_c self.instance_type_prefixes = instance_type_prefixes self.dry_run = dry_run self.iam_fleet_role = self.get_iam_fleet_role() - self.spot_fleet_request_config = dict(SpotPrice=str(spot_price), - TargetCapacity=cores, - IamFleetRole=self.iam_fleet_role.arn) + self.spot_fleet_request_config = dict( + SpotPrice=str(spot_price), TargetCapacity=cores, IamFleetRole=self.iam_fleet_role.arn + ) if client_token: self.spot_fleet_request_config.update(ClientToken=client_token) if duration_hours: deadline = datetime.utcnow().replace(microsecond=0) + timedelta(hours=duration_hours) - self.spot_fleet_request_config.update(ValidUntil=deadline, - TerminateInstancesWithExpiration=True) + self.spot_fleet_request_config.update(ValidUntil=deadline, TerminateInstancesWithExpiration=True) @classmethod def get_iam_fleet_role(cls): - return ensure_iam_role("SpotFleet", - policies=["service-role/AmazonEC2SpotFleetTaggingRole"], - trust=["spotfleet"]) + return ensure_iam_role( + "SpotFleet", policies=["service-role/AmazonEC2SpotFleetTaggingRole"], trust=["spotfleet"] + ) def instance_types(self, max_overprovision=3): def compute_ephemeral_storage_gb(instance_data): @@ -71,14 +81,12 @@ def compute_ephemeral_storage_gb(instance_data): def launch_specs(self, **kwargs): for instance_type, weighted_capacity in self.instance_types(**kwargs): - yield dict(self.launch_spec, - InstanceType=instance_type, - WeightedCapacity=weighted_capacity) + yield dict(self.launch_spec, InstanceType=instance_type, WeightedCapacity=weighted_capacity) def __call__(self, **kwargs): self.spot_fleet_request_config["LaunchSpecifications"] = list(self.launch_specs()) logger.debug(self.spot_fleet_request_config) - res = clients.ec2.request_spot_fleet(DryRun=self.dry_run, - SpotFleetRequestConfig=self.spot_fleet_request_config, - **kwargs) + res = clients.ec2.request_spot_fleet( + DryRun=self.dry_run, SpotFleetRequestConfig=self.spot_fleet_request_config, **kwargs + ) return res["SpotFleetRequestId"] diff --git a/aegea/util/cloudinit.py b/aegea/util/cloudinit.py index de9a1f1d..35d293f0 100644 --- a/aegea/util/cloudinit.py +++ b/aegea/util/cloudinit.py @@ -19,20 +19,22 @@ def add_file_to_cloudinit_manifest(src_path, path, manifest): with open(src_path, "rb") as fh: content = fh.read() - manifest[path] = dict(path=path, permissions='0' + oct(os.stat(src_path).st_mode)[-3:]) + manifest[path] = dict(path=path, permissions="0" + oct(os.stat(src_path).st_mode)[-3:]) try: manifest[path].update(content=content.decode()) except UnicodeDecodeError: manifest[path].update(content=base64.b64encode(gzip_compress_bytes(content)), encoding="gz+b64") + def get_rootfs_skel_dirs(args): # Build a list of rootfs_skel_dirs to build from. The arg element 'auto' is # expanded to the default aegea skel as well as rootfs.skel. directories in # the same paths as config files. rootfs_skel_dirs = OrderedDict() # type: OrderedDict[str, None] for arg in args.rootfs_skel_dirs: - if arg == 'auto': + if arg == "auto": from .. import config + dirs_to_scan = [os.path.dirname(p) for p in config.config_files] for path in dirs_to_scan: path = os.path.join(os.path.expanduser(path), "rootfs.skel." + args.entry_point.__name__) @@ -41,11 +43,12 @@ def get_rootfs_skel_dirs(args): else: rootfs_skel_dirs[os.path.expanduser(arg)] = None if rootfs_skel_dirs: - logger.info('Adding rootfs skel files from these paths: %s', ', '.join(rootfs_skel_dirs)) + logger.info("Adding rootfs skel files from these paths: %s", ", ".join(rootfs_skel_dirs)) else: - logger.debug('No rootfs skel files.') + logger.debug("No rootfs skel files.") return list(rootfs_skel_dirs) + def get_bootstrap_files(rootfs_skel_dirs, dest="cloudinit"): manifest = OrderedDict() # type: OrderedDict[str, Dict] targz = io.BytesIO() @@ -63,22 +66,35 @@ def get_bootstrap_files(rootfs_skel_dirs, dest="cloudinit"): if dest == "cloudinit": add_file_to_cloudinit_manifest(os.path.join(root, file_), path, manifest) elif dest == "tarfile": + assert tar is not None tar.add(os.path.join(root, file_), path) if dest == "cloudinit": return list(manifest.values()) elif dest == "tarfile": + assert tar is not None tar.close() return targz.getvalue() -def get_user_data(host_key=None, commands=None, packages=None, rootfs_skel_dirs=None, storage=None, - mime_multipart_archive=False, ssh_ca_keys=None, provision_users=None, **kwargs): + +def get_user_data( + host_key=None, + commands=None, + packages=None, + rootfs_skel_dirs=None, + storage=None, + mime_multipart_archive=False, + ssh_ca_keys=None, + provision_users=None, + **kwargs, +): """ provision_users can be either a list of Linux usernames or a list of dicts as described in https://cloudinit.readthedocs.io/en/latest/topics/modules.html#module-cloudinit.config.cc_users_groups """ cloud_config_data = OrderedDict() # type: OrderedDict[str, Any] cloud_config_data["bootcmd"] = [ + # systemd interferes with aliases "for d in /dev/disk/by-id/nvme-Amazon_Elastic_Block_Store_vol?????????????????; do " "a=/dev/$(nvme id-ctrl --raw-binary $d 2>/dev/null | dd skip=3072 bs=1 count=4 status=none); " "[[ -e $a ]] || ln -s $d $a; " @@ -95,10 +111,12 @@ def get_user_data(host_key=None, commands=None, packages=None, rootfs_skel_dirs= cloud_config_data["runcmd"] = commands or [] cloud_config_data["write_files"] = get_bootstrap_files(rootfs_skel_dirs or []) if ssh_ca_keys: - cloud_config_data["write_files"] += [dict(path="/etc/ssh/sshd_ca.pem", permissions='0644', content=ssh_ca_keys)] - cloud_config_data["runcmd"].append("grep -q TrustedUserCAKeys /etc/ssh/sshd_config || " - "(echo 'TrustedUserCAKeys /etc/ssh/sshd_ca.pem' >> /etc/ssh/sshd_config;" - " service sshd reload)") + cloud_config_data["write_files"] += [dict(path="/etc/ssh/sshd_ca.pem", permissions="0644", content=ssh_ca_keys)] + cloud_config_data["runcmd"].append( + "grep -q TrustedUserCAKeys /etc/ssh/sshd_config || " + "(echo 'TrustedUserCAKeys /etc/ssh/sshd_ca.pem' >> /etc/ssh/sshd_config;" + " service sshd reload)" + ) if provision_users: cloud_config_data["users"] = [] for user in provision_users: @@ -111,8 +129,7 @@ def get_user_data(host_key=None, commands=None, packages=None, rootfs_skel_dirs= if host_key is not None: buf = io.StringIO() host_key.write_private_key(buf) - cloud_config_data["ssh_keys"] = dict(rsa_private=buf.getvalue(), - rsa_public=get_public_key_from_pair(host_key)) + cloud_config_data["ssh_keys"] = dict(rsa_private=buf.getvalue(), rsa_public=get_public_key_from_pair(host_key)) payload = encode_cloud_config_payload(cloud_config_data, mime_multipart_archive=mime_multipart_archive) if len(payload) >= 16384: logger.warn("Cloud-init payload is too large to be passed in user data, extracting rootfs.skel") @@ -120,6 +137,7 @@ def get_user_data(host_key=None, commands=None, packages=None, rootfs_skel_dirs= payload = encode_cloud_config_payload(cloud_config_data, mime_multipart_archive=mime_multipart_archive) return payload + mime_multipart_archive_template = """Content-Type: multipart/mixed; boundary="==BOUNDARY==" MIME-Version: 1.0 @@ -131,6 +149,7 @@ def get_user_data(host_key=None, commands=None, packages=None, rootfs_skel_dirs= --==BOUNDARY==-- """ + def encode_cloud_config_payload(cloud_config_data, mime_multipart_archive=False, gzip=True): # TODO: default=dict is for handling tweak.Config objects in the hierarchy. # TODO: Should subclass dict, not MutableMapping @@ -141,16 +160,18 @@ def encode_cloud_config_payload(cloud_config_data, mime_multipart_archive=False, slug = "#cloud-config\n" + cloud_config_json return gzip_compress_bytes(slug.encode()) if gzip else slug + def upload_bootstrap_asset(cloud_config_data, rootfs_skel_dirs): key_name = "".join(random.choice(string.ascii_letters) for x in range(32)) enc_key = "".join(random.choice(string.ascii_letters) for x in range(32)) logger.info("Uploading bootstrap asset %s to S3", key_name) bucket = ensure_s3_bucket() - cipher = subprocess.Popen(["openssl", "aes-256-cbc", "-e", "-k", enc_key], - stdin=subprocess.PIPE, stdout=subprocess.PIPE) + cipher = subprocess.Popen( + ["openssl", "aes-256-cbc", "-e", "-k", enc_key], stdin=subprocess.PIPE, stdout=subprocess.PIPE + ) encrypted_tarfile = cipher.communicate(get_bootstrap_files(rootfs_skel_dirs, dest="tarfile"))[0] bucket.upload_fileobj(io.BytesIO(encrypted_tarfile), key_name) - url = clients.s3.generate_presigned_url(ClientMethod='get_object', Params=dict(Bucket=bucket.name, Key=key_name)) + url = clients.s3.generate_presigned_url(ClientMethod="get_object", Params=dict(Bucket=bucket.name, Key=key_name)) cmd = "curl -s '{url}' | openssl aes-256-cbc -d -k {key} | tar -xz --no-same-owner -C /" cloud_config_data["runcmd"].insert(0, cmd.format(url=url, key=enc_key)) del cloud_config_data["write_files"] diff --git a/aegea/util/constants.py b/aegea/util/constants.py index ef37439a..25058b5a 100644 --- a/aegea/util/constants.py +++ b/aegea/util/constants.py @@ -5,8 +5,10 @@ _constants_filename = os.path.join(os.path.dirname(__file__), "..", "constants.json") _constants = {} # type: Dict[str, Any] + def write(): from . import aws + raise NotImplementedError() """ constants = {"instance_types": {}} @@ -21,6 +23,7 @@ def write(): json.dump(constants, fh) """ + def get(field): if not _constants: with open(_constants_filename) as fh: diff --git a/aegea/util/crypto.py b/aegea/util/crypto.py index e1c413c5..85e058d6 100644 --- a/aegea/util/crypto.py +++ b/aegea/util/crypto.py @@ -11,24 +11,31 @@ def new_ssh_key(bits=2048): from paramiko import RSAKey + return RSAKey.generate(bits=bits) + def get_public_key_from_pair(key): return key.get_name() + " " + key.get_base64() + def key_fingerprint(key): hex_fp = binascii.hexlify(key.get_fingerprint()).decode() - return key.get_name() + " " + ":".join(hex_fp[i:i + 2] for i in range(0, len(hex_fp), 2)) + return key.get_name() + " " + ":".join(hex_fp[i : i + 2] for i in range(0, len(hex_fp), 2)) + def get_ssh_key_path(name): return os.path.expanduser("~/.ssh/{}.pem".format(name)) + def get_ssh_id(): for line in subprocess.check_output(["ssh-add", "-L"]).decode().splitlines(): return line + def ensure_local_ssh_key(name): from paramiko import RSAKey + if os.path.exists(get_ssh_key_path(name)): ssh_key = RSAKey.from_private_key_file(get_ssh_key_path(name)) else: @@ -38,16 +45,19 @@ def ensure_local_ssh_key(name): ssh_key.write_private_key_file(get_ssh_key_path(name)) return ssh_key + def add_ssh_key_to_agent(name): try: subprocess.check_call(["ssh-add", get_ssh_key_path(name)], timeout=5) except Exception as e: logger.warn("Failed to add %s to SSH keychain: %s. Connections may fail", get_ssh_key_path(name), e) + def ensure_ssh_key(name=None, base_name=__name__, verify_pem_file=True): if name is None: from getpass import getuser from socket import gethostname + name = base_name + "." + getuser() + "." + gethostname().split(".")[0] try: @@ -62,16 +72,18 @@ def ensure_ssh_key(name=None, base_name=__name__, verify_pem_file=True): if not ec2_key_pairs: ssh_key = ensure_local_ssh_key(name) - resources.ec2.import_key_pair(KeyName=name, - PublicKeyMaterial=get_public_key_from_pair(ssh_key)) + resources.ec2.import_key_pair(KeyName=name, PublicKeyMaterial=get_public_key_from_pair(ssh_key)) logger.info("Imported SSH key %s", get_ssh_key_path(name)) add_ssh_key_to_agent(name) return name + def hostkey_line(hostnames, key): from paramiko import hostkeys + return hostkeys.HostKeyEntry(hostnames=hostnames, key=key).to_line() + def add_ssh_host_key_to_known_hosts(host_key_line): ssh_known_hosts_path = os.path.expanduser("~/.ssh/known_hosts") with open(ssh_known_hosts_path, "a") as fh: diff --git a/aegea/util/exceptions.py b/aegea/util/exceptions.py index f8f3d83c..1d795817 100644 --- a/aegea/util/exceptions.py +++ b/aegea/util/exceptions.py @@ -3,5 +3,6 @@ class AegeaException(Exception): Base class for exceptions in this package. """ + class GetFieldError(AegeaException): pass diff --git a/aegea/util/printing.py b/aegea/util/printing.py index 7a853da4..c3e5238a 100644 --- a/aegea/util/printing.py +++ b/aegea/util/printing.py @@ -19,59 +19,71 @@ def CYAN(message=None): else: return CYAN() + message + ENDC() + def BLUE(message=None): if message is None: return "\033[34m" if sys.stdout.isatty() else "" else: return BLUE() + message + ENDC() + def YELLOW(message=None): if message is None: return "\033[33m" if sys.stdout.isatty() else "" else: return YELLOW() + message + ENDC() + def GREEN(message=None): if message is None: return "\033[32m" if sys.stdout.isatty() else "" else: return GREEN() + message + ENDC() + def RED(message=None): if message is None: return "\033[31m" if sys.stdout.isatty() else "" else: return RED() + message + ENDC() + def WHITE(message=None): if message is None: return "\033[37m" if sys.stdout.isatty() else "" else: return WHITE() + message + ENDC() + def UNDERLINE(message=None): if message is None: return "\033[4m" if sys.stdout.isatty() else "" else: return UNDERLINE() + message + ENDC() + def BOLD(message=None): if message is None: return "\033[1m" if sys.stdout.isatty() else "" else: return BOLD() + message + ENDC() + def ENDC(): return "\033[0m" if sys.stdout.isatty() else "" + def border(i): return WHITE() + i + ENDC() + ansi_pattern = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]") + def strip_ansi_codes(i): return re.sub(ansi_pattern, "", i) + def ansi_truncate(s, max_len): ansi_total_len = 0 for ansi_code in ansi_pattern.finditer(s): @@ -80,9 +92,10 @@ def ansi_truncate(s, max_len): break ansi_total_len += ansi_code_end - ansi_code_start if len(s) > max_len + ansi_total_len: - return s[:max_len + ansi_total_len - 1] + "…" + return s[: max_len + ansi_total_len - 1] + "…" return s + def format_table(table, column_names=None, column_specs=None, max_col_width=32, auto_col_width=False): """ Table pretty printer. Expects tables to be given as arrays of arrays:: @@ -117,10 +130,7 @@ def format_table(table, column_names=None, column_specs=None, max_col_width=32, col_widths[i] = max(col_widths[i], len(strip_ansi_codes(my_item))) trunc_table.append(my_row) - type_colormap = {"boolean": BLUE(), - "integer": YELLOW(), - "float": WHITE(), - "string": GREEN()} + type_colormap = {"boolean": BLUE(), "integer": YELLOW(), "float": WHITE(), "string": GREEN()} for t in "uint8", "int16", "uint16", "int32", "uint32", "int64": type_colormap[t] = type_colormap["integer"] type_colormap["double"] = type_colormap["float"] @@ -133,8 +143,9 @@ def col_head(i): formatted_table = [border("┌") + border("┬").join(border("─") * i for i in col_widths) + border("┐")] if len(my_col_names) > 0: - padded_column_names = [col_head(i) + " " * (col_widths[i] - len(my_col_names[i])) - for i in range(len(my_col_names))] + padded_column_names = [ + col_head(i) + " " * (col_widths[i] - len(my_col_names[i])) for i in range(len(my_col_names)) + ] formatted_table.append(border("│") + border("│").join(padded_column_names) + border("│")) formatted_table.append(border("├") + border("┼").join(border("─") * i for i in col_widths) + border("┤")) @@ -152,6 +163,7 @@ def col_head(i): return format_table(table, max_col_width=max_col_width - 1, auto_col_width=True, **orig_col_args) return "\n".join(formatted_table) + def page_output(content, pager=None, file=None): if file is None: file = sys.stdout @@ -175,8 +187,10 @@ def page_output(content, pager=None, file=None): if tty_rows > content_rows and tty_cols > content_cols: raise AegeaException() - pager_process = subprocess.Popen(pager or os.environ.get("PAGER", "less -RS"), shell=True, - stdin=subprocess.PIPE, stdout=file) + pager_process = subprocess.Popen( + pager or os.environ.get("PAGER", "less -RS"), shell=True, stdin=subprocess.PIPE, stdout=file + ) + assert pager_process.stdin is not None pager_process.stdin.write(content.encode("utf-8")) pager_process.stdin.close() pager_process.wait() @@ -187,10 +201,12 @@ def page_output(content, pager=None, file=None): file.write(content) finally: try: + assert pager_process is not None pager_process.terminate() except BaseException: pass + def get_field(item, field): for element in field.split("."): try: @@ -202,24 +218,29 @@ def get_field(item, field): raise GetFieldError('Unable to access field or attribute "{}" of {}'.format(field, item)) return item + def format_datetime(d): from babel import dates from dateutil.tz import tzutc + d = d.replace(microsecond=0) # Switch from UTC to local TZ d = d.astimezone(tz=None) return dates.format_timedelta(d - datetime.now(tzutc()), add_direction=True) + def format_cell(cell): if isinstance(cell, datetime): cell = format_datetime(cell) if isinstance(cell, timedelta): from babel import dates + cell = dates.format_timedelta(-cell, add_direction=True) if isinstance(cell, (list, dict)): cell = json.dumps(cell, default=lambda x: str(x)) return cell + def get_cell(resource, field, transform=None): cell = get_field(resource, field) if transform: @@ -234,17 +255,20 @@ def get_cell(resource, field, transform=None): return "[Access denied]" raise + def format_tags(cell, row): tags = {tag["Key"]: tag["Value"] for tag in cell} if cell else {} return ", ".join("{}={}".format(k, v) for k, v in tags.items()) + def trim_names(names, *prefixes): for name in names: for prefix in prefixes: if name.startswith(prefix): - name = name[len(prefix):] + name = name[len(prefix) :] yield name + def format_number(n, fractional_digits=2): B = n KB = float(1024) @@ -253,15 +277,16 @@ def format_number(n, fractional_digits=2): TB = float(GB * 1024) if B < KB: - return '{0}'.format(B) + return "{0}".format(B) elif KB <= B < MB: - return '{0:.{precision}f}K'.format(B / KB, precision=fractional_digits) + return "{0:.{precision}f}K".format(B / KB, precision=fractional_digits) elif MB <= B < GB: - return '{0:.{precision}f}M'.format(B / MB, precision=fractional_digits) + return "{0:.{precision}f}M".format(B / MB, precision=fractional_digits) elif GB <= B < TB: - return '{0:.{precision}f}G'.format(B / GB, precision=fractional_digits) + return "{0:.{precision}f}G".format(B / GB, precision=fractional_digits) elif TB <= B: - return '{0:.{precision}f}T'.format(B / TB, precision=fractional_digits) + return "{0:.{precision}f}T".format(B / TB, precision=fractional_digits) + def tabulate(collection, args, cell_transforms=None): if cell_transforms is None: @@ -280,7 +305,7 @@ def tabulate(collection, args, cell_transforms=None): else: if args.sort_by.endswith(":reverse"): reverse = True - args.sort_by = args.sort_by[:-len(":reverse")] + args.sort_by = args.sort_by[: -len(":reverse")] table = sorted(table, key=lambda x: x[args.columns.index(args.sort_by)], reverse=reverse) table = [[format_cell(c) for c in row] for row in table] # type: ignore args.columns = list(trim_names(args.columns, *getattr(args, "trim_col_names", []))) diff --git a/aegea/zones.py b/aegea/zones.py index 95131794..c02000e7 100644 --- a/aegea/zones.py +++ b/aegea/zones.py @@ -16,8 +16,10 @@ def zones(args): zones_parser.print_help() + zones_parser = register_parser(zones, help="Manage Route53 DNS zones", description=__doc__) + def ls(args): table = [] rrs_cols = ["Name", "Type", "TTL"] @@ -34,20 +36,25 @@ def ls(args): column_names = rrs_cols + record_cols + ["Private", "Id"] page_output(format_table(table, column_names=column_names, max_col_width=args.max_col_width)) + parser = register_parser(ls, parent=zones_parser, help="List Route53 DNS zones and records") parser.add_argument("zones", nargs="*") + def update(args): return DNSZone(args.zone).update(*zip(*args.updates), record_type=args.record_type) # type: ignore + parser = register_parser(update, parent=zones_parser, help="Update Route53 DNS records") parser.add_argument("zone") parser.add_argument("updates", nargs="+", metavar="NAME=VALUE", type=lambda x: x.split("=", 1)) parser.add_argument("--record-type", default="CNAME") + def delete(args): return DNSZone(args.zone).delete(name=args.name, record_type=args.record_type, missing_ok=False) + parser = register_parser(delete, parent=zones_parser, help="Delete Route53 DNS records") parser.add_argument("zone") parser.add_argument("name", help=r'Enter a "\052" literal to represent a wildcard.') diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index bee99dbf..00000000 --- a/mypy.ini +++ /dev/null @@ -1,20 +0,0 @@ -[mypy] -cache_dir=/dev/null -[mypy-babel.*] -ignore_missing_imports=True -[mypy-boto3.*] -ignore_missing_imports=True -[mypy-botocore.*] -ignore_missing_imports=True -[mypy-ipwhois.*] -ignore_missing_imports=True -[mypy-paramiko.*] -ignore_missing_imports=True -[mypy-tweak.*] -ignore_missing_imports=True -[mypy-uritemplate.*] -ignore_missing_imports=True -[mypy-aegea.packages.*] -ignore_errors=True -[mypy-aegea.util.compat.*] -ignore_errors=True diff --git a/pyproject.toml b/pyproject.toml index 0efc84ca..4a15bfeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,29 @@ [tool.black] line-length = 120 + [tool.isort] profile = "black" line_length = 120 skip = ["scripts", "aegea/missions", "aegea/packages", "aegea/rootfs.skel.build_ami"] skip_gitignore = true + [tool.ruff] line-length=120 -ignore=["E401", "F401"] exclude=["aegea/packages", "aegea/lambdas"] -[tool.ruff.per-file-ignores] + +[tool.ruff.lint] +ignore=["E401", "F401"] + +[tool.ruff.lint.per-file-ignores] "test/test.py" = ["E402"] + +[tool.mypy] +files = [ + "aegea" +] +check_untyped_defs = true +disallow_incomplete_defs = true + +[[tool.mypy.overrides]] +module = ["babel.*", "boto3.*", "botocore.*", "ipwhois.*", "paramiko.*", "tweak.*", "uritemplate.*"] +ignore_missing_imports = true diff --git a/scripts/aegea b/scripts/aegea index e2d1103d..7019d10c 100755 --- a/scripts/aegea +++ b/scripts/aegea @@ -1,12 +1,17 @@ #!/usr/bin/env python3 # PYTHON_ARGCOMPLETE_OK -import os, sys, logging, pkgutil, importlib +import importlib +import logging +import os +import pkgutil +import sys logging.basicConfig(level=logging.ERROR) logging.getLogger("botocore.vendored.requests").setLevel(logging.ERROR) import argcomplete # noqa + import aegea # noqa for importer, modname, is_pkg in pkgutil.iter_modules(aegea.__path__): diff --git a/scripts/aegea-build-image-for-mission b/scripts/aegea-build-image-for-mission index 93ffe364..0b0c7b90 100755 --- a/scripts/aegea-build-image-for-mission +++ b/scripts/aegea-build-image-for-mission @@ -1,8 +1,15 @@ #!/usr/bin/env python3 # PYTHON_ARGCOMPLETE_OK -import os, sys, subprocess, argparse, base64, shutil +import argparse +import base64 +import os +import shutil +import subprocess +import sys + import argcomplete + import aegea from aegea.util.aws import ARN from aegea.util.compat import TemporaryDirectory diff --git a/scripts/aegea-rebuild-public-elb-sg b/scripts/aegea-rebuild-public-elb-sg index 06e13ae4..05606438 100755 --- a/scripts/aegea-rebuild-public-elb-sg +++ b/scripts/aegea-rebuild-public-elb-sg @@ -1,11 +1,24 @@ #!/usr/bin/env python3 # PYTHON_ARGCOMPLETE_OK -import os, sys, subprocess, argparse, base64, logging +import argparse +import base64 +import logging +import os +import subprocess +import sys + from botocore.exceptions import ClientError + from aegea import logger -from aegea.util.aws import (ensure_vpc, ensure_security_group, ensure_ingress_rule, resources, clients, - expect_error_codes) +from aegea.util.aws import ( + clients, + ensure_ingress_rule, + ensure_security_group, + ensure_vpc, + expect_error_codes, + resources, +) logging.basicConfig(level=logging.INFO) diff --git a/scripts/aegea-ssh b/scripts/aegea-ssh index 11fbfdd2..bb9a0050 100644 --- a/scripts/aegea-ssh +++ b/scripts/aegea-ssh @@ -1,6 +1,11 @@ #!/usr/bin/env python3 -import os, sys, logging, pkgutil, importlib, subprocess +import importlib +import logging +import os +import pkgutil +import subprocess +import sys logging.basicConfig(level=logging.ERROR) logging.getLogger("botocore.vendored.requests").setLevel(logging.ERROR) diff --git a/scripts/pypi-apt-freeze b/scripts/pypi-apt-freeze index 8fb5d803..e609a41a 100755 --- a/scripts/pypi-apt-freeze +++ b/scripts/pypi-apt-freeze @@ -11,8 +11,15 @@ Example: """ -import os, sys, logging, pkgutil, importlib, argparse -import argcomplete, requests +import argparse +import importlib +import logging +import os +import pkgutil +import sys + +import argcomplete +import requests logging.basicConfig(level=logging.WARNING)