Skip to content

Commit d0861c3

Browse files
author
Joseph Suarez
committed
Initial contextual breakout; exposed more variables
1 parent 55998af commit d0861c3

File tree

5 files changed

+55
-30
lines changed

5 files changed

+55
-30
lines changed

pufferlib/config/ocean/breakout.ini

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,21 @@ num_envs = 8
1010
[env]
1111
num_envs = 1024
1212
frameskip = 4
13-
13+
width = 576
14+
height = 330
15+
paddle_width = 62
16+
paddle_height = 8
17+
ball_width = 32
18+
ball_height = 32
19+
brick_width = 32
20+
brick_height = 12
21+
brick_rows = 6
22+
brick_cols = 18
23+
initial_ball_speed = 256
24+
max_ball_speed = 448
25+
paddle_speed = 620
26+
continuous = 0
27+
1428
[policy]
1529
hidden_size = 128
1630

pufferlib/ocean/breakout/binding.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
77
env->frameskip = unpack(kwargs, "frameskip");
88
env->width = unpack(kwargs, "width");
99
env->height = unpack(kwargs, "height");
10-
env->paddle_width = unpack(kwargs, "paddle_width");
10+
env->initial_paddle_width = unpack(kwargs, "paddle_width");
1111
env->paddle_height = unpack(kwargs, "paddle_height");
1212
env->ball_width = unpack(kwargs, "ball_width");
1313
env->ball_height = unpack(kwargs, "ball_height");
1414
env->brick_width = unpack(kwargs, "brick_width");
1515
env->brick_height = unpack(kwargs, "brick_height");
1616
env->brick_rows = unpack(kwargs, "brick_rows");
1717
env->brick_cols = unpack(kwargs, "brick_cols");
18+
env->initial_ball_speed = unpack(kwargs, "initial_ball_speed");
19+
env->max_ball_speed = unpack(kwargs, "max_ball_speed");
20+
env->paddle_speed = unpack(kwargs, "paddle_speed");
1821
env->continuous = unpack(kwargs, "continuous");
1922
init(env);
2023
return 0;

pufferlib/ocean/breakout/breakout.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ void demo() {
1919
.brick_height = 12,
2020
.brick_rows = 6,
2121
.brick_cols = 18,
22+
.initial_ball_speed = 256,
23+
.max_ball_speed = 448,
24+
.paddle_speed = 620,
2225
.continuous = 0,
2326
};
2427
allocate(&env);

pufferlib/ocean/breakout/breakout.h

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#define NOOP 0
1010
#define LEFT 1
1111
#define RIGHT 2
12-
#define MAX_BALL_SPEED 448
1312
#define HALF_PADDLE_WIDTH 31
1413
#define Y_OFFSET 50
1514
#define TICK_RATE 1.0f/60.0f
@@ -55,9 +54,13 @@ typedef struct Breakout {
5554
float* brick_y;
5655
float* brick_states;
5756
int balls_fired;
57+
float initial_paddle_width;
5858
float paddle_width;
5959
float paddle_height;
60+
float paddle_speed;
6061
float ball_speed;
62+
float initial_ball_speed;
63+
float max_ball_speed;
6164
int hits;
6265
int width;
6366
int height;
@@ -289,15 +292,13 @@ bool calc_paddle_ball_collisions(Breakout* env, CollisionInfo* collision_info) {
289292
collision_info->brick_index = BRICK_INDEX_PADDLE_COLLISION;
290293

291294
env->hit_brick = false;
292-
float relative_intersection = ((env->ball_x +
293-
env->ball_width / 2) -
294-
env->paddle_x) /
295-
env->paddle_width;
295+
float relative_intersection = (
296+
(env->ball_x + env->ball_width / 2) - env->paddle_x) / env->paddle_width;
296297
float angle = -base_angle + relative_intersection * 2 * base_angle;
297298
env->ball_vx = sin(angle) * env->ball_speed * TICK_RATE;
298299
env->ball_vy = -cos(angle) * env->ball_speed * TICK_RATE;
299300
env->hits += 1;
300-
if (env->hits % 4 == 0 && env->ball_speed < MAX_BALL_SPEED) {
301+
if (env->hits % 4 == 0 && env->ball_speed < env->max_ball_speed) {
301302
env->ball_speed += 64;
302303
}
303304
if (env->score == env->half_max_score) {
@@ -336,12 +337,16 @@ void calc_all_wall_collisions(Breakout* env, CollisionInfo* collision_info) {
336337
// With rare floating point conditions, the ball could escape the bounds.
337338
// Let's handle that explicitly.
338339
void check_wall_bounds(Breakout* env) {
339-
if (env->ball_x < 0)
340-
env->ball_x += MAX_BALL_SPEED * 1.1f * TICK_RATE;
341-
if (env->ball_x > env->width)
342-
env->ball_x -= MAX_BALL_SPEED * 1.1f * TICK_RATE;
343-
if (env->ball_y < 0)
344-
env->ball_y += MAX_BALL_SPEED * 1.1f * TICK_RATE;
340+
float offset = env->max_ball_speed * 1.1f * TICK_RATE;
341+
if (env->ball_x < 0) {
342+
env->ball_x += offset;
343+
}
344+
if (env->ball_x > env->width) {
345+
env->ball_x -= offset;
346+
}
347+
if (env->ball_y < 0) {
348+
env->ball_y += offset;
349+
}
345350
}
346351

347352
void destroy_brick(Breakout* env, int brick_idx) {
@@ -353,7 +358,7 @@ void destroy_brick(Breakout* env, int brick_idx) {
353358
env->rewards[0] += gained_points;
354359

355360
if (brick_idx / env->brick_cols < 3) {
356-
env->ball_speed = MAX_BALL_SPEED;
361+
env->ball_speed = env->max_ball_speed;
357362
}
358363
}
359364

@@ -393,8 +398,8 @@ void reset_round(Breakout* env) {
393398
env->balls_fired = 0;
394399
env->hit_brick = false;
395400
env->hits = 0;
396-
env->ball_speed = 256;
397-
env->paddle_width = 2 * HALF_PADDLE_WIDTH;
401+
env->ball_speed = env->initial_ball_speed;
402+
env->paddle_width = env->initial_paddle_width;
398403

399404
env->paddle_x = env->width / 2.0 - env->paddle_width / 2;
400405
env->paddle_y = env->height - env->paddle_height - 10;
@@ -437,7 +442,7 @@ void step_frame(Breakout* env, float action) {
437442
if (env->continuous){
438443
act = action;
439444
}
440-
env->paddle_x += act * 620 * TICK_RATE;
445+
env->paddle_x += act * env->paddle_speed * TICK_RATE;
441446
if (env->paddle_x <= 0){
442447
env->paddle_x = fmaxf(0, env->paddle_x);
443448
} else {

pufferlib/ocean/breakout/breakout.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
'''High-perf Pong
2-
3-
Inspired from https://gist.github.com/Yttrmin/18ecc3d2d68b407b4be1
4-
& https://jair.org/index.php/jair/article/view/10819/25823
5-
& https://www.youtube.com/watch?v=PSQt5KGv7Vk
6-
'''
7-
81
import numpy as np
92
import gymnasium
103

@@ -17,7 +10,10 @@ def __init__(self, num_envs=1, render_mode=None,
1710
paddle_width=62, paddle_height=8,
1811
ball_width=32, ball_height=32,
1912
brick_width=32, brick_height=12,
20-
brick_rows=6, brick_cols=18, continuous=False, log_interval=128,
13+
brick_rows=6, brick_cols=18,
14+
initial_ball_speed=256, max_ball_speed=448,
15+
paddle_speed=620,
16+
continuous=False, log_interval=128,
2117
buf=None, seed=0):
2218
self.single_observation_space = gymnasium.spaces.Box(low=0, high=1,
2319
shape=(10 + brick_rows*brick_cols,), dtype=np.float32)
@@ -40,10 +36,14 @@ def __init__(self, num_envs=1, render_mode=None,
4036
self.actions = self.actions.astype(np.float32)
4137

4238
self.c_envs = binding.vec_init(self.observations, self.actions, self.rewards,
43-
self.terminals, self.truncations, num_envs, seed, frameskip=frameskip, width=width, height=height,
44-
paddle_width=paddle_width, paddle_height=paddle_height, ball_width=ball_width, ball_height=ball_height,
45-
brick_width=brick_width, brick_height=brick_height, brick_rows=brick_rows,
46-
brick_cols=brick_cols, continuous=continuous
39+
self.terminals, self.truncations, num_envs, seed, frameskip=frameskip,
40+
width=width, height=height, paddle_width=paddle_width,
41+
paddle_height=paddle_height, ball_width=ball_width,
42+
ball_height=ball_height, brick_width=brick_width,
43+
brick_height=brick_height, brick_rows=brick_rows,
44+
brick_cols=brick_cols, initial_ball_speed=initial_ball_speed,
45+
max_ball_speed=max_ball_speed, paddle_speed=paddle_speed,
46+
continuous=continuous
4747
)
4848

4949
def reset(self, seed=0):

0 commit comments

Comments
 (0)