Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Z_Normalization #551

Merged
merged 4 commits into from
Nov 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions xnmt/rl/policy_gradient.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum

import dynet as dy
import numpy as np

from xnmt import events, losses, param_initializers
from xnmt.modelparts import transforms
Expand Down Expand Up @@ -31,7 +32,7 @@ class PolicyGradient(Serializable):
@events.register_xnmt_handler
def __init__(self, policy_network=None,
baseline=None,
z_normalization=True, # TODO unused?
z_normalization=True,
conf_penalty=None,
weight=1.0,
input_dim=Ref("exp_global.default_layer_dim"),
Expand All @@ -52,6 +53,7 @@ def __init__(self, policy_network=None,

self.confidence_penalty = self.add_serializable_component("conf_penalty", conf_penalty, lambda: conf_penalty) if conf_penalty is not None else None
self.weight = weight
self.z_normalization = z_normalization

"""
state: Input state.
Expand Down Expand Up @@ -90,10 +92,11 @@ def calc_loss(self, policy_reward):
loss.add_loss("rl_baseline", baseline_loss)
## Z-Normalization
rewards = dy.concatenate(rewards, d=0)
dim, batch_size = rewards.dim()
rewards_mean = dy.mean_dim(rewards, [0], False)
rewards_std = dy.std_dim(rewards, [0], False)
rewards = dy.cdiv(rewards - rewards_mean, rewards_std+1e-10)
if self.z_normalization:
rewards_value = rewards.value()
rewards_mean = np.mean(rewards_value)
rewards_std = np.std(rewards_value) + 1e-10
rewards = (rewards - rewards_mean) / rewards_std
## Calculate Confidence Penalty
if self.confidence_penalty:
cp_loss = self.confidence_penalty.calc_loss(self.policy_lls)
Expand All @@ -108,8 +111,8 @@ def calc_loss(self, policy_reward):
if self.valid_pos is not None:
ll = dy.pick_batch_elems(ll, self.valid_pos[i])
reward = dy.pick_batch_elems(reward, self.valid_pos[i])
reinf_loss.append(dy.sum_batches(-ll * reward))
loss.add_loss("rl_reinf", self.weight * dy.esum(reinf_loss))
reinf_loss.append(dy.sum_batches(ll * reward))
loss.add_loss("rl_reinf", -self.weight * dy.esum(reinf_loss))
## the composed losses
return loss

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,15 @@ def transduce(self, embed_sent: ExpressionSequence) -> List[ExpressionSequence]:
def on_calc_additional_loss(self, trg, generator, generator_loss):
if self.policy_learning is None:
return None
trg_counts = dy.inputTensor([t.len_unpadded() for t in trg], batched=True)
reward = FactoredLossExpr()
# Adding all reward from the translator
for loss_key, loss_value in generator_loss.get_nobackprop_loss().items():
if loss_key == 'mle':
reward.add_loss('mle', dy.cdiv(-loss_value, trg_counts))
else:
reward.add_loss(loss_key, -loss_value)
reward.add_loss("generator", -dy.inputTensor(generator_loss.value(), batched=True))
if self.length_prior is not None:
reward.add_loss('seg_lp', self.length_prior.log_ll(self.seg_size_unpadded))
reward.add_loss('length_prior', self.length_prior.log_ll(self.seg_size_unpadded))
reward_value = reward.value()
if trg.batch_size() == 1:
reward_value = [reward_value]
reward_tensor = dy.inputTensor(reward_value, batched=True)
### Calculate losses
### Calculate losses
try:
return self.policy_learning.calc_loss(reward_tensor)
finally:
Expand Down