Skip to content

Commit

Permalink
Merge pull request #383 from uclh-criu/stopwords-loading-fix
Browse files Browse the repository at this point in the history
Fix stopwords loading bug
  • Loading branch information
tomolopolis authored Jan 3, 2024
2 parents 22e2255 + 45fa0e2 commit f0ef8cd
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
7 changes: 5 additions & 2 deletions medcat/pipe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import types
import os
import spacy
import gc
import logging
Expand Down Expand Up @@ -41,6 +42,10 @@ class Pipe(object):
"""

def __init__(self, tokenizer: Tokenizer, config: Config) -> None:
if config.preprocessing.stopwords is not None:
lang = os.path.basename(config.general.spacy_model).split('_', 1)[0]
cls = spacy.util.get_lang_class(lang)
cls.Defaults.stop_words = set(config.preprocessing.stopwords)
try:
self._nlp = self._init_nlp(config)
except Exception as e:
Expand All @@ -57,8 +62,6 @@ def __init__(self, tokenizer: Tokenizer, config: Config) -> None:
# medcat.utils.normalizers.TokenNormalizer.__init__
config.general.spacy_model = DEFAULT_SPACY_MODEL
self._nlp = self._init_nlp(config)
if config.preprocessing.stopwords is not None:
self._nlp.Defaults.stop_words = set(config.preprocessing.stopwords)
self._nlp.tokenizer = tokenizer(self._nlp, config)
# Set max document length
self._nlp.max_length = config.preprocessing.max_document_length
Expand Down
42 changes: 42 additions & 0 deletions tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from medcat.vocab import Vocab
from medcat.cdb import CDB, logger as cdb_logger
from medcat.cat import CAT, logger as cat_logger
from medcat.config import Config
from medcat.pipe import logger as pipe_logger
from medcat.utils.checkpoint import Checkpoint
from medcat.meta_cat import MetaCAT
Expand Down Expand Up @@ -479,6 +480,47 @@ def test_add_and_train_concept_cdb_warns_short_name(self):
self.assertLogsDuringAddAndTrainConcept(cdb_logger, logging.WARNING, name=short_name, name_status='P', nr_of_calls=1)


class GetEntitiesWithStopWords(unittest.TestCase):
# NB! The order in which the different CDBs are created
# is important here since the way that the stop words are
# set is class-based, it creates the side effect of having
# the same stop words the next time around
# regardless of whether or not they should've been set

@classmethod
def setUpClass(cls) -> None:
cls.cdb1 = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat"))
cls.cdb2 = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat"))
cls.vocab = Vocab.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat"))
cls.vocab.make_unigram_table()
cls.cdb1.config.general.spacy_model = "en_core_web_md"
cls.cdb1.config.ner.min_name_len = 2
cls.cdb1.config.ner.upper_case_limit_len = 3
cls.cdb1.config.general.spell_check = True
cls.cdb1.config.linking.train_count_threshold = 10
cls.cdb1.config.linking.similarity_threshold = 0.3
cls.cdb1.config.linking.train = True
cls.cdb1.config.linking.disamb_length_limit = 5
cls.cdb1.config.general.full_unlink = True
cls.cdb2.config = Config.from_dict(cls.cdb1.config.asdict())
# the regular CAT without stopwords
cls.no_stopwords = CAT(cdb=cls.cdb1, config=cls.cdb1.config, vocab=cls.vocab, meta_cats=[])
# this (the following two lines)
# needs to be done before initialising the CAT
# since that initialises the pipe
cls.cdb2.config.preprocessing.stopwords = {"stop", "words"}
cls.cdb2.config.preprocessing.skip_stopwords = True
# the CAT that skips the stopwords
cls.w_stopwords = CAT(cdb=cls.cdb2, config=cls.cdb2.config, vocab=cls.vocab, meta_cats=[])

def test_stopwords_are_skipped(self, text: str = "second words csv"):
# without stopwords no entities are captured
# with stopwords, the `second words csv` entity is captured
doc_no_stopwords = self.no_stopwords(text)
doc_w_stopwords = self.w_stopwords(text)
self.assertGreater(len(doc_w_stopwords._.ents), len(doc_no_stopwords._.ents))


class ModelWithTwoConfigsLoadTests(unittest.TestCase):

@classmethod
Expand Down
9 changes: 8 additions & 1 deletion tests/test_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def setUpClass(cls) -> None:
cls.config.ner['max_skip_tokens'] = 1
cls.config.ner['upper_case_limit_len'] = 4
cls.config.linking['disamb_length_limit'] = 2
cls.config.preprocessing.stopwords = {'stop', 'words'}
cls.cdb = CDB(config=cls.config)

downloader = VocabDownloader()
Expand All @@ -42,7 +43,7 @@ def setUpClass(cls) -> None:
_tokenizer = TokenizerWrapperBERT(hf_tokenizers=AutoTokenizer.from_pretrained("bert-base-uncased"))
cls.meta_cat = MetaCAT(tokenizer=_tokenizer)

cls.text = "CDB - I was running and then Movar Virus attacked and CDb"
cls.text = "stop of CDB - I was running and then Movar Virus attacked and CDb"
cls.undertest = Pipe(tokenizer=spacy_split_all, config=cls.config)

@classmethod
Expand Down Expand Up @@ -81,6 +82,12 @@ def test_add_meta_cat(self):
PipeTests.undertest.add_meta_cat(PipeTests.meta_cat)

self.assertEqual(PipeTests.meta_cat.name, Language.get_factory_meta(PipeTests.meta_cat.name).factory)

def test_stopwords_loading(self):
self.assertEqual(PipeTests.undertest._nlp.Defaults.stop_words, PipeTests.config.preprocessing.stopwords)
doc = PipeTests.undertest(PipeTests.text)
self.assertEqual(doc[0].is_stop, True)
self.assertEqual(doc[1].is_stop, False)

def test_batch_multi_process(self):
PipeTests.undertest.add_tagger(tagger=tag_skip_and_punct, additional_fields=["is_punct"])
Expand Down

0 comments on commit f0ef8cd

Please sign in to comment.