Skip to content

Commit 3d3ebd6

Browse files
committed
Remove config from task manager and stop killing it
1 parent add5b58 commit 3d3ebd6

21 files changed

+153
-93
lines changed

azimuth/app.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from azimuth.startup import startup_tasks
3030
from azimuth.task_manager import TaskManager
3131
from azimuth.types import DatasetSplitName, ModuleOptions, SupportedModule
32-
from azimuth.utils.cluster import default_cluster
3332
from azimuth.utils.conversion import JSONResponseIgnoreNan
3433
from azimuth.utils.logs import set_logger_config
3534
from azimuth.utils.validation import assert_not_none
@@ -147,9 +146,7 @@ def start_app(config_path: Optional[str], load_config_history: bool, debug: bool
147146
if azimuth_config.dataset is None:
148147
raise ValueError("No dataset has been specified in the config.")
149148

150-
local_cluster = default_cluster(large=azimuth_config.large_dask_cluster)
151-
152-
run_startup_tasks(azimuth_config, local_cluster)
149+
run_startup_tasks(azimuth_config)
153150
task_manager = assert_not_none(_task_manager)
154151
task_manager.client.run(set_logger_config, level)
155152

@@ -321,25 +318,20 @@ def load_dataset_split_managers_from_config(
321318
}
322319

323320

324-
def initialize_managers(azimuth_config: AzimuthConfig, cluster: SpecCluster):
325-
"""Initialize DatasetSplitManagers and TaskManagers.
326-
321+
def initialize_managers_and_config(
322+
azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None
323+
):
324+
"""Initialize DatasetSplitManagers and Config.
327325
328326
Args:
329-
azimuth_config: Configuration
330-
cluster: Dask cluster to use.
327+
azimuth_config: Config
328+
cluster: Dask cluster to use, if different than default.
331329
"""
332330
global _task_manager, _dataset_split_managers, _azimuth_config
333-
_azimuth_config = azimuth_config
334-
if _task_manager is not None:
335-
task_history = _task_manager.current_tasks
336-
else:
337-
task_history = {}
338-
339-
_task_manager = TaskManager(azimuth_config, cluster=cluster)
340-
341-
_task_manager.current_tasks = task_history
331+
if not _task_manager:
332+
_task_manager = TaskManager(cluster, azimuth_config.large_dask_cluster)
342333

334+
_azimuth_config = azimuth_config
343335
_dataset_split_managers = load_dataset_split_managers_from_config(azimuth_config)
344336

345337

@@ -361,6 +353,7 @@ def run_validation_module(pipeline_index=None):
361353
_, task = task_manager.get_task(
362354
task_name=SupportedModule.Validation,
363355
dataset_split_name=dataset_split,
356+
config=config,
364357
mod_options=ModuleOptions(pipeline_index=pipeline_index),
365358
)
366359
# Will raise exceptions as needed.
@@ -373,15 +366,14 @@ def run_validation_module(pipeline_index=None):
373366
run_validation_module(pipeline_index)
374367

375368

376-
def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
369+
def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None):
377370
"""Initialize managers, run validation and startup tasks.
378371
379372
Args:
380373
azimuth_config: Config
381-
cluster: Cluster
382-
374+
cluster: Dask cluster to use, if different than default.
383375
"""
384-
initialize_managers(azimuth_config, cluster)
376+
initialize_managers_and_config(azimuth_config, cluster)
385377

386378
task_manager = assert_not_none(get_task_manager())
387379
# Validate that everything is in order **before** the startup tasks.
@@ -393,5 +385,5 @@ def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
393385
azimuth_config.save() # Save only after the validation modules ran successfully
394386

395387
global _startup_tasks, _ready_flag
396-
_startup_tasks = startup_tasks(_dataset_split_managers, task_manager)
388+
_startup_tasks = startup_tasks(_dataset_split_managers, task_manager, azimuth_config)
397389
_ready_flag = Event()

azimuth/routers/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def get_perturbation_testing_summary(
165165
SupportedModule.PerturbationTestingMerged,
166166
dataset_split_name=DatasetSplitName.all,
167167
task_manager=task_manager,
168+
config=config,
168169
last_update=last_update,
169170
mod_options=ModuleOptions(pipeline_index=pipeline_index),
170171
)[0]
@@ -180,6 +181,7 @@ def get_perturbation_testing_summary(
180181
SupportedModule.PerturbationTestingSummary,
181182
dataset_split_name=DatasetSplitName.all,
182183
task_manager=task_manager,
184+
config=config,
183185
mod_options=ModuleOptions(pipeline_index=pipeline_index),
184186
)[0]
185187
return PerturbationTestingSummary(

azimuth/routers/class_overlap.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def get_class_overlap_plot(
7272
SupportedModule.ClassOverlap,
7373
dataset_split_name=dataset_split_name,
7474
task_manager=task_manager,
75+
config=config,
7576
last_update=-1,
7677
)[0]
7778
class_overlap_plot_response: ClassOverlapPlotResponse = make_sankey_plot(
@@ -94,6 +95,7 @@ def get_class_overlap_plot(
9495
def get_class_overlap(
9596
dataset_split_name: DatasetSplitName,
9697
task_manager: TaskManager = Depends(get_task_manager),
98+
config: AzimuthConfig = Depends(get_config),
9799
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
98100
dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends(
99101
get_dataset_split_manager_mapping
@@ -106,6 +108,7 @@ def get_class_overlap(
106108
SupportedModule.ClassOverlap,
107109
dataset_split_name=dataset_split_name,
108110
task_manager=task_manager,
111+
config=config,
109112
last_update=-1,
110113
)[0]
111114
dataset_class_count = class_overlap_result.s_matrix.shape[0]
@@ -121,6 +124,7 @@ def get_class_overlap(
121124
SupportedModule.ConfusionMatrix,
122125
DatasetSplitName.eval,
123126
task_manager=task_manager,
127+
config=config,
124128
mod_options=ModuleOptions(
125129
pipeline_index=pipeline_index, cf_normalize=False, cf_reorder_classes=False
126130
),

azimuth/routers/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from azimuth.app import (
1212
get_config,
1313
get_task_manager,
14-
initialize_managers,
14+
initialize_managers_and_config,
1515
require_editable_config,
1616
run_startup_tasks,
1717
)
@@ -89,7 +89,7 @@ def patch_config(
8989
except Exception as e:
9090
log.error("Rollback config update due to error", exc_info=e)
9191
new_config = config
92-
initialize_managers(new_config, task_manager.cluster)
92+
initialize_managers_and_config(new_config, task_manager.cluster)
9393
log.info("Config update cancelled.")
9494
if isinstance(e, (AzimuthValidationError, ValidationError)):
9595
raise HTTPException(HTTP_400_BAD_REQUEST, detail=str(e))

azimuth/routers/custom_utterances.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,13 @@ def get_saliency(
9595
utterances: List[str] = Query([], title="Utterances"),
9696
pipeline_index: int = Depends(require_pipeline_index),
9797
task_manager: TaskManager = Depends(get_task_manager),
98+
config: AzimuthConfig = Depends(get_config),
9899
) -> List[SaliencyResponse]:
99100
task_result: List[SaliencyResponse] = get_custom_task_result(
100101
SupportedMethod.Saliency,
101102
task_manager=task_manager,
102-
custom_query={task_manager.config.columns.text_input: utterances},
103+
config=config,
104+
custom_query={config.columns.text_input: utterances},
103105
mod_options=ModuleOptions(pipeline_index=pipeline_index),
104106
)
105107

azimuth/routers/dataset_warnings.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from fastapi import APIRouter, Depends
88

9-
from azimuth.app import get_dataset_split_manager_mapping, get_task_manager
9+
from azimuth.app import get_config, get_dataset_split_manager_mapping, get_task_manager
10+
from azimuth.config import AzimuthConfig
1011
from azimuth.dataset_split_manager import DatasetSplitManager
1112
from azimuth.task_manager import TaskManager
1213
from azimuth.types import DatasetSplitName, SupportedModule
@@ -25,6 +26,7 @@
2526
)
2627
def get_dataset_warnings(
2728
task_manager: TaskManager = Depends(get_task_manager),
29+
config: AzimuthConfig = Depends(get_config),
2830
dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends(
2931
get_dataset_split_manager_mapping
3032
),
@@ -35,6 +37,7 @@ def get_dataset_warnings(
3537
dataset_split_name=DatasetSplitName.all,
3638
task_manager=task_manager,
3739
last_update=get_last_update(list(dataset_split_managers.values())),
40+
config=config,
3841
)[0]
3942

4043
return task_result.warning_groups

azimuth/routers/export.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def get_export_perturbation_testing_summary(
108108
SupportedModule.PerturbationTestingSummary,
109109
DatasetSplitName.all,
110110
task_manager=task_manager,
111+
config=config,
111112
last_update=last_update,
112113
mod_options=ModuleOptions(pipeline_index=pipeline_index),
113114
)[0].all_tests_summary
@@ -150,6 +151,7 @@ def get_export_perturbed_set(
150151
SupportedModule.PerturbationTesting,
151152
dataset_split_name,
152153
task_manager,
154+
config=config,
153155
mod_options=ModuleOptions(pipeline_index=pipeline_index_not_null),
154156
)
155157

azimuth/routers/model_performance/confidence_histogram.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from fastapi import APIRouter, Depends, Query
66

7-
from azimuth.app import get_dataset_split_manager, get_task_manager
7+
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
8+
from azimuth.config import AzimuthConfig
89
from azimuth.dataset_split_manager import DatasetSplitManager
910
from azimuth.task_manager import TaskManager
1011
from azimuth.types import (
@@ -33,6 +34,7 @@ def get_confidence_histogram(
3334
dataset_split_name: DatasetSplitName,
3435
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
3536
task_manager: TaskManager = Depends(get_task_manager),
37+
config: AzimuthConfig = Depends(get_config),
3638
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
3739
pipeline_index: int = Depends(require_pipeline_index),
3840
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
@@ -47,6 +49,7 @@ def get_confidence_histogram(
4749
task_name=SupportedModule.ConfidenceHistogram,
4850
dataset_split_name=dataset_split_name,
4951
task_manager=task_manager,
52+
config=config,
5053
mod_options=mod_options,
5154
last_update=dataset_split_manager.last_update,
5255
)[0]

azimuth/routers/model_performance/confusion_matrix.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from fastapi import APIRouter, Depends, Query
66

7-
from azimuth.app import get_dataset_split_manager, get_task_manager
7+
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
8+
from azimuth.config import AzimuthConfig
89
from azimuth.dataset_split_manager import DatasetSplitManager
910
from azimuth.task_manager import TaskManager
1011
from azimuth.types import (
@@ -33,6 +34,7 @@ def get_confusion_matrix(
3334
dataset_split_name: DatasetSplitName,
3435
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
3536
task_manager: TaskManager = Depends(get_task_manager),
37+
config: AzimuthConfig = Depends(get_config),
3638
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
3739
pipeline_index: int = Depends(require_pipeline_index),
3840
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
@@ -51,6 +53,7 @@ def get_confusion_matrix(
5153
SupportedModule.ConfusionMatrix,
5254
dataset_split_name,
5355
task_manager=task_manager,
56+
config=config,
5457
mod_options=mod_options,
5558
last_update=dataset_split_manager.last_update,
5659
)[0]

azimuth/routers/model_performance/metrics.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from fastapi import APIRouter, Depends, Query
77

8-
from azimuth.app import get_dataset_split_manager, get_task_manager
8+
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
9+
from azimuth.config import AzimuthConfig
910
from azimuth.dataset_split_manager import DatasetSplitManager
1011
from azimuth.modules.model_performance.metrics import MetricsModule
1112
from azimuth.task_manager import TaskManager
@@ -41,6 +42,7 @@ def get_metrics(
4142
dataset_split_name: DatasetSplitName,
4243
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
4344
task_manager: TaskManager = Depends(get_task_manager),
45+
config: AzimuthConfig = Depends(get_config),
4446
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
4547
pipeline_index: int = Depends(require_pipeline_index),
4648
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
@@ -55,6 +57,7 @@ def get_metrics(
5557
SupportedModule.Metrics,
5658
dataset_split_name,
5759
task_manager,
60+
config=config,
5861
mod_options=mod_options,
5962
last_update=dataset_split_manager.last_update,
6063
)
@@ -73,6 +76,7 @@ def get_metrics(
7376
def get_metrics_per_filter(
7477
dataset_split_name: DatasetSplitName,
7578
task_manager: TaskManager = Depends(get_task_manager),
79+
config: AzimuthConfig = Depends(get_config),
7680
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
7781
pipeline_index: int = Depends(require_pipeline_index),
7882
) -> MetricsPerFilterAPIResponse:
@@ -81,6 +85,7 @@ def get_metrics_per_filter(
8185
SupportedModule.MetricsPerFilter,
8286
dataset_split_name,
8387
task_manager,
88+
config=config,
8489
mod_options=mod_options,
8590
last_update=dataset_split_manager.last_update,
8691
)[0]
@@ -89,6 +94,7 @@ def get_metrics_per_filter(
8994
SupportedModule.Metrics,
9095
dataset_split_name,
9196
task_manager,
97+
config=config,
9298
mod_options=mod_options,
9399
last_update=dataset_split_manager.last_update,
94100
)[0]

0 commit comments

Comments
 (0)