-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmain_recon.py
64 lines (56 loc) · 2.55 KB
/
main_recon.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from torchvision import datasets, transforms
from torch.utils.data.dataloader import DataLoader
import torch
from termcolor import colored
from utils.engine_recon import train, evaluation, vis_one
import torch.nn as nn
from configs import parser
from model.reconstruct.model_main import ConceptAutoencoder
import os
os.makedirs('saved_model/', exist_ok=True)
def main():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
trainset = datasets.MNIST('../data', train=True, download=True, transform=transform)
valset = datasets.MNIST('../data', train=False, transform=transform)
trainloader = DataLoader(trainset, batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=False)
valloader = DataLoader(valset, batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=False)
model = ConceptAutoencoder(args, num_concepts=args.num_cpt)
reconstruction_loss = nn.MSELoss()
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, lr=args.lr)
device = torch.device("cuda:0")
model.to(device)
record_res = []
record_att = []
accs = []
for i in range(args.epoch):
print(colored('Epoch %d/%d' % (i + 1, args.epoch), 'yellow'))
print(colored('-' * 15, 'yellow'))
# Adjust lr
if i == args.lr_drop:
print("Adjusted learning rate to 0.00001")
optimizer.param_groups[0]["lr"] = optimizer.param_groups[0]["lr"] * 0.1
train(args, model, device, trainloader, reconstruction_loss, optimizer, i)
res_loss, att_loss, acc = evaluation(model, device, valloader, reconstruction_loss)
record_res.append(res_loss)
record_att.append(att_loss)
accs.append(acc)
if i % args.fre == 0:
vis_one(model, device, valloader, epoch=i, select_index=1)
print("Reconstruction Loss: ", record_res)
print("Acc: ", accs)
torch.save(model.state_dict(), f"saved_model/mnist_model_cpt{args.num_cpt}.pt")
if __name__ == '__main__':
args = parser.parse_args()
args.att_bias = 5
args.quantity_bias = 0.1
args.distinctiveness_bias = 0
args.consistence_bias = 0
os.makedirs(args.output_dir + '/', exist_ok=True)
main()