-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4b95ac3
commit 28cc904
Showing
51 changed files
with
132,098 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# nis-patient-encoding |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import numpy as np | ||
import h5py | ||
import tables | ||
import torch | ||
import torch.nn as nn | ||
from torch.utils.data import DataLoader | ||
|
||
from functools import reduce | ||
|
||
from data.data_loader import NISDatabase | ||
from data.cohort_filterer import CohortFilterer | ||
from similarity.matchers import DeepMatcher | ||
# from similarity.matchers import PropensityScoreMatcher | ||
# from similarity.matchers import FuzzyMatcher | ||
from utils import feature_utils | ||
from utils.code_mappings import map_icd9_to_numeric | ||
|
||
INCLUSION_INFO = {'DXCCS': [122, 109]} | ||
CASE_CONTROL_INFO = {'DX': [map_icd9_to_numeric('42731'),]} | ||
COHORT_FILTER_INFO = {'inclusion': INCLUSION_INFO, 'case_control': CASE_CONTROL_INFO} | ||
|
||
DATA_FOLDER = '/home/aisinai/work/repos/nis_patient_encoding/data/raw/' | ||
|
||
DATASETS = { | ||
'filtering' : f'{DATA_FOLDER}NIS_Pruned.h5', | ||
'matching' : f'{DATA_FOLDER}NIS_2012_2014_proto_emb_v2.h5', | ||
'eval_input': f'{DATA_FOLDER}NIS_2012_2014_proto_emb_v2.h5', | ||
'eval_output': f'{DATA_FOLDER}NIS_2012_2014_proto_emb_v2.h5', | ||
} | ||
|
||
MATCHERS = { | ||
# 'propensity' : , | ||
'deep' : DeepMatcher, | ||
# 'fuzzy' : FuzzyMatcher, | ||
} | ||
|
||
class CohortBuilder(): | ||
"""Manages initial cohort construction. | ||
This class takes in the full DataFrame with raw ICD-9 codes to find the indices that map to particular cohorts. | ||
These indices may then be used by downstream classes in more pruned/filtered datasets for use in the AE system. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
datasets=DATASETS, | ||
filter_params=COHORT_FILTER_INFO): | ||
|
||
"""Manages loading of NIS database samples in a H5 file. | ||
Note: reimplemented here, but @TODO to refactor this entire class structure. | ||
""" | ||
# super().__init__(filename, 'TRAIN') | ||
|
||
self.filter_params = filter_params | ||
self.datasets = datasets | ||
|
||
matcher_results = {matcher : {'case': [], 'control': []} for matcher in MATCHERS.keys()} | ||
self.inds = { | ||
'filtered': {'case': [], 'control': []}, # from full dataset to those included in current analysis | ||
'matched': {'case': [], 'control': []} # from included in current analysis to matched | ||
} | ||
|
||
self.filtered = False | ||
self.matched = False | ||
|
||
def prepare_for_match(self): | ||
raise NotImplementedError | ||
|
||
def filter_patients(self, sample=None): | ||
"""Build a case-control group based on very fundamental criteria.""" | ||
|
||
filter_dataset = self.datasets['filtering'] | ||
|
||
filterer = CohortFilterer(filter_dataset, self.filter_params) | ||
self.inds['filtered'] = filterer.filter_patients() | ||
|
||
if sample: | ||
case_alias = self.inds['filtered']['case'] | ||
control_alias = self.inds['filtered']['control'] | ||
self.inds['filtered']['case'] = np.random.choice(case_alias, int(sample * case_alias.shape[0]), replace=False) | ||
self.inds['filtered']['control'] = np.random.choice(control_alias, int(sample * control_alias.shape[0]), replace=False) | ||
|
||
self.filtered = True | ||
|
||
def match_patients(self, model, device='cpu', save_dir=None): | ||
"""Match patients for each matcher.""" | ||
|
||
if self.filtered == False: | ||
raise ValueError("Must perform filtering before matching, silly!") | ||
|
||
matcher_input = { | ||
'dataset' : self.datasets['matching'], | ||
'state_inds' : self.inds['filtered'] | ||
} | ||
|
||
self.matchers = {} | ||
|
||
for matcher_name, matcher_class in MATCHERS.items(): | ||
matcher = matcher_class(matcher_input['dataset'], matcher_input['state_inds'], model, device=device, save_dir=save_dir) | ||
|
||
self.matchers[matcher_name] = {} | ||
self.matchers[matcher_name]['matcher'] = matcher | ||
|
||
matcher.prepare_for_matching() | ||
matcher.match() | ||
|
||
self.matchers[matcher_name]['matches'] = matcher.matches | ||
|
||
def set_match_keys(self, matcher_name): | ||
matches = self.matchers[matcher_name]['matches'] | ||
self.inds['matched']['case'] = self.inds['filtered']['case'][matches['case']] | ||
self.inds['matched']['control'] = self.inds['filtered']['control'][matches['control']] | ||
|
||
def _create_batch_for_eval(self, matcher_name, feature_info): | ||
batches = {} | ||
|
||
state_inds = {} | ||
self.set_match_keys(matcher_name) | ||
|
||
db = NISDatabase(DATASETS['eval_input'], 'case', state_inds=self.inds['matched']) | ||
|
||
for cohort_type, cohort_inds in self.inds['matched'].items(): | ||
cohort_presenting = {} | ||
|
||
db.change_state(cohort_type) | ||
cohort_dl = DataLoader(db, batch_size=1000, pin_memory=True, num_workers=1) | ||
|
||
for idx, batch in enumerate(cohort_dl): | ||
target = self._isolate_features(batch, feature_info) | ||
|
||
for feature, recon in target.items(): | ||
if feature not in list(cohort_presenting.keys()): | ||
recon_np = np.array(recon.detach().to('cpu')) | ||
cohort_presenting[feature] = recon_np | ||
else: | ||
recon_np = np.array(recon.detach().to('cpu')) | ||
cohort_presenting[feature] = np.vstack((cohort_presenting[feature], recon_np)) | ||
|
||
batches[cohort_type] = cohort_presenting | ||
|
||
return batches | ||
|
||
@staticmethod | ||
def _isolate_features(x, feature_info): | ||
"""Modify the features as needed.""" | ||
|
||
# Initalize stores | ||
ground_truth = {} | ||
# x = torch.Tensor(x) | ||
|
||
for embedding, embedding_layer in feature_info['embedding'].items(): | ||
features_to_use = embedding_layer['feature_idx'] | ||
batch_emb = x[:, features_to_use].to(torch.long) | ||
|
||
ground_truth[embedding] = nn.functional.one_hot(batch_emb, num_classes=embedding_layer['num_classes']) | ||
ground_truth[embedding] = (torch.sum(ground_truth[embedding], axis=1) > 0).to(torch.float) | ||
|
||
for one_hot_name, one_hot in feature_info['one_hots'].items(): | ||
nc = one_hot['num_classes'] | ||
one_hot_encoded = nn.functional.one_hot(x[:, one_hot['feature_idx']].to(torch.long), num_classes=nc) | ||
|
||
ground_truth[one_hot_name] = one_hot_encoded | ||
|
||
for cont_name, continuous_feature in feature_info['continuous'].items(): | ||
cont = x[:, continuous_feature['feature_idx']].view(-1, 1) | ||
ground_truth[cont_name] = cont | ||
|
||
return ground_truth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import numpy as np | ||
import h5py | ||
import torch | ||
import torch.nn as nn | ||
|
||
from functools import reduce | ||
|
||
from data.data_loader import NISDatabase | ||
from utils import feature_utils | ||
|
||
ALLOWED_STATES = ['case', 'control'] | ||
|
||
class CohortFilterer(NISDatabase): | ||
def __init__(self, dataset_fn, filter_params): | ||
|
||
super().__init__(dataset_fn, 'full') | ||
|
||
self.filter_params = filter_params | ||
|
||
def filter_patients(self): | ||
"""Build a case-control group based on very fundamental criteria.""" | ||
|
||
if self.dataset is None: | ||
self.dataset = h5py.File(self.filename, 'r')['dataset'] | ||
|
||
# Find feature indices belonging to specific criteria | ||
inclusion_info = self.filter_params['inclusion'] | ||
# exclusion_info = self.filter_params['exclusion'] | ||
case_control_info = self.filter_params['case_control'] | ||
|
||
inclusion_inds = self.check_criteria(inclusion_info, case_control=False) | ||
# exclusion_inds = self.check_criteria(exclusion_info, case_control=False) | ||
case_inds, control_inds = self.check_criteria(case_control_info, case_control=True) | ||
|
||
filtered_inds = {} | ||
# inclusion_exclusion_inds = np.setdiff1d(inclusion_inds, exclusion_inds) | ||
filtered_inds['case'] = np.intersect1d(inclusion_inds, case_inds) | ||
filtered_inds['control'] = np.intersect1d(inclusion_inds, control_inds) | ||
|
||
return filtered_inds | ||
|
||
def find_feature(self, pattern): | ||
"""Find the indices for a feature type in the dataset.""" | ||
idxs = [] | ||
for idx, header in enumerate(self.headers): | ||
header = header.decode('utf-8') | ||
lp = len(pattern) | ||
|
||
# Find continuations | ||
if header == pattern: | ||
idxs.append(idx) | ||
elif header[:lp] == pattern and header[lp] in [str(i) for i in range(0, 10)]: | ||
idxs.append(idx) | ||
|
||
return idxs | ||
|
||
def check_criteria(self, criteria, case_control=False): | ||
"""Perform an intersection across all criteria that are upheld by the data.""" | ||
|
||
if case_control: | ||
pts_meeting_criteria = {key : [] for key in ['case', 'control']} | ||
else: | ||
pts_meeting_criteria = [] | ||
|
||
if len(criteria) == 0: # mostly for exclusion criteria. | ||
return np.array([]) | ||
|
||
for name, criterion in criteria.items(): | ||
print(name, criterion) | ||
feature_inds = self.find_feature(name) | ||
pts_meeting_criterion = self.search_by_chunk(self.dataset, feature_inds, criterion, case_control) | ||
|
||
if case_control: | ||
pts_meeting_criteria['case'].append(pts_meeting_criterion['case']) | ||
pts_meeting_criteria['control'].append(pts_meeting_criterion['control']) | ||
else: | ||
pts_meeting_criteria.append(pts_meeting_criterion) | ||
|
||
if case_control: | ||
return reduce(np.intersect1d, pts_meeting_criteria['case']), \ | ||
reduce(np.intersect1d, pts_meeting_criteria['control']) | ||
else: | ||
return reduce(np.intersect1d, pts_meeting_criteria) | ||
|
||
@staticmethod | ||
def search_by_chunk(dataset, feature_inds, criterion, case_control=False): | ||
count = 0 | ||
chunk_size = 1000000 | ||
|
||
if case_control: | ||
inds = {key : [] for key in ['case', 'control']} | ||
else: | ||
inds = [] | ||
|
||
while count < dataset.shape[0]: | ||
print(count) | ||
# iteration | ||
if count + chunk_size > dataset.shape[0]: | ||
# count = dataset.shape[0] | ||
chunk_size = dataset.shape[0] - count | ||
|
||
# do something | ||
chunk = dataset[count:count+chunk_size, feature_inds].reshape(chunk_size, -1) | ||
|
||
match = np.empty((chunk.shape[0], len(criterion))) | ||
for i, criterion_i in enumerate(criterion): | ||
# print("Criterion: ", criterion_i) | ||
matches_i = (chunk == criterion_i).reshape(-1, chunk.shape[1]) # perform match | ||
match_i = (np.sum(matches_i, axis=1) > 0) # across all features | ||
match[:, i] = match_i | ||
|
||
match = np.sum(match, axis=1) # assuming union, but @TODO make more generalizable | ||
|
||
if case_control: | ||
ind_case = np.where(match)[0] + count | ||
ind_control = np.where(match == 0)[0] + count | ||
inds['case'].extend(ind_case) | ||
inds['control'].extend(ind_control) | ||
|
||
# print(count, ind_case.shape[0], ind_control.shape[0]) | ||
# print(count, len(inds['case']), len(inds['control'])) | ||
|
||
else: | ||
inds_chunk = np.where(match)[0] + count # find indices | ||
inds.extend(inds_chunk) | ||
|
||
# finalize iteration | ||
count += chunk_size | ||
if count >= dataset.shape[0]: | ||
break | ||
|
||
return inds |
Oops, something went wrong.