-
Notifications
You must be signed in to change notification settings - Fork 6k
[torch.compile] Make HiDream torch.compile ready #11477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) | ||
count_freq = torch.bincount(flat_expert_indices, minlength=self.num_activated_experts) | ||
tokens_per_expert = count_freq.cumsum(dim=0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just reimplemented it to eliminate the numpy()
dependency.
@require_torch_2 | ||
@is_torch_compile | ||
@slow | ||
def test_torch_compile_recompilation_and_graph_break(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Relevant test for this PR.
The graph break seems to be induced by
@anijain2305 is this known? |
Even if we remove the decorator, it still fails with the same error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Edit - Checked the messages, missed that there is still a graph break. I can take a look today.
Thanks! Appreciate it.
…On Thu, 8 May 2025 at 7:06 PM, Animesh Jain ***@***.***> wrote:
***@***.**** approved this pull request.
LGTM
—
Reply to this email directly, view it on GitHub
<#11477 (review)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFPE2TCL5CXY6KDTOMUTXM325NMWRAVCNFSM6AAAAAB4H4U5SOVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDQMRVGE3DSNZQHE>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
Not a useful update. But there seems to be some dynamic shapes graph break here coming from moe_infer function. cc @laithsakka |
Okay I think I know why this is happening. The line that primarily causes this shape change is:
This is why the So, I tried with torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch.fx.experimental._config.use_duck_shape = False It then complains:
Keeping this open maybe for better tracking. |
Cc: @StrongerXi for the above observation too. |
On it. |
Okay I spent some time digging into the MOE stuff, here's what I learned:
Then I just have to fix a small graph break here, where
diffusers/src/diffusers/models/transformers/transformer_hidream_image.py Lines 718 to 720 in dacae33
The fix is simple: # create img_sizes
#img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
#img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
img_sizes = [[patch_height, patch_width]] * batch_size Here are the e2e pipeline benchmark results using the hidream demo script, and compiling the transformer:
I also saw that ComfyUI uses the training branch too. So maybe we should just use the training branch in eager as well? Or we could add a |
Wow, this is terrific KT. Thanks, Ryan!
This is a good approach and is worth adding. @yiyixuxu what are your thoughts? |
What does this PR do?
Part of #11430
Trying to make the HiDream model fully compatible with
torch.compile()
but it fails with:https://pastebin.com/EbCFqBvw
To reproduce run the following from a GPU machine:
RUN_COMPILE=1 RUN_SLOW=1 pytest tests/models/transformers/test_models_transformer_hidream.py -k "test_torch_compile_recompilation_and_graph_break"
I am on the following env:
@anijain2305 @StrongerXi would you have any pointers?