Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pufferlib/config/ocean/four_rooms.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[base]
package = ocean
env_name = puffer_four_rooms
policy_name = Policy
rnn_name = Recurrent

[env]
num_envs = 256

[train]
total_timesteps = 10_000_000
gamma = 0.99
learning_rate = 0.015
minibatch_size = 32768
1 change: 1 addition & 0 deletions pufferlib/ocean/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def make_multiagent(buf=None, **kwargs):
'pacman': 'Pacman',
'checkers': 'Checkers',
'asteroids': 'Asteroids',
'four_rooms': 'FourRooms',
'whisker_racer': 'WhiskerRacer',
'spaces': make_spaces,
'multiagent': make_multiagent,
Expand Down
20 changes: 20 additions & 0 deletions pufferlib/ocean/four_rooms/binding.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "four_rooms.h"

#define Env FourRooms
#include "../env_binding.h"

static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
env->size = unpack(kwargs, "size");
env->see_through_walls = 0;
// Allocate grid memory for full state (stores OBJECT_IDX values)
env->grid = (unsigned char*)calloc(env->size * env->size, sizeof(unsigned char));
return 0;
}

static int my_log(PyObject* dict, Log* log) {
assign_to_dict(dict, "perf", log->perf);
assign_to_dict(dict, "score", log->score);
assign_to_dict(dict, "episode_return", log->episode_return);
assign_to_dict(dict, "episode_length", log->episode_length);
return 0;
}
34 changes: 34 additions & 0 deletions pufferlib/ocean/four_rooms/four_rooms.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "four_rooms.h"

int main() {
FourRooms env = {};
env.size = 19;
env.observations = (unsigned char*)calloc(7*7*3, sizeof(unsigned char)); // 7x7x3 for MinGrid encoding
env.actions = (int*)calloc(1, sizeof(int));
env.rewards = (float*)calloc(1, sizeof(float));
env.terminals = (unsigned char*)calloc(1, sizeof(unsigned char));
env.grid = (unsigned char*)calloc(env.size * env.size, sizeof(unsigned char));

c_reset(&env);
c_render(&env);
while (!WindowShouldClose()) {
if (IsKeyDown(KEY_LEFT_SHIFT)) {
env.actions[0] = 7; // Invalid action = no-op
if (IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) env.actions[0] = FORWARD;
if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) env.actions[0] = LEFT;
if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) env.actions[0] = RIGHT;
} else {
env.actions[0] = rand() % 3; // Only use left, right, forward
}
c_step(&env);
c_render(&env);
}
free(env.observations);
free(env.actions);
free(env.rewards);
free(env.terminals);
free(env.grid);
c_close(&env);
return 0;
}

Loading