Skip to content

Jina model support [experimental] #502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions QEfficient/transformers/embeddings/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ def cls_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor)
}


def embedding_forward(
self, input_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, **kwargs
):
print("Forward swapped with new one")
output = self.old_forward(input_ids=input_ids, position_ids=position_ids, **kwargs)
return output[0]

class PooledModel(nn.Module):
"""
Adds pooling functionality to embedding model.
Expand All @@ -92,10 +99,10 @@ def __init__(self, base_model, pooling_fn):
self.pooling_fn = pooling_fn

def forward(
self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs
self, input_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, **kwargs
):
output = self.base_model(input_ids, attention_mask, **kwargs)
return self.pooling_fn(output[0], attention_mask)
output = self.base_model(input_ids, position_ids, **kwargs)
return self.pooling_fn(output[0], position_ids)


def validate_user_pooling_function(user_function):
Expand All @@ -119,7 +126,7 @@ def validate_user_pooling_function(user_function):
raise TypeError("Provided pooling function is not callable.")

sig = inspect.signature(user_function)
required_args = {"last_hidden_states", "attention_mask"}
required_args = {"last_hidden_states", "position_ids"}
if not required_args.issubset(sig.parameters.keys()):
raise ValueError(f"Pooling function must accept arguments: {required_args}")
return user_function
32 changes: 19 additions & 13 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH
from QEfficient.transformers.models.pytorch_transforms import (
CustomOpsTransform,
EmbeddingTransform,
KVCacheExternalModuleMapperTransform,
KVCacheTransform,
PoolingTransform,
Expand Down Expand Up @@ -160,7 +161,7 @@ class QEFFAutoModel(QEFFTransformersBase):
"""

_hf_auto_class = AutoModel
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, EmbeddingTransform]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.Module, pooling=None, **kwargs):
Expand Down Expand Up @@ -267,10 +268,10 @@ def export(self, export_dir: Optional[str] = None) -> str:

example_inputs = {
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
"attention_mask": torch.ones((bs, seq_len), dtype=torch.int64),
"position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
}

dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}}
dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}}

output_names = ["output"]

Expand Down Expand Up @@ -396,32 +397,37 @@ def cloud_ai_100_feature_generate(
# To handle single seq_len as we can't fetch allowed shapes for single seq_len
self.seq_len = self.qpc_session.bindings[0].dims[1] if not hasattr(self, "seq_len") else self.seq_len

qpc_inputs = {}
input_ids = np.array(
torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - input_ids_len), "constant", 0)
)
attention_mask = np.array(
torch.nn.functional.pad(
inputs["attention_mask"], (0, self.seq_len - inputs["attention_mask"].size(1)), "constant", 0
qpc_inputs["input_ids"] = input_ids
qpc_input_names=self.qpc_session.input_names

if "position_ids" in qpc_input_names:
attention_mask = np.array(
torch.nn.functional.pad(
inputs["attention_mask"], (0, self.seq_len - inputs["attention_mask"].size(1)), "constant", 0
)
)
)

inputs = dict(input_ids=input_ids, attention_mask=attention_mask)
position_ids = np.where(attention_mask == 1, np.arange(attention_mask.shape[1]), -1)
qpc_inputs["position_ids"] = position_ids

# TODO: Remove try and catch after compiler fix
try:
outputs = {
"output": np.random.randn(*list(self.qpc_session.bindings[2].dims)).astype(np.float32),
"output": np.random.randn(*list(self.qpc_session.bindings[-1].dims)).astype(np.float32),
}
self.qpc_session.set_buffers(outputs)
outputs = self.qpc_session.run(inputs)
outputs = self.qpc_session.run(qpc_inputs)
except Exception:
outputs = {
"output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[1]).astype(
"output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[-1].dims[1]).astype(
np.float32
),
}
self.qpc_session.set_buffers(outputs)
outputs = self.qpc_session.run(inputs)
outputs = self.qpc_session.run(qpc_inputs)
return outputs

def pytorch_feature_generate(self, model, inputs: Union[torch.Tensor, np.ndarray]) -> List[torch.Tensor]:
Expand Down
11 changes: 10 additions & 1 deletion QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,322 +5,322 @@
#
# -----------------------------------------------------------------------------

import warnings
from types import MethodType
from typing import Callable, Optional, Tuple, Union

from torch import nn
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
CodeGenBlock,
CodeGenForCausalLM,
CodeGenModel,
)
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
FalconDecoderLayer,
FalconForCausalLM,
FalconModel,
)
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaDecoderLayer,
GemmaForCausalLM,
GemmaModel,
GemmaRMSNorm,
)
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2Attention,
Gemma2DecoderLayer,
Gemma2ForCausalLM,
Gemma2Model,
Gemma2RMSNorm,
)
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3Attention,
Gemma3DecoderLayer,
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
Gemma3RMSNorm,
Gemma3TextModel,
)
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
GPTBigCodeAttention,
GPTBigCodeBlock,
GPTBigCodeForCausalLM,
GPTBigCodeModel,
)
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJForCausalLM, GPTJModel
from transformers.models.granite.modeling_granite import (
GraniteAttention,
GraniteForCausalLM,
GraniteModel,
GraniteRMSNorm,
)
from transformers.models.granitemoe.modeling_granitemoe import (
GraniteMoeAttention,
GraniteMoeForCausalLM,
GraniteMoeModel,
GraniteMoeMoE,
GraniteMoeParallelExperts,
GraniteMoeRMSNorm,
GraniteMoeRotaryEmbedding,
GraniteMoeTopKGating,
)
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)
from transformers.models.llama4.modeling_llama4 import (
Llama4ForCausalLM,
Llama4ForConditionalGeneration,
Llama4TextAttention,
Llama4TextDecoderLayer,
Llama4TextExperts,
Llama4TextModel,
Llama4TextMoe,
Llama4TextRMSNorm,
Llama4VisionAttention,
Llama4VisionModel,
)
from transformers.models.llava.modeling_llava import (
LlavaForConditionalGeneration,
)
from transformers.models.llava_next.modeling_llava_next import (
LlavaNextForConditionalGeneration,
)
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralDecoderLayer,
MistralForCausalLM,
MistralModel,
MistralRMSNorm,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralDecoderLayer,
MixtralForCausalLM,
MixtralModel,
MixtralRMSNorm,
MixtralSparseMoeBlock,
)
from transformers.models.mllama.modeling_mllama import (
MllamaCrossAttentionDecoderLayer,
MllamaForCausalLM,
MllamaForConditionalGeneration,
MllamaRotaryEmbedding,
MllamaSelfAttentionDecoderLayer,
MllamaTextCrossAttention,
MllamaTextModel,
MllamaTextRMSNorm,
MllamaTextSelfAttention,
MllamaVisionModel,
)
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel
from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel
from transformers.models.phi3.modeling_phi3 import (
Phi3Attention,
Phi3DecoderLayer,
Phi3ForCausalLM,
Phi3Model,
Phi3RMSNorm,
)
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2ForCausalLM,
Qwen2Model,
Qwen2RMSNorm,
)
from transformers.models.starcoder2.modeling_starcoder2 import (
Starcoder2Attention,
Starcoder2DecoderLayer,
Starcoder2ForCausalLM,
Starcoder2Model,
)
from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder,
WhisperDecoderLayer,
WhisperEncoder,
WhisperForConditionalGeneration,
WhisperModel,
WhisperPositionalEmbedding,
)

from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform
from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC
from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function
from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, embedding_forward, validate_user_pooling_function
from QEfficient.transformers.models.codegen.modeling_codegen import (
QEffCodeGenAttention,
QeffCodeGenBlock,
QEffCodeGenForCausalLM,
QEffCodeGenModel,
)
from QEfficient.transformers.models.falcon.modeling_falcon import (
QEffFalconAttention,
QEffFalconDecoderLayer,
QEffFalconForCausalLM,
QEffFalconModel,
)
from QEfficient.transformers.models.gemma.modeling_gemma import (
QEffGemmaAttention,
QEffGemmaDecoderLayer,
QEffGemmaForCausalLM,
QEffGemmaModel,
)
from QEfficient.transformers.models.gemma2.modeling_gemma2 import (
QEffGemma2Attention,
QEffGemma2DecoderLayer,
QEffGemma2ForCausalLM,
QEffGemma2Model,
)
from QEfficient.transformers.models.gemma3.modeling_gemma3 import (
QEffGemma3Attention,
QEffGemma3CustomRMSNormAIC,
QEffGemma3DecoderLayer,
QEffGemma3ForCausalLMModel,
QEffGemma3ForConditionalGeneration,
QEffGemma3TextModel,
)
from QEfficient.transformers.models.gpt2.modeling_gpt2 import (
QEffGPT2Attention,
QEffGPT2Block,
QEffGPT2LMHeadModel,
QEffGPT2Model,
)
from QEfficient.transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
QEffGPTBigCodeAttention,
QEffGPTBigCodeBlock,
QEffGPTBigCodeForCausalLM,
QEffGPTBigCodeModel,
)
from QEfficient.transformers.models.gptj.modeling_gptj import (
QEffGPTJAttention,
QEffGPTJBlock,
QEffGPTJForCausalLM,
QEffGPTJModel,
)
from QEfficient.transformers.models.granite.modeling_granite import (
QEffGraniteAttention,
QEffGraniteForCausalLM,
QEffGraniteModel,
)
from QEfficient.transformers.models.granitemoe.modeling_granitemoe import (
QEffGraniteMoeAttention,
QEffGraniteMoeForCausalLM,
QEffGraniteMoeModel,
QEffGraniteMoeMoE,
QEffGraniteMoeParallelExperts,
QEffGraniteMoeRotaryEmbedding,
QEffGraniteMoeTopKGating,
)
from QEfficient.transformers.models.grok_1.modeling_grok1 import (
QEFFGrok1CustomRMSNormAIC,
QEffGrok1DecoderLayer,
QEffGrok1Model,
QEffGrok1ModelForCausalLM,
QEffGrok1MoeBlock,
QEffGrok1MultiHeadAttention,
)
from QEfficient.transformers.models.internvl.modeling_internvl import (
QEffInternVisionEmbeddings,
QEffInternVLModel,
)
from QEfficient.transformers.models.llama.modeling_llama import (
QEffLlamaAttention,
QEffLlamaDecoderLayer,
QEffLlamaForCausalLM,
QEffLlamaModel,
)
from QEfficient.transformers.models.llama4.modeling_llama4 import (
QEffLlama4ForCausalLM,
QEffLlama4ForConditionalGeneration,
QEffLlama4TextAttention,
QEffLlama4TextDecoderLayer,
QEffLlama4TextExperts,
QEffLlama4TextModel,
QEffLlama4TextMoe,
QEffLlama4VisionAttention,
QEffLlama4VisionModel,
)
from QEfficient.transformers.models.llava.modeling_llava import (
QEffLlavaForConditionalGeneration,
)
from QEfficient.transformers.models.llava_next.modeling_llava_next import (
QEffLlavaNextForConditionalGeneration,
)
from QEfficient.transformers.models.mistral.modeling_mistral import (
QEffMistralAttention,
QEffMistralDecoderLayer,
QEffMistralForCausalLM,
QEffMistralModel,
)
from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import (
QEffMixtralAttention,
QeffMixtralDecoderLayer,
QEffMixtralForCausalLM,
QEffMixtralModel,
QEffMixtralSparseMoeBlock,
)
from QEfficient.transformers.models.mllama.modeling_mllama import (
QEffMllamaCrossAttentionDecoderLayer,
QEffMllamaForCausalLM,
QEffMllamaForConditionalGeneration,
QEffMllamaRotaryEmbedding,
QEffMllamaSelfAttentionDecoderLayer,
QEffMllamaTextCrossAttentionSingleQPC,
QEffMllamaTextCrossAttentionTwoQPC,
QEffMllamaTextModel,
QEffMllamaTextSelfAttention,
QEffMllamaVisionModel,
)
from QEfficient.transformers.models.mpt.modeling_mpt import (
QEffMptAttention,
QEffMptBlock,
QEffMptForCausalLM,
QEFfMptModel,
)
from QEfficient.transformers.models.phi.modeling_phi import (
QEffPhiAttention,
QEffPhiDecoderLayer,
QEffPhiForCausalLM,
QEffPhiModel,
)
from QEfficient.transformers.models.phi3.modeling_phi3 import (
QEffPhi3Attention,
QEffPhi3DecoderLayer,
QEffPhi3ForCausalLM,
QEffPhi3Model,
)
from QEfficient.transformers.models.qwen2.modeling_qwen2 import (
QEffQwen2Attention,
QEffQwen2DecoderLayer,
QEffQwen2ForCausalLM,
QEffQwen2Model,
)
from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import (
QEffStarcoder2Attention,
QEFFStarcoder2DecoderLayer,
QEffStarcoder2ForCausalLM,
QEffStarcoder2Model,
)
from QEfficient.transformers.models.whisper.modeling_whisper import (
QEffWhisperAttention,
QEffWhisperDecoder,
QEffWhisperDecoderLayer,
QEffWhisperEncoder,
QEffWhisperForConditionalGeneration,
QEffWhisperModel,
QEffWhisperPositionalEmbedding,
)
from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry
from QEfficient.transformers.sampler.sampler import sampler_forward
from QEfficient.transformers.spd.spd_transform_forward import tlm_forward

Check failure on line 323 in QEfficient/transformers/models/pytorch_transforms.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

QEfficient/transformers/models/pytorch_transforms.py:8:1: I001 Import block is un-sorted or un-formatted

SPD_TARGET = "target"

Expand Down Expand Up @@ -632,3 +632,12 @@
model = PooledModel(model, pooling_method)
warnings.warn("Pooling is applied to the model.")
return model, transformed
class EmbeddingTransform:
@classmethod
def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]:
transformed = False
model.old_forward = model.forward
model.forward = MethodType(embedding_forward, model)
transformed = True

return model, transformed
31 changes: 24 additions & 7 deletions examples/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,50 @@
# This is the work example of the Embedding model with the AI 100
# For more information, visit: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2

import torch
from transformers import AutoTokenizer

from QEfficient import QEFFAutoModel as AutoModel


def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
last_hidden_states[input_mask_expanded == 0] = -1e9
return torch.max(last_hidden_states, 1)[0]
# def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
# input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
# last_hidden_states[input_mask_expanded == 0] = -1e9
# return torch.max(last_hidden_states, 1)[0]

import torch

Check failure on line 22 in examples/embedding_model.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F811)

examples/embedding_model.py:22:8: F811 Redefinition of unused `torch` from line 11

Check failure on line 22 in examples/embedding_model.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

examples/embedding_model.py:11:1: I001 Import block is un-sorted or un-formatted

# def max_pooling(last_hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
# # Expand position_ids to match the shape of last_hidden_states
# position_mask_expanded = position_ids.unsqueeze(-1).expand(last_hidden_states.size()).float()

# # Mask out positions with a special value (e.g., -1e9) where position_id is 0
# last_hidden_states[position_mask_expanded == 0] = -1e9

# # Apply max pooling across the sequence length dimension
# return torch.max(last_hidden_states, dim=1)[0]
def max_pooling(last_hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
# Create a mask where position_ids > 0 (or use a different condition based on your data)
position_mask = (position_ids > 0).unsqueeze(-1).expand(last_hidden_states.size()).float()
last_hidden_states[position_mask == 0] = -1e9
return torch.max(last_hidden_states, 1)[0]

# Sentences we want sentence embeddings for
sentences = "This is an example sentence"

model_name="jinaai/jina-embeddings-v2-base-code"
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
tokenizer = AutoTokenizer.from_pretrained(model_name)


# You can specify the pooling strategy either as a string (e.g., "max") or by passing a custom pooling function.
# If no pooling is specified, the model will return its default output (typically token embeddings).
qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", pooling=max_pooling)
qeff_model = AutoModel.from_pretrained(model_name, pooling=max_pooling, trust_remote_code=True, num_hidden_layers=1)
# qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", pooling="max")
# qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

# Here seq_len can be list of seq_len or single int
qeff_model.compile(num_cores=16, seq_len=[32, 64])
qeff_model.compile(num_cores=16, seq_len=32)
# qeff_model.compile(num_cores=16, seq_len=32)


Expand Down
155 changes: 79 additions & 76 deletions tests/transformers/models/test_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@

import numpy as np
import onnxruntime as ort
import pytest

Check failure on line 13 in tests/transformers/models/test_embedding_models.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F401)

tests/transformers/models/test_embedding_models.py:13:8: F401 `pytest` imported but unused
import torch
from transformers import AutoModel, AutoTokenizer

from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP
from QEfficient.transformers.models.modeling_auto import QEFFAutoModel
from QEfficient.utils._utils import create_json

Check failure on line 19 in tests/transformers/models/test_embedding_models.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F401)

tests/transformers/models/test_embedding_models.py:19:37: F401 `QEfficient.utils._utils.create_json` imported but unused
from QEfficient.utils.constants import Constants, QnnConstants

Check failure on line 20 in tests/transformers/models/test_embedding_models.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F401)

tests/transformers/models/test_embedding_models.py:20:51: F401 `QEfficient.utils.constants.QnnConstants` imported but unused

embed_test_models = [
{"model_name": "jinaai/jina-embeddings-v2-base-code", "pooling": "mean"},
# {"model_name": "jinaai/jina-embeddings-v2-base-code", "pooling": "mean"},
{"model_name": "sentence-transformers/nli-bert-base-cls-pooling", "pooling": "cls"},
]

Expand All @@ -40,7 +40,7 @@
# Original PyTorch model
pt_model = AutoModel.from_pretrained(
model_name,
num_hidden_layers=n_layer,
# num_hidden_layers=n_layer,
attn_implementation="eager",
trust_remote_code=True,
)
Expand All @@ -58,6 +58,10 @@
qeff_model = QEFFAutoModel(pt_model, pretrained_model_name_or_path=model_name, pooling=pooling)

# QEff transformed PyTorch model output
position_ids = torch.where(inputs["attention_mask"] == 1, torch.arange(inputs["attention_mask"].shape[1]), -1)
inputs["position_ids"] = position_ids
inputs.pop("attention_mask")

qeff_pt_outputs = qeff_model.generate(inputs=inputs, runtime_ai100=False)
qeff_pt_embeddings = qeff_pt_outputs if pooling else qeff_pt_outputs[0]

Expand All @@ -71,9 +75,12 @@

# Prepare the inputs for ONNX Runtime
input_ids = np.array(inputs["input_ids"])
attention_mask = np.array(inputs["attention_mask"])
position_ids = np.array(inputs["position_ids"])

onnx_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
onnx_inputs = {"input_ids": input_ids}

if len(ort_session.get_inputs()) > 1 and ort_session.get_inputs()[1].name == "position_ids":
onnx_inputs["position_ids"] = position_ids

# Run inference
onnx_outputs = ort_session.run(None, onnx_inputs)
Expand All @@ -88,6 +95,7 @@
enable_qnn=enable_qnn,
qnn_config=qnn_config,
)
inputs = tokenizer("My name is", return_tensors="pt")
ai100_output = qeff_model.generate(inputs=inputs)
qeff_ai100_embeddings = (
ai100_output["output"] if pooling else ai100_output["output"][:, : inputs["input_ids"].shape[1], :]
Expand All @@ -100,84 +108,79 @@
assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json"))


@pytest.mark.on_qaic
@pytest.mark.parametrize("model", embed_test_models)
def test_embed_model_pytorch_vs_onnx_vs_ai100(model):
"""
Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output.
"""
check_embed_pytorch_vs_ort_vs_ai100(model_name=model["model_name"], seq_len=32, n_layer=1)


@pytest.mark.on_qaic
@pytest.mark.parametrize("model", embed_test_models)
def test_embed_model_pytorch_vs_onnx_vs_ai100_pooling(model):
"""
Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output with pooling.
"""
check_embed_pytorch_vs_ort_vs_ai100(model_name=model["model_name"], seq_len=32, n_layer=1, pooling=model["pooling"])


@pytest.mark.on_qaic
@pytest.mark.parametrize("model", embed_test_models[:1])
def test_embed_model_pytorch_vs_onnx_vs_ai100_multiple_seq_len(model):
"""
Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output with multiple seq_len.
"""
check_embed_pytorch_vs_ort_vs_ai100(model_name=model["model_name"], seq_len=[32, 20], n_layer=1)


########## QNN TESTS ##############
check_embed_pytorch_vs_ort_vs_ai100(model_name="jinaai/jina-embeddings-v2-base-code", seq_len=32, n_layer=1)


@pytest.mark.on_qaic
@pytest.mark.qnn
@pytest.mark.parametrize("model_name", embed_test_models)
def test_embed_model_pytorch_vs_onnx_vs_ai100_qnn(model_name):
"""
QNN Compilation path test.
Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output.
"""
qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json")
create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG)
# @pytest.mark.on_qaic
# @pytest.mark.parametrize("model", embed_test_models)
# def test_embed_model_pytorch_vs_onnx_vs_ai100_pooling(model):
# """
# Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output with pooling.
# """
# check_embed_pytorch_vs_ort_vs_ai100(model_name=model["model_name"], seq_len=32, n_layer=1, pooling=model["pooling"])

check_embed_pytorch_vs_ort_vs_ai100(
model_name=model_name["model_name"], seq_len=32, n_layer=1, enable_qnn=True, qnn_config=qnn_config_json_path
)

# @pytest.mark.on_qaic
# @pytest.mark.parametrize("model", embed_test_models[:1])
# def test_embed_model_pytorch_vs_onnx_vs_ai100_multiple_seq_len(model):
# """
# Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output with multiple seq_len.
# """
# check_embed_pytorch_vs_ort_vs_ai100(model_name=model["model_name"], seq_len=[32, 20], n_layer=1)

@pytest.mark.on_qaic
@pytest.mark.qnn
@pytest.mark.parametrize("model", embed_test_models)
def test_embed_model_pytorch_vs_onnx_vs_ai100_pooling_qnn(model):
"""
QNN Compilation path test.
Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output with pooling.
"""
qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json")
create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG)

check_embed_pytorch_vs_ort_vs_ai100(
model_name=model["model_name"],
seq_len=32,
n_layer=1,
pooling=model["pooling"],
enable_qnn=True,
qnn_config=qnn_config_json_path,
)

########## QNN TESTS ##############

@pytest.mark.on_qaic
@pytest.mark.qnn
@pytest.mark.parametrize("model", [embed_test_models[0]])
def test_embed_model_pytorch_vs_onnx_vs_ai100_multiple_seq_len_qnn(model):
"""
QNN Compilation path test.
Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output with multiple seq_len.
"""
qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json")
create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG)

check_embed_pytorch_vs_ort_vs_ai100(
model_name=model["model_name"], seq_len=[32, 20], n_layer=1, enable_qnn=True, qnn_config=qnn_config_json_path
)
# @pytest.mark.on_qaic
# @pytest.mark.qnn
# @pytest.mark.parametrize("model_name", embed_test_models)
# def test_embed_model_pytorch_vs_onnx_vs_ai100_qnn(model_name):
# """
# QNN Compilation path test.
# Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output.
# """
# qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json")
# create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG)

# check_embed_pytorch_vs_ort_vs_ai100(
# model_name=model_name["model_name"], seq_len=32, n_layer=1, enable_qnn=True, qnn_config=qnn_config_json_path
# )


# @pytest.mark.on_qaic
# @pytest.mark.qnn
# @pytest.mark.parametrize("model", embed_test_models)
# def test_embed_model_pytorch_vs_onnx_vs_ai100_pooling_qnn(model):
# """
# QNN Compilation path test.
# Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output with pooling.
# """
# qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json")
# create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG)

# check_embed_pytorch_vs_ort_vs_ai100(
# model_name=model["model_name"],
# seq_len=32,
# n_layer=1,
# pooling=model["pooling"],
# enable_qnn=True,
# qnn_config=qnn_config_json_path,
# )


# @pytest.mark.on_qaic
# @pytest.mark.qnn
# @pytest.mark.parametrize("model", [embed_test_models[0]])
# def test_embed_model_pytorch_vs_onnx_vs_ai100_multiple_seq_len_qnn(model):
# """
# QNN Compilation path test.
# Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output with multiple seq_len.
# """
# qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json")
# create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG)

# check_embed_pytorch_vs_ort_vs_ai100(
# model_name=model["model_name"], seq_len=[32, 20], n_layer=1, enable_qnn=True, qnn_config=qnn_config_json_path
# )
Loading