Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
RespectableGlioma authored Mar 30, 2021
1 parent 4b95ac3 commit 28cc904
Show file tree
Hide file tree
Showing 51 changed files with 132,098 additions and 0 deletions.
1 change: 1 addition & 0 deletions nis-patient-encoding/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# nis-patient-encoding
1 change: 1 addition & 0 deletions nis-patient-encoding/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#
171 changes: 171 additions & 0 deletions nis-patient-encoding/data/cohort_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import numpy as np
import h5py
import tables
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from functools import reduce

from data.data_loader import NISDatabase
from data.cohort_filterer import CohortFilterer
from similarity.matchers import DeepMatcher
# from similarity.matchers import PropensityScoreMatcher
# from similarity.matchers import FuzzyMatcher
from utils import feature_utils
from utils.code_mappings import map_icd9_to_numeric

INCLUSION_INFO = {'DXCCS': [122, 109]}
CASE_CONTROL_INFO = {'DX': [map_icd9_to_numeric('42731'),]}
COHORT_FILTER_INFO = {'inclusion': INCLUSION_INFO, 'case_control': CASE_CONTROL_INFO}

DATA_FOLDER = '/home/aisinai/work/repos/nis_patient_encoding/data/raw/'

DATASETS = {
'filtering' : f'{DATA_FOLDER}NIS_Pruned.h5',
'matching' : f'{DATA_FOLDER}NIS_2012_2014_proto_emb_v2.h5',
'eval_input': f'{DATA_FOLDER}NIS_2012_2014_proto_emb_v2.h5',
'eval_output': f'{DATA_FOLDER}NIS_2012_2014_proto_emb_v2.h5',
}

MATCHERS = {
# 'propensity' : ,
'deep' : DeepMatcher,
# 'fuzzy' : FuzzyMatcher,
}

class CohortBuilder():
"""Manages initial cohort construction.
This class takes in the full DataFrame with raw ICD-9 codes to find the indices that map to particular cohorts.
These indices may then be used by downstream classes in more pruned/filtered datasets for use in the AE system.
"""

def __init__(
self,
datasets=DATASETS,
filter_params=COHORT_FILTER_INFO):

"""Manages loading of NIS database samples in a H5 file.
Note: reimplemented here, but @TODO to refactor this entire class structure.
"""
# super().__init__(filename, 'TRAIN')

self.filter_params = filter_params
self.datasets = datasets

matcher_results = {matcher : {'case': [], 'control': []} for matcher in MATCHERS.keys()}
self.inds = {
'filtered': {'case': [], 'control': []}, # from full dataset to those included in current analysis
'matched': {'case': [], 'control': []} # from included in current analysis to matched
}

self.filtered = False
self.matched = False

def prepare_for_match(self):
raise NotImplementedError

def filter_patients(self, sample=None):
"""Build a case-control group based on very fundamental criteria."""

filter_dataset = self.datasets['filtering']

filterer = CohortFilterer(filter_dataset, self.filter_params)
self.inds['filtered'] = filterer.filter_patients()

if sample:
case_alias = self.inds['filtered']['case']
control_alias = self.inds['filtered']['control']
self.inds['filtered']['case'] = np.random.choice(case_alias, int(sample * case_alias.shape[0]), replace=False)
self.inds['filtered']['control'] = np.random.choice(control_alias, int(sample * control_alias.shape[0]), replace=False)

self.filtered = True

def match_patients(self, model, device='cpu', save_dir=None):
"""Match patients for each matcher."""

if self.filtered == False:
raise ValueError("Must perform filtering before matching, silly!")

matcher_input = {
'dataset' : self.datasets['matching'],
'state_inds' : self.inds['filtered']
}

self.matchers = {}

for matcher_name, matcher_class in MATCHERS.items():
matcher = matcher_class(matcher_input['dataset'], matcher_input['state_inds'], model, device=device, save_dir=save_dir)

self.matchers[matcher_name] = {}
self.matchers[matcher_name]['matcher'] = matcher

matcher.prepare_for_matching()
matcher.match()

self.matchers[matcher_name]['matches'] = matcher.matches

def set_match_keys(self, matcher_name):
matches = self.matchers[matcher_name]['matches']
self.inds['matched']['case'] = self.inds['filtered']['case'][matches['case']]
self.inds['matched']['control'] = self.inds['filtered']['control'][matches['control']]

def _create_batch_for_eval(self, matcher_name, feature_info):
batches = {}

state_inds = {}
self.set_match_keys(matcher_name)

db = NISDatabase(DATASETS['eval_input'], 'case', state_inds=self.inds['matched'])

for cohort_type, cohort_inds in self.inds['matched'].items():
cohort_presenting = {}

db.change_state(cohort_type)
cohort_dl = DataLoader(db, batch_size=1000, pin_memory=True, num_workers=1)

for idx, batch in enumerate(cohort_dl):
target = self._isolate_features(batch, feature_info)

for feature, recon in target.items():
if feature not in list(cohort_presenting.keys()):
recon_np = np.array(recon.detach().to('cpu'))
cohort_presenting[feature] = recon_np
else:
recon_np = np.array(recon.detach().to('cpu'))
cohort_presenting[feature] = np.vstack((cohort_presenting[feature], recon_np))

batches[cohort_type] = cohort_presenting

return batches

@staticmethod
def _isolate_features(x, feature_info):
"""Modify the features as needed."""

# Initalize stores
ground_truth = {}
# x = torch.Tensor(x)

for embedding, embedding_layer in feature_info['embedding'].items():
features_to_use = embedding_layer['feature_idx']
batch_emb = x[:, features_to_use].to(torch.long)

ground_truth[embedding] = nn.functional.one_hot(batch_emb, num_classes=embedding_layer['num_classes'])
ground_truth[embedding] = (torch.sum(ground_truth[embedding], axis=1) > 0).to(torch.float)

for one_hot_name, one_hot in feature_info['one_hots'].items():
nc = one_hot['num_classes']
one_hot_encoded = nn.functional.one_hot(x[:, one_hot['feature_idx']].to(torch.long), num_classes=nc)

ground_truth[one_hot_name] = one_hot_encoded

for cont_name, continuous_feature in feature_info['continuous'].items():
cont = x[:, continuous_feature['feature_idx']].view(-1, 1)
ground_truth[cont_name] = cont

return ground_truth
132 changes: 132 additions & 0 deletions nis-patient-encoding/data/cohort_filterer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import numpy as np
import h5py
import torch
import torch.nn as nn

from functools import reduce

from data.data_loader import NISDatabase
from utils import feature_utils

ALLOWED_STATES = ['case', 'control']

class CohortFilterer(NISDatabase):
def __init__(self, dataset_fn, filter_params):

super().__init__(dataset_fn, 'full')

self.filter_params = filter_params

def filter_patients(self):
"""Build a case-control group based on very fundamental criteria."""

if self.dataset is None:
self.dataset = h5py.File(self.filename, 'r')['dataset']

# Find feature indices belonging to specific criteria
inclusion_info = self.filter_params['inclusion']
# exclusion_info = self.filter_params['exclusion']
case_control_info = self.filter_params['case_control']

inclusion_inds = self.check_criteria(inclusion_info, case_control=False)
# exclusion_inds = self.check_criteria(exclusion_info, case_control=False)
case_inds, control_inds = self.check_criteria(case_control_info, case_control=True)

filtered_inds = {}
# inclusion_exclusion_inds = np.setdiff1d(inclusion_inds, exclusion_inds)
filtered_inds['case'] = np.intersect1d(inclusion_inds, case_inds)
filtered_inds['control'] = np.intersect1d(inclusion_inds, control_inds)

return filtered_inds

def find_feature(self, pattern):
"""Find the indices for a feature type in the dataset."""
idxs = []
for idx, header in enumerate(self.headers):
header = header.decode('utf-8')
lp = len(pattern)

# Find continuations
if header == pattern:
idxs.append(idx)
elif header[:lp] == pattern and header[lp] in [str(i) for i in range(0, 10)]:
idxs.append(idx)

return idxs

def check_criteria(self, criteria, case_control=False):
"""Perform an intersection across all criteria that are upheld by the data."""

if case_control:
pts_meeting_criteria = {key : [] for key in ['case', 'control']}
else:
pts_meeting_criteria = []

if len(criteria) == 0: # mostly for exclusion criteria.
return np.array([])

for name, criterion in criteria.items():
print(name, criterion)
feature_inds = self.find_feature(name)
pts_meeting_criterion = self.search_by_chunk(self.dataset, feature_inds, criterion, case_control)

if case_control:
pts_meeting_criteria['case'].append(pts_meeting_criterion['case'])
pts_meeting_criteria['control'].append(pts_meeting_criterion['control'])
else:
pts_meeting_criteria.append(pts_meeting_criterion)

if case_control:
return reduce(np.intersect1d, pts_meeting_criteria['case']), \
reduce(np.intersect1d, pts_meeting_criteria['control'])
else:
return reduce(np.intersect1d, pts_meeting_criteria)

@staticmethod
def search_by_chunk(dataset, feature_inds, criterion, case_control=False):
count = 0
chunk_size = 1000000

if case_control:
inds = {key : [] for key in ['case', 'control']}
else:
inds = []

while count < dataset.shape[0]:
print(count)
# iteration
if count + chunk_size > dataset.shape[0]:
# count = dataset.shape[0]
chunk_size = dataset.shape[0] - count

# do something
chunk = dataset[count:count+chunk_size, feature_inds].reshape(chunk_size, -1)

match = np.empty((chunk.shape[0], len(criterion)))
for i, criterion_i in enumerate(criterion):
# print("Criterion: ", criterion_i)
matches_i = (chunk == criterion_i).reshape(-1, chunk.shape[1]) # perform match
match_i = (np.sum(matches_i, axis=1) > 0) # across all features
match[:, i] = match_i

match = np.sum(match, axis=1) # assuming union, but @TODO make more generalizable

if case_control:
ind_case = np.where(match)[0] + count
ind_control = np.where(match == 0)[0] + count
inds['case'].extend(ind_case)
inds['control'].extend(ind_control)

# print(count, ind_case.shape[0], ind_control.shape[0])
# print(count, len(inds['case']), len(inds['control']))

else:
inds_chunk = np.where(match)[0] + count # find indices
inds.extend(inds_chunk)

# finalize iteration
count += chunk_size
if count >= dataset.shape[0]:
break

return inds
Loading

0 comments on commit 28cc904

Please sign in to comment.