-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtraining_cross_validate.py
executable file
·133 lines (102 loc) · 5.43 KB
/
training_cross_validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Author: Michał Bednarek PUT Poznan
import os
import pickle
import sys
from argparse import ArgumentParser
import numpy as np
import tensorflow as tf
from sklearn.model_selection import KFold
from tqdm import tqdm
from functions import *
from net import *
def do_regression(args):
os.makedirs(args.results, exist_ok=True)
with open(args.data_path_train, "rb") as fp:
total_dataset = pickle.load(fp)
print("TRAIN NUM SAMPLES: {}".format(len(total_dataset["data"])))
with open(args.data_path_validation, "rb") as fp:
validation_dataset = pickle.load(fp)
print("TO-ADD NUM SAMPLES: {}".format(len(validation_dataset["data"])))
test_ds_list = list()
for test_ds_path in args.data_path_test:
with open(test_ds_path, "rb") as fp:
test_dataset = pickle.load(fp)
test_ds_list.append(test_dataset)
print("TEST NUM SAMPLES: {}".format(len(test_dataset["data"])))
# start a cross validate training
kf = KFold(n_splits=args.num_splits, shuffle=True)
for split_no, (train_idx, val_idx) in enumerate(kf.split(total_dataset["data"], total_dataset["stiffness"])):
# save split indexes
logs_path = os.path.join(args.results, '{}'.format(split_no))
print("Cross-validation, split no. {}. Saving dataset samples indexes...".format(split_no))
np.savetxt(logs_path + "{}_split_train_data_samples.txt".format(split_no), train_idx)
np.savetxt(logs_path + "{}_split_val_data_samples.txt".format(split_no), val_idx)
print("... saved.")
# setup model
if args.model_type == "conv":
model = ConvNet(args.batch_size)
elif args.model_type == "conv_lstm":
model = ConvLstmNet(args.batch_size)
elif args.model_type == "conv_bilstm":
model = ConvBiLstmNet(args.batch_size)
else:
model = ConvNet(args.batch_size)
print("default ConvNet created.")
# setup optimization procedure
eta = tf.Variable(args.lr)
eta_value = tf.keras.optimizers.schedules.ExponentialDecay(args.lr, 100, 0.99)
eta.assign(eta_value(0))
optimizer = tf.keras.optimizers.Adam(eta)
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model)
# restore from checkpoint
if args.restore:
path = tf.train.latest_checkpoint(logs_path)
ckpt.restore(path)
ckpt_man = tf.train.CheckpointManager(ckpt, logs_path, max_to_keep=10)
# setup writers
os.makedirs(logs_path, exist_ok=True)
train_writer = tf.summary.create_file_writer(logs_path + "/train")
val_writer = tf.summary.create_file_writer(logs_path + "/val")
test_writer = tf.summary.create_file_writer(logs_path + "/test")
# create split datasets to tf generators
train_ds, val_ds, test_ds, train_mean, train_std = create_tf_generators(total_dataset, test_ds_list, train_idx,
val_idx, args.batch_size,
real_data=validation_dataset,
add_real_data=args.add_validation_to_train)
# start training
train_step, val_step, test_step = 0, 0, 0
best_metric = [999999999.0 for _ in range(len(test_ds_list))]
for _ in tqdm(range(args.epochs)):
train_step = train(model, train_writer, train_ds, train_mean, train_std, optimizer, train_step,
add_noise=args.add_noise)
val_step, _, _ = validate(model, val_writer, val_ds, train_mean, train_std, val_step)
test_step, best_metric, save_model = validate(model, test_writer, test_ds, train_mean, train_std, test_step,
prefix="test", best_metric=best_metric)
# assign eta
eta.assign(eta_value(0))
# save each save_period
if save_model:
ckpt_man.save()
print("Best MAPE model saved.")
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--add-validation-to-train', default=False, action='store_true')
parser.add_argument('--data-path-train', type=str, default="./data/experiments/real_200_300/test.pickle")
parser.add_argument('--data-path-validation', type=str, default="./data/experiments/real_200_300/train_200.pickle")
parser.add_argument('--data-path-test', nargs="+", required=True)
parser.add_argument('--results', type=str, default="data/logs/real_test")
parser.add_argument('--restore', default=False, action='store_true')
parser.add_argument('--restore-dir', type=str, default="")
parser.add_argument('--model-type', type=str, default="conv_bilstm", choices=['conv', 'conv_lstm', 'conv_bilstm'], )
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch-size', type=int, default=100)
parser.add_argument('--num-splits', type=int, default=5)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--add-noise', default=False, action='store_true')
args, _ = parser.parse_known_args()
if args.model_type not in ['conv', 'lstm', 'conv_lstm', 'conv_bilstm']:
parser.print_help()
sys.exit(1)
allow_memory_growth()
print("ARGUMENTS: {}".format(args))
do_regression(args)