Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions rlax/_src/multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def lambda_returns(
v_t: Array,
lambda_: Numeric = 1.,
stop_target_gradients: bool = False,
unroll: int | bool = 1,
) -> Array:
"""Estimates a multistep truncated lambda return from a trajectory.

Expand Down Expand Up @@ -93,6 +94,7 @@ def lambda_returns(
lambda_: mixing parameter; a scalar or a vector for timesteps t in [1, T].
stop_target_gradients: bool indicating whether or not to apply stop gradient
to targets.
unroll: how many scan iterations to unroll.

Returns:
Multistep lambda returns.
Expand All @@ -111,7 +113,12 @@ def _body(acc, xs):
return acc, acc

_, returns = jax.lax.scan(
_body, v_t[-1], (r_t, discount_t, v_t, lambda_), reverse=True)
_body,
v_t[-1],
(r_t, discount_t, v_t, lambda_),
reverse=True,
unroll=unroll,
)

return jax.lax.select(stop_target_gradients,
jax.lax.stop_gradient(returns),
Expand Down Expand Up @@ -219,6 +226,7 @@ def importance_corrected_td_errors(
lambda_: Array,
values: Array,
stop_target_gradients: bool = False,
unroll: int | bool = 1,
) -> Array:
"""Computes the multistep td errors with per decision importance sampling.

Expand Down Expand Up @@ -246,6 +254,7 @@ def importance_corrected_td_errors(
values: sequence of state values under π for all timesteps t in [0, T].
stop_target_gradients: bool indicating whether or not to apply stop gradient
to targets.
unroll: how many scan iterations to unroll.

Returns:
Off-policy estimates of the multistep td errors.
Expand All @@ -269,7 +278,12 @@ def _body(acc, xs):
return acc, acc

_, errors = jax.lax.scan(
_body, 0.0, (one_step_delta, discount_t, rho_t, lambda_), reverse=True)
_body,
0.0,
(one_step_delta, discount_t, rho_t, lambda_),
reverse=True,
unroll=unroll,
)

errors = rho_tm1 * errors
return jax.lax.select(stop_target_gradients,
Expand All @@ -282,6 +296,7 @@ def truncated_generalized_advantage_estimation(
lambda_: Union[Array, Scalar],
values: Array,
stop_target_gradients: bool = False,
unroll: int | bool = 1,
) -> Array:
"""Computes truncated generalized advantage estimates for a sequence length k.

Expand All @@ -303,6 +318,7 @@ def truncated_generalized_advantage_estimation(
values: Sequence of values under π at times [0, k]
stop_target_gradients: bool indicating whether or not to apply stop gradient
to targets.
unroll: how many scan iterations to unroll.

Returns:
Multistep truncated generalized advantage estimation at times [0, k-1].
Expand All @@ -320,7 +336,12 @@ def _body(acc, xs):
return acc, acc

_, advantage_t = jax.lax.scan(
_body, 0.0, (delta_t, discount_t, lambda_), reverse=True)
_body,
0.0,
(delta_t, discount_t, lambda_),
reverse=True,
unroll=unroll,
)

return jax.lax.select(stop_target_gradients,
jax.lax.stop_gradient(advantage_t),
Expand Down Expand Up @@ -393,6 +414,7 @@ def general_off_policy_returns_from_q_and_v(
discount_t: Array,
c_t: Array,
stop_target_gradients: bool = False,
unroll: int | bool = 1,
) -> Array:
"""Calculates targets for various off-policy evaluation algorithms.

Expand Down Expand Up @@ -421,6 +443,7 @@ def general_off_policy_returns_from_q_and_v(
c_t: weights at times [1, ..., K - 1].
stop_target_gradients: bool indicating whether or not to apply stop gradient
to targets.
unroll: how many scan iterations to unroll.

Returns:
Off-policy estimates of the generalized returns from states visited at times
Expand All @@ -438,7 +461,12 @@ def _body(acc, xs):
return acc, acc

_, returns = jax.lax.scan(
_body, g, (r_t[:-1], discount_t[:-1], c_t, v_t[:-1], q_t), reverse=True)
_body,
g,
(r_t[:-1], discount_t[:-1], c_t, v_t[:-1], q_t),
reverse=True,
unroll=unroll,
)
returns = jnp.concatenate([returns, g[jnp.newaxis]], axis=0)

return jax.lax.select(stop_target_gradients,
Expand Down
21 changes: 18 additions & 3 deletions rlax/_src/vtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def vtrace(
lambda_: Numeric = 1.0,
clip_rho_threshold: float = 1.0,
stop_target_gradients: bool = True,
unroll: int | bool = 1,
) -> Array:
"""Calculates V-Trace errors from importance weights.

Expand All @@ -62,6 +63,7 @@ def vtrace(
lambda_: mixing parameter; a scalar or a vector for timesteps t.
clip_rho_threshold: clip threshold for importance weights.
stop_target_gradients: whether or not to apply stop gradient to targets.
unroll: how many scan iterations to unroll.

Returns:
V-Trace error.
Expand All @@ -86,7 +88,12 @@ def _body(acc, xs):
return acc, acc

_, errors = jax.lax.scan(
_body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
_body,
0.0,
(td_errors, discount_t, c_tm1),
reverse=True,
unroll=unroll,
)

# Return errors, maybe disabling gradient flow through bootstrap targets.
return jax.lax.select(
Expand All @@ -104,7 +111,9 @@ def leaky_vtrace(
alpha_: float = 1.0,
lambda_: Numeric = 1.0,
clip_rho_threshold: float = 1.0,
stop_target_gradients: bool = True):
stop_target_gradients: bool = True,
unroll: int | bool = 1,
):
"""Calculates Leaky V-Trace errors from importance weights.

Leaky-Vtrace is a combination of Importance sampling and V-trace, where the
Expand All @@ -123,6 +132,7 @@ def leaky_vtrace(
lambda_: mixing parameter; a scalar or a vector for timesteps t.
clip_rho_threshold: clip threshold for importance weights.
stop_target_gradients: whether or not to apply stop gradient to targets.
unroll: how many scan iterations to unroll.

Returns:
Leaky V-Trace error.
Expand Down Expand Up @@ -150,7 +160,12 @@ def _body(acc, xs):
return acc, acc

_, errors = jax.lax.scan(
_body, 0.0, (td_errors, discount_t, c_tm1), reverse=True)
_body,
0.0,
(td_errors, discount_t, c_tm1),
reverse=True,
unroll=unroll,
)

# Return errors, maybe disabling gradient flow through bootstrap targets.
return jax.lax.select(
Expand Down