Skip to content

Commit 7ee5248

Browse files
authored
[Feature] return_assistant_tokens_mask for SFT (#3014)
1 parent 7576e47 commit 7ee5248

File tree

2 files changed

+118
-6
lines changed

2 files changed

+118
-6
lines changed

test/llm/test_data.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,24 @@ def test_content_base(self):
318318
The result is""",
319319
]
320320

321+
def test_history_assistant_mask(self):
322+
from transformers import AutoTokenizer
323+
324+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
325+
for test_case in self.TEST_CASES:
326+
history = History.from_text(test_case, chat_template_name="qwen")
327+
proc = history.apply_chat_template(
328+
tokenizer=tokenizer,
329+
chat_template_name="qwen",
330+
add_generation_prompt=False,
331+
return_dict=True,
332+
return_assistant_tokens_mask=True,
333+
)
334+
if "assistant" in history.role:
335+
assert proc["assistant_masks"].any()
336+
else:
337+
assert not proc["assistant_masks"].any()
338+
321339
def test_history_completion(self):
322340
"""Test the History class's handling of complete and incomplete messages."""
323341

torchrl/data/llm/chat.py

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111

1212
import torch
1313

14+
1415
from tensordict import lazy_stack, LazyStackedTensorDict, list_to_stack, TensorClass
1516
from tensordict.utils import _maybe_correct_neg_dim
1617

1718
from torchrl._utils import logger as torchrl_logger
1819

20+
1921
_CHAT_TEMPLATES = {
2022
"chatml_format": """{% for message in messages %}
2123
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
@@ -24,7 +26,64 @@
2426
{{- '<|im_start|>assistant\n' }}
2527
{%- endif %}
2628
""",
27-
"qwen": """'{%- if tools %}\n {{- \'<|im_start|>system\\n\' }}\n {%- if messages[0][\'role\'] == \'system\' %}\n {{- messages[0][\'content\'] }}\n {%- else %}\n {{- \'You are a helpful assistant.\' }}\n {%- endif %}\n {{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}\n {%- for tool in tools %}\n {{- "\\n" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}\n{%- else %}\n {%- if messages[0][\'role\'] == \'system\' %}\n {{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }}\n {%- else %}\n {{- \'<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n\' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}\n {{- \'<|im_start|>\' + message.role + \'\\n\' + message.content + \'<|im_end|>\' + \'\\n\' }}\n {%- elif message.role == "assistant" %}\n {{- \'<|im_start|>\' + message.role }}\n {%- if message.content %}\n {{- \'\\n\' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- \'\\n<tool_call>\\n{"name": "\' }}\n {{- tool_call.name }}\n {{- \'", "arguments": \' }}\n {{- tool_call.arguments | tojson }}\n {{- \'}\\n</tool_call>\' }}\n {%- endfor %}\n {{- \'<|im_end|>\\n\' }}\n {%- elif message.role == "tool" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}\n {{- \'<|im_start|>user\' }}\n {%- endif %}\n {{- \'\\n<tool_response>\\n\' }}\n {{- message.content }}\n {{- \'\\n</tool_response>\' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}\n {{- \'<|im_end|>\\n\' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|im_start|>assistant\\n\' }}\n{%- endif %}\n'""",
29+
"qwen": """
30+
{%- if tools %}
31+
{{- '<|im_start|>system\\n' }}
32+
{%- if messages[0]['role'] == 'system' %}
33+
{{- messages[0]['content'] }}
34+
{%- else %}
35+
{{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
36+
{%- endif %}
37+
{{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}
38+
{%- for tool in tools %}
39+
{{- "\\n" }}
40+
{{- tool | tojson }}
41+
{%- endfor %}
42+
{{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}
43+
{%- else %}
44+
{%- if messages[0]['role'] == 'system' %}
45+
{{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}
46+
{%- else %}
47+
{{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}
48+
{%- endif %}
49+
{%- endif %}
50+
{%- for message in messages %}
51+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
52+
{{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}
53+
{%- elif (message.role == "assistant" and not message.tool_calls) %}
54+
{% generation %} {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }} {% endgeneration %}
55+
{%- elif message.role == "assistant" %}
56+
{% generation %}{{- '<|im_start|>' + message.role }}
57+
{%- if message.content %}
58+
{{- '\\n' + message.content }}
59+
{%- endif %}
60+
{%- for tool_call in message.tool_calls %}
61+
{%- if tool_call.function is defined %}
62+
{%- set tool_call = tool_call.function %}
63+
{%- endif %}
64+
{{- '\\n<tool_call>\\n{\\\"name\\\": \\\"' }}
65+
{{- tool_call.name }}
66+
{{- '\\\", \\\"arguments\\\": ' }}
67+
{{- tool_call.arguments | tojson }}
68+
{{- '}\\n</tool_call>' }}
69+
{%- endfor %}
70+
{{- '<|im_end|>\\n' }}{% endgeneration %}
71+
{%- elif message.role == "tool" %}
72+
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
73+
{{- '<|im_start|>user' }}
74+
{%- endif %}
75+
{{- '\\n<tool_response>\\n' }}
76+
{{- message.content }}
77+
{{- '\\n</tool_response>' }}
78+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
79+
{{- '<|im_end|>\\n' }}
80+
{%- endif %}
81+
{%- endif %}
82+
{%- endfor %}
83+
{%- if add_generation_prompt %}
84+
{% generation %}{{- '<|im_start|>assistant\\n' }}{% endgeneration %}
85+
{%- endif %}
86+
""",
2887
}
2988

3089

@@ -210,12 +269,14 @@ def apply_chat_template(
210269
tokenizer: transformers.AutoTokenizer | transformers.AutoProcessor, # noqa
211270
add_generation_prompt: bool = True,
212271
chat_template: str | None = None,
272+
chat_template_name: Literal["chatml_format", "qwen"] | None = None,
213273
continue_final_message: bool = False,
214-
tokenize: bool = False,
274+
tokenize: bool | None = None,
215275
padding: bool | str = False,
216276
truncation: bool | str = False,
217-
return_tensors: str | None = "pt",
218-
return_dict: bool = False,
277+
return_tensors: str | None = None,
278+
return_dict: bool | None = None,
279+
return_assistant_tokens_mask: bool = False,
219280
**kwargs,
220281
):
221282
"""Applies a chat template to the history.
@@ -224,37 +285,68 @@ def apply_chat_template(
224285
tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use.
225286
add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to `True`.
226287
chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
288+
chat_template_name (Literal["chatml_format", "qwen"], optional): The name of the chat template to use.
289+
Prevalent over `tokenizer.chat_template`. Defaults to `None`.
227290
continue_final_message (bool, optional): Whether to continue the final message. Defaults to `False`.
228291
tokenize (bool, optional): Whether to tokenize the output. Defaults to `False`.
229292
padding (bool | str, optional): The padding strategy to use. Defaults to `False`.
230293
truncation (bool | str, optional): The truncation strategy to use. Defaults to `False`.
231294
return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
232295
return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`.
296+
return_assistant_tokens_mask (bool, optional): Whether to return a mask of the assistant generated tokens.
297+
For tokens generated by the assistant, the mask will contain `1`.
298+
For user and system tokens, the mask will contain `0`.
299+
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
300+
Defaults to `False`.
301+
302+
.. note:: By default, the `"qwen"` chat template does not support this functionality. A modified version of the template
303+
can be used by setting `chat_template_name="qwen"`, which will override the default template from the tokenizer.
304+
For other tokenizers, similar edits can be made to the template and passed to the method via the `chat_template` argument.
305+
233306
**kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method.
234307
235308
Returns:
236309
The formatted history.
237310
"""
238311
if chat_template is None:
239-
if tokenizer is None:
312+
if chat_template_name is not None:
313+
chat_template = _CHAT_TEMPLATES[chat_template_name]
314+
chat_template_name = None
315+
elif tokenizer is None:
240316
raise RuntimeError(
241317
"You must specify a tokenizer to use when chat_template is not specified."
242318
)
243-
chat_template = tokenizer.chat_template
319+
else:
320+
chat_template = tokenizer.chat_template
244321
if chat_template is None:
245322
chat_template = _CHAT_TEMPLATES["chatml_format"]
323+
if tokenize is None:
324+
if return_assistant_tokens_mask or return_tensors is not None:
325+
tokenize = True
326+
else:
327+
tokenize = False
328+
if tokenize:
329+
if return_tensors is None:
330+
return_tensors = "pt"
331+
if return_dict is None and return_assistant_tokens_mask:
332+
return_dict = True
333+
elif return_dict is None:
334+
return_dict = False
335+
246336
if self.ndim > 1:
247337
return [
248338
self[i].apply_chat_template(
249339
tokenizer=tokenizer,
250340
add_generation_prompt=add_generation_prompt,
251341
chat_template=chat_template,
342+
chat_template_name=chat_template_name,
252343
tokenize=tokenize,
253344
padding=padding,
254345
truncation=truncation,
255346
return_tensors=return_tensors,
256347
continue_final_message=continue_final_message,
257348
return_dict=return_dict,
349+
return_assistant_tokens_mask=return_assistant_tokens_mask,
258350
**kwargs,
259351
)
260352
for i in range(self.batch_size[0])
@@ -274,6 +366,8 @@ def apply_chat_template(
274366
return_tensors=return_tensors,
275367
continue_final_message=continue_final_message,
276368
return_dict=return_dict,
369+
return_assistant_tokens_mask=return_assistant_tokens_mask,
370+
**kwargs,
277371
)
278372

279373
@classmethod

0 commit comments

Comments
 (0)