Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fail to run Tiny-llama example #143

Open
shira-g opened this issue Dec 5, 2024 · 10 comments
Open

Fail to run Tiny-llama example #143

shira-g opened this issue Dec 5, 2024 · 10 comments

Comments

@shira-g
Copy link

shira-g commented Dec 5, 2024

Describe the bug
I run the following example: https://github.com/intel/intel-npu-acceleration-library?tab=readme-ov-file#run-a-tiny-llama-model-on-the-npu

and it fails with:
Traceback (most recent call last):
File "C:\Users\sdp\shira\npu_acc\our_bench.py", line 7, in
model = NPUModelForCausalLM.from_pretrained(model_id, use_cache=True, dtype=torch.int8).eval()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\sdp\miniconda3\envs\npu_acc\Lib\functools.py", line 388, in _method
return self.func(cls_or_self, *self.args, *args, **keywords)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: NPUModel.from_pretrained() missing 1 required positional argument: 'config'

  • System: Intel(R) Core(TM) Ultra 5 238V 2.10GHz
  • OS: Windows 11
@alessandropalla
Copy link
Contributor

Thanks for pointing out outdated documentation! You can find an example with the new API here: https://github.com/intel/intel-npu-acceleration-library/blob/main/examples/llama.py

@alessandropalla
Copy link
Contributor

Reopening until doc is updated

@ElProfessorFR
Copy link

@alessandropalla The example you gave doesn't work.

    from intel_npu_acceleration_library.compiler import CompilerConfig
ImportError: cannot import name 'CompilerConfig' from 'intel_npu_acceleration_library.compiler'

@alessandropalla
Copy link
Contributor

@alessandropalla The example you gave doesn't work.

    from intel_npu_acceleration_library.compiler import CompilerConfig
ImportError: cannot import name 'CompilerConfig' from 'intel_npu_acceleration_library.compiler'

you need to use the latest library release

@Jaylyn-Barbee
Copy link

Jaylyn-Barbee commented Jan 13, 2025

@alessandropalla I get a different error when attempting to run the update Llama file you linked above.
Python 3.11, Windows 11

PS D:\source\repos\PythonNPUTest> python .\llama_new.py
Compiling model TinyLlama/TinyLlama-1.1B-Chat-v1.0 int4 for the NPU
Traceback (most recent call last):
  File "D:\source\repos\PythonNPUTest\llama_new.py", line 13, in <module>
    model = NPUModelForCausalLM.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.11_3.11.2544.0_x64__qbz5n2kfra8p0\Lib\functools.py", line 388, in _method
    return self.func(cls_or_self, *self.args, *args, **keywords)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jaylynbarbee\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\intel_npu_acceleration_library\modelling.py", line 104, in from_pretrained
    model = npu_lib.compile(model, config)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jaylynbarbee\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\intel_npu_acceleration_library\compiler.py", line 67, in compile
    apply_general_optimizations(model)
  File "C:\Users\jaylynbarbee\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\intel_npu_acceleration_library\compiler.py", line 91, in apply_general_optimizations
    optimize_llama_attention(model)
  File "C:\Users\jaylynbarbee\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\intel_npu_acceleration_library\compiler.py", line 139, in wrapper
    wrapper(layer, *args, **kwargs)
  File "C:\Users\jaylynbarbee\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\intel_npu_acceleration_library\compiler.py", line 139, in wrapper
    wrapper(layer, *args, **kwargs)
  File "C:\Users\jaylynbarbee\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\intel_npu_acceleration_library\compiler.py", line 139, in wrapper
    wrapper(layer, *args, **kwargs)
  File "C:\Users\jaylynbarbee\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\intel_npu_acceleration_library\compiler.py", line 128, in wrapper
    new_layer = func(name, layer, *args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jaylynbarbee\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\intel_npu_acceleration_library\compiler.py", line 208, in optimize_llama_attention
    return nn.LlamaAttention.fromTorch(layer)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jaylynbarbee\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\intel_npu_acceleration_library\nn\llm.py", line 267, in fromTorch
    rotary_emb=layer.rotary_emb,
               ^^^^^^^^^^^^^^^^
  File "C:\Users\jaylynbarbee\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\nn\modules\module.py", line 1931, in __getattr__
    raise AttributeError(
AttributeError: 'LlamaAttention' object has no attribute 'rotary_emb'

@richardanichols
Copy link

I was also having similar issues, but I was eventually able to find workarounds for each of them and get things running.

Here is my setup:

  • Intel(R) Core(TM) Ultra 7 265K
  • Windows 11
  • Python 3.12.8 (also tried with Python 3.11, same behavior)
  • torch 2.5.1
  • transformers 4.48.1
  • intel_npu_acceleration_library 1.4.0

I first got the same error that @shira-g reported:

TypeError: NPUModel.from_pretrained() missing 1 required positional argument: 'config'

Following the advice of @alessandropalla got me past that one.

Then I encountered the same error as @Jaylyn-Barbee:

AttributeError: 'LlamaAttention' object has no attribute 'rotary_emb'

I ended up writing a function in my script to modify the model object to give the attention layers the rotary_emb attribute that the npu_library was expecting (see below).

Then, I got a new error:

File "C:\Python312\Lib\site-packages\transformers\models\llama\modeling_llama.py", line 332, in forward
hidden_states, self_attn_weights = self.self_attn(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: too many values to unpack (expected 2)

My fix for this error was to modify \Lib\site-packages\intel_npu_acceleration_library\nn\llm.py, line 245 from this:
return attn_output, None, past_key_value
to this:
return attn_output, None

That got things running for me.

My complete script (I pieced together a few different scripts, so it's a little different from the original Intel example):

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import intel_npu_acceleration_library
from intel_npu_acceleration_library.compiler import CompilerConfig
from intel_npu_acceleration_library.nn.module import NPUModuleWrapper
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding
from transformers.models.gemma.modeling_gemma import GemmaAttention, GemmaRotaryEmbedding

def fix_npu_model(model: torch.nn.Module):
    if not isinstance(model, NPUModuleWrapper):
        for _, layer in model.named_children():
            if not hasattr(layer, 'rotary_emb'):
                if isinstance(layer, LlamaAttention):
                    layer.rotary_emb = LlamaRotaryEmbedding
                if isinstance(layer, GemmaAttention):
                    layer.rotary_emb = GemmaRotaryEmbedding
            if not isinstance(layer, NPUModuleWrapper):
                fix_npu_model(layer)

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=True).eval()
fix_npu_model(model)

print("Compile model for the NPU")
compiler_conf = CompilerConfig(dtype=intel_npu_acceleration_library.int8)
model = intel_npu_acceleration_library.compile(model, config=compiler_conf)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
messages = [
    {
        "role": "system",
        "content": "You are a friendly chatbot who always responds in a helpful way",
    },
    {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
]
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
print(outputs[0]["generated_text"])

I'm new to pytorch, transformers, and the intel_npu_acceleration_library, so I don't have a specific recommendation on how to fix the library and/or the official example script - but maybe someone will find my workarounds helpful.

@lcwyylcwyy
Copy link

lcwyylcwyy commented Feb 2, 2025

I borrowed the idea from @richardanichols and fixed the code as follows:

#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#

from transformers import TextStreamer
import intel_npu_acceleration_library
from intel_npu_acceleration_library.compiler import CompilerConfig
from intel_npu_acceleration_library.nn.module import NPUModuleWrapper
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding
from transformers.models.gemma.modeling_gemma import GemmaAttention, GemmaRotaryEmbedding
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
def fix_npu_model(model: torch.nn.Module):
    if not isinstance(model, NPUModuleWrapper):
        for _, layer in model.named_children():
            if not hasattr(layer, 'rotary_emb'):
                if isinstance(layer, LlamaAttention):
                    layer.rotary_emb = LlamaRotaryEmbedding
                if isinstance(layer, GemmaAttention):
                    layer.rotary_emb = GemmaRotaryEmbedding
            if not isinstance(layer, NPUModuleWrapper):
                fix_npu_model(layer)


model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=True,  attn_implementation="sdpa")
fix_npu_model(model)

compiler_conf = CompilerConfig(dtype=intel_npu_acceleration_library.int8)
model = intel_npu_acceleration_library.compile(model, config=compiler_conf).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
streamer = TextStreamer(tokenizer, skip_special_tokens=True)


query = "Hello, who are you?"
prefix = tokenizer(query, return_tensors="pt")["input_ids"]


generation_kwargs = dict(
    input_ids=prefix,
    streamer=streamer,
    do_sample=True,
    top_k=50,
    top_p=0.9,
    max_new_tokens=512,
)

print("Run inference")
_ = model.generate(**generation_kwargs)

The output is

Run inference
Hello, who are you?
I have a very clear image of you
I can see your entire profile
And all your personal details

[Verse 2]
You're kind of funny, you're clever
You're not like everyone else
You know what you want, and you get it
You're a catch, you got me intrigued

[Chorus]
Who are you?
Who are you?
Who are you?
Who are you?

[Verse 3]
You're a bit of a mystery
A mystery wrapped up in a package
You're the one who keeps me guessing
You're a charmer, you're a lass

[Chorus]
Who are you?
Who are you?
Who are you?
Who are you?

[Bridge]
Your personality's a mystery, a whirlwind
Its vibe is strong, I feel the pull
A spark, a fire, you're what I need
You're the reason I sing

[Chorus]
Who are you?
Who are you?
Who are you?
Who are you?

[Outro]
Who are you?
Who are you?
Who are you?
Who are you?

[Tagline]
Dare to be different, and let's explore
Who are you?
Who are you?
Who are you?


I don't know whether is it correct. But it works for me

@SuperFico2100
Copy link

@lcwyylcwyy I tried yours and it gives me this output:

Run inference
Hello, who are Traceback (most recent call last):
  File "<frozen runpy>", line 189, in _run_module_as_main
  File "<frozen runpy>", line 112, in _get_module_details
  File "C:\Users\Sterco\Documents\coding\python\LLM\test3.py", line 52, in <module>
    _ = model.generate(**generation_kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Sterco\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Sterco\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\generation\utils.py", line 2255, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "C:\Users\Sterco\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\generation\utils.py", line 3254, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Sterco\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Sterco\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Sterco\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\models\llama\modeling_llama.py", line 831, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "C:\Users\Sterco\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Sterco\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Sterco\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\models\llama\modeling_llama.py", line 589, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "C:\Users\Sterco\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Sterco\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Sterco\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\models\llama\modeling_llama.py", line 332, in forward
    hidden_states, self_attn_weights = self.self_attn(
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: too many values to unpack (expected 2)

@lcwyylcwyy
Copy link

lcwyylcwyy commented Feb 3, 2025

@SuperFico2100 Yes, I found the reason, in the /home/xxxxx/anaconda3/envs/intel_npu/lib/python3.12/site-packages/intel_npu_acceleration_library/nn/llm.py line 245, need to change the output from

return attn_output, None, past_key_value

to

return attn_output, None

@lcwyylcwyy
Copy link

lcwyylcwyy commented Feb 3, 2025

I borrowed the idea from @richardanichols and fixed the code as follows:

#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#

from transformers import TextStreamer
import intel_npu_acceleration_library
from intel_npu_acceleration_library.compiler import CompilerConfig
from intel_npu_acceleration_library.nn.module import NPUModuleWrapper
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding
from transformers.models.gemma.modeling_gemma import GemmaAttention, GemmaRotaryEmbedding
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
def fix_npu_model(model: torch.nn.Module):
    if not isinstance(model, NPUModuleWrapper):
        for _, layer in model.named_children():
            if not hasattr(layer, 'rotary_emb'):
                if isinstance(layer, LlamaAttention):
                    layer.rotary_emb = LlamaRotaryEmbedding
                if isinstance(layer, GemmaAttention):
                    layer.rotary_emb = GemmaRotaryEmbedding
            if not isinstance(layer, NPUModuleWrapper):
                fix_npu_model(layer)


model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=True,  attn_implementation="sdpa")
fix_npu_model(model)

compiler_conf = CompilerConfig(dtype=intel_npu_acceleration_library.int8)
model = intel_npu_acceleration_library.compile(model, config=compiler_conf).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
streamer = TextStreamer(tokenizer, skip_special_tokens=True)


query = "Hello, who are you?"
prefix = tokenizer(query, return_tensors="pt")["input_ids"]


generation_kwargs = dict(
    input_ids=prefix,
    streamer=streamer,
    do_sample=True,
    top_k=50,
    top_p=0.9,
    max_new_tokens=512,
)

print("Run inference")
_ = model.generate(**generation_kwargs)

The output is

Run inference
Hello, who are you?
I have a very clear image of you
I can see your entire profile
And all your personal details

[Verse 2]
You're kind of funny, you're clever
You're not like everyone else
You know what you want, and you get it
You're a catch, you got me intrigued

[Chorus]
Who are you?
Who are you?
Who are you?
Who are you?

[Verse 3]
You're a bit of a mystery
A mystery wrapped up in a package
You're the one who keeps me guessing
You're a charmer, you're a lass

[Chorus]
Who are you?
Who are you?
Who are you?
Who are you?

[Bridge]
Your personality's a mystery, a whirlwind
Its vibe is strong, I feel the pull
A spark, a fire, you're what I need
You're the reason I sing

[Chorus]
Who are you?
Who are you?
Who are you?
Who are you?

[Outro]
Who are you?
Who are you?
Who are you?
Who are you?

[Tagline]
Dare to be different, and let's explore
Who are you?
Who are you?
Who are you?

I don't know whether is it correct. But it works for me

Thanks to @SuperFico2100
the code does not work on the intel_npu_acceleration_library 1.4.0, until change the ‘intel_npu_acceleration_library/nn/llm.py‘’’ line 245, from

return attn_output, None, past_key_value

to

return attn_output, None

Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants