forked from AlignmentResearch/vlmrm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.yaml
41 lines (41 loc) · 1.85 KB
/
config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
env_name: CartPole-v1 # RL environment name
base_path: /data/evan_gunter/vlmrm/runs/training_cartpole # Base path to save logs and checkpoints
seed: 42 # Seed for reproducibility
description: CartPole training using CLIP reward
tags: # Wandb tags
- training
- cartpole
- CLIP
reward:
name: clip
pretrained_model: ViT-g-14/laion2b_s34b_b88k # CLIP model name
# CLIP batch size per synchronous inference step.
# Batch size must be divisible by n_workers (GPU count)
# so that it can be shared among workers, and must be a divisor
# of n_envs * episode_length so that all batches can be of the
# same size (no support for variable batch size as of now.)
batch_size: 1600
alpha: 0.5 # Alpha value of Baseline CLIP (CO-RELATE)
target_prompts: # Description of the goal state
- pole vertically upright on top of the cart
baseline_prompts: # Description of the environment
- pole and cart
# Path to pre-saved model weights. When executing multiple runs,
# mount a volume to this path to avoid downloading the model
# weights multiple times.
cache_dir: /data/evan_gunter/vlmrm/.cache_clip
rl:
policy_name: MlpPolicy
n_steps: 100000 # Total number of simulation steps to be collected.
n_envs_per_worker: 4 # Number of environments per worker (GPU)
episode_length: 200 # Desired episode length
learning_starts: 100 # Number of env steps to collect before training
train_freq: 200 # Number of collected env steps between training iterations
batch_size: 64 # SAC buffer sample size per gradient step
gradient_steps: 1 # Number of samples to collect from the buffer per training step
tau: 0.005 # SAC target network update rate
gamma: 0.99 # SAC discount factor
learning_rate: 3e-4 # SAC optimizer learning rate
logging:
checkpoint_freq: 800 # Number of env steps between checkpoints
video_freq: 800 # Number of env steps between videos