Skip to content

Commit 15292f7

Browse files
Merge branch 'main' into docs/update-peft-integration
2 parents 1c8f60a + 6f41b18 commit 15292f7

20 files changed

+965
-127
lines changed

docs/source/dataset_formats.md

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,6 @@ preference_example = {
132132
}
133133
```
134134

135-
Conversational datasets are useful for training chat models, but must be converted into a standard format before being used with TRL trainers. This is typically done using chat templates specific to the model being used. For more information, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
136-
137135
#### Tool Calling
138136

139137
Some chat templates support *tool calling*, which allows the model to interact with external functions—referred to as **tools**—during generation. This extends the conversational capabilities of the model by enabling it to output a `"tool_calls"` field instead of a standard `"content"` message whenever it decides to invoke a tool.
@@ -405,76 +403,6 @@ Choosing the right dataset type depends on the task you are working on and the s
405403
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
406404
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
407405

408-
> [!TIP]
409-
> TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
410-
> For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
411-
412-
## Working with conversational datasets in TRL
413-
414-
Conversational datasets are increasingly common, especially for training chat models. However, some TRL trainers don't support conversational datasets in their raw format. (For more information, see [issue #2071](https://github.com/huggingface/trl/issues/2071).) These datasets must first be converted into a standard format.
415-
Fortunately, TRL offers tools to easily handle this conversion, which are detailed below.
416-
417-
### Converting a conversational dataset into a standard dataset
418-
419-
To convert a conversational dataset into a standard dataset, you need to *apply a chat template* to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use.
420-
421-
For detailed instructions on using chat templating, refer to the [Chat templating section in the `transformers` documentation](https://huggingface.co/docs/transformers/en/chat_templating).
422-
423-
In TRL, the method you apply to convert the dataset will vary depending on the task. Fortunately, TRL provides a helper function called [`apply_chat_template`] to simplify this process. Here's an example of how to use it:
424-
425-
```python
426-
from transformers import AutoTokenizer
427-
from trl import apply_chat_template
428-
429-
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
430-
431-
example = {
432-
"prompt": [{"role": "user", "content": "What color is the sky?"}],
433-
"completion": [{"role": "assistant", "content": "It is blue."}]
434-
}
435-
436-
apply_chat_template(example, tokenizer)
437-
# Output:
438-
# {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', 'completion': 'It is blue.<|end|>\n<|endoftext|>'}
439-
```
440-
441-
Alternatively, you can use the [`~datasets.Dataset.map`] method to apply the template across an entire dataset:
442-
443-
```python
444-
from datasets import Dataset
445-
from trl import apply_chat_template
446-
447-
dataset_dict = {
448-
"prompt": [[{"role": "user", "content": "What color is the sky?"}],
449-
[{"role": "user", "content": "Where is the sun?"}]],
450-
"completion": [[{"role": "assistant", "content": "It is blue."}],
451-
[{"role": "assistant", "content": "In the sky."}]]
452-
}
453-
454-
dataset = Dataset.from_dict(dataset_dict)
455-
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
456-
# Output:
457-
# {'prompt': ['<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n',
458-
# '<|user|>\nWhere is the sun?<|end|>\n<|assistant|>\n'],
459-
# 'completion': ['It is blue.<|end|>\n<|endoftext|>', 'In the sky.<|end|>\n<|endoftext|>']}
460-
```
461-
462-
> [!WARNING]
463-
> We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation.
464-
> For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
465-
466-
> [!WARNING]
467-
> It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
468-
>
469-
> ```python
470-
> apply_chat_template(example, AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct"))
471-
> # Output:
472-
> # {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
473-
> # 'completion': 'It is blue.<|im_end|>\n'}
474-
> ```
475-
>
476-
> Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
477-
478406
## Using any dataset with TRL: preprocessing and conversion
479407

480408
Many datasets come in formats tailored to specific tasks, which might not be directly compatible with TRL. To use such datasets with TRL, you may need to preprocess and convert them into the required format.
Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
# Liger Kernel Integration
22

3-
> [!WARNING]
4-
> Section under construction. Feel free to contribute!
5-
63
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, with more to come. The kernel works out of the box with [FlashAttention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
74

85
With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance.
@@ -11,19 +8,71 @@ With this memory reduction, you can potentially turn off `cpu_offloading` or gra
118
| --- | --- |
129
| ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) |
1310

14-
1. To use Liger-Kernel in [`SFTTrainer`], first install it by:
15-
11+
## Supported Trainers
12+
13+
Liger Kernel is supported in the following TRL trainers:
14+
- **SFT** (Supervised Fine-Tuning)
15+
- **DPO** (Direct Preference Optimization)
16+
- **GRPO** (Group Relative Policy Optimization)
17+
- **KTO** (Kahneman-Tversky Optimization)
18+
- **GKD** (Generalized Knowledge Distillation)
19+
20+
## Usage
21+
22+
1. First, install Liger Kernel:
23+
1624
```bash
1725
pip install liger-kernel
1826
```
1927

20-
2. Once installed, set `use_liger_kernel` in [`SFTConfig`]. No other changes are needed!
28+
2. Once installed, set `use_liger_kernel=True` in your trainer config. No other changes are needed!
29+
30+
<hfoptions id="liger">
31+
<hfoption id="SFT">
32+
33+
```python
34+
from trl import SFTConfig
35+
36+
training_args = SFTConfig(..., use_liger_kernel=True)
37+
```
38+
39+
</hfoption>
40+
<hfoption id="DPO">
41+
42+
```python
43+
from trl import DPOConfig
44+
45+
training_args = DPOConfig(..., use_liger_kernel=True)
46+
```
47+
48+
</hfoption>
49+
<hfoption id="GRPO">
50+
51+
```python
52+
from trl import GRPOConfig
53+
54+
training_args = GRPOConfig(..., use_liger_kernel=True)
55+
```
56+
57+
</hfoption>
58+
<hfoption id="KTO">
2159

2260
```python
23-
training_args = SFTConfig(
24-
use_liger_kernel=True,
25-
...
26-
)
61+
from trl import KTOConfig
62+
63+
training_args = KTOConfig(..., use_liger_kernel=True)
2764
```
2865

66+
</hfoption>
67+
<hfoption id="GKD">
68+
69+
```python
70+
from trl import GKDConfig
71+
72+
training_args = GKDConfig(..., use_liger_kernel=True)
73+
```
74+
75+
</hfoption>
76+
</hfoptions>
77+
2978
To learn more about Liger-Kernel, visit their [official repository](https://github.com/linkedin/Liger-Kernel/).

docs/source/openenv.md

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,218 @@ Below is the reward curve from training:
156156
<iframe src="https://trl-lib-trackio.hf.space?project=openenv&metrics=train/rewards/reward_from_env/mean&runs=qgallouedec-1761202871&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe>
157157

158158
To learn more about how to create custom environments, see the [OpenEnv documentation](https://github.com/meta-pytorch/OpenEnv/blob/main/src/envs/README.md).
159+
160+
## Advanced Example
161+
162+
Let's level this up a bit by training a model to interact with a more complex environment. We'll use the game word guessing game [wordle](https://www.nytimes.com/games/wordle/index.html) from the `textarena` environment.
163+
164+
### The TextArena Environment
165+
166+
[TextArena](https://huggingface.co/papers/2504.11442) is an open-source collection of competitive text-based games designed to evaluate reasoning skills in LLMs using textual games like Wordle, Snake, Tic-Tac-Toe, and more. Research has shown that such games improve model performance on reasoning tasks.
167+
168+
![image of textarena](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/text_arena_evals.png)
169+
170+
We will use the `textarena` environment to train a model to play Wordle. The environment is a simple text based response environment that allows the model to interact with the game by making guesses and receive feedback on them.
171+
172+
### Wordle
173+
174+
Wordle is a useful game to train a model on because it requires the model to reason about the word and the feedback provided by the environment. Also, it is a purely language based game that requires no external tools or knowledge. Furthermore, we found that models from 1 billion parameters and up are able to improve on wordle and only require 8 tokens to generate a guess, which makes the game a good benchmark to experiment with Reinforcement Learning environments without significant compute requirements.
175+
176+
> [!NOTE] How does Wordle work?
177+
> Wordle is a word guessing game where the player has to guess a 5-letter word. The player can make 6 guesses, and for each guess, the environment will provide feedback on the correctness of the guess. The player wins if they guess the word in 6 guesses or less. It challenges the model to generate words that are likely to be correct, and to learn from the feedback provided by the environment.
178+
>
179+
> For example, if the wordle environment returns the following feedback:
180+
>
181+
> ```
182+
> G U E S S
183+
> X G Y X X
184+
> ```
185+
> The model has guessed the word "GUESS" and the environment has provided feedback as the letters X, G, and Y. Referring to colors in the original game blank, green, and yellow. From this feedback, the model should learn that the word is "GUESS" is incorrect. The letter "E" is in the word, but in the wrong position. The letter "U" is correct and in the correct position.
186+
187+
In the TextArena environment, reward is only given when the model wins the game. The reward is 1.0 if the model wins, and 0.0 otherwise. This is not a very efficient reward signal for the model, so we have added a number of custom reward functions to the script to help the model learn to play the game. The extensible nature of `reward_funcs` and `rollout_func` allows you to add any custom reward function you want to the script.
188+
189+
### Rollout Function
190+
191+
The rollout function runs one full Wordle episode, prompting the model for a guess each turn and capturing both environment rewards and auxiliary signals such as letter coverage and repetition penalties.
192+
193+
```python
194+
def rollout_once(
195+
env: TextArenaEnv,
196+
tokenizer: AutoTokenizer,
197+
args: GRPOConfig,
198+
dataset_prompt: str,
199+
cli_args: argparse.Namespace,
200+
system_prompt: str,
201+
) -> dict[str, list]:
202+
result = env.reset()
203+
observation = result.observation
204+
205+
prompt_ids: list[int] = []
206+
completion_ids: list[int] = []
207+
logprobs: list[float] = []
208+
raw_rewards: list[float] = []
209+
green_scores: list[float] = []
210+
yellow_scores: list[float] = []
211+
repetition_scores: list[float] = []
212+
correct_scores: list[float] = []
213+
guess_counts: dict[str, int] = {}
214+
215+
for _turn in range(cli_args.max_turns):
216+
# when the game is over the environment will return a done=True
217+
if result.done:
218+
break
219+
220+
# set up the prompt for the model
221+
base_prompt = observation.prompt or dataset_prompt
222+
user_prompt = make_user_prompt(base_prompt, observation.messages)
223+
messages = [
224+
{"role": "system", "content": system_prompt},
225+
{"role": "user", "content": user_prompt},
226+
]
227+
prompt_text = tokenizer.apply_chat_template(
228+
messages,
229+
add_generation_prompt=True,
230+
tokenize=False,
231+
enable_thinking=False,
232+
)
233+
234+
# generate the completion from the model using vLLM
235+
vllm_result = request_vllm_completion(
236+
prompt_text,
237+
args,
238+
endpoint=cli_args.vllm_endpoint,
239+
timeout=cli_args.request_timeout,
240+
fallback=cli_args,
241+
)
242+
prompt_ids.extend(vllm_result["prompt_ids"])
243+
completion_ids.extend(vllm_result["completion_ids"])
244+
logprobs.extend(vllm_result["logprobs"])
245+
completion_text = vllm_result.get("text") or tokenizer.decode(
246+
vllm_result["completion_ids"], skip_special_tokens=True
247+
)
248+
# extract the guess from the completion
249+
guess = extract_guess(completion_text)
250+
251+
# step the environment with the guess
252+
result = env.step(TextArenaAction(message=guess))
253+
raw_rewards.append(float(result.reward or 0.0))
254+
observation = result.observation
255+
correct_score = float(result.reward or 0.0)
256+
feedback = extract_wordle_feedback(observation)
257+
258+
# Update guess counts
259+
previous_occurrences = guess_counts[guess]
260+
repetition_score = scale_repetition_score(previous_occurrences, len(guess_counts))
261+
guess_counts[guess] += 1
262+
263+
# calculate custom reward signals from the feedback
264+
if not feedback:
265+
green_score = 0.0
266+
yellow_score = 0.0
267+
else:
268+
green_count, yellow_count = extract_feedback_counts(feedback)
269+
green_score = green_count / 5.0
270+
yellow_score = yellow_count / 5.0
271+
272+
repetition_scores.append(repetition_score)
273+
green_scores.append(green_score)
274+
yellow_scores.append(yellow_score)
275+
correct_scores.append(correct_score)
276+
277+
correct_reward_value = correct_scores[-1] if correct_scores else (raw_rewards[-1] if raw_rewards else 0.0)
278+
279+
return {
280+
"prompt_ids": prompt_ids,
281+
"completion_ids": completion_ids,
282+
"logprobs": logprobs,
283+
"raw_rewards": raw_rewards,
284+
"correct_reward": correct_reward_value,
285+
"green_reward": green_scores[-1] if green_scores else 0.0,
286+
"yellow_reward": yellow_scores[-1] if yellow_scores else 0.0,
287+
"repetition_reward": repetition_scores[-1] if repetition_scores else 0.0,
288+
}
289+
```
290+
291+
The environment has a reward signal based on the completion of the game. We found that most models struggle to ever win the game, so we have added a number of custom reward functions to the script to help the model learn to play the game more iteratively. At first, the model will learn to cover new letters and avoid repeating guesses. As it improves, it will learn to win the game.
292+
293+
### Reward Functions
294+
295+
We log four reward streams that encourage the model to solve the puzzle, cover new letters, and avoid repeating guesses:
296+
297+
- `reward_correct`: final win/loss signal from the environment.
298+
- `reward_greens`: density of green letters in the last feedback.
299+
- `reward_yellows`: density of yellow letters in the last feedback.
300+
- `reward_repetition`: penalty for guessing the same token multiple times.
301+
302+
```python
303+
def reward_correct(completions: List[str], **kwargs: Optional[Dict]) -> List[float]:
304+
rewards = kwargs.get("correct_reward") if kwargs else None
305+
return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions)
306+
307+
308+
def reward_greens(completions: List[str], **kwargs: Optional[Dict]) -> List[float]:
309+
rewards = kwargs.get("green_reward") if kwargs else None
310+
return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions)
311+
312+
313+
def reward_yellows(completions: List[str], **kwargs: Optional[Dict]) -> List[float]:
314+
rewards = kwargs.get("yellow_reward") if kwargs else None
315+
return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions)
316+
317+
318+
def reward_repetition(completions: List[str], **kwargs: Optional[Dict]) -> List[float]:
319+
rewards = kwargs.get("repetition_reward") if kwargs else None
320+
return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions)
321+
```
322+
323+
### Training the Model
324+
325+
The training script wires the custom rollout and rewards into `GRPOTrainer`. The CLI exposes the configuration used during development as defaults, so you can override endpoints or hyperparameters at launch time.
326+
327+
```python
328+
parser = argparse.ArgumentParser()
329+
# ... add CLI arguments with sensible defaults ...
330+
cli_args = parser.parse_args()
331+
332+
trainer = GRPOTrainer(
333+
model=cli_args.model_id,
334+
processing_class=tokenizer,
335+
reward_funcs=[
336+
reward_correct,
337+
reward_greens,
338+
reward_yellows,
339+
reward_repetition,
340+
],
341+
train_dataset=dataset,
342+
args=grpo_config,
343+
rollout_func=lambda prompts, args, processing_class: rollout_func(
344+
env=env,
345+
tokenizer=tokenizer,
346+
prompts=prompts,
347+
args=args,
348+
cli_args=cli_args,
349+
system_prompt=system_prompt,
350+
),
351+
)
352+
trainer.train()
353+
```
354+
355+
### Running the Example
356+
357+
The example requires two GPUs:
358+
359+
```bash
360+
# Terminal 1: Start vLLM inference server
361+
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
362+
363+
# Terminal 2: Run GRPO training with OpenEnv
364+
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py
365+
```
366+
367+
### Results
368+
369+
The resulting model improves it's performance on the game, both by reducing the number of repetitions and by increasing the number of correct guesses. However, the the Qwen3-1.7B model we trained is not able to consistently win the game. The following reward curve shows the coverage of the model's guesses and the coverage of correct Y and G letters.
370+
371+
<iframe src="https://burtenshaw-wordle-grpo.hf.space/?project=group-Qwen-Qwen3-17B&metrics=train/rewards/reward_coverage/mean&runs=run-2025-10-26_09-39-49&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe>
372+
373+
We experimented larger models like `gpt-oss-20b` and found that model was able to consistently win the game. However, this requires a lot of compute to train and the model. Why not try this out yourself?

0 commit comments

Comments
 (0)