|
| 1 | +from copy import deepcopy |
| 2 | +from typing import TYPE_CHECKING, Any, Dict |
| 3 | + |
| 4 | +from datasets.formatting.formatting import LazyRow |
| 5 | +from loguru import logger |
| 6 | + |
| 7 | +from llmcompressor.transformers.finetune.data import TextGenerationDataset |
| 8 | +from llmcompressor.transformers.finetune.data.base import get_columns |
| 9 | +from llmcompressor.typing import DatasetType, Processor |
| 10 | + |
| 11 | +if TYPE_CHECKING: |
| 12 | + from llmcompressor.transformers import DataTrainingArguments as DataArgs |
| 13 | + |
| 14 | + |
| 15 | +@TextGenerationDataset.register(name="peoples_speech") |
| 16 | +class PeoplesSpeech(TextGenerationDataset): |
| 17 | + """ |
| 18 | + ML Commons People's Speech audio dataset |
| 19 | +
|
| 20 | + Unfortunately, due to the specialized nature of audio model preprocessing, some |
| 21 | + model specific code must be defined here. This dataset has been tested with the |
| 22 | + WhisperForConditionalGeneration and Qwen2AudioForConditionalGeneration model classes |
| 23 | +
|
| 24 | + :param data_args: configuration settings for dataset loading |
| 25 | + :param split: split from dataset to load, for instance `test` or `train[:5%]` |
| 26 | + :param processor: processor or tokenizer to use on dataset |
| 27 | + """ |
| 28 | + |
| 29 | + def __init__(self, data_args: "DataArgs", split: str, processor: Processor): |
| 30 | + data_args = deepcopy(data_args) |
| 31 | + data_args.dataset = "MLCommons/peoples_speech" |
| 32 | + data_args.dataset_config_name = "test" |
| 33 | + if not data_args.overwrite_cache: |
| 34 | + logger.warning( |
| 35 | + "Because audio processors are more complex, dataset mapping functions " |
| 36 | + "vary with model architecture and their results cannot be cached. " |
| 37 | + "Setting overwrite_cache=True" |
| 38 | + ) |
| 39 | + data_args.overwrite_cache = True |
| 40 | + self.processor_type = processor.__class__.__name__ |
| 41 | + |
| 42 | + super().__init__(data_args=data_args, split=split, processor=processor) |
| 43 | + |
| 44 | + def dataset_template(self, example): |
| 45 | + audio = example["audio"]["array"] |
| 46 | + sampling_rate = example["audio"]["sampling_rate"] |
| 47 | + |
| 48 | + if self.processor_type == "Qwen2AudioProcessor": |
| 49 | + messages = [ |
| 50 | + {"role": "user", "content": [{"audio": None}]}, |
| 51 | + {"role": "user", "content": [{"text": "What did the person say?"}]}, |
| 52 | + ] |
| 53 | + text = self.processor.apply_chat_template(messages) |
| 54 | + return {"audios": [audio], "sampling_rate": sampling_rate, "text": text} |
| 55 | + |
| 56 | + else: |
| 57 | + # chat template decoder ids are appended later by self.processor.__call__ |
| 58 | + text = " " + example["text"].capitalize() |
| 59 | + return {"audio": audio, "sampling_rate": sampling_rate, "text": text} |
| 60 | + |
| 61 | + def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType: |
| 62 | + if self.processor_type == "WhisperProcessor": |
| 63 | + tokenizer_args = ["audio", "sampling_rate", "text"] |
| 64 | + column_names = get_columns(dataset) |
| 65 | + |
| 66 | + return dataset.remove_columns(list(set(column_names) - set(tokenizer_args))) |
| 67 | + |
| 68 | + else: |
| 69 | + return super().filter_tokenizer_args(dataset) |
| 70 | + |
| 71 | + def tokenize(self, data: LazyRow) -> Dict[str, Any]: |
| 72 | + if self.processor_type == "WhisperProcessor": |
| 73 | + inputs = self.processor( |
| 74 | + audio=data["audio"], |
| 75 | + sampling_rate=data["sampling_rate"], |
| 76 | + text=data["text"], |
| 77 | + add_special_tokens=True, |
| 78 | + return_tensors="pt", |
| 79 | + ) |
| 80 | + |
| 81 | + # TODO: inputs["input_features"] is a float dtype, which may conflict with |
| 82 | + # the dtype of the model. Add logic to in data pipeline to move inputs to |
| 83 | + # the matching model device and dtype |
| 84 | + inputs["decoder_input_ids"] = inputs["labels"] |
| 85 | + del inputs["labels"] |
| 86 | + |
| 87 | + return inputs |
| 88 | + |
| 89 | + else: |
| 90 | + return super().tokenize(data) |
0 commit comments