Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from flytekit.models import common as _common_models
from flytekit.models import interface as interface_models
from flytekit.models import launch_plan as _launch_plan_models
from flytekit.models import security as _security_models
from flytekit.models.admin import workflow as admin_workflow_models
from flytekit.models.admin.workflow import WorkflowSpec
from flytekit.models.concurrency import ConcurrencyPolicy
Expand Down Expand Up @@ -317,6 +318,26 @@ def get_serializable_workflow(
)


def _merge_security_context(
entity_sc: Optional[_security_models.SecurityContext],
options_sc: Optional[_security_models.SecurityContext],
) -> Optional[_security_models.SecurityContext]:
"""Merge the launch plan's authored security context with the one supplied via registration options.

Registration options override the authored launch plan per field (not wholesale), so e.g. registering with
only a service account does not drop secrets/tokens that were authored on the launch plan.
"""
if options_sc is None:
return entity_sc
if entity_sc is None:
return options_sc
return _security_models.SecurityContext(
run_as=options_sc.run_as or entity_sc.run_as,
secrets=options_sc.secrets or entity_sc.secrets,
tokens=options_sc.tokens or entity_sc.tokens,
)


def get_serializable_launch_plan(
entity_mapping: OrderedDict,
settings: SerializationSettings,
Expand Down Expand Up @@ -379,7 +400,7 @@ def get_serializable_launch_plan(
auth_role=None,
raw_output_data_config=raw_prefix_config,
max_parallelism=options.max_parallelism or entity.max_parallelism,
security_context=options.security_context or entity.security_context,
security_context=_merge_security_context(entity.security_context, options.security_context),
overwrite_cache=options.overwrite_cache or entity.overwrite_cache,
concurrency_policy=concurrency_policy,
)
Expand Down
38 changes: 32 additions & 6 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import re
import dataclasses
import os
import typing
import dataclasses
from collections import OrderedDict

import mock
Expand All @@ -10,17 +9,17 @@
import flytekit.configuration
from flytekit import ContainerTask, ImageSpec, kwtypes
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.array_node_map_task import map_task
from flytekit.core.condition import conditional
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.array_node_map_task import map_task
from flytekit.core.task import eager, task
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteAssertion, FlyteMissingTypeException
from flytekit.image_spec.image_spec import ImageBuildEngine
from flytekit.models import task as task_models
from flytekit.models.admin.workflow import WorkflowSpec
from flytekit.models.annotation import TypeAnnotation
from flytekit.models.literals import (
BindingData,
BindingDataCollection,
Expand All @@ -32,7 +31,6 @@
Void,
)
from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType
from flytekit.models import task as task_models
from flytekit.tools.translator import get_serializable
from flytekit.types.error.error import FlyteError

Expand Down Expand Up @@ -1272,3 +1270,31 @@ async def t1_eager(a: int) -> int:
],
limits=[]
)


def test_launch_plan_security_context_merge():
# Registration options override the authored launch plan per field; registering with only a
# service account must not drop secrets/tokens that were authored on the launch plan.
from flytekit import LaunchPlan
from flytekit.core.options import Options
from flytekit.models.security import Identity, Secret, SecurityContext

@task
def t_sc() -> int:
return 1

@workflow
def wf_sc() -> int:
return t_sc()

lp = LaunchPlan.get_or_create(
workflow=wf_sc,
name="lp_sc_merge",
security_context=SecurityContext(secrets=[Secret(group="g", key="k")]),
)
opts = Options(security_context=SecurityContext(run_as=Identity(k8s_service_account="my-sa")))

serialized = get_serializable(OrderedDict(), serialization_settings, lp, options=opts)
sc = serialized.spec.security_context
assert sc.run_as.k8s_service_account == "my-sa"
assert [(s.group, s.key) for s in sc.secrets] == [("g", "k")]
Loading