-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathenvs.py
44 lines (34 loc) · 1.75 KB
/
envs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gym
import pybullet_envs
# Environments in pybullet
PYBULLET_ENVS = ["CartPoleBulletEnv", "CartPoleContinuousBulletEnv", "MinitaurBulletEnv", "MinitaurBulletDuckEnv", "RacecarGymEnv", "KukaGymEnv", "KukaCamGymEnv", "KukaDiverseObjectEnv"]
PYBULLET_ENVS_DEEPMIMIC = ["HumanoidDeepMimicBackflipBulletEnv", "HumanoidDeepMimicWalkBulletEnv"]
PYBULLET_ENVS_PENDULUM = ["InvertedPendulumBulletEnv", "InvertedDoublePendulumBulletEnv", "InvertedPendulumSwingupBulletEnv"]
PYBULLET_ENVS_MANIPULATOR = ["ReacherBulletEnv", "PusherBulletEnv", "ThrowerBulletEnv", "StrikerBulletEnv"]
PYBULLET_ENVS_LOCOMOTION = ["Walker2DBulletEnv", "HalfCheetahBulletEnv", "AntBulletEnv", "HopperBulletEnv", "HumanoidBulletEnv", "HumanoidFlagrunBulletEnv"]
def make_env(env_name, render=False):
if env_name in PYBULLET_ENVS:
from pybullet_envs import bullet
env = getattr(bullet, env_name)(renders=render)
return env
if env_name in PYBULLET_ENVS_DEEPMIMIC:
from pybullet_envs.deep_mimic.gym_env import deep_mimic_env
env = getattr(deep_mimic_env, env_name)(renders=render)
return env
if env_name in PYBULLET_ENVS_PENDULUM:
from pybullet_envs import gym_pendulum_envs
env = getattr(gym_pendulum_envs, env_name)()
if render:
env.render(mode='human')
return env
if env_name in PYBULLET_ENVS_MANIPULATOR:
from pybullet_envs import gym_manipulator_envs
env = getattr(gym_manipulator_envs, env_name)(render=render)
return env
if env_name in PYBULLET_ENVS_LOCOMOTION:
from pybullet_envs import gym_locomotion_envs
env = getattr(gym_locomotion_envs, env_name)(render=render)
return env
# Else
env = gym.make(env_name)
return env