Skip to content

Commit f861145

Browse files
authored
[Audio] People's Speech dataset and tracer tool (#1086)
## Purpose ## * Provide a predefined audio dataset for * Testing traceability of audio models * e2e tests with audio models * Simpler examples (blog) ## Prerequisites ## * #1030 * #1085 ## Changes ## * Implement `PeoplesSpeech` dataset * Because of the more complex nature of audio processors, this dataset needs to hardcode some processing logic specific to models * Assumes that most processing is similar to whisper processing, which seems to be the standard * Because processing changes depending on the model, this means mapped outputs cannot be cached * Add `load_from_cache_file` argument to preprocessing mapping (this was overlooked before) * Integrate dataset with tracing debugger tool ## Testing ## ```bash llmcompressor.trace \ --model_id openai/whisper-large-v2\ --model_class TraceableWhisperForConditionalGeneration\ --modality audio ``` Traceable definition of qwen2_audio is not finished yet, but this loads and is accepted as valid input ```bash llmcompressor.trace \ --model_id Qwen/Qwen2-Audio-7B\ --model_class Qwen2AudioForConditionalGeneration\ --modality audio ``` --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent c1fe865 commit f861145

File tree

3 files changed

+95
-0
lines changed

3 files changed

+95
-0
lines changed

src/llmcompressor/transformers/finetune/data/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .flickr_30k import Flickr30K
99
from .gsm8k import GSM8KDataset
1010
from .open_platypus import OpenPlatypusDataset
11+
from .peoples_speech import PeoplesSpeech
1112
from .ptb import PtbDataset
1213
from .ultrachat_200k import UltraChatDataset
1314
from .wikitext import WikiTextDataset
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from copy import deepcopy
2+
from typing import TYPE_CHECKING, Any, Dict
3+
4+
from datasets.formatting.formatting import LazyRow
5+
from loguru import logger
6+
7+
from llmcompressor.transformers.finetune.data import TextGenerationDataset
8+
from llmcompressor.transformers.finetune.data.base import get_columns
9+
from llmcompressor.typing import DatasetType, Processor
10+
11+
if TYPE_CHECKING:
12+
from llmcompressor.transformers import DataTrainingArguments as DataArgs
13+
14+
15+
@TextGenerationDataset.register(name="peoples_speech")
16+
class PeoplesSpeech(TextGenerationDataset):
17+
"""
18+
ML Commons People's Speech audio dataset
19+
20+
Unfortunately, due to the specialized nature of audio model preprocessing, some
21+
model specific code must be defined here. This dataset has been tested with the
22+
WhisperForConditionalGeneration and Qwen2AudioForConditionalGeneration model classes
23+
24+
:param data_args: configuration settings for dataset loading
25+
:param split: split from dataset to load, for instance `test` or `train[:5%]`
26+
:param processor: processor or tokenizer to use on dataset
27+
"""
28+
29+
def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
30+
data_args = deepcopy(data_args)
31+
data_args.dataset = "MLCommons/peoples_speech"
32+
data_args.dataset_config_name = "test"
33+
if not data_args.overwrite_cache:
34+
logger.warning(
35+
"Because audio processors are more complex, dataset mapping functions "
36+
"vary with model architecture and their results cannot be cached. "
37+
"Setting overwrite_cache=True"
38+
)
39+
data_args.overwrite_cache = True
40+
self.processor_type = processor.__class__.__name__
41+
42+
super().__init__(data_args=data_args, split=split, processor=processor)
43+
44+
def dataset_template(self, example):
45+
audio = example["audio"]["array"]
46+
sampling_rate = example["audio"]["sampling_rate"]
47+
48+
if self.processor_type == "Qwen2AudioProcessor":
49+
messages = [
50+
{"role": "user", "content": [{"audio": None}]},
51+
{"role": "user", "content": [{"text": "What did the person say?"}]},
52+
]
53+
text = self.processor.apply_chat_template(messages)
54+
return {"audios": [audio], "sampling_rate": sampling_rate, "text": text}
55+
56+
else:
57+
# chat template decoder ids are appended later by self.processor.__call__
58+
text = " " + example["text"].capitalize()
59+
return {"audio": audio, "sampling_rate": sampling_rate, "text": text}
60+
61+
def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType:
62+
if self.processor_type == "WhisperProcessor":
63+
tokenizer_args = ["audio", "sampling_rate", "text"]
64+
column_names = get_columns(dataset)
65+
66+
return dataset.remove_columns(list(set(column_names) - set(tokenizer_args)))
67+
68+
else:
69+
return super().filter_tokenizer_args(dataset)
70+
71+
def tokenize(self, data: LazyRow) -> Dict[str, Any]:
72+
if self.processor_type == "WhisperProcessor":
73+
inputs = self.processor(
74+
audio=data["audio"],
75+
sampling_rate=data["sampling_rate"],
76+
text=data["text"],
77+
add_special_tokens=True,
78+
return_tensors="pt",
79+
)
80+
81+
# TODO: inputs["input_features"] is a float dtype, which may conflict with
82+
# the dtype of the model. Add logic to in data pipeline to move inputs to
83+
# the matching model device and dtype
84+
inputs["decoder_input_ids"] = inputs["labels"]
85+
del inputs["labels"]
86+
87+
return inputs
88+
89+
else:
90+
return super().tokenize(data)

src/llmcompressor/transformers/tracing/debug.py

+4
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ def get_dataset_kwargs(modality: str) -> Dict[str, str]:
117117
"dataset": "flickr",
118118
"splits": {"calibration": "test[:1]"},
119119
},
120+
"audio": {
121+
"dataset": "peoples_speech",
122+
"splits": {"calibration": "test[:1]"},
123+
},
120124
}
121125

122126
if modality not in dataset_kwargs:

0 commit comments

Comments
 (0)