Skip to content

Commit

Permalink
chi modification on bandit code for avg-worst tradeoff
Browse files Browse the repository at this point in the history
  • Loading branch information
IasonC committed Sep 6, 2024
1 parent e99592e commit e624431
Show file tree
Hide file tree
Showing 4 changed files with 2,094 additions and 9 deletions.
18 changes: 13 additions & 5 deletions algos/linear_bandit/group_robust_dpo_vectorised_gradfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
param_limit: int = 1, ## elements of vector θ range in [0, param_limit]
use_closed_form: bool=False, ## closed-form regression solution for IPO
lamba: float =0, ## L2 regularisation for closed-form regression of IPO objective in Linear Bandits case
chi: float = 1.0, ## tradeoff between worst-case and avg accuracy
l2_reg_rdpo: float = 0, ## L2 regularisation for vectorised RDPO
reg_by_group_weights: float = 0, ## regularisation on vectorised RDPO by subtracting step*group_weights^2
train_agent: bool=True, ## if True, use self.train(); else, use self.random_train() func
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(

self.use_closed_form=use_closed_form
self.lamba=lamba
self.chi=chi
self.l2_reg_rdpo = l2_reg_rdpo
self.reg_by_group_weights = reg_by_group_weights
self.train_agent=train_agent
Expand Down Expand Up @@ -383,13 +385,14 @@ def sample_group_transition(group_id):
else:
exp_step_size = self.exp_step_size

self.group_weights = self.group_weights*np.exp(exp_step_size*group_loss)#update weights based on group loss calculated
self.group_weights = self.group_weights*np.exp(exp_step_size*group_loss*self.chi)#update weights based on group loss calculated
self.group_weights = self.group_weights/np.sum(self.group_weights)#normalize the weights

weighted_group_grad = np.zeros_like(group_grad)
for group_id in range(self.group_num):
group_indices = group_id_idx_all[group_id]
group_grad_weighted_sum = np.sum(-neg_cur_data_grad[group_indices], axis=0) * self.group_weights[group_id]
worst_avg_tradeoff = (1-self.chi)/self.group_num + self.chi*self.group_weights[group_id]
group_grad_weighted_sum = np.sum(-neg_cur_data_grad[group_indices], axis=0) * worst_avg_tradeoff # * self.group_weights[group_id]
if self.importance_sampling==False: # divide grads by group counts in RDPO
weighted_group_grad[group_id] = group_grad_weighted_sum / cur_group_counts[group_id] # len(sampled_group_transitions) ############### had self.group_weights[group_id] scaling before
else: # Importance Sampling has weights assigned to inv-freq and then group-grads divided by batch size
Expand Down Expand Up @@ -626,8 +629,9 @@ def sample_group_transition(group_id):

if self.importance_sampling==False:
#print(self.group_weights,group_loss,np.exp(self.exp_step_size*group_loss))
self.group_weights=self.group_weights*np.exp(self.exp_step_size*group_loss)#update weights based on group loss calculated
#print(self.group_weights)
print("CHI = ", self.chi, " type ", type(self.chi))
self.group_weights=self.group_weights*np.exp(self.exp_step_size*group_loss*self.chi)#update weights based on group loss calculated
print(self.group_weights)
self.group_weights=self.group_weights/np.sum(self.group_weights)#normalize the weights
self.hist_grad_squared_norm += np.sum(np.square(grad))
self.hist_group_loss+=group_loss
Expand All @@ -637,7 +641,10 @@ def sample_group_transition(group_id):
else:
step_size = self.step_size
#print(grad)
self.cur_group_counts_closedform = cur_group_counts
live_grad=self.WeightedRegression(sampled_group_transitions,self.lamba)
self.theta_update = 'ClosedForm'
self.total_loss = 'ClosedForm'
#self.param=np.array([1.0,2.0])
return np.sqrt(np.sum(np.square(grad))),live_grad

Expand All @@ -660,7 +667,8 @@ def WeightedRegression(self, dataset: List[GroupTransition], lamba: float)-> flo
self.feature_func(state, non_pref_act,group_id),
)
Y.append(feat_pref_act-feat_non_pref_act)
w.append(self.group_weights[group_id])
w.append( (1-self.chi)/self.group_num + self.chi*self.group_weights[group_id] )

Y=np.array(Y)
w=np.array(w)
#print(w)
Expand Down
Loading

0 comments on commit e624431

Please sign in to comment.