From 261f0654de605c6a260784e47e9a17a737a1a985 Mon Sep 17 00:00:00 2001 From: Sina Date: Thu, 24 Jun 2021 15:17:33 +0430 Subject: [PATCH] Add project --- .gitignore | 1 + README.md | 26 +++ main.py | 144 +++++++++++++++++ partedvae/__init__.py | 0 partedvae/models.py | 278 ++++++++++++++++++++++++++++++++ partedvae/training.py | 266 ++++++++++++++++++++++++++++++ requirements.txt | 8 + result/.gitignore | 1 + utils/__init__.py | 0 utils/dataloaders.py | 240 +++++++++++++++++++++++++++ utils/fast_tensor_dataloader.py | 54 +++++++ utils/load_model.py | 19 +++ utils/metrics.py | 140 ++++++++++++++++ viz/__init__.py | 0 viz/visualize.py | 259 +++++++++++++++++++++++++++++ 15 files changed, 1436 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 main.py create mode 100644 partedvae/__init__.py create mode 100644 partedvae/models.py create mode 100644 partedvae/training.py create mode 100644 requirements.txt create mode 100644 result/.gitignore create mode 100644 utils/__init__.py create mode 100644 utils/dataloaders.py create mode 100644 utils/fast_tensor_dataloader.py create mode 100644 utils/load_model.py create mode 100644 utils/metrics.py create mode 100644 viz/__init__.py create mode 100644 viz/visualize.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..485dee6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/README.md b/README.md new file mode 100644 index 0000000..2982e7c --- /dev/null +++ b/README.md @@ -0,0 +1,26 @@ +# PartedVAE + +Pytorch implementation of [Semi-Supervised Disentanglement of Class-Related and Class-Independent Factors in VAE](https://arxiv.org/abs/2102.00892) (PartedVAE). + +This repository's structure is based on the [joint-vae](https://github.com/Schlumberger/joint-vae) repository. + +## Usage + +Use `main.py` to train the model. Add needed tests and evaluations at the end. + +## Citing + +If you find our work useful in your research, please cite using: + +``` +@article{hajimiri2021semi, + title={Semi-Supervised Disentanglement of Class-Related and Class-Independent Factors in VAE}, + author={Hajimiri, Sina and Lotfi, Aryo and Soleymani Baghshah, Mahdieh}, + journal={arXiv preprint arXiv:2102.00892}, + year={2021} +} +``` + +## License + +MIT \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..e2bfa35 --- /dev/null +++ b/main.py @@ -0,0 +1,144 @@ +import itertools + +import torch +from torch import optim + +from partedvae.models import VAE +from partedvae.training import Trainer +from utils.dataloaders import get_dsprites_dataloader, get_mnist_dataloaders, get_celeba_dataloader +from utils.load_model import load +from viz.visualize import Visualizer +from utils.metrics import dis_by_fact_metric + +load_model_path = '/path/to/saved/model/' +dataset = 'celeba' + +LOAD_MODEL = False +LOAD_DATASET = True +TRAIN = True and LOAD_DATASET +SAVE = True and TRAIN +WARM_UP = True and TRAIN +RECON_TYPE = 'abs' if dataset == 'celeba' else 'bce' # 'mse' is also possible + +epochs = 80 + +batch_size = 64 +lr_warm_up = 5e-4 +lr_model = 5e-4 + +use_cuda = torch.cuda.is_available() +device = torch.device('cuda:0' if use_cuda else 'cpu') + + +def save(trainer, z_capacity, u_capacities, latent_spec, epochs, lr_warm_up, lr_model, dataset, recon_type): + torch.save(trainer.model.state_dict(), 'model.pt') + with open('specs.json', 'w') as f: + f.write('''{ + "z_capacity": %s, + "u_capacity": %s, + "latent_spec": %s, + "epochs": %d, + "lr_warm_up": %f, + "lr_model": %f, + "dataset": "%s", + "recon_type": "%s" + }''' % (str(z_capacity), str(u_capacities), str(latent_spec).replace("'", '"'), epochs, + lr_warm_up, lr_model, dataset, recon_type)) + + +if __name__ == '__main__': + if dataset == 'dsprites': + disc_priors = [[0.33, 0.33, 0.34]] + disc_count = len(disc_priors) + img_size = (1, 64, 64) + latent_spec = { + 'z': 5, + 'c': [3], + 'single_u': 1, + } + z_capacity = [0., 30., 300000, 50.] + u_capacity = [0., 5., 300000, 50.] + g_c, g_h = 100., 10. + g_bc = 10. + bc_threshold = 0.1 + elif dataset == 'mnist': + disc_priors = [[0.1] * 10] + disc_count = len(disc_priors) + img_size = (1, 32, 32) + latent_spec = { + 'z': 6, + 'c': [10], + 'single_u': 10 + } + z_capacity = [0., 7.0, 100000, 15] + u_capacity = [0., 7.0, 100000, 15] + g_c, g_h = 15., 30. + g_bc = 30. + bc_threshold = 0.15 + else: + disc_priors = [[0.42, 0.33, 0.18, 0.06], [0.9, 0.07, 0.03], [0.85, 0.15], [0.74, 0.15, 0.11], + [0.93, 0.07], [0.47, 0.53], [0.95, 0.05], [0.57, 0.43]] + disc_count = len(disc_priors) + img_size = (3, 218, 178) + latent_spec = { + 'z': 10, + 'c': [4, 3, 2, 3, 2, 2, 2, 2], + 'single_u': 1, + } + z_capacity = [0., 30., 125000, 1000.] + u_capacity = [0., 15., 125000, 1000.] + g_c, g_h = 2000., 10. + g_bc = 500. + bc_threshold = 0.2 + + if LOAD_DATASET: + if dataset == 'dsprites': + train_loader, warm_up_loader = get_dsprites_dataloader(batch_size=64, fraction=1, + path_to_data='../datasets/dsprites/ndarray.npz', + device=device, warm_up=WARM_UP) + test_loader = train_loader + elif dataset == 'mnist': + train_loader, test_loader, warm_up_loader = get_mnist_dataloaders(batch_size=64, + path_to_data='../datasets/', + device=device, warm_up=WARM_UP) + else: + train_loader, test_loader, warm_up_loader = get_celeba_dataloader(batch_size=64, + path_to_data='../datasets/', + device=device, warm_up=WARM_UP) + if not WARM_UP: + warm_up_loader = None + + if LOAD_MODEL: + # Note: When you load a model, capacities are restarted, which isn't intuitive if you are gonna re-train it + model = load(load_model_path, img_size=img_size, disc_priors=disc_priors, device=device) + model.sigmoid_coef = 8. + else: + model = VAE(img_size=img_size, latent_spec=latent_spec, c_priors=disc_priors, device=device) + + viz = Visualizer(model, root='result/') + + if TRAIN: + optimizer_warm_up = optim.Adam(itertools.chain(*[ + model.img_to_features.parameters(), + model.features_to_hidden.parameters(), + model.h_to_c_logit_fc.parameters() + ]), lr=lr_warm_up) + optimizer_model = optim.Adam(model.parameters(), lr=lr_model) + optimizers = [optimizer_warm_up, optimizer_model] + + trainer = Trainer(model, optimizers, dataset=dataset, device=device, recon_type=RECON_TYPE, + z_capacity=z_capacity, u_capacity=u_capacity, c_gamma=g_c, entropy_gamma=g_h, + bc_gamma=g_bc, bc_threshold=bc_threshold) + trainer.train(train_loader, warm_up_loader=warm_up_loader, epochs=epochs, run_after_epoch=None, + run_after_epoch_args=[]) + + if SAVE: + save(trainer, z_capacity, u_capacity, latent_spec, epochs, lr_warm_up, lr_model, dataset, RECON_TYPE) + + with torch.no_grad(): + if LOAD_DATASET: + loader = test_loader if test_loader else train_loader + for batch, labels in loader: + break + + viz.reconstructions(batch) diff --git a/partedvae/__init__.py b/partedvae/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/partedvae/models.py b/partedvae/models.py new file mode 100644 index 0000000..fc7f1ab --- /dev/null +++ b/partedvae/models.py @@ -0,0 +1,278 @@ +import itertools + +import numpy as np +import torch +from torch import nn + +EPS = 1e-12 + + +def my_tanh(x): + return 2 * torch.tanh(x) - 1 # std = e^{0.5 * logvar}: [0.22, 1.64] + + +class PartialLogSoftmax: + def __init__(self, dimensions, device): + self.device = device + self.sum_dims = sum(dimensions) + self.minus_inf = torch.tensor([float('-inf')], device=self.device, requires_grad=False) + self.eps = torch.tensor([EPS], device=self.device, requires_grad=False) + + self.unwrap_logits_mask = torch.zeros(1, self.sum_dims, self.sum_dims, len(dimensions), device=self.device, requires_grad=False) + self.scatter_log_denominator_mask = torch.zeros(len(dimensions), self.sum_dims, device=self.device, requires_grad=False) + + start = 0 + for dim_idx, size in enumerate(dimensions): + self.unwrap_logits_mask[0, torch.arange(start, start + size), torch.arange(start, start + size), dim_idx] = 1 + self.scatter_log_denominator_mask[dim_idx, torch.arange(start, start + size)] = 1 + start += size + + def __call__(self, logits): + logits = torch.where(logits == 0, self.eps, logits) # Later on, we replace 0s with -inf, but 0s in logits have meaning and shouldn't be replaced + unwrapped_logits = logits.unsqueeze(1).unsqueeze(2).matmul(self.unwrap_logits_mask).squeeze(2) # batch_size, sum_dims, count_dims + unwrapped_logits = torch.where(unwrapped_logits == 0, self.minus_inf, unwrapped_logits) + log_denominator = torch.logsumexp(unwrapped_logits, dim=1) + log_denominator = log_denominator.matmul(self.scatter_log_denominator_mask) + log_softmax = logits - log_denominator + return log_softmax + + +class VAE(nn.Module): + def __init__(self, img_size, latent_spec, c_priors, device, temperature=0.67): + super(VAE, self).__init__() + self.device = device + + # Parameters + self.img_size = img_size + self.has_indep = latent_spec.get('z', 0) > 0 + self.has_dep = len(latent_spec.get('c', [])) > 0 + if self.has_dep and latent_spec.get('single_u', 0) < 0: + raise RuntimeError('Model has c variables but u_dim is 0') + + self.latent_spec = latent_spec + self.num_pixels = img_size[0] * img_size[1] * img_size[2] + self.temperature = temperature + self.hidden_dim = 256 + self.reshape = (64, 5, 5) if img_size[1:] == (218, 178) else (64, 4, 4) # Shape required to start transpose convs + + self.z_dim, self.single_dep_cont_dim, self.u_dim = 0, 0, 0 + self.c_dims, self.c_count, self.sum_c_dims = 0, 0, 0 + self.c_priors = [] + if self.has_indep: + self.z_dim = self.latent_spec['z'] + + if self.has_dep: + self.c_priors = torch.tensor(sum(c_priors, []), device=self.device, requires_grad=False) + self.c_dims = self.latent_spec['c'] + self.c_count = len(self.c_dims) + self.sum_c_dims = sum(self.c_dims) + self.single_u_dim = self.latent_spec['single_u'] + self.u_dim = self.single_u_dim * self.c_count + + self.latent_dim = self.z_dim + self.u_dim + + # Encoder + if self.img_size[1:] == (32, 32) or self.img_size[1:] == (64, 64): + encoder_layers = [ + nn.Conv2d(self.img_size[0], 32, (4, 4), stride=2, padding=1), + nn.ReLU() + ] + if self.img_size[1:] == (64, 64): + encoder_layers += [ + nn.Conv2d(32, 32, (4, 4), stride=2, padding=1), + nn.ReLU() + ] + encoder_layers += [ + nn.Conv2d(32, 64, (4, 4), stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(64, 64, (4, 4), stride=2, padding=1), + nn.LeakyReLU(0.1), + ] + elif self.img_size[1:] == (218, 178): + encoder_layers = [ + nn.Conv2d(self.img_size[0], 32, (4, 4), stride=2, padding=(0, 2)), # out: 108 x 90 + nn.ReLU(), + nn.Conv2d(32, 32, (4, 4), stride=2, padding=(1, 2)), # 54 x 46 + nn.ReLU(), + nn.Conv2d(32, 64, (4, 4), stride=2, padding=(0, 2)), # 26 x 24 + nn.ReLU(), + nn.Conv2d(64, 64, (4, 4), stride=2, padding=(0, 1)), # 12 x 12 + nn.ReLU(), + nn.Conv2d(64, 64, (4, 4), stride=2, padding=0), # 5 x 5 + nn.LeakyReLU(0.1), + ] + else: + raise RuntimeError('img_size not supported') + + # Define encoder + self.img_to_features = nn.Sequential(*encoder_layers) + + self.features_to_hidden = nn.Sequential( + nn.Linear(np.prod(self.reshape), self.hidden_dim), + nn.LeakyReLU(0.1), + ) + + # Latent Space + # FC: Fully Connected, PC: Partially Connected + if self.has_indep: + self.z_mean_fc = nn.Linear(self.hidden_dim, self.z_dim) + self.z_logvar_fc = nn.Linear(self.hidden_dim, self.z_dim) + + if self.has_dep: + self.h_to_c_logit_fc = nn.Linear(self.hidden_dim, self.sum_c_dims) + + self.c_to_a_logit_pc = nn.Linear(self.sum_c_dims, self.c_count * self.hidden_dim) # A Sigmoid should be placed after this layer + self.c_to_a_logit_mask = torch.zeros(self.c_count * self.hidden_dim, self.sum_c_dims, requires_grad=False) + h_start = 0 + for i, h_size in enumerate(self.c_dims): + v_start, v_end = i * self.hidden_dim, (i + 1) * self.hidden_dim + h_end = h_start + h_size + indices = itertools.product(range(v_start, v_end), range(h_start, h_end)) + self.c_to_a_logit_mask[list(zip(*indices))] = 1 # It actually unzips :D + h_start = h_end + with torch.no_grad(): + self.c_to_a_logit_pc.weight.mul_(self.c_to_a_logit_mask) + + self.h_dot_a_to_u_mean_pc = nn.Linear(self.c_count * self.hidden_dim, self.u_dim) + self.h_dot_a_to_u_logvar_pc = nn.Linear(self.c_count * self.hidden_dim, self.u_dim) + self.h_dot_a_to_u_mask = torch.zeros(self.u_dim, self.c_count * self.hidden_dim, requires_grad=False) + for i, dim in enumerate(self.c_dims): + v_start, v_end = i * self.single_u_dim, (i + 1) * self.single_u_dim + h_start, h_end = i * self.hidden_dim, (i + 1) * self.hidden_dim + indices = itertools.product(range(v_start, v_end), range(h_start, h_end)) + self.h_dot_a_to_u_mask[list(zip(*indices))] = 1 + with torch.no_grad(): + self.h_dot_a_to_u_mean_pc.weight.mul_(self.h_dot_a_to_u_mask) + self.h_dot_a_to_u_logvar_pc.weight.mul_(self.h_dot_a_to_u_mask) + + # These lines should be after the multiplications in torch.no_grad(), because model (and therefore h_dot_e_to_u_mean_pc.weight) hasn't gone to GPU yet + self.c_to_a_logit_mask = self.c_to_a_logit_mask.to(self.device) + self.h_dot_a_to_u_mask = self.h_dot_a_to_u_mask.to(self.device) + + # Decoder + self.latent_to_features = nn.Sequential( + nn.Linear(self.latent_dim, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, np.prod(self.reshape)), + nn.ReLU() + ) + + if self.img_size[1:] == (32, 32) or self.img_size[1:] == (64, 64): + if self.img_size[1:] == (64, 64): + decoder_layers = [ + nn.ConvTranspose2d(64, 64, (4, 4), stride=2, padding=1), + nn.ReLU() + ] + else: + decoder_layers = list() + decoder_layers += [ + nn.ConvTranspose2d(64, 32, (4, 4), stride=2, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(32, 32, (4, 4), stride=2, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(32, self.img_size[0], (4, 4), stride=2, padding=1), + nn.Sigmoid() + ] + elif self.img_size[1:] == (218, 178): + decoder_layers = [ + nn.ConvTranspose2d(64, 64, (4, 4), stride=2, padding=0), + nn.ReLU(), + nn.ConvTranspose2d(64, 64, (4, 4), stride=2, padding=(0, 1)), + nn.ReLU(), + nn.ConvTranspose2d(64, 32, (4, 4), stride=2, padding=(0, 2)), + nn.ReLU(), + nn.ConvTranspose2d(32, 32, (4, 4), stride=2, padding=(1, 2)), + nn.ReLU(), + nn.ConvTranspose2d(32, self.img_size[0], (4, 4), stride=2, padding=(0, 2)), + nn.Sigmoid() + ] + + self.features_to_img = nn.Sequential(*decoder_layers) + + # Define psi network + self.u_prior_means = nn.Parameter(torch.randn(self.sum_c_dims, self.single_u_dim), requires_grad=True) + self.u_prior_logvars_before_tanh = nn.Parameter(torch.randn(self.sum_c_dims, self.single_u_dim), requires_grad=True) + + self.sigmoid_coef = 1 + + self.logsoftmaxer = PartialLogSoftmax(self.c_dims, device=self.device) + + self.to(self.device) + + @property + def u_prior_logvars(self): + return my_tanh(self.u_prior_logvars_before_tanh) + + def my_sigmoid(self, x): + if not self.training or self.sigmoid_coef > 8: + return torch.sigmoid(8 * x) + if self.sigmoid_coef < 8: + self.sigmoid_coef += 2e-4 + return torch.sigmoid(self.sigmoid_coef * x) + + def encode(self, x, only_disc_dist=False): # x: (N, C, H, W) + batch_size = x.size()[0] + + features = self.img_to_features(x) + hidden = self.features_to_hidden(features.view(batch_size, -1)) + + latent_dist = dict() + if self.has_dep: + c_logit = self.h_to_c_logit_fc(hidden) + latent_dist['log_c'] = self.logsoftmaxer(c_logit) + if only_disc_dist: + return latent_dist + + sampled_c = self.sample_gumbel_partial_softmax(c_logit) # One hot (sort of) + a_logit = self.c_to_a_logit_pc(sampled_c) + a = self.my_sigmoid(a_logit) + h_dot_a = hidden.repeat(1, self.c_count) * a + latent_dist['u'] = [self.h_dot_a_to_u_mean_pc(h_dot_a), self.h_dot_a_to_u_logvar_pc(h_dot_a)] + + if self.has_indep: + latent_dist['z'] = [self.z_mean_fc(hidden), self.z_logvar_fc(hidden)] + + return latent_dist + + def reparameterize(self, latent_dist): + latent_sample = list() + for mean, logvar in [latent_dist['z'], latent_dist['u']]: + sample = self.sample_normal(mean, logvar) + latent_sample.append(sample) + return torch.cat(latent_sample, dim=1) + + def sample_normal(self, mean, logvar): + if self.training: + std = torch.exp(0.5 * logvar) + eps = torch.zeros(std.size(), device=self.device).normal_() + return mean + std * eps + else: + return mean + + def sample_gumbel_partial_softmax(self, c_logit): + if self.training: + unif = torch.rand(c_logit.size(), device=self.device) + gumbel = -torch.log(-torch.log(unif + EPS) + EPS) + logit = (c_logit + gumbel) / self.temperature + return torch.exp(self.logsoftmaxer(logit)) + else: + alphas = torch.exp(self.logsoftmaxer(c_logit)) + one_hot_samples = torch.zeros(alphas.size(), device=self.device) + start = 0 + for size in self.c_dims: # Here speed is not that important + alpha = alphas[:, start:start+size] + _, max_alpha = torch.max(alpha, dim=1) + one_hot_sample = torch.zeros(alpha.size(), device=self.device) + one_hot_sample.scatter_(1, max_alpha.view(-1, 1).data, 1) + one_hot_samples[:, start:start+size] = one_hot_sample + start += size + return one_hot_samples + + def decode(self, latent_sample): + features = self.latent_to_features(latent_sample) + return self.features_to_img(features.view(-1, *self.reshape)) + + def forward(self, x): + latent_dist = self.encode(x) + latent_sample = self.reparameterize(latent_dist) + return self.decode(latent_sample), latent_dist diff --git a/partedvae/training.py b/partedvae/training.py new file mode 100644 index 0000000..3259970 --- /dev/null +++ b/partedvae/training.py @@ -0,0 +1,266 @@ +import itertools +import math +import random +from time import time + +import numpy as np +import pandas as pd +import torch +from torch.nn import functional as F +from torch.optim.lr_scheduler import ReduceLROnPlateau + +EPS = 1e-12 + + +class Trainer: + def __init__(self, model, optimizers, dataset=None, device=None, recon_type=None, z_capacity=None, + u_capacity=None, c_gamma=None, entropy_gamma=None, bc_gamma=None, bc_threshold=None): + self.dataset = dataset + self.device = device + + self.model = model.to(self.device) + + self.optimizer_warm_up, self.optimizer_model = optimizers + if dataset == 'dsprites': + self.scheduler_warm_up = ReduceLROnPlateau(self.optimizer_warm_up, factor=0.5, patience=1, threshold=1e-1, + threshold_mode='rel', cooldown=2, min_lr=0, eps=1e-06, + verbose=True) + self.scheduler_model = ReduceLROnPlateau(self.optimizer_model, factor=0.5, patience=1, threshold=1e-2, + threshold_mode='rel', cooldown=4, min_lr=0, eps=1e-07, + verbose=True) + else: + self.scheduler_warm_up = ReduceLROnPlateau(self.optimizer_warm_up, factor=0.5, patience=2, threshold=1e-1, + threshold_mode='rel', cooldown=3, min_lr=0, eps=1e-06, + verbose=True) + self.scheduler_model = ReduceLROnPlateau(self.optimizer_model, factor=0.5, patience=2, threshold=1e-2, + threshold_mode='rel', cooldown=4, min_lr=0, eps=1e-07, + verbose=True) + + self.recon_type = recon_type + + self.z_capacity = z_capacity + self.u_capacity = u_capacity + self.c_gamma = c_gamma + self.entropy_gamma = entropy_gamma + self.bc_gamma = bc_gamma + self.bc_threshold = bc_threshold + + # The following variable is used in computing KLD, it is computed here once, for speeding up + self.u_kl_func_valid_indices = torch.zeros(self.model.sum_c_dims, dtype=torch.long, device=self.device, requires_grad=False) + start = 0 + for value, disc_dim in enumerate(self.model.c_dims): + self.u_kl_func_valid_indices[start:start + disc_dim] = value + start += disc_dim + + # Used in computing KLs and Entropy for each random variable + self.unwrap_mask = torch.zeros(self.model.sum_c_dims, self.model.sum_c_dims, self.model.c_count, device=self.device, requires_grad=False) + start = 0 + for dim_idx, size in enumerate(self.model.c_dims): + self.unwrap_mask[torch.arange(start, start + size), torch.arange(start, start + size), dim_idx] = 1 + start += size + + # Used in computing BC + self.u_valid_prior_BC_mask = torch.zeros(self.model.sum_c_dims, self.model.sum_c_dims, device=self.device, requires_grad=False) + start = 0 + for dim_idx, size in enumerate(self.model.c_dims): + indices = itertools.product(range(start, start + size), range(start, start + size)) + self.u_valid_prior_BC_mask[list(zip(*indices))] = 1 + start += size + self.u_valid_prior_BC_mask.tril_(diagonal=-1) + + self.num_steps = 0 + self.batch_size = None + + def train(self, data_loader, warm_up_loader=None, epochs=10, run_after_epoch=None, run_after_epoch_args=None): + self.batch_size = data_loader.batch_size + self.model.train() + + if warm_up_loader is None: + print('No warm-up') + + for epoch in range(epochs): + t = time() + warm_up_mean_loss, separated_mean_epoch_loss = self._train_epoch(data_loader, warm_up_loader) + others, kl_z_dims, df = [ + list(separated_mean_epoch_loss[:12]), + list(np.round(separated_mean_epoch_loss[12:12+self.model.z_dim], 2)), + pd.DataFrame((100 * separated_mean_epoch_loss[-3 * self.model.c_count:]).round(2).reshape(3, -1).transpose(), columns=['100 * E[KL(U_{C_i})]', '100 * KL(C_i)', '100 * H(q(c_i|x))']) + ] + with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', None): + print(""" +Epoch: %d\t MLoss: %.3f \tWarmUpLoss: %.7f\t\tTime: %.2f' +%.3f \t+ %.5f\t+ %.5f\t+ %.5f\t+ %.5f\t+ %.5f +(Recon) \tz %.5f\tu %.5f\tc %.5f\th %.5f\tbc %.5f +z kls: \t%s +\t%s\n""" % (epoch + 1, others[0], warm_up_mean_loss, round((time() - t) / 60, 2), *others[1:], str(kl_z_dims), str(df).replace('\n', '\n\t ')), flush=True) + + if warm_up_loader is not None: + self.scheduler_warm_up.step(warm_up_mean_loss) + self.scheduler_model.step(separated_mean_epoch_loss[0]) + if run_after_epoch is not None: + run_after_epoch(epoch, *run_after_epoch_args) + + def _train_epoch(self, data_loader, warm_up_loader=None): + warm_up_loss, separated_sum_loss = 0, 0 + for batch_idx, (data, label) in enumerate(data_loader): + warm_up_loss += self._warm_up(warm_up_loader) + separated_sum_loss += self._train_iteration(data) + + return warm_up_loss / (batch_idx + 1), separated_sum_loss / len(data_loader.dataset) + + def _warm_up(self, loader): + if not loader: + return 0 + + epoch_loss = 0 + for data, label in loader: + data, label = data.to(self.device), label.to(self.device) + + self.optimizer_warm_up.zero_grad() + latent_dist = self.model.encode(data, only_disc_dist=True) + sum_ce = torch.sum(-1 * label * latent_dist['log_c'], dim=1) + loss = torch.mean(sum_ce) + loss.backward() + self.optimizer_warm_up.step() + epoch_loss += loss.item() + + if random.random() < 0.001: + print('%.3f' % epoch_loss, end=', ', flush=False) + return epoch_loss + + def _train_iteration(self, data): + self.num_steps += 1 + data = data.to(self.device) + + recon_batch, latent_dist = self.model(data) + loss, separated_mean_loss = self._loss_function(data, recon_batch, latent_dist) + if np.isnan(separated_mean_loss[0]): + raise Exception('NaN!') + + self.optimizer_model.zero_grad() + loss.backward() + # Applying partially connected layers + with torch.no_grad(): + self.model.c_to_a_logit_pc.weight.grad.mul_(self.model.c_to_a_logit_mask) + self.model.h_dot_a_to_u_mean_pc.weight.grad.mul_(self.model.h_dot_a_to_u_mask) + self.model.h_dot_a_to_u_logvar_pc.weight.grad.mul_(self.model.h_dot_a_to_u_mask) + self.optimizer_model.step() + + return separated_mean_loss * data.size(0) + + def _loss_function(self, data, recon_data, latent_dist): + if self.recon_type.lower() == 'bce': + recon_loss = F.binary_cross_entropy(recon_data, data, reduction='none') + elif self.recon_type.lower() == 'mse': + recon_loss = F.mse_loss(recon_data, data, reduction='none') + else: + recon_loss = torch.abs(recon_data - data) + recon_loss = torch.mean(torch.sum(recon_loss, dim=[1, 2, 3])) + + tmp_zero = torch.zeros(1, device=self.device, requires_grad=False) + z_kl, z_loss, each_c_kl, c_loss, each_c_entropy, c_entropy_loss, u_loss, each_u_expected_kl = 8 * [tmp_zero] + bc, priors_intersection_loss = 2 * [tmp_zero] + mean_bc = 0 + + if self.model.has_indep: + mean, logvar = latent_dist['z'] + z_each_dim_kl = self._kld_each_dim_with_standard_gaussian(mean, logvar) + z_kl = torch.sum(z_each_dim_kl) + cap_min, cap_max, num_iters, gamma = self.z_capacity + cap_current = (cap_max - cap_min) * self.num_steps / float(num_iters) + cap_min + cap_current = min(cap_current, cap_max) + z_loss = gamma * torch.abs(cap_current - z_kl) + + if self.model.has_dep: + each_c_kl = self._each_c_kl_loss(latent_dist['log_c']) + c_loss = self.c_gamma * torch.sum(each_c_kl) + + each_c_entropy = self._each_c_entropy(latent_dist['log_c']) + c_entropy_loss = self.entropy_gamma * torch.sum(each_c_entropy) + + each_u_expected_kl = self._each_u_expected_kl_loss(latent_dist['log_c'], latent_dist['u']) + cap_min, cap_max, num_iters, gamma = self.u_capacity + cap_current = (cap_max - cap_min) * self.num_steps / float(num_iters) + cap_min + cap_current = min(cap_current, cap_max) + u_loss = gamma * torch.abs(cap_current - torch.sum(each_u_expected_kl)) + + bc = self._bhattacharyya_coefficient_inter_priors(self.model.u_prior_means, self.model.u_prior_logvars, self.u_valid_prior_BC_mask) + priors_intersection_loss = self.bc_gamma * torch.sum(torch.clamp_min(bc - self.bc_threshold, min=0)) + mean_bc = (torch.sum(bc) / torch.sum(self.u_valid_prior_BC_mask)).item() + + # Total loss + total_loss = recon_loss + z_loss + c_loss + u_loss + c_entropy_loss + priors_intersection_loss + + return ( + total_loss, + np.array([total_loss.item(), recon_loss.item(), z_loss.item(), u_loss.item(), c_loss.item(), + c_entropy_loss.item(), priors_intersection_loss.item(), + z_kl.item(), torch.sum(each_u_expected_kl).item(), torch.sum(each_c_kl).item(), + torch.sum(each_c_entropy).item(), mean_bc, + *z_each_dim_kl.detach().cpu().numpy(), + *each_u_expected_kl.detach().cpu().numpy(), + *each_c_kl.detach().cpu().numpy(), + *each_c_entropy.detach().cpu().numpy()]) + ) + + def _kld_each_dim_with_standard_gaussian(self, mean, logvar): + kl_values = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()) + kl_means = torch.mean(kl_values, dim=0) + return kl_means + + def _bhattacharyya_coefficient_inter_priors(self, mu, logvar, mask): + variance = torch.exp(logvar) + avg_var = 0.5 * (variance.unsqueeze(0) + variance.unsqueeze(1)) # n_priors, n_priors, d + inv_avg_var = 1 / (avg_var + EPS) + diff_mean = mu.unsqueeze(0) - mu.unsqueeze(1) + db_first_term = 1/8 * torch.sum(diff_mean * inv_avg_var * diff_mean, dim=2) # n_priors, n_priors + db_second_term = 0.5 * (torch.sum(torch.log(avg_var + EPS), dim=2) + - 0.5 * (torch.sum(logvar, dim=1).unsqueeze(0) + torch.sum(logvar, dim=1).unsqueeze(1))) + db = db_first_term + db_second_term + bc = torch.exp(-db) + valid_bc = bc.mul(mask) + return valid_bc + + def _kld_each_dim_data_and_priors(self, prior_means, prior_logvars, batch_mean, batch_logvar, parts_count, valid_indices): + n_priors, d = prior_means.size() + + batch_logvar = batch_logvar.view(-1, 1, parts_count, d).expand(-1, n_priors, -1, -1)[:, torch.arange(n_priors), valid_indices, :] + batch_var = torch.exp(batch_logvar) # batch_size, n_priors, d + + diff_mean_with_invalid_items = prior_means.unsqueeze(0).unsqueeze(2) - batch_mean.view(-1, 1, parts_count, d) # batch_size, n_priors, disc_count, d + diff_mean = diff_mean_with_invalid_items[:, torch.arange(n_priors), valid_indices, :] # batch_size, n_priors, d + + priors_unsqueezed_inv_var = torch.exp(-1 * prior_logvars).unsqueeze(0) # 1, n_priors, d + + return 0.5 * ( + prior_logvars.unsqueeze(0) # 1, n_priors, d + - batch_logvar # batch_size, n_priors, d + - 1 + + priors_unsqueezed_inv_var * batch_var # batch_size, n_priors, d + + diff_mean * priors_unsqueezed_inv_var * diff_mean # batch_size, n_priors, d + ) # batch_size, n_priors, d + + def _each_u_expected_kl_loss(self, log_q_cs_given_x, u_dist): + each_dim_kld = self._kld_each_dim_data_and_priors(self.model.u_prior_means, self.model.u_prior_logvars, + u_dist[0], u_dist[1], + self.model.c_count, self.u_kl_func_valid_indices) + kld = torch.sum(each_dim_kld, dim=2) + kld_dot_prob = torch.exp(log_q_cs_given_x) * kld + unwrapped_kld_dot_prob = kld_dot_prob.unsqueeze(1).unsqueeze(2).matmul(self.unwrap_mask.unsqueeze(0)).squeeze(2) + expected_kld = torch.sum(unwrapped_kld_dot_prob, dim=1) # batch_size, disc_count + each_u_expected_kl_loss = torch.mean(expected_kld, dim=0) + return each_u_expected_kl_loss + + def _each_c_kl_loss(self, log_q_cs_given_x): # log_q_cs_given_x's size is batch * sum_disc_dim + log_q_cs = torch.logsumexp(log_q_cs_given_x, dim=0) - math.log(log_q_cs_given_x.size(0)) + q_log_q_on_p = torch.exp(log_q_cs) * (log_q_cs - torch.log(self.model.c_priors + EPS)) + unwrapped_q_log_q_on_p = q_log_q_on_p.matmul(self.unwrap_mask) + each_c_kl_loss = torch.sum(unwrapped_q_log_q_on_p, dim=0) + return each_c_kl_loss + + def _each_c_entropy(self, log_q_cs_given_x): + q_log_q = torch.exp(log_q_cs_given_x) * log_q_cs_given_x + unwrapped_q_log_q = q_log_q.matmul(self.unwrap_mask) + batch_each_c_neg_entropy = torch.sum(unwrapped_q_log_q, dim=0) + each_c_entropy = -1 * torch.mean(batch_each_c_neg_entropy, dim=0) + return each_c_entropy diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..964a273 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +torch~=1.6.0 +torchvision~=0.7.0 +scipy~=1.5.0 +imageio~=2.9.0 +scikit-image~=0.16.2 +scikit-learn~=0.23.1 +pandas~=1.1.0 +pillow~=7.2.0 diff --git a/result/.gitignore b/result/.gitignore new file mode 100644 index 0000000..8d1c8b6 --- /dev/null +++ b/result/.gitignore @@ -0,0 +1 @@ + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/dataloaders.py b/utils/dataloaders.py new file mode 100644 index 0000000..9a4a243 --- /dev/null +++ b/utils/dataloaders.py @@ -0,0 +1,240 @@ +import numpy as np +import torch +from torch.nn import functional as F +from torch.utils.data import Dataset, DataLoader +from torchvision import datasets, transforms +from torchvision.utils import save_image + +from utils.fast_tensor_dataloader import FastTensorDataLoader + + +def get_mnist_dataloaders(batch_size=128, path_to_data='../data', warm_up=True, device=None): + data_transforms = transforms.Compose([ + transforms.Resize(32), + transforms.ToTensor() + ]) + target_transform = lambda x: F.one_hot(torch.tensor(x), num_classes=10) + train_data = datasets.MNIST(path_to_data, train=True, download=True, transform=data_transforms, target_transform=target_transform) + test_data = datasets.MNIST(path_to_data, train=False, transform=data_transforms, target_transform=target_transform) + train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) + test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) + + warm_up_loader = None + if warm_up: + warm_up_x, warm_up_y = WarmUpMNISTDataset(path_to_data, count=256, transform=data_transforms, target_transform=target_transform, device=device).get_tensors() + warm_up_loader = FastTensorDataLoader(warm_up_x, warm_up_y, batch_size=batch_size, shuffle=True) + + return train_loader, test_loader, warm_up_loader + + +class WarmUpMNISTDataset(datasets.MNIST): + def __init__(self, root, transform=None, target_transform=None, download=False, count=256, device=None): + self.__class__.__name__ = 'MNIST' # This is used in directory structure of datasets.MNIST + super(WarmUpMNISTDataset, self).__init__(root, train=True, transform=transform, target_transform=target_transform, download=download) + self.device = device + self.count = count + self.delete = set() + self.mapping = list(set(range(count + len(self.delete))) - self.delete) + + self.save_all_images() + + def __len__(self): + return self.count + + def __getitem__(self, index): + translated_index = self.mapping[index] + return super().__getitem__(translated_index) + + def get_tensors(self): + x_shape, y_shape = self[0][0].shape, self[0][1].shape + x, y = torch.zeros(self.count, *x_shape, device=self.device), torch.zeros(self.count, *y_shape, device=self.device) + for i, (data, label) in enumerate(self): + x[i], y[i] = data.to(self.device), label.to(self.device) + return x, y + + def save_all_images(self): + x_shape = self[0][0].shape + all_images = torch.zeros(self.count, *x_shape) + for i, (data, label) in enumerate(self): + all_images[i] = data + save_image(all_images, 'warm_up.png', nrow=(len(self) // 16)) + + +def get_celeba_dataloader(batch_size=128, path_to_data='../celeba_64', device=None, warm_up=True): + data_transforms = transforms.Compose([ + # transforms.Resize(64), + # transforms.CenterCrop(64), + transforms.ToTensor(), + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + dataset_kwargs = { + 'target_type': 'attr', + 'transform': data_transforms, + } + train_data = datasets.CelebA(path_to_data, split='train', download=True, **dataset_kwargs) + test_data = datasets.CelebA(path_to_data, split='test', **dataset_kwargs) + # warm_up_data = WarmUpCelebADataset(path_to_data, split='train', target_transform=target_transforms, **dataset_kwargs) + + dataloader_kwargs = { + 'batch_size': batch_size, + 'shuffle': True, + 'pin_memory': device.type != 'cpu', + # 'pin_memory': False, + 'num_workers': 0 if device.type == 'cpu' else 4, + } + train_loader = DataLoader(train_data, **dataloader_kwargs) + test_loader = DataLoader(test_data, **dataloader_kwargs) + # warm_up_loader = DataLoader(warm_up_data, **dataloader_kwargs) + + warm_up_loader = None + if warm_up: + # target_transforms = transforms.Compose([ + # lambda x: x[celeba_good_columns], + # # lambda x: torch.flatten(F.one_hot(x, num_classes=2)) + # my_celeba_target_transfrom + # ]) + warm_up_x, warm_up_y = WarmUpCelebADataset(path_to_data, count=800, device=device, **dataset_kwargs).get_tensors() # TODO: If it is good, make the class simpler + warm_up_loader = FastTensorDataLoader(warm_up_x, warm_up_y, batch_size=batch_size, shuffle=True) + + return train_loader, test_loader, warm_up_loader + + +class WarmUpCelebADataset(datasets.CelebA): + def __init__(self, root, split="train", target_type="attr", transform=None, target_transform=None, download=False, + count=256, device=None): + super().__init__(root, split, target_type, transform, target_transform, download) + self.count = count + self.device = device + # self.delete = {2, 36, 43, 66, 74, 96, 119, 148, 149, 162, 166, 168, 183, 188, 198} # From 0 to 255+15 + # self.delete = {43, 74, 162, 183} # From 0 to 299 + self.delete = set() + self.mapping = list(set(range(count + len(self.delete))) - self.delete) + + self.labels = torch.tensor(np.genfromtxt('warm_up_labels.csv', delimiter=','), dtype=torch.float) + + # self.save_all_images() + + def __len__(self): + return self.count + + def __getitem__(self, index): + # return super().__getitem__(index) + translated_index = self.mapping[index] + x, _ = super().__getitem__(translated_index) + return x, self.labels[translated_index] + + def get_tensors(self): + x_shape, y_shape = self[0][0].shape, self[0][1].shape + x, y = torch.zeros(self.count, *x_shape, device=self.device), torch.zeros(self.count, *y_shape, device=self.device) + for i, (data, label) in enumerate(self): + x[i], y[i] = data.to(self.device), label.to(self.device) + return x, y + + def save_all_images(self): + x_shape = self[0][0].shape + all_images = torch.zeros(self.count, *x_shape) + for i, (data, label) in enumerate(self): + all_images[i] = data + save_image(all_images, 'warm_up.png', nrow=(len(self) // 16)) + + +def get_dsprites_dataloader(batch_size=128, path_to_data='../dsprites/ndarray.npz', fraction=1., device=None, warm_up=False): + dsprites_data = DSpritesDataset(path_to_data, fraction=fraction, device=device) + # dsprites_loader = FastTensorDataLoader(*dsprites_data.get_tensors(), batch_size=batch_size, shuffle=True) # Comment if you have memory limits, and uncomment the next line + dataloader_kwargs = { + 'batch_size': batch_size, + 'shuffle': True, + 'pin_memory': device.type != 'cpu', + 'num_workers': 0 if device.type == 'cpu' else 4, + } + dsprites_loader = DataLoader(dsprites_data, **dataloader_kwargs) + + warm_up_loader = None + if warm_up: + warm_up_data = DSpritesWarmUpDataset(path_to_data, device=device) + warm_up_loader = FastTensorDataLoader(*warm_up_data.get_tensors(), batch_size=batch_size, shuffle=True) + + return dsprites_loader, warm_up_loader + + +class DSpritesWarmUpDataset(Dataset): + # Color[1], Shape[3], Scale, Orientation, PosX, PosY + def __init__(self, path_to_data, size=10000, device=None): # was 100, 737, 1000, 3686, 10000 + self.device = device + data = np.load(path_to_data) + indices = self.good_indices(size) + self.imgs = np.expand_dims(data['imgs'][indices], axis=1) + + shape_value = data['latents_classes'][indices, 1] + self.classes = np.zeros((size, 3)) + self.classes[np.arange(size), shape_value] = 1 + + print(np.mean(self.classes, axis=0)) + + def good_indices(self, size): + # if size < 3 * 6 * 2 * 2 * 2: + # raise Exception('Too small!') + indices = np.zeros(size, dtype=np.long) + # [1, 3, 6, 40, 32, 32] + module = np.array([737280, 245760, 40960, 1024, 32, 1]) + i = 0 + while True: + for y_span in range(2): + for x_span in range(2): + for orientation_span in range(2): + for scale in range(6): + for shape in range(3): + orientation = int(np.random.randint(0, 20, 1) + orientation_span * 20) + x = int(np.random.randint(0, 16, 1) + x_span * 16) + y = int(np.random.randint(0, 16, 1) + y_span * 16) + sample = np.array([0, shape, scale, orientation, x, y]) + indices[i] = np.sum(sample * module) + i += 1 + if i >= size: + return indices + + def __len__(self): + return len(self.imgs) + + def __getitem__(self, idx): + return self.imgs[idx], self.classes[idx] + + def get_tensors(self): + return torch.tensor(self.imgs, dtype=torch.float, device=self.device), torch.tensor(self.classes, device=self.device) + + +class DSpritesDataset(Dataset): + # Color[1], Shape[3], Scale, Orientation, PosX, PosY + def __init__(self, path_to_data, fraction=1., device=None): + self.device = device + data = np.load(path_to_data) + self.imgs = data['imgs'] + self.imgs = np.expand_dims(self.imgs, axis=1) + self.classes = data['latents_classes'] + if fraction < 1: + indices = np.random.choice(737280, size=int(fraction * 737280), replace=False) + self.imgs = self.imgs[indices] + self.classes = self.classes[indices] + # self.attrs = data['latents_values'][indices] + # self.transform = transform + + def __len__(self): + return len(self.imgs) + + def __getitem__(self, idx): + # # Each image in the dataset has binary values so multiply by 255 to get + # # pixel values + # sample = self.imgs[idx] * 255 + # # Add extra dimension to turn shape into (H, W) -> (H, W, C) + # sample = sample.reshape(sample.shape + (1,)) + + # if self.transform: + # sample = self.transform(sample) + # Since there are no labels, we just return 0 for the "label" here + # return sample, (self.classes[idx], self.attrs[idx]) + + # return torch.tensor(self.imgs[idx], dtype=torch.float, device=self.device), torch.tensor(self.classes[idx], device=self.device) + return torch.tensor(self.imgs[idx], dtype=torch.float), torch.tensor(self.classes[idx]) + + def get_tensors(self): + return torch.tensor(self.imgs, dtype=torch.float, device=self.device), torch.tensor(self.classes, device=self.device) diff --git a/utils/fast_tensor_dataloader.py b/utils/fast_tensor_dataloader.py new file mode 100644 index 0000000..bff60b3 --- /dev/null +++ b/utils/fast_tensor_dataloader.py @@ -0,0 +1,54 @@ +# https://github.com/hcarlens/pytorch-tabular/blob/master/fast_tensor_data_loader.py +import torch + + +class FastTensorDataLoader: + """ + A DataLoader-like object for a set of tensors that can be much faster than + TensorDataset + DataLoader because dataloader grabs individual indices of + the dataset and calls cat (slow). + Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 + """ + def __init__(self, *tensors, batch_size=32, shuffle=False, drop_last=True): + """ + Initialize a FastTensorDataLoader. + :param *tensors: tensors to store. Must have the same length @ dim 0. + :param batch_size: batch size to load. + :param shuffle: if True, shuffle the data *in-place* whenever an + iterator is created out of this object. + :returns: A FastTensorDataLoader. + """ + assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) + self.tensors = tensors + + self.dataset_len = self.tensors[0].shape[0] + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + + # Calculate # batches + n_batches, remainder = divmod(self.dataset_len, self.batch_size) + if remainder > 0 and not self.drop_last: + n_batches += 1 + self.n_batches = n_batches + + def __iter__(self): + if self.shuffle: + r = torch.randperm(self.dataset_len) + self.tensors = [t[r] for t in self.tensors] + self.i = 0 + return self + + def __next__(self): + if self.i >= self.dataset_len or (self.drop_last and self.i + self.batch_size > self.dataset_len): + raise StopIteration + batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors) + self.i += self.batch_size + return batch + + def __len__(self): + return self.n_batches + + @property + def dataset(self): + return self.tensors[0] diff --git a/utils/load_model.py b/utils/load_model.py new file mode 100644 index 0000000..993b333 --- /dev/null +++ b/utils/load_model.py @@ -0,0 +1,19 @@ +import json + +import torch + +from partedvae.models import VAE + + +def load(path, img_size, disc_priors, device): + path_to_specs = path + 'specs.json' + path_to_model = path + 'model.pt' + + with open(path_to_specs) as specs_file: + specs = json.load(specs_file) + latent_spec = specs["latent_spec"] + + model = VAE(img_size=img_size, latent_spec=latent_spec, c_priors=disc_priors, device=device) + model.load_state_dict(torch.load(path_to_model, map_location=lambda storage, loc: storage)) + + return model diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000..584a0c7 --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,140 @@ +import numpy as np +from sklearn import svm, linear_model +import torch + + +def all_metrics(dsprites_loader, model): + latents, classes = get_all_repr(dsprites_loader, model) + bvae = beta_vae_metric(latents, classes) + fact = dis_by_fact_metric(latents, classes) + sap = compute_sap_score(latents, classes) + return sap, fact, bvae + + +def get_all_repr(dsprites_loader, model): + model.eval() + with torch.no_grad(): + latents = np.zeros((len(dsprites_loader.dataset), model.latent_dim)) + classes = np.zeros((len(dsprites_loader.dataset), 6)) + start = 0 + for i, (x, c) in enumerate(dsprites_loader): + latent_dist = model.encode(x.to(model.device)) + r = range(start, start + x.size(0)) + latents[r, :] = torch.cat([latent_dist['z'][0], latent_dist['u'][0]], dim=1).cpu().numpy() + classes[r] = c.cpu().numpy() + start += x.size(0) + return latents, classes + + +# beta-VAE +L = 100 +dsprites_classes_num_states = [1, 3, 6, 40, 32, 32] + +def get_diffs_and_labels(latents, classes): + D = latents.shape[1] + K = classes.shape[1] + diffs, labels = list(), list() + for k in range(1, K): # Ignore 'Color' + for val in range(dsprites_classes_num_states[k]): + all_fk = np.where(classes[:, k] == val)[0] + for i in range(len(all_fk) // (2 * L)): + r = range(i * (2 * L), (i + 1) * (2 * L)) + current = all_fk[r] + current_latents = latents[current].reshape(2, -1, D) + diff = np.abs(current_latents[0, :, :] - current_latents[1, :, :]) + diffs.append(np.mean(diff, axis=0)) + labels.append(np.array([k])) + diffs = np.stack(diffs, axis=0) + labels = np.stack(labels, axis=0) + return diffs, labels + + +def beta_vae_metric(latents, classes): + N = latents.shape[0] + train_z, train_y, test_z, test_y = subsample_train_and_test(latents, classes, int(N * 0.8), int(N * 0.2)) + train_diffs, train_labels = get_diffs_and_labels(train_z, train_y) + test_diffs, test_labels = get_diffs_and_labels(test_z, test_y) + + model = linear_model.LogisticRegression() + model.fit(train_diffs, train_labels) + train_score = model.score(train_diffs, train_labels) + test_score = model.score(test_diffs, test_labels) + + print(round(train_score, 4), round(test_score, 4), flush=True) + return train_score, test_score + + +# FactorVAE +def dis_by_fact_metric(latents, classes): + D = latents.shape[1] + K = classes.shape[1] + + stds = np.std(latents, axis=0, keepdims=True) + normalized_latents = latents / stds + + stats = np.zeros((K, D)) + for k in range(1, K): # Ignore 'Color' + for val in range(dsprites_classes_num_states[k]): + all_fk = np.where(classes[:, k] == val)[0] + for i in range(len(all_fk) // L): + r = range(i * L, (i + 1) * L) + current = all_fk[r] + vars = np.var(normalized_latents[current], axis=0) + d_star = np.argmin(vars) + stats[k, d_star] += 1 + + # Prune collapsed latent dimensions + print(stds) + effective_stats = np.copy(stats) + for i in range(D): + if stds[0, i] < 1e-1: + effective_stats[:, i] = 0 + with open('latents.npy', 'wb') as f: + np.save(f, latents) + + # A single latent dimension should only correspond to a single factor of variation, but a factor of variation can relate to multiple latent elements + score = np.sum(np.max(effective_stats, axis=0)) / np.sum(effective_stats) + np.set_printoptions(suppress=True) + print(stats) + print(score, flush=True) + return score, stats # or effective_stats? + + +# SAP +def subsample_train_and_test(latents, classes, num_train, num_test): + indices = np.random.choice(np.arange(len(latents)), size=num_train+num_test, replace=False) + # indices = torch.multinomial(torch.arange(len(latents)), num_train+num_test, replacement=False) + train_indices, test_indices = indices[:num_train], indices[num_train:] + return latents[train_indices], classes[train_indices], latents[test_indices], classes[test_indices] + + +def compute_sap_score(latents, classes): + train_z, train_y, test_z, test_y = subsample_train_and_test(latents, classes, 20000, 5000) + + matrix = compute_score_matrix(train_z, train_y, test_z, test_y) + # print(matrix) + score = compute_avg_diff_top_two(matrix) + print('SAP:', score, flush=True) + return score + + +# From https://github.com/google-research/disentanglement_lib +def compute_score_matrix(mus, ys, mus_test, ys_test): + D = mus.shape[1] + K = ys.shape[1] + score_matrix = np.zeros((D, K)) + for i in range(D): + for j in range(1, K): + mu_i, mu_i_test = mus[:, i], mus_test[:, i] + y_j, y_j_test = ys[:, j], ys_test[:, j] + classifier = svm.LinearSVC(C=0.01, class_weight='balanced', dual=False) + classifier.fit(mu_i[:, np.newaxis], y_j) + pred = classifier.predict(mu_i_test[:, np.newaxis]) + score_matrix[i, j] = np.mean(pred == y_j_test) + return score_matrix + + +def compute_avg_diff_top_two(matrix): + sorted_matrix = np.sort(matrix, axis=0) + diff = sorted_matrix[-1, :] - sorted_matrix[-2, :] + return np.mean(diff[1:]) # Ignore 'Color' diff --git a/viz/__init__.py b/viz/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/viz/visualize.py b/viz/visualize.py new file mode 100644 index 0000000..b1be9bd --- /dev/null +++ b/viz/visualize.py @@ -0,0 +1,259 @@ +import math + +import numpy as np +import torch +from PIL import Image +from scipy import stats +from torchvision.utils import make_grid, save_image + + +def my_save_image(grid, filename, resize=False): + # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer + ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + im = Image.fromarray(ndarr) + if resize: + im = im.resize((im.size[0] // 3, im.size[1] // 3), Image.ANTIALIAS) + im.save(filename) + + +class Visualizer: + def __init__(self, model, root='result/'): + self.device = model.device + self.model = model + self.root = root + + def reconstructions(self, data, size=(8, 8), filename='recon.png'): + # Plot reconstructions in test mode, i.e. without sampling from latent + self.model.eval() + data = data.to(self.device) + with torch.no_grad(): + recon_data, _ = self.model(data) + self.model.train() + + # Upper half of plot will contain data, bottom half will contain reconstructions + num_images = size[0] * size[1] // 2 + originals = data[:num_images].cpu() + reconstructions = recon_data.view(-1, *self.model.img_size)[:num_images].cpu() + # If there are fewer examples given than spaces available in grid, augment with blank images + num_examples = originals.size()[0] + if num_images > num_examples: + blank_images = torch.zeros((num_images - num_examples,) + originals.size()[1:]) + originals = torch.cat([originals, blank_images]) + reconstructions = torch.cat([reconstructions, blank_images]) + comparison = torch.cat([originals, reconstructions]) + + save_image(comparison.data, self.root + filename, nrow=size[0], pad_value=0.3) + # return make_grid(comparison.data, nrow=size[0]) + + def _traverse_standard_gaussian(self, idx, size, d, sample_prior=False): # TODO size not used in cdf_traversals + samples = torch.randn(size, d, device=self.device) if sample_prior else torch.zeros(size, d, device=self.device) + + if idx is not None: + # Sweep over linearly spaced coordinates transformed through the inverse CDF (ppf) of + # a gaussian since the prior of the latent space is gaussian + # cdf_traversal = np.linspace(0.05, 0.95, size) + cdf_traversal = np.array([0.001, 0.01, 0.1, 0.25, 0.4, 0.6, 0.75, 0.9, 0.99, 0.999]) + cont_traversal = torch.tensor(stats.norm.ppf(cdf_traversal), device=self.device) + samples[:, idx] += cont_traversal + + return samples + + def _traverse_custom_gaussian(self, idx, size, mean, std): # sample_prior not implemented # TODO size not used in cdf_traversals + samples = mean.unsqueeze(0).repeat(size, 1) + + if idx is not None: + # Sweep over linearly spaced coordinates transformed through the inverse CDF (ppf) of + # a gaussian since the prior of the latent space is gaussian + # cdf_traversal = np.linspace(0.05, 0.95, size) + cdf_traversal = np.array([0.001, 0.01, 0.1, 0.25, 0.4, 0.6, 0.75, 0.9, 0.99, 0.999]) + cont_traversal = torch.tensor(stats.norm.ppf(cdf_traversal), device=self.device) + samples[:, idx] += std[idx] * cont_traversal + + return samples + + # C[j] = k | j in [0, C count), k in [0, disc_dim) + def traverse_with_fix_c(self, j, k, dz_mean, dz_logvar, size=10, path='./', filename_prefix='', filename_suffix='.png', resize=True): + dz_std = torch.exp(0.5 * dz_logvar) + rows = list() + for cont_idx in range(self.model.z_dim): + line = list() + line.append(self._traverse_standard_gaussian(cont_idx, size, self.model.z_dim)) + line.append(self._traverse_custom_gaussian(None, size, dz_mean, dz_std)) + rows.append(torch.cat(line, dim=1)) + for dz_idx in range(j * self.model.single_u_dim, (j + 1) * self.model.single_u_dim): + line = list() + line.append(self._traverse_standard_gaussian(None, size, self.model.z_dim)) + line.append(self._traverse_custom_gaussian(dz_idx, size, dz_mean, dz_std)) + rows.append(torch.cat(line, dim=1)) + + generated = self._decode_latents(torch.cat(rows, dim=0)) + + filename = filename_prefix + 'c_' + str(j) + '_' + str(k) + filename_suffix + grid = make_grid(generated.data, nrow=size, pad_value=0.3) + + # Add a red line to distinguish z from dz + place = self.model.z_dim * (self.model.img_size[1] + 2) + grid[0, place:place + 2, 2:-2] = 1 + + # Transfer to a new grid + rows = self.model.z_dim + self.model.single_u_dim + height = math.ceil(rows / 2) * (self.model.img_size[1] + 2) + 2 + width = grid.size(2) * 2 + 2 + new_grid = torch.zeros(3, height, width, device=self.device) + new_grid[:, :, :grid.size(2)] = grid[:, :height, :] + new_grid[1, :, grid.size(2):grid.size(2) + 2] = 1 + new_grid[:, 2:grid.size(1) - height + 2, grid.size(2) + 2:] = grid[:, height:, :] + + my_save_image(new_grid, self.root + path + filename, resize=resize) + + def celeba_all_traversals(self, path='traversals_normal/', bang=None, gender=None, beard=None, hat=None, resize=True): + with torch.no_grad(): + null = torch.zeros(0, device=self.device) + dz_base_prior_indices = np.cumsum([0] + self.model.c_dims[:-1]) + for disc_idx, disc_dim in enumerate(self.model.c_dims): + for i in range(disc_dim): + priors_indices = dz_base_prior_indices.copy() + priors_indices[disc_idx] += i + if disc_idx in [beard, hat]: + priors_indices[gender] += 1 + mean, logvar = self.model.u_prior_means[priors_indices].flatten(), self.model.u_prior_logvars[priors_indices].flatten() + self.traverse_with_fix_c(disc_idx, i, mean, logvar, null, null, path=path, resize=resize) + if disc_idx == bang: + priors_indices[gender] += 1 + mean, logvar = self.model.u_prior_means[priors_indices].flatten(), self.model.u_prior_logvars[priors_indices].flatten() + self.traverse_with_fix_c(disc_idx, i, mean, logvar, null, null, path=path, filename_prefix='male_', resize=resize) + + def traverse_desired_u(self, desired_us, dim, name): + z = torch.zeros(self.model.z_dim).unsqueeze(0) + priors_indices = np.cumsum([0] + self.model.c_dims[:-1]) + mean = self.model.u_prior_means[priors_indices].detach().flatten().unsqueeze(0) + u_list = list() + for desired_u in desired_us: + u = mean.clone() + u[0, dim] = desired_u + u_list.append(torch.cat([z, u], dim=1)) + generated = self._decode_latents(torch.cat(u_list, dim=0)) + + generated = generated[:, :, 20:-20, :] + + grid = make_grid(generated.data, nrow=len(desired_us), pad_value=0) + my_save_image(grid, self.root + '%s.png' % name, resize=False) + + def transform(self, images): + intermediary = 8 + N, image_shape = images.size()[0], images.size()[1:] + assert N % 2 == 0 + all_images = torch.zeros(N // 2, intermediary + 2, *image_shape) + all_images[:, 0], all_images[:, -1] = images[:N//2], images[N//2:] + + latent_dist = self.model.encode(images) + z = latent_dist['z'][0].view(2, N // 2, -1).unsqueeze(2) + u = latent_dist['u'][0].view(2, N // 2, -1).unsqueeze(2) + coefs = torch.linspace(0, 1, intermediary).view(1, -1, 1) + z_interpolation = z[0] + (z[1] - z[0]) * coefs + u_interpolation = u[0] + (u[1] - u[0]) * coefs + latents = torch.cat([z_interpolation, u_interpolation], dim=2) + all_images[:, 1:-1] = self._decode_latents(latents).view(N//2, intermediary, *image_shape) + grid = make_grid(all_images.view(-1, *image_shape), nrow=intermediary+2, pad_value=0) + my_save_image(grid, self.root + 'transform.png', resize=False) + + def swap(self, images): + count = len(images) + image_shape = images.size()[1:] + latent_dist = self.model.encode(images) + z = latent_dist['z'][0] + u = latent_dist['u'][0] + latents = list() + for i in range(count): + for j in range(count): + latents.append(torch.cat([z[j], u[i]]).unsqueeze(0)) + all_images = torch.zeros(count+1, count+1, *image_shape) + all_images[0, 1:] = images + all_images[1:, 0] = images + # all_images[torch.arange(count+1, (count+1)*(count+1), count+1)] = images + all_images[1:, 1:] = self._decode_latents(torch.cat(latents, dim=0)).view(count, count, *image_shape) + grid = make_grid(all_images.view(-1, *image_shape), nrow=count+1, pad_value=0) + my_save_image(grid, self.root + 'swap.png', resize=False) + + def z_traversal(self): + intermediary = 10 + linespace = torch.linspace(-3, 3, intermediary) + priors_indices = np.cumsum([0] + self.model.c_dims[:-1]) + u = self.model.u_prior_means[priors_indices].detach().flatten().unsqueeze(0) + latents = list() + + for i in [2, 3, 4, 6, 8, 9]: + z = torch.zeros(intermediary, self.model.z_dim) + z[:, i] = linespace + latents.append(torch.cat([z, u.expand(intermediary, -1)], dim=1)) + generated = self._decode_latents(torch.cat(latents, dim=0)) + + generated = generated[:, :, 20:-20, :] + + grid = make_grid(generated, nrow=intermediary, pad_value=0) + my_save_image(grid, self.root + 'z.png', resize=False) + + def celeba_u_traversal(self): # 2 sigma for u, 0.1 * z + size = 10 + with torch.no_grad(): + base_prior_indices = np.cumsum([0] + self.model.c_dims[:-1]) + default_female = self.model.u_prior_means[base_prior_indices].flatten() + base_prior_indices[-1] += 1 + default_male = self.model.u_prior_means[base_prior_indices].flatten() + pairs = [[3, 9, 10, False], [3, 9, 11, True], [4, 12, 13, False], [5, 14, 15, False], [6, 16, 17, True], [7, 18, 19, False]] + rows = list() + for dim, left, right, male in pairs: + z = torch.randn(1, self.model.z_dim).expand(size, -1) * 0.5 + u = (default_male.clone() if male else default_female.clone()).unsqueeze(0).repeat(size, 1) + left_mu, left_std = self.model.u_prior_means[left, 0], torch.exp(0.5 * self.model.u_prior_logvars[left, 0]) + right_mu, right_std = self.model.u_prior_means[right, 0], torch.exp(0.5 * self.model.u_prior_logvars[right, 0]) + minn = min(left_mu - 2*left_std, right_mu - 2*right_std) + maxx = max(left_mu + 2*left_std, right_mu + 2*right_std) + u[:, dim] = torch.linspace(minn, maxx, size) + rows.append(torch.cat([z, u], dim=1)) + generated = self._decode_latents(torch.cat(rows, dim=0)) + + generated = generated[:, :, 20:-20, :] + + grid = make_grid(generated, nrow=size, pad_value=0) + my_save_image(grid, self.root + 'u.png', resize=False) + + def no_bc_celeba_u_traversal(self): # 2 sigma for u, 0.1 * z + size = 10 + with torch.no_grad(): + base_prior_indices = np.cumsum([0] + self.model.c_dims[:-1]) + default_female = self.model.u_prior_means[base_prior_indices].flatten() + base_prior_indices[-1] += 1 + default_male = self.model.u_prior_means[base_prior_indices].flatten() + pairs = [[1, 4, 6, True], [2, 7, 8, True], [4, 12, 13, False], [6, 16, 17, True]] + rows = list() + for dim, left, right, male in pairs: + z = torch.randn(1, self.model.z_dim).expand(size, -1) * 0.5 + u = (default_male.clone() if male else default_female.clone()).unsqueeze(0).repeat(size, 1) + left_mu, left_std = self.model.u_prior_means[left, 0], torch.exp(0.5 * self.model.u_prior_logvars[left, 0]) + right_mu, right_std = self.model.u_prior_means[right, 0], torch.exp(0.5 * self.model.u_prior_logvars[right, 0]) + minn = min(left_mu - 2*left_std, right_mu - 2*right_std) + maxx = max(left_mu + 2*left_std, right_mu + 2*right_std) + u[:, dim] = torch.linspace(minn, maxx, size) + rows.append(torch.cat([z, u], dim=1)) + generated = self._decode_latents(torch.cat(rows, dim=0)) + + generated = generated[:, :, 20:-20, :] + + grid = make_grid(generated, nrow=size, pad_value=0) + my_save_image(grid, self.root + 'u.png', resize=False) + + def all_traversals(self, path='traversals_normal/'): + with torch.no_grad(): + null = torch.zeros(0, device=self.device) + dz_base_prior_indices = np.cumsum([0] + self.model.c_dims[:-1]) + for disc_idx, disc_dim in enumerate(self.model.c_dims): + for i in range(disc_dim): + priors_indices = dz_base_prior_indices.copy() + priors_indices[disc_idx] += i + mean, logvar = self.model.u_prior_means[priors_indices].flatten(), self.model.u_prior_logvars[priors_indices].flatten() + self.traverse_with_fix_c(disc_idx, i, mean, logvar, null, null, path=path, resize=False) + + def _decode_latents(self, latent_samples): + with torch.no_grad(): + return self.model.decode(latent_samples).cpu()