Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit 22e2255

Browse files
authored
CU-8693bpq82 fallback spacy model (#384)
* CU-8693bpq82: Add fallback spacy model along with test * CU-8693bpq82: Remove debug output * CU-8693bpq82: Add exception info to warning upon spacy model load failure and fallback
1 parent 9e5fca1 commit 22e2255

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

medcat/cdb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def load_config(self, config_path: str) -> None:
526526
if not os.path.exists(config_path):
527527
if not self._config_from_file:
528528
# if there's no config defined anywhere
529-
raise ValueError("Could not find a config in the CDB nor ",
529+
raise ValueError("Could not find a config in the CDB nor "
530530
"in the config.json for this model "
531531
f"({os.path.dirname(config_path)})",
532532
)

medcat/pipe.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
logger = logging.getLogger(__name__) # different logger from the package-level one
2323

2424

25+
DEFAULT_SPACY_MODEL = 'en_core_web_md'
26+
27+
2528
class Pipe(object):
2629
"""A wrapper around the standard spacy pipeline.
2730
@@ -38,7 +41,22 @@ class Pipe(object):
3841
"""
3942

4043
def __init__(self, tokenizer: Tokenizer, config: Config) -> None:
41-
self._nlp = spacy.load(config.general.spacy_model, disable=config.general.spacy_disabled_components)
44+
try:
45+
self._nlp = self._init_nlp(config)
46+
except Exception as e:
47+
if config.general.spacy_model == DEFAULT_SPACY_MODEL:
48+
raise e
49+
logger.warning("Could not load spacy model from '%s'. "
50+
"Falling back to installed en_core_web_md. "
51+
"For best compatibility, we'd recommend "
52+
"packaging and using your model pack with "
53+
"the spacy model it was designed for",
54+
config.general.spacy_model, exc_info=e)
55+
# we're changing the config value so that this propages
56+
# to other places that try to load the model. E.g:
57+
# medcat.utils.normalizers.TokenNormalizer.__init__
58+
config.general.spacy_model = DEFAULT_SPACY_MODEL
59+
self._nlp = self._init_nlp(config)
4260
if config.preprocessing.stopwords is not None:
4361
self._nlp.Defaults.stop_words = set(config.preprocessing.stopwords)
4462
self._nlp.tokenizer = tokenizer(self._nlp, config)
@@ -48,6 +66,9 @@ def __init__(self, tokenizer: Tokenizer, config: Config) -> None:
4866
# Set log level
4967
logger.setLevel(self.config.general.log_level)
5068

69+
def _init_nlp(selef, config: Config) -> Language:
70+
return spacy.load(config.general.spacy_model, disable=config.general.spacy_disabled_components)
71+
5172
def add_tagger(self, tagger: Callable, name: Optional[str] = None, additional_fields: List[str] = []) -> None:
5273
"""Add any kind of a tagger for tokens.
5374

tests/test_cat.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from medcat.vocab import Vocab
1111
from medcat.cdb import CDB, logger as cdb_logger
1212
from medcat.cat import CAT, logger as cat_logger
13+
from medcat.pipe import logger as pipe_logger
1314
from medcat.utils.checkpoint import Checkpoint
1415
from medcat.meta_cat import MetaCAT
1516
from medcat.config_meta_cat import ConfigMetaCAT
@@ -499,6 +500,34 @@ def test_loading_model_pack_with_cdb_config_and_config_json_raises_exception(sel
499500
CAT.load_model_pack(self.model_path)
500501

501502

503+
class ModelLoadsUnreadableSpacy(unittest.TestCase):
504+
505+
@classmethod
506+
def setUpClass(cls) -> None:
507+
cls.temp_dir = tempfile.TemporaryDirectory()
508+
model_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples")
509+
cdb = CDB.load(os.path.join(model_path, 'cdb.dat'))
510+
cdb.config.general.spacy_model = os.path.join(cls.temp_dir.name, "en_core_web_md")
511+
# save CDB in new location
512+
cdb.save(os.path.join(cls.temp_dir.name, 'cdb.dat'))
513+
# save config in new location
514+
cdb.config.save(os.path.join(cls.temp_dir.name, 'config.json'))
515+
# copy vocab into new location
516+
vocab_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat")
517+
cls.vocab_path = os.path.join(cls.temp_dir.name, 'vocab.dat')
518+
shutil.copyfile(vocab_path, cls.vocab_path)
519+
520+
@classmethod
521+
def tearDownClass(cls) -> None:
522+
# REMOVE temp dir
523+
cls.temp_dir.cleanup()
524+
525+
def test_loads_without_specified_spacy_model(self):
526+
with self.assertLogs(logger=pipe_logger, level=logging.WARNING):
527+
cat = CAT.load_model_pack(self.temp_dir.name)
528+
self.assertTrue(isinstance(cat, CAT))
529+
530+
502531
class ModelWithZeroConfigsLoadTests(unittest.TestCase):
503532

504533
@classmethod

0 commit comments

Comments
 (0)