-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updates for MetaCAT #515
base: master
Are you sure you want to change the base?
Updates for MetaCAT #515
Changes from 2 commits
f72fc28
874954e
ea8e6fe
c703014
55d2165
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from typing import Dict, Any | ||
from typing import Dict, Any, List | ||
from medcat.config import MixingConfig, BaseModel, Optional | ||
|
||
|
||
|
@@ -27,8 +27,22 @@ class General(MixingConfig, BaseModel): | |
"""What category is this meta_cat model predicting/training. | ||
|
||
NB! For these changes to take effect, the pipe would need to be recreated.""" | ||
category_names_map: List = [] | ||
"""Map that stores the variations of possible category names | ||
Example: For Experiencer, the alternate name is Subject | ||
category_names_map: ['Experiencer','Subject'] | ||
|
||
In the case that one specified in 'category_name' parameter does not match the data, this ensures no error is raised and it is automatically mapped | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also say that the model output will be the value configured in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure I completely understood this |
||
|
||
category_value2id: Dict = {} | ||
"""Map from category values to ID, if empty it will be autocalculated during training""" | ||
class_names_map: List[List] = [[]] | ||
"""Map that stores the variations of possible class names for the given category (task) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, not a map, but a list. At least in its current state.
|
||
Example: For Presence task, the class names vary across NHS sites. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please change to |
||
To accommodate for this, class_names_map is populated as: [["Hypothetical (N/A)","Hypothetical"],["Not present (False)","False"],["Present (True)","True"]] | ||
Each sub list contains the possible variations of the given class. | ||
""" | ||
vocab_size: Optional[int] = None | ||
"""Will be set automatically if the tokenizer is provided during meta_cat init""" | ||
lowercase: bool = True | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -244,10 +244,17 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data | |
|
||
# Check is the name present | ||
category_name = g_config['category_name'] | ||
category_name_options = g_config['category_names_map'] | ||
if category_name not in data: | ||
raise Exception( | ||
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}".format( | ||
category_name, " | ".join(list(data.keys())))) | ||
category_matching = [cat for cat in category_name_options if cat in data.keys()] | ||
if len(category_matching) > 0: | ||
mart-r marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.warning("The category name provided in the config - '%s' is not present in the data. However, the corresponding name - '%s' from the category_name_mapping has been found. Updating the category name...",category_name,*category_matching) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd say this would be an |
||
g_config['category_name'] = category_matching[0] | ||
category_name = g_config['category_name'] | ||
else: | ||
raise Exception( | ||
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}".format( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps mention the possibility of setting |
||
category_name, " | ".join(list(data.keys())))) | ||
|
||
data = data[category_name] | ||
if data_oversampled: | ||
|
@@ -258,27 +265,29 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data | |
if not category_value2id: | ||
# Encode the category values | ||
full_data, data_undersampled, category_value2id = encode_category_values(data, | ||
category_undersample=self.config.model.category_undersample) | ||
g_config['category_value2id'] = category_value2id | ||
category_undersample=self.config.model.category_undersample,class_name_map=g_config['class_names_map']) | ||
else: | ||
# We already have everything, just get the data | ||
full_data, data_undersampled, category_value2id = encode_category_values(data, | ||
existing_category_value2id=category_value2id, | ||
category_undersample=self.config.model.category_undersample) | ||
g_config['category_value2id'] = category_value2id | ||
# Make sure the config number of classes is the same as the one found in the data | ||
if len(category_value2id) != self.config.model['nclasses']: | ||
logger.warning( | ||
"The number of classes set in the config is not the same as the one found in the data: %d vs %d",self.config.model['nclasses'], len(category_value2id)) | ||
logger.warning("Auto-setting the nclasses value in config and rebuilding the model.") | ||
self.config.model['nclasses'] = len(category_value2id) | ||
category_undersample=self.config.model.category_undersample,class_name_map=g_config['class_names_map']) | ||
g_config['category_value2id'] = category_value2id | ||
self.config.model['nclasses'] = len(category_value2id) | ||
|
||
# This is now handled in data_utils where an exception is raised when mismatch is found | ||
# Make sure that the categoryvalue2id if present is same as the labels found | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these commented lines are needed? |
||
# if len(category_value2id) != self.config.model['nclasses']: | ||
# logger.warning( | ||
# "The number of classes set in the config is not the same as the one found in the data: %d vs %d",self.config.model['nclasses'], len(category_value2id)) | ||
# logger.warning("Auto-setting the nclasses value in config and rebuilding the model.") | ||
|
||
if self.config.model.phase_number == 2 and save_dir_path is not None: | ||
model_save_path = os.path.join(save_dir_path, 'model.dat') | ||
device = torch.device(g_config['device']) | ||
try: | ||
self.model.load_state_dict(torch.load(model_save_path, map_location=device)) | ||
logger.info("Model state loaded from dict for 2 phase learning") | ||
logger.info("Training model for Phase 2 now...") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logged message right after the one before seems unnecessary? If the logged message i unclear, maybe just amend the previous |
||
|
||
except FileNotFoundError: | ||
raise FileNotFoundError(f"\nError: Model file not found at path: {model_save_path}\nPlease run phase 1 training and then run phase 2.") | ||
|
@@ -295,6 +304,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data | |
if not t_config['auto_save_model']: | ||
logger.info("For phase 1, model state has to be saved. Saving model...") | ||
t_config['auto_save_model'] = True | ||
logger.info("Training model for Phase 1 now...") | ||
|
||
report = train_model(self.model, data=data, config=self.config, save_dir_path=save_dir_path) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from typing import Dict, Optional, Tuple, Iterable, List | ||
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase | ||
import copy | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -153,7 +154,7 @@ def prepare_for_oversampled_data(data: List, | |
|
||
|
||
def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict] = None, | ||
category_undersample=None) -> Tuple: | ||
category_undersample=None, class_name_map: List[List] = []) -> Tuple: | ||
"""Converts the category values in the data outputted by `prepare_from_json` | ||
into integer values. | ||
|
||
|
@@ -164,6 +165,8 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict | |
Map from category_value to id (old/existing). | ||
category_undersample: | ||
Name of class that should be used to undersample the data (for 2 phase learning) | ||
class_name_map: | ||
Map that stores the variations of possible class names for the given category (task) | ||
|
||
Returns: | ||
dict: | ||
|
@@ -172,6 +175,9 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict | |
New undersampled data (for 2 phase learning) with integers inplace of strings for category values | ||
dict: | ||
Map from category value to ID for all categories in the data. | ||
|
||
Raises: | ||
Exception: If categoryvalue2id is pre-defined and its labels do not match the labels found in the data | ||
""" | ||
data = list(data) | ||
if existing_category_value2id is not None: | ||
|
@@ -180,9 +186,38 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict | |
category_value2id = {} | ||
|
||
category_values = set([x[2] for x in data]) | ||
for c in category_values: | ||
if c not in category_value2id: | ||
category_value2id[c] = len(category_value2id) | ||
|
||
# If categoryvalue2id is pre-defined, then making sure it is same as the labels found in the data | ||
if len(category_value2id) != 0: | ||
if set(category_value2id.keys()) != category_values: | ||
# if categoryvalue2id doesn't match the labels in the data, then 'class_name_map' has to be defined to check for variations | ||
if len(class_name_map) != 0: | ||
updated_category_value2id = {} | ||
for _class in category_value2id.keys(): | ||
if _class in category_values: | ||
updated_category_value2id[_class] = category_value2id[_class] | ||
else: | ||
found_in = [sub_map for sub_map in class_name_map if _class in sub_map][0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here as well, we take the first matching "map" that has a matching class in it. |
||
if len(found_in) != 0: | ||
mart-r marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class_name_matched = [label for label in found_in if label in category_values][0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there's no existing category names defined in the "map", this will fail with an |
||
updated_category_value2id[class_name_matched] = category_value2id[_class] | ||
logger.warning("Class name '%s' does not exist in the data; however a variation of it '%s' is present; updating it...",_class,class_name_matched) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, I feel like at this point, it'd be expected behaviour and better suited as an |
||
else: | ||
raise Exception(f"The classes set in the config are not the same as the one found in the data. The classes present in the config vs the ones found in the data - {set(category_value2id.keys())}, {category_values}") | ||
category_value2id = copy.deepcopy(updated_category_value2id) | ||
logger.info("Updated categoryvalue2id mapping - %s", category_value2id) | ||
|
||
# Else throw an exception since the labels don't match | ||
else: | ||
raise Exception( | ||
f"The classes set in the config are not the same as the one found in the data. The classes present in the config vs the ones found in the data - {set(category_value2id.keys())}, {category_values}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps also mention the possibility of setting |
||
|
||
# Else create the mapping from the labels found in the data | ||
else: | ||
for c in category_values: | ||
if c not in category_value2id: | ||
category_value2id[c] = len(category_value2id) | ||
logger.info("Categoryvalue2id mapping created with labels found in the data - %s", category_value2id) | ||
|
||
# Map values to numbers | ||
for i in range(len(data)): | ||
|
@@ -194,7 +229,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict | |
if data[i][2] in category_value2id.values(): | ||
label_data_[data[i][2]] = label_data_[data[i][2]] + 1 | ||
|
||
logger.info("Original label_data: %s",label_data_) | ||
logger.info("Original number of samples per label: %s",label_data_) | ||
# Undersampling data | ||
if category_undersample is None or category_undersample == '': | ||
min_label = min(label_data_.values()) | ||
|
@@ -217,7 +252,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict | |
for i in range(len(data_undersampled)): | ||
if data_undersampled[i][2] in category_value2id.values(): | ||
label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1 | ||
logger.info("Updated label_data: %s",label_data) | ||
logger.info("Updated number of samples per label (for 2-phase learning): %s",label_data) | ||
|
||
return data, data_undersampled, category_value2id | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -329,12 +329,12 @@ def initialize_model(classifier, data_, batch_size_, lr_, epochs=4): | |
print_report(epoch, running_loss_test, all_logits_test, y=y_test, name='Test') | ||
|
||
_report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), | ||
output_dict=True) | ||
output_dict=True,zero_division=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to have it return 1? I.e if there are no retrieved elements, then precision is undetermined. But if we set it to 1 and don't show the warning, the user might think everything is working great. But in reality, nothing was extracted. Though in an ideal world, I'd love to still show a warning to the user. But maybe limit it to once per 10 minutes or something. But I don't think they have API for something like that. And I don't think it'd make sense to spend the time implementing it, either. |
||
if not winner_report or _report[config.train['metric']['base']][config.train['metric']['score']] > \ | ||
winner_report['report'][config.train['metric']['base']][config.train['metric']['score']]: | ||
|
||
report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), | ||
output_dict=True) | ||
output_dict=True,zero_division=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question here. Does it make sense to show 1 in case of no retrieved elements? Or would 0 be more appropriate? |
||
cm = confusion_matrix(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), normalize='true') | ||
report_train = classification_report(y_train, np.argmax(np.concatenate(all_logits, axis=0), axis=1), | ||
output_dict=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This not a mapping. It's a list. As such, the name doesn't make a lot of sense.
Perhaps this could be named
alternative_category_names
or something along those lines?