-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathbasic.py
60 lines (46 loc) · 1.84 KB
/
basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#
# Copyright (C) 2025 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
#
import datasets
from transformers import BatchEncoding
from llmart import TaggedTokenizer, Transform
class BasicConfig(datasets.BuilderConfig):
def __init__(
self,
*args,
tokenizer: TaggedTokenizer,
mark_prompt: Transform,
mark_completion: Transform,
**kwargs,
):
super().__init__(*args, **kwargs)
self.tokenizer = tokenizer
self.mark_prompt = mark_prompt
self.mark_completion = mark_completion
class BasicBuilder(datasets.GeneratorBasedBuilder):
BUILDER_CONFIG_CLASS = BasicConfig
def _info(self):
return datasets.DatasetInfo()
def _split_generators(self, dl_manager):
del dl_manager
return [datasets.SplitGenerator(name="train")]
def _generate_examples(self, **kwargs):
mark_prompt: Transform = self.config.mark_prompt # type: ignore
mark_completion: Transform = self.config.mark_completion # type: ignore
# Create conversation data structure and mark parts we care about
conv = [
dict(role="user", content=mark_prompt("Tell me about the planet Saturn.")),
dict(role="assistant", content=mark_completion("NO WAY JOSE")),
]
# Turn conversation into input_ids and masks
inputs: BatchEncoding = self.config.tokenizer.apply_chat_template( # type: ignore
conv, return_tensors="pt", return_dict=True
)
# Construct labels from response_mask
response_mask = inputs["response_mask"] # type: ignore
inputs["labels"] = inputs["input_ids"].clone() # type: ignore
inputs["labels"][~response_mask] = -100 # type: ignore
# Remove batch axis which apply_chat_template adds
yield 0, {k: v[0] for k, v in inputs.items()}