Skip to content

Commit

Permalink
separate dataloader for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
kishwarshafin committed Jul 19, 2018
1 parent 68e40bd commit bac24f5
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 9 deletions.
10 changes: 5 additions & 5 deletions analysis/hyperband_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,21 @@ def plot_hyperband(pkl_file_path):
plt.plot(*zip(*results_xy[i][1]), 'o--')
indx = results_xy[i][0]
# labels.append(r'$(%f ,%f)$' % (result_reverse_hash[indx][0], result_reverse_hash[indx][1]))
# x1, x2, y1, y2 = plt.axis()
# plt.axis((x1, x2, y1, 0.00005))
x1, x2, y1, y2 = plt.axis()
# plt.axis((x1, x2, y1, 0.00075))
# print(min_loss, max_loss)
# plt.legend(labels, ncol=4, loc='upper center',
# bbox_to_anchor=[0.5, 1.1],
# columnspacing=1.0, labelspacing=0.0,
# handletextpad=0.0, handlelength=1.5,
# fancybox=True, shadow=True)
plt.text(3.5, 0.00004, 'Parameters tuned:\n1) Encoder learning rate\n2) Encoder weight decay\n'
'3) Decoder learning rate\n4) Decoder weight decay', verticalalignment='center',
plt.text(6.5, 0.0020, 'Parameters tuned:\n1) Encoder learning rate\n2) Encoder weight decay\n'
'3) Decoder learning rate\n4) Decoder weight decay', verticalalignment='center',
bbox=dict(facecolor='white', alpha=0.5))
plt.xlabel('Iterations')
plt.ylabel('Test loss')
plt.title('Hyper-parameter tuning with hyperband algorithm')
# labels.append(r'$y = %ix + %i$' % (i, 5 * i))
# labels.append(r'$y = %ix + %i$' % (i, 5 * i))
plt.show()


Expand Down
2 changes: 1 addition & 1 deletion call_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import multiprocessing
from torch.autograd import Variable
from modules.models.Seq2Seq_atn import EncoderCRNN, AttnDecoderRNN
from modules.core.dataloader import SequenceDataset
from modules.core.dataloader_test import SequenceDataset
from modules.handlers.TextColor import TextColor
from collections import defaultdict
from modules.handlers.VcfWriter import VCFWriter
Expand Down
68 changes: 68 additions & 0 deletions modules/core/dataloader_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
import h5py
import torch
import pickle


class SequenceDataset(Dataset):
"""
Arguments:
A CSV file path
"""

def __init__(self, csv_path, transform=None):
data_frame = pd.read_csv(csv_path, header=None, dtype=str)
# assert data_frame[0].apply(lambda x: os.path.isfile(x.split(' ')[0])).all(), \
# "Some images referenced in the CSV file were not found"
# / data / users / jacob / image_output / run_07182018_172143 / chr19 / chr19_11546392.h5 / data / users / jacob / image_output / run_07182018_172143 / chr19 / candidate_dictionaries / chr19_11546392_11550865.pkl, 63, chr19
# 11546517, 0, TATTTTTAG * TAGAGATGGGG
self.transform = transform

self.file_info = list(data_frame[0])
self.index_info = list(data_frame[1])
self.position_info = list(data_frame[2])
self.label = list(data_frame[3])
self.reference_seq = list(data_frame[4])

@staticmethod
def load_dictionary(dictionary_location):
f = open(dictionary_location, 'rb')
dict = pickle.load(f)
f.close()
return dict

def __getitem__(self, index):
# load the image
hdf5_file_path, allele_dict_path = self.file_info[index].split(' ')
hdf5_index = int(self.index_info[index])
# load positional information
chromosome_name, genomic_start_position = self.position_info[index].split(' ')
# load genomic position information
reference_sequence = self.reference_seq[index]
# load the labels
label = self.label[index]
label = [int(x) for x in label]

hdf5_file = h5py.File(hdf5_file_path, 'r')
image_dataset = hdf5_file['images']
img = np.array(image_dataset[hdf5_index], dtype=np.uint8)

label = np.array(label)

# img = img.astype(dtype=np.uint8)
# type fix and convert to tensor
if self.transform is not None:
img = self.transform(img)
img = img.transpose(1, 2)

label = torch.from_numpy(label)

positional_information = (chromosome_name, genomic_start_position, reference_sequence, allele_dict_path)

return img, label, positional_information

def __len__(self):
return len(self.file_info)
2 changes: 1 addition & 1 deletion modules/hyperband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def train(train_file, test_file, batch_size, epoch_limit, prev_ite, gpu_mode, nu
decoder_model.train()
batch_no = 1
with tqdm(total=len(train_loader), desc='Loss', leave=True, dynamic_ncols=True) as progress_bar:
for images, labels, positional_information in train_loader:
for images, labels in train_loader:
if gpu_mode:
# encoder_hidden = encoder_hidden.cuda()
images = images.cuda()
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test(data_file, batch_size, hidden_size, gpu_mode, encoder_model, decoder_mo
accuracy = 0
with torch.no_grad():
with tqdm(total=len(test_loader), desc='Accuracy: ', leave=True, dynamic_ncols=True) as pbar:
for i, (images, labels, positional_information) in enumerate(test_loader):
for i, (images, labels) in enumerate(test_loader):
if gpu_mode:
# encoder_hidden = encoder_hidden.cuda()
images = images.cuda()
Expand Down Expand Up @@ -190,7 +190,7 @@ def train(train_file, test_file, batch_size, epoch_limit, gpu_mode, num_workers,
decoder_model.train()
batch_no = 1
with tqdm(total=len(train_loader), desc='Loss', leave=True, dynamic_ncols=True) as progress_bar:
for images, labels, positional_information in train_loader:
for images, labels in train_loader:
if gpu_mode:
# encoder_hidden = encoder_hidden.cuda()
images = images.cuda()
Expand Down

0 comments on commit bac24f5

Please sign in to comment.