From 2f4c7918b9d6946bea43983068a8b6adfaf020ce Mon Sep 17 00:00:00 2001 From: Andrey Kislyuk Date: Thu, 24 Dec 2020 11:35:22 -0800 Subject: [PATCH] batch: begin fargate support --- aegea/base_config.yml | 13 +++-- aegea/batch.py | 107 +++++++++++++++++++++++----------------- aegea/ecs.py | 14 +++--- aegea/util/aws/batch.py | 44 ++++++++++++----- aegea/util/aws/iam.py | 5 ++ 5 files changed, 114 insertions(+), 69 deletions(-) diff --git a/aegea/base_config.yml b/aegea/base_config.yml index ef8931c..28995ae 100644 --- a/aegea/base_config.yml +++ b/aegea/base_config.yml @@ -572,21 +572,24 @@ batch_queues: - status - statusReason -# Valid instance types values: "optimal", family name, type name batch_create_compute_environment: type: MANAGED - compute_type: SPOT - instance_types: [m5d, c5d, r5d] + compute_type: FARGATE_SPOT + max_vcpus: 64 + # The following parameters are not applicable (and ignored) for Fargate CEs + # Valid instance types values: "optimal", family name, type name min_vcpus: 0 desired_vcpus: 2 - max_vcpus: 64 + instance_types: [m5d, c5d, r5d] # ecs_container_instance_ami_tags: # AegeaMission: "ecs-container-instance" batch_submit: user: "0" - default_memory_mb: 4000 + default_memory_mb: 4096 job_role: aegea.batch.worker + platform_capabilities: + - EC2 default_job_role_iam_policies: - AmazonEC2ReadOnlyAccess - AmazonS3ReadOnlyAccess diff --git a/aegea/batch.py b/aegea/batch.py index 699b704..d187fc6 100755 --- a/aegea/batch.py +++ b/aegea/batch.py @@ -24,7 +24,7 @@ from .util.aws.spot import SpotFleetBuilder from .util.aws.logs import CloudwatchLogReader from .util.aws.batch import ensure_job_definition, get_command_and_env, ensure_lambda_helper -from .util.aws.iam import IAMPolicyBuilder, ensure_iam_role, ensure_instance_profile +from .util.aws.iam import IAMPolicyBuilder, ensure_iam_role, ensure_instance_profile, ensure_fargate_execution_role def complete_queue_name(**kwargs): return [q["jobQueueName"] for q in paginate(clients.batch.get_paginator("describe_job_queues"))] @@ -76,17 +76,6 @@ def ensure_launch_template(prefix=__name__.replace(".", "_"), **kwargs): return name def create_compute_environment(args): - commands = instance_storage_shellcode.strip().format(mountpoint="/mnt", mkfs=get_mkfs_command()).split("\n") - user_data = get_user_data(commands=commands, mime_multipart_archive=True) - if args.ecs_container_instance_ami: - ecs_ami_id = args.ecs_container_instance_ami - elif args.ecs_container_instance_ami_tags: - ecs_ami_id = resolve_ami(**args.ecs_container_instance_ami_tags) - else: - ecs_ami_id = get_ssm_parameter("/aws/service/ecs/optimized-ami/amazon-linux-2/recommended/image_id") - launch_template = ensure_launch_template(ImageId=ecs_ami_id, - # TODO: add configurable BDM for Docker image cache space - UserData=base64.b64encode(user_data).decode()) batch_iam_role = ensure_iam_role(args.service_role, trust=["batch"], policies=["service-role/AWSBatchServiceRole"]) vpc = ensure_vpc() ssh_key_name = ensure_ssh_key(args.ssh_key_name, base_name=__name__) @@ -96,16 +85,29 @@ def create_compute_environment(args): "AmazonSSMManagedInstanceCore", IAMPolicyBuilder(action="sts:AssumeRole", resource="*")}) compute_resources = dict(type=args.compute_type, - minvCpus=args.min_vcpus, desiredvCpus=args.desired_vcpus, maxvCpus=args.max_vcpus, - instanceTypes=args.instance_types, + maxvCpus=args.max_vcpus, subnets=[subnet.id for subnet in vpc.subnets.all()], - securityGroupIds=[ensure_security_group("aegea.launch", vpc).id], - instanceRole=instance_profile.name, - bidPercentage=100, - spotIamFleetRole=SpotFleetBuilder.get_iam_fleet_role().name, - ec2KeyPair=ssh_key_name, - tags=dict(Name=__name__), - launchTemplate=dict(launchTemplateName=launch_template)) + securityGroupIds=[ensure_security_group("aegea.launch", vpc).id]) + if not args.compute_type.startswith("FARGATE"): + commands = instance_storage_shellcode.strip().format(mountpoint="/mnt", mkfs=get_mkfs_command()).split("\n") + user_data = get_user_data(commands=commands, mime_multipart_archive=True) + if args.ecs_container_instance_ami: + ecs_ami_id = args.ecs_container_instance_ami + elif args.ecs_container_instance_ami_tags: + ecs_ami_id = resolve_ami(**args.ecs_container_instance_ami_tags) + else: + ecs_ami_id = get_ssm_parameter("/aws/service/ecs/optimized-ami/amazon-linux-2/recommended/image_id") + launch_template = ensure_launch_template(ImageId=ecs_ami_id, + # TODO: add configurable BDM for Docker image cache space + UserData=base64.b64encode(user_data).decode()) + compute_resources.update(minvCpus=args.min_vcpus, + desiredvCpus=args.desired_vcpus, + ec2KeyPair=ssh_key_name, + instanceRole=instance_profile.name, + instanceTypes=args.instance_types, + launchTemplate=dict(launchTemplateName=launch_template), + spotIamFleetRole=SpotFleetBuilder.get_iam_fleet_role().name, + tags=dict(Name=__name__)) logger.info("Creating compute environment %s in %s", args.name, vpc) compute_environment = clients.batch.create_compute_environment(computeEnvironmentName=args.name, type=args.type, @@ -119,7 +121,7 @@ def create_compute_environment(args): cce_parser = register_parser(create_compute_environment, parent=batch_parser, help="Create a Batch compute environment") cce_parser.add_argument("name") cce_parser.add_argument("--type", choices={"MANAGED", "UNMANAGED"}) -cce_parser.add_argument("--compute-type", choices={"EC2", "SPOT"}) +cce_parser.add_argument("--compute-type", choices={"EC2", "SPOT", "FARGATE", "FARGATE_SPOT"}) cce_parser.add_argument("--min-vcpus", type=int) cce_parser.add_argument("--desired-vcpus", type=int) cce_parser.add_argument("--max-vcpus", type=int) @@ -176,9 +178,11 @@ def submit(args): if not any([args.command, args.execute, args.wdl]): raise AegeaException("One of the arguments --command --execute --wdl is required") elif args.name is None: - raise AegeaException("The argument --name is required") + args.name = os.path.basename(args.job_definition_arn).replace(":", "_") + ensure_log_group("docker") ensure_log_group("syslog") + args.execution_role_arn = ensure_fargate_execution_role(__name__ + ".fargate_execution").arn if args.job_definition_arn is None: command, environment = get_command_and_env(args) container_overrides = dict(command=command, environment=environment) @@ -190,38 +194,53 @@ def submit(args): )) else: args.default_job_role_iam_policies = [] - - jd_res = ensure_job_definition(args) - args.job_definition_arn = jd_res["jobDefinitionArn"] - args.name = args.name or "{}_{}".format(jd_res["jobDefinitionName"], jd_res["revision"]) + job_definition_arn, job_name = ensure_job_definition(args) else: container_overrides = {} if args.command: container_overrides["command"] = args.command if args.environment: container_overrides["environment"] = args.environment + job_definition_arn, job_name = args.job_definition_arn, args.name + if args.memory is None: logger.warn("Specify a memory quota for your job with --memory-mb NNNN.") logger.warn("The memory quota is required and a hard limit. Setting it to %d MB.", int(args.default_memory_mb)) args.memory = int(args.default_memory_mb) container_overrides["memory"] = args.memory - submit_args = dict(jobName=args.name, - jobQueue=args.queue, - dependsOn=[dict(jobId=dep) for dep in args.depends_on], - jobDefinition=args.job_definition_arn, - parameters={k: v for k, v in args.parameters}, - containerOverrides=container_overrides) - if args.dry_run: - logger.info("The following command would be run:") - sys.stderr.write(json.dumps(submit_args, indent=4) + "\n") - return {"Dry run succeeded": True} - try: - job = clients.batch.submit_job(**submit_args) - except ClientError as e: - if not re.search("JobQueue .+ not found", str(e)): - raise - ensure_queue(args.queue) - job = clients.batch.submit_job(**submit_args) + while True: + submit_args = dict(jobName=job_name, + jobQueue=args.queue, + dependsOn=[dict(jobId=dep) for dep in args.depends_on], + jobDefinition=job_definition_arn, + parameters={k: v for k, v in args.parameters}, + containerOverrides=container_overrides) + try: + if args.dry_run: + logger.info("The following command would be run:") + sys.stderr.write(json.dumps(submit_args, indent=4) + "\n") + return {"Dry run succeeded": True} + else: + job = clients.batch.submit_job(**submit_args) + break + except (clients.batch.exceptions.ClientError, clients.batch.exceptions.ClientException) as e: + if re.search("JobQueue .+ not found", str(e)): + ensure_queue(args.queue) + elif "Job Queue is attached to Compute Environment that can not run Jobs with capability EC2" in str(e): + if args.job_definition_arn is not None: + raise AegeaException("To submit a job to a Fargate queue, specify a Fargate job definition") + logger.debug("This job must be dispatched to Fargate, switching to Fargate job definition") + args.platform_capabilities = ["FARGATE"] + job_definition_arn, job_name = ensure_job_definition(args) + container_overrides["resourceRequirements"] = [dict(type="VCPU", value=str(args.vcpus)), + dict(type="MEMORY", value=str(args.memory))] + del container_overrides["memory"] + submit_args.update(jobName=job_name, + jobDefinition=job_definition_arn, + containerOverrides=container_overrides) + else: + raise + if args.watch: try: watch(watch_parser.parse_args([job["jobId"]])) diff --git a/aegea/ecs.py b/aegea/ecs.py index ce7fd0c..19e047a 100644 --- a/aegea/ecs.py +++ b/aegea/ecs.py @@ -23,7 +23,7 @@ ensure_ecs_cluster, expect_error_codes, encode_tags) from .util.aws.logs import CloudwatchLogReader from .util.aws.batch import get_command_and_env, set_ulimits, get_volumes_and_mountpoints, get_ecr_image_uri -from .util.aws.iam import ensure_iam_role +from .util.aws.iam import ensure_iam_role, ensure_fargate_execution_role def complete_cluster_name(**kwargs): return [ARN(c).resource.partition("/")[2] for c in paginate(clients.ecs.get_paginator("list_clusters"))] @@ -114,9 +114,7 @@ def run(args): mountPoints=[dict(sourceVolume="scratch", containerPath="/mnt")] + mount_points, volumesFrom=[]) set_ulimits(args, container_defn) - exec_role = ensure_iam_role(args.execution_role, trust=["ecs-tasks"], - policies=["service-role/AmazonEC2ContainerServiceforEC2Role", - "service-role/AWSBatchServiceRole"]) + exec_role = ensure_fargate_execution_role(args.execution_role) task_role = ensure_iam_role(args.task_role, trust=["ecs-tasks"]) expect_task_defn = dict(containerDefinitions=[container_defn], @@ -149,12 +147,12 @@ def run(args): task_desc = clients.ecs.register_task_definition(family=task_defn_name, **expect_task_defn)["taskDefinition"] network_config = { - 'awsvpcConfiguration': { - 'subnets': [ + "awsvpcConfiguration": { + "subnets": [ subnet.id for subnet in vpc.subnets.all() ], - 'securityGroups': [ensure_security_group(args.security_group, vpc).id], - 'assignPublicIp': 'ENABLED' + "securityGroups": [ensure_security_group(args.security_group, vpc).id], + "assignPublicIp": "ENABLED" } } container_overrides = [dict(name=args.task_name, command=command, environment=environment)] diff --git a/aegea/util/aws/batch.py b/aegea/util/aws/batch.py index 9146641..c2aa3e1 100644 --- a/aegea/util/aws/batch.py +++ b/aegea/util/aws/batch.py @@ -172,20 +172,37 @@ def get_volumes_and_mountpoints(args): return volumes, mount_points def ensure_job_definition(args): + def get_jd_arn_and_job_name(jd_res): + job_name = args.name or "{}_{}".format(jd_res["jobDefinitionName"], jd_res["revision"]) + return jd_res["jobDefinitionArn"], job_name + if args.ecs_image: args.image = get_ecr_image_uri(args.ecs_image) - container_props = {k: getattr(args, k) for k in ("image", "vcpus", "user", "privileged")} - container_props.update(memory=4, volumes=[], mountPoints=[], environment=[], command=[], resourceRequirements=[]) + container_props = dict(image=args.image, user=args.user, privileged=args.privileged) + container_props.update(volumes=[], mountPoints=[], environment=[], command=[], resourceRequirements=[], ulimits=[], + secrets=[]) + if args.platform_capabilities == ["FARGATE"]: + container_props["resourceRequirements"].append(dict(type="VCPU", value="0.25")) + container_props["resourceRequirements"].append(dict(type="MEMORY", value="512")) + container_props["executionRoleArn"] = args.execution_role_arn + else: + container_props["vcpus"] = args.vcpus + container_props["memory"] = 4 + set_ulimits(args, container_props) container_props["volumes"], container_props["mountPoints"] = get_volumes_and_mountpoints(args) - set_ulimits(args, container_props) if args.gpus: - container_props["resourceRequirements"] = [{"type": "GPU", "value": str(args.gpus)}] + container_props["resourceRequirements"].append(dict(type="GPU", value=str(args.gpus))) iam_role = ensure_iam_role(args.job_role, trust=["ecs-tasks"], policies=args.default_job_role_iam_policies) container_props.update(jobRoleArn=iam_role.arn) - expect_job_defn = dict(status="ACTIVE", type="container", parameters={}, - retryStrategy={'attempts': args.retry_attempts}, containerProperties=container_props) + expect_job_defn = dict(status="ACTIVE", type="container", parameters={}, tags={}, + retryStrategy=dict(attempts=args.retry_attempts, evaluateOnExit=[]), + containerProperties=container_props, platformCapabilities=args.platform_capabilities) job_hash = hashlib.sha256(json.dumps(container_props, sort_keys=True).encode()).hexdigest()[:8] job_defn_name = __name__.replace(".", "_") + "_jd_" + job_hash + if args.platform_capabilities == ["FARGATE"]: + job_defn_name += "_FARGATE" + container_props["fargatePlatformConfiguration"] = dict(platformVersion="LATEST") + container_props["networkConfiguration"] = dict(assignPublicIp="ENABLED") describe_job_definitions_paginator = Paginator(method=clients.batch.describe_job_definitions, pagination_config=dict(result_key="jobDefinitions", input_token="nextToken", @@ -195,12 +212,15 @@ def ensure_job_definition(args): for job_defn in paginate(describe_job_definitions_paginator, jobDefinitionName=job_defn_name): job_defn_desc = {k: job_defn.pop(k) for k in ("jobDefinitionName", "jobDefinitionArn", "revision")} if job_defn == expect_job_defn: - return job_defn_desc - return clients.batch.register_job_definition(jobDefinitionName=job_defn_name, - type="container", - containerProperties=container_props, - retryStrategy=dict(attempts=args.retry_attempts)) - + logger.info("Found existing Batch job definition %s", job_defn_desc["jobDefinitionArn"]) + return get_jd_arn_and_job_name(job_defn_desc) + logger.info("Creating new Batch job definition %s", job_defn_name) + jd_res = clients.batch.register_job_definition(jobDefinitionName=job_defn_name, + type="container", + containerProperties=container_props, + retryStrategy=dict(attempts=args.retry_attempts), + platformCapabilities=args.platform_capabilities) + return get_jd_arn_and_job_name(jd_res) def ensure_lambda_helper(): awslambda = getattr(clients, "lambda") diff --git a/aegea/util/aws/iam.py b/aegea/util/aws/iam.py index 5dab2cb..d85dbdd 100644 --- a/aegea/util/aws/iam.py +++ b/aegea/util/aws/iam.py @@ -149,3 +149,8 @@ def compose_managed_policies(policy_names): policy.policy["Statement"].append(statement) policy.policy["Statement"][-1]["Sid"] = policy_name + str(i) return policy + +def ensure_fargate_execution_role(name): + return ensure_iam_role(name, trust=["ecs-tasks"], + policies=["service-role/AmazonEC2ContainerServiceforEC2Role", + "service-role/AWSBatchServiceRole"])