Skip to content

Commit

Permalink
Initial commit of Trove legacy code
Browse files Browse the repository at this point in the history
  • Loading branch information
jason-fries committed Nov 27, 2020
1 parent ed9da6f commit 1468a06
Show file tree
Hide file tree
Showing 147 changed files with 869,402 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# OS generated files
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
67 changes: 67 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Trove
### NOTE: This is branch is depricated and being refactored into master

Tools for weakly supervised sequence labeling and span classification in biomedical text. Code is provided for training and evaluating Snorkel models for unsupervised ensembling of dictionary and other heuristic methods of labeling text.

**Benchmark Tasks**

- 4 NER tasks from scientific literature (BC5CDR) and EHRs (ShARe/CLEF, i2b2) for *Disease, Chemical, Disorder, Drug* entities
- 2 span classification (ShARe/CLEF, THYME) (*Negation*, *Temporality*)

Each NER task is evaluated under different ablation settings assuming increasing degrees of available labeling resources. Tiers are additive.

1. Annotation Guideline Rules
2. UMLS Ontologies
3. 3rd Party Ontologies / Text-mined Lexicons
4. Task-specific Labeling Functions

## I. Quick Start

Download datasets and dictionaries dependencies [here](https://drive.google.com/drive/folders/1zeZPZUBlV-jh-WCDK8YnkIU3Pm-LSqwu?usp=sharing)

This assumes your documents have been preprocessed into a JSON container format.

### 1. Generate Label Matrices
```
python apply-lfs.py \
--indir <INDIR> \
--outdir <OUTDIR> \
--data_root <DATA> \
--corpus cdr \
--domain pubmed \
--tag_fmt IO \
--entity_type chemical \
--top_k 4 \
--active_tiers 1234 \
--n_procs 4 > chemical.lfs.log
```

### 2. Train Label Model
Use the default Snorkel label model to

```
python train-label-model.py \
--indir <INDIR> \
--outdir <OUTDIR> \
--train 0 \
--dev 1 \
--test 2 \
--n_model_search 50 \
--tag_fmt_ckpnt BIO \
--seed 1234 > chemical.label_model.log
```

### 3. Train End Model (e.g., BERT)

```
python proba-train-bert.py \
--train chemical.train.prior_balance.proba.conll.tsv \
--dev chemical.dev.prior_balance.proba.conll.tsv \
--test chemical.test.prior_balance.proba.conll.tsv \
--model biobert_v1.1_pubmed-pytorch/ \
--device cuda \
--n_epochs 10 \
--lr 1e-5 \
--seed 1234 > chemical.bert.log
```

186 changes: 186 additions & 0 deletions apply-label-model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import re
import os
import sys
import time
import glob
import torch
import pandas as pd
import snorkel
import collections
import numpy as np
import argparse
import pickle
import itertools
from snorkel.labeling import LabelModel
from trove.analysis.metrics import split_by_seq_len
from trove.data.dataloaders.contexts import Span


def timeit(f):
"""
Decorator for timing function calls
:param f:
:return:
"""
def timed(*args, **kw):
ts = time.time()
result = f(*args, **kw)
te = time.time()
print(f'{f.__name__} took: {te - ts:2.4f} sec')
return result

return timed


def convert_label_matrix(L):
# abstain is -1
# negative is 0
L = L.toarray().copy()
L[L == 0] = -1
L[L == 2] = 0
return L


def dump_entity_spans(outfpath,
y_word_probas,
sentences,
b=0.5,
stopwords=None):
"""
Given per-word probabilities, generate all
entity spans (assumes IO tagging)
"""
stopwords = stopwords if stopwords else None

snapshot = []
seq_lens = [len(s) for s in sentences]
y_sents = split_by_seq_len(y_word_probas[..., 1], seq_lens)

for sent, y_proba in zip(sentences, y_sents):

y_pred = [1 if p > b else 0 for p in y_proba]

curr = []
spans = []
for i, (word, y, char_start) in enumerate(
zip(sent.words, y_pred, sent.abs_char_offsets)):
if y == 1:
curr.append((word, char_start))
elif y == 0 and curr:
curr = [(w, ch) for w, ch in curr if len(w.strip()) != 0]
spans.append(list(zip(*curr)))
curr = []
if curr:
curr = [(w, ch) for w, ch in curr if len(w.strip()) != 0]
spans.append(list(zip(*curr)))

# initalize spans
unique_spans = set()
for s in spans:
if not s:
continue

toks, offsets = s
start = offsets[0] - sent.abs_char_offsets[0]
end = offsets[-1] + len(toks[-1]) - sent.abs_char_offsets[0]
span = Span(char_start=start, char_end=end - 1, sentence=sent)
if len(span.text) <= 1 or span.text in stopwords:
continue
unique_spans.add(span)

# export rows
for span in unique_spans:
norm_str = re.sub('\s{2,}', ' ', span.text.strip())
norm_str = norm_str.replace('\n', '')
snapshot.append(
[span.sentence.document.name, norm_str, span.abs_char_start,
span.abs_char_end])
snapshot[-1] = list(map(str, snapshot[-1]))

# write to TSV
with open(outfpath, 'w') as fp:
for row in snapshot:
fp.write('\t'.join(row) + '\n')

@timeit
def main(args):

# -------------------------------------------------------------------------
# Load Dataset and L Matrices
# -------------------------------------------------------------------------
X_sents = pickle.load(open(f'{args.indir}/X_sents', 'rb'))
X_seq_lens = pickle.load(open(f'{args.indir}/X_seq_lens', 'rb'))
X_doc_seq_lens = pickle.load(open(f'{args.indir}/X_doc_seq_lens', 'rb'))
L_words = pickle.load(open(f'{args.indir}/L_words', 'rb'))
X_words = pickle.load(open(f'{args.indir}/X_words', 'rb'))
Y_words = pickle.load(open(f'{args.indir}/Y_words', 'rb'))

for i in range(len(Y_words)):
freq = collections.Counter()
for y in Y_words[i]:
freq[y] += 1
print(i, freq)

for i in range(len(L_words)):
size = L_words[i].shape if L_words[i] is not None else None
print(f'i={i} {size}')
print(f'train={args.train} dev={args.dev} test={args.test}')

# transform label matrices to Snorkel v9.3+ format
if args.label_fmt == 'snorkel':
L_words = [
convert_label_matrix(L_words[0]),
convert_label_matrix(L_words[1]),
convert_label_matrix(L_words[2]),
convert_label_matrix(L_words[3]) \
if L_words[3] is not None else None
]

Y_words = [
np.array([0 if y == 2 else 1 for y in Y_words[0]]),
np.array([0 if y == 2 else 1 for y in Y_words[1]]),
np.array([0 if y == 2 else 1 for y in Y_words[2]]),
np.array([])
]
print("Coverted to Snorkel 9.3+ label matrix format")

# -------------------------------------------------------------------------
# Load Best Model
# -------------------------------------------------------------------------
model = LabelModel(cardinality=2, verbose=True)
model.load(args.model)
print(f"Loaded model from {args.model}")

# -------------------------------------------------------------------------
# Predict Labels
# -------------------------------------------------------------------------
y_proba = model.predict_proba(L_words[args.split])
sentences = list(
itertools.chain.from_iterable(
[doc.sentences for doc in X_sents[args.split]]
)
)

# -------------------------------------------------------------------------
# Dump Entity Spans
# -------------------------------------------------------------------------
dump_entity_spans(args.outdir, y_proba, sentences)


if __name__=="__main__":

parser = argparse.ArgumentParser()

parser.add_argument("--model", type=str, default=None)
parser.add_argument("--input", type=str, default=None)
parser.add_argument("--outdir", type=str, default=None)
parser.add_argument("--split", type=int, default=None)
args = parser.parse_args()

print(f'PyTorch v{torch.__version__}')
print(f'Snorkel v{snorkel.__version__}')
print(args)
main(args)


Loading

0 comments on commit 1468a06

Please sign in to comment.