Skip to content

Commit

Permalink
Updated preprocessing and convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
taj-gillin committed May 2, 2023
1 parent 08a0c7b commit 09afa4d
Show file tree
Hide file tree
Showing 12 changed files with 228 additions and 101 deletions.
Binary file modified .DS_Store
Binary file not shown.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Installation instructions

1. Run `conda env create --file environment.yml`
1. Run `conda env create --file /config/environment.yml`
2. Run `conda activate pix2pix-pytorch`

# pix2pix-terraform
Expand Down
Binary file added __pycache__/preprocess.cpython-310.pyc
Binary file not shown.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Binary file modified data/.DS_Store
Binary file not shown.
File renamed without changes.
234 changes: 160 additions & 74 deletions pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,22 @@
from preprocess import GetDataset
import random

dataset = GetDataset()
# dataset = GetDataset()

composed = transforms.Compose(
[transforms.Resize(286),
transforms.RandomCrop(256)])

for i in range(0, len(dataset), 2):
#Satellite, Elevation Pair
pair = [dataset[i]['image'], dataset[i+1]['image']]
pair = [transforms.ToTensor()(image) for image in pair]

#Performs a Resize and RandomCrop
transformed_pair = [composed(image) for image in pair]

#Performs the same flip for both images
flip = random.choice([transforms.RandomHorizontalFlip(1),
transforms.RandomHorizontalFlip(0)])
transformed_pair = [flip(image) for image in pair]

print(transformed_pair[0].shape)
print(transformed_pair[1].shape)
print(i)
#plt.show()

OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_batchnorm=True):
def downsample(in_channels, out_channels, size, apply_batchnorm=True):
#3 in_channels, 3 out_channels, 4x4 size, 2 stride
conv = torch.nn.Conv2d(filters, filters, size, strides=2, padding="same", bias=False)
conv = torch.nn.Conv2d(in_channels, out_channels, size, stride=2, padding=1, bias=False)
torch.nn.init.normal(conv.weight, 0, 0.02)

#expected 3 channels
batchnorm = torch.nn.BatchNorm2d(3)
leakyrelu = torch.nn.LeakyReLU()

def call(x):
print(x)
x = conv(x)
if apply_batchnorm:
x = batchnorm(x)
Expand All @@ -51,10 +31,14 @@ def call(x):

return call

test_down = downsample(3, 64, 4, apply_batchnorm=False)
inp = torch.zeros((1,3,256,256))
print("Test: ")
print(test_down(inp).size())

def upsample(x, filters, size, apply_dropout=False):
def upsample(filters, size, apply_dropout=False):
#3 in_channels, 3 out_channels, 4x4 size, 2 stride
convT = torch.nn.ConvTranspose2d(filters, filters, size, strides=2, padding="same", bias=False)
convT = torch.nn.ConvTranspose2d(filters, filters, size, stride=2, bias=False)
torch.nn.init.normal(convT.weight, 0, 0.02)

#expected 3 channels
Expand All @@ -74,51 +58,153 @@ def call(x):

#inputs go here

down_model = downsample(None)
up_model = upsample(None)

def Generator():

inputs = torch.tensor(np.zeros(3,256,256))
down_stack = [
downsample(64, 4, apply_batchnorm=False), # (batch_size, 128, 128, 64)
downsample(128, 4), # (batch_size, 64, 64, 128)
downsample(256, 4), # (batch_size, 32, 32, 256)
downsample(512, 4), # (batch_size, 16, 16, 512)
downsample(512, 4), # (batch_size, 8, 8, 512)
downsample(512, 4), # (batch_size, 4, 4, 512)
downsample(512, 4), # (batch_size, 2, 2, 512)
downsample(512, 4), # (batch_size, 1, 1, 512)
]
up_stack = [
upsample(512, 4, apply_dropout=True), # (batch_size, 2, 2, 1024)
upsample(512, 4, apply_dropout=True), # (batch_size, 4, 4, 1024)
upsample(512, 4, apply_dropout=True), # (batch_size, 8, 8, 1024)
upsample(512, 4), # (batch_size, 16, 16, 1024)
upsample(256, 4), # (batch_size, 32, 32, 512)
upsample(128, 4), # (batch_size, 64, 64, 256)
upsample(64, 4), # (batch_size, 128, 128, 128)
]
last = torch.nn.ConvTranspose2d(3, 3, 4, strides=2, padding="same", bias=False)
tanh = torch.nn.Tanh()
torch.nn.init.normal(last.weight, 0, 0.02)

x = inputs

# Downsampling through the model
skips = []
for down in down_stack:
x = down(x)
skips.append(x)

skips = reversed(skips[:-1])

# Upsampling and establishing the skip connections
for up, skip in zip(up_stack, skips):
x = up(x)
x = tf.keras.layers.Concatenate()([x, skip])

x = last(x)
x = tanh(x)

return tf.keras.Model(inputs=inputs, outputs=x)
# down_model = downsample(None)
# up_model = upsample(None)

class Generator(torch.nn.Module):

def __init__(self):
super().__init__()

# inputs = torch.tensor(np.zeros(3,256,256))
self.down_stack = [
downsample(64, 4, apply_batchnorm=False), # (batch_size, 128, 128, 64)
downsample(128, 4), # (batch_size, 64, 64, 128)
downsample(256, 4), # (batch_size, 32, 32, 256)
downsample(512, 4), # (batch_size, 16, 16, 512)
downsample(512, 4), # (batch_size, 8, 8, 512)
downsample(512, 4), # (batch_size, 4, 4, 512)
downsample(512, 4), # (batch_size, 2, 2, 512)
downsample(512, 4), # (batch_size, 1, 1, 512)
]

self.up_stack = [
upsample(512, 4, apply_dropout=True), # (batch_size, 2, 2, 1024)
upsample(512, 4, apply_dropout=True), # (batch_size, 4, 4, 1024)
upsample(512, 4, apply_dropout=True), # (batch_size, 8, 8, 1024)
upsample(512, 4), # (batch_size, 16, 16, 1024)
upsample(256, 4), # (batch_size, 32, 32, 512)
upsample(128, 4), # (batch_size, 64, 64, 256)
upsample(64, 4), # (batch_size, 128, 128, 128)
]

self.last = torch.nn.ConvTranspose2d(3, 3, 4, stride=2, padding="same", bias=False)
self.tanh = torch.nn.Tanh()

# torch.nn.init.normal(last.weight, 0, 0.02)

def forward(self, x):
# Downsampling through the model
skips = []
for down in self.down_stack:
print(f"X with size: {x.size()}")
x = down(x)
skips.append(x)

skips = reversed(skips[:-1])

# Upsampling and establishing the skip connections
for up, skip in zip(self.up_stack, skips):
x = up(x)
x = torch.cat(x, skip)

x = self.last(x)
x = self.tanh(x)

return x

# return torch.nn.Model(inputs=inputs, outputs=x)

generator = Generator()

# inputs = torch.tensor(np.zeros((3,256,256)))
# plt.imshow(generator(inputs)[0])



LAMBDA = 100

loss_object = torch.nn.BCELoss()

def generator_loss(disc_generated_output, gen_output, target):
gan_loss = loss_object(disc_generated_output, torch.ones_like(disc_generated_output))

# mean absolute error
l1_loss = torch.mean(torch.abs(target - gen_output))

total_gen_loss = gan_loss + (LAMBDA * l1_loss)

return total_gen_loss, gan_loss, l1_loss

# Define the Discriminator
class Discriminator(torch.nn.Module):

def __init__(self):
super().__init__()

self.down1 = downsample(64, 4, False)
self.down2 = downsample(128, 4)
self.down3 = downsample(256, 4)

self.zero_pad1 = torch.nn.ZeroPad2d(1)
self.conv = torch.nn.Conv2d(256, 512, 4, stride=1, bias=False)

self.batchnorm = torch.nn.BatchNorm2d(512)
self.leaky_relu = torch.nn.LeakyReLU()

self.zero_pad2 = torch.nn.ZeroPad2d(1)

self.last = torch.nn.Conv2d(512, 1, 4, stride=1, bias=False)

def forward(self, x):
x = self.down1(x)
x = self.down2(x)
x = self.down3(x)

x = self.zero_pad1(x)
x = self.conv(x)

x = self.batchnorm(x)
x = self.leaky_relu(x)

x = self.zero_pad2(x)

x = self.last(x)

return x

discriminator = Discriminator()

# Define discriminator loss
def discriminator_loss(disc_real_output, disc_generated_output):
real_loss = loss_object(disc_real_output, torch.ones_like(disc_real_output))

generated_loss = loss_object(disc_generated_output, torch.zeros_like(disc_generated_output))

total_disc_loss = real_loss + generated_loss

return total_disc_loss


# Optimizers
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Generate images
def generate_images(model, test_input, tar):
prediction = model(test_input)
plt.figure(figsize=(15, 15))

display_list = [test_input[0], tar[0], prediction[0]]
title = ['Input Image', 'Ground Truth', 'Predicted Image']

for i in range(3):
plt.subplot(1, 3, i+1)
plt.title(title[i])
# Getting the pixel values in the [0, 1] range to plot.
plt.imshow(display_list[i] * 0.5 + 0.5)
plt.axis('off')
plt.show()

if __name__ == '__main__':
print('runnin')
75 changes: 49 additions & 26 deletions preprocess.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,70 @@
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

import os
import random
import numpy as np
import torch
from skimage import io, transform
from skimage import io
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

os.environ['KMP_DUPLICATE_LIB_OK']='True'

class SatelliteElevationDataset(Dataset):
'''Satellite Elevation Dataset'''

def __init__(self, root_dir, tile_ct):
self.root_dir = root_dir
self.filenames = []
for i in range(0, tile_ct * 256, 256):
for j in range(0, tile_ct * 256, 256):
self.filenames.append(f"{i},{j}c.jpg")
self.filenames.append(f"{i},{j}.jpg")
def __init__(self, root_dirs, tile_cts, transform=None):
'''
Arguments:
root_dir (list[string]): List of paths to the root directory of the image data
tile_ct (int): number of 51.2 km squares per side of area from which data was collected
transform (callable, optional): Optional transform to be applied on a sample
'''
self.transform = transform
self.elevation_imgs = []
self.satellite_imgs = []
for root_dir, tile_ct in zip(root_dirs, tile_cts):
for i in range(0, tile_ct * 256, 256):
for j in range(0, tile_ct * 256, 256):
self.elevation_imgs.append(os.path.join(root_dir + f"{i},{j}c.jpg"))
self.satellite_imgs.append(os.path.join(root_dir + f"{i},{j}.jpg"))

def __len__(self):
return len(self.filenames)
return len(self.elevation_imgs)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.root_dir + self.filenames[idx])
image = io.imread(img_name)
sample = {'image': image}
# os.path.join(self.root_dir + self.elevation_imgs[idx])
elevation_img_name = self.elevation_imgs[idx]
satellite_img_name = self.satellite_imgs[idx]
elevation_img = io.imread(elevation_img_name)
satellite_img = io.imread(satellite_img_name)
sample = {'elevation': elevation_img, 'satellite': satellite_img}
if self.transform:
sample = self.transform(sample)
return sample

def GetDataset():
return SatelliteElevationDataset("data/ANDES/", 12)
# preprocessing: apply random jittering and mirroring to preprocess the training set
def transform(sample):

# print(len(sat_dataset))
# for sample in sat_dataset:
# print(sample['image'].shape)
transformation = transforms.Compose(
[transforms.Resize(286),
transforms.RandomCrop(256)])
flip = random.choice([transforms.RandomHorizontalFlip(0), transforms.RandomHorizontalFlip(1)])

# fig = plt.figure()
elevation_img, satellite_img = sample['elevation'], sample['satellite']

# for i in range(0,17):
# ax = plt.subplot(4,5, i+1)
# plt.tight_layout()
# sample = sat_dataset[i]
# plt.imshow(sample['image'])
elevation_img = transforms.ToTensor()(elevation_img)
elevation_img = transformation(elevation_img)
elevation_img = flip(elevation_img)

# plt.show()

satellite_img = transforms.ToTensor()(satellite_img)
satellite_img = transformation(satellite_img)
satellite_img = flip(satellite_img)

return {'elevation': elevation_img, 'satellite': satellite_img}

def GetDataset():
return SatelliteElevationDataset(["data/CALI/", "data/ANDES/"], [12, 12], transform=transform)
18 changes: 18 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import numpy as np
from skimage import io, transform
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from preprocess import GetDataset
import random

dataset = GetDataset()

fig = plt.figure()
sample = dataset[67]
elevation_img = sample['elevation']
elevation_img = torch.transpose(elevation_img, 0, 2)
elevation_img = torch.transpose(elevation_img, 1, 0)
plt.imshow(elevation_img)
plt.show()

0 comments on commit 09afa4d

Please sign in to comment.