diff --git a/README.md b/README.md index a6e5988..bc8708c 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,58 @@ # Trove +-- - +[![Documentation Status](https://readthedocs.org/projects/trove/badge/?version=latest)](https://trove.readthedocs.io/en/latest/?badge=latest) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) -Trove is a framework for training weakly supervised (bio)medical named entity recognition (NER) and other entity attribute classifiers without hand-labeled training data. +Trove is a research framework for building weakly supervised (bio)medical named entity recognition (NER) and other entity attribute classifiers without hand-labeled training data. -We combine a range of supervision signal common medical ontologies such as the Unified Medical Language System (UMLS), clinical text heuristics, and other noisy labeling sources for use with weak supervision frameworks such as [Snorkel](https://github.com/snorkel-team/snorkel). +The COVID-19 pandemic has underlined the need for faster, more flexible ways of building and sharing state-of-the-art NLP/NLU tools to analyze electronic health records (EHR), scientific literature, and social media. Trove provides tools for combining freely available supervision sources such as medical ontologies from the [Unified Medical Language System (UMLS)](https://www.nlm.nih.gov/research/umls/index.html), common text heuristics, and other noisy labeling sources for use as entity *labelers* in weak supervision frameworks such as [Snorkel](https://github.com/snorkel-team/snorkel), [FlyingSquid ](https://github.com/HazyResearch/flyingsquid) and others. Technical details are available in our [manuscript](https://www.nature.com/articles/s41467-021-22328-4). -Technical details are available in our [manuscript](https://arxiv.org/abs/2008.01972). +Trove has been used as part of several COVID-19 reseach efforts at Stanford. -## Installation +- [Continuous symptom profiling of patients screened for SARS-CoV-2](https://med.stanford.edu/covid19/research.html#data-science-and-modeling). We used a daily feed of patient notes from Stanford Health Care emergency departments to generate up-to-date [COVID-19 symptom frequency](https://docs.google.com/spreadsheets/d/1iZZvbv94fpZdC6XaiPosiniMOh18etSPliAXVlLLr1w/edit#gid=344371264) data. Funded by the [Bill & Melinda Gates Foundation](https://www.gatesfoundation.org/about/committed-grants/2020/04/inv017214). +- [Estimating the efficacy of symptom-based screening for COVID-19](https://rdcu.be/chSrv) published in *npj Digitial Medicine*. +- Our COVID-19 symptom data was used by CMU's [DELPHI group](https://covidcast.cmu.edu/) to prioritize selection of informative features from [Google's Symptom Search Trends dataset](https://github.com/GoogleCloudPlatform/covid-19-open-data/blob/main/docs/table-search-trends.md). -Requirements: python 3.6, pytorch 1.0+, snorkel 0.9.5+ -## Tutorials +## Getting Started -See `tutorials/` +### Tutorials -## Requirements +See [`tutorials/`](https://github.com/som-shahlab/trove/tree/dev/tutorials) for Jupyter notebooks walking through an example NER application. + +### Installation + +Requirements: Python 3.6 or later. We recomend using `pip` to install + +`pip install -r requirements.txt` + +## Contributions +We welcome all contributions to the code base! Please submit a pull request and/or start a discussion on GitHub Issues. + +Weakly supervised methods for programatically building and maintaining training labels provides new opportunities for the larger community to participate in the creation of important datasets. This is especially exciting in domains such as medicine, where sharing labeled data is often challening due to patient privacy concerns. + +Inspired by recent efforts such as [HuggingFace's Datasets](ttps://github.com/huggingface/datasets) library, +we would love to start a conversation around how to support sharing labelers in service of mantaining an open task library, so that it is easier to create, deploy, and version control weakly supervised models. -Tested on OSX and Linux. ## Citation -If use Trove in your research, please cite [Ontology-driven weak supervision for clinical entity classification in electronic health records]() +If use Trove in your research, please cite us! + +Fries, J.A., Steinberg, E., Khattar, S. et al. Ontology-driven weak supervision for clinical entity classification in electronic health records. Nat Commun 12, 2017 (2021). https://doi-org.stanford.idm.oclc.org/10.1038/s41467-021-22328-4 + +``` +@article{fries2021trove, + title={Ontology-driven weak supervision for clinical entity classification in electronic health records}, + author={Fries, Jason A and Steinberg, Ethan and Khattar, Saelig and Fleming, Scott L and Posada, Jose and Callahan, Alison and Shah, Nigam H}, + journal={Nature Communications}, + volume={12}, + number={1}, + year={2021}, + publisher={Nature Publishing Group} +} +``` -See the `manuscript` branch for the code used diff --git a/applications/README.md b/applications/README.md index 2ebc994..356a893 100644 --- a/applications/README.md +++ b/applications/README.md @@ -5,10 +5,10 @@ Labeling functions for various weakly supervised biomedical classification tasks -| Name | Task | Domain | Type | Source | -|------------------|------------------|------------|------|-----------------------------------------------| -| `bc5cdr/` | Chemical/Disease | Literature | NER | BioCreative V Chemical-Disease Relation (CDR) | -| `i2b2drugs/` | Drug | Clinical | NER | n2c2/i2b2 2009 Medication Challenge | -| `shareclef2014/` | Disorder | Clinical | NER | ShARe/CLEF 2014 | -| `thyme/` | DocRelaTime | Clinical | Span | THYME 2017 | -| `covid19/` | Exposure | Clinical | Span | COVID-19 exposure | \ No newline at end of file +| Name | Task | Domain | Type | Source | Access | +|------------------|------------------|------------|------|-----------------------------------------------|------------| +| `bc5cdr/` | Chemical/Disease | Literature | NER | BioCreative V Chemical-Disease Relation (CDR) | Public | +| `i2b2drugs/` | Drug | Clinical | NER | n2c2/i2b2 2009 Medication Challenge | DUA | +| `shareclef2014/` | Disorder | Clinical | NER | ShARe/CLEF 2014 | DUA | +| `thyme/` | DocRelaTime | Clinical | Span | THYME 2017 | DUA| +| `covid19/` | Exposure | Clinical | Span | COVID-19 exposure | - | \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index f3d724a..cad4fe9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,9 +10,9 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) +import os +import sys +sys.path.insert(0, os.path.abspath('../..')) # -- Project information ----------------------------------------------------- @@ -31,8 +31,15 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.coverage', + 'sphinx.ext.napoleon', + 'sphinx.ext.autosummary' ] +autosummary_generate = True + + # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] diff --git a/docs/source/index.rst b/docs/source/index.rst index f46c0b0..82f4260 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,17 +1,28 @@ -.. trove documentation master file, created by - sphinx-quickstart on Mon Mar 22 00:23:28 2021. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -Welcome to trove's documentation! +Welcome to Trove's documentation! ================================= +Trove is a research framework for building weakly supervised (bio)medical +named entity recognition (NER) and other entity attribute classifiers without hand-labeled training data. + +The COVID-19 pandemic has underlined the need for faster, more flexible ways of building +and sharing state-of-the-art NLP/NLU tools to analyze electronic health records (EHR), +scientific literature, and social media. Trove provides tools for combining freely +available supervision sources such as medical ontologies from the Unified Medical +Language System (UMLS), common text heuristics, and other noisy labeling sources for use +as entity *labelers* in weak supervision frameworks such as Snorkel, FlyingSquid, and +others. Technical details are available in our manuscript. + +.. autosummary:: + :toctree: _autosummary + :recursive: + + trove + .. toctree:: - :maxdepth: 2 + :maxdepth: 10 :caption: Contents: - Indices and tables ================== diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d20b7b8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +toolz==0.11.1 +tqdm==4.59.0 +torch==1.8.0 +requests==2.25.1 +pandas==1.1.5 +scipy==1.5.2 +lxml==4.6.2 +spacy==3.0.5 +numpy==1.19.2 +joblib==1.0.1 +msgpack_python==0.5.6 +norm==1.6.0 +pytorch_pretrained_bert==0.6.2 +scikit_learn==0.24.1 +seqeval==1.2.2 +stopwords==1.0.0 diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/metrics/__init__.py b/test/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/metrics/test_metrics.py b/test/metrics/test_metrics.py new file mode 100644 index 0000000..b43a061 --- /dev/null +++ b/test/metrics/test_metrics.py @@ -0,0 +1,12 @@ +import unittest +import numpy as np + + +class MetricsTest(unittest.TestCase): + def test_convert_tag_fmt(self): + return True + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/trove/labelers/abbreviations.py b/trove/labelers/abbreviations.py index 970f58d..913bc7d 100644 --- a/trove/labelers/abbreviations.py +++ b/trove/labelers/abbreviations.py @@ -14,7 +14,6 @@ """ import re import collections -from typing import Set from trove.dataloaders.contexts import Span from trove.labelers.labeling import ( LabelingFunction, @@ -23,13 +22,13 @@ ) from typing import List, Set, Dict -def is_short_form(s, min_length=2): +def is_short_form(text, min_length=2): """ Rule-based function for determining if a token is likely an abbreviation, acronym or other "short form" mention Parameters ---------- - s + text min_length Returns @@ -39,22 +38,21 @@ def is_short_form(s, min_length=2): accept_rgx = '[0-9A-Z-]{2,8}[s]*' reject_rgx = '([0-9]+/[0-9]+|[0-9]+[-][0-7]+)' - keep = re.search(accept_rgx, s) != None - keep &= re.search(reject_rgx, s) == None - keep &= not s.strip("-").isdigit() - keep &= "," not in s - keep &= len(s) < 15 + keep = re.search(accept_rgx, text) is not None + keep &= re.search(reject_rgx, text) is None + keep &= not text.strip("-").isdigit() + keep &= "," not in text + keep &= len(text) < 15 # reject if too short too short or contains lowercase single letters - reject = (len(s) > 3 and not keep) - reject |= (len(s) <= 3 and re.search("[/,+0-9-]", s) != None) - reject |= (len(s) < min_length) - reject |= (len(s) <= min_length and s.islower()) # + reject = (len(text) > 3 and not keep) + reject |= (len(text) <= 3 and re.search("[/,+0-9-]", text) is not None) + reject |= (len(text) < min_length) + reject |= (len(text) <= min_length and text.islower()) return False if reject else True - def get_parenthetical_short_forms(sentence): """Generator that returns indices of all words directly wrapped by parentheses or brackets. @@ -67,10 +65,10 @@ def get_parenthetical_short_forms(sentence): ------- """ - for i, w in enumerate(sentence.words): + for i, _ in enumerate(sentence.words): if i > 0 and i < len(sentence.words) - 1: window = sentence.words[i - 1:i + 2] - if (window[0] == "(" and window[-1] == ")"): + if window[0] == "(" and window[-1] == ")": if is_short_form(window[1]): yield i @@ -83,7 +81,7 @@ def extract_long_form(i, sentence, max_dup_chars=2): short_form = sentence.words[i] left_window = [w for w in sentence.words[0:i]] - # strip brackets/parantheses + # strip brackets/parentheses while left_window and left_window[-1] in ["(", "[", ":"]: left_window.pop() diff --git a/trove/labelers/core.py b/trove/labelers/core.py index 5c43f99..1d6cb42 100644 --- a/trove/labelers/core.py +++ b/trove/labelers/core.py @@ -1,11 +1,12 @@ +import logging import itertools import numpy as np from scipy import sparse from functools import partial from toolz import partition_all from joblib import Parallel, delayed -from abc import ABCMeta, abstractmethod +logger = logging.getLogger(__name__) class Distributed: @@ -14,7 +15,7 @@ def __init__(self, num_workers=1, backend='multiprocessing'): backend=backend, prefer="processes") self.num_workers = num_workers - print(self.client) + logger.info(self.client) class SequenceLabelingServer(Distributed): @@ -29,15 +30,15 @@ def apply(self, lfs, Xs, block_size=None): block_size = int( np.ceil(np.sum([len(x) for x in Xs]) / self.num_workers) ) - print(f'auto block size={block_size}') + logger.info("auto block size %s", block_size) if block_size: blocks = list( partition_all(block_size, itertools.chain.from_iterable(Xs)) ) - print(f"Partitioned into {len(blocks)} blocks, " - f"{np.unique([len(x) for x in blocks])} sizes") + lens = np.unique([len(x) for x in blocks]) + logger.info("Partitioned into %s blocks %s sizes ", len(blocks), lens) do = delayed(partial(SequenceLabelingServer.worker, lfs)) jobs = (do(batch) for batch in blocks) @@ -67,15 +68,15 @@ def apply(self, lfs, Xs, block_size=None): block_size = int( np.ceil(np.sum([len(x) for x in Xs]) / self.num_workers) ) - print(f'auto block size={block_size}') + logger.info("auto block size %s", block_size) if block_size: blocks = list( partition_all(block_size, itertools.chain.from_iterable(Xs)) ) - print(f"Partitioned into {len(blocks)} blocks, " - f"{np.unique([len(x) for x in blocks])} sizes") + lens = np.unique([len(x) for x in blocks]) + logger.info("Partitioned into %s blocks %s sizes ", len(blocks), lens) do = delayed(partial(LabelingServer.worker, lfs)) jobs = (do(batch) for batch in blocks) diff --git a/trove/labelers/labeling.py b/trove/labelers/labeling.py index ba278fa..d6dea4a 100644 --- a/trove/labelers/labeling.py +++ b/trove/labelers/labeling.py @@ -87,7 +87,7 @@ def __init__(self, name: str, ontology: Dict[str, np.array], case_sensitive: bool = False, - max_ngrams: int = 4, + max_ngrams: int = 8, stopwords = None) -> None: super().__init__(name, None) @@ -103,9 +103,17 @@ def __init__(self, else int(np.argmax(proba) + 1) self.ontology = frozenset(ontology) - def _get_term_label(self, t): + def _get_term_label(self, term): + """ + Check for term match, given set of simple transformations + (e.g., lowercasing, simple pluralization) - for key in [t, t.lower(), t.rstrip('s'), t + 's']: + TODO: Consider a proper abstraction for handling valid aliases. + + :param term: + :return: + """ + for key in [term, term.lower(), term.rstrip('s'), term + 's']: if key in self.stopwords: return self.stopwords[key] if key in self._labels: @@ -202,17 +210,12 @@ def _get_term_label(self, t): return None def _merge_matches(self, matches): - """ Merge all contiguous spans with the same label. - - Parameters - ---------- - matches - - Returns - ------- - """ + Merge all contiguous spans with the same label. + :param matches: + :return: + """ terms = [m[-1] for m in matches] labels = [self._get_term_label(m[-1]) for m in matches] @@ -387,7 +390,7 @@ def __call__(self, sentence): class SynSetLabelingFunction(LabelingFunction): """ - Given a map of TERM -> {t \in SYNONYMS}, if the TERM AND any t + Given a map of TERM -> {t \\in SYNONYMS}, if the TERM AND any t appear in document, label as a positive instance of the entity. """ def __init__(self, @@ -466,4 +469,4 @@ def __call__(self, sentence): spans = self._get_contiguous_spans(spans) spans = list(itertools.chain.from_iterable([s for s in spans if len(s) >= self.min_length])) - return {i:L[i] for i in spans} \ No newline at end of file + return {i:L[i] for i in spans} diff --git a/trove/labelers/matchers.py b/trove/labelers/matchers.py index bd6e0f5..276abdc 100644 --- a/trove/labelers/matchers.py +++ b/trove/labelers/matchers.py @@ -194,9 +194,7 @@ def match_rgx(rgx: Pattern, sentence: Sentence) -> Dict[Tuple, Span]: def get_longest_matches(matches: Dict[Tuple, Span]) -> Iterable[Span]: - """ - TODO: Hack -- rewrite this - """ + mask = {} for key in sorted(matches.keys(), key=lambda x: x[-1], reverse=1): is_longest = True @@ -208,4 +206,4 @@ def get_longest_matches(matches: Dict[Tuple, Span]) -> Iterable[Span]: else: is_longest = False if is_longest: - yield span \ No newline at end of file + yield span diff --git a/trove/labelers/umls.py b/trove/labelers/umls.py index b2cc01a..ffc7851 100644 --- a/trove/labelers/umls.py +++ b/trove/labelers/umls.py @@ -45,19 +45,30 @@ def __init__(self, **kwargs): kwargs['cache_path'] if 'cache_path' in kwargs else UMLS.cache_root ) - logger.info(f"cache_path= {self.cache_path}") - logger.info(f"backend= {self.backend}") + logger.info("cache_path=%s", self.cache_path) + logger.info("backend=%s", self.backend) self._load_indices() self._apply_filters(**kwargs) @classmethod def config(cls, cache_root, backend): + """ + Assign new defaults to class member variables. + + :param cache_root: + :param backend: + :return: + """ cls.cache_root = cache_root cls.backend = backend def _load_indices(self): + """ + Load various pre-generated indices. + :return: + """ if not UMLS.is_initalized(self.cache_path, self.backend): raise Exception("Error, UMLS not initialized.") @@ -71,7 +82,13 @@ def _load_indices(self): open(f"{self.cache_path}/tui_to_sty.bin", 'rb')) def _load_terminologies(self, filter_sabs, type_mapping='TUI'): + """ + Load pre-generated terminology files. + :param filter_sabs: + :param type_mapping: + :return: + """ if self.backend == 'pandas': df = pd.read_parquet( f"{self.cache_path}/concepts/", @@ -110,6 +127,17 @@ def _apply_filters(self, stopwords=None): """ Load concepts file and create transformed terminology dictionaries + + :param type_mapping: + :param min_char_len: + :param max_tok_len: + :param min_dict_size: + :param languages: + :param transforms: + :param filter_sabs: + :param filter_rgx: + :param stopwords: + :return: """ # defaults filter_sabs = filter_sabs if filter_sabs else {'SNOMEDCT_VET'} @@ -165,7 +193,13 @@ def apply_transforms(term, transforms): @staticmethod def init_sqlite_tables(fpath, dataframe): + """ + Initialize a simple sqlite3 database schema. + :param fpath: + :param dataframe: + :return: + """ conn = sqlite3.connect(fpath) sql = """CREATE TABLE IF NOT EXISTS terminology ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -193,6 +227,13 @@ def init_sqlite_tables(fpath, dataframe): @staticmethod def is_initalized(cache_root=None, backend=None): + """ + Test if the UMLS has been initialized by looking at the cache. + + :param cache_root: + :param backend: + :return: + """ cache_path = UMLS.get_full_cache_path(cache_root) backend = backend if backend else UMLS.backend filelist = ['sabs.bin', 'tui_to_sty.bin', 'concepts'] @@ -207,6 +248,12 @@ def is_initalized(cache_root=None, backend=None): @staticmethod def reset(cache_root=None): + """ + Clear all cached files. + + :param cache_root: + :return: + """ cache_path = UMLS.get_full_cache_path(cache_root) if os.path.exists(cache_path): shutil.rmtree(cache_path) @@ -232,6 +279,10 @@ def init_from_nlm_zip(fpath, complete RRF file set. :param fpath: + :param outdir: + :param backend: + :param use_checksum: + :param keep_original_rrfs: :return: """ assert os.path.exists(fpath) diff --git a/tutorials/4_BERT_End_Model.ipynb b/tutorials/4_BERT_End_Model.ipynb new file mode 100644 index 0000000..140644c --- /dev/null +++ b/tutorials/4_BERT_End_Model.ipynb @@ -0,0 +1,47 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# II. BERT End Model \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Coming Soon!\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "trove", + "language": "python", + "name": "trove" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.11" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tutorials/Weakly-Supervised-NER.ipynb b/tutorials/Weakly-Supervised-NER.ipynb deleted file mode 100644 index ce9b910..0000000 --- a/tutorials/Weakly-Supervised-NER.ipynb +++ /dev/null @@ -1,1562 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# II. Weakly Supervised Named Entity Recognition (NER)\n", - "\n", - "We'll use the public [BioCreative V Chemical Disease Relation](https://biocreative.bioinformatics.udel.edu/tasks/biocreative-v/track-3-cdr/) (BC5CDR) dataset, focusing on Chemical entities. \n", - "\n", - "See `../applications/BC5CDR/` for the complete labeling function set used in our paper. \n", - "\n", - "## Installation Instructions\n", - "\n", - "- Trove requires access to the [Unified Medical Language System (UMLS)](https://www.nlm.nih.gov/research/umls/licensedcontent/umlsknowledgesources.html) which is freely available after signing up for an account with the National Library of Medicine. Visit the link above and download the latest \"UMLS Metathesaurus Files\" release [2020AB](https://download.nlm.nih.gov/umls/kss/2020AB/umls-2020AB-metathesaurus.zip) and run our UMLS install script. \n", - "- Unzip the preprocessed BioCreative V CDR chemical dataset `bc5cdr.zip`" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import sys\n", - "sys.path.insert(0,'../../trove')\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Load Unlabeled Data & Define Entity Classes\n", - "\n", - "### A. Load Preprocessed Documents\n", - "This notebook assumes documents have already been preprocessed for sentence boundary detection and dumped into JSON format. See `preprocessing/README.md` for details.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tagged Entities: 5203\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Tokenization Error: Token is not a head token Annotation[Chemical](Cl|1240-1242) 19692487\n", - "Tokenization Error: Token is not a head token Annotation[Chemical](Cl|1579-1581) 15075188\n", - "Errors: Span Alignment: 2/5347 (0.0%)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tagged Entities: 5345\n", - "Tagged Entities: 5385\n", - "CPU times: user 29.8 s, sys: 688 ms, total: 30.5 s\n", - "Wall time: 32.4 s\n" - ] - } - ], - "source": [ - "%%time\n", - "import transformers\n", - "from trove.dataloaders import load_json_dataset\n", - "\n", - "tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)\n", - "\n", - "data_dir = \"data/bc5cdr/\"\n", - "dataset = {\n", - " split : load_json_dataset(f'{data_dir}/{split}.cdr.chemical.json', tokenizer)\n", - " for split in ['train', 'dev', 'test']\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### B. Define Entity Categories\n", - "In popular biomedical annotators such as [NCBO BioPortal](https://bioportal.bioontology.org/annotator), we configure the annotator by selecting a set of semantic categories which define our entity class and a corresponding set of ontologies mapped to those types. \n", - "\n", - "Trove uses a similar style of interface in API form. For `CHEMICAL` tagging, we define an entity class consisting of [UMLS Semantic Network](https://semanticnetwork.nlm.nih.gov/) types mapped to $\\{0,1\\}$. The semantic network defines 127 concept categories called _Semantic Types_ (e.g., Disease or Syndrome , Medical Device) which are mappable to 15 coarser-grained _Semantic Groups_ (e.g., Anatomy, Chemicals & Drugs, Disorders). \n", - "\n", - "We use the _Chemicals & Drugs_ (CHEM) semantic group as the basis of our positive class label $1$, removing some categories (e.g., Gene or Genome) that do not match the definition of chemical as outlined in the BC5CDR annotation guidelines. Non-chemical STYs define our negative class label $0$." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "# load the chemical entity definition\n", - "entity_def = pd.read_csv('data/chemical_semantic_types.tsv', sep='\\t')\n", - "class_map = {row.TUI:row.LABEL for row in entity_def.itertuples() if row.LABEL != -1}\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Load Ontology Labeling Sources\n", - "### A. Unified Medical Language System (UMLS) Metathesaurus\n", - "The UMLS Metathesaurus is a convenient source for deriving labels, since it provides over 200 source vocabularies (terminologies) with consistent entity categorization provided by the UMLS Semantic Network.\n", - "\n", - "The first time this is run, Trove requires access to the installation zip\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 2.85 ms, sys: 3.17 ms, total: 6.03 ms\n", - "Wall time: 11.8 ms\n" - ] - } - ], - "source": [ - "%%time\n", - "from trove.labelers.umls import UMLS\n", - "\n", - "# initialize UMLS\n", - "backend = 'pandas'\n", - "if not UMLS.is_initalized(backend=backend):\n", - " print(f'Please initalize the UMLS before running this notebook. See `umls_install.sh`')\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We apply some minimal preprocessing to each source vocabularies term set, as outlined in the Trove paper. The most important settings are:\n", - "- `SmartLowercase()`, a string matching heuristic for preserving likely abbreviations and acronyms\n", - "- `min_char_len`, `filter_rgx`, filters for terms that are single characters or numbers \n", - "\n", - "Other choices are largely for speed purposes, such as restricting the max token length used for string matching. \n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 1min 38s, sys: 8.19 s, total: 1min 47s\n", - "Wall time: 1min 43s\n" - ] - } - ], - "source": [ - "%%time\n", - "from trove.labelers.umls import UMLS\n", - "from trove.transforms import SmartLowercase\n", - "\n", - "# english stopwords\n", - "stopwords = set(open('data/stopwords.txt','r').read().splitlines())\n", - "stopwords = stopwords.union(set([t[0].upper() + t[1:] for t in stopwords]))\n", - "\n", - "# options for filtering terms\n", - "config = {\n", - " \"type_mapping\" : \"TUI\", # TUI = semantic types, CUI = concept ids\n", - " 'min_char_len' : 2,\n", - " 'max_tok_len' : 8,\n", - " 'min_dict_size' : 500,\n", - " 'stopwords' : stopwords,\n", - " 'transforms' : [SmartLowercase()],\n", - " 'languages' : {\"ENG\"},\n", - " 'filter_sabs' : {\"SNOMEDCT_VET\"},\n", - " 'filter_rgx' : r'''^[-+]*[0-9]+([.][0-9]+)*$''' # filter numbers\n", - "}\n", - "\n", - "umls = UMLS(backend=backend, **config)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 57.1 s, sys: 1.48 s, total: 58.6 s\n", - "Wall time: 57.9 s\n" - ] - } - ], - "source": [ - "%%time\n", - "import numpy as np\n", - "\n", - "def map_entity_classes(dictionary, class_map):\n", - " \"\"\"\n", - " Given a dictionary, create the term entity class probabilities\n", - " \"\"\"\n", - " k = len([y for y in set(class_map.values()) if y != -1])\n", - " ontology = {}\n", - " for term in dictionary:\n", - " proba = np.zeros(shape=k).astype(np.float32)\n", - " for cls in dictionary[term]:\n", - " # ignore abstains\n", - " idx = class_map[cls] if cls in class_map else -1\n", - " if idx != -1:\n", - " proba[idx - 1] += 1\n", - " # don't include terms that don't map to any classes\n", - " if np.sum(proba) > 0:\n", - " ontology[term] = proba / np.sum(proba)\n", - " return ontology\n", - "\n", - "# These are the top 10 ontologies as ranked by term overlap with the BC5CDR training set\n", - "terminologies = ['CHV', 'SNOMEDCT_US', 'NCI', 'MSH']\n", - "\n", - "ontologies = {\n", - " sab : map_entity_classes(umls.terminologies[sab], class_map)\n", - " for sab in terminologies\n", - "}\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 9.06 s, sys: 445 ms, total: 9.5 s\n", - "Wall time: 9.51 s\n" - ] - } - ], - "source": [ - "%%time\n", - "\n", - "# create dictionaries for our Schwartz-Hearst abbreviation detection labelers\n", - "positive, negative = set(), set()\n", - "\n", - "for sab in umls.terminologies:\n", - " for term in umls.terminologies[sab]:\n", - " for tui in umls.terminologies[sab][term]:\n", - " if tui in class_map and class_map[tui] == 1:\n", - " positive.add(term)\n", - " elif tui in class_map and class_map[tui] == 0:\n", - " negative.add(term)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### B. Additional Ontologies: ChEBI Database\n", - "We also want to utilize non-UMLS ontologies. External databases such as ChEBI or CTD typically don't include rich mappings to Semantic Network types, so we treat this as an ontology/dictionary mapping to a single class label." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "from ontologies import ChebiDatabase\n", - "\n", - "config = {\n", - " 'min_char_len' : 2,\n", - " 'max_tok_len' : 8,\n", - " 'min_dict_size' : 1,\n", - " 'stopwords' : stopwords,\n", - " 'transforms' : [SmartLowercase()],\n", - " 'languages' : None,\n", - " 'filter_sources': None,\n", - " 'filter_rgx' : r'''^[-+]*[0-9]+([.][0-9]+)*$''' # filter numbers\n", - "}\n", - "chebi = ChebiDatabase(cache_path=None, **config)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "names.tsv.gz: 0.00B [00:00, ?B/s]\n" - ] - }, - { - "ename": "NameError", - "evalue": "name 'filename' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0mdownloader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'ftp://ftp.ebi.ac.uk/pub/databases/chebi/Flat_file_tab_delimited/names.tsv.gz'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\".\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mdownloader\u001b[0;34m(url, save_dir)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0mfname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0murl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mProgressBar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munit\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'B'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munit_scale\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munit_divisor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1024\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mminiters\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdesc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfname\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0murlretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreporthook\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_to\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'filename' is not defined" - ] - } - ], - "source": [ - "import os\n", - "from tqdm import tqdm\n", - "from urllib.request import urlretrieve\n", - "\n", - "\n", - "class ProgressBar(tqdm):\n", - " \"\"\"\n", - " Based on https://gist.github.com/leimao/37ff6e990b3226c2c9670a2cd1e4a6f5\n", - " \"\"\"\n", - " def update_to(self, b=1, bsize=1, tsize=None):\n", - " \"\"\"\n", - " b : int, optional\n", - " Number of blocks transferred so far [default: 1].\n", - " bsize : int, optional\n", - " Size of each block (in tqdm units) [default: 1].\n", - " tsize : int, optional\n", - " Total size (in tqdm units). If [default: None] remains unchanged.\n", - " \"\"\"\n", - " if tsize is not None:\n", - " self.total = tsize\n", - " self.update(b * bsize - self.n) # will also set self.n = b * bsize\n", - "\n", - "def downloader(url, save_dir):\n", - " fname = url.split('/')[-1]\n", - " with ProgressBar(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=fname) as t:\n", - " urlretrieve(url, filename = os.path.join(save_dir, filename), reporthook = t.update_to)\n", - "\n", - " \n", - "downloader('ftp://ftp.ebi.ac.uk/pub/databases/chebi/Flat_file_tab_delimited/names.tsv.gz', \".\")\n", - " \n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using files at cache/ctd/\n" - ] - } - ], - "source": [ - "import requests\n", - "import re\n", - "import numpy as np\n", - "import pandas as pd\n", - "from tqdm import tqdm\n", - "from abc import ABCMeta, abstractmethod\n", - "\n", - "\n", - "\n", - "# def download(url, outdir, block_size=1024):\n", - "# \"\"\"\n", - "# See https://stackoverflow.com/a/37573701\n", - "# \"\"\"\n", - "# print(url)\n", - "# fname = url.split('/')[-1]\n", - "# response = requests.get(url, stream=True)\n", - "# total_bytes= int(response.headers.get('content-length', 0))\n", - "# progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)\n", - "# with open(fname, 'wb') as file:\n", - "# for data in response.iter_content(block_size):\n", - "# progress_bar.update(len(data))\n", - "# file.write(data)\n", - "# progress_bar.close()\n", - "# if total_bytes != 0 and progress_bar.n != total_bytes:\n", - "# print(\"ERROR downloading file\")\n", - "\n", - "from tqdm import tqdm\n", - "import urllib.request\n", - "\n", - "\n", - "class ProgressBar(tqdm):\n", - " \"\"\"\n", - " Based on https://gist.github.com/leimao/37ff6e990b3226c2c9670a2cd1e4a6f5\n", - " \"\"\"\n", - " def update_to(self, b=1, bsize=1, tsize=None):\n", - " \"\"\"\n", - " b : int, optional\n", - " Number of blocks transferred so far [default: 1].\n", - " bsize : int, optional\n", - " Size of each block (in tqdm units) [default: 1].\n", - " tsize : int, optional\n", - " Total size (in tqdm units). If [default: None] remains unchanged.\n", - " \"\"\"\n", - " if tsize is not None:\n", - " self.total = tsize\n", - " self.update(b * bsize - self.n) # will also set self.n = b * bsize\n", - "\n", - "def download(url, save_dir):\n", - " fname = url.split('/')[-1]\n", - " opener = urllib.request.build_opener()\n", - " opener.addheaders = [(\"User-agent\", \"Mozilla/5.0\")]\n", - " urllib.request.install_opener(opener)\n", - " \n", - " with ProgressBar(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=fname) as t:\n", - " urllib.request.urlretrieve(url, filename=os.path.join(save_dir, fname), reporthook=t.update_to)\n", - "\n", - "def apply_transforms(term, transforms):\n", - " for tf in transforms:\n", - " term = tf(term.strip())\n", - " if not term:\n", - " return None\n", - " return term\n", - " \n", - " \n", - "class KnowledgeBase(metaclass=ABCMeta):\n", - " \"\"\"\n", - " We use Knowledge Base to loosely refer to a structured resource\n", - " that contains terminology information. We are interested in the \n", - " following properties:\n", - " \n", - " - term typing\n", - " - synonomy\n", - " \n", - " When source information is available, we store the above info mapped to source.\n", - " \n", - " \"\"\"\n", - " _cache_path = \"cache/\"\n", - " \n", - " def __init__(self, cache_path, files, force_download=False):\n", - " \n", - " self.cache_path = cache_path\n", - " self.files = files\n", - " \n", - " if not self._check_cache() or force_download:\n", - " self._download()\n", - " else:\n", - " print(f\"Using files at {self.cache_path}\")\n", - " \n", - " def _download(self):\n", - " \n", - " for fname,url in self.files.items():\n", - " download(url, self.cache_path)\n", - " \n", - " def _check_cache(self):\n", - " \"\"\"\n", - " Confirm all file dependencies exist in the cache.\n", - " \"\"\"\n", - " if not os.path.exists(self.cache_path):\n", - " os.makedirs(self.cache_path)\n", - " return False\n", - " \n", - " for fname in self.files:\n", - " if not os.path.exists(f\"{self.cache_path}{fname}\"):\n", - " return False\n", - " return True\n", - " \n", - " @abstractmethod\n", - " def name(self):\n", - " ...\n", - " \n", - " @abstractmethod\n", - " def manifest(self):\n", - " ... \n", - " \n", - " @abstractmethod\n", - " def _load(self, **kwargs):\n", - " ...\n", - " \n", - " @abstractmethod\n", - " def get_source_terms(self):\n", - " ...\n", - " \n", - " @abstractmethod\n", - " def get_source_synsets(self):\n", - " ...\n", - " \n", - "##'names.tsv.gz':'ftp://ftp.ebi.ac.uk/pub/databases/chebi/Flat_file_tab_delimited/names.tsv.gz'\n", - " \n", - "\n", - "class CtdDatabase(KnowledgeBase):\n", - " \"\"\"\n", - " TODO: CTD contains additional entity type information we can encode as an Ontology LF\n", - " \"\"\"\n", - " def __init__(self, cache_path=None, **kwargs):\n", - " \n", - " cache_root = cache_path if cache_path else KnowledgeBase._cache_path\n", - " force_download = kwargs['force_download'] if 'force_download' in kwargs else False\n", - " \n", - " super().__init__(\n", - " cache_path = f\"{cache_root}{self.name}/\",\n", - " files = self.manifest,\n", - " force_download = force_download\n", - " )\n", - " \n", - " self.terms = {}\n", - " self.data = self._load()\n", - " \n", - " for name,key in {'disease':'DiseaseName', 'chemical':'ChemicalName'}.items():\n", - " self.terms[name] = self._collapse_terms(self.data[name], key)\n", - " self.terms[name] = self._transform_terminologies(self.terms[name], **kwargs)\n", - " \n", - " # TODO\n", - " self.synset = {}\n", - " \n", - " @property\n", - " def name(self):\n", - " return 'ctd'\n", - " \n", - " @property\n", - " def manifest(self):\n", - " return {\n", - " 'CTD_diseases.csv.gz' : 'http://ctdbase.org/reports/CTD_diseases.csv.gz',\n", - " 'CTD_chemicals.csv.gz' : 'http://ctdbase.org/reports/CTD_chemicals.csv.gz'\n", - " } \n", - " \n", - " def _collapse_terms(self, df, key):\n", - " \"\"\"\n", - " CTD includes ID: terms -> synonyms. We just collapse \n", - " all terms into a single entity dictionary.\n", - " \"\"\"\n", - " terms = set()\n", - " for row in df.itertuples():\n", - " if not pd.isnull(getattr(row, key)):\n", - " terms.add(getattr(row, key))\n", - " if not pd.isnull(row.Synonyms):\n", - " for term in row.Synonyms.split(\"|\"):\n", - " terms.add(term)\n", - " return terms\n", - "\n", - " \n", - " def _load_disease_data(self):\n", - " \n", - " columns = [\n", - " 'DiseaseName',\n", - " 'DiseaseID',\n", - " 'AltDiseaseIDs',\n", - " 'Definition',\n", - " 'ParentIDs',\n", - " 'TreeNumbers',\n", - " 'ParentTreeNumbers',\n", - " 'Synonyms',\n", - " 'SlimMappings'\n", - " ]\n", - " \n", - " fpath = f\"{self.cache_path}/CTD_diseases.csv.gz\"\n", - " return pd.read_csv(\n", - " fpath, \n", - " comment='#', \n", - " sep=',', \n", - " names=columns,\n", - " dtype=str\n", - " )\n", - " \n", - " def _load_chemical_data(self):\n", - " \n", - " columns = [\n", - " 'ChemicalName',\n", - " 'ChemicalID',\n", - " 'CasRN',\n", - " 'Definition',\n", - " 'ParentIDs',\n", - " 'TreeNumbers',\n", - " 'ParentTreeNumbers',\n", - " 'Synonyms'\n", - " ]\n", - " \n", - " fpath = f\"{self.cache_path}/CTD_chemicals.csv.gz\"\n", - " return pd.read_csv(\n", - " fpath, \n", - " comment='#', \n", - " sep=',', \n", - " names=columns,\n", - " dtype=str\n", - " )\n", - " \n", - " def _transform_terminologies(self,\n", - " terms,\n", - " min_char_len=2,\n", - " max_tok_len=100,\n", - " transforms=None,\n", - " filter_rgx=r'''^[0-9]$''',\n", - " stopwords=None,\n", - " **kwargs):\n", - " \n", - " transforms = [] if not transforms else transforms\n", - " filter_rgx = re.compile(filter_rgx) if filter_rgx else None\n", - " stopwords = {} if not stopwords else stopwords\n", - "\n", - " def include(t):\n", - " return t and len(t) >= min_char_len and \\\n", - " t.count(' ') <= max_tok_len - 1 and \\\n", - " t not in stopwords and \\\n", - " (filter_rgx and not filter_rgx.search(t))\n", - " \n", - " tmp = set()\n", - " for term in terms:\n", - " term = apply_transforms(term, transforms)\n", - " if include(term):\n", - " tmp.add(term)\n", - " return tmp\n", - " \n", - " def _load(self):\n", - " \n", - " return {\n", - " \"disease\" : self._load_disease_data(),\n", - " \"chemical\" : self._load_chemical_data()\n", - " }\n", - " \n", - " def get_source_terms(self, source):\n", - " assert source in self.data\n", - " return self.terms[source]\n", - " \n", - " def get_source_synsets(self, source):\n", - " pass\n", - " \n", - "ctd = CtdDatabase(**config)\n", - "\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "#ctd_terms = ctd.get_source_terms('disease')\n" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "#class BioPortalDatabase(KnowledgeBase):\n", - "# pass\n", - " \n", - " \n", - "# class CtdDatabase(KnowledgeBase):\n", - "# pass\n", - "\n", - "# class SpecialistLexicon(KnowledgeBase):\n", - "# pass\n", - "\n", - "\n", - "\n", - "# def load_ctd_dictionary(filename, stopwords=None):\n", - "# '''Comparative Toxicogenomics Database'''\n", - "# stopwords = stopwords if stopwords else {}\n", - " \n", - "# d = {}\n", - "# header = ['DiseaseName', 'DiseaseID', 'AltDiseaseIDs', 'Definition', \n", - "# 'ParentIDs', 'TreeNumbers', 'ParentTreeNumbers', 'Synonyms', \n", - "# 'SlimMappings']\n", - " \n", - "# synonyms = {}\n", - "# dnames = {}\n", - "# with open(filename,\"r\") as fp:\n", - "# for i,line in enumerate(fp):\n", - "# line = line.strip()\n", - "# if line[0] == \"#\":\n", - "# continue\n", - "# row = line.split(\"\\t\")\n", - "# if len(row) != 9:\n", - "# continue\n", - "# row = dict(zip(header,row))\n", - " \n", - "# synset = row[\"Synonyms\"].strip().split(\"|\")\n", - "# if synset:\n", - "# synonyms.update(dict.fromkeys(synset))\n", - "# term = row[\"DiseaseName\"].strip()\n", - "# if term:\n", - "# dnames[term] = 1\n", - " \n", - "# terms = {lowercase(t) for t in set(list(synonyms.keys()) + list(dnames.keys())) if t}\n", - "# # filter out stopwords \n", - "# return {t for t in terms if t not in stopwords and not re.search(r'''^[0-9]$''',t)}\n", - "\n", - "# class AdamDictionary\n", - "\n", - "# def get_url(self) -> str:\n", - "# return (\n", - "# \"http://arrowsmith.psych.uic.edu/arrowsmith_uic/download/adam.tar\"\n", - "# )\n", - "\n", - "\n", - "# class CTD:\n", - "# \"\"\"\n", - "# Comparative Toxicogenomics Database\n", - "# \"\"\"\n", - " \n", - "# _cfg = {\n", - "# 'url': 'ftp://ftp.ebi.ac.uk/pub/databases/chebi/Flat_file_tab_delimited/names.tsv.gz'\n", - "# }\n", - "# _cache_path = \"cache/chebi/\"\n", - " \n", - "# def __init__(self, cache_path, **kwargs):\n", - "# self.cache_path = cache_path\n", - "# self.df = self._load_terminologies(**kwargs)\n", - " \n", - "# def terms(self, filter_sources=None):\n", - " \n", - "# filter_sources = filter_sources if filter_sources else {}\n", - "# terms = set()\n", - "# for source in self.terminologies:\n", - "# if source in filter_sources:\n", - "# continue\n", - "# terms = terms.union(self.terminologies[source])\n", - "# return terms\n", - " \n", - "# def _load_terminologies(self,\n", - "# min_char_len=2,\n", - "# max_tok_len=100,\n", - "# min_dict_size=1,\n", - "# languages=None,\n", - "# transforms=None,\n", - "# filter_sources=None,\n", - "# filter_rgx=None,\n", - "# stopwords=None):\n", - " \n", - "# # defaults\n", - "# languages = languages if languages else {}\n", - "# transforms = [] if not transforms else transforms\n", - "# filter_sources = filter_sources if filter_sources else {}\n", - "# filter_rgx = re.compile(filter_rgx) if filter_rgx else None\n", - "# stopwords = {} if not stopwords else stopwords\n", - "\n", - "# def include(t):\n", - "# return t and len(t) >= min_char_len and \\\n", - "# t.count(' ') <= max_tok_len - 1 and \\\n", - "# t not in stopwords and \\\n", - "# (filter_rgx and not filter_rgx.search(t))\n", - " \n", - "# df = pd.read_csv('/users/fries/downloads/names.tsv', \n", - "# sep='\\t', \n", - "# na_filter=False, \n", - "# dtype={'NAME':'object', 'COMPOUND_ID':'object'})\n", - " \n", - "# self.terminologies = {}\n", - "# if languages:\n", - "# df = df[df.LANGUAGE.isin(languages)]\n", - " \n", - "# for source, data in df.groupby(['SOURCE']):\n", - "# if source in filter_sources or len(data) < min_dict_size:\n", - "# continue\n", - "# self.terminologies[source] = set()\n", - " \n", - "# for term in data.NAME:\n", - "# term = apply_transforms(term, transforms)\n", - "# if include(term):\n", - "# self.terminologies[source].add(term)\n", - "# self.data = df\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### C. ADAM Biomedical Abbreviations" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [ - "# TBD" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Create Sequence Labeling Functions\n", - "### A. Guideline Labeling Functions\n", - "\n", - "Annotation guidelines -- the instructions provided to domain experts when labeling training data -- can have a big impact on the generalizability of named enity classifiers. These instructions include seeminly simple choices such as whether to include determiners in entity spans (\"the XXX\") or more complex tagging choices like not labeling negated mentions of drugs. These choices are baked into the dataset and expensive to change. \n", - "\n", - "With weak supervision, many of these annotation assumptions can encoded as labeling functions, making training set changes faster, more flexible, and lower cost. For our `Chemical` labeling functions, we use the instructions provided [here](https://biocreative.bioinformatics.udel.edu/media/store/files/2015/bc5_CDR_data_guidelines.pdf) (pages 5-6) to create small dictionaries encoding some of these guidelines. Note that these can be easily expanded on, and in some cases complex rules (e.g., not annotating polypeptides with more than 15 amino acids) can be coupled with richer structured resources to create more sophisticated rules. \n", - "\n", - "We also fine it useful to include labeling functions that exclude numbers and punctuation tokens, another common flag in online biomedical annotators. \n" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [], - "source": [ - "from trove.labelers.labeling import (\n", - " OntologyLabelingFunction,\n", - " DictionaryLabelingFunction, \n", - " RegexEachLabelingFunction\n", - ")\n", - "\n", - "# load our guideline dictionaries\n", - "df = pd.read_csv('data/bc5cdr_guidelines.tsv', sep='\\t',)\n", - "guidelines = {\n", - " t:np.array([1.,0.]) if y==1 else np.array([0.,1.]) \n", - " for t,y in zip(df.TERM, df.LABEL)\n", - "}\n", - "\n", - "# use guideline negative examples as an additional stopword list\n", - "guideline_stopwords = {t:2 for t in df[df.LABEL==0].TERM}\n", - "stopwords = {t:2 for t in stopwords}\n", - "\n", - "guideline_lfs = [\n", - " OntologyLabelingFunction('guidelines', guidelines),\n", - " DictionaryLabelingFunction('stopwords', stopwords, 2),\n", - " DictionaryLabelingFunction('punctuation', set('!\"#$%&*+,./:;<=>?@[\\\\]^_`{|}~'), 2),\n", - " RegexEachLabelingFunction('numbers', [r'''^[-]*[1-9]+[0-9]*([.][0-9]+)*$'''], 2)\n", - "]\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### B. Semantic Type Labeling Functions\n", - "\n", - "The bulk of our supervision comes from structured medical ontologies. " - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 16.1 s, sys: 205 ms, total: 16.3 s\n", - "Wall time: 16.2 s\n" - ] - } - ], - "source": [ - "%%time\n", - "from trove.labelers.abbreviations import SchwartzHearstLabelingFunction\n", - "\n", - "ontology_lfs = [\n", - " OntologyLabelingFunction(\n", - " f'UMLS_{name}', \n", - " ontologies[name], \n", - " stopwords=guideline_stopwords \n", - " )\n", - " for name in ontologies\n", - "]\n", - "\n", - "ontology_lfs += [\n", - " SchwartzHearstLabelingFunction('UMLS_schwartz_hearst_1', positive, 1, stopwords=guideline_stopwords),\n", - " SchwartzHearstLabelingFunction('UMLS_schwartz_hearst_2', negative, 2)\n", - "]\n" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [], - "source": [ - "ext_ontology_lfs = [\n", - " DictionaryLabelingFunction('CHEBI', chebi.terms(), 1, stopwords=guideline_stopwords),\n", - " \n", - " #DictionaryLabelingFunction('DOID', doid.terms(), 1, stopwords=guideline_stopwords),\n", - " #DictionaryLabelingFunction('HP', hp.terms(), 1, stopwords=guideline_stopwords),\n", - " #DictionaryLabelingFunction('AutoNER', autoner.terms(), 1, stopwords=guideline_stopwords)\n", - " \n", - " DictionaryLabelingFunction('CTD_chemical', ctd.get_source_terms('chemical'), 1, stopwords=guideline_stopwords),\n", - " DictionaryLabelingFunction('CTD_disease', ctd.get_source_terms('disease'), 2, stopwords=guideline_stopwords)\n", - " \n", - " \n", - " \n", - "]\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### C. SynSet Labeling Functions\n", - "\n", - "For biomedical concepts, abbreviations and acronymns (more generally \"short forms\") are a large source of ambiguity. \n", - "These can be ambiguous to human readers as well, so authors of PubMed abstract typically define ambiguous terms when they are introduced in text. We can take adavantage of this redundancy to both handle ambiguous mentions and identify out-of-ontology short forms using classic text mining techniques such as the [Schwartz-Hearst algorithm](https://psb.stanford.edu/psb-online/proceedings/psb03/schwartz.pdf)." - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [], - "source": [ - "# #\n", - "# # TBD\n", - "# #\n", - "\n", - "# synset_lfs = [\n", - "# SynSetLabelingFunction('SPECIALIST_synsets'),\n", - "# SynSetLabelingFunction('ADAM_synsets'),\n", - "# ]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### D. Task-specific Labeling Functions\n", - "\n", - "Ontology-based labeling functions can do suprisingly well on their own, but we can get more performance gains by adding custom labeling functions. For this demo, we focus on simple rules that are easy to create via data exploration but any existing rule-based model can be transformed into a labeling function. " - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [], - "source": [ - "import re\n", - "from trove.labelers.labeling import RegexLabelingFunction\n", - "\n", - "task_specific_lfs = []\n", - "\n", - "# We noticed parentheses were causing errors so this labeling function \n", - "# identifies negative examples, e.g. (n=100), (10%)\n", - "parens_rgxs = [\n", - " r'''[(](p|n)\\s*([><=]+|(less|great)(er)*)|(ml|mg|kg|g|(year|day|month)[s]*)[)]|[(][0-9]+[%][)]'''\n", - "]\n", - "# case insensitive \n", - "parens_rgxs = [re.compile(rgx, re.I) for rgx in parens_rgxs]\n", - "task_specific_lfs.append(RegexLabelingFunction('LF_parentheses', parens_rgxs, 2))\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [], - "source": [ - "lfs = guideline_lfs + ontology_lfs + ext_ontology_lfs #+ task_specific_lfs " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Construct the Label Matrix $\\Lambda$\n", - "### A. Apply Sequence Labeling Functions" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Parallel(n_jobs=4)\n", - "auto block size=3495\n", - "Partitioned into 4 blocks, [3494 3495] sizes\n", - "CPU times: user 20.8 s, sys: 4.06 s, total: 24.8 s\n", - "Wall time: 1min 28s\n" - ] - } - ], - "source": [ - "%%time\n", - "import itertools\n", - "from trove.labelers.core import SequenceLabelingServer\n", - "\n", - "X_sents = [\n", - " dataset['train'].sentences,\n", - " dataset['dev'].sentences,\n", - " dataset['test'].sentences,\n", - "]\n", - "\n", - "labeler = SequenceLabelingServer(num_workers=4)\n", - "L_sents = labeler.apply(lfs, X_sents)\n", - "\n", - "\n", - "# Wall time: 1min 48s" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [], - "source": [ - "import itertools\n", - "\n", - "splits = ['train', 'dev', 'test']\n", - "tag2idx = {'O':2, 'I-Chemical':1}\n", - "\n", - "X_words = [\n", - " np.array(list(itertools.chain.from_iterable([s.words for s in X_sents[i]]))) \n", - " for i,name in enumerate(splits)\n", - "]\n", - "\n", - "X_seq_lens = [\n", - " np.array([len(s.words) for s in X_sents[i]])\n", - " for i,name in enumerate(splits)\n", - "]\n", - "\n", - "X_doc_seq_lens = [ \n", - " np.array([len(doc.sentences) for doc in dataset[name].documents]) \n", - " for i,name in enumerate(splits)\n", - "]\n", - "\n", - "Y_words = [\n", - " [dataset['train'].tagged(i)[-1] for i in range(len(dataset['train']))],\n", - " [dataset['dev'].tagged(i)[-1] for i in range(len(dataset['dev']))],\n", - " [dataset['test'].tagged(i)[-1] for i in range(len(dataset['test']))],\n", - "]\n", - "\n", - "Y_words[0] = np.array([tag2idx[t] for t in list(itertools.chain.from_iterable(Y_words[0]))])\n", - "Y_words[1] = np.array([tag2idx[t] for t in list(itertools.chain.from_iterable(Y_words[1]))])\n", - "Y_words[2] = np.array([tag2idx[t] for t in list(itertools.chain.from_iterable(Y_words[2]))])\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### B. Build the Label Matrix" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 15.6 s, sys: 182 ms, total: 15.8 s\n", - "Wall time: 16.1 s\n" - ] - } - ], - "source": [ - "%%time\n", - "from scipy.sparse import dok_matrix, vstack, csr_matrix\n", - "\n", - "def create_word_lf_mat(Xs, Ls, num_lfs):\n", - " \"\"\"\n", - " Create word-level LF matrix from LFs indexed by sentence/word\n", - " 0 words X lfs\n", - " 1 words X lfs\n", - " 2 words X lfs\n", - " ...\n", - " \n", - " \"\"\"\n", - " Yws = []\n", - " for sent_i in range(len(Xs)):\n", - " ys = dok_matrix((len(Xs[sent_i].words), num_lfs))\n", - " for lf_i in range(num_lfs):\n", - " for word_i,y in Ls[sent_i][lf_i].items():\n", - " ys[word_i, lf_i] = y\n", - " Yws.append(ys)\n", - " return csr_matrix(vstack(Yws))\n", - "\n", - "L_words = [\n", - " create_word_lf_mat(X_sents[0], L_sents[0], len(lfs)),\n", - " create_word_lf_mat(X_sents[1], L_sents[1], len(lfs)),\n", - " create_word_lf_mat(X_sents[2], L_sents[2], len(lfs)),\n", - "]\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### C. Inspect Labeling Function Performance" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
jPolarityCoverage%Overlaps%Conflicts%CoverageCorrectIncorrectEmp. Acc.
guidelines0[1.0, 2.0]0.0060850.0046940.001478704678260.963068
stopwords120.2827960.0216180.0008303271732649680.997922
punctuation220.0994890.0042790.0002511151011425850.992615
numbers320.0353870.0028090.001737409437903040.925745
UMLS_CHV4[1.0, 2.0]0.3521450.3398880.017400407403969610440.974374
UMLS_SNOMEDCT_US5[1.0, 2.0]0.3346330.3296190.01777138714378298850.977140
UMLS_NCI6[1.0, 2.0]0.3970320.3516610.02029545933451158180.982191
UMLS_MSH7[1.0, 2.0]0.1811720.1795040.01131520960204275330.974571
UMLS_schwartz_hearst_1810.0060330.0053940.003371698649490.929799
UMLS_schwartz_hearst_2920.0131300.0120580.004858151912073120.794602
CHEBI1010.0630730.0588030.0200627297517621210.709333
CTD_chemical1110.0482410.0470390.009119558149756060.891417
CTD_disease1220.0413770.0411180.002100478746671200.974932
\n", - "
" - ], - "text/plain": [ - " j Polarity Coverage% Overlaps% Conflicts% \\\n", - "guidelines 0 [1.0, 2.0] 0.006085 0.004694 0.001478 \n", - "stopwords 1 2 0.282796 0.021618 0.000830 \n", - "punctuation 2 2 0.099489 0.004279 0.000251 \n", - "numbers 3 2 0.035387 0.002809 0.001737 \n", - "UMLS_CHV 4 [1.0, 2.0] 0.352145 0.339888 0.017400 \n", - "UMLS_SNOMEDCT_US 5 [1.0, 2.0] 0.334633 0.329619 0.017771 \n", - "UMLS_NCI 6 [1.0, 2.0] 0.397032 0.351661 0.020295 \n", - "UMLS_MSH 7 [1.0, 2.0] 0.181172 0.179504 0.011315 \n", - "UMLS_schwartz_hearst_1 8 1 0.006033 0.005394 0.003371 \n", - "UMLS_schwartz_hearst_2 9 2 0.013130 0.012058 0.004858 \n", - "CHEBI 10 1 0.063073 0.058803 0.020062 \n", - "CTD_chemical 11 1 0.048241 0.047039 0.009119 \n", - "CTD_disease 12 2 0.041377 0.041118 0.002100 \n", - "\n", - " Coverage Correct Incorrect Emp. Acc. \n", - "guidelines 704 678 26 0.963068 \n", - "stopwords 32717 32649 68 0.997922 \n", - "punctuation 11510 11425 85 0.992615 \n", - "numbers 4094 3790 304 0.925745 \n", - "UMLS_CHV 40740 39696 1044 0.974374 \n", - "UMLS_SNOMEDCT_US 38714 37829 885 0.977140 \n", - "UMLS_NCI 45933 45115 818 0.982191 \n", - "UMLS_MSH 20960 20427 533 0.974571 \n", - "UMLS_schwartz_hearst_1 698 649 49 0.929799 \n", - "UMLS_schwartz_hearst_2 1519 1207 312 0.794602 \n", - "CHEBI 7297 5176 2121 0.709333 \n", - "CTD_chemical 5581 4975 606 0.891417 \n", - "CTD_disease 4787 4667 120 0.974932 " - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from trove.metrics.analysis import lf_summary\n", - "\n", - "lf_summary(L_words[0], Y=Y_words[0], lf_names=[lf.name for lf in lfs])\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Train the Label Model" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [], - "source": [ - "# Trove uses a different internal mapping for labeling function abstains\n", - "def convert_label_matrix(L):\n", - " # abstain is -1\n", - " # negative is 0\n", - " L = L.toarray().copy()\n", - " L[L == 0] = -1\n", - " L[L == 2] = 0\n", - " return L\n", - "\n", - "L_words_hat = [\n", - " convert_label_matrix(L_words[0]),\n", - " convert_label_matrix(L_words[1]),\n", - " convert_label_matrix(L_words[2])\n", - "]\n", - "\n", - "Y_words_hat = [\n", - " np.array([0 if y == 2 else 1 for y in Y_words[0]]),\n", - " np.array([0 if y == 2 else 1 for y in Y_words[1]]),\n", - " np.array([0 if y == 2 else 1 for y in Y_words[2]])\n", - "]\n" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Hyperparamater Search Space: 192\n", - "Using SEQUENCE dev checkpointing\n", - "Using IO dev checkpointing\n", - "Grid search over 25 configs\n", - "[0] Label Model\n", - "[1] Label Model\n", - "[2] Label Model\n", - "[3] Label Model\n", - "[4] Label Model\n", - "[5] Label Model\n", - "[6] Label Model\n", - "[7] Label Model\n", - "[8] Label Model\n", - "{'lr': 0.0001, 'l2': 0.0001, 'n_epochs': 600, 'prec_init': 0.6, 'optimizer': 'adamax', 'lr_scheduler': 'constant'}\n", - "[TRAIN] accuracy: 97.96 | precision: 84.63 | recall: 84.34 | f1: 84.48\n", - "[DEV] accuracy: 98.19 | precision: 86.60 | recall: 86.49 | f1: 86.55\n", - "----------------------------------------------------------------------------------------\n", - "[9] Label Model\n", - "[10] Label Model\n", - "[11] Label Model\n", - "{'lr': 0.001, 'l2': 0.001, 'n_epochs': 100, 'prec_init': 0.6, 'optimizer': 'adamax', 'lr_scheduler': 'constant'}\n", - "[TRAIN] accuracy: 98.01 | precision: 84.38 | recall: 85.34 | f1: 84.85\n", - "[DEV] accuracy: 98.22 | precision: 86.59 | recall: 87.69 | f1: 87.13\n", - "----------------------------------------------------------------------------------------\n", - "[12] Label Model\n", - "[13] Label Model\n", - "[14] Label Model\n", - "[15] Label Model\n", - "[16] Label Model\n", - "[17] Label Model\n", - "[18] Label Model\n", - "[19] Label Model\n", - "[20] Label Model\n", - "[21] Label Model\n", - "[22] Label Model\n", - "[23] Label Model\n", - "[24] Label Model\n", - "BEST\n", - "{'lr': 0.001, 'l2': 0.001, 'n_epochs': 100, 'prec_init': 0.6, 'optimizer': 'adamax', 'lr_scheduler': 'constant'}\n" - ] - } - ], - "source": [ - "import functools\n", - "from trove.models.model_search import grid_search\n", - "from snorkel.labeling.model.label_model import LabelModel\n", - "\n", - "np.random.seed(1234)\n", - "\n", - "n = L_words_hat[0].shape[0]\n", - "\n", - "param_grid = {\n", - " 'lr': [0.01, 0.005, 0.001, 0.0001],\n", - " 'l2': [0.001, 0.0001],\n", - " 'n_epochs': [50, 100, 200, 600, 700, 1000],\n", - " 'prec_init': [0.6, 0.7, 0.8, 0.9],\n", - " 'optimizer': [\"adamax\"], \n", - " 'lr_scheduler': ['constant'],\n", - "\n", - "# 'seed': list(np.random.randint(0,10000, 400)),\n", - "# 'mu_eps': [1 / 10 ** np.ceil(np.log10(n*100)), \n", - "# 1 / 10 ** np.ceil(np.log10(n*10)),\n", - "# 1 / 10 ** np.ceil(np.log10(n))]\n", - "}\n", - "\n", - "model_class_init = {\n", - " 'cardinality': 2, \n", - " 'verbose': True\n", - "}\n", - "\n", - "n_model_search = 25\n", - "num_hyperparams = functools.reduce(lambda x,y:x*y, [len(x) for x in param_grid.values()])\n", - "print(\"Hyperparamater Search Space:\", num_hyperparams)\n", - "\n", - "\n", - "L_train = L_words_hat[0]\n", - "Y_train = Y_words_hat[0]\n", - "L_dev = L_words_hat[1]\n", - "Y_dev = Y_words_hat[1]\n", - "\n", - "label_model, best_config = grid_search(LabelModel, \n", - " model_class_init, \n", - " param_grid,\n", - " train = (L_train, Y_train, X_seq_lens[0]),\n", - " dev = (L_dev, Y_dev, X_seq_lens[1]),\n", - " n_model_search=n_model_search, \n", - " val_metric='f1', \n", - " seq_eval=True,\n", - " seed=1234,\n", - " tag_fmt_ckpnt='IO')" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "BIO Tag Format\n", - "[Label Model] accuracy: 98.01 | precision: 84.38 | recall: 85.34 | f1: 84.85\n", - "[Majority Vote] accuracy: 97.57 | precision: 76.15 | recall: 83.12 | f1: 79.48\n", - "--------------------------------------------------------------------------------\n", - "BIO Tag Format\n", - "[Label Model] accuracy: 98.22 | precision: 86.59 | recall: 87.69 | f1: 87.13\n", - "[Majority Vote] accuracy: 97.83 | precision: 78.47 | recall: 85.42 | f1: 81.80\n", - "--------------------------------------------------------------------------------\n", - "BIO Tag Format\n", - "[Label Model] accuracy: 98.39 | precision: 86.10 | recall: 87.23 | f1: 86.66\n", - "[Majority Vote] accuracy: 97.85 | precision: 76.23 | recall: 84.52 | f1: 80.16\n", - "--------------------------------------------------------------------------------\n" - ] - } - ], - "source": [ - "from trove.metrics import eval_label_model # get_coverage, \n", - "\n", - "for i in range(3):\n", - " #get_coverage(L_words_hat[i], Y_words_hat[i])\n", - " #print(\"IO Tag Format\")\n", - " #eval_label_model(label_model, L_words_hat[i], Y_words_hat[i], X_seq_lens[i])\n", - " print(\"BIO Tag Format\")\n", - " eval_label_model(label_model, L_words_hat[i], Y_words_hat[i], X_seq_lens[i])\n", - " print('-' * 80)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Export Proba Conll" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#\n", - "# TBD\n", - "#" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "trove", - "language": "python", - "name": "trove" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.11" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -}