Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable streaming in data preprocessor #437

Open
wants to merge 35 commits into
base: main
Choose a base branch
from

Conversation

willmj
Copy link
Collaborator

@willmj willmj commented Jan 14, 2025

Description of the change

These changes enable streaming and test streaming datasets.
Added:

  • Add streaming as an arg in DataSetConfig similarly to sampling
  • Add examples of DataSetConfig in tests/artifacts/predefined_data_configs/ for streaming
  • Add unit tests
  • Since IterableDatasets can't be indexed, use first example where column names are needed
  • User must set max_steps instead of num_train_epochs if using streaming

Related issue number

How to verify the PR

  • Run new unit tests which verify HF inference works and passing streaming in dataconfig returns and IterableDataset
  • Run on single GPU error
  • Run on multi GPU without error

Was the PR tested

  • I have added >=1 unit test(s) for every new method I have added.
  • I have ensured all unit tests pass

willmj added 3 commits January 8, 2025 15:31
…r future tests, add streaming to config

Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
Copy link

Thanks for making a pull request! 😃
One of the maintainers will review and advise on the next steps.

@github-actions github-actions bot added the feat label Jan 14, 2025
Copy link
Collaborator

@Abhishek-TAMU Abhishek-TAMU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @willmj for integrating usage of Iterable datasets. Just some initial thoughts.

@ashokponkumar
Copy link
Collaborator

Shouldn't streaming be a top level object instead of a per dataset object? Is it possible to mix streaming and non-streaming datasets using concat?

@seshapad
Copy link
Contributor

@willmj Is this PR in a usable state? We need to run a EPT with large datasets. Without streaming the data processing is failing. We want the streaming feature to address this issue.

@kmehant
Copy link
Collaborator

kmehant commented Jan 26, 2025

@willmj I would request your attention to this

if "column_names" not in data or data.column_names is None:
if isinstance(data, IterableDataset):
if hasattr(data, "_resolve_features"):
data = data._resolve_features()
else:
raise ValueError(
"_resolve_features API is not available to fetch column names"
)
else:
raise ValueError(
f"not possible to fetch column names for the loaded dataset of type {type(data)}"
)
. iterabledatasets often loose out column information (sometimes on loading, or after map operations applied), so its good to be defensive on retrieving columns wherever necessary.

Signed-off-by: Will Johnson <[email protected]>
[
(
[TWITTER_COMPLAINTS_DATA_DIR_JSON],
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
Copy link
Collaborator

@Abhishek-TAMU Abhishek-TAMU Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this yaml file is to be used for the this test case: DATA_CONFIG_YAML_STREAMING

@willmj
Copy link
Collaborator Author

willmj commented Jan 28, 2025

@seshapad I have now had a successful tuning job with streaming on multi GPU. You should be able to try it out, let me know if you run into any errors.

Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
@willmj
Copy link
Collaborator Author

willmj commented Jan 29, 2025

Tuning + inference works! Only 200 steps so the equivalent of less than an epoch, which is why the result is wrong - but format is right.
Config:

      {
          "model_name_or_path": "/llama3/hf/8b_pre_trained",
          "data_config_path": "/testing/tuning/input/apply-custom-template-streaming-data-config.yaml",
          "output_dir": "/testing/tuning/output/llama3-8b/ft/tone_20250129_1045-streaming-dataconfig",
          "save_model_dir": "/testing/tuning/output/llama3-8b/ft/tone_20250129_1045-streaming-dataconfig/save_model",
          "max_steps": 200,
          "per_device_train_batch_size": 4,
          "gradient_accumulation_steps": 1,
          "learning_rate": 1e-4,
          "response_template": "\n### Response:",
          "dataset_text_field": "output"
      }

Inference result on "Text: @sho_help @showtime your arrive is terrible streaming is stop and start every couple mins. Get it together it's xmas\n\n### Label:":

{
  "responses": [
    {
      "generatedTokenCount": 2,
      "text": " polite\u003c|end_of_text|\u003e",
      "inputTokenCount": 34,
      "stopReason": "EOS_TOKEN",
      "stopSequence": "\u003c|end_of_text|\u003e"
    }
  ]
}

@seshapad
Copy link
Contributor

seshapad commented Jan 30, 2025

@willmj The streaming option crashes. I have attached the log for debugging. Here is the data config:

dataprocessor:
    type: default
    sampling_stopping_strategy: all_exhausted
    seed: 66
    streaming: true
datasets:
  - name: pleias
    sampling: 1.0
    data_paths:
      - "/pleias_greek/"
    data_handlers:
      - name: apply_dataset_formatting
        arguments:
          remove_columns: ['source_directory', 'domain', 'document', 'subset', 'split', 'document_id', 'identifier', 'collection', 'license', '_meta_timestamp', '_meta_request_url', '_meta_final_url', '_meta_dataset', '_meta_job_id', '_meta_file_name', '_meta_json']
          fn_kwargs:
            dataset_text_field: "contents"

I can share the dataset with you if you wish to attempt reproducing this bug.
Configuration of cli used:

accelerate launch \
  --num_processes=8 \
  --dynamo_backend="no" \
  --fsdp_auto_wrap_policy="TRANSFORMER_BASED_WRAP" \
  --fsdp_cpu_ram_efficient_loading="true" \
  --fsdp_forward_prefetch="false" \
  --fsdp_offload_params="false" \
  --fsdp_sharding_strategy="HYBRID_SHARD" \
  --fsdp_state_dict_type="FULL_STATE_DICT" \
  --fsdp_sync_module_states="true" \
  --machine_rank="${RANK}" \
  --main_process_ip="${MASTER_ADDR}" \
  --main_process_port="${MASTER_PORT}" \
  --mixed_precision="no" \
  --num_machines="${WORLD_SIZE}" \
  --rdzv_backend="static" \
  --same_network \
  --use_fsdp \
  -m tuning.sft_trainer \
  --adam_beta2="0.95" \
  --aim_repo="${AIMSTACK_DB}" \
  --data_config="data_config.yaml" \
  --evaluation_strategy="no" \
  --experiment="train-nb-g8b-r18" \
  --gradient_accumulation_steps="1" \
  --gradient_checkpointing="true" \
  --include_tokens_per_second="true" \
  --learning_rate="0.0003" \
  --logging_steps="1" \
  --logging_strategy="steps" \
  --lr_scheduler_type="cosine" \
  --max_grad_norm="1" \
  --max_steps="100" \
  --model_name_or_path="ibm-granite/granite-3.1-8b-base" \
  --output_dir="/run18" \
  --packing="true" \
  --per_device_train_batch_size="8" \
  --save_steps="50" \
  --save_strategy="steps" \
  --split_batches="true" \
  --torch_dtype="bfloat16" \
  --tracker="aim" \
  --use_flash_attn="true" \
  --warmup_ratio="0.05" \
  --weight_decay="0.1" \
  2>&1 | tee -a "/run18/accelerate_launch_output.log"

cc: @ashokponkumar

… pretokenized case in data collator

Signed-off-by: Will Johnson <[email protected]>
data = processor.load_dataset(
None,
streaming=processor.processor_config.streaming,
splitName="train[:1]",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need to specify streaming if all we do is load just first line of the train split?
Can you please check what does HF docs say about this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For checking the columns, it seems fine in unit tests to not pass streaming - however it does load the example as a Dataset instead of an IterableDataset. If this is okay with you we can either pass in streaming through kwargs of load_dataset, default streaming to false in load_dataset, or just set it to false for loading this. Let me know what you think will work best.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If for checking columns a single sample can be loaded without streaming you can choose that route and force disable streaming in this call..I would be fine with it

What my question was to ask if a single sample can be loaded in all cases without performance considerations even for large datasets...so I wanted to ask if 1) HF load the only 1 sample from disk? 2) HF loads all samples and then drops all but one
In (2) the performance can take a hit.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems from HF documentation on slice splits that HF load dataset goes for number 1.

"Setting `split_batches` to true - splitting batches among devices \
`per_device_train_batch_size` is now the global batch size, and \
should be treated as such."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I can live with this check for now but feel like we should be more clear in this and must not waste a run where user's run fails and they scramble through the logs to find this..

Is there a suggestion to make this explicit?

Also...please move this inside process_dataargs itself...we do have train_args available so can do this inside that function...and do we need to set accelerator_config to this dict ... do we not need to append it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes per Mehant's suggestion we set accelerator_config to this dict. You bring up a good point that documentation should be added for this PR, I will add it soon.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added documentation in advanced-data-preprocessing.md, and made the warning more explicit.

@dushyantbehl
Copy link
Contributor

@seshapad @HarikrishnanBalagopal can we take this branch now and shoot an EPT run which ended in error before?

@willmj has made all the code corrrectness changes so request you to do a sanity check before we go for merge.

@seshapad
Copy link
Contributor

seshapad commented Feb 6, 2025

@HarikrishnanBalagopal please provide image for this branch. I will start an ept.

@willmj
Copy link
Collaborator Author

willmj commented Feb 11, 2025

According to this comment once trl is upgraded the case highlighted shouldn't be needed so I have removed it, which may cause some tests to fail while trl is not merged, specifically

tests/test_sft_trainer.py::test_run_causallm_ft_and_inference_streaming_ept

This means this PR is waiting on upgrading TRL.

@willmj
Copy link
Collaborator Author

willmj commented Feb 20, 2025

Removing the following test case:

@pytest.mark.parametrize(
    "datafiles, datasetconfigname",
    [
        (
            [TWITTER_COMPLAINTS_TOKENIZED_JSON],
            DATA_CONFIG_YAML_STREAMING_PRETOKENIZED,
        ),
    ],
)
def test_run_causallm_ft_and_inference_streaming_ept(datasetconfigname, datafiles):
    """Check if we can finetune causallm models using multiple datasets with multiple files"""
    with tempfile.TemporaryDirectory() as tempdir:
        data_formatting_args = copy.deepcopy(DATA_ARGS)

        # set training_data_path and response_template to none
        data_formatting_args.response_template = None
        data_formatting_args.training_data_path = None

        # add data_paths in data_config file
        with tempfile.NamedTemporaryFile(
            "w", delete=False, suffix=".yaml"
        ) as temp_yaml_file:
            with open(datasetconfigname, "r", encoding="utf-8") as f:
                data = yaml.safe_load(f)
                datasets = data["datasets"]
                for _, d in enumerate(datasets):
                    d["data_paths"] = datafiles
                yaml.dump(data, temp_yaml_file)
                data_formatting_args.data_config_path = temp_yaml_file.name

        train_args = copy.deepcopy(TRAIN_ARGS)
        train_args.output_dir = tempdir
        train_args.max_steps = 1
        train_args.packing = True

        sft_trainer.train(MODEL_ARGS, data_formatting_args, train_args)

        # validate full ft configs
        _validate_training(tempdir)
        _, checkpoint_path = _get_latest_checkpoint_trainer_state(tempdir)

        # Load the model
        loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)

        # Run inference on the text
        output_inference = loaded_model.run(
            "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
        )
        assert len(output_inference) > 0
        assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference

as it will be resolved by #468

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants