Skip to content

Commit

Permalink
fixed gradients rdpo
Browse files Browse the repository at this point in the history
  • Loading branch information
IasonC committed Mar 22, 2024
1 parent e04d56c commit ea6ce6e
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 39 deletions.
14 changes: 12 additions & 2 deletions algos/linear_bandit/group_dpo_vectorised.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(
ipo_grad_type: str = 'justdpo', ## `justdpo` (vectorised version), `linear` (IPO), or `log` (IPO)
param_limit: int = 1, ## elements of vector θ range in [0, param_limit]
lamba: float=0, ## L2 regularisation for closed-form regression of IPO objective in Linear Bandits case
train_agent: bool=True ## if True, use self.train(); else, use self.random_train() func
train_agent: bool=True, ## if True, use self.train(); else, use self.random_train() func
report_iter: int = 2000 ## log metrics after these iters
) -> None:
self.state_dim = state_dim
self.action_num = action_num
Expand All @@ -60,6 +61,7 @@ def __init__(
self.param = np.random.uniform(0, param_limit, self.feature_dim)
self.lamba = lamba
self.train_agent=train_agent
self.report_iter = report_iter

print('Vectorised DPO; step size = ', self.step_size)

Expand Down Expand Up @@ -205,6 +207,7 @@ def update_once(self, dataset: List[GroupTransition]) -> float:
for group_id in range(self.group_num):
group_indices = group_id_idx_all[group_id]
group_loss[group_id] = np.sum(-np.log(sigmoid(log_ratio_diff_all[group_indices])))#+self.adj[group_id]/np.sqrt(self.group_counts[group_id]) #calculate group losses
self.total_loss = np.sum(-np.log(sigmoid(log_ratio_diff_all)))/len(dataset)
elif self.ipo_grad_type=='linear':
lin_diff = feature_diff_all @ self.param.reshape(self.feature_dim,1) - 0.5*(1/self.reg_coef)
coef = -2*lin_diff/self.reg_coef
Expand Down Expand Up @@ -237,6 +240,10 @@ def update_once(self, dataset: List[GroupTransition]) -> float:
else:
step_size = self.step_size
self.param = self.param - step_size * grad

#print(f'GRAD -- {grad} || PARAM -- {self.param}\n')
self.theta_update = step_size * grad

return np.sqrt(np.sum(np.square(grad))) # grad L2-norm

def evaluate_ipo_loss(self, dataset: List[GroupTransition], policy=None) -> float:
Expand Down Expand Up @@ -769,7 +776,10 @@ def train(self, dataset: List[GroupTransition],
return rew
for step in range(self.num_iters):
grad_norm = self.update_once(dataset)
if step % 2000 == 0:
if step % self.report_iter == 0:
print('UPDATE PARAM GroupDPO: ', self.theta_update)
print('TOTAL LOSS GroupDPO: ', self.total_loss)

if self.ipo_grad_type=='justdpo':
train_loss = self.evaluate_loss(dataset)
val_loss = self.evaluate_loss(val_dataset)
Expand Down
Loading

0 comments on commit ea6ce6e

Please sign in to comment.