-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNopileup_dataset.py
38 lines (24 loc) · 1.07 KB
/
Nopileup_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
import numpy as np
import pandas as pd
from torch.utils import data
import re
class NOPDataset(data.Dataset):
"""
Custom PyTorch dataset class mapping short-read STR loci and their corresponding features to long-read STR counts
"""
def __init__(self, ohe_dir, metadata_file):
self.ohe_dir = ohe_dir # each file is a one-hot encoding of the sequence and depth at a locus
self.metadata = pd.read_csv(metadata_file, sep = '\t') # contains truth labels and metadata
self.annots = len(self.metadata.index)
def __len__(self):
return self.annots
def __getitem__(self, idx):
locus = self.metadata['trid'].iloc[idx]
sample_name = self.metadata['sample_name'].iloc[idx]
ohe_file = os.path.join(self.ohe_dir, sample_name, f'{sample_name}_{locus}.npy')
ohe = np.load(ohe_file)
mc = self.metadata['MC'].iloc[idx]
mc_split = mc.split(',')
label = float(max(mc_split))
return (ohe, label)