-
Notifications
You must be signed in to change notification settings - Fork 65
feat: Enable LoRA saving only for non MoE linear layers training with kernels. #530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cac0b8c
c522429
481dde6
397c9ba
79dec24
3103720
b61cbde
c12be0e
123c2d4
f68500b
55ec4b5
0cfb9f4
67bed66
2362349
a848a9b
42c420c
e3e7525
8449659
3c25265
da81f93
1424efd
b67ef0f
6a32d32
d2b6153
765ec95
b0dea82
806b716
f742e0b
2567d30
70468db
5b826c8
af408f9
7b026a9
d7c2d15
0f7796e
4e8d774
1759a2f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -168,6 +168,33 @@ def train( | |
| "Trainer should not perform packing when using `--padding_free`" | ||
| ) | ||
|
|
||
| if fast_moe_config is not None: | ||
| # Checking for unsupported modules with Scatter MoE for LoRA | ||
| # Only raise an error for `all-linear` | ||
| restricted_modules = ["all-linear"] | ||
| if ( | ||
| peft_config is not None | ||
| and hasattr(peft_config, "target_modules") | ||
| and any( | ||
| module in (peft_config.target_modules or []) | ||
| for module in restricted_modules | ||
| ) | ||
| ): | ||
| raise ValueError( | ||
| "`--fast_moe` with LoRA does not currently support `all-linear`, as " | ||
| "target modules at this time. Please explicitly specify target " | ||
| "modules when using `--fast_moe` with LoRA." | ||
| ) | ||
| # If other common non-linear modules, raise warning | ||
| if peft_config is not None and hasattr(peft_config, "target_modules"): | ||
| logger.warning( | ||
| "You are running lora with the ScatterMoE plugin, please note that " | ||
| "passing target modules that are part of the moe module can cause unexpected " | ||
| "behaviors and unsuccessful tuning while LoRA tuning with ScatterMoE. " | ||
| "For safe tuning, only pass linear modules such as those in the attn layer " | ||
| "(i.e. ['q_proj', 'v_proj', 'o_proj', 'k_proj'])" | ||
| ) | ||
|
|
||
| task_type = "CAUSAL_LM" | ||
| additional_metrics = {} | ||
|
|
||
|
|
@@ -360,6 +387,15 @@ def train( | |
| model, (peft_config,) = framework.augmentation( | ||
| model, train_args, modifiable_args=(peft_config,) | ||
| ) | ||
| # HACK - For LoRa ScatterMoE, disable grad for ScatterMoE. | ||
| # In the future, requires_grad should be enabled for LoRA tuning | ||
| # with ScatterMoE and this code should be removed. | ||
| if peft_config is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think its better to just do an instance check using
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe you should put a comment that this is a workaround, and in the future where ScatterMoE loras can be tuned, this code needs to be removed. |
||
| for module in model.modules(): | ||
| # Use string comparison to check if ScatterMoE module | ||
| if module.__class__.__name__ == "ScatterMoE": | ||
| for param in module.parameters(): | ||
| param.requires_grad = False | ||
|
|
||
| # HACK - The SFT Trainer has internal validation which inspects the name of the class | ||
| # being used for the HF training args; if it's a TrainingArguments class, which is | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed for non lora use case? since after this step it would directly reach the below block
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's needed to determine if there is a peft config for the following if statement, but in fine-tuning case this if statement should be false