Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[<Ray component: Core|RLlib] Single agent episode is incompatible with complex observations (when using DQNs) #49916

Open
laknath opened this issue Jan 17, 2025 · 0 comments
Labels
bug Something that is supposed to be working; but isn't rllib RLlib related issues triage Needs triage (eg: priority, bug/not-bug, and owning component)

Comments

@laknath
Copy link

laknath commented Jan 17, 2025

What happened + What you expected to happen

When using an environment with complex observations (dictionary) with RLlib DQN, Single Agent Episode (used by default) does not seem to support dictionaries.

  File "/Users/workspace/train_dqn.py", line 524, in <module>
    train(args, model=model)
  File "/Users/workspace/train_dqn.py", line 193, in train
    results = algo.train()
              ^^^^^^^^^^^^
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/tune/trainable/trainable.py", line 331, in train
    raise skipped from exception_cause(skipped)
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/tune/trainable/trainable.py", line 328, in train
    result = self.step()
             ^^^^^^^^^^^
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/algorithms/algorithm.py", line 936, in step
    train_results, train_iter_ctx = self._run_one_training_iteration()
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/algorithms/algorithm.py", line 3201, in _run_one_training_iteration
    training_step_return_value = self.training_step()
                                 ^^^^^^^^^^^^^^^^^^^^
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/algorithms/dqn/dqn.py", line 655, in training_step
    return self._training_step_new_api_stack(with_noise_reset=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/algorithms/dqn/dqn.py", line 679, in _training_step_new_api_stack
    self.local_replay_buffer.add(episodes)
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/utils/replay_buffers/prioritized_episode_buffer.py", line 282, in add
    existing_eps.concat_episode(eps)
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/env/single_agent_episode.py", line 618, in concat_episode
    assert np.all(other.observations[0] == self.observations[-1])
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

When the assert in /Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/env/single_agent_episode.py", line 618 is removed,

Traceback (most recent call last):
  File "/Users/workspace/train_dqn.py", line 524, in <module>
    train(args, model=model)
  File "/Users/workspace/train_dqn.py", line 193, in train
    results = algo.train()
              ^^^^^^^^^^^^
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/tune/trainable/trainable.py", line 331, in train
    raise skipped from exception_cause(skipped)
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/tune/trainable/trainable.py", line 328, in train
    result = self.step()
             ^^^^^^^^^^^
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/algorithms/algorithm.py", line 936, in step
    train_results, train_iter_ctx = self._run_one_training_iteration()
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/algorithms/algorithm.py", line 3201, in _run_one_training_iteration
    training_step_return_value = self.training_step()
                                 ^^^^^^^^^^^^^^^^^^^^
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/algorithms/dqn/dqn.py", line 655, in training_step
    return self._training_step_new_api_stack(with_noise_reset=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/algorithms/dqn/dqn.py", line 679, in _training_step_new_api_stack
    self.local_replay_buffer.add(episodes)
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/utils/replay_buffers/prioritized_episode_buffer.py", line 282, in add
    existing_eps.concat_episode(eps)
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/env/single_agent_episode.py", line 626, in concat_episode
    self.observations.extend(other.get_observations())
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/ray/rllib/env/utils/infinite_lookback_buffer.py", line 125, in extend
    self.data = tree.map_structure(
                ^^^^^^^^^^^^^^^^^^^
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/tree/__init__.py", line 426, in map_structure
    assert_same_structure(structures[0], other, check_types=check_types)
  File "/Users/workspace/.venv2/lib/python3.12/site-packages/tree/__init__.py", line 281, in assert_same_structure
    raise type(e)("%s\n"
ValueError: The two structures don't have the same nested structure.

Versions / Dependencies

ray 2.40.0
Python 3.12.7
MacOX 15.1.1

Reproduction script

import gymnasium as gym
import gymnasium.spaces as spaces
from dataclasses import dataclass

from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.examples.envs.classes.cartpole_with_dict_observation_space import (
    CartPoleWithDictObservationSpace,
)
from ray.tune.registry import register_env
from ray.rllib.utils.annotations import override
from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import Encoder, ENCODER_OUT
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.connectors.env_to_module import FlattenObservations
from ray.rllib.algorithms.dqn.dqn_rainbow_noisy_net_configs import NoisyMLPConfig
from ray.rllib.algorithms.dqn.torch.dqn_rainbow_torch_noisy_net import TorchNoisyMLPEncoder
from ray.rllib.algorithms.dqn.dqn_rainbow_catalog import DQNRainbowCatalog

register_env("CartPoleComplex-v1", lambda _: CartPoleWithDictObservationSpace())

class MaskedMultiInputTorchNoisyMLPEncoder(TorchNoisyMLPEncoder):
    def __init__(self, observation_space: gym.Space, config: NoisyMLPConfig) -> None:
        super().__init__(config)
        self.observation_space = observation_space

    """The Encoder class used to build models for multi input states (i.e. Dictionary)."""
    @override(TorchNoisyMLPEncoder)
    def _forward(self, inputs: dict, **kwargs) -> dict:
        preprocessed_obs = flatten_inputs_to_1d_tensor(inputs[Columns.OBS], self.observation_space)
        return {ENCODER_OUT: self.net(preprocessed_obs)}

@dataclass
class MultiInputMLPEncoderConfig(NoisyMLPConfig):
    """The Encoder class used to build models for multi input states (i.e. Dictionary)."""
    observation_space: gym.Space = None
    model_config_dict: dict = None

    def build(self, framework: str = "torch") -> "Encoder":
        self._validate(framework)

        if framework != "torch":
            raise ValueError(
                "`NoisyMLPEncoder` is not implemented for framework " f"{framework}. "
            )
        return MaskedMultiInputTorchNoisyMLPEncoder(
            observation_space=self.observation_space,
            config=self,
        )

class MultiInputDQNCatalog(DQNRainbowCatalog):
    """The catalog class used to build models for multi input states (i.e. Dictionary)."""

    @classmethod
    def _get_encoder_config(
        cls,
        observation_space: gym.Space,
        **kwargs,
    ):
        if isinstance(observation_space, spaces.Dict):
            flat_obs = FlattenObservations(
                            observation_space,
                            None,
                            multi_agent=False
                        ).recompute_output_observation_space(None, None)
            return MultiInputMLPEncoderConfig(
                observation_space=observation_space,
                model_config_dict = kwargs.get('model_config_dict', {}),
                input_dims = flat_obs.shape
            )
        else:
            return super()._get_encoder_config(observation_space, **kwargs)

config = (
    DQNConfig()
    .environment(env="CartPoleComplex-v1")
    .training(
        lr=0.0005 * (1) ** 0.5,
        train_batch_size_per_learner=32,
        replay_buffer_config={
            "type": "PrioritizedEpisodeReplayBuffer",
            "capacity": 50000,
            "alpha": 0.6,
            "beta": 0.4,
        },
        n_step=(2, 5),
        double_q=False,
        dueling=True,
        epsilon=[(0, 1.0), (10000, 0.02)],
    )
    .rl_module(
        rl_module_spec=RLModuleSpec(
            catalog_class=MultiInputDQNCatalog,
            model_config={
                "head_fcnet_hiddens": [64, 64],
                "head_fcnet_activation": "relu",
                "double_q": False,
                "uses_dueling": False,
                "noisy": False,
                "num_atoms": 1,
                "epsilon": 0.1,
            },
        )
    )
)

algo = config.build()
results = algo.train()
print(f"results={results}")
algo.stop()

Issue Severity

High: It blocks me from completing my task.

@laknath laknath added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jan 17, 2025
@jcotant1 jcotant1 added the rllib RLlib related issues label Jan 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't rllib RLlib related issues triage Needs triage (eg: priority, bug/not-bug, and owning component)
Projects
None yet
Development

No branches or pull requests

2 participants