Skip to content

Commit fe8b3f5

Browse files
suryabhupaRLaxDev
authored andcommitted
rlax: Upstream Muesli utilities to rlax.
We now provide methods for constructing the clipped MPO (CMPO) policy targets used as part of the Muesli agent loss. These CMPO targets are in expectation proportional to: `prior(a|s) * exp(clip(norm(Q(s, a))))` where the prior is computed by the actor policy head, and the Q values are computed using the learned model's reward and value heads. See "Muesli: Combining Improvements in Policy Optimization" by Hessel et al. (https://arxiv.org/pdf/2104.06159.pdf) for more details. PiperOrigin-RevId: 493987878
1 parent 44ef3f0 commit fe8b3f5

File tree

3 files changed

+207
-1
lines changed

3 files changed

+207
-1
lines changed

docs/api.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ Policy Optimization
229229
.. autosummary::
230230

231231
clipped_surrogate_pg_loss
232+
cmpo_policy_targets
232233
constant_policy_targets
233234
dpg_loss
234235
entropy_loss
@@ -238,6 +239,7 @@ Policy Optimization
238239
qpg_loss
239240
rm_loss
240241
rpg_loss
242+
sampled_cmpo_policy_targets
241243
sampled_policy_distillation_loss
242244
zero_policy_targets
243245

@@ -247,6 +249,18 @@ Clipped Surrogate PG Loss
247249
.. autofunction:: clipped_surrogate_pg_loss
248250

249251

252+
CMPO Policy Targets
253+
~~~~~~~~~~~~~~~~~~~
254+
255+
.. autofunction:: cmpo_policy_targets
256+
257+
258+
Sampled CMPO Policy Targets
259+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
260+
261+
.. autofunction:: sampled_cmpo_policy_targets
262+
263+
250264
Compute Parametric KL Penalty and Dual Loss
251265
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
252266

rlax/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,10 @@
8888
from rlax._src.policy_gradients import qpg_loss
8989
from rlax._src.policy_gradients import rm_loss
9090
from rlax._src.policy_gradients import rpg_loss
91+
from rlax._src.policy_targets import cmpo_policy_targets
9192
from rlax._src.policy_targets import constant_policy_targets
9293
from rlax._src.policy_targets import PolicyTarget
94+
from rlax._src.policy_targets import sampled_cmpo_policy_targets
9395
from rlax._src.policy_targets import sampled_policy_distillation_loss
9496
from rlax._src.policy_targets import zero_policy_targets
9597
from rlax._src.pop_art import art
@@ -159,6 +161,7 @@
159161
"categorical_td_learning",
160162
"clip_gradient",
161163
"clipped_surrogate_pg_loss",
164+
"cmpo_policy_targets",
162165
"compose_tx",
163166
"conditional_update",
164167
"constant_policy_targets",
@@ -230,6 +233,7 @@
230233
"rpg_loss",
231234
"sample_start_indices",
232235
"sampled_policy_distillation_loss",
236+
"sampled_cmpo_policy_targets",
233237
"sarsa",
234238
"sarsa_lambda",
235239
"sigmoid",

rlax/_src/policy_targets.py

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
"""Utilities to construct and learn from policy targets."""
15+
"""Construct and learn from policy targets. Used by Muesli-based agents."""
1616

1717
import functools
1818

1919
import chex
2020
import distrax
2121
import jax
2222
import jax.numpy as jnp
23+
from rlax._src import base
2324

2425

2526
@chex.dataclass(frozen=True)
@@ -106,3 +107,190 @@ def sampled_policy_distillation_loss(
106107
# We average over the samples, over time and batch, and if the actions are
107108
# a continuous vector also over the actions.
108109
return -jnp.mean(weights * jnp.maximum(log_probs, min_logp))
110+
111+
112+
def cmpo_policy_targets(
113+
prior_distribution,
114+
embeddings,
115+
rng_key,
116+
baseline_value,
117+
q_provider,
118+
advantage_normalizer,
119+
*,
120+
num_actions,
121+
min_target_advantage=-jnp.inf,
122+
max_target_advantage=1.0,
123+
kl_weight=1.0,
124+
) -> PolicyTarget:
125+
"""Policy targets for Clipped MPO.
126+
127+
The policy targets are in-expectation proportional to:
128+
`prior(a|s) * exp(clip(norm(Q(s, a))))`
129+
130+
See "Muesli: Combining Improvements in Policy Optimization" by Hessel et al.
131+
(https://arxiv.org/pdf/2104.06159.pdf).
132+
133+
Args:
134+
prior_distribution: the prior policy distribution.
135+
embeddings: embeddings for the `q_provider`.
136+
rng_key: a JAX pseudo random number generator key.
137+
baseline_value: the baseline for `advantage_normalizer`.
138+
q_provider: a fn to compute q values.
139+
advantage_normalizer: a fn to normalise advantages.
140+
*,
141+
num_actions: The total number of discrete actions.
142+
min_target_advantage: The minimum advantage of a policy target.
143+
max_target_advantage: The max advantage of a policy target.
144+
kl_weight: The coefficient for the KL regularizer.
145+
146+
Returns:
147+
the clipped MPO policy targets.
148+
"""
149+
# Expecting shape [B].
150+
chex.assert_rank(baseline_value, 1)
151+
rng_key, query_rng_key = jax.random.split(rng_key)
152+
del rng_key
153+
154+
# Producing all actions with shape [num_actions, B].
155+
batch_size, = baseline_value.shape
156+
actions = jnp.broadcast_to(
157+
jnp.expand_dims(jnp.arange(num_actions, dtype=jnp.int32), axis=-1),
158+
[num_actions, batch_size])
159+
160+
# Using vmap over the num_actions in axis=0.
161+
def _query_q(actions):
162+
return q_provider(
163+
# Using the same rng_key for the all actions samples.
164+
rng_key=query_rng_key,
165+
action=actions,
166+
embeddings=embeddings)
167+
qvalues = jax.vmap(_query_q)(actions)
168+
169+
# Using the same advantage normalization as for policy gradients.
170+
raw_advantage = advantage_normalizer(
171+
returns=qvalues, baseline_value=baseline_value)
172+
clipped_advantage = jnp.clip(
173+
raw_advantage, min_target_advantage,
174+
max_target_advantage)
175+
176+
# Construct and normalise the weights.
177+
log_prior = prior_distribution.log_prob(actions)
178+
weights = softmax_policy_target_normalizer(
179+
log_prior + clipped_advantage / kl_weight)
180+
policy_targets = PolicyTarget(actions=actions, weights=weights)
181+
return policy_targets
182+
183+
184+
def sampled_cmpo_policy_targets(
185+
prior_distribution,
186+
embeddings,
187+
rng_key,
188+
baseline_value,
189+
q_provider,
190+
advantage_normalizer,
191+
*,
192+
num_actions=2,
193+
min_target_advantage=-jnp.inf,
194+
max_target_advantage=1.0,
195+
kl_weight=1.0,
196+
) -> PolicyTarget:
197+
"""Policy targets for sampled CMPO.
198+
199+
As in CMPO the policy targets are in-expectation proportional to:
200+
`prior(a|s) * exp(clip(norm(Q(s, a))))`
201+
However we only sample a subset of the actions, this allows to scale to
202+
large discrete action spaces and to continuous actions.
203+
204+
See "Muesli: Combining Improvements in Policy Optimization" by Hessel et al.
205+
(https://arxiv.org/pdf/2104.06159.pdf).
206+
207+
Args:
208+
prior_distribution: the prior policy distribution.
209+
embeddings: embeddings for the `q_provider`.
210+
rng_key: a JAX pseudo random number generator key.
211+
baseline_value: the baseline for `advantage_normalizer`.
212+
q_provider: a fn to compute q values.
213+
advantage_normalizer: a fn to normalise advantages.
214+
*,
215+
num_actions: The number of actions to expand on each step.
216+
min_target_advantage: The minimum advantage of a policy target.
217+
max_target_advantage: The max advantage of a policy target.
218+
kl_weight: The coefficient for the KL regularizer.
219+
220+
Returns:
221+
the sampled clipped MPO policy targets.
222+
"""
223+
# Expecting shape [B].
224+
chex.assert_rank(baseline_value, 1)
225+
query_rng_key, action_key = jax.random.split(rng_key)
226+
del rng_key
227+
228+
# Sampling the actions from the prior.
229+
actions = prior_distribution.sample(
230+
seed=action_key, sample_shape=[num_actions])
231+
232+
# Using vmap over the num_expanded in axis=0.
233+
def _query_q(actions):
234+
return q_provider(
235+
# Using the same rng_key for the all actions samples.
236+
rng_key=query_rng_key,
237+
action=actions,
238+
embeddings=embeddings)
239+
qvalues = jax.vmap(_query_q)(actions)
240+
241+
# Using the same advantage normalization as for policy gradients.
242+
raw_advantage = advantage_normalizer(
243+
returns=qvalues, baseline_value=baseline_value)
244+
clipped_advantage = jnp.clip(
245+
raw_advantage, min_target_advantage, max_target_advantage)
246+
247+
# The expected normalized weight would be 1.0. The weights would be
248+
# normalized, if the baseline_value is the log of the expected weight. I.e.,
249+
# if the baseline_value is log(sum_a(prior(a|s) * exp(Q(s, a)/c))).
250+
weights = jnp.exp(clipped_advantage / kl_weight)
251+
252+
# The weights are tiled, if using multiple continuous actions.
253+
# It is OK to use multiple continuous actions inside the Q(s, a),
254+
# because the action is sampled from the joint distribution
255+
# and weight is not based on non-joint probabilities.
256+
log_prob = prior_distribution.log_prob(actions)
257+
weights = jnp.broadcast_to(
258+
base.lhs_broadcast(weights, log_prob), log_prob.shape)
259+
return PolicyTarget(actions=actions, weights=weights)
260+
261+
262+
def softmax_policy_target_normalizer(log_weights):
263+
"""Returns self-normalized weights.
264+
265+
The self-normalizing weights introduce a significant bias,
266+
if computing the average weight from a small number of samples.
267+
268+
Args:
269+
log_weights: log unnormalized weights, shape `[num_targets, ...]`.
270+
271+
Returns:
272+
Weights divided by average weight from sample. Weights sum to `num_targets`.
273+
"""
274+
num_targets = log_weights.shape[0]
275+
return num_targets * jax.nn.softmax(log_weights, axis=0)
276+
277+
278+
def loo_policy_target_normalizer(log_weights):
279+
"""A leave-one-out normalizer.
280+
281+
Args:
282+
log_weights: log unnormalized weights, shape `[num_targets, ...]`.
283+
284+
Returns:
285+
Weights divided by a consistent estimate of the average weight. The weights
286+
are not guaranteed to sum to `num_targets`.
287+
"""
288+
num_targets = log_weights.shape[0]
289+
weights = jnp.exp(log_weights)
290+
# Using a safe consistent estimator of the average weight, independently of
291+
# the numerator.
292+
# The unnormalized weight are already approximately normalized by a
293+
# baseline_value, so we use `1` as the initial estimate of the average weight.
294+
avg_weight = (
295+
1 + jnp.sum(weights, axis=0, keepdims=True) - weights) / num_targets
296+
return weights / avg_weight

0 commit comments

Comments
 (0)