Skip to content

Commit

Permalink
Determinism experiments (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
natolambert authored Jun 8, 2024
1 parent 4e2d848 commit eeb28e4
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 354 deletions.
329 changes: 0 additions & 329 deletions rewardbench/__main__.py

This file was deleted.

39 changes: 19 additions & 20 deletions rewardbench/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,27 +213,26 @@ def inference_step(self, batch, ref_free: bool = False) -> list:
Uses TRL inference batched logprob computation to compute chosen + rejected
logprobs then compute rewards and win rate.
"""
with torch.no_grad():
(
policy_chosen_logps,
policy_rejected_logps,
_, # policy_chosen_logits,
_, # policy_rejected_logits,
) = self.concatenated_forward(self.model, batch)

# optionally compute reward without normalizing via reference model
if not ref_free:
(
policy_chosen_logps,
policy_rejected_logps,
_, # policy_chosen_logits,
_, # policy_rejected_logits,
) = self.concatenated_forward(self.model, batch)

# optionally compute reward without normalizing via reference model
if not ref_free:
(
ref_chosen_logps,
ref_rejected_logps,
_, # ref_chosen_logits,
_, # ref_rejected_logits,
) = self.concatenated_forward(self.ref_model, batch)
chosen_logratios = policy_chosen_logps.detach().cpu() - ref_chosen_logps.detach().cpu()
rejected_logratios = policy_rejected_logps.detach().cpu() - ref_rejected_logps.detach().cpu()
else:
chosen_logratios = policy_chosen_logps.detach().cpu()
rejected_logratios = policy_rejected_logps.detach().cpu()
ref_chosen_logps,
ref_rejected_logps,
_, # ref_chosen_logits,
_, # ref_rejected_logits,
) = self.concatenated_forward(self.ref_model, batch)
chosen_logratios = policy_chosen_logps.detach().cpu() - ref_chosen_logps.detach().cpu()
rejected_logratios = policy_rejected_logps.detach().cpu() - ref_rejected_logps.detach().cpu()
else:
chosen_logratios = policy_chosen_logps.detach().cpu()
rejected_logratios = policy_rejected_logps.detach().cpu()

return chosen_logratios, rejected_logratios

Expand Down
Loading

0 comments on commit eeb28e4

Please sign in to comment.