Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for MetaCAT #515

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion medcat/config_meta_cat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any
from typing import Dict, Any, List
from medcat.config import MixingConfig, BaseModel, Optional


Expand Down Expand Up @@ -27,8 +27,22 @@ class General(MixingConfig, BaseModel):
"""What category is this meta_cat model predicting/training.

NB! For these changes to take effect, the pipe would need to be recreated."""
category_names_map: List = []
"""Map that stores the variations of possible category names
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This not a mapping. It's a list. As such, the name doesn't make a lot of sense.
Perhaps this could be named alternative_category_names or something along those lines?

Example: For Experiencer, the alternate name is Subject
category_names_map: ['Experiencer','Subject']

In the case that one specified in 'category_name' parameter does not match the data, this ensures no error is raised and it is automatically mapped
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also say that the model output will be the value configured in <whatever the config property> is

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I completely understood this
Can you please explain this?


category_value2id: Dict = {}
"""Map from category values to ID, if empty it will be autocalculated during training"""
class_names_map: List[List] = [[]]
"""Map that stores the variations of possible class names for the given category (task)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, not a map, but a list. At least in its current state.
Perhaps making it a Dict[str, str] could make sense? I.e a mapping from the presented class name to the expected one.
The example would then be:

{"Hypothetical (N/A)": "Hypothetical", "Not present (False)": "False", "Present (True)": "True"}

Example: For Presence task, the class names vary across NHS sites.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change to List of lists that stores..., and again specify that the final output of the model will be the class name label specified in the the config <property>

To accommodate for this, class_names_map is populated as: [["Hypothetical (N/A)","Hypothetical"],["Not present (False)","False"],["Present (True)","True"]]
Each sub list contains the possible variations of the given class.
"""
vocab_size: Optional[int] = None
"""Will be set automatically if the tokenizer is provided during meta_cat init"""
lowercase: bool = True
Expand Down
36 changes: 23 additions & 13 deletions medcat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,17 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data

# Check is the name present
category_name = g_config['category_name']
category_name_options = g_config['category_names_map']
if category_name not in data:
raise Exception(
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}".format(
category_name, " | ".join(list(data.keys()))))
category_matching = [cat for cat in category_name_options if cat in data.keys()]
if len(category_matching) > 0:
mart-r marked this conversation as resolved.
Show resolved Hide resolved
logger.warning("The category name provided in the config - '%s' is not present in the data. However, the corresponding name - '%s' from the category_name_mapping has been found. Updating the category name...",category_name,*category_matching)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say this would be an info message since after this change it'll be expected behaviour.

g_config['category_name'] = category_matching[0]
category_name = g_config['category_name']
else:
raise Exception(
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}".format(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps mention the possibility of setting config.general.category_names_map for alternatives in here?

category_name, " | ".join(list(data.keys()))))

data = data[category_name]
if data_oversampled:
Expand All @@ -258,27 +265,29 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
if not category_value2id:
# Encode the category values
full_data, data_undersampled, category_value2id = encode_category_values(data,
category_undersample=self.config.model.category_undersample)
g_config['category_value2id'] = category_value2id
category_undersample=self.config.model.category_undersample,class_name_map=g_config['class_names_map'])
else:
# We already have everything, just get the data
full_data, data_undersampled, category_value2id = encode_category_values(data,
existing_category_value2id=category_value2id,
category_undersample=self.config.model.category_undersample)
g_config['category_value2id'] = category_value2id
# Make sure the config number of classes is the same as the one found in the data
if len(category_value2id) != self.config.model['nclasses']:
logger.warning(
"The number of classes set in the config is not the same as the one found in the data: %d vs %d",self.config.model['nclasses'], len(category_value2id))
logger.warning("Auto-setting the nclasses value in config and rebuilding the model.")
self.config.model['nclasses'] = len(category_value2id)
category_undersample=self.config.model.category_undersample,class_name_map=g_config['class_names_map'])
g_config['category_value2id'] = category_value2id
self.config.model['nclasses'] = len(category_value2id)

# This is now handled in data_utils where an exception is raised when mismatch is found
# Make sure that the categoryvalue2id if present is same as the labels found
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these commented lines are needed?

# if len(category_value2id) != self.config.model['nclasses']:
# logger.warning(
# "The number of classes set in the config is not the same as the one found in the data: %d vs %d",self.config.model['nclasses'], len(category_value2id))
# logger.warning("Auto-setting the nclasses value in config and rebuilding the model.")

if self.config.model.phase_number == 2 and save_dir_path is not None:
model_save_path = os.path.join(save_dir_path, 'model.dat')
device = torch.device(g_config['device'])
try:
self.model.load_state_dict(torch.load(model_save_path, map_location=device))
logger.info("Model state loaded from dict for 2 phase learning")
logger.info("Training model for Phase 2 now...")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logged message right after the one before seems unnecessary? If the logged message i unclear, maybe just amend the previous info? Or remove the one before if that seems more appropriate.


except FileNotFoundError:
raise FileNotFoundError(f"\nError: Model file not found at path: {model_save_path}\nPlease run phase 1 training and then run phase 2.")
Expand All @@ -295,6 +304,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
if not t_config['auto_save_model']:
logger.info("For phase 1, model state has to be saved. Saving model...")
t_config['auto_save_model'] = True
logger.info("Training model for Phase 1 now...")

report = train_model(self.model, data=data, config=self.config, save_dir_path=save_dir_path)

Expand Down
47 changes: 41 additions & 6 deletions medcat/utils/meta_cat/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, Optional, Tuple, Iterable, List
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase
import copy
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -153,7 +154,7 @@ def prepare_for_oversampled_data(data: List,


def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict] = None,
category_undersample=None) -> Tuple:
category_undersample=None, class_name_map: List[List] = []) -> Tuple:
"""Converts the category values in the data outputted by `prepare_from_json`
into integer values.

Expand All @@ -164,6 +165,8 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
Map from category_value to id (old/existing).
category_undersample:
Name of class that should be used to undersample the data (for 2 phase learning)
class_name_map:
Map that stores the variations of possible class names for the given category (task)

Returns:
dict:
Expand All @@ -172,6 +175,9 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
New undersampled data (for 2 phase learning) with integers inplace of strings for category values
dict:
Map from category value to ID for all categories in the data.

Raises:
Exception: If categoryvalue2id is pre-defined and its labels do not match the labels found in the data
"""
data = list(data)
if existing_category_value2id is not None:
Expand All @@ -180,9 +186,38 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
category_value2id = {}

category_values = set([x[2] for x in data])
for c in category_values:
if c not in category_value2id:
category_value2id[c] = len(category_value2id)

# If categoryvalue2id is pre-defined, then making sure it is same as the labels found in the data
if len(category_value2id) != 0:
if set(category_value2id.keys()) != category_values:
# if categoryvalue2id doesn't match the labels in the data, then 'class_name_map' has to be defined to check for variations
if len(class_name_map) != 0:
updated_category_value2id = {}
for _class in category_value2id.keys():
if _class in category_values:
updated_category_value2id[_class] = category_value2id[_class]
else:
found_in = [sub_map for sub_map in class_name_map if _class in sub_map][0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here as well, we take the first matching "map" that has a matching class in it.
If there isn't a matching "sub map" then this will fail with an IndexError.
What if there are more than 1? To me that sounds like a configuration mismatch and should lead to an exception.

if len(found_in) != 0:
mart-r marked this conversation as resolved.
Show resolved Hide resolved
class_name_matched = [label for label in found_in if label in category_values][0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's no existing category names defined in the "map", this will fail with an IndexError.
If there's more than one existing category name in this "map", I think and exception should be raised as it'd be a misconfiguration.

updated_category_value2id[class_name_matched] = category_value2id[_class]
logger.warning("Class name '%s' does not exist in the data; however a variation of it '%s' is present; updating it...",_class,class_name_matched)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I feel like at this point, it'd be expected behaviour and better suited as an info message.

else:
raise Exception(f"The classes set in the config are not the same as the one found in the data. The classes present in the config vs the ones found in the data - {set(category_value2id.keys())}, {category_values}")
category_value2id = copy.deepcopy(updated_category_value2id)
logger.info("Updated categoryvalue2id mapping - %s", category_value2id)

# Else throw an exception since the labels don't match
else:
raise Exception(
f"The classes set in the config are not the same as the one found in the data. The classes present in the config vs the ones found in the data - {set(category_value2id.keys())}, {category_values}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps also mention the possibility of setting config.general.category_names_map for alternatives?


# Else create the mapping from the labels found in the data
else:
for c in category_values:
if c not in category_value2id:
category_value2id[c] = len(category_value2id)
logger.info("Categoryvalue2id mapping created with labels found in the data - %s", category_value2id)

# Map values to numbers
for i in range(len(data)):
Expand All @@ -194,7 +229,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
if data[i][2] in category_value2id.values():
label_data_[data[i][2]] = label_data_[data[i][2]] + 1

logger.info("Original label_data: %s",label_data_)
logger.info("Original number of samples per label: %s",label_data_)
# Undersampling data
if category_undersample is None or category_undersample == '':
min_label = min(label_data_.values())
Expand All @@ -217,7 +252,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
for i in range(len(data_undersampled)):
if data_undersampled[i][2] in category_value2id.values():
label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1
logger.info("Updated label_data: %s",label_data)
logger.info("Updated number of samples per label (for 2-phase learning): %s",label_data)

return data, data_undersampled, category_value2id

Expand Down
4 changes: 2 additions & 2 deletions medcat/utils/meta_cat/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,12 @@ def initialize_model(classifier, data_, batch_size_, lr_, epochs=4):
print_report(epoch, running_loss_test, all_logits_test, y=y_test, name='Test')

_report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1),
output_dict=True)
output_dict=True,zero_division=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to have it return 1? I.e if there are no retrieved elements, then precision is undetermined. But if we set it to 1 and don't show the warning, the user might think everything is working great. But in reality, nothing was extracted.
Perhaps setting it to 0 would make more sense?

Though in an ideal world, I'd love to still show a warning to the user. But maybe limit it to once per 10 minutes or something. But I don't think they have API for something like that. And I don't think it'd make sense to spend the time implementing it, either.

if not winner_report or _report[config.train['metric']['base']][config.train['metric']['score']] > \
winner_report['report'][config.train['metric']['base']][config.train['metric']['score']]:

report = classification_report(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1),
output_dict=True)
output_dict=True,zero_division=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here. Does it make sense to show 1 in case of no retrieved elements? Or would 0 be more appropriate?

cm = confusion_matrix(y_test, np.argmax(np.concatenate(all_logits_test, axis=0), axis=1), normalize='true')
report_train = classification_report(y_train, np.argmax(np.concatenate(all_logits, axis=0), axis=1),
output_dict=True)
Expand Down
Loading