Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving/restoring latent #86

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions big_sleep/big_sleep.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dill
import os
import sys
import subprocess
Expand Down Expand Up @@ -168,7 +169,8 @@ def __init__(
image_size,
max_classes = None,
class_temperature = 2.,
ema_decay = 0.99
ema_decay = 0.99,
restore_latents_filename = None
):
super().__init__()
assert image_size in (128, 256, 512), 'image size must be one of 128, 256, or 512'
Expand All @@ -177,8 +179,12 @@ def __init__(
self.class_temperature = class_temperature
self.ema_decay\
= ema_decay

self.init_latents()
if restore_latents_filename is None:
self.init_latents()
else:
old_state_backup = dill.load(open(restore_latents_filename, "rb"))
self.latents = old_state_backup.ema_backup


def init_latents(self):
latents = Latents(
Expand Down Expand Up @@ -208,6 +214,7 @@ def __init__(
experimental_resample = False,
ema_decay = 0.99,
center_bias = False,
restore_latents_filename = None
):
super().__init__()
self.loss_coef = loss_coef
Expand All @@ -222,7 +229,8 @@ def __init__(
image_size = image_size,
max_classes = max_classes,
class_temperature = class_temperature,
ema_decay = ema_decay
ema_decay = ema_decay,
restore_latents_filename = restore_latents_filename
)

def reset(self):
Expand Down Expand Up @@ -289,6 +297,10 @@ def forward(self, text_embeds, text_min_embeds=[], return_loss = True):
sim_loss = sum(results).mean()
return out, (lat_loss, cls_loss, sim_loss)

class CurrentStateBackup:
def __init__(self, ema, optimizer):
self.ema_backup = ema
self.optimizer_backup = optimizer

class Imagine(nn.Module):
def __init__(
Expand Down Expand Up @@ -318,12 +330,16 @@ def __init__(
ema_decay = 0.99,
num_cutouts = 128,
center_bias = False,
save_latents = False,
restore_latents_filename = None
):
super().__init__()

if torch_deterministic:
assert not bilinear, 'the deterministic (seeded) operation does not work with interpolation (PyTorch 1.7.1)'
torch.set_deterministic(True)

self.save_latents = save_latents

self.seed = seed
self.append_seed = append_seed
Expand All @@ -346,12 +362,19 @@ def __init__(
ema_decay = ema_decay,
num_cutouts = num_cutouts,
center_bias = center_bias,
restore_latents_filename = restore_latents_filename
).cuda()

self.model = model

self.lr = lr
self.optimizer = Adam(model.model.latents.model.parameters(), lr)

if restore_latents_filename is None:
self.optimizer = Adam(model.model.latents.model.parameters(), lr)
else:
old_state_backup = dill.load(open(restore_latents_filename, "rb"))
self.optimizer = old_state_backup.optimizer_backup

self.gradient_accumulate_every = gradient_accumulate_every
self.save_every = save_every

Expand Down Expand Up @@ -472,6 +495,10 @@ def train_step(self, epoch, i, pbar=None):
num = total_iterations // self.save_every

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest moving num above this if block so it can be used if save_progress hasn't been set to True

save_image(image, Path(f'./{self.text_path}.{num}{self.seed_suffix}.png'))

if self.save_latents:
current_state_backup = CurrentStateBackup(self.model.model.latents, self.optimizer)
dill.dump(current_state_backup, file = open(f'./{self.text_path}.{num}{self.seed_suffix}.backup', "wb"))

if self.save_best and top_score.item() < self.current_best_score:
self.current_best_score = top_score.item()
save_image(image, Path(f'./{self.text_path}{self.seed_suffix}.best.png'))
Expand Down