Skip to content

Comments

Fix runtime issues for FSDP/DeepSpeed training#1017

Open
SajanGhimire1 wants to merge 2 commits intolinkedin:mainfrom
SajanGhimire1:patch-1
Open

Fix runtime issues for FSDP/DeepSpeed training#1017
SajanGhimire1 wants to merge 2 commits intolinkedin:mainfrom
SajanGhimire1:patch-1

Conversation

@SajanGhimire1
Copy link

  • Removed unsafe usage of _MISSING_TYPE in parse_args.
  • Fixed KeyError in DataModule by correcting dataset field access.
  • Replaced set-based FSDP auto_wrap_policy with transformer_auto_wrap_policy.
  • Corrected Lightning precision strings to valid values (bf16-mixed).
  • Fixed devices argument to safely detect available GPUs.
  • Added safe get() for labels in training/validation steps to avoid KeyError.

Summary

Ensures stable and correct training across multi-GPU setups with FSDP/DeepSpeed by fixing dataset handling, auto-wrap policy, precision settings, and device detection.

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

- Removed unsafe usage of _MISSING_TYPE in parse_args.
- Fixed KeyError in DataModule by correcting dataset field access.
- Replaced set-based FSDP auto_wrap_policy with transformer_auto_wrap_policy.
- Corrected Lightning precision strings to valid values (bf16-mixed).
- Fixed devices argument to safely detect available GPUs.
- Added safe get() for labels in training/validation steps to avoid KeyError.
from torch.utils.data import DataLoader
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
from trl import DataCollatorForCompletionOnlyLM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What trl version should I use? I couldn't import DataCollatorForCompletionOnlyLM

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataCollatorForCompletionOnlyLM class requires trl >= 0.8.0. Please make sure that version (or higher) is installed to avoid import issues.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trl==0.26.2 doesn't work

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trl==0.26.2 no longer exposes DataCollatorForCompletionOnlyLM in the same way. This example is intended to work with older TRL releases where the collator exists, specifically: trl>=0.8.0,<0.21.0 In newer TRL versions (including 0.26.x), the collator was refactored/removed, which causes the import error.

data: str = "cais/mmlu"
output_dir: str = "mmlu_finetuning"
max_length: int = 2048
# for llam3 8B model, deepspeed will OOM with 16 on 8XA100 80G and 8 will OOM with 8XA100 40G
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why removing comments?

Restore removed comments without functional changes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants