Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
QEFFCommonLoader,
)
from QEfficient.compile.compile_helper import compile
from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.peft import QEffAutoPeftModelForCausalLM
Expand All @@ -39,6 +40,7 @@
"QEFFAutoModelForImageTextToText",
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
"QEffFluxPipeline",
]
# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
Expand Down
55 changes: 29 additions & 26 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import gc
import inspect
import logging
import re
import shutil
import subprocess
import warnings
Expand All @@ -21,26 +20,21 @@

from QEfficient.base.onnx_transforms import (
BaseOnnxTransform,
CustomOpTransform,
OnnxTransformPipeline,
RenameFunctionOutputsTransform,
)
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.transformers.cache_utils import InvalidIndexProvider
from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
from QEfficient.utils import (
constants,
create_json,
create_model_params,
dump_qconfig,
export_wrapper,
generate_mdp_partition_config,
hash_dict_params,
load_json,
)
from QEfficient.utils.torch_patches import apply_torch_patches, undo_torch_patches
from QEfficient.utils.export_utils import export_wrapper

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -125,9 +119,35 @@ def _model_offloaded_check(self) -> None:
logger.error(error_msg)
raise RuntimeError(error_msg)

@property
def model_name(self) -> str:
"""
Get the model class name without QEff/QEFF prefix.

This property extracts the underlying model's class name and removes
any QEff or QEFF prefix that may have been added during wrapping.

Returns:
str: Model class name (e.g., "CLIPTextModel" instead of "QEffCLIPTextModel")
"""
mname = self.model.__class__.__name__
if mname.startswith("QEff") or mname.startswith("QEFF"):
mname = mname[4:]
return mname

@property
@abstractmethod
def model_name(self) -> str: ...
def get_model_config(self) -> Dict:
"""
Get the model configuration as a dictionary.

This is an abstract property that must be implemented by all subclasses.
Typically returns: self.model.config.__dict__

Returns:
Dict: The configuration dictionary of the underlying model
"""
pass

@abstractmethod
def export(self, export_dir: Optional[str] = None) -> Path:
Expand Down Expand Up @@ -188,7 +208,6 @@ def _export(
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
export_dir: Optional[str] = None,
offload_pt_weights: bool = True,
use_onnx_subfunctions: bool = False,
) -> str:
"""
Export the PyTorch model to ONNX and apply ONNX transforms
Expand Down Expand Up @@ -253,18 +272,8 @@ def _export(
input_names.append(param)

try:
# Initialize the registry with your custom ops
# Export to ONNX
export_kwargs = {} if export_kwargs is None else export_kwargs
if use_onnx_subfunctions:
warnings.warn(
"The subfunction feature is experimental. Please note that using compile consecutively with and without subfunction may produce inconsistent results."
)
apply_torch_patches()
InvalidIndexProvider.SUBFUNC_ENABLED = True
output_names = [re.sub("_RetainedState", "_InternalRetainedState", s) for s in output_names]
export_kwargs["export_modules_as_functions"] = get_decoder_layer_classes_for_export(self.model)
self._onnx_transforms.append(RenameFunctionOutputsTransform)
self._onnx_transforms.append(CustomOpTransform)

torch.onnx.export(
self.model,
Expand Down Expand Up @@ -309,12 +318,6 @@ def _export(
finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)

if use_onnx_subfunctions:
undo_torch_patches()
InvalidIndexProvider.SUBFUNC_ENABLED = False
self._onnx_transforms.remove(CustomOpTransform)
self._onnx_transforms.remove(RenameFunctionOutputsTransform)

self.onnx_path = onnx_path
return onnx_path

Expand Down
27 changes: 27 additions & 0 deletions QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,33 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
raise NotImplementedError("Use subclasses for Pytorch transform")


class ProxyModuleMappingTransform(PytorchTransform):
"""
Replaces the PyTorch modules based on the _module_mapping class variable.
"""

_module_mapping: Dict[Type[nn.Module], Type[nn.Module]]

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
for name, module in model.named_modules():
for base_type, repl_type in cls._module_mapping.items():
if isinstance(module, base_type):
if base_type is nn.Linear:
short_name = name.split(".")[-1] if name else ""
if short_name != "lm_head":
continue
# Perform in-place class replacement (preserve parameters/state)
try:
module.__class__ = repl_type
transformed = True
except Exception as e:
logger.warning(f"Failed to replace module {name} ({base_type}) -> {repl_type}: {e}")

return model, transformed


class ModuleMappingTransform(PytorchTransform):
"""
Replaces the PyTorch modules based on the _module_mapping class variable.
Expand Down
95 changes: 95 additions & 0 deletions QEfficient/diffusers/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@

<div align="center">


# **Diffusion Models on Qualcomm Cloud AI 100**


<div align="center">

### 🎨 **Experience the Future of AI Image Generation**

* Optimized for Qualcomm Cloud AI 100*

<img src="../../docs/image/girl_laughing.png" alt="Sample Output" width="400">

**Generated with**: `black-forest-labs/FLUX.1-schnell` • `"A girl laughing"` • 4 steps • 0.0 guidance scale • ⚡



</div>



[![Diffusers](https://img.shields.io/badge/Diffusers-0.35.1-orange.svg)](https://github.com/huggingface/diffusers)
</div>

---

## ✨ Overview

QEfficient Diffusers brings the power of state-of-the-art diffusion models to Qualcomm Cloud AI 100 hardware for text-to-image generation. Built on top of the popular HuggingFace Diffusers library, our optimized pipeline provides seamless inference on Qualcomm Cloud AI 100 hardware.

## 🛠️ Installation

### Prerequisites

Ensure you have Python 3.8+ and the required dependencies:

```bash
# Create Python virtual environment (Recommended Python 3.10)
sudo apt install python3.10-venv
python3.10 -m venv qeff_env
source qeff_env/bin/activate
pip install -U pip
```

### Install QEfficient

```bash
# Install from GitHub (includes diffusers support)
pip install git+https://github.com/quic/efficient-transformers

# Or build from source
git clone https://github.com/quic/efficient-transformers.git
cd efficient-transformers
pip install build wheel
python -m build --wheel --outdir dist
pip install dist/qefficient-0.0.1.dev0-py3-none-any.whl
```

---

## 🎯 Supported Models
- ✅ [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell)

---


## 📚 Examples

Check out our comprehensive examples in the [`examples/diffusers/`](../../examples/diffusers/) directory:

---

## 🤝 Contributing

We welcome contributions! Please see our [Contributing Guide](../../CONTRIBUTING.md) for details.



---

## 🙏 Acknowledgments

- **HuggingFace Diffusers**: For the excellent foundation library
- **Stability AI**: For the amazing Stable Diffusion models
---

## 📞 Support

- 📖 **Documentation**: [https://quic.github.io/efficient-transformers/](https://quic.github.io/efficient-transformers/)
- 🐛 **Issues**: [GitHub Issues](https://github.com/quic/efficient-transformers/issues)

---

6 changes: 6 additions & 0 deletions QEfficient/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
6 changes: 6 additions & 0 deletions QEfficient/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
40 changes: 40 additions & 0 deletions QEfficient/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
from typing import Optional, Tuple

import torch
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle


class QEffAdaLayerNormZero(AdaLayerNormZero):
def forward(
self,
x: torch.Tensor,
shift_msa: Optional[torch.Tensor] = None,
scale_msa: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x


class QEffAdaLayerNormZeroSingle(AdaLayerNormZeroSingle):
def forward(
self,
x: torch.Tensor,
scale_msa: Optional[torch.Tensor] = None,
shift_msa: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x


class QEffAdaLayerNormContinuous(AdaLayerNormContinuous):
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = conditioning_embedding
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
56 changes: 56 additions & 0 deletions QEfficient/diffusers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
from diffusers.models.transformers.transformer_flux import (
FluxAttention,
FluxAttnProcessor,
FluxSingleTransformerBlock,
FluxTransformer2DModel,
FluxTransformerBlock,
)
from torch import nn

from QEfficient.base.pytorch_transforms import ModuleMappingTransform
from QEfficient.customop.rms_norm import CustomRMSNormAIC
from QEfficient.diffusers.models.normalization import (
QEffAdaLayerNormContinuous,
QEffAdaLayerNormZero,
QEffAdaLayerNormZeroSingle,
)
from QEfficient.diffusers.models.transformers.transformer_flux import (
QEffFluxAttention,
QEffFluxAttnProcessor,
QEffFluxSingleTransformerBlock,
QEffFluxTransformer2DModel,
QEffFluxTransformerBlock,
)


class CustomOpsTransform(ModuleMappingTransform):
_module_mapping = {
RMSNorm: CustomRMSNormAIC,
nn.RMSNorm: CustomRMSNormAIC, # for torch.nn.RMSNorm
}


class AttentionTransform(ModuleMappingTransform):
_module_mapping = {
FluxSingleTransformerBlock: QEffFluxSingleTransformerBlock,
FluxTransformerBlock: QEffFluxTransformerBlock,
FluxTransformer2DModel: QEffFluxTransformer2DModel,
FluxAttention: QEffFluxAttention,
FluxAttnProcessor: QEffFluxAttnProcessor,
}


class NormalizationTransform(ModuleMappingTransform):
_module_mapping = {
AdaLayerNormZero: QEffAdaLayerNormZero,
AdaLayerNormZeroSingle: QEffAdaLayerNormZeroSingle,
AdaLayerNormContinuous: QEffAdaLayerNormContinuous,
}
6 changes: 6 additions & 0 deletions QEfficient/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
Loading