forked from allenai/ai2thor-rearrangement
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathone_phase_rgb_ppo.py
39 lines (32 loc) · 1.17 KB
/
one_phase_rgb_ppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from typing import Dict, Any
from allenact.algorithms.onpolicy_sync.losses import PPO
from allenact.algorithms.onpolicy_sync.losses.ppo import PPOConfig
from allenact.utils.experiment_utils import LinearDecay, PipelineStage
from baseline_configs.one_phase.one_phase_rgb_base import (
OnePhaseRGBBaseExperimentConfig,
)
class OnePhaseRGBPPOExperimentConfig(OnePhaseRGBBaseExperimentConfig):
CNN_PREPROCESSOR_TYPE_AND_PRETRAINING = None
@classmethod
def tag(cls) -> str:
return "OnePhaseRGBPPO"
@classmethod
def num_train_processes(cls) -> int:
return 40
@classmethod
def _training_pipeline_info(cls, **kwargs) -> Dict[str, Any]:
"""Define how the model trains."""
training_steps = cls.TRAINING_STEPS
return dict(
named_losses=dict(
ppo_loss=PPO(clip_decay=LinearDecay(training_steps), **PPOConfig)
),
pipeline_stages=[
PipelineStage(loss_names=["ppo_loss"], max_stage_steps=training_steps,)
],
num_steps=64,
num_mini_batch=1,
update_repeats=3,
use_lr_decay=True,
lr=3e-4,
)