-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
64 lines (47 loc) · 2.33 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# train.py
import argparse
import time
from ifcb import DataDirectory
from models.feature_extractor import FeatureExtractor
from models.trainer import ModelTrainer
from utils.constants import IFCB_ASPECT_RATIO, CONTAMINATION, CHUNK_SIZE, N_JOBS, MODEL
def main():
parser = argparse.ArgumentParser(description='Train anomaly detection model')
parser.add_argument('data_dir', help='Directory containing point cloud data')
parser.add_argument('--id-file', help='File containing list of IDs to load')
parser.add_argument('--n-jobs', type=int, default=N_JOBS, help='Number of parallel jobs for load/extraction phase')
parser.add_argument('--contamination', type=float, default=CONTAMINATION, help='Expected fraction of anomalous distributions')
parser.add_argument('--aspect-ratio', type=float, default=IFCB_ASPECT_RATIO, help='Camera frame aspect ratio (width/height)')
parser.add_argument('--chunk-size', type=int, default=CHUNK_SIZE, help='Number of PIDs to process in each chunk')
parser.add_argument('--model', default=MODEL, help='Model save/load path')
args = parser.parse_args()
beginning = time.time()
# Load PIDs
if args.id_file:
with open(args.id_file) as f:
pids = [line.strip() for line in f]
else:
pids = [bin.lid for bin in DataDirectory(args.data_dir)]
then = time.time()
print(f'Loading and performing feature extraction on {len(pids)} point clouds')
# Extract features from point clouds
extractor = FeatureExtractor(aspect_ratio=args.aspect_ratio)
feature_results = extractor.load_extract_parallel(
pids, args.data_dir,
n_jobs=args.n_jobs,
chunk_size=args.chunk_size
)
elapsed = time.time() - then
print(f'Extracted features for {len(feature_results)} point clouds in {elapsed:.2f} seconds')
then = time.time()
# Train the classifier
print(f'Training classifier')
trainer = ModelTrainer(filepath=args.model, contamination=args.contamination, n_jobs=args.n_jobs)
classifier = trainer.train_classifier(feature_results)
print(f'Trained classifier in {elapsed:.2f} seconds')
# save the classifier
trainer.save_model(classifier)
elapsed = time.time() - beginning
print(f'Total load/extract/train time: {elapsed:.2f} seconds')
if __name__ == "__main__":
main()