Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions bsuite/environments/catch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
# ============================================================================
"""Catch reinforcement learning environment."""
import warnings

from typing import Optional

Expand Down Expand Up @@ -65,6 +66,13 @@ def __init__(self,
self._total_regret = 0.
self.bsuite_num_episodes = sweep.NUM_EPISODES

def _get_observation(self):
self._board.fill(0.)
self._board[self._ball_y, self._ball_x] = 1.
self._board[self._paddle_y, self._paddle_x] = 1.

return self._board.copy()

def _reset(self) -> dm_env.TimeStep:
"""Returns the first `TimeStep` of a new episode."""
self._reset_next_step = False
Expand Down Expand Up @@ -107,11 +115,10 @@ def action_spec(self) -> specs.DiscreteArray:
dtype=np.int, num_values=len(_ACTIONS), name="action")

def _observation(self) -> np.ndarray:
self._board.fill(0.)
self._board[self._ball_y, self._ball_x] = 1.
self._board[self._paddle_y, self._paddle_x] = 1.

return self._board.copy()
warnings.warn(
"Deprecated method `_observation`, use `_get_observation` instead."
)
return self._get_observation()

def bsuite_info(self):
return dict(total_regret=self._total_regret)