Skip to content

Commit b7918c0

Browse files
Move GKDTrainer to experimental module (#4474)
Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 07b5011 commit b7918c0

File tree

16 files changed

+612
-549
lines changed

16 files changed

+612
-549
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@
6060
title: DPO
6161
- local: online_dpo_trainer
6262
title: Online DPO
63-
- local: gkd_trainer
64-
title: GKD
6563
- local: grpo_trainer
6664
title: GRPO
6765
- local: kto_trainer
@@ -107,6 +105,8 @@
107105
title: CPO
108106
- local: gfpo
109107
title: GFPO
108+
- local: gkd_trainer
109+
title: GKD
110110
- local: gold_trainer
111111
title: GOLD
112112
- local: grpo_with_replay_buffer

docs/source/dataset_formats.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ Choosing the right dataset type depends on the task you are working on and the s
390390
| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
391391
| [`experimental.cpo.CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
392392
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
393-
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
393+
| [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) |
394394
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
395395
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
396396
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |

docs/source/example_overview.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl
4343
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`experimental.cpo.CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
4444
| [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a model. |
4545
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
46-
| [`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py) | This script shows how to use [`HfPairwiseJudge`] or [`experimental.judges.OpenAIPairwiseJudge`] to judge model generations. |
47-
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
46+
| [`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py) | This script shows how to use [`experimental.judges.HfPairwiseJudge`] or [`experimental.judges.OpenAIPairwiseJudge`] to judge model generations. |
47+
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`experimental.gkd.GKDTrainer`] to fine-tune a model. |
4848
| [`trl/scripts/grpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a model. |
4949
| [`examples/scripts/grpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
5050
| [`examples/scripts/gspo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune model for reasoning using the [AI-MO/NuminaMath-TIR](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. |

docs/source/gkd_trainer.md

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface.
1919

2020
## Usage tips
2121

22-
The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
22+
The [`experimental.gkd.GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`experimental.gkd.GKDConfig`] namely:
2323

2424
* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
25-
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
25+
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
2626
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.
2727

2828
The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.
@@ -34,11 +34,8 @@ The basic API is as follows:
3434

3535
```python
3636
from datasets import Dataset
37-
from trl import GKDConfig, GKDTrainer
38-
from transformers import (
39-
AutoModelForCausalLM,
40-
AutoTokenizer,
41-
)
37+
from transformers import AutoModelForCausalLM, AutoTokenizer
38+
from trl.experimental.gkd import GKDConfig, GKDTrainer
4239

4340
NUM_DUMMY_SAMPLES = 100
4441

@@ -92,11 +89,11 @@ The dataset should be formatted as a list of "messages" where each message is a
9289

9390
## GKDTrainer
9491

95-
[[autodoc]] GKDTrainer
92+
[[autodoc]] experimental.gkd.GKDTrainer
9693
- train
9794
- save_model
9895
- push_to_hub
9996

10097
## GKDConfig
10198

102-
[[autodoc]] GKDConfig
99+
[[autodoc]] experimental.gkd.GKDConfig

docs/source/gold_trainer.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Key capabilities:
1313

1414
1. **Cross-tokenizer alignment** – GOLD incrementally decodes the student and teacher tokens, groups passages with the same visible text, and merges probabilities inside each group. This guarantees loss terms are computed over the full completion even when token boundaries differ.
1515
2. **Hybrid ULD loss** – when `uld_use_hybrid_loss` is enabled, GOLD compares exact vocabulary matches directly and falls back to the original sorted-probability ULD loss for unmatched tokens. This improves stability for students whose vocabularies only partially overlap with the teacher.
16-
3. **Seamless integration with GKD** – GOLD inherits the on-policy vs. off-policy scheduling from the [`GKDTrainer`](./gkd_trainer.md), so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run.
16+
3. **Seamless integration with GKD** – GOLD inherits the on-policy vs. off-policy scheduling from the [`experimental.gkd.GKDTrainer`], so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run.
1717

1818
> [!NOTE]
1919
> GOLD is currently part of the `trl.experimental` namespace. APIs may change without notice while the feature is iterated on.
@@ -27,7 +27,7 @@ messages). Important configuration flags on [`GOLDConfig`] include:
2727
* `teacher_tokenizer_name_or_path` – required when `use_uld_loss=True`; GOLD uses the teacher tokenizer to align tokens.
2828
* `uld_use_hybrid_loss`, `uld_hybrid_matched_weight`, `uld_hybrid_unmatched_weight` – enables and weights the hybrid
2929
matched/unmatched loss.
30-
* `beta`, `lmbda`, `seq_kd` – inherited from `GKDConfig`, controlling the generalized JSD interpolation and on-policy
30+
* `beta`, `lmbda`, `seq_kd` – inherited from [`experimental.gkd.GKDConfig`], controlling the generalized JSD interpolation and on-policy
3131
sampling ratio.
3232

3333
A minimal end-to-end example:

docs/source/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL
4848

4949
### Knowledge distillation
5050

51-
- [`GKDTrainer`]
51+
- [`experimental.gkd.GKDTrainer`] 🧪
5252
- [`experimental.minillm.MiniLLMTrainer`] 🧪
5353

5454
</div>

docs/source/liger_kernel_integration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ training_args = KTOConfig(..., use_liger_kernel=True)
6767
<hfoption id="GKD">
6868

6969
```python
70-
from trl import GKDConfig
70+
from trl.experimental.gkd import GKDConfig
7171

7272
training_args = GKDConfig(..., use_liger_kernel=True)
7373
```

docs/source/paper_index.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,12 +646,12 @@ On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to
646646

647647
Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data.
648648

649-
To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`GKDTrainer`] and [`GKDConfig`]:
649+
To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`experimental.gkd.GKDTrainer`] and [`experimental.gkd.GKDConfig`]:
650650

651651
```python
652-
from trl import GKDConfig
652+
from trl.experimental.gkd import GKDConfig
653653

654-
config = GKDConfig(
654+
training_args = GKDConfig(
655655
lmbda=1.0, # student produces rollouts for all batches
656656
beta=1.0, # to ensure reverse-kl as the loss function
657657
teacher_model_name_or_path="teacher-model", # specify the teacher model

docs/source/reducing_memory_usage.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ training_args = KTOConfig(..., use_liger_kernel=True)
165165
<hfoption id="GKD">
166166

167167
```python
168-
from trl import GKDConfig
168+
from trl.experimental.gkd import GKDConfig
169169

170170
training_args = GKDConfig(..., use_liger_kernel=True)
171171
```

examples/scripts/gkd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@
5858
from transformers import AutoTokenizer, GenerationConfig
5959

6060
from trl import (
61-
GKDConfig,
62-
GKDTrainer,
6361
LogCompletionsCallback,
6462
ModelConfig,
6563
ScriptArguments,
@@ -68,6 +66,7 @@
6866
get_peft_config,
6967
get_quantization_config,
7068
)
69+
from trl.experimental.gkd import GKDConfig, GKDTrainer
7170

7271

7372
# Enable logging in a Hugging Face Space

0 commit comments

Comments
 (0)