Skip to content

Commit

Permalink
batch: begin fargate support
Browse files Browse the repository at this point in the history
  • Loading branch information
kislyuk committed Dec 24, 2020
1 parent 6ffa804 commit 2f4c791
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 69 deletions.
13 changes: 8 additions & 5 deletions aegea/base_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -572,21 +572,24 @@ batch_queues:
- status
- statusReason

# Valid instance types values: "optimal", family name, type name
batch_create_compute_environment:
type: MANAGED
compute_type: SPOT
instance_types: [m5d, c5d, r5d]
compute_type: FARGATE_SPOT
max_vcpus: 64
# The following parameters are not applicable (and ignored) for Fargate CEs
# Valid instance types values: "optimal", family name, type name
min_vcpus: 0
desired_vcpus: 2
max_vcpus: 64
instance_types: [m5d, c5d, r5d]
# ecs_container_instance_ami_tags:
# AegeaMission: "ecs-container-instance"

batch_submit:
user: "0"
default_memory_mb: 4000
default_memory_mb: 4096
job_role: aegea.batch.worker
platform_capabilities:
- EC2
default_job_role_iam_policies:
- AmazonEC2ReadOnlyAccess
- AmazonS3ReadOnlyAccess
Expand Down
107 changes: 63 additions & 44 deletions aegea/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .util.aws.spot import SpotFleetBuilder
from .util.aws.logs import CloudwatchLogReader
from .util.aws.batch import ensure_job_definition, get_command_and_env, ensure_lambda_helper
from .util.aws.iam import IAMPolicyBuilder, ensure_iam_role, ensure_instance_profile
from .util.aws.iam import IAMPolicyBuilder, ensure_iam_role, ensure_instance_profile, ensure_fargate_execution_role

def complete_queue_name(**kwargs):
return [q["jobQueueName"] for q in paginate(clients.batch.get_paginator("describe_job_queues"))]
Expand Down Expand Up @@ -76,17 +76,6 @@ def ensure_launch_template(prefix=__name__.replace(".", "_"), **kwargs):
return name

def create_compute_environment(args):
commands = instance_storage_shellcode.strip().format(mountpoint="/mnt", mkfs=get_mkfs_command()).split("\n")
user_data = get_user_data(commands=commands, mime_multipart_archive=True)
if args.ecs_container_instance_ami:
ecs_ami_id = args.ecs_container_instance_ami
elif args.ecs_container_instance_ami_tags:
ecs_ami_id = resolve_ami(**args.ecs_container_instance_ami_tags)
else:
ecs_ami_id = get_ssm_parameter("/aws/service/ecs/optimized-ami/amazon-linux-2/recommended/image_id")
launch_template = ensure_launch_template(ImageId=ecs_ami_id,
# TODO: add configurable BDM for Docker image cache space
UserData=base64.b64encode(user_data).decode())
batch_iam_role = ensure_iam_role(args.service_role, trust=["batch"], policies=["service-role/AWSBatchServiceRole"])
vpc = ensure_vpc()
ssh_key_name = ensure_ssh_key(args.ssh_key_name, base_name=__name__)
Expand All @@ -96,16 +85,29 @@ def create_compute_environment(args):
"AmazonSSMManagedInstanceCore",
IAMPolicyBuilder(action="sts:AssumeRole", resource="*")})
compute_resources = dict(type=args.compute_type,
minvCpus=args.min_vcpus, desiredvCpus=args.desired_vcpus, maxvCpus=args.max_vcpus,
instanceTypes=args.instance_types,
maxvCpus=args.max_vcpus,
subnets=[subnet.id for subnet in vpc.subnets.all()],
securityGroupIds=[ensure_security_group("aegea.launch", vpc).id],
instanceRole=instance_profile.name,
bidPercentage=100,
spotIamFleetRole=SpotFleetBuilder.get_iam_fleet_role().name,
ec2KeyPair=ssh_key_name,
tags=dict(Name=__name__),
launchTemplate=dict(launchTemplateName=launch_template))
securityGroupIds=[ensure_security_group("aegea.launch", vpc).id])
if not args.compute_type.startswith("FARGATE"):
commands = instance_storage_shellcode.strip().format(mountpoint="/mnt", mkfs=get_mkfs_command()).split("\n")
user_data = get_user_data(commands=commands, mime_multipart_archive=True)
if args.ecs_container_instance_ami:
ecs_ami_id = args.ecs_container_instance_ami
elif args.ecs_container_instance_ami_tags:
ecs_ami_id = resolve_ami(**args.ecs_container_instance_ami_tags)
else:
ecs_ami_id = get_ssm_parameter("/aws/service/ecs/optimized-ami/amazon-linux-2/recommended/image_id")
launch_template = ensure_launch_template(ImageId=ecs_ami_id,
# TODO: add configurable BDM for Docker image cache space
UserData=base64.b64encode(user_data).decode())
compute_resources.update(minvCpus=args.min_vcpus,
desiredvCpus=args.desired_vcpus,
ec2KeyPair=ssh_key_name,
instanceRole=instance_profile.name,
instanceTypes=args.instance_types,
launchTemplate=dict(launchTemplateName=launch_template),
spotIamFleetRole=SpotFleetBuilder.get_iam_fleet_role().name,
tags=dict(Name=__name__))
logger.info("Creating compute environment %s in %s", args.name, vpc)
compute_environment = clients.batch.create_compute_environment(computeEnvironmentName=args.name,
type=args.type,
Expand All @@ -119,7 +121,7 @@ def create_compute_environment(args):
cce_parser = register_parser(create_compute_environment, parent=batch_parser, help="Create a Batch compute environment")
cce_parser.add_argument("name")
cce_parser.add_argument("--type", choices={"MANAGED", "UNMANAGED"})
cce_parser.add_argument("--compute-type", choices={"EC2", "SPOT"})
cce_parser.add_argument("--compute-type", choices={"EC2", "SPOT", "FARGATE", "FARGATE_SPOT"})
cce_parser.add_argument("--min-vcpus", type=int)
cce_parser.add_argument("--desired-vcpus", type=int)
cce_parser.add_argument("--max-vcpus", type=int)
Expand Down Expand Up @@ -176,9 +178,11 @@ def submit(args):
if not any([args.command, args.execute, args.wdl]):
raise AegeaException("One of the arguments --command --execute --wdl is required")
elif args.name is None:
raise AegeaException("The argument --name is required")
args.name = os.path.basename(args.job_definition_arn).replace(":", "_")

ensure_log_group("docker")
ensure_log_group("syslog")
args.execution_role_arn = ensure_fargate_execution_role(__name__ + ".fargate_execution").arn
if args.job_definition_arn is None:
command, environment = get_command_and_env(args)
container_overrides = dict(command=command, environment=environment)
Expand All @@ -190,38 +194,53 @@ def submit(args):
))
else:
args.default_job_role_iam_policies = []

jd_res = ensure_job_definition(args)
args.job_definition_arn = jd_res["jobDefinitionArn"]
args.name = args.name or "{}_{}".format(jd_res["jobDefinitionName"], jd_res["revision"])
job_definition_arn, job_name = ensure_job_definition(args)
else:
container_overrides = {}
if args.command:
container_overrides["command"] = args.command
if args.environment:
container_overrides["environment"] = args.environment
job_definition_arn, job_name = args.job_definition_arn, args.name

if args.memory is None:
logger.warn("Specify a memory quota for your job with --memory-mb NNNN.")
logger.warn("The memory quota is required and a hard limit. Setting it to %d MB.", int(args.default_memory_mb))
args.memory = int(args.default_memory_mb)
container_overrides["memory"] = args.memory
submit_args = dict(jobName=args.name,
jobQueue=args.queue,
dependsOn=[dict(jobId=dep) for dep in args.depends_on],
jobDefinition=args.job_definition_arn,
parameters={k: v for k, v in args.parameters},
containerOverrides=container_overrides)
if args.dry_run:
logger.info("The following command would be run:")
sys.stderr.write(json.dumps(submit_args, indent=4) + "\n")
return {"Dry run succeeded": True}
try:
job = clients.batch.submit_job(**submit_args)
except ClientError as e:
if not re.search("JobQueue .+ not found", str(e)):
raise
ensure_queue(args.queue)
job = clients.batch.submit_job(**submit_args)
while True:
submit_args = dict(jobName=job_name,
jobQueue=args.queue,
dependsOn=[dict(jobId=dep) for dep in args.depends_on],
jobDefinition=job_definition_arn,
parameters={k: v for k, v in args.parameters},
containerOverrides=container_overrides)
try:
if args.dry_run:
logger.info("The following command would be run:")
sys.stderr.write(json.dumps(submit_args, indent=4) + "\n")
return {"Dry run succeeded": True}
else:
job = clients.batch.submit_job(**submit_args)
break
except (clients.batch.exceptions.ClientError, clients.batch.exceptions.ClientException) as e:
if re.search("JobQueue .+ not found", str(e)):
ensure_queue(args.queue)
elif "Job Queue is attached to Compute Environment that can not run Jobs with capability EC2" in str(e):
if args.job_definition_arn is not None:
raise AegeaException("To submit a job to a Fargate queue, specify a Fargate job definition")
logger.debug("This job must be dispatched to Fargate, switching to Fargate job definition")
args.platform_capabilities = ["FARGATE"]
job_definition_arn, job_name = ensure_job_definition(args)
container_overrides["resourceRequirements"] = [dict(type="VCPU", value=str(args.vcpus)),
dict(type="MEMORY", value=str(args.memory))]
del container_overrides["memory"]
submit_args.update(jobName=job_name,
jobDefinition=job_definition_arn,
containerOverrides=container_overrides)
else:
raise

if args.watch:
try:
watch(watch_parser.parse_args([job["jobId"]]))
Expand Down
14 changes: 6 additions & 8 deletions aegea/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ensure_ecs_cluster, expect_error_codes, encode_tags)
from .util.aws.logs import CloudwatchLogReader
from .util.aws.batch import get_command_and_env, set_ulimits, get_volumes_and_mountpoints, get_ecr_image_uri
from .util.aws.iam import ensure_iam_role
from .util.aws.iam import ensure_iam_role, ensure_fargate_execution_role

def complete_cluster_name(**kwargs):
return [ARN(c).resource.partition("/")[2] for c in paginate(clients.ecs.get_paginator("list_clusters"))]
Expand Down Expand Up @@ -114,9 +114,7 @@ def run(args):
mountPoints=[dict(sourceVolume="scratch", containerPath="/mnt")] + mount_points,
volumesFrom=[])
set_ulimits(args, container_defn)
exec_role = ensure_iam_role(args.execution_role, trust=["ecs-tasks"],
policies=["service-role/AmazonEC2ContainerServiceforEC2Role",
"service-role/AWSBatchServiceRole"])
exec_role = ensure_fargate_execution_role(args.execution_role)
task_role = ensure_iam_role(args.task_role, trust=["ecs-tasks"])

expect_task_defn = dict(containerDefinitions=[container_defn],
Expand Down Expand Up @@ -149,12 +147,12 @@ def run(args):
task_desc = clients.ecs.register_task_definition(family=task_defn_name, **expect_task_defn)["taskDefinition"]

network_config = {
'awsvpcConfiguration': {
'subnets': [
"awsvpcConfiguration": {
"subnets": [
subnet.id for subnet in vpc.subnets.all()
],
'securityGroups': [ensure_security_group(args.security_group, vpc).id],
'assignPublicIp': 'ENABLED'
"securityGroups": [ensure_security_group(args.security_group, vpc).id],
"assignPublicIp": "ENABLED"
}
}
container_overrides = [dict(name=args.task_name, command=command, environment=environment)]
Expand Down
44 changes: 32 additions & 12 deletions aegea/util/aws/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,20 +172,37 @@ def get_volumes_and_mountpoints(args):
return volumes, mount_points

def ensure_job_definition(args):
def get_jd_arn_and_job_name(jd_res):
job_name = args.name or "{}_{}".format(jd_res["jobDefinitionName"], jd_res["revision"])
return jd_res["jobDefinitionArn"], job_name

if args.ecs_image:
args.image = get_ecr_image_uri(args.ecs_image)
container_props = {k: getattr(args, k) for k in ("image", "vcpus", "user", "privileged")}
container_props.update(memory=4, volumes=[], mountPoints=[], environment=[], command=[], resourceRequirements=[])
container_props = dict(image=args.image, user=args.user, privileged=args.privileged)
container_props.update(volumes=[], mountPoints=[], environment=[], command=[], resourceRequirements=[], ulimits=[],
secrets=[])
if args.platform_capabilities == ["FARGATE"]:
container_props["resourceRequirements"].append(dict(type="VCPU", value="0.25"))
container_props["resourceRequirements"].append(dict(type="MEMORY", value="512"))
container_props["executionRoleArn"] = args.execution_role_arn
else:
container_props["vcpus"] = args.vcpus
container_props["memory"] = 4
set_ulimits(args, container_props)
container_props["volumes"], container_props["mountPoints"] = get_volumes_and_mountpoints(args)
set_ulimits(args, container_props)
if args.gpus:
container_props["resourceRequirements"] = [{"type": "GPU", "value": str(args.gpus)}]
container_props["resourceRequirements"].append(dict(type="GPU", value=str(args.gpus)))
iam_role = ensure_iam_role(args.job_role, trust=["ecs-tasks"], policies=args.default_job_role_iam_policies)
container_props.update(jobRoleArn=iam_role.arn)
expect_job_defn = dict(status="ACTIVE", type="container", parameters={},
retryStrategy={'attempts': args.retry_attempts}, containerProperties=container_props)
expect_job_defn = dict(status="ACTIVE", type="container", parameters={}, tags={},
retryStrategy=dict(attempts=args.retry_attempts, evaluateOnExit=[]),
containerProperties=container_props, platformCapabilities=args.platform_capabilities)
job_hash = hashlib.sha256(json.dumps(container_props, sort_keys=True).encode()).hexdigest()[:8]
job_defn_name = __name__.replace(".", "_") + "_jd_" + job_hash
if args.platform_capabilities == ["FARGATE"]:
job_defn_name += "_FARGATE"
container_props["fargatePlatformConfiguration"] = dict(platformVersion="LATEST")
container_props["networkConfiguration"] = dict(assignPublicIp="ENABLED")
describe_job_definitions_paginator = Paginator(method=clients.batch.describe_job_definitions,
pagination_config=dict(result_key="jobDefinitions",
input_token="nextToken",
Expand All @@ -195,12 +212,15 @@ def ensure_job_definition(args):
for job_defn in paginate(describe_job_definitions_paginator, jobDefinitionName=job_defn_name):
job_defn_desc = {k: job_defn.pop(k) for k in ("jobDefinitionName", "jobDefinitionArn", "revision")}
if job_defn == expect_job_defn:
return job_defn_desc
return clients.batch.register_job_definition(jobDefinitionName=job_defn_name,
type="container",
containerProperties=container_props,
retryStrategy=dict(attempts=args.retry_attempts))

logger.info("Found existing Batch job definition %s", job_defn_desc["jobDefinitionArn"])
return get_jd_arn_and_job_name(job_defn_desc)
logger.info("Creating new Batch job definition %s", job_defn_name)
jd_res = clients.batch.register_job_definition(jobDefinitionName=job_defn_name,
type="container",
containerProperties=container_props,
retryStrategy=dict(attempts=args.retry_attempts),
platformCapabilities=args.platform_capabilities)
return get_jd_arn_and_job_name(jd_res)

def ensure_lambda_helper():
awslambda = getattr(clients, "lambda")
Expand Down
5 changes: 5 additions & 0 deletions aegea/util/aws/iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,8 @@ def compose_managed_policies(policy_names):
policy.policy["Statement"].append(statement)
policy.policy["Statement"][-1]["Sid"] = policy_name + str(i)
return policy

def ensure_fargate_execution_role(name):
return ensure_iam_role(name, trust=["ecs-tasks"],
policies=["service-role/AmazonEC2ContainerServiceforEC2Role",
"service-role/AWSBatchServiceRole"])

0 comments on commit 2f4c791

Please sign in to comment.