Skip to content

Enables the per_tensor lowering patterns for weight per_packing #2391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

choudhary-devang
Copy link
Collaborator

@choudhary-devang choudhary-devang commented Jun 17, 2025

This Pr is an extension of #2139 pr,

Major changes:
1)Introduced lowering pattern for "per_tensor" quantized weights.
2) Modified the original api get_default_arm_inductor_quantization_config to add user choice of using "per_tensor" and "per_channel" granularity in model weight's quantization.

supported shapes:

  1. s8:s8:f32 - (per_tensor / per_channel) input : s8, weight : s8, output : f32
  2. u8:s8:f32 - (per_tensor / per_channel ) input : u8, weight : s8, output : f32

Tested and verified for different models:

  • Bert model
  • Resnet model
  • Vit model
  • Custum models

Example script for refence:

import torch
from transformers import BertModel
import copy
import time
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e
import torchao.quantization.pt2e.quantizer.arm_inductor_quantizer as aiq
from torchao.quantization.pt2e.quantizer.arm_inductor_quantizer import ArmInductorQuantizer
import torch.profiler
import torch._inductor.config as config
# Enable C++ wrapper for Inductor
config.cpp_wrapper = True
config.freezing=True

model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)

# Set the model to eval mode
model = model.eval()

# Create the data, using dummy data here as an example
traced_bs = 32
seq_length = 128
x = torch.randint(0, 10000, (traced_bs, seq_length))
attention_mask = torch.ones((traced_bs, seq_length))
example_inputs = (x, attention_mask)

# Capture the FX Graph to be quantized
with torch.no_grad():
    exported_model = torch.export.export_for_training(model, example_inputs).module()

    # Set up the quantizer and prepare the model for post-training quantization
    quantizer = ArmInductorQuantizer()
    quantizer.set_global(aiq.get_default_arm_inductor_quantization_config(is_dynamic=True, is_per_channel=True))
    prepared_model = prepare_pt2e(exported_model, quantizer)

    # Run the prepared model to apply the quantization
    prepared_model(*example_inputs)

    # Convert the model to the quantized version
    converted_model = convert_pt2e(prepared_model)
    optimized_model = torch.compile(converted_model)
    st = time.time()
    optimized_model(*example_inputs)
    et = time.time()
    print(f"Average time required for inference = {et-st}\n")

cc: @jerryzh168, @fadara01, @Xia-Weiwen

Copy link

pytorch-bot bot commented Jun 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2391

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit c698531 with merge base e4f2715 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants