Skip to content

Commit 1cc95ad

Browse files
authored
Merge pull request #38 from Farama-Foundation/api/aec-record-statistics
RecordStatisticsWrapper for AEC
2 parents f6202b1 + bb6e3a8 commit 1cc95ad

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

momaland/utils/aec_wrappers.py

+35
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,45 @@
11
"""Various wrappers for AEC MO environments."""
2+
from typing import Optional
23

34
import numpy as np
45
from gymnasium.wrappers.normalize import RunningMeanStd
56
from pettingzoo.utils.wrappers.base import BaseWrapper
67

78

9+
class RecordEpisodeStatistics(BaseWrapper):
10+
"""This wrapper will record episode statistics and print them at the end of each episode."""
11+
12+
def __init__(self, env):
13+
"""This wrapper will record episode statistics and print them at the end of each episode.
14+
15+
Args:
16+
env (env): The environment to apply the wrapper
17+
"""
18+
BaseWrapper.__init__(self, env)
19+
self.episode_rewards = {agent: 0 for agent in self.possible_agents}
20+
self.episode_lengths = {agent: 0 for agent in self.possible_agents}
21+
22+
def last(self, observe: bool = True):
23+
"""Receives the latest observation from the environment, recording episode statistics."""
24+
obs, rews, terminated, truncated, infos = super().last(observe=observe)
25+
for agent in self.env.possible_agents:
26+
self.episode_rewards[agent] += rews
27+
self.episode_lengths[agent] += 1
28+
if terminated or truncated:
29+
infos["episode"] = {
30+
"r": self.episode_rewards,
31+
"l": self.episode_lengths,
32+
}
33+
return obs, rews, terminated, truncated, infos
34+
35+
def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
36+
"""Resets the environment and the episode statistics."""
37+
super().reset(seed, options)
38+
for agent in self.env.possible_agents:
39+
self.episode_rewards[agent] = 0
40+
self.episode_lengths[agent] = 0
41+
42+
843
class LinearizeReward(BaseWrapper):
944
"""Convert MO reward vector into scalar SO reward value.
1045

momaland/utils/parallel_wrappers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def step(self, actions):
3333
return obs, rews, terminateds, truncateds, infos
3434

3535
def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
36-
"""Resets the environment, recording episode statistics."""
36+
"""Resets the environment and the episode statistics."""
3737
obs, info = super().reset(seed, options)
3838
for agent in self.env.possible_agents:
3939
self.episode_rewards[agent] = 0

0 commit comments

Comments
 (0)