Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Local seed #154

Merged
merged 4 commits into from
Mar 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/test_vmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,21 @@ def test_vmas_differentiable(scenario, n_steps=10, n_envs=10):

loss = obs[-1].mean() + rews[-1].mean()
grad = torch.autograd.grad(loss, first_action)


def test_seeding():
env = make_env(scenario="balance", num_envs=2, seed=0)
env.seed(0)
random_obs = env.reset()[0][0, 0]
env.seed(0)
assert random_obs == env.reset()[0][0, 0]
env.seed(0)
torch.manual_seed(1)
assert random_obs == env.reset()[0][0, 0]

torch.manual_seed(0)
random_obs = torch.randn(1)
torch.manual_seed(0)
env.seed(1)
env.reset()
assert random_obs == torch.randn(1)
183 changes: 160 additions & 23 deletions vmas/simulator/environment/environment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2022-2024.
# Copyright (c) 2022-2025.
# ProrokLab (https://www.proroklab.org/)
# All rights reserved.
import contextlib
import math
import random
from ctypes import byref
Expand All @@ -26,14 +27,41 @@
)


# environment for all agents in the multiagent world
# currently code assumes that no agents will be created/destroyed at runtime!
@contextlib.contextmanager
def local_seed(vmas_random_state):
torch_state = torch.random.get_rng_state()
np_state = np.random.get_state()
py_state = random.getstate()

torch.random.set_rng_state(vmas_random_state[0])
np.random.set_state(vmas_random_state[1])
random.setstate(vmas_random_state[2])
yield
vmas_random_state[0] = torch.random.get_rng_state()
vmas_random_state[1] = np.random.get_state()
vmas_random_state[2] = random.getstate()

torch.random.set_rng_state(torch_state)
np.random.set_state(np_state)
random.setstate(py_state)


class Environment(TorchVectorizedObject):
"""
The VMAS environment
"""

metadata = {
"render.modes": ["human", "rgb_array"],
"runtime.vectorized": True,
}
vmas_random_state = [
torch.random.get_rng_state(),
np.random.get_state(),
random.getstate(),
]

@local_seed(vmas_random_state)
def __init__(
self,
scenario: BaseScenario,
Expand Down Expand Up @@ -68,7 +96,7 @@ def __init__(
self.grad_enabled = grad_enabled
self.terminated_truncated = terminated_truncated

observations = self.reset(seed=seed)
observations = self._reset(seed=seed)

# configure spaces
self.multidiscrete_actions = multidiscrete_actions
Expand All @@ -81,6 +109,7 @@ def __init__(
self.visible_display = None
self.text_lines = None

@local_seed(vmas_random_state)
def reset(
self,
seed: Optional[int] = None,
Expand All @@ -92,21 +121,112 @@ def reset(
Resets the environment in a vectorized way
Returns observations for all envs and agents
"""
return self._reset(
seed=seed,
return_observations=return_observations,
return_info=return_info,
return_dones=return_dones,
)

@local_seed(vmas_random_state)
def reset_at(
self,
index: int,
return_observations: bool = True,
return_info: bool = False,
return_dones: bool = False,
):
"""
Resets the environment at index
Returns observations for all agents in that environment
"""
return self._reset_at(
index=index,
return_observations=return_observations,
return_info=return_info,
return_dones=return_dones,
)

@local_seed(vmas_random_state)
def get_from_scenario(
self,
get_observations: bool,
get_rewards: bool,
get_infos: bool,
get_dones: bool,
dict_agent_names: Optional[bool] = None,
):
"""
Get the environment data from the scenario

Args:
get_observations (bool): whether to return the observations
get_rewards (bool): whether to return the rewards
get_infos (bool): whether to return the infos
get_dones (bool): whether to return the dones
dict_agent_names (bool, optional): whether to return the information in a dictionary with agent names as keys
or in a list

Returns:
The agents' data

"""
return self._get_from_scenario(
get_observations=get_observations,
get_rewards=get_rewards,
get_infos=get_infos,
get_dones=get_dones,
dict_agent_names=dict_agent_names,
)

@local_seed(vmas_random_state)
def seed(self, seed=None):
"""
Sets the seed for the environment
Args:
seed (int, optional): Seed for the environment. Defaults to None.

"""
return self._seed(seed=seed)

@local_seed(vmas_random_state)
def done(self):
"""
Get the done flags for the scenario.

Returns:
Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False)

"""
return self._done()

def _reset(
self,
seed: Optional[int] = None,
return_observations: bool = True,
return_info: bool = False,
return_dones: bool = False,
):
"""
Resets the environment in a vectorized way
Returns observations for all envs and agents
"""

if seed is not None:
self.seed(seed)
self._seed(seed)
# reset world
self.scenario.env_reset_world_at(env_index=None)
self.steps = torch.zeros(self.num_envs, device=self.device)

result = self.get_from_scenario(
result = self._get_from_scenario(
get_observations=return_observations,
get_infos=return_info,
get_rewards=False,
get_dones=return_dones,
)
return result[0] if result and len(result) == 1 else result

def reset_at(
def _reset_at(
self,
index: int,
return_observations: bool = True,
Expand All @@ -121,7 +241,7 @@ def reset_at(
self.scenario.env_reset_world_at(index)
self.steps[index] = 0

result = self.get_from_scenario(
result = self._get_from_scenario(
get_observations=return_observations,
get_infos=return_info,
get_rewards=False,
Expand All @@ -130,7 +250,7 @@ def reset_at(

return result[0] if result and len(result) == 1 else result

def get_from_scenario(
def _get_from_scenario(
self,
get_observations: bool,
get_rewards: bool,
Expand Down Expand Up @@ -178,35 +298,41 @@ def get_from_scenario(

if self.terminated_truncated:
if get_dones:
terminated, truncated = self.done()
terminated, truncated = self._done()
result = [obs, rewards, terminated, truncated, infos]
else:
if get_dones:
dones = self.done()
dones = self._done()
result = [obs, rewards, dones, infos]

return [data for data in result if data is not None]

def seed(self, seed=None):
def _seed(self, seed=None):
"""
Sets the seed for the environment
Args:
seed (int, optional): Seed for the environment. Defaults to None.

"""
if seed is None:
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
return [seed]

@local_seed(vmas_random_state)
def step(self, actions: Union[List, Dict]):
"""Performs a vectorized step on all sub environments using `actions`.

Args:
actions: Is a list on len 'self.n_agents' of which each element is a torch.Tensor of shape
'(self.num_envs, action_size_of_agent)'.
actions: Is a list on len 'self.n_agents' of which each element is a torch.Tensor of shape '(self.num_envs, action_size_of_agent)'.

Returns:
obs: List on len 'self.n_agents' of which each element is a torch.Tensor
of shape '(self.num_envs, obs_size_of_agent)'
obs: List on len 'self.n_agents' of which each element is a torch.Tensor of shape '(self.num_envs, obs_size_of_agent)'
rewards: List on len 'self.n_agents' of which each element is a torch.Tensor of shape '(self.num_envs)'
dones: Tensor of len 'self.num_envs' of which each element is a bool
infos : List on len 'self.n_agents' of which each element is a dictionary for which each key is a metric
and the value is a tensor of shape '(self.num_envs, metric_size_per_agent)'
infos: List on len 'self.n_agents' of which each element is a dictionary for which each key is a metric and the value is a tensor of shape '(self.num_envs, metric_size_per_agent)'

Examples:
>>> import vmas
Expand All @@ -222,6 +348,7 @@ def step(self, actions: Union[List, Dict]):
>>> obs = env.reset()
>>> for _ in range(10):
... obs, rews, dones, info = env.step(env.get_random_actions())

"""
if isinstance(actions, Dict):
actions_dict = actions
Expand Down Expand Up @@ -269,14 +396,21 @@ def step(self, actions: Union[List, Dict]):

self.steps += 1

return self.get_from_scenario(
return self._get_from_scenario(
get_observations=True,
get_infos=True,
get_rewards=True,
get_dones=True,
)

def done(self):
def _done(self):
"""
Get the done flags for the scenario.

Returns:
Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False)

"""
terminated = self.scenario.done().clone()

if self.max_steps is not None:
Expand Down Expand Up @@ -387,6 +521,7 @@ def get_agent_observation_space(self, agent: Agent, obs: AGENT_OBS_TYPE):
f"Invalid type of observation {obs} for agent {agent.name}"
)

@local_seed(vmas_random_state)
def get_random_action(self, agent: Agent) -> torch.Tensor:
"""Returns a random action for the given agent.

Expand Down Expand Up @@ -447,7 +582,7 @@ def get_random_action(self, agent: Agent) -> torch.Tensor:
return action

def get_random_actions(self) -> Sequence[torch.Tensor]:
"""Returns random actions for all agents that you can feed to :class:`step`
"""Returns random actions for all agents that you can feed to :meth:`step`

Returns:
Sequence[torch.tensor]: the random actions for the agents
Expand Down Expand Up @@ -612,6 +747,7 @@ def _set_action(self, action, agent):
)
agent.action.c += noise

@local_seed(vmas_random_state)
def render(
self,
mode="human",
Expand All @@ -635,15 +771,15 @@ def render(
Render function for environment using pyglet

On servers use mode="rgb_array" and set

```
export DISPLAY=':99.0'
Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 &
```

:param mode: One of human or rgb_array
:param env_index: Index of the environment to render
:param agent_index_focus: If specified the camera will stay on the agent with this index.
If None, the camera will stay in the center and zoom out to contain all agents
:param agent_index_focus: If specified the camera will stay on the agent with this index. If None, the camera will stay in the center and zoom out to contain all agents
:param visualize_when_rgb: Also run human visualization when mode=="rgb_array"
:param plot_position_function: A function to plot under the rendering.
The function takes a numpy array with shape (n_points, 2), which represents a set of x,y values to evaluate f over and plot it
Expand All @@ -657,6 +793,7 @@ def render(
:param plot_position_function_cmap_range: The range of the cmap in case plot_position_function outputs a single value
:param plot_position_function_cmap_alpha: The alpha of the cmap in case plot_position_function outputs a single value
:return: Rgb array or None, depending on the mode

"""
self._check_batch_index(env_index)
assert (
Expand Down