You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/example_overview.md
+2-2Lines changed: 2 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -43,8 +43,8 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl
43
43
|[`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py)| This script shows how to use the [`experimental.cpo.CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
44
44
|[`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py)| This script shows how to use the [`DPOTrainer`] to fine-tune a model. |
45
45
|[`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py)| This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
46
-
|[`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py)| This script shows how to use [`HfPairwiseJudge`] or [`experimental.judges.OpenAIPairwiseJudge`] to judge model generations. |
47
-
|[`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py)| This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
46
+
|[`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py)| This script shows how to use [`experimental.judges.HfPairwiseJudge`] or [`experimental.judges.OpenAIPairwiseJudge`] to judge model generations. |
47
+
|[`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py)| This script shows how to use the [`experimental.gkd.GKDTrainer`] to fine-tune a model. |
48
48
|[`trl/scripts/grpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py)| This script shows how to use the [`GRPOTrainer`] to fine-tune a model. |
49
49
|[`examples/scripts/grpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py)| This script shows how to use the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
50
50
|[`examples/scripts/gspo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo.py)| This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune model for reasoning using the [AI-MO/NuminaMath-TIR](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. |
Copy file name to clipboardExpand all lines: docs/source/gkd_trainer.md
+6-9Lines changed: 6 additions & 9 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -19,10 +19,10 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface.
19
19
20
20
## Usage tips
21
21
22
-
The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
22
+
The [`experimental.gkd.GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`experimental.gkd.GKDConfig`] namely:
23
23
24
24
*`lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
25
-
*`seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
25
+
*`seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
26
26
*`beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.
27
27
28
28
The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.
@@ -34,11 +34,8 @@ The basic API is as follows:
34
34
35
35
```python
36
36
from datasets import Dataset
37
-
from trl import GKDConfig, GKDTrainer
38
-
from transformers import (
39
-
AutoModelForCausalLM,
40
-
AutoTokenizer,
41
-
)
37
+
from transformers import AutoModelForCausalLM, AutoTokenizer
38
+
from trl.experimental.gkd import GKDConfig, GKDTrainer
42
39
43
40
NUM_DUMMY_SAMPLES=100
44
41
@@ -92,11 +89,11 @@ The dataset should be formatted as a list of "messages" where each message is a
Copy file name to clipboardExpand all lines: docs/source/gold_trainer.md
+2-2Lines changed: 2 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -13,7 +13,7 @@ Key capabilities:
13
13
14
14
1.**Cross-tokenizer alignment** – GOLD incrementally decodes the student and teacher tokens, groups passages with the same visible text, and merges probabilities inside each group. This guarantees loss terms are computed over the full completion even when token boundaries differ.
15
15
2.**Hybrid ULD loss** – when `uld_use_hybrid_loss` is enabled, GOLD compares exact vocabulary matches directly and falls back to the original sorted-probability ULD loss for unmatched tokens. This improves stability for students whose vocabularies only partially overlap with the teacher.
16
-
3.**Seamless integration with GKD** – GOLD inherits the on-policy vs. off-policy scheduling from the [`GKDTrainer`](./gkd_trainer.md), so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run.
16
+
3.**Seamless integration with GKD** – GOLD inherits the on-policy vs. off-policy scheduling from the [`experimental.gkd.GKDTrainer`], so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run.
17
17
18
18
> [!NOTE]
19
19
> GOLD is currently part of the `trl.experimental` namespace. APIs may change without notice while the feature is iterated on.
@@ -27,7 +27,7 @@ messages). Important configuration flags on [`GOLDConfig`] include:
27
27
*`teacher_tokenizer_name_or_path` – required when `use_uld_loss=True`; GOLD uses the teacher tokenizer to align tokens.
28
28
*`uld_use_hybrid_loss`, `uld_hybrid_matched_weight`, `uld_hybrid_unmatched_weight` – enables and weights the hybrid
29
29
matched/unmatched loss.
30
-
*`beta`, `lmbda`, `seq_kd` – inherited from `GKDConfig`, controlling the generalized JSD interpolation and on-policy
30
+
*`beta`, `lmbda`, `seq_kd` – inherited from [`experimental.gkd.GKDConfig`], controlling the generalized JSD interpolation and on-policy
Copy file name to clipboardExpand all lines: docs/source/paper_index.md
+3-3Lines changed: 3 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -646,12 +646,12 @@ On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to
646
646
647
647
Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data.
648
648
649
-
To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`GKDTrainer`] and [`GKDConfig`]:
649
+
To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`experimental.gkd.GKDTrainer`] and [`experimental.gkd.GKDConfig`]:
650
650
651
651
```python
652
-
from trl import GKDConfig
652
+
from trl.experimental.gkdimport GKDConfig
653
653
654
-
config= GKDConfig(
654
+
training_args= GKDConfig(
655
655
lmbda=1.0, # student produces rollouts for all batches
656
656
beta=1.0, # to ensure reverse-kl as the loss function
657
657
teacher_model_name_or_path="teacher-model", # specify the teacher model
0 commit comments