Skip to content

Commit

Permalink
Fix masked_fill inputs are not on the same device for torch.compile.
Browse files Browse the repository at this point in the history
  • Loading branch information
libinta committed Dec 12, 2023
1 parent 3aa707e commit cbf3e1a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 0 deletions.
8 changes: 8 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,14 @@ def generate():
duration = time.perf_counter() - t0
total_new_tokens_generated = args.n_iterations * args.batch_size * args.max_new_tokens
throughput = total_new_tokens_generated / duration
print(
"LIBIN DEBUG THROUGHPUT ",
throughput,
" total_new_tokens_generated ",
total_new_tokens_generated,
", duration ",
duration,
)

print()
print("Input/outputs:")
Expand Down
6 changes: 6 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
gaudi_codegen_block_forward,
gaudi_codegen_model_forward,
gaudi_conv1d_forward,
gaudi_distilbert_multiheadselfattention_forward,
gaudi_esm_for_protein_folding_forward,
gaudi_esmfolding_trunk_forward,
gaudi_falcon_attention_forward,
Expand Down Expand Up @@ -268,3 +269,8 @@ def adapt_transformers_to_gaudi():
transformers.models.mistral.modeling_mistral.MistralModel.forward = gaudi_mistral_model_forward
transformers.models.mistral.modeling_mistral.MistralAttention.forward = gaudi_mistral_attn_forward
transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = gaudi_mistral_decoder_layer_forward

# Optimization for distilbert on Gaudi for torch.compile
transformers.models.distilbert.modeling_distilbert.MultiHeadSelfAttention.forward = (
gaudi_distilbert_multiheadselfattention_forward
)
3 changes: 3 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
gaudi_codegen_block_forward,
gaudi_codegen_model_forward,
)
from .distilbert import (
gaudi_distilbert_multiheadselfattention_forward,
)
from .esm import (
gaudi_esm_for_protein_folding_forward,
gaudi_esmfolding_trunk_forward,
Expand Down

0 comments on commit cbf3e1a

Please sign in to comment.