Skip to content

[torch.compile] Make HiDream torch.compile ready #11477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented May 1, 2025

What does this PR do?

Part of #11430

Trying to make the HiDream model fully compatible with torch.compile() but it fails with:
https://pastebin.com/EbCFqBvw

To reproduce run the following from a GPU machine:

RUN_COMPILE=1 RUN_SLOW=1 pytest tests/models/transformers/test_models_transformer_hidream.py -k "test_torch_compile_recompilation_and_graph_break"

I am on the following env:

- 🤗 Diffusers version: 0.34.0.dev0
- Platform: Linux-6.8.0-55-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.10.12
- PyTorch version (GPU?): 2.7.0+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.30.2
- Transformers version: 4.51.3
- Accelerate version: 1.6.0.dev0
- PEFT version: 0.15.2.dev0
- Bitsandbytes version: 0.45.3
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
NVIDIA GeForce RTX 4090, 24564 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

@anijain2305 @StrongerXi would you have any pointers?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines -392 to +394
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
count_freq = torch.bincount(flat_expert_indices, minlength=self.num_activated_experts)
tokens_per_expert = count_freq.cumsum(dim=0)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just reimplemented it to eliminate the numpy() dependency.

@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relevant test for this PR.

@StrongerXi
Copy link

The graph break seems to be induced by @torch.no_grad:

@anijain2305 is this known?

@sayakpaul
Copy link
Member Author

The graph break seems to be induced by @torch.no_grad:

@anijain2305 is this known?

Even if we remove the decorator, it still fails with the same error.

Copy link
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Edit - Checked the messages, missed that there is still a graph break. I can take a look today.

@sayakpaul
Copy link
Member Author

sayakpaul commented May 8, 2025 via email

@anijain2305
Copy link
Contributor

Not a useful update. But there seems to be some dynamic shapes graph break here coming from moe_infer function.

cc @laithsakka

@sayakpaul
Copy link
Member Author

@anijain2305

Okay I think I know why this is happening. The line that primarily causes this shape change is:

hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)

This is why the moe_infer() function, when called with single_stream_blocks, complains about the shape changes.

So, I tried with dynamic=True along with

torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch.fx.experimental._config.use_duck_shape = False

It then complains:

msg = 'dynamic shape operator: aten.bincount.default; Operator does not have a meta kernel that supports dynamic output shapes, please report an issue to PyTorch'

Keeping this open maybe for better tracking.

@sayakpaul
Copy link
Member Author

Cc: @StrongerXi for the above observation too.

@StrongerXi
Copy link

On it.

@sayakpaul sayakpaul added performance Anything related to performance improvements, profiling and benchmarking torch.compile labels Jun 10, 2025
@StrongerXi
Copy link

Okay I spent some time digging into the MOE stuff, here's what I learned:

  1. HiDream has 2 branches in the MOE FFN layer, and looks like the moe_infer branch is meant to speed up inference as it explicitly skips the experts without tokens. However, that's really bad for torch.compile because (a). it creates a hard-to-resolve data dependency (the branching depends on output of torch.bincount which depends on the data of flat_expert_indices, and (b). even if we solve (a), we'd face lots of recompilations, because torch.compile would compile for each possible execution path (e.g., expert 1 & 3 firing, or expert 1, 2, 4 firing, etc.).

    for i, end_idx in enumerate(tokens_per_expert):
    start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
    if start_idx == end_idx:
    continue

  2. Then I did some benchmark, and turns out the moe_infer isn't faster than the "training branch", and they produce identical output images, and torch.compile produces much lower e2e latency using the "training branch":

    if self.training and not self._force_inference_output:
    x = x.repeat_interleave(self.num_activated_experts, dim=0)
    y = torch.empty_like(x, dtype=wtype)
    for i, expert in enumerate(self.experts):
    y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
    y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
    y = y.view(*orig_shape).to(dtype=wtype)
    # y = AddAuxiliaryLoss.apply(y, aux_loss)

Then I just have to fix a small graph break here, where img_sizes is supposed to be a List[Tuple[int, int]] but got computed as tensors:

def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
# create img_sizes
img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)

The fix is simple:

        # create img_sizes
        #img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
        #img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
        img_sizes = [[patch_height, patch_width]] * batch_size

Here are the e2e pipeline benchmark results using the hidream demo script, and compiling the transformer:

# pytorch 2.7.1
#
# original eager:     26.6s, compiled 24.8s (fullgraph=False)
# train-branch eager: 25.9s, compiled 19.5s (fullgraph=True)

I also saw that ComfyUI uses the training branch too. So maybe we should just use the training branch in eager as well? Or we could add a torch.compiler.is_compiling() to use the training branch under compile only. What do you think @sayakpaul?

@sayakpaul
Copy link
Member Author

Wow, this is terrific KT. Thanks, Ryan!

Or we could add a torch.compiler.is_compiling() to use the training branch under compile only.

This is a good approach and is worth adding. @yiyixuxu what are your thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Anything related to performance improvements, profiling and benchmarking torch.compile
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants