Skip to content

Qwen3Next compatibilty issue with flash-linear-attention #41732

@Tcc0403

Description

@Tcc0403

System Info

  • transformers version: 4.57.1
  • Platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.35.3
  • Safetensors version: 0.6.2
  • Accelerate version: not installed
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.7.1+cu126 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 3080

Who can help?

@ArthurZucker @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Install flash-linear-attention

  2. Run the following code:

from transformers import AutoTokenizer
from transformers import Qwen3NextConfig
from transformers import Qwen3NextForCausalLM

config = Qwen3NextConfig(
    vocab_size=32000,
    hidden_size=896,
    intermediate_size=4864,
    num_hidden_layers=4,
    num_attention_heads=8,
    num_key_value_heads=2,
    hidden_act="silu",
    max_position_embeddings=32768,
    initializer_range=0.02,
    rms_norm_eps=1e-6,
    use_cache=True,
    tie_word_embeddings=False,
    rope_theta=10000.0,
    rope_scaling=None,
    attention_bias=False,
    use_sliding_window=False,
    sliding_window=4096,
    max_window_layers=28,
    attention_dropout=0.0,
    decoder_sparse_step=1,
    moe_intermediate_size=768,
    num_experts_per_tok=2,
    num_experts=8,
    norm_topk_prob=False,
    output_router_logits=False,
    router_aux_loss_coef=0.001,
    # dtype=torch.bfloat16,  <---- No issue if dtype is given
)
model = Qwen3NextForCausalLM(config)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")

prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt")

# Generate
generate_ids = model.generate(inputs.input_ids, max_length=30)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

error output:

AttributeError: module 'torch' has no attribute 'get_current_dtype'

Expected behavior

dtype=config.dtype if config.dtype is not None else torch.get_current_dtype(),

I guess it's meant to call torch.get_default_dtype()?

Also, setting config.dtype=torch.float32 won't work with flash-linear-attention either since fla's ChunkGatedDeltaRuleFunction does not support float32 (torch's does support it if fla is installed). Should we raise the error during model initialization instead?

AssertionError: ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions