Skip to content

Commit 1ac7b2d

Browse files
committed
update task_resume_workflows to also resume processes in CREATED/RESUMED status
- improve start_process to include logic that happens in both executors. - move threadpool executor functions to its own file.
1 parent 3150434 commit 1ac7b2d

File tree

8 files changed

+253
-134
lines changed

8 files changed

+253
-134
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
<mxfile host="65bd71144e">
2+
<diagram id="2PQGSCLdhIAJmseasrSe" name="Page-1">
3+
<mxGraphModel dx="2190" dy="577" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
4+
<root>
5+
<mxCell id="0"/>
6+
<mxCell id="1" parent="0"/>
7+
<mxCell id="2" value="Workflow / Task" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1">
8+
<mxGeometry x="-820" y="20" width="120" height="60" as="geometry"/>
9+
</mxCell>
10+
<mxCell id="3" value="Add to DB&lt;div&gt;CREATED&lt;/div&gt;" style="shape=step;perimeter=stepPerimeter;whiteSpace=wrap;html=1;fixedSize=1;rounded=0;" parent="1" vertex="1">
11+
<mxGeometry x="-660" y="20" width="130" height="60" as="geometry"/>
12+
</mxCell>
13+
<mxCell id="4" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" parent="1" source="2" target="3" edge="1">
14+
<mxGeometry width="50" height="50" relative="1" as="geometry">
15+
<mxPoint x="60" y="300" as="sourcePoint"/>
16+
<mxPoint x="110" y="250" as="targetPoint"/>
17+
</mxGeometry>
18+
</mxCell>
19+
<mxCell id="5" value="Add to broker&lt;div&gt;(default Redis)&lt;/div&gt;" style="shape=step;perimeter=stepPerimeter;whiteSpace=wrap;html=1;fixedSize=1;" parent="1" vertex="1">
20+
<mxGeometry x="-560" y="20" width="130" height="60" as="geometry"/>
21+
</mxCell>
22+
<mxCell id="8" value="Celery worker&amp;nbsp;&lt;div&gt;picks it up&lt;/div&gt;&lt;div&gt;RUNNING&lt;/div&gt;" style="shape=step;perimeter=stepPerimeter;whiteSpace=wrap;html=1;fixedSize=1;" parent="1" vertex="1">
23+
<mxGeometry x="-450" y="20" width="130" height="60" as="geometry"/>
24+
</mxCell>
25+
<mxCell id="10" value="Done&lt;div&gt;COMPLETED&lt;/div&gt;" style="shape=step;perimeter=stepPerimeter;whiteSpace=wrap;html=1;fixedSize=1;" parent="1" vertex="1">
26+
<mxGeometry x="-340" y="20" width="130" height="60" as="geometry"/>
27+
</mxCell>
28+
<mxCell id="12" value="Task Fails&lt;div&gt;FAILED&lt;/div&gt;" style="shape=step;perimeter=stepPerimeter;whiteSpace=wrap;html=1;fixedSize=1;" parent="1" vertex="1">
29+
<mxGeometry x="-340" y="80" width="130" height="60" as="geometry"/>
30+
</mxCell>
31+
<mxCell id="14" value="Celery worker shutdown" style="shape=step;perimeter=stepPerimeter;whiteSpace=wrap;html=1;fixedSize=1;" parent="1" vertex="1">
32+
<mxGeometry x="-340" y="140" width="130" height="60" as="geometry"/>
33+
</mxCell>
34+
<mxCell id="16" value="" style="endArrow=classic;html=1;exitX=0.75;exitY=0;exitDx=0;exitDy=0;entryX=0.432;entryY=1.027;entryDx=0;entryDy=0;entryPerimeter=0;" parent="1" source="18" target="8" edge="1">
35+
<mxGeometry width="50" height="50" relative="1" as="geometry">
36+
<mxPoint x="-550.0000000000001" y="319.99999999999994" as="sourcePoint"/>
37+
<mxPoint x="-549.41" y="365.06" as="targetPoint"/>
38+
</mxGeometry>
39+
</mxCell>
40+
<mxCell id="18" value="CREATED &amp;amp; RESUMED&lt;div&gt;are states where it needs to be picked up by a worker&lt;/div&gt;&lt;div&gt;and can break when celery shuts down forcefully&lt;/div&gt;" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1">
41+
<mxGeometry x="-670" y="170" width="180" height="90" as="geometry"/>
42+
</mxCell>
43+
<mxCell id="19" value="Task/Workflow&lt;div&gt;hangs on&lt;/div&gt;&lt;div&gt;CREATED&lt;/div&gt;" style="shape=step;perimeter=stepPerimeter;whiteSpace=wrap;html=1;fixedSize=1;" parent="1" vertex="1">
44+
<mxGeometry x="-230" y="140" width="130" height="60" as="geometry"/>
45+
</mxCell>
46+
<mxCell id="24" value="Resume&lt;div&gt;workflow&lt;br&gt;RESUMED&lt;/div&gt;" style="shape=step;perimeter=stepPerimeter;whiteSpace=wrap;html=1;fixedSize=1;" parent="1" vertex="1">
47+
<mxGeometry x="-230" y="80" width="130" height="60" as="geometry"/>
48+
</mxCell>
49+
<mxCell id="25" value="Celery worker&amp;nbsp;&lt;div&gt;picks it up&lt;/div&gt;&lt;div&gt;RUNNING&lt;/div&gt;" style="shape=step;perimeter=stepPerimeter;whiteSpace=wrap;html=1;fixedSize=1;" parent="1" vertex="1">
50+
<mxGeometry x="-120" y="80" width="130" height="60" as="geometry"/>
51+
</mxCell>
52+
<mxCell id="26" value="Celery worker shutdown" style="shape=step;perimeter=stepPerimeter;whiteSpace=wrap;html=1;fixedSize=1;" parent="1" vertex="1">
53+
<mxGeometry x="-10" y="140" width="130" height="60" as="geometry"/>
54+
</mxCell>
55+
<mxCell id="27" value="Done&lt;div&gt;COMPLETED&lt;/div&gt;" style="shape=step;perimeter=stepPerimeter;whiteSpace=wrap;html=1;fixedSize=1;" parent="1" vertex="1">
56+
<mxGeometry x="-10" y="80" width="130" height="60" as="geometry"/>
57+
</mxCell>
58+
<mxCell id="30" value="Task/Workflow&lt;div&gt;hangs on&lt;/div&gt;&lt;div&gt;RESUMED&lt;/div&gt;" style="shape=step;perimeter=stepPerimeter;whiteSpace=wrap;html=1;fixedSize=1;" parent="1" vertex="1">
59+
<mxGeometry x="100" y="140" width="130" height="60" as="geometry"/>
60+
</mxCell>
61+
<mxCell id="34" value="" style="endArrow=classic;html=1;exitX=0.464;exitY=0.944;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;exitPerimeter=0;" parent="1" source="5" target="18" edge="1">
62+
<mxGeometry width="50" height="50" relative="1" as="geometry">
63+
<mxPoint x="-670" y="140" as="sourcePoint"/>
64+
<mxPoint x="-609" y="97" as="targetPoint"/>
65+
</mxGeometry>
66+
</mxCell>
67+
<mxCell id="41" value="&lt;span style=&quot;color: rgb(0, 0, 0); text-wrap-mode: nowrap;&quot;&gt;Add hanging processes&lt;/span&gt;&lt;div&gt;&lt;span style=&quot;color: rgb(0, 0, 0); text-wrap-mode: nowrap;&quot;&gt;with a scheduler&lt;/span&gt;&lt;/div&gt;" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1">
68+
<mxGeometry x="-380" y="270" width="180" height="70" as="geometry"/>
69+
</mxCell>
70+
<mxCell id="42" value="" style="endArrow=classic;html=1;entryX=0.835;entryY=-0.054;entryDx=0;entryDy=0;entryPerimeter=0;exitX=0.311;exitY=0.99;exitDx=0;exitDy=0;exitPerimeter=0;" parent="1" source="19" target="41" edge="1">
71+
<mxGeometry width="50" height="50" relative="1" as="geometry">
72+
<mxPoint x="-180" y="210" as="sourcePoint"/>
73+
<mxPoint x="-150" y="170" as="targetPoint"/>
74+
</mxGeometry>
75+
</mxCell>
76+
<mxCell id="44" style="edgeStyle=none;html=1;exitX=0.25;exitY=0;exitDx=0;exitDy=0;entryX=0.606;entryY=1.03;entryDx=0;entryDy=0;entryPerimeter=0;" parent="1" source="41" target="5" edge="1">
77+
<mxGeometry relative="1" as="geometry"/>
78+
</mxCell>
79+
<mxCell id="47" style="edgeStyle=none;html=1;exitX=0.394;exitY=1.03;exitDx=0;exitDy=0;entryX=1.001;entryY=0.363;entryDx=0;entryDy=0;entryPerimeter=0;exitPerimeter=0;" parent="1" source="30" target="41" edge="1">
80+
<mxGeometry relative="1" as="geometry"/>
81+
</mxCell>
82+
</root>
83+
</mxGraphModel>
84+
</diagram>
85+
</mxfile>
53.7 KB
Loading

docs/reference-docs/app/scaling.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,7 @@ celery.conf.task_routes = {
248248

249249
If you decide to override the queue names in this configuration, you also have to make sure that you also
250250
update the names accordingly after the `-Q` flag.
251+
252+
### Celery Workflow/Task flow
253+
254+
![Celery Workflow/Task flow](celery-flow.png)

orchestrator/services/celery.py renamed to orchestrator/services/executors/celery.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
from celery.result import AsyncResult
2020
from kombu.exceptions import ConnectionError, OperationalError
2121

22-
from oauth2_lib.fastapi import OIDCUserModel
2322
from orchestrator import app_settings
2423
from orchestrator.api.error_handling import raise_status
2524
from orchestrator.db import ProcessTable, db
2625
from orchestrator.services.input_state import store_input_state
2726
from orchestrator.services.processes import create_process, delete_process
2827
from orchestrator.services.workflows import get_workflow_by_name
29-
from orchestrator.workflows import get_workflow
28+
from orchestrator.workflow import ProcessStat
3029
from pydantic_forms.types import State
3130

3231
SYSTEM_USER = "SYSTEM"
@@ -42,29 +41,17 @@ def _block_when_testing(task_result: AsyncResult) -> None:
4241
raise RuntimeError("Celery worker has failed to resume process")
4342

4443

45-
def _celery_start_process(
46-
workflow_key: str,
47-
user_inputs: list[State] | None,
48-
user: str = SYSTEM_USER,
49-
user_model: OIDCUserModel | None = None,
50-
**kwargs: Any,
51-
) -> UUID:
44+
def _celery_start_process(pstat: ProcessStat, user: str = SYSTEM_USER, **kwargs: Any) -> UUID:
5245
"""Client side call of Celery."""
5346
from orchestrator.services.tasks import NEW_TASK, NEW_WORKFLOW, get_celery_task
5447

55-
workflow = get_workflow(workflow_key)
56-
if not workflow:
57-
raise_status(HTTPStatus.NOT_FOUND, "Workflow does not exist")
58-
59-
wf_table = get_workflow_by_name(workflow.name)
60-
if not wf_table:
48+
if not (wf_table := get_workflow_by_name(pstat.workflow.name)):
6149
raise_status(HTTPStatus.NOT_FOUND, "Workflow in Database does not exist")
6250

6351
task_name = NEW_TASK if wf_table.is_task else NEW_WORKFLOW
6452
trigger_task = get_celery_task(task_name)
65-
pstat = create_process(workflow_key, user_inputs=user_inputs, user=user, user_model=user_model)
6653
try:
67-
result = trigger_task.delay(pstat.process_id, workflow_key, user)
54+
result = trigger_task.delay(pstat.process_id, pstat.workflow.name, user)
6855
_block_when_testing(result)
6956
return pstat.process_id
7057
except (ConnectionError, OperationalError) as e:
@@ -80,7 +67,7 @@ def _celery_resume_process(
8067
user_inputs: list[State] | None = None,
8168
user: str | None = None,
8269
**kwargs: Any,
83-
) -> UUID:
70+
) -> bool:
8471
"""Client side call of Celery."""
8572
from orchestrator.services.processes import load_process
8673
from orchestrator.services.tasks import RESUME_TASK, RESUME_WORKFLOW, get_celery_task
@@ -103,7 +90,7 @@ def _celery_resume_process(
10390
result = trigger_task.delay(pstat.process_id, user)
10491
_block_when_testing(result)
10592

106-
return pstat.process_id
93+
return True
10794
except (ConnectionError, OperationalError) as e:
10895
logger.warning(
10996
"Connection error when submitting task to celery. Resetting process status back",
@@ -135,7 +122,8 @@ def _celery_set_process_status_resumed(process: ProcessTable) -> None:
135122

136123

137124
def _celery_validate(validation_workflow: str, json: list[State] | None) -> None:
138-
_celery_start_process(validation_workflow, user_inputs=json)
125+
pstat = create_process(validation_workflow, user_inputs=json)
126+
_celery_start_process(pstat)
139127

140128

141129
CELERY_EXECUTION_CONTEXT: dict[str, Callable] = {
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2019-2025 SURF, GÉANT, ESnet.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
from collections.abc import Callable
14+
from functools import partial
15+
from uuid import UUID
16+
17+
import structlog
18+
19+
from oauth2_lib.fastapi import OIDCUserModel
20+
from orchestrator.db import ProcessTable, db
21+
from orchestrator.services.input_state import store_input_state
22+
from orchestrator.services.processes import (
23+
SYSTEM_USER,
24+
StateMerger,
25+
_get_process,
26+
_run_process_async,
27+
create_process,
28+
load_process,
29+
safe_logstep,
30+
)
31+
from orchestrator.types import BroadcastFunc
32+
from orchestrator.workflow import (
33+
ProcessStat,
34+
ProcessStatus,
35+
runwf,
36+
)
37+
from orchestrator.workflows.removed_workflow import removed_workflow
38+
from pydantic_forms.core import post_form
39+
from pydantic_forms.types import State
40+
41+
logger = structlog.get_logger(__name__)
42+
43+
44+
def thread_start_process(
45+
pstat: ProcessStat,
46+
user: str = SYSTEM_USER,
47+
user_model: OIDCUserModel | None = None,
48+
broadcast_func: BroadcastFunc | None = None,
49+
) -> UUID:
50+
if pstat.workflow == removed_workflow:
51+
raise ValueError("This workflow cannot be started")
52+
53+
process = _get_process(pstat.process_id)
54+
process.last_status = ProcessStatus.RUNNING
55+
db.session.add(process)
56+
db.session.commit()
57+
58+
_safe_logstep_with_func = partial(safe_logstep, broadcast_func=broadcast_func)
59+
return _run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_with_func))
60+
61+
62+
def thread_resume_process(
63+
process: ProcessTable,
64+
*,
65+
user_inputs: list[State] | None = None,
66+
user: str | None = None,
67+
user_model: OIDCUserModel | None = None,
68+
broadcast_func: BroadcastFunc | None = None,
69+
) -> UUID:
70+
# ATTENTION!! When modifying this function make sure you make similar changes to `resume_workflow` in the test code
71+
72+
if user_inputs is None:
73+
user_inputs = [{}]
74+
75+
pstat = load_process(process)
76+
if pstat.workflow == removed_workflow:
77+
raise ValueError("This workflow cannot be resumed")
78+
79+
form = pstat.log[0].form
80+
81+
user_input = post_form(form, pstat.state.unwrap(), user_inputs)
82+
83+
if user:
84+
pstat.update(current_user=user)
85+
86+
if user_input:
87+
pstat.update(state=pstat.state.map(lambda state: StateMerger.merge(state, user_input)))
88+
store_input_state(pstat.process_id, user_input, "user_input")
89+
# enforce an update to the process status to properly show the process
90+
process.last_status = ProcessStatus.RUNNING
91+
db.session.add(process)
92+
db.session.commit()
93+
94+
_safe_logstep_prep = partial(safe_logstep, broadcast_func=broadcast_func)
95+
_run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_prep))
96+
return pstat.process_id
97+
98+
99+
def thread_validate_workflow(validation_workflow: str, json: list[State] | None) -> UUID:
100+
pstat = create_process(validation_workflow, user_inputs=json)
101+
return thread_start_process(pstat)
102+
103+
104+
THREADPOOL_EXECUTION_CONTEXT: dict[str, Callable] = {
105+
"start": thread_start_process,
106+
"resume": thread_resume_process,
107+
"validate": thread_validate_workflow,
108+
}

0 commit comments

Comments
 (0)