Skip to content

Commit 6e3d988

Browse files
committed
get hidream transformer fully torch.compile compatible.
1 parent 1d1e715 commit 6e3d988

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,9 @@ def forward(self, x):
389389
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
390390
expert_cache = torch.zeros_like(x)
391391
idxs = flat_expert_indices.argsort()
392-
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
392+
count_freq = torch.bincount(flat_expert_indices, minlength=self.num_activated_experts)
393+
tokens_per_expert = count_freq.cumsum(dim=0)
394+
393395
token_idxs = idxs // self.num_activated_experts
394396
for i, end_idx in enumerate(tokens_per_expert):
395397
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]

tests/models/transformers/test_models_transformer_hidream.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
from diffusers import HiDreamImageTransformer2DModel
2121
from diffusers.utils.testing_utils import (
2222
enable_full_determinism,
23+
is_torch_compile,
24+
require_torch_2,
25+
require_torch_gpu,
26+
slow,
2327
torch_device,
2428
)
2529

@@ -94,3 +98,20 @@ def test_set_attn_processor_for_determinism(self):
9498
def test_gradient_checkpointing_is_applied(self):
9599
expected_set = {"HiDreamImageTransformer2DModel"}
96100
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
101+
102+
@require_torch_gpu
103+
@require_torch_2
104+
@is_torch_compile
105+
@slow
106+
def test_torch_compile_recompilation_and_graph_break(self):
107+
torch._dynamo.reset()
108+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
109+
110+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
111+
112+
model = self.model_class(**init_dict).to(torch_device)
113+
model = torch.compile(model, fullgraph=True)
114+
115+
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
116+
_ = model(**inputs_dict)
117+
_ = model(**inputs_dict)

0 commit comments

Comments
 (0)