Skip to content

Commit 1dcdba3

Browse files
authored
Merge branch 'PufferAI:3.0' into vision_test
2 parents a3e73a3 + 53d9f32 commit 1dcdba3

File tree

16 files changed

+718
-445
lines changed

16 files changed

+718
-445
lines changed

pufferlib/config/ocean/breakout.ini

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,21 @@ num_envs = 8
1010
[env]
1111
num_envs = 1024
1212
frameskip = 4
13-
13+
width = 576
14+
height = 330
15+
paddle_width = 62
16+
paddle_height = 8
17+
ball_width = 32
18+
ball_height = 32
19+
brick_width = 32
20+
brick_height = 12
21+
brick_rows = 6
22+
brick_cols = 18
23+
initial_ball_speed = 256
24+
max_ball_speed = 448
25+
paddle_speed = 620
26+
continuous = 0
27+
1428
[policy]
1529
hidden_size = 128
1630

pufferlib/config/ocean/cartpole.ini

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ rnn_name = Recurrent
66

77
[env]
88
num_envs = 4096
9+
cart_mass = 1.0
10+
pole_mass = 0.1
11+
pole_length = 0.5
12+
gravity = 9.8
13+
force_mag = 10.0
14+
dt = 0.02
915

1016
[train]
1117
total_timesteps = 20_000_000

pufferlib/ocean/breakout/binding.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
77
env->frameskip = unpack(kwargs, "frameskip");
88
env->width = unpack(kwargs, "width");
99
env->height = unpack(kwargs, "height");
10-
env->paddle_width = unpack(kwargs, "paddle_width");
10+
env->initial_paddle_width = unpack(kwargs, "paddle_width");
1111
env->paddle_height = unpack(kwargs, "paddle_height");
1212
env->ball_width = unpack(kwargs, "ball_width");
1313
env->ball_height = unpack(kwargs, "ball_height");
1414
env->brick_width = unpack(kwargs, "brick_width");
1515
env->brick_height = unpack(kwargs, "brick_height");
1616
env->brick_rows = unpack(kwargs, "brick_rows");
1717
env->brick_cols = unpack(kwargs, "brick_cols");
18+
env->initial_ball_speed = unpack(kwargs, "initial_ball_speed");
19+
env->max_ball_speed = unpack(kwargs, "max_ball_speed");
20+
env->paddle_speed = unpack(kwargs, "paddle_speed");
1821
env->continuous = unpack(kwargs, "continuous");
1922
init(env);
2023
return 0;

pufferlib/ocean/breakout/breakout.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ void demo() {
1919
.brick_height = 12,
2020
.brick_rows = 6,
2121
.brick_cols = 18,
22+
.initial_ball_speed = 256,
23+
.max_ball_speed = 448,
24+
.paddle_speed = 620,
2225
.continuous = 0,
2326
};
2427
allocate(&env);

pufferlib/ocean/breakout/breakout.h

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#define NOOP 0
1010
#define LEFT 1
1111
#define RIGHT 2
12-
#define MAX_BALL_SPEED 448
1312
#define HALF_PADDLE_WIDTH 31
1413
#define Y_OFFSET 50
1514
#define TICK_RATE 1.0f/60.0f
@@ -55,9 +54,13 @@ typedef struct Breakout {
5554
float* brick_y;
5655
float* brick_states;
5756
int balls_fired;
57+
float initial_paddle_width;
5858
float paddle_width;
5959
float paddle_height;
60+
float paddle_speed;
6061
float ball_speed;
62+
float initial_ball_speed;
63+
float max_ball_speed;
6164
int hits;
6265
int width;
6366
int height;
@@ -289,15 +292,13 @@ bool calc_paddle_ball_collisions(Breakout* env, CollisionInfo* collision_info) {
289292
collision_info->brick_index = BRICK_INDEX_PADDLE_COLLISION;
290293

291294
env->hit_brick = false;
292-
float relative_intersection = ((env->ball_x +
293-
env->ball_width / 2) -
294-
env->paddle_x) /
295-
env->paddle_width;
295+
float relative_intersection = (
296+
(env->ball_x + env->ball_width / 2) - env->paddle_x) / env->paddle_width;
296297
float angle = -base_angle + relative_intersection * 2 * base_angle;
297298
env->ball_vx = sin(angle) * env->ball_speed * TICK_RATE;
298299
env->ball_vy = -cos(angle) * env->ball_speed * TICK_RATE;
299300
env->hits += 1;
300-
if (env->hits % 4 == 0 && env->ball_speed < MAX_BALL_SPEED) {
301+
if (env->hits % 4 == 0 && env->ball_speed < env->max_ball_speed) {
301302
env->ball_speed += 64;
302303
}
303304
if (env->score == env->half_max_score) {
@@ -336,12 +337,16 @@ void calc_all_wall_collisions(Breakout* env, CollisionInfo* collision_info) {
336337
// With rare floating point conditions, the ball could escape the bounds.
337338
// Let's handle that explicitly.
338339
void check_wall_bounds(Breakout* env) {
339-
if (env->ball_x < 0)
340-
env->ball_x += MAX_BALL_SPEED * 1.1f * TICK_RATE;
341-
if (env->ball_x > env->width)
342-
env->ball_x -= MAX_BALL_SPEED * 1.1f * TICK_RATE;
343-
if (env->ball_y < 0)
344-
env->ball_y += MAX_BALL_SPEED * 1.1f * TICK_RATE;
340+
float offset = env->max_ball_speed * 1.1f * TICK_RATE;
341+
if (env->ball_x < 0) {
342+
env->ball_x += offset;
343+
}
344+
if (env->ball_x > env->width) {
345+
env->ball_x -= offset;
346+
}
347+
if (env->ball_y < 0) {
348+
env->ball_y += offset;
349+
}
345350
}
346351

347352
void destroy_brick(Breakout* env, int brick_idx) {
@@ -353,7 +358,7 @@ void destroy_brick(Breakout* env, int brick_idx) {
353358
env->rewards[0] += gained_points;
354359

355360
if (brick_idx / env->brick_cols < 3) {
356-
env->ball_speed = MAX_BALL_SPEED;
361+
env->ball_speed = env->max_ball_speed;
357362
}
358363
}
359364

@@ -393,8 +398,8 @@ void reset_round(Breakout* env) {
393398
env->balls_fired = 0;
394399
env->hit_brick = false;
395400
env->hits = 0;
396-
env->ball_speed = 256;
397-
env->paddle_width = 2 * HALF_PADDLE_WIDTH;
401+
env->ball_speed = env->initial_ball_speed;
402+
env->paddle_width = env->initial_paddle_width;
398403

399404
env->paddle_x = env->width / 2.0 - env->paddle_width / 2;
400405
env->paddle_y = env->height - env->paddle_height - 10;
@@ -437,7 +442,7 @@ void step_frame(Breakout* env, float action) {
437442
if (env->continuous){
438443
act = action;
439444
}
440-
env->paddle_x += act * 620 * TICK_RATE;
445+
env->paddle_x += act * env->paddle_speed * TICK_RATE;
441446
if (env->paddle_x <= 0){
442447
env->paddle_x = fmaxf(0, env->paddle_x);
443448
} else {

pufferlib/ocean/breakout/breakout.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
'''High-perf Pong
2-
3-
Inspired from https://gist.github.com/Yttrmin/18ecc3d2d68b407b4be1
4-
& https://jair.org/index.php/jair/article/view/10819/25823
5-
& https://www.youtube.com/watch?v=PSQt5KGv7Vk
6-
'''
7-
81
import numpy as np
92
import gymnasium
103

@@ -17,7 +10,10 @@ def __init__(self, num_envs=1, render_mode=None,
1710
paddle_width=62, paddle_height=8,
1811
ball_width=32, ball_height=32,
1912
brick_width=32, brick_height=12,
20-
brick_rows=6, brick_cols=18, continuous=False, log_interval=128,
13+
brick_rows=6, brick_cols=18,
14+
initial_ball_speed=256, max_ball_speed=448,
15+
paddle_speed=620,
16+
continuous=False, log_interval=128,
2117
buf=None, seed=0):
2218
self.single_observation_space = gymnasium.spaces.Box(low=0, high=1,
2319
shape=(10 + brick_rows*brick_cols,), dtype=np.float32)
@@ -40,10 +36,14 @@ def __init__(self, num_envs=1, render_mode=None,
4036
self.actions = self.actions.astype(np.float32)
4137

4238
self.c_envs = binding.vec_init(self.observations, self.actions, self.rewards,
43-
self.terminals, self.truncations, num_envs, seed, frameskip=frameskip, width=width, height=height,
44-
paddle_width=paddle_width, paddle_height=paddle_height, ball_width=ball_width, ball_height=ball_height,
45-
brick_width=brick_width, brick_height=brick_height, brick_rows=brick_rows,
46-
brick_cols=brick_cols, continuous=continuous
39+
self.terminals, self.truncations, num_envs, seed, frameskip=frameskip,
40+
width=width, height=height, paddle_width=paddle_width,
41+
paddle_height=paddle_height, ball_width=ball_width,
42+
ball_height=ball_height, brick_width=brick_width,
43+
brick_height=brick_height, brick_rows=brick_rows,
44+
brick_cols=brick_cols, initial_ball_speed=initial_ball_speed,
45+
max_ball_speed=max_ball_speed, paddle_speed=paddle_speed,
46+
continuous=continuous
4747
)
4848

4949
def reset(self, seed=0):

pufferlib/ocean/cartpole/binding.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
#include "../env_binding.h"
44

55
static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
6+
env->cart_mass = unpack(kwargs, "cart_mass");
7+
env->pole_mass = unpack(kwargs, "pole_mass");
8+
env->pole_length = unpack(kwargs, "pole_length");
9+
env->gravity = unpack(kwargs, "gravity");
10+
env->force_mag = unpack(kwargs, "force_mag");
11+
env->tau = unpack(kwargs, "dt");
612
env->continuous = unpack(kwargs, "continuous");
713
init(env);
814
return 0;

pufferlib/ocean/cartpole/cartpole.h

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,6 @@
66
#include <time.h>
77
#include "raylib.h"
88

9-
#define GRAVITY 9.8f
10-
#define MASSCART 1.0f
11-
#define MASSPOLE 0.1f
12-
#define TOTAL_MASS (MASSPOLE + MASSCART)
13-
#define LENGTH 0.5f // half pole length
14-
#define POLEMASS_LENGTH (MASSPOLE * LENGTH)
15-
#define FORCE_MAG 10.0f
16-
#define TAU 0.02f // timestep duration
17-
189
#define X_THRESHOLD 2.4f
1910
#define THETA_THRESHOLD_RADIANS (12 * 2 * M_PI / 360)
2011
#define MAX_STEPS 200
@@ -51,6 +42,12 @@ struct Cartpole {
5142
float theta;
5243
float theta_dot;
5344
int tick;
45+
float cart_mass;
46+
float pole_mass;
47+
float pole_length;
48+
float gravity;
49+
float force_mag;
50+
float tau;
5451
int continuous;
5552
float episode_return;
5653
};
@@ -172,25 +169,29 @@ void c_step(Cartpole* env) {
172169
}
173170
/* ========================================================== */
174171

175-
if (!isfinite(a)) a = 0.0f;
172+
if (!isfinite(a)) {
173+
a = 0.0f;
174+
}
176175
a = fminf(fmaxf(a, -1.0f), 1.0f);
177176
env->actions[0] = a;
178177

179-
float force = env->continuous ? a * FORCE_MAG
180-
: (a > 0.5f ? FORCE_MAG : -FORCE_MAG);
178+
float force = env->continuous ? a * env->force_mag
179+
: (a > 0.5f ? env->force_mag: -env->force_mag);
181180

182181
float costheta = cosf(env->theta);
183182
float sintheta = sinf(env->theta);
184183

185-
float temp = (force + POLEMASS_LENGTH * env->theta_dot * env->theta_dot * sintheta) / TOTAL_MASS;
186-
float thetaacc = (GRAVITY * sintheta - costheta * temp) /
187-
(LENGTH * (4.0f / 3.0f - MASSPOLE * costheta * costheta / TOTAL_MASS));
188-
float xacc = temp - POLEMASS_LENGTH * thetaacc * costheta / TOTAL_MASS;
189-
190-
env->x += TAU * env->x_dot;
191-
env->x_dot += TAU * xacc;
192-
env->theta += TAU * env->theta_dot;
193-
env->theta_dot += TAU * thetaacc;
184+
float total_mass = env->cart_mass + env->pole_mass;
185+
float polemass_length = total_mass + env->pole_mass;
186+
float temp = (force + polemass_length * env->theta_dot * env->theta_dot * sintheta) / total_mass;
187+
float thetaacc = (env->gravity * sintheta - costheta * temp) /
188+
(env->pole_length * (4.0f / 3.0f - total_mass * costheta * costheta / total_mass));
189+
float xacc = temp - polemass_length * thetaacc * costheta / total_mass;
190+
191+
env->x += env->tau * env->x_dot;
192+
env->x_dot += env->tau * xacc;
193+
env->theta += env->tau * env->theta_dot;
194+
env->theta_dot += env->tau * thetaacc;
194195

195196
env->tick += 1;
196197

pufferlib/ocean/cartpole/cartpole.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from pufferlib.ocean.cartpole import binding
55

66
class Cartpole(pufferlib.PufferEnv):
7-
def __init__(self, num_envs=1, render_mode='human', report_interval=1, continuous=False, buf=None, seed=0):
7+
def __init__(self, num_envs=1, cart_mass=1.0, pole_mass=0.1,
8+
pole_length=0.5, gravity=9.8, force_mag=10.0, dt=0.02,
9+
render_mode='human', report_interval=1, continuous=False,
10+
buf=None, seed=0):
811
self.render_mode = render_mode
912
self.num_agents = num_envs
1013
self.report_interval = report_interval
@@ -35,6 +38,12 @@ def __init__(self, num_envs=1, render_mode='human', report_interval=1, continuou
3538
self.truncations,
3639
num_envs,
3740
seed,
41+
cart_mass=cart_mass,
42+
pole_mass=pole_mass,
43+
pole_length=pole_length,
44+
gravity=gravity,
45+
force_mag=force_mag,
46+
dt=dt,
3847
continuous=int(self.continuous),
3948
)
4049

pufferlib/ocean/drone_race/binding.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) {
1313
static int my_log(PyObject *dict, Log *log) {
1414
assign_to_dict(dict, "perf", log->perf);
1515
assign_to_dict(dict, "score", log->score);
16+
assign_to_dict(dict, "collision_rate", log->collision_rate);
17+
assign_to_dict(dict, "oob", log->oob);
18+
assign_to_dict(dict, "timeout", log->timeout);
1619
assign_to_dict(dict, "episode_return", log->episode_return);
1720
assign_to_dict(dict, "episode_length", log->episode_length);
1821
assign_to_dict(dict, "n", log->n);

0 commit comments

Comments
 (0)