|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | # ============================================================================== |
15 | | -"""Utilities to construct and learn from policy targets.""" |
| 15 | +"""Construct and learn from policy targets. Used by Muesli-based agents.""" |
16 | 16 |
|
17 | 17 | import functools |
18 | 18 |
|
19 | 19 | import chex |
20 | 20 | import distrax |
21 | 21 | import jax |
22 | 22 | import jax.numpy as jnp |
| 23 | +from rlax._src import base |
23 | 24 |
|
24 | 25 |
|
25 | 26 | @chex.dataclass(frozen=True) |
@@ -106,3 +107,190 @@ def sampled_policy_distillation_loss( |
106 | 107 | # We average over the samples, over time and batch, and if the actions are |
107 | 108 | # a continuous vector also over the actions. |
108 | 109 | return -jnp.mean(weights * jnp.maximum(log_probs, min_logp)) |
| 110 | + |
| 111 | + |
| 112 | +def cmpo_policy_targets( |
| 113 | + prior_distribution, |
| 114 | + embeddings, |
| 115 | + rng_key, |
| 116 | + baseline_value, |
| 117 | + q_provider, |
| 118 | + advantage_normalizer, |
| 119 | + *, |
| 120 | + num_actions, |
| 121 | + min_target_advantage=-jnp.inf, |
| 122 | + max_target_advantage=1.0, |
| 123 | + kl_weight=1.0, |
| 124 | +) -> PolicyTarget: |
| 125 | + """Policy targets for Clipped MPO. |
| 126 | +
|
| 127 | + The policy targets are in-expectation proportional to: |
| 128 | + `prior(a|s) * exp(clip(norm(Q(s, a))))` |
| 129 | +
|
| 130 | + See "Muesli: Combining Improvements in Policy Optimization" by Hessel et al. |
| 131 | + (https://arxiv.org/pdf/2104.06159.pdf). |
| 132 | +
|
| 133 | + Args: |
| 134 | + prior_distribution: the prior policy distribution. |
| 135 | + embeddings: embeddings for the `q_provider`. |
| 136 | + rng_key: a JAX pseudo random number generator key. |
| 137 | + baseline_value: the baseline for `advantage_normalizer`. |
| 138 | + q_provider: a fn to compute q values. |
| 139 | + advantage_normalizer: a fn to normalise advantages. |
| 140 | + *, |
| 141 | + num_actions: The total number of discrete actions. |
| 142 | + min_target_advantage: The minimum advantage of a policy target. |
| 143 | + max_target_advantage: The max advantage of a policy target. |
| 144 | + kl_weight: The coefficient for the KL regularizer. |
| 145 | +
|
| 146 | + Returns: |
| 147 | + the clipped MPO policy targets. |
| 148 | + """ |
| 149 | + # Expecting shape [B]. |
| 150 | + chex.assert_rank(baseline_value, 1) |
| 151 | + rng_key, query_rng_key = jax.random.split(rng_key) |
| 152 | + del rng_key |
| 153 | + |
| 154 | + # Producing all actions with shape [num_actions, B]. |
| 155 | + batch_size, = baseline_value.shape |
| 156 | + actions = jnp.broadcast_to( |
| 157 | + jnp.expand_dims(jnp.arange(num_actions, dtype=jnp.int32), axis=-1), |
| 158 | + [num_actions, batch_size]) |
| 159 | + |
| 160 | + # Using vmap over the num_actions in axis=0. |
| 161 | + def _query_q(actions): |
| 162 | + return q_provider( |
| 163 | + # Using the same rng_key for the all actions samples. |
| 164 | + rng_key=query_rng_key, |
| 165 | + action=actions, |
| 166 | + embeddings=embeddings) |
| 167 | + qvalues = jax.vmap(_query_q)(actions) |
| 168 | + |
| 169 | + # Using the same advantage normalization as for policy gradients. |
| 170 | + raw_advantage = advantage_normalizer( |
| 171 | + returns=qvalues, baseline_value=baseline_value) |
| 172 | + clipped_advantage = jnp.clip( |
| 173 | + raw_advantage, min_target_advantage, |
| 174 | + max_target_advantage) |
| 175 | + |
| 176 | + # Construct and normalise the weights. |
| 177 | + log_prior = prior_distribution.log_prob(actions) |
| 178 | + weights = softmax_policy_target_normalizer( |
| 179 | + log_prior + clipped_advantage / kl_weight) |
| 180 | + policy_targets = PolicyTarget(actions=actions, weights=weights) |
| 181 | + return policy_targets |
| 182 | + |
| 183 | + |
| 184 | +def sampled_cmpo_policy_targets( |
| 185 | + prior_distribution, |
| 186 | + embeddings, |
| 187 | + rng_key, |
| 188 | + baseline_value, |
| 189 | + q_provider, |
| 190 | + advantage_normalizer, |
| 191 | + *, |
| 192 | + num_actions=2, |
| 193 | + min_target_advantage=-jnp.inf, |
| 194 | + max_target_advantage=1.0, |
| 195 | + kl_weight=1.0, |
| 196 | +) -> PolicyTarget: |
| 197 | + """Policy targets for sampled CMPO. |
| 198 | +
|
| 199 | + As in CMPO the policy targets are in-expectation proportional to: |
| 200 | + `prior(a|s) * exp(clip(norm(Q(s, a))))` |
| 201 | + However we only sample a subset of the actions, this allows to scale to |
| 202 | + large discrete action spaces and to continuous actions. |
| 203 | +
|
| 204 | + See "Muesli: Combining Improvements in Policy Optimization" by Hessel et al. |
| 205 | + (https://arxiv.org/pdf/2104.06159.pdf). |
| 206 | +
|
| 207 | + Args: |
| 208 | + prior_distribution: the prior policy distribution. |
| 209 | + embeddings: embeddings for the `q_provider`. |
| 210 | + rng_key: a JAX pseudo random number generator key. |
| 211 | + baseline_value: the baseline for `advantage_normalizer`. |
| 212 | + q_provider: a fn to compute q values. |
| 213 | + advantage_normalizer: a fn to normalise advantages. |
| 214 | + *, |
| 215 | + num_actions: The number of actions to expand on each step. |
| 216 | + min_target_advantage: The minimum advantage of a policy target. |
| 217 | + max_target_advantage: The max advantage of a policy target. |
| 218 | + kl_weight: The coefficient for the KL regularizer. |
| 219 | +
|
| 220 | + Returns: |
| 221 | + the sampled clipped MPO policy targets. |
| 222 | + """ |
| 223 | + # Expecting shape [B]. |
| 224 | + chex.assert_rank(baseline_value, 1) |
| 225 | + query_rng_key, action_key = jax.random.split(rng_key) |
| 226 | + del rng_key |
| 227 | + |
| 228 | + # Sampling the actions from the prior. |
| 229 | + actions = prior_distribution.sample( |
| 230 | + seed=action_key, sample_shape=[num_actions]) |
| 231 | + |
| 232 | + # Using vmap over the num_expanded in axis=0. |
| 233 | + def _query_q(actions): |
| 234 | + return q_provider( |
| 235 | + # Using the same rng_key for the all actions samples. |
| 236 | + rng_key=query_rng_key, |
| 237 | + action=actions, |
| 238 | + embeddings=embeddings) |
| 239 | + qvalues = jax.vmap(_query_q)(actions) |
| 240 | + |
| 241 | + # Using the same advantage normalization as for policy gradients. |
| 242 | + raw_advantage = advantage_normalizer( |
| 243 | + returns=qvalues, baseline_value=baseline_value) |
| 244 | + clipped_advantage = jnp.clip( |
| 245 | + raw_advantage, min_target_advantage, max_target_advantage) |
| 246 | + |
| 247 | + # The expected normalized weight would be 1.0. The weights would be |
| 248 | + # normalized, if the baseline_value is the log of the expected weight. I.e., |
| 249 | + # if the baseline_value is log(sum_a(prior(a|s) * exp(Q(s, a)/c))). |
| 250 | + weights = jnp.exp(clipped_advantage / kl_weight) |
| 251 | + |
| 252 | + # The weights are tiled, if using multiple continuous actions. |
| 253 | + # It is OK to use multiple continuous actions inside the Q(s, a), |
| 254 | + # because the action is sampled from the joint distribution |
| 255 | + # and weight is not based on non-joint probabilities. |
| 256 | + log_prob = prior_distribution.log_prob(actions) |
| 257 | + weights = jnp.broadcast_to( |
| 258 | + base.lhs_broadcast(weights, log_prob), log_prob.shape) |
| 259 | + return PolicyTarget(actions=actions, weights=weights) |
| 260 | + |
| 261 | + |
| 262 | +def softmax_policy_target_normalizer(log_weights): |
| 263 | + """Returns self-normalized weights. |
| 264 | +
|
| 265 | + The self-normalizing weights introduce a significant bias, |
| 266 | + if computing the average weight from a small number of samples. |
| 267 | +
|
| 268 | + Args: |
| 269 | + log_weights: log unnormalized weights, shape `[num_targets, ...]`. |
| 270 | +
|
| 271 | + Returns: |
| 272 | + Weights divided by average weight from sample. Weights sum to `num_targets`. |
| 273 | + """ |
| 274 | + num_targets = log_weights.shape[0] |
| 275 | + return num_targets * jax.nn.softmax(log_weights, axis=0) |
| 276 | + |
| 277 | + |
| 278 | +def loo_policy_target_normalizer(log_weights): |
| 279 | + """A leave-one-out normalizer. |
| 280 | +
|
| 281 | + Args: |
| 282 | + log_weights: log unnormalized weights, shape `[num_targets, ...]`. |
| 283 | +
|
| 284 | + Returns: |
| 285 | + Weights divided by a consistent estimate of the average weight. The weights |
| 286 | + are not guaranteed to sum to `num_targets`. |
| 287 | + """ |
| 288 | + num_targets = log_weights.shape[0] |
| 289 | + weights = jnp.exp(log_weights) |
| 290 | + # Using a safe consistent estimator of the average weight, independently of |
| 291 | + # the numerator. |
| 292 | + # The unnormalized weight are already approximately normalized by a |
| 293 | + # baseline_value, so we use `1` as the initial estimate of the average weight. |
| 294 | + avg_weight = ( |
| 295 | + 1 + jnp.sum(weights, axis=0, keepdims=True) - weights) / num_targets |
| 296 | + return weights / avg_weight |
0 commit comments