Skip to content

Commit 554ef6f

Browse files
committed
refactoring
1 parent 99684b2 commit 554ef6f

8 files changed

+1206
-0
lines changed

ddpm_infer_latent.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from utils.ddpm_utils import load_latent_trajectory
2+
from utils.ddpm_utils import plot_results_reconstruction
3+
import torch
4+
from models.ddpm_models import ContextUnet, DDPM
5+
from omegaconf import OmegaConf
6+
from utils.metric_dataloader import MetricDataPreprocessor
7+
from utils.ddpm_utils import get_diffusion_outputs_from_z
8+
from utils.plot_utils import plot_engine_run_diff
9+
from utils.plot_utils import plot_engine_run_diff_decision_boundary
10+
from utils.plot_utils import reconstruct_and_plot
11+
from utils.tshae_utils import load_tshae_model
12+
import hydra
13+
import pickle
14+
import os
15+
16+
17+
@hydra.main(version_base=None, config_path="./configs", config_name="config.yaml")
18+
def main(config):
19+
20+
extrapolated_z_path = config.diffusion.extrapolated_latent.path
21+
extrapolated_z = load_latent_trajectory(extrapolated_z_path)
22+
print(extrapolated_z.keys())
23+
24+
tshae_checkpoint_path = config.diffusion.checkpoint_tshae.path
25+
print(tshae_checkpoint_path)
26+
tshae_config_path = os.path.dirname(tshae_checkpoint_path) + "/.hydra/config.yaml"
27+
tshae_config = OmegaConf.load(tshae_config_path)
28+
29+
preproc = MetricDataPreprocessor(**tshae_config.data_preprocessor)
30+
train_loader, test_loader, val_loader = preproc.get_dataloaders()
31+
print(f"train set: {len(train_loader.dataset)} val set: {len(val_loader.dataset)}")
32+
model_tshae = load_tshae_model(tshae_checkpoint_path)
33+
#print(model_tshae)
34+
35+
hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
36+
output_dir = hydra_cfg['runtime']['output_dir']
37+
print(f"output dir: {output_dir}")
38+
39+
ddpm_checkpoint_path = config.diffusion.checkpoint_ddpm.path
40+
ddpm_config_path = os.path.dirname(ddpm_checkpoint_path) + "/.hydra/config.yaml"
41+
ddpm_checkpoint_config = OmegaConf.load(ddpm_config_path)
42+
43+
n_T = ddpm_checkpoint_config.diffusion.ddpm_train.n_T # 500
44+
device = ddpm_checkpoint_config.diffusion.ddpm_train.device #"cuda:0" or "cpu"#
45+
z_dim = ddpm_checkpoint_config.diffusion.ddpm_train.z_dim
46+
n_feat = ddpm_checkpoint_config.diffusion.ddpm_train.n_feat # 128 ok, 256 better (but slower)
47+
drop_prob = ddpm_checkpoint_config.diffusion.ddpm_model.drop_prob
48+
49+
ddpm = DDPM(
50+
nn_model=ContextUnet(
51+
in_channels=1,
52+
n_feat=n_feat,
53+
z_dim=z_dim),
54+
betas=(1e-4, 0.02),
55+
n_T=n_T,
56+
device=device,
57+
drop_prob=drop_prob)
58+
ddpm.load_state_dict(torch.load(config.diffusion.checkpoint_ddpm.path))
59+
ddpm.eval().to(device)
60+
model_tshae.eval().to(device)
61+
62+
engine_runs = get_diffusion_outputs_from_z(
63+
z_space_dict=extrapolated_z,
64+
tshae_model=model_tshae,
65+
diffusion_model=ddpm,
66+
dataloader=val_loader,
67+
num_samples=config.diffusion.diffusion_tester.num_samples,
68+
w=config.diffusion.diffusion_tester.w,
69+
quantile=config.diffusion.diffusion_tester.quantile,
70+
mode=config.diffusion.diffusion_tester.mode
71+
)
72+
73+
pickle_path = os.path.join(output_dir, "engine_runs_diff.pickle")
74+
with open(pickle_path, 'wb') as handle:
75+
pickle.dump(engine_runs, handle)
76+
77+
plot_results_reconstruction(engine_runs, output_dir)
78+
79+
if __name__ == "__main__":
80+
main()

ddpm_infer_validation.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
import numpy as np
3+
from models.ddpm_models import ContextUnet, DDPM
4+
from omegaconf import OmegaConf
5+
from utils.metric_dataloader import MetricDataPreprocessor
6+
from utils.ddpm_utils import get_diffusion_outputs_from_dataloader
7+
from utils.plot_utils import plot_engine_run_diff
8+
from utils.plot_utils import plot_engine_run_diff_decision_boundary
9+
from utils.tshae_utils import load_tshae_model
10+
import hydra
11+
import pickle
12+
import os
13+
14+
15+
@hydra.main(version_base=None, config_path="./configs", config_name="config.yaml")
16+
def test(config):
17+
18+
tshae_checkpoint_path = config.diffusion.checkpoint_tshae.path
19+
print(tshae_checkpoint_path)
20+
tshae_config_path = os.path.dirname(tshae_checkpoint_path) + "/.hydra/config.yaml"
21+
tshae_config = OmegaConf.load(tshae_config_path)
22+
23+
preproc = MetricDataPreprocessor(**tshae_config.data_preprocessor)
24+
train_loader, test_loader, val_loader = preproc.get_dataloaders()
25+
print(f"train set: {len(train_loader.dataset)} val set: {len(val_loader.dataset)}")
26+
model_tshae = load_tshae_model(tshae_checkpoint_path)
27+
#print(model_tshae)
28+
29+
hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
30+
output_dir = hydra_cfg['runtime']['output_dir']
31+
print(f"output dir: {output_dir}")
32+
33+
ddpm_checkpoint_path = config.diffusion.checkpoint_ddpm.path
34+
ddpm_config_path = os.path.dirname(ddpm_checkpoint_path) + "/.hydra/config.yaml"
35+
ddpm_checkpoint_config = OmegaConf.load(ddpm_config_path)
36+
37+
n_T = ddpm_checkpoint_config.diffusion.ddpm_train.n_T # 500
38+
device = ddpm_checkpoint_config.diffusion.ddpm_train.device #"cuda:0" or "cpu"#
39+
z_dim = ddpm_checkpoint_config.diffusion.ddpm_train.z_dim
40+
n_feat = ddpm_checkpoint_config.diffusion.ddpm_train.n_feat # 128 ok, 256 better (but slower)
41+
drop_prob = ddpm_checkpoint_config.diffusion.ddpm_model.drop_prob
42+
43+
ddpm = DDPM(
44+
nn_model=ContextUnet(
45+
in_channels=1,
46+
n_feat=n_feat,
47+
z_dim=z_dim),
48+
betas=(1e-4, 0.02),
49+
n_T=n_T,
50+
device=device,
51+
drop_prob=drop_prob)
52+
ddpm.load_state_dict(torch.load(config.diffusion.checkpoint_ddpm.path))
53+
ddpm.eval().to(device)
54+
model_tshae.eval().to(device)
55+
56+
val_ids = val_loader.dataset.ids
57+
print(val_loader.dataset.ids)
58+
engine_runs = get_diffusion_outputs_from_dataloader(
59+
tshae_model=model_tshae,
60+
diffusion_model=ddpm,
61+
dataloader=val_loader,
62+
num_samples=config.diffusion.diffusion_tester.num_samples,
63+
w=config.diffusion.diffusion_tester.w,
64+
quantile=config.diffusion.diffusion_tester.quantile,
65+
mode=config.diffusion.diffusion_tester.mode
66+
)
67+
68+
pickle_path = os.path.join(output_dir, "engine_runs_diff.pickle")
69+
with open(pickle_path, 'wb') as handle:
70+
pickle.dump(engine_runs, handle)
71+
for engine in engine_runs.keys():
72+
plot_engine_run_diff(
73+
engine_runs,
74+
engine_id=engine,
75+
img_path=output_dir,
76+
save=True)
77+
78+
plot_engine_run_diff_decision_boundary(
79+
model_tshae,
80+
engine_runs,
81+
img_path=output_dir,
82+
engine_id=engine,
83+
title="engine_run_boundary",
84+
save=True,
85+
show=False)
86+
87+
88+
if __name__ == "__main__":
89+
test()

ddpm_plot_history.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import numpy as np
3+
from utils.plot_utils import plot_engine_run_diff, plot_engine_run_diff_decision_boundary, reconstruct_and_plot
4+
from models.ddpm_models import ContextUnet, DDPM
5+
from utils.metric_dataloader import MetricDataPreprocessor
6+
from utils.tshae_utils import load_tshae_model
7+
import hydra
8+
import pickle
9+
10+
11+
@hydra.main(version_base=None, config_path="./configs", config_name="config.yaml")
12+
def plot(config):
13+
14+
preproc = MetricDataPreprocessor(**config.data_preprocessor)
15+
train_loader, test_loader, val_loader = preproc.get_dataloaders()
16+
print(f"train set: {len(train_loader.dataset)} val set: {len(val_loader.dataset)}")
17+
model_tshae = load_tshae_model(config.diffusion.checkpoint.path)
18+
#print(model_tshae)
19+
20+
hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
21+
output_dir = hydra_cfg['runtime']['output_dir']
22+
print(f"output dir: {output_dir}")
23+
24+
with open("engine_runs_diff.pickle", 'rb') as handle:
25+
engine_runs = pickle.load(handle)
26+
for engine in engine_runs.keys():
27+
plot_engine_run_diff(
28+
engine_runs,
29+
engine_id=engine,
30+
img_path=output_dir,
31+
save=True,
32+
show=False
33+
)
34+
plot_engine_run_diff_decision_boundary(
35+
model_tshae,
36+
engine_runs,
37+
img_path=output_dir,
38+
engine_id=engine,
39+
title="engine_run_boundary",
40+
save=True,
41+
show=False)
42+
43+
if __name__ == "__main__":
44+
plot()

ddpm_train.py

+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from utils.metric_dataloader import MetricDataPreprocessor
2+
from models.ddpm_models import DDPM, ContextUnet
3+
from matplotlib.animation import FuncAnimation, PillowWriter
4+
from utils.tshae_utils import load_tshae_model
5+
from tqdm import tqdm
6+
import random
7+
import torch
8+
import os
9+
import matplotlib.pyplot as plt
10+
import numpy as np
11+
import torch.nn as nn
12+
13+
import hydra
14+
from hydra.utils import instantiate
15+
16+
17+
18+
@hydra.main(version_base=None, config_path="./configs", config_name="config.yaml")
19+
def train_cmapss(config):
20+
21+
preproc = MetricDataPreprocessor(**config.diffusion.data_preprocessor)
22+
train_loader, test_loader, val_loader = preproc.get_dataloaders()
23+
print(f"train set: {len(train_loader.dataset)} val set: {len(val_loader.dataset)}")
24+
25+
model_tshae = load_tshae_model(config.diffusion.checkpoint_tshae.path)
26+
27+
hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
28+
output_dir = hydra_cfg['runtime']['output_dir']
29+
print(f"output dir: {output_dir}")
30+
31+
32+
n_epoch = config.diffusion.ddpm_train.epochs
33+
n_T = config.diffusion.ddpm_train.n_T # 500
34+
device = config.diffusion.ddpm_train.device #"cuda:0" or "cpu"
35+
z_dim = config.diffusion.ddpm_train.z_dim
36+
n_feat = config.diffusion.ddpm_train.n_feat # 128 ok, 256 better (but slower)
37+
lrate = config.diffusion.ddpm_train.lrate #1e-4
38+
save_model = config.diffusion.ddpm_train.save_model
39+
save_dir = output_dir #'./outputs/diffusion_outputs/'
40+
ws_test = config.diffusion.ddpm_train.ws_test #[0.0, 0.5, 2.0] strength of generative guidance
41+
42+
drop_prob = config.diffusion.ddpm_model.drop_prob
43+
44+
ddpm = DDPM(
45+
nn_model=ContextUnet(
46+
in_channels=1,
47+
n_feat=n_feat,
48+
z_dim=z_dim),
49+
betas=(1e-4, 0.02),
50+
n_T=n_T,
51+
device=device,
52+
drop_prob=drop_prob)
53+
54+
ddpm.to(device)
55+
56+
model_tshae.eval().to(device)
57+
for param in model_tshae.parameters():
58+
param.requires_grad = False
59+
60+
# optionally load a model
61+
# ddpm.load_state_dict(torch.load("./data/diffusion_outputs/ddpm_unet01_mnist_9.pth"))
62+
63+
# Instantiating the optimizer:
64+
optimizer = instantiate(config.diffusion.optimizer, params=ddpm.parameters())
65+
66+
for ep in range(n_epoch):
67+
print(f'epoch {ep}')
68+
ddpm.train()
69+
70+
# linear lrate decay
71+
optimizer.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)
72+
73+
pbar = tqdm(train_loader)
74+
loss_ema = None
75+
pairs_mode = train_loader.dataset.return_pairs
76+
for data in pbar:
77+
if pairs_mode:
78+
x, pos_x, neg_x, true_rul, _, _ = data
79+
else:
80+
x, true_rul = data
81+
82+
x = x.to(device)
83+
with torch.no_grad():
84+
predicted_rul, z, mean, log_var, x_hat = model_tshae(x)
85+
m = nn.ReplicationPad2d((0, 11, 0, 0))
86+
x_diffusion = m(x)
87+
88+
optimizer.zero_grad()
89+
x_diffusion = x_diffusion.unsqueeze(1).to(device)
90+
context = z.to(device)
91+
loss = ddpm(x_diffusion, context)
92+
93+
loss.backward()
94+
if loss_ema is None:
95+
loss_ema = loss.item()
96+
else:
97+
loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
98+
pbar.set_description(f"loss: {loss_ema:.4f}")
99+
optimizer.step()
100+
101+
# for eval, save an image of currently generated samples (top rows)
102+
# followed by real images (bottom rows)
103+
rul_range = np.arange(0, 100, 10)
104+
run_ids = train_loader.dataset.ids
105+
idx = random.choice(run_ids)
106+
run_x, run_rul = train_loader.dataset.get_run(idx)
107+
x_samples = run_x[torch.isin(run_rul, torch.Tensor(rul_range))]
108+
rul_seed = run_rul[torch.isin(run_rul, torch.Tensor(rul_range))]
109+
x_samples = x_samples.to(device)
110+
with torch.no_grad():
111+
predicted_rul, z_samples, mean, log_var, x_hat = model_tshae(x_samples)
112+
113+
114+
ddpm.eval()
115+
with torch.no_grad():
116+
n_sample = 4
117+
num_columns = z_samples.shape[0]
118+
num_rows = n_sample
119+
for w_i, w in enumerate(ws_test):
120+
x_gen, x_gen_store =ddpm.sample_cmapss(n_sample=n_sample, size=(1,32,32), device=device, z_space_contexts=z_samples, guide_w = w)
121+
122+
# append some real images at bottom, order by class also
123+
x_real = m(x_samples).to(device)
124+
125+
x_all = torch.cat([x_gen, x_real.unsqueeze(1)])
126+
127+
fig, axs = plt.subplots(nrows=num_rows+1, ncols=num_columns ,sharex=True,sharey=True,figsize=(20,15))
128+
for row in range(num_rows+1):
129+
if row == num_rows:
130+
plot_type = "true"
131+
else:
132+
plot_type = "gen"
133+
for col in range(num_columns):
134+
axs[row, col].clear()
135+
axs[row, col].set_xticks([])
136+
axs[row, col].set_yticks([])
137+
axs[row, col].set_title(f"{plot_type} Id: {idx} RUL: {int(rul_seed[col])}", fontsize=10)
138+
axs[row, col].imshow(x_all[row*num_columns+col,:,:,:21].cpu().squeeze(),vmin=(x_all[:,:,:,:21].min()), vmax=(x_all[:,:,:,:21].max()))
139+
img_path = save_dir + '/images/'
140+
os.makedirs(os.path.dirname(img_path), exist_ok=True)
141+
plt.savefig(img_path + f"image_ep{ep}_w{w}.png", dpi=100)
142+
print('saved image at ' + save_dir + f"image_ep{ep}_w{w}.png")
143+
plt.close('all')
144+
#fig.clf()
145+
146+
if ep%5==0 or ep == int(n_epoch-1):
147+
# create gif of images evolving over time, based on x_gen_store
148+
fig, axs = plt.subplots(nrows=num_rows, ncols=num_columns ,sharex=True,sharey=True,figsize=(12,7))
149+
def animate_diff(i, x_gen_store):
150+
print(f'gif animating frame {i} of {x_gen_store.shape[0]}', end='\r')
151+
plots = []
152+
for row in range(num_rows):
153+
for col in range(num_columns):
154+
axs[row, col].clear()
155+
axs[row, col].set_xticks([])
156+
axs[row, col].set_yticks([])
157+
plots.append(axs[row, col].imshow(x_gen_store[i,(row*num_columns)+col,0,:,:21],vmin=(x_gen_store[i,:,0,:,:21]).min(), vmax=(x_gen_store[i,:,0,:,:21]).max()))
158+
return plots
159+
#print("x_gen shape:", x_gen_store.shape)
160+
ani = FuncAnimation(fig, animate_diff, fargs=[x_gen_store], interval=200, blit=False, repeat=True, frames=x_gen_store.shape[0])
161+
img_path = save_dir + '/images/'
162+
os.makedirs(os.path.dirname(img_path), exist_ok=True)
163+
ani.save(img_path + f"gif_ep{ep}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
164+
print('saved image at ' + save_dir + f"gif_ep{ep}_w{w}.gif")
165+
plt.close('all')
166+
# optionally save model
167+
if save_model and ep == int(n_epoch-1):
168+
torch.save(ddpm.state_dict(), save_dir + f"/model_{ep}.pth")
169+
print('saved model at ' + save_dir + f"/model_{ep}.pth")
170+
171+
172+
if __name__ == "__main__":
173+
train_cmapss()

0 commit comments

Comments
 (0)