Skip to content

Commit 4095a73

Browse files
committed
Refactor artifact manager to real singleton and remove clear_cache
1 parent b984600 commit 4095a73

File tree

19 files changed

+132
-133
lines changed

19 files changed

+132
-133
lines changed

azimuth/app.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,14 @@
2525

2626
from azimuth.config import AzimuthConfig, load_azimuth_config
2727
from azimuth.dataset_split_manager import DatasetSplitManager
28-
from azimuth.modules.base_classes import DaskModule
28+
from azimuth.modules.base_classes import ArtifactManager, DaskModule
2929
from azimuth.modules.utilities.validation import ValidationModule
3030
from azimuth.startup import startup_tasks
3131
from azimuth.task_manager import TaskManager
3232
from azimuth.types import DatasetSplitName, ModuleOptions
3333
from azimuth.utils.cluster import default_cluster
3434
from azimuth.utils.conversion import JSONResponseIgnoreNan
3535
from azimuth.utils.logs import set_logger_config
36-
from azimuth.utils.project import load_dataset_split_managers_from_config
3736
from azimuth.utils.validation import assert_not_none
3837

3938
_dataset_split_managers: Dict[DatasetSplitName, Optional[DatasetSplitManager]] = {}
@@ -297,6 +296,32 @@ def create_app() -> FastAPI:
297296
return app
298297

299298

299+
def load_dataset_split_managers_from_config(
300+
azimuth_config: AzimuthConfig,
301+
) -> Dict[DatasetSplitName, Optional[DatasetSplitManager]]:
302+
"""
303+
Load all dataset splits for the application.
304+
305+
Args:
306+
azimuth_config: Azimuth Configuration.
307+
308+
Returns:
309+
For all DatasetSplitName, None or a dataset_split manager.
310+
311+
"""
312+
artifact_manager = ArtifactManager.instance()
313+
dataset = artifact_manager.get_dataset_dict(azimuth_config)
314+
315+
return {
316+
dataset_split_name: None
317+
if dataset_split_name not in dataset
318+
else artifact_manager.get_dataset_split_manager(
319+
azimuth_config, DatasetSplitName[dataset_split_name]
320+
)
321+
for dataset_split_name in [DatasetSplitName.eval, DatasetSplitName.train]
322+
}
323+
324+
300325
def initialize_managers(azimuth_config: AzimuthConfig, cluster: SpecCluster):
301326
"""Initialize DatasetSplitManagers and TaskManagers.
302327
@@ -348,7 +373,6 @@ def run_validation_module(pipeline_index=None):
348373
else:
349374
for pipeline_index in range(len(config.pipelines)):
350375
run_validation_module(pipeline_index)
351-
task_manager.clear_worker_cache()
352376

353377

354378
def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):

azimuth/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,15 @@ def check_pipeline_names(cls, pipeline_definitions):
355355
raise ValueError(f"Duplicated pipeline names {pipeline_names}.")
356356
return pipeline_definitions
357357

358+
def get_model_contract_hash(self):
359+
return md5_hash(
360+
self.dict(
361+
include=ModelContractConfig.__fields__.keys()
362+
- CommonFieldsConfig.__fields__.keys(),
363+
by_alias=True,
364+
)
365+
)
366+
358367

359368
class MetricsConfig(ModelContractConfig):
360369
# Custom HuggingFace metrics

azimuth/modules/base_classes/artifact_manager.py

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Copyright ServiceNow, Inc. 2021 – 2022
22
# This source code is licensed under the Apache 2.0 license found in the LICENSE file
33
# in the root directory of this source tree.
4-
from multiprocessing import Lock
5-
from typing import Callable, Dict, Optional
4+
from collections import defaultdict
5+
from typing import Callable, Dict
66

7+
import structlog
78
from datasets import DatasetDict
89

910
from azimuth.config import AzimuthConfig
@@ -18,31 +19,66 @@
1819
Hash = int
1920

2021

22+
log = structlog.get_logger()
23+
24+
25+
class Singleton:
26+
"""
27+
A non-thread-safe helper class to ease implementing singletons.
28+
This should be used as a decorator -- not a metaclass -- to the
29+
class that should be a singleton.
30+
31+
To get the singleton instance, use the `instance` method. Trying
32+
to use `__call__` will result in a `TypeError` being raised.
33+
34+
Args:
35+
decorated: Decorated class
36+
"""
37+
38+
def __init__(self, decorated):
39+
self._decorated = decorated
40+
41+
def instance(self):
42+
"""
43+
Returns the singleton instance. Upon its first call, it creates a
44+
new instance of the decorated class and calls its `__init__` method.
45+
On all subsequent calls, the already created instance is returned.
46+
47+
Returns:
48+
Instance of the decorated class
49+
"""
50+
try:
51+
return self._instance
52+
except AttributeError:
53+
self._instance = self._decorated()
54+
return self._instance
55+
56+
def __call__(self):
57+
raise TypeError("Singletons must be accessed through `instance()`.")
58+
59+
def clear_instance(self):
60+
"""For test purposes only"""
61+
if hasattr(self, "_instance"):
62+
delattr(self, "_instance")
63+
64+
65+
@Singleton
2166
class ArtifactManager:
2267
"""This class is a singleton which holds different artifacts.
2368
2469
Artifacts include dataset_split_managers, datasets and models for each config, so they don't
2570
need to be reloaded many times for a same module.
2671
"""
2772

28-
instance: Optional["ArtifactManager"] = None
29-
3073
def __init__(self):
3174
# The keys of the dict are a hash of the config.
3275
self.dataset_dict_mapping: Dict[Hash, DatasetDict] = {}
3376
self.dataset_split_managers_mapping: Dict[
3477
Hash, Dict[DatasetSplitName, DatasetSplitManager]
35-
] = {}
36-
self.models_mapping: Dict[Hash, Dict[int, Callable]] = {}
37-
self.tokenizer = None
78+
] = defaultdict(dict)
79+
self.models_mapping: Dict[Hash, Dict[int, Callable]] = defaultdict(dict)
3880
self.metrics = {}
39-
40-
@classmethod
41-
def get_instance(cls):
42-
with Lock():
43-
if cls.instance is None:
44-
cls.instance = cls()
45-
return cls.instance
81+
log.debug(f"Creating new Artifact Manager {id(self)}.")
4682

4783
def get_dataset_split_manager(
4884
self, config: AzimuthConfig, name: DatasetSplitName
@@ -68,8 +104,6 @@ def get_dataset_split_manager(
68104
f"Found {tuple(dataset_dict.keys())}."
69105
)
70106
project_hash: Hash = config.get_project_hash()
71-
if project_hash not in self.dataset_split_managers_mapping:
72-
self.dataset_split_managers_mapping[project_hash] = {}
73107
if name not in self.dataset_split_managers_mapping[project_hash]:
74108
self.dataset_split_managers_mapping[project_hash][name] = DatasetSplitManager(
75109
name=name,
@@ -78,6 +112,7 @@ def get_dataset_split_manager(
78112
initial_prediction_tags=ALL_PREDICTION_TAGS,
79113
dataset_split=dataset_dict[name],
80114
)
115+
log.debug(f"New {name} DM in Artifact Manager {id(self)}")
81116
return self.dataset_split_managers_mapping[project_hash][name]
82117

83118
def get_dataset_dict(self, config) -> DatasetDict:
@@ -106,17 +141,16 @@ def get_model(self, config: AzimuthConfig, pipeline_idx: int):
106141
Returns:
107142
Loaded model.
108143
"""
109-
110-
project_hash: Hash = config.get_project_hash()
111-
if project_hash not in self.models_mapping:
112-
self.models_mapping[project_hash] = {}
113-
if pipeline_idx not in self.models_mapping[project_hash]:
144+
# We only need to reload the pipeline if the model contract part of the config is changed.
145+
model_contract_hash: Hash = config.get_model_contract_hash()
146+
if pipeline_idx not in self.models_mapping[model_contract_hash]:
147+
log.debug(f"Loading pipeline {pipeline_idx}.")
114148
pipelines = assert_not_none(config.pipelines)
115-
self.models_mapping[project_hash][pipeline_idx] = load_custom_object(
149+
self.models_mapping[model_contract_hash][pipeline_idx] = load_custom_object(
116150
assert_not_none(pipelines[pipeline_idx].model), azimuth_config=config
117151
)
118152

119-
return self.models_mapping[project_hash][pipeline_idx]
153+
return self.models_mapping[model_contract_hash][pipeline_idx]
120154

121155
def get_metric(self, config, name: str, **kwargs):
122156
hash: Hash = md5_hash({"name": name, **kwargs})
@@ -125,6 +159,6 @@ def get_metric(self, config, name: str, **kwargs):
125159
return self.metrics[hash]
126160

127161
@classmethod
128-
def clear_cache(cls) -> None:
129-
with Lock():
130-
cls.instance = None
162+
def instance(cls):
163+
# Implemented in decorator
164+
raise NotImplementedError

azimuth/modules/base_classes/module.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_indices(self, name: Optional[DatasetSplitName] = None) -> List[int]:
7979
def artifact_manager(self):
8080
"""This is set as a property so the Module always have access to the current version of
8181
the ArtifactManager on the worker."""
82-
return ArtifactManager.get_instance()
82+
return ArtifactManager.instance()
8383

8484
@property
8585
def available_dataset_splits(self) -> Set[DatasetSplitName]:
@@ -215,6 +215,3 @@ def get_pipeline_definition(self) -> PipelineDefinition:
215215
pipeline_index = assert_not_none(self.mod_options.pipeline_index)
216216
current_pipeline = pipelines[pipeline_index]
217217
return current_pipeline
218-
219-
def clear_cache(self):
220-
self.artifact_manager.clear_cache()

azimuth/routers/config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,6 @@ def patch_config(
9898
HTTP_500_INTERNAL_SERVER_ERROR, detail="Error when loading the new config."
9999
)
100100

101-
# Clear workers so that they load the correct config.
102-
task_manager.clear_worker_cache()
103101
return new_config
104102

105103

azimuth/routers/export.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import time
77
from os.path import join as pjoin
8-
from typing import Dict, Generator, List, Optional, cast
8+
from typing import Dict, Generator, List, Optional
99

1010
import pandas as pd
1111
from fastapi import APIRouter, Depends, HTTPException
@@ -155,7 +155,10 @@ def get_export_perturbed_set(
155155

156156
output = list(
157157
make_utterance_level_result(
158-
dataset_split_manager, task_result, pipeline_index=pipeline_index_not_null
158+
dataset_split_manager,
159+
task_result,
160+
pipeline_index=pipeline_index_not_null,
161+
config=config,
159162
)
160163
)
161164
with open(path, "w") as f:
@@ -164,20 +167,23 @@ def get_export_perturbed_set(
164167

165168

166169
def make_utterance_level_result(
167-
dm: DatasetSplitManager, results: List[List[PerturbedUtteranceResult]], pipeline_index: int
170+
dm: DatasetSplitManager,
171+
results: List[List[PerturbedUtteranceResult]],
172+
pipeline_index: int,
173+
config: AzimuthConfig,
168174
) -> Generator[Dict, None, None]:
169175
"""Massage perturbation testing results for the frontend.
170176
171177
Args:
172178
dm: Current DatasetSplitManager.
173179
results: Output of Perturbation Testing.
174180
pipeline_index: Index of the pipeline that made the results.
181+
config: Azimuth config
175182
176183
Returns:
177184
Generator that yield json-able object for the frontend.
178185
179186
"""
180-
config = cast(AzimuthConfig, dm.config)
181187
for idx, (utterance, test_results) in enumerate(
182188
zip(
183189
dm.get_dataset_split(

azimuth/routers/utterances.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,6 @@ def patch_utterances(
229229
utterances: List[UtterancePatch] = Body(...),
230230
config: AzimuthConfig = Depends(get_config),
231231
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
232-
task_manager: TaskManager = Depends(get_task_manager),
233232
ignore_not_found: bool = Query(False),
234233
) -> List[UtterancePatch]:
235234
if ignore_not_found:
@@ -250,7 +249,6 @@ def patch_utterances(
250249

251250
dataset_split_manager.add_tags(data_actions)
252251

253-
task_manager.clear_worker_cache()
254252
updated_tags = dataset_split_manager.get_tags(row_indices)
255253

256254
return [

azimuth/startup.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ def on_end(fut: Future, module: DaskModule, dm: DatasetSplitManager, task_manage
150150
# Task is done, save the result.
151151
if isinstance(module, DatasetResultModule):
152152
module.save_result(module.result(), dm)
153-
# We only need to clear cache when the dataset is modified.
154-
task_manager.clear_worker_cache()
155153
else:
156154
log.exception("Error in", module=module, fut=fut, exc_info=fut.exception())
157155

azimuth/task_manager.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from distributed import Client, SpecCluster
99

1010
from azimuth.config import AzimuthConfig
11-
from azimuth.modules.base_classes import ArtifactManager, DaskModule, ExpirableMixin
11+
from azimuth.modules.base_classes import DaskModule, ExpirableMixin
1212
from azimuth.modules.task_mapping import model_contract_methods, modules
1313
from azimuth.types import (
1414
DatasetSplitName,
@@ -67,7 +67,6 @@ def close(self):
6767
mod.future.cancel()
6868
except Exception:
6969
pass
70-
self.clear_worker_cache()
7170
self.client.close()
7271

7372
def register_task(self, name, cls):
@@ -214,9 +213,6 @@ def status(self):
214213
**self.get_all_tasks_status(task=None),
215214
}
216215

217-
def clear_worker_cache(self):
218-
self.client.run(ArtifactManager.clear_cache)
219-
220216
def restart(self):
221217
log.info("Cluster restarted to free memory.")
222218
for task_name, module in self.current_tasks.items():

azimuth/utils/project.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright ServiceNow, Inc. 2021 – 2022
22
# This source code is licensed under the Apache 2.0 license found in the LICENSE file
33
# in the root directory of this source tree.
4-
from typing import Dict, Optional
4+
from typing import Dict
55

66
import structlog
77
from datasets import DatasetDict
@@ -12,9 +12,7 @@
1212
PerturbationTestingConfig,
1313
SimilarityConfig,
1414
)
15-
from azimuth.dataset_split_manager import DatasetSplitManager
1615
from azimuth.types import DatasetSplitName, SupportedModelContract
17-
from azimuth.types.tag import ALL_PREDICTION_TAGS, ALL_STANDARD_TAGS
1816
from azimuth.utils.object_loader import load_custom_object
1917

2018
log = structlog.get_logger()
@@ -59,38 +57,6 @@ def update_config(old_config: AzimuthConfig, partial_config: Dict) -> AzimuthCon
5957
return old_config.copy(update=partial_config, deep=True)
6058

6159

62-
def load_dataset_split_managers_from_config(
63-
azimuth_config: AzimuthConfig,
64-
) -> Dict[DatasetSplitName, Optional[DatasetSplitManager]]:
65-
"""
66-
Load all dataset splits for the application.
67-
68-
Args:
69-
azimuth_config: Azimuth Configuration.
70-
71-
Returns:
72-
For all DatasetSplitName, None or a dataset_split manager.
73-
74-
"""
75-
dataset = load_dataset_from_config(azimuth_config)
76-
77-
def make_dataset_split_manager(name: DatasetSplitName):
78-
return DatasetSplitManager(
79-
name=name,
80-
config=azimuth_config,
81-
initial_tags=ALL_STANDARD_TAGS,
82-
initial_prediction_tags=ALL_PREDICTION_TAGS,
83-
dataset_split=dataset[name],
84-
)
85-
86-
return {
87-
dataset_split_name: None
88-
if dataset_split_name not in dataset
89-
else make_dataset_split_manager(DatasetSplitName[dataset_split_name])
90-
for dataset_split_name in [DatasetSplitName.eval, DatasetSplitName.train]
91-
}
92-
93-
9460
def predictions_available(config: ModelContractConfig) -> bool:
9561
return config.pipelines is not None
9662

0 commit comments

Comments
 (0)