-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrack_task_progress.py
123 lines (107 loc) · 4.13 KB
/
track_task_progress.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Celery background task to process task asynchronously."""
import logging
from time import sleep
from foca.database.register_mongodb import _create_mongo_client # type: ignore
from foca.models.config import Config # type: ignore
from flask import Flask
from flask import current_app
import tes # type: ignore
from pro_tes.ga4gh.tes.models import TesState, TesTask
from pro_tes.utils.db import DbDocumentConnector
from pro_tes.ga4gh.tes.states import States
from pro_tes.celery_worker import celery
from pro_tes.utils.models import TaskModelConverter
logger = logging.getLogger(__name__)
# pylint: disable-msg=too-many-locals
# pylint: disable=unsubscriptable-object
# pylint: disable=too-many-arguments
# pylint: disable=too-many-positional-arguments
# pylint: disable=unused-argument
@celery.task(
name="tasks.track_run_progress",
bind=True,
ignore_result=True,
track_started=True,
)
def task__track_task_progress(
self,
worker_id: str,
remote_host: str,
remote_base_path: str,
remote_task_id: str,
user: str,
password: str,
) -> None:
"""Relay task run request to remote TES and track run progress.
Args:
worker_id: Worker identifier.
remote_host: Host at which the TES API is served that is processing
this request; note that this should include the path information
but *not* the base path defined in the TES API specification;
e.g., specify https://my.tes.com/api if the actual API is hosted at
https://my.tes.com/api/ga4gh/tes/v1.
remote_base_path: Override the default path suffix defined in the TES
API specification, i.e., `/ga4gh/tes/v1`.
remote_task_id: task run identifier on remote TES service.
user: User-name for basic authentication.
password: Password for basic authentication.
"""
foca_config: Config = current_app.config.foca # type: ignore
controller_config: dict = foca_config.controllers["post_task"]
# create database client
collection = _create_mongo_client(
app=Flask(__name__),
host=foca_config.db.host,
port=foca_config.db.port,
db="taskStore",
).db["tasks"]
db_client = DbDocumentConnector(
collection=collection,
worker_id=worker_id,
)
# update state: INITIALIZING
db_client.update_task_state(state=TesState.INITIALIZING.value)
url = f"{remote_host.strip('/')}/{remote_base_path.strip('/')}"
# fetch task log and upsert database document
try:
cli = tes.HTTPClient(
url,
timeout=5,
user=user,
password=password,
)
response = cli.get_task(task_id=remote_task_id)
except Exception:
db_client.update_task_state(state=TesState.SYSTEM_ERROR.value)
raise
# track task progress
task_state: TesState = TesState.UNKNOWN
attempt: int = 1
while task_state not in States.FINISHED:
sleep(controller_config["polling"]["wait"])
try:
response = cli.get_task(
task_id=remote_task_id,
)
except Exception as exc: # pylint: disable=broad-except
if attempt <= controller_config["polling"]["attempts"]:
attempt += 1
logger.warning(exc, exc_info=True)
continue
db_client.update_task_state(state=TesState.SYSTEM_ERROR.value)
raise
if response.state != task_state:
task_state = response.state
db_client.update_task_state(state=str(task_state))
task_model_converter = TaskModelConverter(task=response)
task_converted: TesTask = task_model_converter.convert_task()
document = db_client.get_document()
# updating task after task is finished
document.task.state = task_converted.state
assert task_converted.logs is not None
assert document.task.logs is not None
for index, logs in enumerate(task_converted.logs):
document.task.logs[index].logs = logs.logs
document.task.logs[index].outputs = logs.outputs
# updating the database
db_client.upsert_fields_in_root_object(root="task", **document.task.dict())