|
1 | 1 | """Various wrappers for AEC MO environments."""
|
| 2 | +from typing import Optional |
2 | 3 |
|
3 | 4 | import numpy as np
|
4 | 5 | from gymnasium.wrappers.normalize import RunningMeanStd
|
5 | 6 | from pettingzoo.utils.wrappers.base import BaseWrapper
|
6 | 7 |
|
7 | 8 |
|
| 9 | +class RecordEpisodeStatistics(BaseWrapper): |
| 10 | + """This wrapper will record episode statistics and print them at the end of each episode.""" |
| 11 | + |
| 12 | + def __init__(self, env): |
| 13 | + """This wrapper will record episode statistics and print them at the end of each episode. |
| 14 | +
|
| 15 | + Args: |
| 16 | + env (env): The environment to apply the wrapper |
| 17 | + """ |
| 18 | + BaseWrapper.__init__(self, env) |
| 19 | + self.episode_rewards = {agent: 0 for agent in self.possible_agents} |
| 20 | + self.episode_lengths = {agent: 0 for agent in self.possible_agents} |
| 21 | + |
| 22 | + def last(self, observe: bool = True): |
| 23 | + """Receives the latest observation from the environment, recording episode statistics.""" |
| 24 | + obs, rews, terminated, truncated, infos = super().last(observe=observe) |
| 25 | + for agent in self.env.possible_agents: |
| 26 | + self.episode_rewards[agent] += rews |
| 27 | + self.episode_lengths[agent] += 1 |
| 28 | + if terminated or truncated: |
| 29 | + infos["episode"] = { |
| 30 | + "r": self.episode_rewards, |
| 31 | + "l": self.episode_lengths, |
| 32 | + } |
| 33 | + return obs, rews, terminated, truncated, infos |
| 34 | + |
| 35 | + def reset(self, seed: Optional[int] = None, options: Optional[dict] = None): |
| 36 | + """Resets the environment and the episode statistics.""" |
| 37 | + super().reset(seed, options) |
| 38 | + for agent in self.env.possible_agents: |
| 39 | + self.episode_rewards[agent] = 0 |
| 40 | + self.episode_lengths[agent] = 0 |
| 41 | + |
| 42 | + |
8 | 43 | class LinearizeReward(BaseWrapper):
|
9 | 44 | """Convert MO reward vector into scalar SO reward value.
|
10 | 45 |
|
|
0 commit comments