Skip to content

Commit 179da0a

Browse files
authored
feat: Enable LoRA saving only for non MoE linear layers training with kernels. (#530)
* save peft Signed-off-by: Will Johnson <[email protected]> * post process hf converted dir Signed-off-by: Will Johnson <[email protected]> * fix: convert hf converted checkpoint Signed-off-by: Will Johnson <[email protected]> * lora config Signed-off-by: Will Johnson <[email protected]> * save adapter config Signed-off-by: Will Johnson <[email protected]> * fix: add input linear and output linear to target modules Signed-off-by: Will Johnson <[email protected]> * fix: extend instead of append Signed-off-by: Will Johnson <[email protected]> * fix: if hasattr peft config Signed-off-by: Will Johnson <[email protected]> * fix: remove unneeded target modules Signed-off-by: Will Johnson <[email protected]> * test: lora for scattermoe Signed-off-by: Will Johnson <[email protected]> * explitcitly don't support router layer Signed-off-by: Will Johnson <[email protected]> * docs: update documentation Signed-off-by: Will Johnson <[email protected]> * fix: simplify accelerate launch post processing Signed-off-by: Will Johnson <[email protected]> * tests: more target modules + ep_degree Signed-off-by: Will Johnson <[email protected]> * fix: only restrict all-linear, raise warning for other modules Signed-off-by: Will Johnson <[email protected]> * fix: augmentation test Signed-off-by: Will Johnson <[email protected]> * fix: raise error Signed-off-by: Will Johnson <[email protected]> * turn off requires grad if using scattermoe with lora Signed-off-by: Will Johnson <[email protected]> * fix: freeze scattermoe params Signed-off-by: Will Johnson <[email protected]> * fix: safer freezing Signed-off-by: Will Johnson <[email protected]> --------- Signed-off-by: Will Johnson <[email protected]>
1 parent ebe35a3 commit 179da0a

File tree

7 files changed

+182
-41
lines changed

7 files changed

+182
-41
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ notes-rgx=
475475
[REFACTORING]
476476

477477
# Maximum number of nested blocks for function / method body
478-
max-nested-blocks=5
478+
max-nested-blocks=6
479479

480480
# Complete name of functions that never returns. When checking for
481481
# inconsistent-return-statements if a never returning function is called then

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,9 @@ Notes:
855855
- When a boolean is passed, the expert parallel degree defaults to 1 and further the behaviour would be as follows:
856856
- if True, it is Scatter MoE Kernels with experts sharded based on the top level sharding protocol (e.g. FSDP).
857857
- if False, Scatter MoE Kernels with complete replication of experts across ranks.
858+
- FSDP must be used when lora tuning with `--fast_moe`
859+
- lora tuning with ScatterMoE is supported, but because of inference restrictions on vLLM/vanilla PEFT, the expert layers and router linear layer should not be trained as `target_modules` for models being tuned with ScatterMoE. Users have control over which `target_modules` they wish to train:
860+
- At this time, only attention layers are trainable when using LoRA with scatterMoE. Until support for the router linear layer is added in, target modules must be specified explicitly (i.e `target_modules: ["q_proj", "v_proj", "o_proj", "k_proj"]`) instead of passing `target_modules: ["all-linear"]`.
858861
- `world_size` must be divisible by the `ep_degree`
859862
- `number of experts` in the MoE module must be divisible by the `ep_degree`
860863
- Running fast moe modifies the state dict of the model, and must be post-processed which happens automatically and the converted checkpoint can be found at `hf_converted_checkpoint` folder within every saved checkpoint directory. Alternatively, we can perform similar option manually through [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) script.

build/accelerate_launch.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,17 @@ def main():
146146
save_model_dir, save_model_dir, num_added_tokens
147147
)
148148

149+
# In case of ScatterMoE LoRa
150+
hf_converted_checkpoint = os.path.join(
151+
save_model_dir, "hf_converted_checkpoint"
152+
)
153+
if os.path.exists(
154+
os.path.join(hf_converted_checkpoint, "adapter_model.safetensors")
155+
):
156+
post_process_vLLM_adapters_new_tokens(
157+
hf_converted_checkpoint, hf_converted_checkpoint, num_added_tokens
158+
)
159+
149160
if (
150161
os.path.exists(os.path.join(output_dir, "added_tokens_info.json"))
151162
and job_config.get("save_strategy") != "no"
@@ -159,11 +170,30 @@ def main():
159170
for _, dirs, _ in os.walk(output_dir, topdown=False):
160171
for name in dirs:
161172
if "checkpoint-" in name.lower():
162-
post_process_vLLM_adapters_new_tokens(
163-
os.path.join(output_dir, name),
164-
os.path.join(output_dir, name),
165-
num_added_tokens,
173+
base_checkpoint_dir = os.path.join(output_dir, name)
174+
hf_converted_checkpoint = os.path.join(
175+
base_checkpoint_dir, "hf_converted_checkpoint"
176+
)
177+
178+
# Use hf_converted_checkpoint if exists, otherwise use base_checkpoint_dir
179+
checkpoint_dir = (
180+
hf_converted_checkpoint
181+
if os.path.exists(
182+
os.path.join(
183+
hf_converted_checkpoint, "adapter_model.safetensors"
184+
)
185+
)
186+
else base_checkpoint_dir
166187
)
188+
189+
if os.path.exists(
190+
os.path.join(checkpoint_dir, "adapter_model.safetensors")
191+
):
192+
post_process_vLLM_adapters_new_tokens(
193+
checkpoint_dir,
194+
checkpoint_dir,
195+
num_added_tokens,
196+
)
167197
else:
168198
logging.warning(
169199
"Failed to post-process: file added_tokens_info.json not in path %s",

tests/acceleration/test_acceleration_framework.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -532,8 +532,8 @@ def test_framework_initialized_properly_moe():
532532
)
533533

534534
# spy inside the train to ensure that the ilab plugin is called
535-
assert spy["model_loader_calls"] == 1
536-
assert spy["augmentation_calls"] == 0
535+
assert spy["model_loader_calls"] == 0
536+
assert spy["augmentation_calls"] == 1
537537
assert spy["get_ready_for_train_calls"] == 1
538538

539539

@@ -776,37 +776,34 @@ def test_error_raised_fast_moe_with_non_moe_model():
776776
"""
777777
Ensure error is thrown when `--fast_moe` is passed and model is not MoE
778778
"""
779-
with pytest.raises(
780-
AttributeError,
781-
match="'LlamaConfig' object has no attribute 'num_local_experts'",
782-
):
783-
with tempfile.TemporaryDirectory() as tempdir:
779+
with tempfile.TemporaryDirectory() as tempdir:
784780

785-
model_args = copy.deepcopy(MODEL_ARGS)
786-
model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3"
787-
model_args.torch_dtype = torch.bfloat16
788-
train_args = copy.deepcopy(TRAIN_ARGS)
789-
train_args.output_dir = tempdir
790-
train_args.save_strategy = "no"
791-
train_args.bf16 = True
792-
data_args = copy.deepcopy(DATA_ARGS)
793-
data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT
794-
data_args.response_template = "\n\n### Label:"
795-
data_args.dataset_text_field = "output"
781+
model_args = copy.deepcopy(MODEL_ARGS)
782+
model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3"
783+
model_args.torch_dtype = torch.bfloat16
784+
train_args = copy.deepcopy(TRAIN_ARGS)
785+
train_args.output_dir = tempdir
786+
train_args.save_strategy = "no"
787+
train_args.bf16 = True
788+
data_args = copy.deepcopy(DATA_ARGS)
789+
data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT
790+
data_args.response_template = "\n\n### Label:"
791+
data_args.dataset_text_field = "output"
796792

797-
# initialize a config
798-
moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1))
793+
# initialize a config
794+
moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1))
799795

800-
# 1. mock a plugin class
801-
# 2. register the mocked plugins
802-
# 3. call sft_trainer.train
803-
with build_framework_and_maybe_instantiate(
804-
[
805-
(["training.moe.scattermoe"], ScatterMoEAccelerationPlugin),
806-
],
807-
instantiate=False,
808-
):
809-
with instantiate_model_patcher():
796+
# 1. mock a plugin class
797+
# 2. register the mocked plugins
798+
# 3. call sft_trainer.train
799+
with build_framework_and_maybe_instantiate(
800+
[
801+
(["training.moe.scattermoe"], ScatterMoEAccelerationPlugin),
802+
],
803+
instantiate=False,
804+
):
805+
with instantiate_model_patcher():
806+
with pytest.raises((ValueError, AttributeError)):
810807
sft_trainer.train(
811808
model_args,
812809
data_args,

tests/test_sft_trainer.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,6 +1453,61 @@ def test_run_moe_ft_and_inference_ep1_kernels(dataset_path, ep_degree):
14531453
)
14541454

14551455

1456+
@pytest.mark.skipif(
1457+
not is_fms_accelerate_available(plugins="moe"),
1458+
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
1459+
)
1460+
@pytest.mark.parametrize(
1461+
"target_modules",
1462+
[
1463+
"all-linear",
1464+
["q_proj"],
1465+
["q_proj", "k_proj"],
1466+
["q_proj", "k_proj", "v_proj"],
1467+
["q_proj", "k_proj", "v_proj", "o_proj"],
1468+
],
1469+
)
1470+
@pytest.mark.parametrize("ep_degree", [True, False])
1471+
@pytest.mark.parametrize("dataset_path", [TWITTER_COMPLAINTS_DATA_JSONL])
1472+
def test_run_moe_lora_and_inference(dataset_path, target_modules, ep_degree):
1473+
"""Check if we can finetune a moe model and check if hf checkpoint is created"""
1474+
with tempfile.TemporaryDirectory() as tempdir:
1475+
data_args = copy.deepcopy(DATA_ARGS)
1476+
data_args.training_data_path = dataset_path
1477+
model_args = copy.deepcopy(MODEL_ARGS)
1478+
model_args.model_name_or_path = "ibm-granite/granite-3.1-1b-a400m-base"
1479+
train_args = copy.deepcopy(TRAIN_ARGS)
1480+
train_args.output_dir = tempdir
1481+
lora_args = copy.deepcopy(PEFT_LORA_ARGS)
1482+
lora_args.r = 16
1483+
lora_args.target_modules = target_modules
1484+
fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=ep_degree))
1485+
1486+
if target_modules == "all-linear":
1487+
with pytest.raises(ValueError):
1488+
sft_trainer.train(
1489+
model_args,
1490+
data_args,
1491+
train_args,
1492+
lora_args,
1493+
fast_moe_config=fast_moe_config,
1494+
)
1495+
else:
1496+
sft_trainer.train(
1497+
model_args,
1498+
data_args,
1499+
train_args,
1500+
lora_args,
1501+
fast_moe_config=fast_moe_config,
1502+
)
1503+
_test_run_inference(
1504+
checkpoint_path=os.path.join(
1505+
_get_checkpoint_path(tempdir), "hf_converted_checkpoint"
1506+
),
1507+
base_model_name_or_path="ibm-granite/granite-3.1-1b-a400m-base",
1508+
)
1509+
1510+
14561511
@pytest.mark.skipif(
14571512
not is_fms_accelerate_available(plugins="moe"),
14581513
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
@@ -1491,9 +1546,9 @@ def _test_run_causallm_ft(training_args, model_args, data_args, tempdir):
14911546
_validate_training(tempdir)
14921547

14931548

1494-
def _test_run_inference(checkpoint_path):
1549+
def _test_run_inference(checkpoint_path, base_model_name_or_path=None):
14951550
# Load the model
1496-
loaded_model = TunedCausalLM.load(checkpoint_path)
1551+
loaded_model = TunedCausalLM.load(checkpoint_path, base_model_name_or_path)
14971552

14981553
# Run inference on the text
14991554
output_inference = loaded_model.run(

tuning/config/acceleration_configs/fast_moe.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from dataclasses import dataclass, field
1717
from typing import Union
1818
import argparse
19+
import json
1920
import os
2021

2122
# Third Party
@@ -121,10 +122,29 @@ def checkpoint(checkpoint_dir, save_dir):
121122
args,
122123
os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME),
123124
)
124-
# Save model config files
125-
self.trainer.model.config.save_pretrained(
126-
hf_converted_output_dir
127-
)
125+
126+
# Unwrap FSDP module
127+
model = self.trainer.model
128+
if hasattr(model, "module"):
129+
model = model.module
130+
131+
if hasattr(model, "peft_config"):
132+
lora_config = model.peft_config["default"]
133+
config_dict = lora_config.to_dict()
134+
config_dict["target_modules"] = sorted(
135+
list(config_dict["target_modules"])
136+
)
137+
with open(
138+
os.path.join(
139+
hf_converted_output_dir, "adapter_config.json"
140+
),
141+
"w",
142+
encoding="utf-8",
143+
) as f:
144+
json.dump(config_dict, f, indent=2)
145+
146+
else:
147+
model.config.save_pretrained(hf_converted_output_dir)
128148

129149
except Exception as e:
130150
raise ValueError(

tuning/sft_trainer.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,33 @@ def train(
168168
"Trainer should not perform packing when using `--padding_free`"
169169
)
170170

171+
if fast_moe_config is not None:
172+
# Checking for unsupported modules with Scatter MoE for LoRA
173+
# Only raise an error for `all-linear`
174+
restricted_modules = ["all-linear"]
175+
if (
176+
peft_config is not None
177+
and hasattr(peft_config, "target_modules")
178+
and any(
179+
module in (peft_config.target_modules or [])
180+
for module in restricted_modules
181+
)
182+
):
183+
raise ValueError(
184+
"`--fast_moe` with LoRA does not currently support `all-linear`, as "
185+
"target modules at this time. Please explicitly specify target "
186+
"modules when using `--fast_moe` with LoRA."
187+
)
188+
# If other common non-linear modules, raise warning
189+
if peft_config is not None and hasattr(peft_config, "target_modules"):
190+
logger.warning(
191+
"You are running lora with the ScatterMoE plugin, please note that "
192+
"passing target modules that are part of the moe module can cause unexpected "
193+
"behaviors and unsuccessful tuning while LoRA tuning with ScatterMoE. "
194+
"For safe tuning, only pass linear modules such as those in the attn layer "
195+
"(i.e. ['q_proj', 'v_proj', 'o_proj', 'k_proj'])"
196+
)
197+
171198
task_type = "CAUSAL_LM"
172199
additional_metrics = {}
173200

@@ -360,6 +387,15 @@ def train(
360387
model, (peft_config,) = framework.augmentation(
361388
model, train_args, modifiable_args=(peft_config,)
362389
)
390+
# HACK - For LoRa ScatterMoE, disable grad for ScatterMoE.
391+
# In the future, requires_grad should be enabled for LoRA tuning
392+
# with ScatterMoE and this code should be removed.
393+
if peft_config is not None:
394+
for module in model.modules():
395+
# Use string comparison to check if ScatterMoE module
396+
if module.__class__.__name__ == "ScatterMoE":
397+
for param in module.parameters():
398+
param.requires_grad = False
363399

364400
# HACK - The SFT Trainer has internal validation which inspects the name of the class
365401
# being used for the HF training args; if it's a TrainingArguments class, which is

0 commit comments

Comments
 (0)