Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Removing Hard-coded module paths for Parallelization #4

Closed
HanGuo97 opened this issue Jan 26, 2025 · 7 comments
Closed

[RFC] Removing Hard-coded module paths for Parallelization #4

HanGuo97 opened this issue Jan 26, 2025 · 7 comments
Labels
enhancement New feature or request

Comments

@HanGuo97
Copy link
Collaborator

Proposal

The current parallelization utilities have hard-coded methods to obtain specific types of modules (e.g., layers, embedding, norms). For instance, the following line assumes that the model has a .model.layers attribute.

for layer_id, block in enumerate(model.model.layers):

This is not necessarily true for all models in the FLA library (Mamba2 uses .backbone). The folder contains a few other such instances.

I am considering adding a new file parallelisms/utils.py that allows for the registration of model classes and corresponding getters, as shown below:

# Register a model class
ModelRegistry.register(
    xxxForCausalLM,
    embeddings_path=model.embedding”,  # Custom path if different from default
    norm_path=model.norm”,
    lm_head_path=lm_head”,
    layers_path=model.layers”
)

# Utilize the registry in parallelisms/parallelize_fla.py
model = xxxForCausalLM(...)
embeddings = get_embeddings(model)
norm = get_norm(model)
lm_head = get_lm_head(model)
layers = get_layers(model)

Any thoughts? I would be happy to make a PR for that, but I am not familiar enough with the FLA library to determine if this is over-engineering. If most models indeed follow the hard-coded patterns, it might be simpler to just add some if-else statements within parallelize_fla.py.

Rationale

No response

@HanGuo97 HanGuo97 added the enhancement New feature or request label Jan 26, 2025
@yzhangcs
Copy link
Member

@HanGuo97 Nice suggestion! PRs are welcome lol. I'd be happy to collaborate on these in the next coming days!

@yzhangcs
Copy link
Member

yzhangcs commented Jan 27, 2025

@HanGuo97 feels like there are much simpler solutions.
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mamba/modeling_mamba.py#L385

We only need to revise very few things in fla and flame.

Update:

Fixed by fla-org/flash-linear-attention@7f9f83c

For now we can retrieve layers via

for i, layer in enumerate(getattr(model, model.base_model_prefix).layers):
    ...

@rakkit
Copy link
Contributor

rakkit commented Jan 27, 2025

general saying, in the context of FSDP, we don't need a complex patch. As long as all blocks can be accessed via model.model.layers, following HF's transformer xxxForCausalLM design typically. In principle, we should only use FSDP to shard the blocks. (Heads/embedding should go to TP if needed).

the actual problem comes from TP/CP, a possible way maybe is to add the rules for each model individually in a separate file.

@yzhangcs
Copy link
Member

yzhangcs commented Jan 27, 2025

@rakkit

the actual problem comes from TP/CP, a more realistic way is to add the rules for each model individually in a separate file.

Yeeeesssss! Found it's quite hard to define rules in this repo. Will consider handling TP/CP in fla model by model in the near future.

@rakkit
Copy link
Contributor

rakkit commented Jan 27, 2025

yip. maybe it's not that bad.

for TP, we can just add a function in the block, MLP and attention level to return the tp_plan, in fla.
here we can just call tp_plan=block.get_tp_plan(). And in fla lib and block level, we can recursively query mlp/attention's TP plan.

This should give us enough flexibility to automatically handle any combination of blocks in fla.

embedding, out norm and head is easier to deal with. but I am not sure if TP & FusedCrossEntropyLoss can work out of box,

@HanGuo97
Copy link
Collaborator Author

Ah, thanks for making this easy upstream!

FYI, I made a PR for the FSDP part #5. For the TP part, this seems to be a bit more annoying so I'm going to leave that out for now.

@yzhangcs
Copy link
Member

Closing this issue as the original hard-coded qustions have been solved. More discussions for 4d parallel could be found in fla-org/flash-linear-attention#148

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants