@@ -47,16 +47,9 @@ void add_log(BitFlip *env) {
47
47
}
48
48
49
49
void c_reset (BitFlip * env ) {
50
- // Clear observations
51
50
memset (env -> observations , OFF , env -> size * 3 * sizeof (char ));
52
-
53
- // Clear n_correct
54
51
env -> n_correct = 0 ;
55
-
56
- // Always make the first bit 1 to avoid "free" rounds just by chance
57
52
env -> observations [0 ] = ON ;
58
-
59
- // Initialize target pattern
60
53
for (int i = 1 ; i < env -> size ; i ++ ) {
61
54
env -> observations [i ] = (rand () % 2 == 1 ) ? ON : OFF ;
62
55
@@ -65,32 +58,30 @@ void c_reset(BitFlip *env) {
65
58
env -> n_correct ++ ;
66
59
}
67
60
}
68
-
69
- // Initialize starting position in the middle
70
61
env -> pos = 2 * env -> size + (env -> size - 1 ) / 2 ;
71
62
env -> observations [env -> pos ] = CURSOR ;
72
-
73
- // Clear number of steps
74
63
env -> tick = 0 ;
75
64
}
76
65
77
66
void c_step (BitFlip * env ) {
78
- env -> tick ++ ;
67
+ env -> tick += 1 ;
79
68
80
- env -> observations [env -> pos ] = EMPTY ;
69
+ int action = env -> actions [0 ];
70
+ env -> terminals [0 ] = 0 ;
71
+ env -> rewards [0 ] = 0.0 ;
81
72
82
- if (env -> actions [0 ] == LEFT ) {
83
- env -> pos -- ;
84
- }
73
+ env -> observations [env -> pos ] = EMPTY ;
85
74
86
- if (env -> actions [0 ] == RIGHT ) {
87
- env -> pos ++ ;
75
+ if (action == LEFT ) {
76
+ env -> pos -= 1 ;
77
+ } else if (action == RIGHT ) {
78
+ env -> pos += 1 ;
88
79
}
89
80
90
81
if (env -> tick == 12 * env -> size || env -> pos < 2 * env -> size ||
91
82
env -> pos >= env -> size * 3 ) {
92
- env -> rewards [0 ] = -1.0 ;
93
83
env -> terminals [0 ] = 1 ;
84
+ env -> rewards [0 ] = -1.0 ;
94
85
add_log (env );
95
86
c_reset (env );
96
87
return ;
@@ -101,14 +92,13 @@ void c_step(BitFlip *env) {
101
92
int state_idx = env -> pos - env -> size ;
102
93
int target_idx = env -> pos - 2 * env -> size ;
103
94
104
- // Flip bit
105
- if (env -> actions [0 ] == FLIP ) {
95
+ if (action == FLIP ) {
106
96
env -> observations [state_idx ] ^= 1 ;
107
97
108
98
if (env -> observations [state_idx ] == env -> observations [target_idx ]) {
109
- env -> n_correct ++ ;
99
+ env -> n_correct += 1 ;
110
100
} else {
111
- env -> n_correct -- ;
101
+ env -> n_correct -= 1 ;
112
102
}
113
103
}
114
104
@@ -119,9 +109,6 @@ void c_step(BitFlip *env) {
119
109
c_reset (env );
120
110
return ;
121
111
}
122
-
123
- env -> rewards [0 ] = 0.0 ;
124
- env -> terminals [0 ] = 0 ;
125
112
}
126
113
127
114
void c_render (BitFlip * env ) {
0 commit comments