Skip to content

Commit 1236f12

Browse files
committed
Update train and test script
1 parent 7c2ad96 commit 1236f12

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

tutorials/ppo/test.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import gymnasium
66
import gymnasium.wrappers.vector.jax_to_torch
7+
import jax
78
import numpy as np
89
import torch
910
from agent import Agent
@@ -44,15 +45,16 @@
4445
test_env = gymnasium.wrappers.vector.jax_to_torch.JaxToTorch(norm_test_env, device=device)
4546

4647
# Load checkpoint
47-
checkpoint = torch.load(Path(__file__).parent / "ppo_checkpoint.pt")
48+
checkpoint = torch.load(Path(__file__).parent / "ppo_checkpoint.pt", weights_only=True)
4849

4950
# Create agent and load state
5051
agent = Agent(test_env).to(device)
5152
agent.load_state_dict(checkpoint["model_state_dict"])
5253

5354
# Set normalization parameters
54-
norm_test_env.obs_rms.mean = checkpoint["obs_mean"]
55-
norm_test_env.obs_rms.var = checkpoint["obs_var"]
55+
jax_device = jax.devices(env_device)[0]
56+
norm_test_env.obs_rms.mean = jax.dlpack.from_dlpack(checkpoint["obs_mean"], jax_device)
57+
norm_test_env.obs_rms.var = jax.dlpack.from_dlpack(checkpoint["obs_var"], jax_device)
5658

5759
# Test for 10 episodes
5860
n_episodes = 10

tutorials/ppo/train.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def save_model(agent, optimizer, train_envs, path):
8787
save_dict = {"model_state_dict": agent.state_dict(), "optim_state_dict": optimizer.state_dict()}
8888
# Unwrap the environment to find normalization wrapper
8989
if (norm_env := unwrap_norm_env(train_envs)) is not None:
90-
save_dict["obs_mean"] = norm_env.obs_rms.mean
91-
save_dict["obs_var"] = norm_env.obs_rms.var
90+
save_dict["obs_mean"] = torch.utils.dlpack.from_dlpack(norm_env.obs_rms.mean)
91+
save_dict["obs_var"] = torch.utils.dlpack.from_dlpack(norm_env.obs_rms.var)
9292
torch.save(save_dict, path)
9393

9494

@@ -414,10 +414,10 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False):
414414
"seed": 0,
415415
"n_eval_envs": 64,
416416
"n_eval_steps": 1_000,
417-
"save_model": False,
417+
"save_model": True,
418418
"eval_interval": 999_000,
419419
"lr_decay": False,
420420
}
421421
)
422422

423-
train_ppo(config, wandb_log=True)
423+
train_ppo(config, wandb_log=False)

0 commit comments

Comments
 (0)