1
+ from utils .metric_dataloader import MetricDataPreprocessor
2
+ from models .ddpm_models import DDPM , ContextUnet
3
+ from matplotlib .animation import FuncAnimation , PillowWriter
4
+ from utils .tshae_utils import load_tshae_model
5
+ from tqdm import tqdm
6
+ import random
7
+ import torch
8
+ import os
9
+ import matplotlib .pyplot as plt
10
+ import numpy as np
11
+ import torch .nn as nn
12
+
13
+ import hydra
14
+ from hydra .utils import instantiate
15
+
16
+
17
+
18
+ @hydra .main (version_base = None , config_path = "./configs" , config_name = "config.yaml" )
19
+ def train_cmapss (config ):
20
+
21
+ preproc = MetricDataPreprocessor (** config .diffusion .data_preprocessor )
22
+ train_loader , test_loader , val_loader = preproc .get_dataloaders ()
23
+ print (f"train set: { len (train_loader .dataset )} val set: { len (val_loader .dataset )} " )
24
+
25
+ model_tshae = load_tshae_model (config .diffusion .checkpoint_tshae .path )
26
+
27
+ hydra_cfg = hydra .core .hydra_config .HydraConfig .get ()
28
+ output_dir = hydra_cfg ['runtime' ]['output_dir' ]
29
+ print (f"output dir: { output_dir } " )
30
+
31
+
32
+ n_epoch = config .diffusion .ddpm_train .epochs
33
+ n_T = config .diffusion .ddpm_train .n_T # 500
34
+ device = config .diffusion .ddpm_train .device #"cuda:0" or "cpu"
35
+ z_dim = config .diffusion .ddpm_train .z_dim
36
+ n_feat = config .diffusion .ddpm_train .n_feat # 128 ok, 256 better (but slower)
37
+ lrate = config .diffusion .ddpm_train .lrate #1e-4
38
+ save_model = config .diffusion .ddpm_train .save_model
39
+ save_dir = output_dir #'./outputs/diffusion_outputs/'
40
+ ws_test = config .diffusion .ddpm_train .ws_test #[0.0, 0.5, 2.0] strength of generative guidance
41
+
42
+ drop_prob = config .diffusion .ddpm_model .drop_prob
43
+
44
+ ddpm = DDPM (
45
+ nn_model = ContextUnet (
46
+ in_channels = 1 ,
47
+ n_feat = n_feat ,
48
+ z_dim = z_dim ),
49
+ betas = (1e-4 , 0.02 ),
50
+ n_T = n_T ,
51
+ device = device ,
52
+ drop_prob = drop_prob )
53
+
54
+ ddpm .to (device )
55
+
56
+ model_tshae .eval ().to (device )
57
+ for param in model_tshae .parameters ():
58
+ param .requires_grad = False
59
+
60
+ # optionally load a model
61
+ # ddpm.load_state_dict(torch.load("./data/diffusion_outputs/ddpm_unet01_mnist_9.pth"))
62
+
63
+ # Instantiating the optimizer:
64
+ optimizer = instantiate (config .diffusion .optimizer , params = ddpm .parameters ())
65
+
66
+ for ep in range (n_epoch ):
67
+ print (f'epoch { ep } ' )
68
+ ddpm .train ()
69
+
70
+ # linear lrate decay
71
+ optimizer .param_groups [0 ]['lr' ] = lrate * (1 - ep / n_epoch )
72
+
73
+ pbar = tqdm (train_loader )
74
+ loss_ema = None
75
+ pairs_mode = train_loader .dataset .return_pairs
76
+ for data in pbar :
77
+ if pairs_mode :
78
+ x , pos_x , neg_x , true_rul , _ , _ = data
79
+ else :
80
+ x , true_rul = data
81
+
82
+ x = x .to (device )
83
+ with torch .no_grad ():
84
+ predicted_rul , z , mean , log_var , x_hat = model_tshae (x )
85
+ m = nn .ReplicationPad2d ((0 , 11 , 0 , 0 ))
86
+ x_diffusion = m (x )
87
+
88
+ optimizer .zero_grad ()
89
+ x_diffusion = x_diffusion .unsqueeze (1 ).to (device )
90
+ context = z .to (device )
91
+ loss = ddpm (x_diffusion , context )
92
+
93
+ loss .backward ()
94
+ if loss_ema is None :
95
+ loss_ema = loss .item ()
96
+ else :
97
+ loss_ema = 0.95 * loss_ema + 0.05 * loss .item ()
98
+ pbar .set_description (f"loss: { loss_ema :.4f} " )
99
+ optimizer .step ()
100
+
101
+ # for eval, save an image of currently generated samples (top rows)
102
+ # followed by real images (bottom rows)
103
+ rul_range = np .arange (0 , 100 , 10 )
104
+ run_ids = train_loader .dataset .ids
105
+ idx = random .choice (run_ids )
106
+ run_x , run_rul = train_loader .dataset .get_run (idx )
107
+ x_samples = run_x [torch .isin (run_rul , torch .Tensor (rul_range ))]
108
+ rul_seed = run_rul [torch .isin (run_rul , torch .Tensor (rul_range ))]
109
+ x_samples = x_samples .to (device )
110
+ with torch .no_grad ():
111
+ predicted_rul , z_samples , mean , log_var , x_hat = model_tshae (x_samples )
112
+
113
+
114
+ ddpm .eval ()
115
+ with torch .no_grad ():
116
+ n_sample = 4
117
+ num_columns = z_samples .shape [0 ]
118
+ num_rows = n_sample
119
+ for w_i , w in enumerate (ws_test ):
120
+ x_gen , x_gen_store = ddpm .sample_cmapss (n_sample = n_sample , size = (1 ,32 ,32 ), device = device , z_space_contexts = z_samples , guide_w = w )
121
+
122
+ # append some real images at bottom, order by class also
123
+ x_real = m (x_samples ).to (device )
124
+
125
+ x_all = torch .cat ([x_gen , x_real .unsqueeze (1 )])
126
+
127
+ fig , axs = plt .subplots (nrows = num_rows + 1 , ncols = num_columns ,sharex = True ,sharey = True ,figsize = (20 ,15 ))
128
+ for row in range (num_rows + 1 ):
129
+ if row == num_rows :
130
+ plot_type = "true"
131
+ else :
132
+ plot_type = "gen"
133
+ for col in range (num_columns ):
134
+ axs [row , col ].clear ()
135
+ axs [row , col ].set_xticks ([])
136
+ axs [row , col ].set_yticks ([])
137
+ axs [row , col ].set_title (f"{ plot_type } Id: { idx } RUL: { int (rul_seed [col ])} " , fontsize = 10 )
138
+ axs [row , col ].imshow (x_all [row * num_columns + col ,:,:,:21 ].cpu ().squeeze (),vmin = (x_all [:,:,:,:21 ].min ()), vmax = (x_all [:,:,:,:21 ].max ()))
139
+ img_path = save_dir + '/images/'
140
+ os .makedirs (os .path .dirname (img_path ), exist_ok = True )
141
+ plt .savefig (img_path + f"image_ep{ ep } _w{ w } .png" , dpi = 100 )
142
+ print ('saved image at ' + save_dir + f"image_ep{ ep } _w{ w } .png" )
143
+ plt .close ('all' )
144
+ #fig.clf()
145
+
146
+ if ep % 5 == 0 or ep == int (n_epoch - 1 ):
147
+ # create gif of images evolving over time, based on x_gen_store
148
+ fig , axs = plt .subplots (nrows = num_rows , ncols = num_columns ,sharex = True ,sharey = True ,figsize = (12 ,7 ))
149
+ def animate_diff (i , x_gen_store ):
150
+ print (f'gif animating frame { i } of { x_gen_store .shape [0 ]} ' , end = '\r ' )
151
+ plots = []
152
+ for row in range (num_rows ):
153
+ for col in range (num_columns ):
154
+ axs [row , col ].clear ()
155
+ axs [row , col ].set_xticks ([])
156
+ axs [row , col ].set_yticks ([])
157
+ plots .append (axs [row , col ].imshow (x_gen_store [i ,(row * num_columns )+ col ,0 ,:,:21 ],vmin = (x_gen_store [i ,:,0 ,:,:21 ]).min (), vmax = (x_gen_store [i ,:,0 ,:,:21 ]).max ()))
158
+ return plots
159
+ #print("x_gen shape:", x_gen_store.shape)
160
+ ani = FuncAnimation (fig , animate_diff , fargs = [x_gen_store ], interval = 200 , blit = False , repeat = True , frames = x_gen_store .shape [0 ])
161
+ img_path = save_dir + '/images/'
162
+ os .makedirs (os .path .dirname (img_path ), exist_ok = True )
163
+ ani .save (img_path + f"gif_ep{ ep } _w{ w } .gif" , dpi = 100 , writer = PillowWriter (fps = 5 ))
164
+ print ('saved image at ' + save_dir + f"gif_ep{ ep } _w{ w } .gif" )
165
+ plt .close ('all' )
166
+ # optionally save model
167
+ if save_model and ep == int (n_epoch - 1 ):
168
+ torch .save (ddpm .state_dict (), save_dir + f"/model_{ ep } .pth" )
169
+ print ('saved model at ' + save_dir + f"/model_{ ep } .pth" )
170
+
171
+
172
+ if __name__ == "__main__" :
173
+ train_cmapss ()
0 commit comments