Skip to content

Commit ecdc192

Browse files
committed
Add user input into the pstat in thread start and resume
1 parent 9f23d77 commit ecdc192

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed

orchestrator/services/executors/threadpool.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919
from oauth2_lib.fastapi import OIDCUserModel
2020
from orchestrator.db import ProcessTable, db
21+
from orchestrator.services.input_state import retrieve_input_state
2122
from orchestrator.services.processes import (
2223
SYSTEM_USER,
24+
StateMerger,
2325
_get_process,
2426
_run_process_async,
2527
create_process,
@@ -52,6 +54,10 @@ def thread_start_process(
5254
db.session.add(process)
5355
db.session.commit()
5456

57+
pstat = load_process(process)
58+
input_data = retrieve_input_state(process.process_id, "initial_state")
59+
pstat.update(state=pstat.state.map(lambda state: StateMerger.merge(state, input_data.input_state)))
60+
5561
_safe_logstep_with_func = partial(safe_logstep, broadcast_func=broadcast_func)
5662
return _run_process_async(pstat.process_id, lambda: runwf(pstat, _safe_logstep_with_func))
5763

@@ -71,6 +77,9 @@ def thread_resume_process(
7177
if user:
7278
pstat.update(current_user=user)
7379

80+
input_data = retrieve_input_state(process.process_id, "user_input")
81+
pstat.update(state=pstat.state.map(lambda state: StateMerger.merge(state, input_data.input_state)))
82+
7483
# enforce an update to the process status to properly show the process
7584
process.last_status = ProcessStatus.RUNNING
7685
db.session.add(process)

orchestrator/services/processes.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,29 @@ def start_process(
486486
return start_func(pstat, user=user, user_model=user_model, broadcast_func=broadcast_func)
487487

488488

489+
def restart_process(
490+
process: ProcessTable,
491+
*,
492+
user: str | None = None,
493+
broadcast_func: BroadcastFunc | None = None,
494+
) -> UUID:
495+
"""Start a process for workflow.
496+
497+
Args:
498+
process: Process from database
499+
user: user who resumed this process
500+
broadcast_func: Optional function to broadcast process data
501+
502+
Returns:
503+
process id
504+
505+
"""
506+
pstat = load_process(process)
507+
508+
start_func = get_execution_context()["start"]
509+
return start_func(pstat, user=user, broadcast_func=broadcast_func)
510+
511+
489512
def resume_process(
490513
process: ProcessTable,
491514
*,

orchestrator/workflows/tasks/resume_workflows.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@ def resume_found_workflows(
5050
created_state_process_ids: list[UUIDstr],
5151
resumed_state_process_ids: list[UUIDstr],
5252
) -> State:
53-
all_processes = waiting_process_ids + created_state_process_ids + resumed_state_process_ids
53+
resume_processes = waiting_process_ids + resumed_state_process_ids
54+
5455
resumed_process_ids = []
55-
for process_id in all_processes:
56+
for process_id in resume_processes:
5657
try:
5758
process = db.session.get(ProcessTable, process_id)
5859
if not process:
@@ -67,7 +68,28 @@ def resume_found_workflows(
6768
# Make sure to turn it on again
6869
db.session.info["disabled"] = True
6970

70-
return {"number_of_resumed_process_ids": len(resumed_process_ids), "resumed_process_ids": resumed_process_ids}
71+
started_process_ids = []
72+
for process_id in created_state_process_ids:
73+
try:
74+
process = db.session.get(ProcessTable, process_id)
75+
if not process:
76+
continue
77+
# Workaround the commit disable function
78+
db.session.info["disabled"] = False
79+
processes.restart_process(process)
80+
started_process_ids.append(process_id)
81+
except Exception:
82+
logger.exception()
83+
finally:
84+
# Make sure to turn it on again
85+
db.session.info["disabled"] = True
86+
87+
return {
88+
"number_of_resumed_process_ids": len(resumed_process_ids),
89+
"resumed_process_ids": resumed_process_ids,
90+
"number_of_started_process_ids": len(started_process_ids),
91+
"started_process_ids": started_process_ids,
92+
}
7193

7294

7395
@workflow("Resume all workflows that are stuck on tasks with the status 'waiting'", target=Target.SYSTEM)

0 commit comments

Comments
 (0)