-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnlcd_dataset.py
141 lines (129 loc) · 4.35 KB
/
nlcd_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from __future__ import annotations
from PIL import Image
import torch
from pathlib import Path
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.transforms as T
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision.transforms.functional import convert_image_dtype
import torchvision.transforms.functional as F
from torchvision.io import read_image
import utils
import json
from torch.utils.data import random_split
from build_nlcd_dataset import BuildNlcdDataset
class NlcdDataModule(pl.LightningDataModule):
def __init__(
self,
cvusa_root: Path = None,
out_root: Path = None,
nlcd_csv_path: Path = None,
nlcd_root: Path = None,
batch_size: int = 2,
zoom: str = "18",
num_workers: int = 2,
start_index: int = 0,
num_items: int = 1000,
valid_pct: float = 0.05,
):
super().__init__()
self.batch_size = batch_size
self.cvusa_root = cvusa_root
self.out_root = out_root
self.nlcd_root = nlcd_root
self.num_workers = num_workers
self.start_index = start_index
self.num_items = num_items
self.nlcd_csv_path = nlcd_csv_path
self.valid_pct = valid_pct
self.zoom = zoom
def setup(self, stage=None):
if not self.nlcd_csv_path.is_file():
bnd = BuildNlcdDataset(
num_items=self.num_items,
nlcd_root=self.nlcd_root,
cvusa_root=self.cvusa_root,
out_root=self.out_root,
zoom=self.zoom,
)
bnd.build()
data = utils.read_csv(self.nlcd_csv_path)
self.tfm = transforms.Compose([T.Resize([256, 256])])
if self.num_items == None:
self.num_items = len(data)
rows = data[self.start_index : self.start_index + self.num_items]
valid_size = int(self.valid_pct * self.num_items)
train_rows, valid_rows = random_split(
rows,
[self.num_items - valid_size, valid_size],
generator=torch.Generator().manual_seed(42),
)
self.train_nlcd_dataset = NlcdDataset(
data=train_rows, cvusa_root=self.cvusa_root, tfms=self.tfm
)
self.valid_nlcd_dataset = NlcdDataset(
data=valid_rows, cvusa_root=self.cvusa_root, tfms=self.tfm
)
def train_dataloader(self):
return DataLoader(
self.train_nlcd_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
shuffle=False,
drop_last=True,
collate_fn=utils.collate_fn,
)
def val_dataloader(self):
return DataLoader(
self.valid_nlcd_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
shuffle=False,
drop_last=True,
collate_fn=utils.collate_fn,
)
class NlcdDataset(Dataset):
def __init__(self, cvusa_root: str, data, tfms=None) -> None:
super().__init__()
self.data = data
self.cvusa_root = cvusa_root
self.tfms = tfms
def __getitem__(self, index):
if index >= len(self) or index < 0:
raise IndexError("index greater than data length")
row = self.data[index]
img_name = row[0]
lat = float(row[1])
lon = float(row[2])
labels = json.loads(row[3])
label = labels[0]
clabels = json.loads(row[4])
if len(clabels) == 0:
return None
clabel = clabels[0]
try:
img = read_image(str(self.cvusa_root / img_name))
if img.shape[0] != 3:
print(f"3 dimensions not present in image with id {img_name}")
return None
except Exception as e:
print(f"could not read image with id {img_name}, error msg {str(e)}")
return None
if self.tfms:
img = self.tfms(img)
img = convert_image_dtype(img)
row = {
# 'image_path': str(img_name),
"aerial_img": img,
"lat": lat,
"lon": lon,
"nlcd_labels": label,
"nlcd_coarse_labels": clabel,
}
return row
def __len__(self):
return len(self.data)