- 
                Notifications
    You must be signed in to change notification settings 
- Fork 31k
Closed
Labels
Description
System Info
- transformersversion: 4.57.1
- Platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.35.3
- Safetensors version: 0.6.2
- Accelerate version: not installed
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.7.1+cu126 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA GeForce RTX 3080
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
-  An officially supported task in the examplesfolder (such as GLUE/SQuAD, ...)
- My own task or dataset (give details below)
Reproduction
- 
Install flash-linear-attention 
- 
Run the following code: 
from transformers import AutoTokenizer
from transformers import Qwen3NextConfig
from transformers import Qwen3NextForCausalLM
config = Qwen3NextConfig(
    vocab_size=32000,
    hidden_size=896,
    intermediate_size=4864,
    num_hidden_layers=4,
    num_attention_heads=8,
    num_key_value_heads=2,
    hidden_act="silu",
    max_position_embeddings=32768,
    initializer_range=0.02,
    rms_norm_eps=1e-6,
    use_cache=True,
    tie_word_embeddings=False,
    rope_theta=10000.0,
    rope_scaling=None,
    attention_bias=False,
    use_sliding_window=False,
    sliding_window=4096,
    max_window_layers=28,
    attention_dropout=0.0,
    decoder_sparse_step=1,
    moe_intermediate_size=768,
    num_experts_per_tok=2,
    num_experts=8,
    norm_topk_prob=False,
    output_router_logits=False,
    router_aux_loss_coef=0.001,
    # dtype=torch.bfloat16,  <---- No issue if dtype is given
)
model = Qwen3NextForCausalLM(config)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt")
# Generate
generate_ids = model.generate(inputs.input_ids, max_length=30)
tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]error output:
AttributeError: module 'torch' has no attribute 'get_current_dtype'
Expected behavior
| dtype=config.dtype if config.dtype is not None else torch.get_current_dtype(), | 
I guess it's meant to call torch.get_default_dtype()?
Also, setting config.dtype=torch.float32 won't work with flash-linear-attention either since fla's ChunkGatedDeltaRuleFunction does not support float32 (torch's does support it if fla is installed). Should we raise the error during model initialization instead?
AssertionError: ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16.