Skip to content

Commit

Permalink
Merge pull request #29 from epignatelli/dev/walkable
Browse files Browse the repository at this point in the history
Dev/walkable
  • Loading branch information
epignatelli authored Jun 20, 2023
2 parents c37dad5 + 6c21be4 commit d850a0f
Show file tree
Hide file tree
Showing 14 changed files with 144 additions and 48 deletions.
1 change: 0 additions & 1 deletion VERSION.txt

This file was deleted.

4 changes: 0 additions & 4 deletions navix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
# under the License.


__version__ = open("VERSION.txt", "r").read().strip()
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())


from . import (
actions,
components,
Expand Down
22 changes: 22 additions & 0 deletions navix/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2023 The Navix Authors.

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


__version__ = "0.2.2"
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())
17 changes: 15 additions & 2 deletions navix/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from __future__ import annotations
from typing import Tuple

import chex
import jax
import jax.numpy as jnp
from jax import Array
Expand All @@ -47,9 +46,23 @@ def _translate(position: Array, direction: Array) -> Array:
return jax.lax.switch(direction, moves, position)


def _move_allowed(state: State, position: Array) -> Array:
# according to the grid
walkable = jnp.equal(state.grid[tuple(position)], 0)
# and not occupied by another non-walkable entity
occupied_keys = jax.vmap(lambda x: jnp.array_equal(x, position))(
state.keys.position
)
occupied_doors = jax.vmap(lambda x: jnp.array_equal(x, position))(
state.doors.position
)
occupied = jnp.any(jnp.concatenate([occupied_keys, occupied_doors]))
return jnp.logical_and(walkable, jnp.logical_not(occupied))


def _move(state: State, direction: Array) -> State:
new_position = _translate(state.player.position, direction)
can_move = jnp.equal(state.grid[tuple(state.player.position)], 0)
can_move = _move_allowed(state, new_position)
new_position = jnp.where(can_move, new_position, state.player.position)
player = state.player.replace(position=new_position)
return state.replace(player=player)
Expand Down
2 changes: 1 addition & 1 deletion navix/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Pickable(Component):

@property
def tag(self):
return - self.id
return -self.id


class Consumable(Component):
Expand Down
11 changes: 9 additions & 2 deletions navix/environments/keydoor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from navix.environments import Environment
from navix.components import State, Player, Pickable, Consumable, Goal
from navix.environments import Timestep
from navix.grid import two_rooms, random_positions, random_directions, mask_by_coordinates
from navix.grid import (
two_rooms,
random_positions,
random_directions,
mask_by_coordinates,
)


class KeyDoor(Environment):
Expand All @@ -29,7 +34,9 @@ def reset(self, key) -> Timestep:

# spawn the goal in the second room
second_room = jnp.where(first_room_mask, -1, grid)
goal_pos = random_positions(k2, second_room, exclude=jnp.stack([player_pos, key_pos]))
goal_pos = random_positions(
k2, second_room, exclude=jnp.stack([player_pos, key_pos])
)
goals = Goal(position=goal_pos[None])

# add the door
Expand Down
10 changes: 6 additions & 4 deletions navix/graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def colour_chart(size: int = TILE_SIZE) -> Array:
grid = jnp.zeros((size * len(colours), size * len(colours), 3), dtype=jnp.uint8)
for i, colour in enumerate(colours):
for j, colour in enumerate(colours):
grid = grid.at[
i * size : (i + 1) * size, j * size : (j + 1) * size
].set(colour)
grid = grid.at[i * size : (i + 1) * size, j * size : (j + 1) * size].set(
colour
)
return grid


Expand Down Expand Up @@ -153,7 +153,9 @@ def door_tile(size: int = TILE_SIZE, colour: Array = BROWN) -> Array:
x_0 = TILE_SIZE - TILE_SIZE // 4
y_centre = TILE_SIZE // 2
y_size = TILE_SIZE // 5
door = door.at[y_centre - y_size // 2:y_centre + y_size // 2, x_0:x_0 + 1].set(1)
door = door.at[y_centre - y_size // 2 : y_centre + y_size // 2, x_0 : x_0 + 1].set(
1
)
return colorise_tile(door, colour)


Expand Down
6 changes: 3 additions & 3 deletions navix/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def random_positions(

# temporarily set excluded positions to 0 probability
mask = jnp.min(
jax.vmap(
lambda position: mask.at[idx_from_coordinates(grid, position)].set(0)
)(exclude),
jax.vmap(lambda position: mask.at[idx_from_coordinates(grid, position)].set(0))(
exclude
),
axis=0,
)

Expand Down
52 changes: 31 additions & 21 deletions navix/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
floor_tile,
wall_tile,
mosaic,
TILE_SIZE
TILE_SIZE,
)


Expand All @@ -56,39 +56,46 @@ def categorical(state: State) -> Array:
)
# place doors
grid = jnp.max(
jax.vmap(lambda door: grid.at[tuple(door.position)].set(door.tag))(
state.doors
),
jax.vmap(lambda door: grid.at[tuple(door.position)].set(door.tag))(state.doors),
axis=0,
)
# place goals
grid = jnp.max(
jax.vmap(lambda goal: grid.at[tuple(goal.position)].set(goal.tag))(state.goals), axis=0
jax.vmap(lambda goal: grid.at[tuple(goal.position)].set(goal.tag))(state.goals),
axis=0,
)
# place player last, always on top
grid = grid.at[tuple(state.player.position)].set(state.player.tag)
return grid


def rgb(state: State) -> Array:
positions = jnp.stack([
*state.keys.position,
*state.doors.position,
*state.goals.position,
state.player.position,
])

tiles = jnp.stack([
*([key_tile()] * len(state.keys.position)),
*([door_tile()] * len(state.doors.position)),
*([diamond_tile()] * len(state.goals.position)),
triangle_east_tile(),
])
positions = jnp.stack(
[
*state.keys.position,
*state.doors.position,
*state.goals.position,
state.player.position,
]
)

tiles = jnp.stack(
[
*([key_tile()] * len(state.keys.position)),
*([door_tile()] * len(state.doors.position)),
*([diamond_tile()] * len(state.goals.position)),
triangle_east_tile(),
]
)

def draw(carry, x):
image = carry
mask, tile = x
mask = jax.image.resize(mask,(mask.shape[0] * TILE_SIZE, mask.shape[1] * TILE_SIZE), method='nearest')
mask = jax.image.resize(
mask,
(mask.shape[0] * TILE_SIZE, mask.shape[1] * TILE_SIZE),
method="nearest",
)
mask = jnp.stack([mask] * tile.shape[-1], axis=-1)
tiled = mosaic(state.grid, tile)
image = jnp.where(mask, tiled, image)
Expand All @@ -99,7 +106,10 @@ def body_fun(carry, x):
mask = jnp.zeros_like(state.grid).at[tuple(position)].set(1)
return draw(carry, (mask, tile))

background = jnp.zeros((state.grid.shape[0] * TILE_SIZE, state.grid.shape[1] * TILE_SIZE, 3), dtype=jnp.uint8)
background = jnp.zeros(
(state.grid.shape[0] * TILE_SIZE, state.grid.shape[1] * TILE_SIZE, 3),
dtype=jnp.uint8,
)

# add floor
floor_mask = jnp.where(state.grid == 0, 1, 0)
Expand All @@ -113,4 +123,4 @@ def body_fun(carry, x):

# add entities
image, _ = jax.lax.scan(body_fun, background, (positions, tiles))
return image # type: ignore
return image # type: ignore
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ bug_tracker = "https://github.com/epignatelli/navix/issues"


[tool.setuptools.dynamic]
version = {file = "VERSION.txt"}
version = {file = "_version.py"}
dependencies = {file = "./requirements.txt"}


Expand Down
21 changes: 16 additions & 5 deletions scripts/release.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
# get current version
SCRIPT=$(readlink -f "$0")
#!/bin/bash

# get current script directory
SCRIPT="$(readlink -f "$0")"
SCRIPT_DIR=$(dirname "$SCRIPT")
echo "Script dir is: $SCRIPT_DIR"
VERSION=$SCRIPT_DIR/../VERSION.txt
echo "Current version is: $(cat $VERSION)"

# get version file
VERSION_FILE="$SCRIPT_DIR/../navix/_version.py"
VERSION_CONTENT="$(cat "$VERSION_FILE")"
echo "Version file found at: $VERSION_FILE and contains:"
echo "$VERSION_CONTENT"

# extract version
VERSION=$(cat navix/_version.py | grep "__version__ = " | cut -d'=' -f2 | sed 's,\",,g' | sed "s,',,g" | sed 's, ,,g')
echo "Current version is:"
echo "$VERSION"

# cd to repo dir
REPO_DIR="$(cd "$(dirname -- "$1")" >/dev/null; pwd -P)/$(basename -- "$1")"
Expand All @@ -18,4 +29,4 @@ git push origin $(cat $VERSION)
gh release create $VERSION

# trigger CD
gh workflow run cd.yml
gh workflow run cd.yml
36 changes: 35 additions & 1 deletion tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,39 @@ def test_rotation():
return


def test_walkable():
height = 6
width = 18
env = nx.environments.KeyDoor(
height=height,
width=width,
max_steps=100,
observation_fn=nx.observations.categorical,
)

key = jax.random.PRNGKey(0)
timestep = env.reset(key)
actions = (
2,
3, # in front of key after this
)
actions_stuck = (
3, # should not be able to move forward
3, # should not be able to move forward
2, # rotate towards the wall
3, # should not be able to move forward
3, # should not be able to move forward
)
for action in actions:
timestep = env.step(timestep, jnp.asarray(action))
print(timestep.state.player.position)

for action in actions_stuck:
next_timestep = env.step(timestep, jnp.asarray(action))
print(timestep.state.player.position)
assert jnp.array_equal(timestep.state.player.position, next_timestep.state.player.position)
timestep = next_timestep

if __name__ == "__main__":
test_rotation()
# test_rotation()
test_walkable()
6 changes: 4 additions & 2 deletions tests/test_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@


def test_rgb():
env = nx.environments.KeyDoor(height=10, width=5, max_steps=100, observation_fn=nx.observations.rgb)
env = nx.environments.KeyDoor(
height=10, width=5, max_steps=100, observation_fn=nx.observations.rgb
)
key = jax.random.PRNGKey(4)
state = jax.jit(env.reset)(key)
print(state.observation)


if __name__ == "__main__":
test_rgb()
test_rgb()
2 changes: 1 addition & 1 deletion tests/test_terminations.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ def f():

if __name__ == "__main__":
test_termination()
test_truncation()
test_truncation()

0 comments on commit d850a0f

Please sign in to comment.