Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions metaflow/plugins/aws/batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def create_job(
ephemeral_storage=None,
log_driver=None,
log_options=None,
container_secrets=None,
offload_command_to_s3=False,
):
job_name = self._job_name(
Expand Down Expand Up @@ -303,6 +304,7 @@ def create_job(
ephemeral_storage=ephemeral_storage,
log_driver=log_driver,
log_options=log_options,
container_secrets=container_secrets,
)
.task_id(attrs.get("metaflow.task_id"))
.environment_variable("AWS_DEFAULT_REGION", self._client.region())
Expand Down Expand Up @@ -427,6 +429,7 @@ def launch_job(
ephemeral_storage=None,
log_driver=None,
log_options=None,
container_secrets=None,
):
if queue is None:
queue = next(self._client.active_job_queues(), None)
Expand Down Expand Up @@ -469,6 +472,7 @@ def launch_job(
ephemeral_storage=ephemeral_storage,
log_driver=log_driver,
log_options=log_options,
container_secrets=container_secrets,
)
self.num_parallel = num_parallel
self.job = job.execute()
Expand Down
20 changes: 20 additions & 0 deletions metaflow/plugins/aws/batch/batch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,25 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs):
if env_deco:
env.update(env_deco[0].attributes["vars"])

# Collect ECS-style container secrets from @secrets decorator, if any.
# These entries have shape: {"name": <ENV>, "value_from": <SecretsManager ARN>} (snake_case input),
# and will be injected at container startup by AWS Batch/ECS via job definition.
container_secrets = []
secrets_deco = [deco for deco in node.decorators if deco.name == "secrets"]
if secrets_deco:
try:
for s in secrets_deco[0].attributes.get("sources", []) or []:
if isinstance(s, dict):
name = s.get("name")
value_from = s.get("value_from")
if isinstance(name, str) and isinstance(value_from, str):
container_secrets.append(
{"name": name, "value_from": value_from}
)
except Exception:
# best-effort only; ignore malformed entries silently to avoid breaking launches
pass

# Add the environment variables related to the input-paths argument
if split_vars:
env.update(split_vars)
Expand Down Expand Up @@ -366,6 +385,7 @@ def _sync_metadata():
log_driver=log_driver,
log_options=log_options,
num_parallel=num_parallel,
container_secrets=container_secrets,
)
except Exception:
traceback.print_exc()
Expand Down
16 changes: 16 additions & 0 deletions metaflow/plugins/aws/batch/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _register_job_definition(
ephemeral_storage,
log_driver,
log_options,
container_secrets=None,
):
# identify platform from any compute environment associated with the
# queue
Expand Down Expand Up @@ -199,6 +200,19 @@ def _register_job_definition(
"propagateTags": True,
}

# Inject ECS secrets for container-start environment variables, if provided.
if container_secrets:
norm = []
for item in container_secrets:
if not isinstance(item, dict):
continue
name = item.get("name")
value_from = item.get("value_from")
if isinstance(name, str) and isinstance(value_from, str):
norm.append({"name": name, "valueFrom": value_from})
if norm:
job_definition["containerProperties"]["secrets"] = norm

log_options_dict = {}
if log_options:
if isinstance(log_options, str):
Expand Down Expand Up @@ -480,6 +494,7 @@ def job_def(
ephemeral_storage,
log_driver,
log_options,
container_secrets=None,
):
self.payload["jobDefinition"] = self._register_job_definition(
image,
Expand All @@ -502,6 +517,7 @@ def job_def(
ephemeral_storage,
log_driver,
log_options,
container_secrets,
)
return self

Expand Down