Skip to content

Commit 6f3403e

Browse files
authored
CU-869805t7e alt names fixes (#520)
* CU-869805t7e: Move getting of applicable category name to the config * CU-869805t7e: Use alternative category names in eval method * CU-869805t7e: Reduce indentation * CU-869805t7e: Reduce indentation (again) * CU-869805t7e: Some comment fixing due to rearrangements before * CU-869805t7e: Fix usage of matched class name when encoding category values * CU-869805t7e: Avoid duplicating exception message
1 parent e3913ab commit 6f3403e

File tree

3 files changed

+58
-40
lines changed

3 files changed

+58
-40
lines changed

medcat/config_meta_cat.py

+17
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
import logging
12
from typing import Dict, Any, List
3+
from collections.abc import Container
24
from medcat.config import MixingConfig, BaseModel, Optional
35

46

7+
logger = logging.getLogger(__name__)
8+
9+
510
class General(MixingConfig, BaseModel):
611
"""The General part of the MetaCAT config"""
712
device: str = 'cpu'
@@ -78,6 +83,18 @@ class General(MixingConfig, BaseModel):
7883
"""If set, the spacy span group that the metacat model will assign annotations.
7984
Otherwise defaults to doc._.ents or doc.ents per the annotate_overlapping settings"""
8085

86+
def get_applicable_category_name(self, available_names: Container[str]) -> Optional[str]:
87+
if self.category_name in available_names:
88+
return self.category_name
89+
matches = [cat for cat in self.alternative_category_names if cat in available_names]
90+
if len(matches) > 0:
91+
logger.info("The category name provided in the config - '%s' is not present in the data. "
92+
"However, the corresponding name - '%s' from the category_name_mapping has been found. "
93+
"Updating the category name...", self.category_name, *matches)
94+
self.category_name = matches[0]
95+
return self.category_name
96+
return None
97+
8198
class Config:
8299
extra = 'allow'
83100
validate_assignment = True

medcat/meta_cat.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -243,18 +243,13 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
243243
lowercase=g_config['lowercase'])
244244

245245
# Check is the name present
246-
category_name = g_config['category_name']
247-
category_name_options = g_config['alternative_category_names']
248-
if category_name not in data:
249-
category_matching = [cat for cat in category_name_options if cat in data.keys()]
250-
if len(category_matching) > 0:
251-
logger.info("The category name provided in the config - '%s' is not present in the data. However, the corresponding name - '%s' from the category_name_mapping has been found. Updating the category name...",category_name,*category_matching)
252-
g_config['category_name'] = category_matching[0]
253-
category_name = g_config['category_name']
254-
else:
255-
raise Exception(
256-
"The category name does not exist in this json file. You've provided '{}', while the possible options are: {}. Additionally, ensure the populate the 'alternative_category_names' attribute to accommodate for variations.".format(
257-
category_name, " | ".join(list(data.keys()))))
246+
category_name = g_config.get_applicable_category_name(data)
247+
if category_name is None:
248+
raise Exception(
249+
"The category name does not exist in this json file. You've provided '{}', "
250+
"while the possible options are: {}. Additionally, ensure the populate the "
251+
"'alternative_category_names' attribute to accommodate for variations.".format(
252+
category_name, " | ".join(list(data.keys()))))
258253

259254
data = data[category_name]
260255
if data_oversampled:
@@ -344,8 +339,8 @@ def eval(self, json_path: str) -> Dict:
344339
lowercase=g_config['lowercase'])
345340

346341
# Check is the name there
347-
category_name = g_config['category_name']
348-
if category_name not in data:
342+
category_name = g_config.get_applicable_category_name(data)
343+
if category_name is None:
349344
raise Exception("The category name does not exist in this json file.")
350345

351346
data = data[category_name]

medcat/utils/meta_cat/data_utils.py

+32-26
Original file line numberDiff line numberDiff line change
@@ -188,33 +188,39 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
188188
category_values = set([x[2] for x in data])
189189

190190
# If categoryvalue2id is pre-defined, then making sure it is same as the labels found in the data
191-
if len(category_value2id) != 0:
192-
if set(category_value2id.keys()) != category_values:
193-
# if categoryvalue2id doesn't match the labels in the data, then 'alternative_class_names' has to be defined to check for variations
194-
if len(alternative_class_names) != 0:
195-
updated_category_value2id = {}
196-
for _class in category_value2id.keys():
197-
if _class in category_values:
198-
updated_category_value2id[_class] = category_value2id[_class]
199-
else:
200-
found_in = [sub_map for sub_map in alternative_class_names if _class in sub_map]
201-
if len(found_in) != 0:
202-
class_name_matched = [label for label in found_in[0] if label in category_values]
203-
if len(class_name_matched) != 0:
204-
updated_category_value2id[class_name_matched] = category_value2id[_class]
205-
logger.info("Class name '%s' does not exist in the data; however a variation of it '%s' is present; updating it...",_class,class_name_matched)
206-
else:
207-
raise Exception(
208-
f"The classes set in the config are not the same as the one found in the data. The classes present in the config vs the ones found in the data - {set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the 'alternative_class_names' attribute to accommodate for variations.")
209-
else:
210-
raise Exception(f"The classes set in the config are not the same as the one found in the data. The classes present in the config vs the ones found in the data - {set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the 'alternative_class_names' attribute to accommodate for variations.")
211-
category_value2id = copy.deepcopy(updated_category_value2id)
212-
logger.info("Updated categoryvalue2id mapping - %s", category_value2id)
213-
214-
# Else throw an exception since the labels don't match
191+
if len(category_value2id) != 0 and set(category_value2id.keys()) != category_values:
192+
# if categoryvalue2id doesn't match the labels in the data, then 'alternative_class_names' has to be defined to check for variations
193+
if len(alternative_class_names) == 0:
194+
# Raise an exception since the labels don't match
195+
raise Exception(
196+
"The classes set in the config are not the same as the one found in the data. "
197+
"The classes present in the config vs the ones found in the data - "
198+
f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the "
199+
"'alternative_class_names' attribute to accommodate for variations.")
200+
updated_category_value2id = {}
201+
for _class in category_value2id.keys():
202+
if _class in category_values:
203+
updated_category_value2id[_class] = category_value2id[_class]
215204
else:
216-
raise Exception(
217-
f"The classes set in the config are not the same as the one found in the data. The classes present in the config vs the ones found in the data - {set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the 'alternative_class_names' attribute to accommodate for variations.")
205+
found_in = [sub_map for sub_map in alternative_class_names if _class in sub_map]
206+
failed_to_find = False
207+
if len(found_in) != 0:
208+
class_name_matched = [label for label in found_in[0] if label in category_values]
209+
if len(class_name_matched) != 0:
210+
updated_category_value2id[class_name_matched[0]] = category_value2id[_class]
211+
logger.info("Class name '%s' does not exist in the data; however a variation of it "
212+
"'%s' is present; updating it...", _class, class_name_matched[0])
213+
else:
214+
failed_to_find = True
215+
else:
216+
failed_to_find = True
217+
if failed_to_find:
218+
raise Exception("The classes set in the config are not the same as the one found in the data. "
219+
"The classes present in the config vs the ones found in the data - "
220+
f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the "
221+
"populate the 'alternative_class_names' attribute to accommodate for variations.")
222+
category_value2id = copy.deepcopy(updated_category_value2id)
223+
logger.info("Updated categoryvalue2id mapping - %s", category_value2id)
218224

219225
# Else create the mapping from the labels found in the data
220226
else:

0 commit comments

Comments
 (0)