You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/dataset_formats.md
-72Lines changed: 0 additions & 72 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -132,8 +132,6 @@ preference_example = {
132
132
}
133
133
```
134
134
135
-
Conversational datasets are useful for training chat models, but must be converted into a standard format before being used with TRL trainers. This is typically done using chat templates specific to the model being used. For more information, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
136
-
137
135
#### Tool Calling
138
136
139
137
Some chat templates support *tool calling*, which allows the model to interact with external functions—referred to as **tools**—during generation. This extends the conversational capabilities of the model by enabling it to output a `"tool_calls"` field instead of a standard `"content"` message whenever it decides to invoke a tool.
@@ -405,76 +403,6 @@ Choosing the right dataset type depends on the task you are working on and the s
405
403
|[`SFTTrainer`]|[Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion)|
406
404
|[`XPOTrainer`]|[Prompt-only](#prompt-only)|
407
405
408
-
> [!TIP]
409
-
> TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
410
-
> For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
411
-
412
-
## Working with conversational datasets in TRL
413
-
414
-
Conversational datasets are increasingly common, especially for training chat models. However, some TRL trainers don't support conversational datasets in their raw format. (For more information, see [issue #2071](https://github.com/huggingface/trl/issues/2071).) These datasets must first be converted into a standard format.
415
-
Fortunately, TRL offers tools to easily handle this conversion, which are detailed below.
416
-
417
-
### Converting a conversational dataset into a standard dataset
418
-
419
-
To convert a conversational dataset into a standard dataset, you need to *apply a chat template* to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use.
420
-
421
-
For detailed instructions on using chat templating, refer to the [Chat templating section in the `transformers` documentation](https://huggingface.co/docs/transformers/en/chat_templating).
422
-
423
-
In TRL, the method you apply to convert the dataset will vary depending on the task. Fortunately, TRL provides a helper function called [`apply_chat_template`] to simplify this process. Here's an example of how to use it:
# {'prompt': ['<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n',
458
-
# '<|user|>\nWhere is the sun?<|end|>\n<|assistant|>\n'],
459
-
# 'completion': ['It is blue.<|end|>\n<|endoftext|>', 'In the sky.<|end|>\n<|endoftext|>']}
460
-
```
461
-
462
-
> [!WARNING]
463
-
> We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation.
464
-
> For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
465
-
466
-
> [!WARNING]
467
-
> It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
># {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
473
-
># 'completion': 'It is blue.<|im_end|>\n'}
474
-
>```
475
-
>
476
-
> Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
477
-
478
406
## Using any dataset with TRL: preprocessing and conversion
479
407
480
408
Many datasets come in formats tailored to specific tasks, which might not be directly compatible with TRL. To use such datasets with TRL, you may need to preprocess and convert them into the required format.
> Section under construction. Feel free to contribute!
5
-
6
3
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, with more to come. The kernel works out of the box with [FlashAttention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
7
4
8
5
With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance.
@@ -11,19 +8,71 @@ With this memory reduction, you can potentially turn off `cpu_offloading` or gra
To learn more about how to create custom environments, see the [OpenEnv documentation](https://github.com/meta-pytorch/OpenEnv/blob/main/src/envs/README.md).
159
+
160
+
## Advanced Example
161
+
162
+
Let's level this up a bit by training a model to interact with a more complex environment. We'll use the game word guessing game [wordle](https://www.nytimes.com/games/wordle/index.html) from the `textarena` environment.
163
+
164
+
### The TextArena Environment
165
+
166
+
[TextArena](https://huggingface.co/papers/2504.11442) is an open-source collection of competitive text-based games designed to evaluate reasoning skills in LLMs using textual games like Wordle, Snake, Tic-Tac-Toe, and more. Research has shown that such games improve model performance on reasoning tasks.
167
+
168
+

169
+
170
+
We will use the `textarena` environment to train a model to play Wordle. The environment is a simple text based response environment that allows the model to interact with the game by making guesses and receive feedback on them.
171
+
172
+
### Wordle
173
+
174
+
Wordle is a useful game to train a model on because it requires the model to reason about the word and the feedback provided by the environment. Also, it is a purely language based game that requires no external tools or knowledge. Furthermore, we found that models from 1 billion parameters and up are able to improve on wordle and only require 8 tokens to generate a guess, which makes the game a good benchmark to experiment with Reinforcement Learning environments without significant compute requirements.
175
+
176
+
> [!NOTE] How does Wordle work?
177
+
> Wordle is a word guessing game where the player has to guess a 5-letter word. The player can make 6 guesses, and for each guess, the environment will provide feedback on the correctness of the guess. The player wins if they guess the word in 6 guesses or less. It challenges the model to generate words that are likely to be correct, and to learn from the feedback provided by the environment.
178
+
>
179
+
> For example, if the wordle environment returns the following feedback:
180
+
>
181
+
> ```
182
+
> G U E S S
183
+
> X G Y X X
184
+
> ```
185
+
> The model has guessed the word "GUESS" and the environment has provided feedback as the letters X, G, and Y. Referring to colors in the original game blank, green, and yellow. From this feedback, the model should learn that the word is "GUESS" is incorrect. The letter "E" is in the word, but in the wrong position. The letter "U" is correct and in the correct position.
186
+
187
+
In the TextArena environment, reward is only given when the model wins the game. The reward is 1.0 if the model wins, and 0.0 otherwise. This is not a very efficient reward signal for the model, so we have added a number of custom reward functions to the script to help the model learn to play the game. The extensible nature of `reward_funcs` and `rollout_func` allows you to add any custom reward function you want to the script.
188
+
189
+
### Rollout Function
190
+
191
+
The rollout function runs one full Wordle episode, prompting the model for a guess each turn and capturing both environment rewards and auxiliary signals such as letter coverage and repetition penalties.
192
+
193
+
```python
194
+
def rollout_once(
195
+
env: TextArenaEnv,
196
+
tokenizer: AutoTokenizer,
197
+
args: GRPOConfig,
198
+
dataset_prompt: str,
199
+
cli_args: argparse.Namespace,
200
+
system_prompt: str,
201
+
) -> dict[str, list]:
202
+
result = env.reset()
203
+
observation = result.observation
204
+
205
+
prompt_ids: list[int] = []
206
+
completion_ids: list[int] = []
207
+
logprobs: list[float] = []
208
+
raw_rewards: list[float] = []
209
+
green_scores: list[float] = []
210
+
yellow_scores: list[float] = []
211
+
repetition_scores: list[float] = []
212
+
correct_scores: list[float] = []
213
+
guess_counts: dict[str, int] = {}
214
+
215
+
for _turn in range(cli_args.max_turns):
216
+
# when the game is over the environment will return a done=True
217
+
if result.done:
218
+
break
219
+
220
+
# set up the prompt for the model
221
+
base_prompt = observation.prompt or dataset_prompt
correct_reward_value = correct_scores[-1] if correct_scores else (raw_rewards[-1] if raw_rewards else 0.0)
278
+
279
+
return {
280
+
"prompt_ids": prompt_ids,
281
+
"completion_ids": completion_ids,
282
+
"logprobs": logprobs,
283
+
"raw_rewards": raw_rewards,
284
+
"correct_reward": correct_reward_value,
285
+
"green_reward": green_scores[-1] if green_scores else 0.0,
286
+
"yellow_reward": yellow_scores[-1] if yellow_scores else 0.0,
287
+
"repetition_reward": repetition_scores[-1] if repetition_scores else 0.0,
288
+
}
289
+
```
290
+
291
+
The environment has a reward signal based on the completion of the game. We found that most models struggle to ever win the game, so we have added a number of custom reward functions to the script to help the model learn to play the game more iteratively. At first, the model will learn to cover new letters and avoid repeating guesses. As it improves, it will learn to win the game.
292
+
293
+
### Reward Functions
294
+
295
+
We log four reward streams that encourage the model to solve the puzzle, cover new letters, and avoid repeating guesses:
296
+
297
+
-`reward_correct`: final win/loss signal from the environment.
298
+
-`reward_greens`: density of green letters in the last feedback.
299
+
-`reward_yellows`: density of yellow letters in the last feedback.
300
+
-`reward_repetition`: penalty for guessing the same token multiple times.
rewards = kwargs.get("repetition_reward") if kwargs elseNone
320
+
return [float(r) for r in rewards] if rewards isnotNoneelse [0.0] *len(completions)
321
+
```
322
+
323
+
### Training the Model
324
+
325
+
The training script wires the custom rollout and rewards into `GRPOTrainer`. The CLI exposes the configuration used during development as defaults, so you can override endpoints or hyperparameters at launch time.
326
+
327
+
```python
328
+
parser = argparse.ArgumentParser()
329
+
# ... add CLI arguments with sensible defaults ...
The resulting model improves it's performance on the game, both by reducing the number of repetitions and by increasing the number of correct guesses. However, the the Qwen3-1.7B model we trained is not able to consistently win the game. The following reward curve shows the coverage of the model's guesses and the coverage of correct Y and G letters.
We experimented larger models like `gpt-oss-20b` and found that model was able to consistently win the game. However, this requires a lot of compute to train and the model. Why not try this out yourself?
0 commit comments