-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Removing Hard-coded module paths for Parallelization #4
Comments
@HanGuo97 Nice suggestion! PRs are welcome lol. I'd be happy to collaborate on these in the next coming days! |
@HanGuo97 feels like there are much simpler solutions. We only need to revise very few things in fla and flame. Update: Fixed by fla-org/flash-linear-attention@7f9f83c For now we can retrieve layers via for i, layer in enumerate(getattr(model, model.base_model_prefix).layers):
... |
general saying, in the context of FSDP, we don't need a complex patch. As long as all blocks can be accessed via the actual problem comes from TP/CP, a possible way maybe is to add the rules for each model individually in a separate file. |
Yeeeesssss! Found it's quite hard to define rules in this repo. Will consider handling TP/CP in |
yip. maybe it's not that bad. for TP, we can just add a function in the This should give us enough flexibility to automatically handle any combination of embedding, out norm and head is easier to deal with. but I am not sure if TP & FusedCrossEntropyLoss can work out of box, |
Ah, thanks for making this easy upstream! FYI, I made a PR for the FSDP part #5. For the TP part, this seems to be a bit more annoying so I'm going to leave that out for now. |
Closing this issue as the original hard-coded qustions have been solved. More discussions for 4d parallel could be found in fla-org/flash-linear-attention#148 |
Proposal
The current parallelization utilities have hard-coded methods to obtain specific types of modules (e.g., layers, embedding, norms). For instance, the following line assumes that the model has a
.model.layers
attribute.flame/flame/parallelisms/parallelize_fla.py
Line 239 in 816c326
This is not necessarily true for all models in the FLA library (Mamba2 uses
.backbone
). The folder contains a few other such instances.I am considering adding a new file
parallelisms/utils.py
that allows for the registration of model classes and corresponding getters, as shown below:Any thoughts? I would be happy to make a PR for that, but I am not familiar enough with the FLA library to determine if this is over-engineering. If most models indeed follow the hard-coded patterns, it might be simpler to just add some if-else statements within
parallelize_fla.py
.Rationale
No response
The text was updated successfully, but these errors were encountered: