Skip to content

Commit 2a2b262

Browse files
author
wbw520
committed
modify the training process
1 parent e4d9d84 commit 2a2b262

10 files changed

+65
-40
lines changed

Diff for: configs.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@ def get_args_parser():
66

77
# train settings
88
parser.add_argument("--dataset", type=str, default="cityscapes")
9-
parser.add_argument("--model_name", type=str, default="PSPNet")
9+
parser.add_argument("--model_name", type=str, default="Segmenter")
10+
parser.add_argument("--pre_model", type=str, default="mae_pre_epoch99_crop640_patch8_ed768_depth12_head12.pt")
1011
parser.add_argument("--batch_size", type=int, default=1,
1112
help="Number of images sent to the network in one step.")
1213
parser.add_argument("--root", type=str, default="/home/wangbowen/DATA/cityscapes",
1314
help="Path to the directory containing the image list.")
1415
parser.add_argument("--crop_size", type=int, default=[640, 640],
1516
help="crop size for training and inference slice.")
1617
parser.add_argument("--stride_rate", type=float, default=0.5, help="stride ratio.")
17-
parser.add_argument("--num_epoch", type=int, default=60, help="Number of training steps.")
18+
parser.add_argument("--num_epoch", type=int, default=200, help="Number of training steps.")
1819
parser.add_argument("--num_classes", type=int, default=19, help="Number of class for dataset.")
1920
parser.add_argument('--accum_iter', default=1, type=int,
2021
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')

Diff for: mae_pre-training.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.utils.data import DistributedSampler
88
from data import cityscapes
99
from data.loader_tools import get_joint_transformations, get_standard_transformations
10-
from model import mae_model as models_mae
10+
from model.mae_model import mae_vit
1111
from utils2.misc import NativeScalerWithGradNormCount as NativeScaler
1212
import timm.optim.optim_factory as optim_factory
1313
import os
@@ -23,17 +23,14 @@
2323

2424
def get_args_parser():
2525
parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
26-
parser.add_argument('--batch_size', default=1, type=int,
26+
parser.add_argument('--batch_size', default=4, type=int,
2727
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
28-
parser.add_argument('--epochs', default=200, type=int)
28+
parser.add_argument('--num_epochs', default=100, type=int)
2929
parser.add_argument('--accum_iter', default=1, type=int,
3030
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
3131
parser.add_argument('--use_pre', default=False, type=bool)
3232

3333
# Model parameters
34-
parser.add_argument('--model', default='mae_vit_large_patch8', type=str, metavar='MODEL',
35-
help='Name of model to train')
36-
parser.add_argument('--patch_size', default=8, type=int, help='size of patch')
3734
parser.add_argument("--crop_size", type=int, default=[640, 640],
3835
help="crop size for training and inference slice.")
3936
parser.add_argument("--stride_rate", type=float, default=0.5, help="stride ratio.")
@@ -45,6 +42,15 @@ def get_args_parser():
4542
help='Use (per-patch) normalized pixels as targets for computing loss')
4643
parser.set_defaults(norm_pix_loss=False)
4744

45+
# VIT settings
46+
parser.add_argument("--patch_size", type=int, default=16, help="define the patch size.")
47+
parser.add_argument("--encoder_embed_dim", type=int, default=1024, help="dimension for encoder.")
48+
parser.add_argument("--decoder_embed_dim", type=int, default=512, help="dimension for decoder.")
49+
parser.add_argument("--encoder_depth", type=int, default=24, help="depth for encoder.")
50+
parser.add_argument("--decoder_depth", type=int, default=8, help="depth for decoder.")
51+
parser.add_argument("--encoder_num_head", type=int, default=16, help="head number for encoder.")
52+
parser.add_argument("--decoder_num_head", type=int, default=16, help="head number for decoder.")
53+
4854
# Optimizer parameters
4955
parser.add_argument('--weight_decay', type=float, default=0.05,
5056
help='weight decay (default: 0.05)')
@@ -56,13 +62,13 @@ def get_args_parser():
5662
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
5763
help='lower lr bound for cyclic schedulers that hit 0')
5864

59-
parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
65+
parser.add_argument('--warmup_epochs', type=int, default=20, metavar='N',
6066
help='epochs to warmup LR')
6167

6268
# Dataset parameters
6369
parser.add_argument('--root', default="/home/wangbowen/DATA/cityscapes", type=str,
6470
help='dataset path')
65-
parser.add_argument('--output_dir', default='save_model',
71+
parser.add_argument('--output_dir', default='save_model/',
6672
help='path where to save, empty for no saving')
6773
parser.add_argument('--log_dir', default='save_model',
6874
help='path where to tensorboard log')
@@ -86,7 +92,7 @@ def main():
8692
misc.init_distributed_mode(args)
8793
device = torch.device(args.device)
8894
cudnn.benchmark = True
89-
model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss, img_size=args.crop_size[0])
95+
model = mae_vit(args)
9096
model.to(device)
9197
model_without_ddp = model
9298
# print("Model = %s" % str(model_without_ddp))
@@ -97,6 +103,7 @@ def main():
97103
standard_transform=standard_transformations)
98104

99105
if args.use_pre:
106+
# use the pre-trained parameter from mae paper
100107
checkpoint = torch.load("save_model/mae_visualize_vit_large.pth", map_location='cpu')
101108
checkpoint_model = checkpoint['model']
102109
interpolate_pos_embed(model, checkpoint_model)
@@ -138,17 +145,18 @@ def main():
138145
print(optimizer)
139146
loss_scaler = NativeScaler()
140147

141-
print(f"Start training for {args.epochs} epochs")
148+
print(f"Start training for {args.num_epochs} epochs")
142149
start_time = time.time()
143150

144-
for epoch in range(args.epochs):
151+
for epoch in range(args.num_epochs):
145152
if args.distributed:
146153
sampler_train.set_epoch(epoch)
147154

148-
if args.output_dir and ((epoch + 1) % 20 == 0 or epoch + 1 == args.epochs):
149-
misc.save_model(
150-
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
151-
loss_scaler=loss_scaler, epoch=epoch)
155+
if args.output_dir and ((epoch + 1) % 50 == 0 or epoch + 1 == args.num_epochs):
156+
torch.save(model_without_ddp.state_dict(),
157+
args.output_dir + "mae_pre_epoch" + str(epoch) + "_crop" + str(args.crop_size[0]) + "_patch" +
158+
str(args.patch_size) + "_ed" + str(args.encoder_embed_dim) + "_depth" + str(args.encoder_depth) +
159+
"_head" + str(args.encoder_num_head) + ".pt")
152160

153161
train_stats = train_one_epoch(
154162
model, train_loader,

Diff for: model/mae_model.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,9 @@ def forward(self, imgs, mask_ratio=0.75):
216216
return loss, pred, mask
217217

218218

219-
def mae_vit(**kwargs):
219+
def mae_vit(args):
220220
model = MaskedAutoencoderViT(
221-
patch_size=16, embed_dim=768, depth=12, num_heads=12,
222-
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
223-
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
224-
return model
225-
226-
227-
# set recommended archs
228-
mae_vit = mae_vit
221+
img_size=args.crop_size[0], patch_size=args.patch_size, embed_dim=args.encoder_embed_dim, depth=args.encoder_depth, num_heads=args.encoder_num_head,
222+
decoder_embed_dim=args.decoder_embed_dim, decoder_depth=args.decoder_depth, decoder_num_heads=args.decoder_num_head,
223+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6))
224+
return model

Diff for: model/segmenter.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from model.vit_model import vit_encoder
22
from model.segmenter_decoder import sg_vit_mask_decoder
33
import torch.nn as nn
4+
import torch
5+
import os
46
import torch.nn.functional as F
57

68

@@ -42,6 +44,17 @@ def get_attention_map_dec(self, im, layer_id):
4244

4345
def create_segmenter(args):
4446
encoder = vit_encoder(img_size=args.crop_size[0], patch_size=args.patch_size, embed_dim=args.encoder_embed_dim, depth=args.encoder_depth, num_heads=args.encoder_num_head)
47+
check_point = torch.load(os.path.join("save_model", args.pre_model), map_location="cuda:0")
48+
state_dict = ["decoder", "mask_token"]
49+
record = []
50+
for k, v in check_point.items():
51+
if state_dict[0] in k or state_dict[1] in k:
52+
record.append(k)
53+
for item in record:
54+
del check_point[item]
55+
encoder.load_state_dict(check_point, strict=True)
56+
print("load pre-trained weight from: ", args.pre_model)
57+
4558
decoder = sg_vit_mask_decoder(patch_size=args.patch_size, encoder_embed_dim=args.encoder_embed_dim,
4659
decoder_embed_dim=args.decoder_embed_dim, decoder_depth=args.decoder_depth, decoder_num_heads=args.decoder_num_head, n_cls=args.num_classes)
4760
model = Segmenter(encoder, decoder)

Diff for: model/segmenter_decoder.py

-3
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,7 @@ def forward(self, x, im_size):
6262

6363
patches = patches / patches.norm(dim=-1, keepdim=True)
6464
cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)
65-
print(patches.shape)
66-
print(cls_seg_feat.shape)
6765
masks = patches @ cls_seg_feat.transpose(1, 2)
68-
print(masks.shape)
6966
masks = self.mask_norm(masks)
7067
masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS))
7168

Diff for: model/vit_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,5 @@ def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens):
7979

8080

8181
def vit_encoder(**kwargs):
82-
model = VisionTransformer(mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
82+
model = VisionTransformer(mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes=0, **kwargs)
8383
return model

Diff for: sample_demo/color_mask_translate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def main():
10-
root = "/home/wangbowen/DATA/facades/facade_raw"
10+
root = "/home/wangbowen/DATA/Facade/"
1111
item_list = get_name(root, mode_folder=False)
1212
image_list = []
1313
for item in item_list:

Diff for: train.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def main():
4949
# param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
5050
param_groups = [p for p in model_without_ddp.parameters() if p.requires_grad]
5151
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
52+
5253
print(optimizer)
5354
loss_scaler = NativeScaler()
5455

Diff for: utils/engine.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
def train_model(args, epoch, model, train_loader, criterion, optimizer, loss_scaler, device):
1111
model.train()
1212
train_main_loss = AverageMeter('Train Main Loss', ':.5')
13-
train_aux_loss = AverageMeter('Train Aux Loss', ':.5')
1413
lr = AverageMeter('lr', ':.5')
1514
L = len(train_loader)
1615
curr_iter = epoch * L
17-
progress = ProgressMeter(L, [train_main_loss, train_aux_loss, lr], prefix="Epoch: [{}]".format(epoch))
16+
record = [lr, train_main_loss]
17+
if args.model_name == "PSPNet":
18+
train_aux_loss = AverageMeter('Train Aux Loss', ':.5')
19+
record.append(train_aux_loss)
20+
progress = ProgressMeter(L, record, prefix="Epoch: [{}]".format(epoch))
1821
accum_iter = args.accum_iter
1922

2023
for data_iter_step, data in enumerate(train_loader):
@@ -26,10 +29,15 @@ def train_model(args, epoch, model, train_loader, criterion, optimizer, loss_sca
2629
mask = data["masks"].to(device, dtype=torch.int64)
2730

2831
with torch.cuda.amp.autocast():
29-
outputs, aux = model(inputs)
30-
main_loss = criterion(outputs, mask)
31-
aux_loss = criterion(aux, mask)
32-
loss = main_loss + 0.4 * aux_loss
32+
if args.model_name == "PSPNet":
33+
outputs, aux = model(inputs)
34+
main_loss = criterion(outputs, mask)
35+
aux_loss = criterion(aux, mask)
36+
loss = main_loss + 0.4 * aux_loss
37+
else:
38+
outputs = model(inputs)
39+
main_loss = criterion(outputs, mask)
40+
loss = main_loss
3341

3442
loss_value = loss.item()
3543
if not math.isfinite(loss_value):
@@ -43,7 +51,8 @@ def train_model(args, epoch, model, train_loader, criterion, optimizer, loss_sca
4351
torch.cuda.synchronize()
4452

4553
train_main_loss.update(main_loss.item())
46-
train_aux_loss.update(aux_loss.item())
54+
if args.model_name == "PSPNet":
55+
train_aux_loss.update(aux_loss.item())
4756
lr.update(optimizer.param_groups[0]['lr'])
4857

4958
curr_iter += 1
@@ -77,7 +86,7 @@ def evaluation(args, best_record, epoch, model, model_without_ddp, val_loader, c
7786
best_record['acc_cls'] = acc_cls
7887
best_record['mean_iou'] = mean_iou
7988
if args.output_dir:
80-
torch.save(model_without_ddp.state_dict(), args.output_dir + str(epoch) + "_epoch_" + args.model_name + ".pt")
89+
torch.save(model_without_ddp.state_dict(), args.output_dir + args.model_name + ".pt")
8190

8291
print('-----------------------------------------------------------------------------------------------------------')
8392
print('[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], [mean_iou %.5f]' % (

Diff for: utils2/lr_sched.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def adjust_learning_rate(optimizer, epoch, args):
1313
lr = args.lr * epoch / args.warmup_epochs
1414
else:
1515
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
16-
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.num_epoch - args.warmup_epochs)))
16+
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.num_epochs - args.warmup_epochs)))
1717
for param_group in optimizer.param_groups:
1818
if "lr_scale" in param_group:
1919
param_group["lr"] = lr * param_group["lr_scale"]

0 commit comments

Comments
 (0)