diff --git a/navix/_version.py b/navix/_version.py index 2d0377e..85ee2ad 100644 --- a/navix/_version.py +++ b/navix/_version.py @@ -18,5 +18,5 @@ # under the License. -__version__ = "0.6.15" +__version__ = "0.6.16" __version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit()) diff --git a/navix/actions.py b/navix/actions.py index f7b6163..8897883 100644 --- a/navix/actions.py +++ b/navix/actions.py @@ -70,7 +70,7 @@ def _can_walk_there(state: State, position: Array) -> Tuple[Array, EventsManager obstructs = jnp.logical_and( jnp.logical_not(state.entities[k].walkable), same_position ) - walkable = jnp.logical_and(walkable, jnp.any(jnp.logical_not(obstructs))) + walkable = jnp.logical_and(walkable, jnp.all(jnp.logical_not(obstructs))) return jnp.asarray(walkable, dtype=jnp.bool_), events diff --git a/navix/environments/lava_gap.py b/navix/environments/lava_gap.py index cc67211..f87ae4c 100644 --- a/navix/environments/lava_gap.py +++ b/navix/environments/lava_gap.py @@ -96,36 +96,54 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times register_env( "Navix-LavaGapS5-v0", lambda *args, **kwargs: LavaGap.create( - *args, - **kwargs, height=5, width=5, observation_fn=kwargs.pop("observation_fn", observations.symbolic), reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), - termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), + termination_fn=kwargs.pop( + "termination_fn", + terminations.compose( + terminations.on_goal_reached, + terminations.on_lava_fall, + ), + ), + *args, + **kwargs, ), ) register_env( "Navix-LavaGapS6-v0", lambda *args, **kwargs: LavaGap.create( - *args, - **kwargs, height=6, width=6, observation_fn=kwargs.pop("observation_fn", observations.symbolic), reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), - termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), + termination_fn=kwargs.pop( + "termination_fn", + terminations.compose( + terminations.on_goal_reached, + terminations.on_lava_fall, + ), + ), + *args, + **kwargs, ), ) register_env( "Navix-LavaGapS7-v0", lambda *args, **kwargs: LavaGap.create( - *args, - **kwargs, height=7, width=7, observation_fn=kwargs.pop("observation_fn", observations.symbolic), reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached), - termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached), + termination_fn=kwargs.pop( + "termination_fn", + terminations.compose( + terminations.on_goal_reached, + terminations.on_lava_fall, + ), + ), + *args, + **kwargs, ), ) diff --git a/tests/test_issues.py b/tests/test_issues.py new file mode 100644 index 0000000..91c9bb4 --- /dev/null +++ b/tests/test_issues.py @@ -0,0 +1,53 @@ +# Copyright 2023 The Navix Authors. + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import jax +import jax.numpy as jnp + +import navix as nx +from navix import observations + + +def test_82(): + env = nx.make( + "Navix-DoorKey-5x5-v0", + max_steps=100, + observation_fn=observations.rgb, + ) + key = jax.random.PRNGKey(5) + timestep = env.reset(key) + # Seed 5 is: + # # # # # + # P # . # + # . # . # + # K D G # + # # # # # + + # start agent direction = EAST + prev_pos = timestep.state.entities["player"].position + # action 2 is forward + timestep = env.step(timestep, 2) # should not walk into wall + pos = timestep.state.entities["player"].position + assert jnp.array_equal(prev_pos, pos) + + +if __name__ == "__main__": + test_82()