Skip to content

Add a feature of UniformQDQ - support CV/NLP model's OPs, includingConv, DepthwiseConv2D, MatMul, etc. Additional op support be added upon request. #2155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4c9ff61
Enable Uniform QDQ for Keras Models
zehao-intel Mar 20, 2024
bfd0524
fix bugs
zehao-intel Mar 20, 2024
f5e6726
fix import
zehao-intel Mar 20, 2024
70cb8d3
support saved_model out
zehao-intel Apr 19, 2024
5c03b79
fix import
zehao-intel Apr 21, 2024
178449d
Merge branch 'master' of https://github.com/intel/neural-compressor
zehao-intel Apr 21, 2024
ed31916
Merge branch 'master' into zehao/uniform_qdq
zehao-intel Apr 21, 2024
b1ca538
fix issues
zehao-intel Apr 22, 2024
9dd4beb
add hf resnet50 example
zehao-intel Apr 22, 2024
2a6d162
fix uint8 max
zehao-intel Apr 22, 2024
48013dc
fix quint range for sequential or functional keras model
zehao-intel Apr 22, 2024
1959471
modify bert example
zehao-intel Apr 22, 2024
d5223eb
refine zp calculation for uint8
zehao-intel Jun 24, 2024
1db86f1
fix resnet50
zehao-intel Jul 3, 2024
2427209
fix getting value dict for weight min max
zehao-intel Jul 4, 2024
15d5c28
fix zp and scale factor
zehao-intel Jul 9, 2024
506dfa5
Update for generating QdQ
qgao007 Mar 10, 2025
58e7afa
Merge remote-tracking branch 'origin/master' into qg/refactor
qgao007 Mar 17, 2025
d564b77
update name changes
qgao007 Mar 17, 2025
ff5cb06
Add raise error for Unsupported op type for per-channel quantization
qgao007 Mar 20, 2025
7f47d66
add ssd_mobilenet
qgao007 Mar 21, 2025
6564fa1
remove debugging print
qgao007 Mar 26, 2025
2660fa3
add support for ssd_mobile
qgao007 Mar 26, 2025
3bd7ba9
clean debug
qgao007 Mar 27, 2025
250a03c
Merge remote-tracking branch 'origin/master' into qg/refactor before PR
qgao007 Mar 27, 2025
17b9e16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 27, 2025
af3910a
Update to address Leon's feedback
qgao007 Apr 7, 2025
e25cf25
Merge remote-tracking branch 'origin/master' into qg/refactor
qgao007 Apr 7, 2025
9a7b86f
Add missing pydantic for UT pass
qgao007 Apr 18, 2025
dec25bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
945 changes: 945 additions & 0 deletions examples/keras/image_recognition/hf_bert/main.py

Large diffs are not rendered by default.

76 changes: 76 additions & 0 deletions examples/keras/image_recognition/hf_resnet50/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
Step-by-Step
============

This document is used to enable Tensorflow Keras models using Intel® Neural Compressor.
This example can run on Intel CPUs and GPUs.


# Prerequisite

## 1. Environment

### Installation
```shell
# Install Intel® Neural Compressor
pip install neural-compressor
```

### Install Requirements
The Tensorflow and intel-extension-for-tensorflow is mandatory to be installed to run this example.
The Intel Extension for Tensorflow for Intel CPUs is installed as default.
```shell
pip install -r requirements.txt
```
> Note: Validated TensorFlow [Version](/docs/source/installation_guide.md#validated-software-environment).

## 2. Prepare Pretrained model

The pretrained model is provided by [Keras Applications](https://keras.io/api/applications/). prepare the model, Run as follow:
```
python prepare_model.py --output_model=/path/to/model
```
`--output_model ` the model should be saved as SavedModel format or H5 format.

## 3. Prepare Dataset

TensorFlow [models](https://github.com/tensorflow/models) repo provides [scripts and instructions](https://github.com/tensorflow/models/tree/master/research/slim#an-automated-script-for-processing-imagenet-data) to download, process and convert the ImageNet dataset to the TF records format.
We also prepared related scripts in `imagenet_prepare` directory. To download the raw images, the user must create an account with image-net.org. If you have downloaded the raw data and preprocessed the validation data by moving the images into the appropriate sub-directory based on the label (synset) of the image. we can use below command ro convert it to tf records format.

```shell
cd examples/keras/image_recognition/
# convert validation subset
bash prepare_dataset.sh --output_dir=/resnetv2_50/quantization/ptq/data --raw_dir=/PATH/TO/img_raw/val/ --subset=validation
# convert train subset
bash prepare_dataset.sh --output_dir=/resnetv2_50/quantization/ptq/data --raw_dir=/PATH/TO/img_raw/train/ --subset=train
cd resnetv2_50/quantization/ptq
```
> **Note**:
> The raw ImageNet dataset resides in JPEG files should be in the following directory structure. Taking validation set as an example:<br>
> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;/PATH/TO/img_raw/val/n01440764/ILSVRC2012_val_00000293.JPEG<br>
> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;/PATH/TO/img_raw/val/n01440764/ILSVRC2012_val_00000543.JPEG<br>
> where 'n01440764' is the unique synset label associated with these images.

# Run Command

## Quantization Config
The Quantization Config class has default parameters setting for running on Intel CPUs. If running this example on Intel GPUs, the 'backend' parameter should be set to 'itex' and the 'device' parameter should be set to 'gpu'.

```
config = PostTrainingQuantConfig(
device="gpu",
backend="itex",
...
)
```

## Quantization
```shell
bash run_quant.sh --input_model=./resnetv2_50_keras/ --output_model=./result --dataset_location=/path/to/evaluation/dataset
```

## Benchmark
```shell
bash run_benchmark.sh --input_model=./result --mode=accuracy --dataset_location=/path/to/evaluation/dataset --batch_size=32
bash run_benchmark.sh --input_model=./result --mode=performance --dataset_location=/path/to/evaluation/dataset --batch_size=1
```

171 changes: 171 additions & 0 deletions examples/keras/image_recognition/hf_resnet50/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#
# -*- coding: utf-8 -*-
#
# Copyright (c) 2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import numpy as np
import tensorflow as tf
from neural_compressor.utils import logger
# tf.config.optimizer.set_experimental_options({'remapping': False})
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

flags = tf.compat.v1.flags
FLAGS = flags.FLAGS

## Required parameters
flags.DEFINE_string(
'input_model', None, 'Run inference with specified keras model.')

flags.DEFINE_string(
'output_model', None, 'The output quantized model.')

flags.DEFINE_string(
'mode', 'performance', 'define benchmark mode for accuracy or performance')

flags.DEFINE_bool(
'tune', False, 'whether to tune the model')

flags.DEFINE_bool(
'benchmark', False, 'whether to benchmark the model')

flags.DEFINE_string(
'calib_data', None, 'location of calibration dataset')

flags.DEFINE_string(
'eval_data', None, 'location of evaluate dataset')

flags.DEFINE_integer('batch_size', 32, 'batch_size')

flags.DEFINE_integer(
'iters', 100, 'maximum iteration when evaluating performance')

from neural_compressor import Metric
from neural_compressor.data.transforms.transform import ComposeTransform
from neural_compressor.data.datasets.dataset import TensorflowImageRecord
from neural_compressor.data.transforms.imagenet_transform import LabelShift
from neural_compressor.data.dataloaders.tensorflow_dataloader import TensorflowDataLoader
from neural_compressor.data.transforms.imagenet_transform import BilinearImagenetTransform

height = width = 224
eval_dataset = TensorflowImageRecord(root=FLAGS.eval_data, transform=ComposeTransform(transform_list= \
[BilinearImagenetTransform(height=height, width=width)]))

eval_dataloader = TensorflowDataLoader(dataset=eval_dataset, batch_size=FLAGS.batch_size)

if FLAGS.calib_data:
calib_dataset = TensorflowImageRecord(root=FLAGS.calib_data, transform= \
ComposeTransform(transform_list= [BilinearImagenetTransform(height=height, width=width)]))
calib_dataloader = TensorflowDataLoader(dataset=calib_dataset, batch_size=10)

def evaluate(model):
"""
Custom evaluate function to inference the model for specified metric on validation dataset.

Args:
model (tf.keras.Model): The input model will be the objection of tf.keras.Model.

Returns:
accuracy (float): evaluation result, the larger is better.
"""
infer = model.signatures["serving_default"]
# print ("infer.inputs: {}".format(infer.inputs))
output_dict_keys = infer.structured_outputs.keys()
output_name = list(output_dict_keys )[0]
postprocess = LabelShift(label_shift=1)
from neural_compressor import METRICS
metrics = METRICS('tensorflow')
metric = metrics['topk']()
latency_list = []

def eval_func(dataloader, metric):
warmup = 5
iteration = None
latency_list = []
if FLAGS.benchmark and FLAGS.mode == 'performance':
iteration = FLAGS.iters
predict_fun = tf.function(infer, jit_compile=False)
for idx, (inputs, labels) in enumerate(dataloader):
inputs = np.array(inputs)
input_tensor = tf.constant(inputs, dtype=tf.float32)
input_tensor = tf.transpose(input_tensor, perm=[0, 3, 1, 2])
start = time.time()
predictions = predict_fun(input_tensor)[output_name]
end = time.time()
predictions, labels = postprocess((predictions, labels))
predictions = predictions.numpy()
metric.update(predictions, labels)
latency_list.append(end - start)
if iteration and idx >= iteration:
break
latency = np.array(latency_list[warmup:]).mean() / eval_dataloader.batch_size
return latency

latency = eval_func(eval_dataloader, metric)
if FLAGS.benchmark:
logger.info("\n{} mode benchmark result:".format(FLAGS.mode))
for i, res in enumerate(latency_list):
logger.debug("Iteration {} result {}:".format(i, res))
if FLAGS.benchmark and FLAGS.mode == 'performance':
logger.info("Batch size = {}".format(eval_dataloader.batch_size))
logger.info("Latency: {:.3f} ms".format(latency * 1000))
logger.info("Throughput: {:.3f} images/sec".format(1. / latency))
acc = metric.result()
return acc

def main(_):
if FLAGS.tune:
from neural_compressor.quantization import fit
from neural_compressor.config import PostTrainingQuantConfig, AccuracyCriterion
from neural_compressor import set_random_seed
set_random_seed(9527)
excluded_op_type = {
'matmul': {
'weight':{
'dtype':['fp32']
},
'activation':{
'dtype':['fp32']
}
}
}
config = PostTrainingQuantConfig(backend='itex',
calibration_sampling_size=[50, 100],
accuracy_criterion = AccuracyCriterion(tolerable_loss=0.9999),)
#op_type_dict=excluded_op_type,)
q_model = fit(
model=FLAGS.input_model,
conf=config,
calib_func=evaluate,
eval_func=evaluate)
q_model.save(FLAGS.output_model)

if FLAGS.benchmark:
from neural_compressor.benchmark import fit
from neural_compressor.config import BenchmarkConfig
if FLAGS.mode == 'performance':
conf = BenchmarkConfig(backend='itex', cores_per_instance=4, num_of_instance=1)
fit(FLAGS.input_model, conf, b_func=evaluate)
else:
# from neural_compressor.model import Model
# model = Model(FLAGS.input_model).model
from tensorflow.python.saved_model import load
model = load.load(FLAGS.input_model)
accuracy = evaluate(model)
logger.info('Batch size = %d' % FLAGS.batch_size)
logger.info("Accuracy: %.5f" % accuracy)

if __name__ == "__main__":
tf.compat.v1.app.run()
7 changes: 7 additions & 0 deletions examples/keras/image_recognition/hf_resnet50/prepare_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import tensorflow as tf
from transformers import TFResNetForImageClassification

# Download Resnet50 from HuggingFace and save it as saved model
# It will be saved at resnet50-saved-model/saved_model/1
model = TFResNetForImageClassification.from_pretrained("microsoft/resnet-50")
model.save_pretrained('resnet50-saved-model', saved_model=True)
2 changes: 2 additions & 0 deletions examples/keras/image_recognition/hf_resnet50/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tensorflow>=2.11.1
intel-extension-for-tensorflow[cpu]
50 changes: 50 additions & 0 deletions examples/keras/image_recognition/hf_resnet50/run_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/bin/bash
set -x

function main {

init_params "$@"
run_benchmark

}

# init params
function init_params {
batch_size=32
iters=100

for var in "$@"
do
case $var in
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--mode=*)
mode=$(echo $var |cut -f2 -d=)
;;
--dataset_location=*)
dataset_location=$(echo $var |cut -f2 -d=)
;;
--batch_size=*)
batch_size=$(echo $var |cut -f2 -d=)
;;
--iters=*)
iters=$(echo $var |cut -f2 -d=)
esac
done

}

# run_tuning
function run_benchmark {

python main.py \
--input_model ${input_model} \
--benchmark \
--mode ${mode} \
--eval_data ${dataset_location} \
--batch_size ${batch_size} \
--iters ${iters}
}

main "$@"
40 changes: 40 additions & 0 deletions examples/keras/image_recognition/hf_resnet50/run_quant.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash
set -x

function main {
init_params "$@"
run_tuning

}

# init params
function init_params {

for var in "$@"
do
case $var in
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--output_model=*)
output_model=$(echo $var |cut -f2 -d=)
;;
--dataset_location=*)
dataset_location=$(echo $var |cut -f2 -d=)
;;
esac
done

}

# run_tuning
function run_tuning {
python main.py \
--input_model ${input_model} \
--output_model ${output_model} \
--eval_data ${dataset_location} \
--calib_data ${dataset_location} \
--tune
}

main "$@"
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def _process_image_files(name, filenames, synsets, labels, humans, num_shards):
assert len(filenames) == len(humans)

# Break all images into batches with a [ranges[i][0], ranges[i][1]].
spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int)
spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int32)
ranges = []
threads = []
for i in xrange(len(spacing) - 1):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,13 @@ def main(_):
if FLAGS.tune:
from neural_compressor.quantization import fit
from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor.config import AccuracyCriterion
from neural_compressor import set_random_seed
set_random_seed(9527)
accuracy_criterion = AccuracyCriterion(criterion='absolute')

config = PostTrainingQuantConfig(backend='itex',
calibration_sampling_size=[50, 100])
calibration_sampling_size=[50, 100], accuracy_criterion=accuracy_criterion)
q_model = fit(
model=FLAGS.input_model,
conf=config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,12 @@ def main(_):
if FLAGS.tune:
from neural_compressor.quantization import fit
from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor.config import AccuracyCriterion
from neural_compressor import set_random_seed
set_random_seed(9524)
accuracy_criterion = AccuracyCriterion(criterion='absolute')
config = PostTrainingQuantConfig(backend='itex',
accuracy_criterion=accuracy_criterion,
calibration_sampling_size=[10, 15])
q_model = fit(
model=FLAGS.input_model,
Expand Down
Loading
Loading