-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_tuning.py
More file actions
49 lines (40 loc) · 1.75 KB
/
run_tuning.py
File metadata and controls
49 lines (40 loc) · 1.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import opentuner
from opentuner import ConfigurationManipulator, EnumParameter
from opentuner import LogFloatParameter
from opentuner import MeasurementInterface
from opentuner import Result
from ddr_learning_helpers.tuneable import run_tuning
class DdrTuner(MeasurementInterface):
def manipulator(self):
"""
Define the search space by creating a
ConfigurationManipulator
"""
manipulator = ConfigurationManipulator()
manipulator.add_parameter(
LogFloatParameter('learning_rate', 1e-5, 1e-1))
manipulator.add_parameter(EnumParameter('gamma', [0.9, 0.99, 0.9999]))
manipulator.add_parameter(EnumParameter('batch_size', [32, 64, 256]))
manipulator.add_parameter(EnumParameter('n_steps', [128]))
return manipulator
def run(self, desired_result, input, limit):
"""
Run training with particular hyperparameters and see how goo the
performance is
"""
cfg = desired_result.configuration.data
print("Running with config: ", cfg)
result = run_tuning(cfg, self.args.config)
print("Config: ", cfg, "\nResult: ", result)
return Result(time=-result)
def save_final_config(self, configuration):
"""called at the end of tuning"""
print("Optimal hyperparameters written to hyperparams_final.json:",
configuration.data)
self.manipulator().save_to_file(
repr(configuration.data).encode('utf-8'), 'hyperparams_final.json')
if __name__ == '__main__':
argparser = opentuner.default_argparser()
argparser.add_argument('-c', action='store', dest='config',
help="Config file to read for the training")
DdrTuner.main(argparser.parse_args())