Skip to content

update task_resume_workflows to also resume processes in CREATED/RESUMED status #984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/reference-docs/app/celery-flow.drawio.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions docs/reference-docs/app/scaling.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,11 @@ celery.conf.task_routes = {

If you decide to override the queue names in this configuration, you also have to make sure that you also
update the names accordingly after the `-Q` flag.

### Celery Workflow/Task flow

This diagram shows the current flow of how we execute a workflow or task with celery.
It's created to show the reason why a workflow/task can get stuck on `CREATED` or `RESUMED` and what we've done to fix it.
All step statuses are shown in UPPERCASE for clarity.

![Celery Workflow/Task flow](celery-flow.drawio.png)
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
from celery.result import AsyncResult
from kombu.exceptions import ConnectionError, OperationalError

from oauth2_lib.fastapi import OIDCUserModel
from orchestrator import app_settings
from orchestrator.api.error_handling import raise_status
from orchestrator.db import ProcessTable, db
from orchestrator.services.input_state import store_input_state
from orchestrator.services.processes import create_process, delete_process
from orchestrator.services.workflows import get_workflow_by_name
from orchestrator.workflows import get_workflow
from orchestrator.workflow import ProcessStat
from pydantic_forms.types import State

SYSTEM_USER = "SYSTEM"
Expand All @@ -42,29 +40,17 @@ def _block_when_testing(task_result: AsyncResult) -> None:
raise RuntimeError("Celery worker has failed to resume process")


def _celery_start_process(
workflow_key: str,
user_inputs: list[State] | None,
user: str = SYSTEM_USER,
user_model: OIDCUserModel | None = None,
**kwargs: Any,
) -> UUID:
def _celery_start_process(pstat: ProcessStat, user: str = SYSTEM_USER, **kwargs: Any) -> UUID:
"""Client side call of Celery."""
from orchestrator.services.tasks import NEW_TASK, NEW_WORKFLOW, get_celery_task

workflow = get_workflow(workflow_key)
if not workflow:
raise_status(HTTPStatus.NOT_FOUND, "Workflow does not exist")

wf_table = get_workflow_by_name(workflow.name)
if not wf_table:
if not (wf_table := get_workflow_by_name(pstat.workflow.name)):
raise_status(HTTPStatus.NOT_FOUND, "Workflow in Database does not exist")

task_name = NEW_TASK if wf_table.is_task else NEW_WORKFLOW
trigger_task = get_celery_task(task_name)
pstat = create_process(workflow_key, user_inputs=user_inputs, user=user, user_model=user_model)
try:
result = trigger_task.delay(pstat.process_id, workflow_key, user)
result = trigger_task.delay(pstat.process_id, pstat.workflow.name, user)
_block_when_testing(result)
return pstat.process_id
except (ConnectionError, OperationalError) as e:
Expand All @@ -77,10 +63,9 @@ def _celery_start_process(
def _celery_resume_process(
process: ProcessTable,
*,
user_inputs: list[State] | None = None,
user: str | None = None,
**kwargs: Any,
) -> UUID:
) -> bool:
"""Client side call of Celery."""
from orchestrator.services.processes import load_process
from orchestrator.services.tasks import RESUME_TASK, RESUME_WORKFLOW, get_celery_task
Expand All @@ -96,14 +81,12 @@ def _celery_resume_process(
task_name = RESUME_TASK if wf_table.is_task else RESUME_WORKFLOW
trigger_task = get_celery_task(task_name)

user_inputs = user_inputs or [{}]
store_input_state(pstat.process_id, user_inputs, "user_input")
try:
_celery_set_process_status_resumed(process)
result = trigger_task.delay(pstat.process_id, user)
_block_when_testing(result)

return pstat.process_id
return True
except (ConnectionError, OperationalError) as e:
logger.warning(
"Connection error when submitting task to celery. Resetting process status back",
Expand Down Expand Up @@ -135,7 +118,8 @@ def _celery_set_process_status_resumed(process: ProcessTable) -> None:


def _celery_validate(validation_workflow: str, json: list[State] | None) -> None:
_celery_start_process(validation_workflow, user_inputs=json)
pstat = create_process(validation_workflow, user_inputs=json)
_celery_start_process(pstat)


CELERY_EXECUTION_CONTEXT: dict[str, Callable] = {
Expand Down
102 changes: 102 additions & 0 deletions orchestrator/services/executors/threadpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2019-2025 SURF, GÉANT, ESnet.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
from functools import partial
from uuid import UUID

import structlog

from oauth2_lib.fastapi import OIDCUserModel
from orchestrator.db import ProcessTable, db
from orchestrator.services.input_state import retrieve_input_state
from orchestrator.services.processes import (
SYSTEM_USER,
StateMerger,
_get_process,
_run_process_async,
create_process,
load_process,
safe_logstep,
)
from orchestrator.types import BroadcastFunc
from orchestrator.workflow import (
ProcessStat,
ProcessStatus,
runwf,
)
from orchestrator.workflows.removed_workflow import removed_workflow
from pydantic_forms.types import State

logger = structlog.get_logger(__name__)


def thread_start_process(
pstat: ProcessStat,
user: str = SYSTEM_USER,
user_model: OIDCUserModel | None = None,
broadcast_func: BroadcastFunc | None = None,
) -> UUID:
if pstat.workflow == removed_workflow:
raise ValueError("This workflow cannot be started")

process = _get_process(pstat.process_id)
process.last_status = ProcessStatus.RUNNING
db.session.add(process)
db.session.commit()

pstat = load_process(process)
input_data = retrieve_input_state(process.process_id, "initial_state")
pstat.update(state=pstat.state.map(lambda state: StateMerger.merge(state, input_data.input_state)))

_safe_logstep_with_func = partial(safe_logstep, broadcast_func=broadcast_func)
return _run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_with_func))


def thread_resume_process(
process: ProcessTable,
*,
user: str | None = None,
user_model: OIDCUserModel | None = None,
broadcast_func: BroadcastFunc | None = None,
) -> UUID:
# ATTENTION!! When modifying this function make sure you make similar changes to `resume_workflow` in the test code
pstat = load_process(process)
if pstat.workflow == removed_workflow:
raise ValueError("This workflow cannot be resumed because it has been removed")

if user:
pstat.update(current_user=user)

input_data = retrieve_input_state(process.process_id, "user_input")
pstat.update(state=pstat.state.map(lambda state: StateMerger.merge(state, input_data.input_state)))

# enforce an update to the process status to properly show the process
process.last_status = ProcessStatus.RUNNING
db.session.add(process)
db.session.commit()

_safe_logstep_prep = partial(safe_logstep, broadcast_func=broadcast_func)
_run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_prep))
return pstat.process_id


def thread_validate_workflow(validation_workflow: str, json: list[State] | None) -> UUID:
pstat = create_process(validation_workflow, user_inputs=json)
return thread_start_process(pstat)


THREADPOOL_EXECUTION_CONTEXT: dict[str, Callable] = {
"start": thread_start_process,
"resume": thread_resume_process,
"validate": thread_validate_workflow,
}
84 changes: 26 additions & 58 deletions orchestrator/services/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
Success,
Workflow,
abort_wf,
runwf,
)
from orchestrator.workflow import Process as WFProcess
from orchestrator.workflows import get_workflow
Expand All @@ -69,10 +68,12 @@

def get_execution_context() -> dict[str, Callable]:
if app_settings.EXECUTOR == ExecutorType.WORKER:
from orchestrator.services.celery import CELERY_EXECUTION_CONTEXT
from orchestrator.services.executors.celery import CELERY_EXECUTION_CONTEXT

return CELERY_EXECUTION_CONTEXT

from orchestrator.services.executors.threadpool import THREADPOOL_EXECUTION_CONTEXT

return THREADPOOL_EXECUTION_CONTEXT


Expand Down Expand Up @@ -440,7 +441,6 @@ def create_process(
}

try:

state = post_form(workflow.initial_input_form, initial_state, user_inputs)
except FormValidationError:
logger.exception("Validation errors", user_inputs=user_inputs)
Expand All @@ -460,19 +460,6 @@ def create_process(
return pstat


def thread_start_process(
workflow_key: str,
user_inputs: list[State] | None = None,
user: str = SYSTEM_USER,
user_model: OIDCUserModel | None = None,
broadcast_func: BroadcastFunc | None = None,
) -> UUID:
pstat = create_process(workflow_key, user_inputs=user_inputs, user=user, user_model=user_model)

_safe_logstep_with_func = partial(safe_logstep, broadcast_func=broadcast_func)
return _run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_with_func))


def start_process(
workflow_key: str,
user_inputs: list[State] | None = None,
Expand All @@ -493,57 +480,33 @@ def start_process(
process id

"""
pstat = create_process(workflow_key, user_inputs=user_inputs, user=user)

start_func = get_execution_context()["start"]
return start_func(
workflow_key, user_inputs=user_inputs, user=user, user_model=user_model, broadcast_func=broadcast_func
)
return start_func(pstat, user=user, user_model=user_model, broadcast_func=broadcast_func)


def thread_resume_process(
def restart_process(
process: ProcessTable,
*,
user_inputs: list[State] | None = None,
user: str | None = None,
broadcast_func: BroadcastFunc | None = None,
) -> UUID:
# ATTENTION!! When modifying this function make sure you make similar changes to `resume_workflow` in the test code

if user_inputs is None:
user_inputs = [{}]

pstat = load_process(process)

if pstat.workflow == removed_workflow:
raise ValueError("This workflow cannot be resumed")

form = pstat.log[0].form

user_input = post_form(form, pstat.state.unwrap(), user_inputs)

if user:
pstat.update(current_user=user)

if user_input:
pstat.update(state=pstat.state.map(lambda state: StateMerger.merge(state, user_input)))
store_input_state(pstat.process_id, user_input, "user_input")
# enforce an update to the process status to properly show the process
process.last_status = ProcessStatus.RUNNING
db.session.add(process)
db.session.commit()

_safe_logstep_prep = partial(safe_logstep, broadcast_func=broadcast_func)
return _run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_prep))
"""Start a process for workflow.

Args:
process: Process from database
user: user who resumed this process
broadcast_func: Optional function to broadcast process data

def thread_validate_workflow(validation_workflow: str, json: list[State] | None) -> UUID:
return thread_start_process(validation_workflow, user_inputs=json)
Returns:
process id

"""
pstat = load_process(process)

THREADPOOL_EXECUTION_CONTEXT: dict[str, Callable] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the PR should mention that these dicts were moved. And that the signature of the functions has changed. Just in case someone is using them directly

"start": thread_start_process,
"resume": thread_resume_process,
"validate": thread_validate_workflow,
}
start_func = get_execution_context()["start"]
return start_func(pstat, user=user, broadcast_func=broadcast_func)


def resume_process(
Expand All @@ -552,7 +515,7 @@ def resume_process(
user_inputs: list[State] | None = None,
user: str | None = None,
broadcast_func: BroadcastFunc | None = None,
) -> UUID:
) -> bool:
"""Resume a failed or suspended process.

Args:
Expand All @@ -567,14 +530,19 @@ def resume_process(
"""
pstat = load_process(process)

if pstat.workflow == removed_workflow:
raise ValueError("This workflow cannot be resumed because it has been removed")

try:
post_form(pstat.log[0].form, pstat.state.unwrap(), user_inputs=user_inputs or [])
user_input = post_form(pstat.log[0].form, pstat.state.unwrap(), user_inputs=user_inputs or [{}])
except FormValidationError:
logger.exception("Validation errors", user_inputs=user_inputs)
raise

store_input_state(pstat.process_id, user_input, "user_input")

resume_func = get_execution_context()["resume"]
return resume_func(process, user_inputs=user_inputs, user=user, broadcast_func=broadcast_func)
return resume_func(process, user=user, broadcast_func=broadcast_func)


def ensure_correct_callback_token(pstat: ProcessStat, *, token: str) -> None:
Expand Down
Loading
Loading