|
17 | 17 | from models.ssd.config import mobilenetv1_ssd_config
|
18 | 18 | from models.ssd.data_preprocessing import TrainAugmentation, TestTransform
|
19 | 19 |
|
| 20 | +import config |
| 21 | + |
20 | 22 | parser = argparse.ArgumentParser(
|
21 | 23 | description='Single Shot MultiBox Detector Training With Pytorch')
|
22 | 24 |
|
23 |
| -parser.add_argument('--datasets', nargs='+', help='Dataset directory path') |
24 |
| -parser.add_argument('--validation_dataset', help='Dataset directory path') |
| 25 | +parser.add_argument('--train_images_path', default=config.TRAIN_IMAGES_PATH, help='train_images_path') |
| 26 | +parser.add_argument('--train_xmls_path', default=config.TRAIN_XMLS_PATH, help='train_xmls_path') |
| 27 | +parser.add_argument('--val_images_path', default=config.VAL_IMAGES_PATH, help='val_images_path') |
| 28 | +parser.add_argument('--val_xmls_path', default=config.VAL_XMLS_PATH, help='val_xmls_path') |
25 | 29 |
|
26 |
| -parser.add_argument('--net', default='mb1-ssd', |
| 30 | +parser.add_argument('--net', default='mb1-ssd-lite', |
27 | 31 | help='The network architecture')
|
28 | 32 | parser.add_argument('--freeze_base_net', action='store_true',
|
29 | 33 | help='Freeze base net layers.')
|
|
64 | 68 | help='T_max value for Cosine Annealing Scheduler.')
|
65 | 69 |
|
66 | 70 | # Train params
|
67 |
| -parser.add_argument('--batch_size', default=32, type=int, |
| 71 | +parser.add_argument('--batch_size', default=8, type=int, |
68 | 72 | help='Batch size for training')
|
69 | 73 | parser.add_argument('--num_epochs', default=120, type=int,
|
70 | 74 | help='the number epochs')
|
71 |
| -parser.add_argument('--num_workers', default=4, type=int, |
| 75 | +parser.add_argument('--num_workers', default=1, type=int, |
72 | 76 | help='Number of workers used in dataloading')
|
73 | 77 | parser.add_argument('--validation_epochs', default=5, type=int,
|
74 | 78 | help='the number epochs')
|
@@ -96,6 +100,7 @@ def train(loader, net, criterion, optimizer, device, debug_steps=100, epoch=-1):
|
96 | 100 | running_regression_loss = 0.0
|
97 | 101 | running_classification_loss = 0.0
|
98 | 102 | for i, data in enumerate(loader):
|
| 103 | + print(f"i:{i}") |
99 | 104 | images, boxes, labels = data
|
100 | 105 | images = images.to(device)
|
101 | 106 | boxes = boxes.to(device)
|
@@ -168,20 +173,23 @@ def test(loader, net, criterion, device):
|
168 | 173 | test_transform = TestTransform(config.image_size, config.image_mean, config.image_std)
|
169 | 174 |
|
170 | 175 | logging.info("Prepare training datasets.")
|
171 |
| - datasets = [] |
172 |
| - |
173 |
| - for dataset_path in args.datasets: |
174 |
| - dataset = RATDataset(dataset_path, transform=train_transform, target_transform=target_transform) |
175 |
| - num_classes = len(dataset.class_names) |
176 |
| - datasets.append(dataset) |
177 |
| - train_datasets = ConcatDataset(datasets) |
| 176 | + # datasets = [] |
| 177 | + |
| 178 | + # for dataset_path in args.datasets: |
| 179 | + # dataset = RATDataset(dataset_path, transform=train_transform, target_transform=target_transform) |
| 180 | + # num_classes = len(dataset.class_names) |
| 181 | + # datasets.append(dataset) |
| 182 | + # train_datasets = ConcatDataset(datasets) |
| 183 | + train_datasets = RATDataset(args.train_images_path, args.train_xmls_path, transform=train_transform, |
| 184 | + target_transform=target_transform) |
| 185 | + num_classes = len(train_datasets.class_names) |
178 | 186 | logging.info(f"Train dataset size :{len(train_datasets)}")
|
179 | 187 | logging.info(train_datasets)
|
180 | 188 | train_loader = DataLoader(train_datasets, args.batch_size,
|
181 |
| - num_workers=args.num_workers, shuffle=True) |
| 189 | + num_workers=args.num_workers, shuffle=True) |
182 | 190 |
|
183 | 191 | logging.info("Prepare Validation datasets.")
|
184 |
| - val_dataset = RATDataset(args.validation_dataset, transform=test_transform, |
| 192 | + val_dataset = RATDataset(args.val_images_path, args.val_xmls_path, transform=test_transform, |
185 | 193 | target_transform=target_transform, is_test=True)
|
186 | 194 | logging.info(val_dataset)
|
187 | 195 | val_loader = DataLoader(val_dataset, args.batch_size,
|
@@ -279,4 +287,3 @@ def test(loader, net, criterion, device):
|
279 | 287 | model_path = os.path.join(args.checkpoint_folder, f"{args.net}-Epoch-{epoch}-Loss-{val_loss}.pth")
|
280 | 288 | net.save(model_path)
|
281 | 289 | logging.info(f"Saved model {model_path}")
|
282 |
| - |
|
0 commit comments