-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrun_avg_ens.py
executable file
·69 lines (54 loc) · 1.78 KB
/
run_avg_ens.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
########################################################################
#
# @author : Emmanouil Sylligardos
# @when : Winter Semester 2022/2023
# @where : LIPADE internship Paris
# @title : MSAD (Model Selection Anomaly Detection)
# @component: root
# @file : run_avg_ens
#
########################################################################
from models.model.avg_ens import Avg_ens
from utils.scores_loader import ScoresLoader
from utils.data_loader import DataLoader
from utils.metrics_loader import MetricsLoader
from utils.config import *
import argparse
import numpy as np
import sys
def create_avg_ens(n_jobs=1):
'''Create, fit and save the results for the 'Avg_ens' model
:param n_jobs: Threads to use in parallel to compute the metrics faster
'''
# Load metrics' names
metricsloader = MetricsLoader(TSB_metrics_path)
metrics = metricsloader.get_names()
# Load data
dataloader = DataLoader(TSB_data_path)
datasets = dataloader.get_dataset_names()
x, y, fnames = dataloader.load(datasets)
# Load scores
scoresloader = ScoresLoader(TSB_scores_path)
scores, idx_failed = scoresloader.load(fnames)
# Remove failed idxs
if len(idx_failed) > 0:
for idx in sorted(idx_failed, reverse=True):
del x[idx]
del y[idx]
del fnames[idx]
# Create Avg_ens
avg_ens = Avg_ens()
metric_values = avg_ens.fit(y, scores, metrics, n_jobs=n_jobs)
for metric in metrics:
# Write metric values for avg_ens
metricsloader.write(metric_values[metric], fnames, 'AVG_ENS', metric)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='run_avg_ense',
description="Create the average ensemble model"
)
parser.add_argument('-n', '--n_jobs', type=int, default=4,
help='Threads to use for parallel computation'
)
args = parser.parse_args()
create_avg_ens(n_jobs=args.n_jobs)