Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable streaming in data preprocessor #437

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c394509
feat: first unit test to test if streaming works, example template fo…
willmj Jan 8, 2025
7895082
test: first draft of tests
willmj Jan 9, 2025
3956125
feat: enable streaming explicitly passing argument
willmj Jan 10, 2025
dc851bb
fix: add back if data return false
willmj Jan 14, 2025
1b54137
feat: (currently broken) move streaming to data preprocessor config i…
willmj Jan 16, 2025
4bedbe6
feat: add streaming param in data preprocessor config
willmj Jan 17, 2025
6e04cf0
fix: split batches
willmj Jan 28, 2025
88ebe95
fix: pass split batches correctly to train args
willmj Jan 28, 2025
390e273
fmt + lint
willmj Jan 28, 2025
906a633
Merge branch 'main' into datapreprocessor-streaming
willmj Jan 28, 2025
5980c31
fix: column names conditional
willmj Jan 28, 2025
50316a4
fix: logging
willmj Jan 29, 2025
52de5a1
fmt
willmj Jan 29, 2025
c3d9679
fix: validate mergeable datasets for IterableDatasets
willmj Jan 29, 2025
10cf8db
tests: more streaming tests
willmj Jan 29, 2025
bf84480
fix: use resolve_features function for iterable datasets, add packing…
willmj Jan 30, 2025
b4154e2
docs: docstring include streaming
willmj Jan 30, 2025
25d0a98
fix: indentation
willmj Jan 30, 2025
ccccc9f
fix: resolve iterable dataset features save
willmj Jan 31, 2025
77cc16e
docs: more specific logger info
willmj Feb 5, 2025
7f2f9c4
merge: branch 'main' into datapreprocessor-streaming
willmj Feb 5, 2025
ceea99e
fix: remove streaming variable from _process_dataset_configs
willmj Feb 6, 2025
b7f1016
fmt
willmj Feb 6, 2025
65ba578
fix: add check for iterabledatasetdict, nits
willmj Feb 6, 2025
c61921f
fix: make streaming part of load kwargs
willmj Feb 6, 2025
e854994
fix: check by dataset instead of flag
willmj Feb 6, 2025
b76621b
fix: move split_batches update to process_dataargs
willmj Feb 6, 2025
786e1fc
docs: streaming documentation
willmj Feb 6, 2025
2cbe114
fix: set streaming to False when loading 1 example for column names
willmj Feb 7, 2025
d356ad3
docs: add docstrings to unit tests and add missing docstrings from pr…
willmj Feb 10, 2025
d2019ea
fix: remove packing pretokenized case as it will be fixed once trl PR…
willmj Feb 11, 2025
1a3a3dd
merge: branch 'main' into datapreprocessor-streaming
willmj Feb 11, 2025
8bbb0f4
merge: main into datapreprocessor-streaming
willmj Feb 20, 2025
f6b031f
test: remove test-case
willmj Feb 20, 2025
2805fb9
test: make test compatible with streamin PR
willmj Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/artifacts/predefined_data_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling.yaml"
)
DATA_CONFIG_YAML_STREAMING = os.path.join(
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking_streaming.yaml"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
dataprocessor:
type: default
streaming: true
datasets:
- name: apply_custom_data_template
data_paths:
- "FILE_PATH"
data_handlers:
- name: apply_custom_data_formatting_template
arguments:
remove_columns: all
batched: false
fn_kwargs:
dataset_text_field: "dataset_text_field"
template: "dataset_template"
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
dataprocessor:
type: default
streaming: true
datasets:
- name: text_dataset_input_output_masking
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_field_name: input
output_field_name: output
29 changes: 29 additions & 0 deletions tests/data/test_data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,35 @@ def test_apply_custom_formatting_template():
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response


def test_apply_custom_formatting_template_iterable():
json_dataset = datasets.load_dataset(
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL, streaming=True
)
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
formatted_dataset_field = "formatted_data_field"
formatted_dataset = json_dataset.map(
apply_custom_data_formatting_template,
fn_kwargs={
"tokenizer": tokenizer,
"dataset_text_field": formatted_dataset_field,
"template": template,
},
)
# First response from the data file that is read.
expected_response = (
"### Input: @HMRCcustomers No this is my first job"
+ " \n\n ### Response: no complaint"
+ tokenizer.eos_token
)

first_sample = next(iter(formatted_dataset["train"]))

# a new dataset_text_field is created in Dataset
assert formatted_dataset_field in first_sample
assert first_sample[formatted_dataset_field] == expected_response


def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
"""Tests that the formatting function will throw error if wrong keys are passed to template"""
json_dataset = datasets.load_dataset(
Expand Down
197 changes: 181 additions & 16 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tempfile

# Third Party
from datasets import Dataset
from datasets import Dataset, IterableDataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
from trl import DataCollatorForCompletionOnlyLM
import datasets
Expand All @@ -33,6 +33,7 @@
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
DATA_CONFIG_PRETOKENIZE_JSON_DATA_YAML,
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
DATA_CONFIG_YAML_STREAMING,
)
from tests.artifacts.testdata import (
MODEL_NAME,
Expand Down Expand Up @@ -140,7 +141,10 @@ def test_load_dataset_with_datafile(datafile, column_names):
processor_config=DataPreProcessorConfig(), tokenizer=None
)
load_dataset = processor.load_dataset(
datasetconfig=None, splitName="train", datafile=datafile
datasetconfig=None,
streaming=processor.processor_config.streaming,
splitName="train",
datafile=datafile,
)
assert set(load_dataset.column_names) == column_names

Expand All @@ -155,7 +159,10 @@ def test_load_dataset_with_hf_dataset(hf_dataset, splitName):
processor_config=DataPreProcessorConfig(), tokenizer=None
)
load_dataset = processor.load_dataset(
datasetconfig=datasetconfig, splitName=splitName, datafile=None
datasetconfig=datasetconfig,
streaming=processor.processor_config.streaming,
splitName=splitName,
datafile=None,
)
assert isinstance(load_dataset, Dataset)

Expand Down Expand Up @@ -250,7 +257,10 @@ def test_load_dataset_with_datasetconfig(
processor_config=DataPreProcessorConfig(), tokenizer=None
)
load_dataset = processor.load_dataset(
datasetconfig=datasetconfig, splitName="train", datafile=None
datasetconfig=datasetconfig,
streaming=processor.processor_config.streaming,
splitName="train",
datafile=None,
)
assert set(load_dataset.column_names) == column_names

Expand Down Expand Up @@ -280,7 +290,10 @@ def test_load_dataset_with_non_exist_path(data_paths, datasetconfigname):
)
with pytest.raises((datasets.exceptions.DatasetNotFoundError, ValueError)):
processor.load_dataset(
datasetconfig=datasetconfig, splitName="train", datafile=None
datasetconfig=datasetconfig,
streaming=processor.processor_config.streaming,
splitName="train",
datafile=None,
)


Expand All @@ -302,7 +315,10 @@ def test_load_dataset_with_datasetconfig_incorrect_builder(
)
with pytest.raises(pyarrow.lib.ArrowInvalid):
processor.load_dataset(
datasetconfig=datasetconfig, splitName="train", datafile=None
datasetconfig=datasetconfig,
streaming=processor.processor_config.streaming,
splitName="train",
datafile=None,
)


Expand Down Expand Up @@ -331,7 +347,10 @@ def test_load_dataset_with_dataconfig_and_datafile(datafile, datasetconfigname):
)
with pytest.raises(ValueError):
processor.load_dataset(
datasetconfig=datasetconfig, splitName="train", datafile=datafile
datasetconfig=datasetconfig,
streaming=processor.processor_config.streaming,
splitName="train",
datafile=datafile,
)


Expand Down Expand Up @@ -361,7 +380,10 @@ def test_load_dataset_with_dataconfig_and_datafolder(datasetconfig, column_names
processor_config=DataPreProcessorConfig(), tokenizer=None
)
load_dataset = processor.load_dataset(
datasetconfig=datasetconfig, splitName="train", datafile=None
datasetconfig=datasetconfig,
streaming=processor.processor_config.streaming,
splitName="train",
datafile=None,
)
assert set(load_dataset.column_names) == column_names

Expand All @@ -383,7 +405,10 @@ def test_load_dataset_with_dataconfig_and_datafolder_incorrect_builder(datasetco
)
with pytest.raises(pyarrow.lib.ArrowInvalid):
processor.load_dataset(
datasetconfig=datasetconfig, splitName="train", datafile=None
datasetconfig=datasetconfig,
streaming=processor.processor_config.streaming,
splitName="train",
datafile=None,
)


Expand All @@ -393,7 +418,12 @@ def test_load_dataset_without_dataconfig_and_datafile():
processor_config=DataPreProcessorConfig(), tokenizer=None
)
with pytest.raises(ValueError):
processor.load_dataset(datasetconfig=None, splitName="train", datafile=None)
processor.load_dataset(
datasetconfig=None,
streaming=processor.processor_config.streaming,
splitName="train",
datafile=None,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -430,7 +460,10 @@ def test_load_dataset_with_datasetconfig_files_folders(
processor_config=DataPreProcessorConfig(), tokenizer=None
)
load_dataset = processor.load_dataset(
datasetconfig=datasetconfig, splitName="train", datafile=None
datasetconfig=datasetconfig,
streaming=processor.processor_config.streaming,
splitName="train",
datafile=None,
)
assert set(load_dataset.column_names) == column_names

Expand Down Expand Up @@ -460,7 +493,10 @@ def test_load_dataset_with_datasetconfig_files_folders_incorrect_builder(
)
with pytest.raises(ValueError):
processor.load_dataset(
datasetconfig=datasetconfig, splitName="train", datafile=None
datasetconfig=datasetconfig,
streaming=processor.processor_config.streaming,
splitName="train",
datafile=None,
)


Expand Down Expand Up @@ -686,6 +722,117 @@ def test_process_data_args_throws_error_where_needed(data_args, packing):
(_, _, _, _, _, _) = process_dataargs(data_args, tokenizer, TRAIN_ARGS)


@pytest.mark.parametrize(
"data_config_path, data_path",
[
(
DATA_CONFIG_YAML_STREAMING,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
),
],
)
def test_process_streaming_dataconfig_file(data_config_path, data_path):
"""Ensure that datasets are formatted and validated correctly based on the arguments passed in config file."""
with open(data_config_path, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"][0] = data_path
datasets_name = yaml_content["datasets"][0]["name"]

# Modify input_field_name and output_field_name according to dataset
if datasets_name == "text_dataset_input_output_masking":
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"input_field_name": "input",
"output_field_name": "output",
}

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name == "apply_custom_data_template":
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
"template": template,
}

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

TRAIN_ARGS = configs.TrainingArguments(
max_steps=1,
output_dir="tmp", # Not needed but positional
)

(train_set, _, _) = _process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
assert isinstance(train_set, IterableDataset)

# Grab the keys since IterableDataset has no column names
first_example = next(iter(train_set))
set_column_names = list(first_example.keys())

if datasets_name == "text_dataset_input_output_masking":
column_names = set(["input_ids", "attention_mask", "labels"])
assert set(set_column_names) == column_names
elif datasets_name == "pretokenized_dataset":
assert set(["input_ids", "labels"]).issubset(set(set_column_names))
elif datasets_name == "apply_custom_data_template":
assert formatted_dataset_field in set(set_column_names)


@pytest.mark.parametrize(
"data_config_path, data_path",
[
(
DATA_CONFIG_YAML_STREAMING,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
),
],
)
def test_process_streaming_dataconfig_file_no_max_steps(data_config_path, data_path):
"""Ensure that if max steps aren't passed with streaming, error is raised"""
with open(data_config_path, "r") as f:
yaml_content = yaml.safe_load(f)
yaml_content["datasets"][0]["data_paths"][0] = data_path
datasets_name = yaml_content["datasets"][0]["name"]

# Modify input_field_name and output_field_name according to dataset
if datasets_name == "text_dataset_input_output_masking":
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"input_field_name": "input",
"output_field_name": "output",
}

# Modify dataset_text_field and template according to dataset
formatted_dataset_field = "formatted_data_field"
if datasets_name == "apply_custom_data_template":
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
yaml_content["datasets"][0]["data_handlers"][0]["arguments"]["fn_kwargs"] = {
"dataset_text_field": formatted_dataset_field,
"template": template,
}

with tempfile.NamedTemporaryFile(
"w", delete=False, suffix=".yaml"
) as temp_yaml_file:
yaml.dump(yaml_content, temp_yaml_file)
temp_yaml_file_path = temp_yaml_file.name
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

TRAIN_ARGS = configs.TrainingArguments(
output_dir="tmp", # Not needed but positional
)

with pytest.raises(ValueError):
(train_set, _, _) = _process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)


@pytest.mark.parametrize(
"data_config_path, data_path",
[
Expand Down Expand Up @@ -746,7 +893,12 @@ def test_process_dataconfig_file(data_config_path, data_path):
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)

TRAIN_ARGS = configs.TrainingArguments(
output_dir="tmp", # Not needed but positional
)

(train_set, _, _) = _process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
assert isinstance(train_set, Dataset)
if datasets_name == "text_dataset_input_output_masking":
column_names = set(["input_ids", "attention_mask", "labels"])
Expand Down Expand Up @@ -873,7 +1025,12 @@ def test_process_dataconfig_multiple_files(data_config_path, data_path_list):
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)

TRAIN_ARGS = configs.TrainingArguments(
output_dir="tmp", # Not needed but positional
)

(train_set, _, _) = _process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
assert isinstance(train_set, Dataset)
if datasets_name == "text_dataset_input_output_masking":
column_names = set(["input_ids", "attention_mask", "labels"])
Expand Down Expand Up @@ -937,7 +1094,12 @@ def test_process_dataconfig_multiple_files_folders_with_globbing(
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)

TRAIN_ARGS = configs.TrainingArguments(
output_dir="tmp", # Not needed but positional
)

(train_set, _, _) = _process_dataconfig_file(data_args, TRAIN_ARGS, tokenizer)
assert isinstance(train_set, Dataset)
assert set(["input_ids", "attention_mask", "labels"]).issubset(
set(train_set.column_names)
Expand Down Expand Up @@ -996,7 +1158,10 @@ def test_process_dataconfig_multiple_files_folders_without_builder(
(datasets.exceptions.DatasetNotFoundError, ValueError, pyarrow.lib.ArrowInvalid)
):
processor.load_dataset(
datasetconfig=datasetconfig, splitName="train", datafile=None
datasetconfig=datasetconfig,
streaming=processor.processor_config.streaming,
splitName="train",
datafile=None,
)


Expand Down
Loading