Skip to content

Library agnostic to pytorch model #270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mplc/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@
TITANIC = "titanic"
ESC50 = "esc50"
IMDB = 'imdb'
CIFAR100 = "cifar100"
# Supported datasets
SUPPORTED_DATASETS_NAMES = [MNIST, CIFAR10, TITANIC, ESC50, IMDB]
SUPPORTED_DATASETS_NAMES = [MNIST, CIFAR10, TITANIC, ESC50, IMDB, CIFAR100]

# Number of attempts allowed before raising an error while trying to download dataset
NUMBER_OF_DOWNLOAD_ATTEMPTS = 3
76 changes: 74 additions & 2 deletions mplc/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from librosa.feature import mfcc
from loguru import logger
from sklearn.model_selection import train_test_split
from tensorflow.keras.datasets import cifar10, mnist, imdb
from tensorflow.keras.datasets import cifar10, cifar100, mnist, imdb
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Conv2D, GlobalAveragePooling2D, MaxPooling2D
from tensorflow.keras.layers import Dense, Dropout
Expand All @@ -29,7 +29,8 @@
from tensorflow.keras.utils import to_categorical

from . import constants
from .models import LogisticRegression
from .models import LogisticRegression, ModelPytorch
from torchvision import models


class Dataset(ABC):
Expand Down Expand Up @@ -194,6 +195,77 @@ def generate_new_model(self):
return model


class Cifar100(Dataset):
def __init__(self):
self.input_shape = (3, 32, 32)
self.num_classes = 100
x_test, x_train, y_test, y_train = self.load_data()

super(Cifar100, self).__init__(dataset_name='cifar100',
num_classes=self.num_classes,
input_shape=self.input_shape,
x_train=x_train,
y_train=y_train,
x_test=x_test,
y_test=y_test)

def load_data(self):
attempts = 0
while True:
try:
(x_train, y_train), (x_test, y_test) = cifar100.load_data()
break
except (HTTPError, URLError) as e:
if hasattr(e, 'code'):
temp = e.code
else:
temp = e.errno
logger.debug(
f'URL fetch failure on '
f'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz : '
f'{temp} -- {e.reason}')
if attempts < constants.NUMBER_OF_DOWNLOAD_ATTEMPTS:
sleep(2)
attempts += 1
else:
raise

# Pre-process inputs
x_train = self.preprocess_dataset_inputs(x_train)
x_test = self.preprocess_dataset_inputs(x_test)
# y_train = self.preprocess_dataset_labels(y_train)
# y_test = self.preprocess_dataset_labels(y_test)
return x_test, x_train, y_test, y_train

# Data samples pre-processing method for inputs
@staticmethod
def preprocess_dataset_inputs(x):
x = x.astype("float32")
x /= 255

return x

# Data samples pre-processing method for labels
def preprocess_dataset_labels(self, y):
y = to_categorical(y, self.num_classes)

return y

# Model structure and generation
def generate_new_model(self):
model = ModelPytorch()
return model

# train, test, val splits
@staticmethod
def train_test_split_local(x, y):
return train_test_split(x, y, test_size=0.1, random_state=42)

@staticmethod
def train_val_split_local(x, y):
return train_test_split(x, y, test_size=0.1, random_state=42)


class Titanic(Dataset):
def __init__(self, proportion=1,
val_proportion=0.1):
Expand Down
157 changes: 157 additions & 0 deletions mplc/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import numpy as np
import collections
from joblib import dump, load
from loguru import logger
from sklearn.linear_model import LogisticRegression as skLR
from sklearn.metrics import log_loss
from tensorflow.keras.backend import dot
from tensorflow.keras.layers import Dense
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms


class LogisticRegression(skLR):
Expand Down Expand Up @@ -88,6 +94,157 @@ def load_model(path):
path.replace('.h5', '.joblib')
return load(path)

class cifar100_dataset(torch.utils.data.Dataset):

def __init__(self, x, y, transform=[]):
self.x = x
self.y = y
self.transform = transform

def __len__(self):
return len(self.x)

def __getitem__(self, index):

x = self.x[index]
y = torch.tensor(int(self.y[index][0]))

if self.transform:
x = self.transform(x)

return x, y

class ModelPytorch(nn.Module):
def __init__(self):
super(ModelPytorch, self).__init__()
model = torchvision.models.vgg16()
self.features = nn.Sequential(model.features)
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(7, 7))
self.classifier = nn.Sequential(
nn.Linear(25088, 4096),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5, inplace=False),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5, inplace=False),
nn.Linear(4096, 1000)
)
self.optimizer = optim.Adam(model.parameters(), lr=1e-3)


def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return self.classifier(x)


def fit(self, x_train, y_train, batch_size, validation_data, epochs=1, verbose=False, callbacks=None):
criterion = nn.CrossEntropyLoss()
transform = transforms.Compose([transforms.ToTensor()])

train_data = cifar100_dataset(x_train, y_train, transform)
train_loader = data.DataLoader(train_data, batch_size=int(batch_size), shuffle=True)

history = super(ModelPytorch, self).train()

for batch_idx, (image, label) in enumerate(train_loader):
images, labels = torch.autograd.Variable(image), torch.autograd.Variable(label)

outputs = self.forward(images)
loss = criterion(outputs, labels)

self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

[loss, acc] = self.evaluate(x_train, y_train)
[val_loss, val_acc] = self.evaluate(*validation_data)
# Mimic Keras' history
history.history = {
'loss': [loss],
'accuracy': [acc],
'val_loss': [val_loss],
'val_accuracy': [val_acc]
}

return history

def evaluate(self, x_eval, y_eval, **kwargs):
criterion = nn.CrossEntropyLoss()
transform = transforms.Compose([transforms.ToTensor()])

test_data = cifar100_dataset(x_eval, y_eval, transform)
test_loader = data.DataLoader(test_data, shuffle=True)

self.eval()

with torch.no_grad():

y_true_np = []
y_pred_np = []
count=0
for i, (images, labels) in enumerate(test_loader):
count+= 1
N = images.size(0)

images = torch.autograd.Variable(images)
labels = torch.autograd.Variable(labels)

outputs = self(images)
predictions = outputs.max(1, keepdim=True)[1]

val_loss =+ criterion(outputs, labels).item()
val_acc =+ (predictions.eq(labels.view_as(predictions)).sum().item() / N)

model_evaluation = [val_loss/count, val_acc/count]

return model_evaluation


def save_weights(self, path):
if '.h5' in path:
logger.debug('Automatically switch file format from .h5 to .pth')
path.replace('.h5', '.pth')
torch.save(self.state_dict(), path)


def load_weights(self, path):
if '.h5' in path:
logger.debug('Automatically switch file format from .h5 to .pth')
path.replace('.h5', '.pth')
weights = torch.load(path)
self.set_weights(weights)


def get_weights(self):
self.state_dict()
weights = []
for layer in self.state_dict().keys():
weights.append(self.state_dict()[layer].numpy())
return weights


def set_weights(self, weights):
for i, layer in enumerate(self.state_dict().keys()):
self.state_dict()[layer]= torch.Tensor(weights[i])


def save_model(self, path):
if '.h5' in path:
logger.debug('Automatically switch file format from .h5 to .pth')
path.replace('.h5', '.pth')
torch.save(self, path)


@staticmethod
def load_model(path):
if '.h5' in path:
logger.debug('Automatically switch file format from .h5 to .pth')
path.replace('.h5', '.pth')
model = torch.load(path)
return model.eval()


class NoiseAdaptationChannel(Dense):
"""
Expand Down
2 changes: 2 additions & 0 deletions mplc/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def __init__(
self.dataset = dataset_module.Esc50()
elif dataset_name == constants.IMDB:
self.dataset = dataset_module.Imdb()
elif dataset_name == constants.CIFAR100:
self.dataset = dataset_module.Cifar100()
else:
raise Exception(
f"Dataset named '{dataset_name}' is not supported (yet). You can construct your own "
Expand Down