Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for MetaCAT #515

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion medcat/config_meta_cat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any
from typing import Dict, Any, List
from medcat.config import MixingConfig, BaseModel, Optional


Expand Down Expand Up @@ -27,8 +27,22 @@ class General(MixingConfig, BaseModel):
"""What category is this meta_cat model predicting/training.

NB! For these changes to take effect, the pipe would need to be recreated."""
alternative_category_names: List = []
"""List that stores the variations of possible category names
Example: For Experiencer, the alternate name is Subject
alternative_category_names: ['Experiencer','Subject']

In the case that one specified in self.general.category_name parameter does not match the data, this ensures no error is raised and it is automatically mapped
"""
category_value2id: Dict = {}
"""Map from category values to ID, if empty it will be autocalculated during training"""
alternative_class_names: List[List] = [[]]
"""List of lists that stores the variations of possible class names for each class mentioned in self.general.category_value2id

Example: For Presence task, the class names vary across NHS sites.
To accommodate for this, alternative_class_names is populated as: [["Hypothetical (N/A)","Hypothetical"],["Not present (False)","False"],["Present (True)","True"]]
Each sub list contains the possible variations of the given class.
"""
vocab_size: Optional[int] = None
"""Will be set automatically if the tokenizer is provided during meta_cat init"""
lowercase: bool = True
Expand Down
30 changes: 16 additions & 14 deletions medcat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,17 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data

# Check is the name present
category_name = g_config['category_name']
category_name_options = g_config['alternative_category_names']
if category_name not in data:
raise Exception(
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}".format(
category_name, " | ".join(list(data.keys()))))
category_matching = [cat for cat in category_name_options if cat in data.keys()]
if len(category_matching) > 0:
mart-r marked this conversation as resolved.
Show resolved Hide resolved
logger.info("The category name provided in the config - '%s' is not present in the data. However, the corresponding name - '%s' from the category_name_mapping has been found. Updating the category name...",category_name,*category_matching)
g_config['category_name'] = category_matching[0]
category_name = g_config['category_name']
else:
raise Exception(
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}. Additionally, ensure the populate the 'alternative_category_names' attribute to accommodate for variations.".format(
category_name, " | ".join(list(data.keys()))))

data = data[category_name]
if data_oversampled:
Expand All @@ -258,27 +265,21 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
if not category_value2id:
# Encode the category values
full_data, data_undersampled, category_value2id = encode_category_values(data,
category_undersample=self.config.model.category_undersample)
g_config['category_value2id'] = category_value2id
category_undersample=self.config.model.category_undersample,alternative_class_names=g_config['alternative_class_names'])
else:
# We already have everything, just get the data
full_data, data_undersampled, category_value2id = encode_category_values(data,
existing_category_value2id=category_value2id,
category_undersample=self.config.model.category_undersample)
g_config['category_value2id'] = category_value2id
# Make sure the config number of classes is the same as the one found in the data
if len(category_value2id) != self.config.model['nclasses']:
logger.warning(
"The number of classes set in the config is not the same as the one found in the data: %d vs %d",self.config.model['nclasses'], len(category_value2id))
logger.warning("Auto-setting the nclasses value in config and rebuilding the model.")
self.config.model['nclasses'] = len(category_value2id)
category_undersample=self.config.model.category_undersample,alternative_class_names=g_config['alternative_class_names'])
g_config['category_value2id'] = category_value2id
self.config.model['nclasses'] = len(category_value2id)

if self.config.model.phase_number == 2 and save_dir_path is not None:
model_save_path = os.path.join(save_dir_path, 'model.dat')
device = torch.device(g_config['device'])
try:
self.model.load_state_dict(torch.load(model_save_path, map_location=device))
logger.info("Model state loaded from dict for 2 phase learning")
logger.info("Training model for Phase 2, with model dict loaded from disk")

except FileNotFoundError:
raise FileNotFoundError(f"\nError: Model file not found at path: {model_save_path}\nPlease run phase 1 training and then run phase 2.")
Expand All @@ -295,6 +296,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
if not t_config['auto_save_model']:
logger.info("For phase 1, model state has to be saved. Saving model...")
t_config['auto_save_model'] = True
logger.info("Training model for Phase 1 now...")

report = train_model(self.model, data=data, config=self.config, save_dir_path=save_dir_path)

Expand Down
51 changes: 45 additions & 6 deletions medcat/utils/meta_cat/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, Optional, Tuple, Iterable, List
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase
import copy
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -153,7 +154,7 @@ def prepare_for_oversampled_data(data: List,


def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict] = None,
category_undersample=None) -> Tuple:
category_undersample=None, alternative_class_names: List[List] = []) -> Tuple:
"""Converts the category values in the data outputted by `prepare_from_json`
into integer values.

Expand All @@ -164,6 +165,8 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
Map from category_value to id (old/existing).
category_undersample:
Name of class that should be used to undersample the data (for 2 phase learning)
alternative_class_names:
Map that stores the variations of possible class names for the given category (task)

Returns:
dict:
Expand All @@ -172,6 +175,9 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
New undersampled data (for 2 phase learning) with integers inplace of strings for category values
dict:
Map from category value to ID for all categories in the data.

Raises:
Exception: If categoryvalue2id is pre-defined and its labels do not match the labels found in the data
"""
data = list(data)
if existing_category_value2id is not None:
Expand All @@ -180,9 +186,42 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
category_value2id = {}

category_values = set([x[2] for x in data])
for c in category_values:
if c not in category_value2id:
category_value2id[c] = len(category_value2id)

# If categoryvalue2id is pre-defined, then making sure it is same as the labels found in the data
if len(category_value2id) != 0:
if set(category_value2id.keys()) != category_values:
# if categoryvalue2id doesn't match the labels in the data, then 'alternative_class_names' has to be defined to check for variations
if len(alternative_class_names) != 0:
updated_category_value2id = {}
for _class in category_value2id.keys():
if _class in category_values:
updated_category_value2id[_class] = category_value2id[_class]
else:
found_in = [sub_map for sub_map in alternative_class_names if _class in sub_map]
if len(found_in) != 0:
mart-r marked this conversation as resolved.
Show resolved Hide resolved
class_name_matched = [label for label in found_in[0] if label in category_values]
if len(class_name_matched) != 0:
updated_category_value2id[class_name_matched] = category_value2id[_class]
logger.info("Class name '%s' does not exist in the data; however a variation of it '%s' is present; updating it...",_class,class_name_matched)
else:
raise Exception(
f"The classes set in the config are not the same as the one found in the data. The classes present in the config vs the ones found in the data - {set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the 'alternative_class_names' attribute to accommodate for variations.")
else:
raise Exception(f"The classes set in the config are not the same as the one found in the data. The classes present in the config vs the ones found in the data - {set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the 'alternative_class_names' attribute to accommodate for variations.")
category_value2id = copy.deepcopy(updated_category_value2id)
logger.info("Updated categoryvalue2id mapping - %s", category_value2id)

# Else throw an exception since the labels don't match
else:
raise Exception(
f"The classes set in the config are not the same as the one found in the data. The classes present in the config vs the ones found in the data - {set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the 'alternative_class_names' attribute to accommodate for variations.")

# Else create the mapping from the labels found in the data
else:
for c in category_values:
if c not in category_value2id:
category_value2id[c] = len(category_value2id)
logger.info("Categoryvalue2id mapping created with labels found in the data - %s", category_value2id)

# Map values to numbers
for i in range(len(data)):
Expand All @@ -194,7 +233,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
if data[i][2] in category_value2id.values():
label_data_[data[i][2]] = label_data_[data[i][2]] + 1

logger.info("Original label_data: %s",label_data_)
logger.info("Original number of samples per label: %s",label_data_)
# Undersampling data
if category_undersample is None or category_undersample == '':
min_label = min(label_data_.values())
Expand All @@ -217,7 +256,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
for i in range(len(data_undersampled)):
if data_undersampled[i][2] in category_value2id.values():
label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1
logger.info("Updated label_data: %s",label_data)
logger.info("Updated number of samples per label (for 2-phase learning): %s",label_data)

return data, data_undersampled, category_value2id

Expand Down
4 changes: 2 additions & 2 deletions medcat/utils/meta_cat/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,12 @@ def initialize_model(classifier, data_, batch_size_, lr_, epochs=4):
print_report(epoch, running_loss_test, all_logits_test, y=y_test, name='Test')

_report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1),
output_dict=True)
output_dict=True,zero_division=0)
if not winner_report or _report[config.train['metric']['base']][config.train['metric']['score']] > \
winner_report['report'][config.train['metric']['base']][config.train['metric']['score']]:

report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1),
output_dict=True)
output_dict=True,zero_division=0)
cm = confusion_matrix(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), normalize='true')
report_train = classification_report(y_train, np.argmax(np.concatenate(all_logits, axis=0), axis=1),
output_dict=True)
Expand Down
Loading