-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_d4rl.py
More file actions
24 lines (18 loc) · 748 Bytes
/
train_d4rl.py
File metadata and controls
24 lines (18 loc) · 748 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import fire
from ireul.algo import algo_select
from ireul.data.d4rl import load_d4rl_buffer
from ireul.evaluation import OnlineCallBackFunction
def run_algo(**kwargs):
algo_init_fn, algo_trainer_obj, algo_config = algo_select(kwargs)
train_buffer, obs_mean, obs_std = load_d4rl_buffer(algo_config)
algo_config["obs_mean"] = obs_mean
algo_config["obs_std"] = obs_std
algo_init = algo_init_fn(algo_config)
algo_trainer = algo_trainer_obj(algo_init, algo_config)
callback = OnlineCallBackFunction()
callback.initialize(
train_buffer=train_buffer, val_buffer=None, config=algo_config
)
algo_trainer.train(train_buffer, None, callback_fn=callback)
if __name__ == "__main__":
fire.Fire(run_algo)