Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
cac0b8c
save peft
willmj Mar 24, 2025
c522429
fix: model
willmj Mar 24, 2025
481dde6
post process hf converted dir
willmj Apr 1, 2025
397c9ba
fix: convert hf converted checkpoint
willmj Apr 7, 2025
79dec24
lora config
willmj Apr 7, 2025
3103720
save adapter config
willmj Apr 7, 2025
b61cbde
fmt + comments
willmj Apr 7, 2025
c12be0e
fix: add input linear and output linear to target modules
willmj Apr 8, 2025
123c2d4
fix: extend instead of append
willmj Apr 8, 2025
f68500b
fix: if hasattr peft config
willmj Apr 8, 2025
55ec4b5
fix: remove unneeded target modules
willmj Apr 9, 2025
0cfb9f4
Merge branch 'main' into save-peft-fast-moe
willmj Apr 10, 2025
67bed66
merge: main into branch
willmj Apr 10, 2025
2362349
lint + fmt
willmj Apr 10, 2025
a848a9b
docs
willmj Apr 11, 2025
42c420c
test: lora for scattermoe
willmj Apr 14, 2025
e3e7525
fmt tests
willmj Apr 15, 2025
8449659
docs: notes on restrictions
willmj Apr 16, 2025
3c25265
explitcitly don't support router layer
willmj Apr 17, 2025
da81f93
docs: generalize
willmj Apr 18, 2025
1424efd
docs: update documentation
willmj Apr 18, 2025
b67ef0f
fix: simplify accelerate launch post processing
willmj Apr 18, 2025
6a32d32
tests: more target modules + ep_degree
willmj Apr 18, 2025
d2b6153
fix: only restrict all-linear, raise warning for other modules
willmj Apr 18, 2025
765ec95
fix: augmentation test
willmj Apr 18, 2025
b0dea82
fix: raise error
willmj Apr 18, 2025
806b716
fix: raise error
willmj Apr 18, 2025
f742e0b
Merge branch 'main' into save-peft-fast-moe-limited
willmj Apr 18, 2025
2567d30
fix: make warning more general
willmj Apr 18, 2025
70468db
turn off requires grad if using scattermoe with lora
willmj Apr 18, 2025
5b826c8
fix: freeze scattermoe params
willmj Apr 18, 2025
af408f9
fix: safer freezing
willmj Apr 18, 2025
7b026a9
Merge branch 'main' into save-peft-fast-moe-limited
willmj Apr 18, 2025
d7c2d15
just use string for class name
willmj Apr 18, 2025
0f7796e
comment
willmj Apr 18, 2025
4e8d774
Merge branch 'main' into save-peft-fast-moe-limited
willmj Apr 18, 2025
1759a2f
add comment
willmj Apr 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ notes-rgx=
[REFACTORING]

# Maximum number of nested blocks for function / method body
max-nested-blocks=5
max-nested-blocks=6

# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,9 @@ Notes:
- When a boolean is passed, the expert parallel degree defaults to 1 and further the behaviour would be as follows:
- if True, it is Scatter MoE Kernels with experts sharded based on the top level sharding protocol (e.g. FSDP).
- if False, Scatter MoE Kernels with complete replication of experts across ranks.
- FSDP must be used when lora tuning with `--fast_moe`
- lora tuning with ScatterMoE is supported, but because of inference restrictions on vLLM/vanilla PEFT, the expert layers and router linear layer should not be trained as `target_modules` for models being tuned with ScatterMoE. Users have control over which `target_modules` they wish to train:
- At this time, only attention layers are trainable when using LoRA with scatterMoE. Until support for the router linear layer is added in, target modules must be specified explicitly (i.e `target_modules: ["q_proj", "v_proj", "o_proj", "k_proj"]`) instead of passing `target_modules: ["all-linear"]`.
- `world_size` must be divisible by the `ep_degree`
- `number of experts` in the MoE module must be divisible by the `ep_degree`
- Running fast moe modifies the state dict of the model, and must be post-processed which happens automatically and the converted checkpoint can be found at `hf_converted_checkpoint` folder within every saved checkpoint directory. Alternatively, we can perform similar option manually through [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) script.
Expand Down
38 changes: 34 additions & 4 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ def main():
save_model_dir, save_model_dir, num_added_tokens
)

# In case of ScatterMoE LoRa
hf_converted_checkpoint = os.path.join(
save_model_dir, "hf_converted_checkpoint"
)
if os.path.exists(
os.path.join(hf_converted_checkpoint, "adapter_model.safetensors")
):
post_process_vLLM_adapters_new_tokens(
hf_converted_checkpoint, hf_converted_checkpoint, num_added_tokens
)

if (
os.path.exists(os.path.join(output_dir, "added_tokens_info.json"))
and job_config.get("save_strategy") != "no"
Expand All @@ -159,11 +170,30 @@ def main():
for _, dirs, _ in os.walk(output_dir, topdown=False):
for name in dirs:
if "checkpoint-" in name.lower():
post_process_vLLM_adapters_new_tokens(
os.path.join(output_dir, name),
os.path.join(output_dir, name),
num_added_tokens,
base_checkpoint_dir = os.path.join(output_dir, name)
hf_converted_checkpoint = os.path.join(
base_checkpoint_dir, "hf_converted_checkpoint"
)

# Use hf_converted_checkpoint if exists, otherwise use base_checkpoint_dir
checkpoint_dir = (
hf_converted_checkpoint
if os.path.exists(
os.path.join(
hf_converted_checkpoint, "adapter_model.safetensors"
)
)
else base_checkpoint_dir
)

if os.path.exists(
os.path.join(checkpoint_dir, "adapter_model.safetensors")
):
post_process_vLLM_adapters_new_tokens(
checkpoint_dir,
checkpoint_dir,
num_added_tokens,
)
else:
logging.warning(
"Failed to post-process: file added_tokens_info.json not in path %s",
Expand Down
57 changes: 27 additions & 30 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ def test_framework_initialized_properly_moe():
)

# spy inside the train to ensure that the ilab plugin is called
assert spy["model_loader_calls"] == 1
assert spy["augmentation_calls"] == 0
assert spy["model_loader_calls"] == 0
assert spy["augmentation_calls"] == 1
assert spy["get_ready_for_train_calls"] == 1


Expand Down Expand Up @@ -776,37 +776,34 @@ def test_error_raised_fast_moe_with_non_moe_model():
"""
Ensure error is thrown when `--fast_moe` is passed and model is not MoE
"""
with pytest.raises(
AttributeError,
match="'LlamaConfig' object has no attribute 'num_local_experts'",
):
with tempfile.TemporaryDirectory() as tempdir:
with tempfile.TemporaryDirectory() as tempdir:

model_args = copy.deepcopy(MODEL_ARGS)
model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3"
model_args.torch_dtype = torch.bfloat16
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.save_strategy = "no"
train_args.bf16 = True
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT
data_args.response_template = "\n\n### Label:"
data_args.dataset_text_field = "output"
model_args = copy.deepcopy(MODEL_ARGS)
model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3"
model_args.torch_dtype = torch.bfloat16
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.save_strategy = "no"
train_args.bf16 = True
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT
data_args.response_template = "\n\n### Label:"
data_args.dataset_text_field = "output"

# initialize a config
moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1))
# initialize a config
moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1))

# 1. mock a plugin class
# 2. register the mocked plugins
# 3. call sft_trainer.train
with build_framework_and_maybe_instantiate(
[
(["training.moe.scattermoe"], ScatterMoEAccelerationPlugin),
],
instantiate=False,
):
with instantiate_model_patcher():
# 1. mock a plugin class
# 2. register the mocked plugins
# 3. call sft_trainer.train
with build_framework_and_maybe_instantiate(
[
(["training.moe.scattermoe"], ScatterMoEAccelerationPlugin),
],
instantiate=False,
):
with instantiate_model_patcher():
with pytest.raises((ValueError, AttributeError)):
sft_trainer.train(
model_args,
data_args,
Expand Down
59 changes: 57 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,61 @@ def test_run_moe_ft_and_inference_ep1_kernels(dataset_path, ep_degree):
)


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="moe"),
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
)
@pytest.mark.parametrize(
"target_modules",
[
"all-linear",
["q_proj"],
["q_proj", "k_proj"],
["q_proj", "k_proj", "v_proj"],
["q_proj", "k_proj", "v_proj", "o_proj"],
],
)
@pytest.mark.parametrize("ep_degree", [True, False])
@pytest.mark.parametrize("dataset_path", [TWITTER_COMPLAINTS_DATA_JSONL])
def test_run_moe_lora_and_inference(dataset_path, target_modules, ep_degree):
"""Check if we can finetune a moe model and check if hf checkpoint is created"""
with tempfile.TemporaryDirectory() as tempdir:
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = dataset_path
model_args = copy.deepcopy(MODEL_ARGS)
model_args.model_name_or_path = "ibm-granite/granite-3.1-1b-a400m-base"
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
lora_args = copy.deepcopy(PEFT_LORA_ARGS)
lora_args.r = 16
lora_args.target_modules = target_modules
fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=ep_degree))

if target_modules == "all-linear":
with pytest.raises(ValueError):
sft_trainer.train(
model_args,
data_args,
train_args,
lora_args,
fast_moe_config=fast_moe_config,
)
else:
sft_trainer.train(
model_args,
data_args,
train_args,
lora_args,
fast_moe_config=fast_moe_config,
)
_test_run_inference(
checkpoint_path=os.path.join(
_get_checkpoint_path(tempdir), "hf_converted_checkpoint"
),
base_model_name_or_path="ibm-granite/granite-3.1-1b-a400m-base",
)


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="moe"),
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
Expand Down Expand Up @@ -1491,9 +1546,9 @@ def _test_run_causallm_ft(training_args, model_args, data_args, tempdir):
_validate_training(tempdir)


def _test_run_inference(checkpoint_path):
def _test_run_inference(checkpoint_path, base_model_name_or_path=None):
# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)
loaded_model = TunedCausalLM.load(checkpoint_path, base_model_name_or_path)

# Run inference on the text
output_inference = loaded_model.run(
Expand Down
28 changes: 24 additions & 4 deletions tuning/config/acceleration_configs/fast_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dataclasses import dataclass, field
from typing import Union
import argparse
import json
import os

# Third Party
Expand Down Expand Up @@ -121,10 +122,29 @@ def checkpoint(checkpoint_dir, save_dir):
args,
os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME),
)
# Save model config files
self.trainer.model.config.save_pretrained(
hf_converted_output_dir
)

# Unwrap FSDP module
model = self.trainer.model
if hasattr(model, "module"):
model = model.module
Comment on lines +127 to +129
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed for non lora use case? since after this step it would directly reach the below block

                        else:
                            model.config.save_pretrained(hf_converted_output_dir)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's needed to determine if there is a peft config for the following if statement, but in fine-tuning case this if statement should be false


if hasattr(model, "peft_config"):
lora_config = model.peft_config["default"]
config_dict = lora_config.to_dict()
config_dict["target_modules"] = sorted(
list(config_dict["target_modules"])
)
with open(
os.path.join(
hf_converted_output_dir, "adapter_config.json"
),
"w",
encoding="utf-8",
) as f:
json.dump(config_dict, f, indent=2)

else:
model.config.save_pretrained(hf_converted_output_dir)

except Exception as e:
raise ValueError(
Expand Down
36 changes: 36 additions & 0 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,33 @@ def train(
"Trainer should not perform packing when using `--padding_free`"
)

if fast_moe_config is not None:
# Checking for unsupported modules with Scatter MoE for LoRA
# Only raise an error for `all-linear`
restricted_modules = ["all-linear"]
if (
peft_config is not None
and hasattr(peft_config, "target_modules")
and any(
module in (peft_config.target_modules or [])
for module in restricted_modules
)
):
raise ValueError(
"`--fast_moe` with LoRA does not currently support `all-linear`, as "
"target modules at this time. Please explicitly specify target "
"modules when using `--fast_moe` with LoRA."
)
# If other common non-linear modules, raise warning
if peft_config is not None and hasattr(peft_config, "target_modules"):
logger.warning(
"You are running lora with the ScatterMoE plugin, please note that "
"passing target modules that are part of the moe module can cause unexpected "
"behaviors and unsuccessful tuning while LoRA tuning with ScatterMoE. "
"For safe tuning, only pass linear modules such as those in the attn layer "
"(i.e. ['q_proj', 'v_proj', 'o_proj', 'k_proj'])"
)

task_type = "CAUSAL_LM"
additional_metrics = {}

Expand Down Expand Up @@ -360,6 +387,15 @@ def train(
model, (peft_config,) = framework.augmentation(
model, train_args, modifiable_args=(peft_config,)
)
# HACK - For LoRa ScatterMoE, disable grad for ScatterMoE.
# In the future, requires_grad should be enabled for LoRA tuning
# with ScatterMoE and this code should be removed.
if peft_config is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its better to just do an instance check using ScatterMoE and then freeze everything inside.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you should put a comment that this is a workaround, and in the future where ScatterMoE loras can be tuned, this code needs to be removed.

for module in model.modules():
# Use string comparison to check if ScatterMoE module
if module.__class__.__name__ == "ScatterMoE":
for param in module.parameters():
param.requires_grad = False

# HACK - The SFT Trainer has internal validation which inspects the name of the class
# being used for the HF training args; if it's a TrainingArguments class, which is
Expand Down