11# Copyright ServiceNow, Inc. 2021 – 2022
22# This source code is licensed under the Apache 2.0 license found in the LICENSE file
33# in the root directory of this source tree.
4- from multiprocessing import Lock
5- from typing import Callable , Dict , Optional
4+ from collections import defaultdict
5+ from typing import Callable , Dict
66
7+ import structlog
78from datasets import DatasetDict
89
910from azimuth .config import AzimuthConfig
1819Hash = int
1920
2021
22+ log = structlog .get_logger ()
23+
24+
25+ class Singleton :
26+ """
27+ A non-thread-safe helper class to ease implementing singletons.
28+ This should be used as a decorator -- not a metaclass -- to the
29+ class that should be a singleton.
30+
31+ To get the singleton instance, use the `instance` method. Trying
32+ to use `__call__` will result in a `TypeError` being raised.
33+
34+ Args:
35+ decorated: Decorated class
36+ """
37+
38+ def __init__ (self , decorated ):
39+ self ._decorated = decorated
40+
41+ def instance (self ):
42+ """
43+ Returns the singleton instance. Upon its first call, it creates a
44+ new instance of the decorated class and calls its `__init__` method.
45+ On all subsequent calls, the already created instance is returned.
46+
47+ Returns:
48+ Instance of the decorated class
49+ """
50+ try :
51+ return self ._instance
52+ except AttributeError :
53+ self ._instance = self ._decorated ()
54+ return self ._instance
55+
56+ def __call__ (self ):
57+ raise TypeError ("Singletons must be accessed through `instance()`." )
58+
59+ def clear_instance (self ):
60+ """For test purposes only"""
61+ if hasattr (self , "_instance" ):
62+ delattr (self , "_instance" )
63+
64+
65+ @Singleton
2166class ArtifactManager :
2267 """This class is a singleton which holds different artifacts.
2368
2469 Artifacts include dataset_split_managers, datasets and models for each config, so they don't
2570 need to be reloaded many times for a same module.
2671 """
2772
28- instance : Optional ["ArtifactManager" ] = None
29-
3073 def __init__ (self ):
3174 # The keys of the dict are a hash of the config.
3275 self .dataset_dict_mapping : Dict [Hash , DatasetDict ] = {}
3376 self .dataset_split_managers_mapping : Dict [
3477 Hash , Dict [DatasetSplitName , DatasetSplitManager ]
35- ] = {}
36- self .models_mapping : Dict [Hash , Dict [int , Callable ]] = {}
37- self .tokenizer = None
78+ ] = defaultdict (dict )
79+ self .models_mapping : Dict [Hash , Dict [int , Callable ]] = defaultdict (dict )
3880 self .metrics = {}
39-
40- @classmethod
41- def get_instance (cls ):
42- with Lock ():
43- if cls .instance is None :
44- cls .instance = cls ()
45- return cls .instance
81+ log .debug (f"Creating new Artifact Manager { id (self )} ." )
4682
4783 def get_dataset_split_manager (
4884 self , config : AzimuthConfig , name : DatasetSplitName
@@ -68,8 +104,6 @@ def get_dataset_split_manager(
68104 f"Found { tuple (dataset_dict .keys ())} ."
69105 )
70106 project_hash : Hash = config .get_project_hash ()
71- if project_hash not in self .dataset_split_managers_mapping :
72- self .dataset_split_managers_mapping [project_hash ] = {}
73107 if name not in self .dataset_split_managers_mapping [project_hash ]:
74108 self .dataset_split_managers_mapping [project_hash ][name ] = DatasetSplitManager (
75109 name = name ,
@@ -78,6 +112,7 @@ def get_dataset_split_manager(
78112 initial_prediction_tags = ALL_PREDICTION_TAGS ,
79113 dataset_split = dataset_dict [name ],
80114 )
115+ log .debug (f"New { name } DM in Artifact Manager { id (self )} " )
81116 return self .dataset_split_managers_mapping [project_hash ][name ]
82117
83118 def get_dataset_dict (self , config ) -> DatasetDict :
@@ -106,17 +141,16 @@ def get_model(self, config: AzimuthConfig, pipeline_idx: int):
106141 Returns:
107142 Loaded model.
108143 """
109-
110- project_hash : Hash = config .get_project_hash ()
111- if project_hash not in self .models_mapping :
112- self .models_mapping [project_hash ] = {}
113- if pipeline_idx not in self .models_mapping [project_hash ]:
144+ # We only need to reload the pipeline if the model contract part of the config is changed.
145+ model_contract_hash : Hash = config .get_model_contract_hash ()
146+ if pipeline_idx not in self .models_mapping [model_contract_hash ]:
147+ log .debug (f"Loading pipeline { pipeline_idx } ." )
114148 pipelines = assert_not_none (config .pipelines )
115- self .models_mapping [project_hash ][pipeline_idx ] = load_custom_object (
149+ self .models_mapping [model_contract_hash ][pipeline_idx ] = load_custom_object (
116150 assert_not_none (pipelines [pipeline_idx ].model ), azimuth_config = config
117151 )
118152
119- return self .models_mapping [project_hash ][pipeline_idx ]
153+ return self .models_mapping [model_contract_hash ][pipeline_idx ]
120154
121155 def get_metric (self , config , name : str , ** kwargs ):
122156 hash : Hash = md5_hash ({"name" : name , ** kwargs })
@@ -125,6 +159,6 @@ def get_metric(self, config, name: str, **kwargs):
125159 return self .metrics [hash ]
126160
127161 @classmethod
128- def clear_cache (cls ) -> None :
129- with Lock ():
130- cls . instance = None
162+ def instance (cls ):
163+ # Implemented in decorator
164+ raise NotImplementedError
0 commit comments