Skip to content

[Bug]: torch.OutOfMemoryError: CUDA out of memory for SQ+GPTQ on Llama 70B with basic pipeline on 4xH100 #1858

@nsantavas

Description

@nsantavas

⚙️ Your current environment

The output of python collect_env.py
### Environment Information ###
Operating System: `Linux-5.15.0-153-generic-x86_64-with-glibc2.35`
Python Version: `3.10.18 | packaged by conda-forge | (main, Jun  4 2025, 14:45:41) [GCC 13.3.0]`
llm-compressor Version: `0.7.1`
compressed-tensors Version: `0.11.0`
transformers Version: `4.55.2`
torch Version: `2.8.0`
CUDA Devices: `['NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3']`
AMD Devices: `None`

🐛 Describe the bug

I am trying to quantize Llama 70B with SQ+GPTQ on 4xH100s but I am getting OOM error once SQ starts. Same happens even if I disable SQ.

2025-09-23T09:11:39.662418+0000 | reset | INFO - Compression lifecycle reset
2025-09-23T09:11:39.667505+0000 | _create_default_logger | INFO - Logging all LLM Compressor modifier-level logs to sparse_logs/23-09-2025_09.11.39.log
2025-09-23T09:11:39.671668+0000 | from_modifiers | INFO - Creating recipe from modifiers
2025-09-23T09:11:39.792732+0000 | initialize | INFO - Compression lifecycle initialized for 1 modifiers
2025-09-23T09:11:39.792828+0000 | from_modifiers | WARNING - Calibration pipeline is set to `basic`, but it is recommended to use `sequential`
Calibrating:   0%|                                      | 0/256 [00:00<?, ?it/s]
...
pressor/modifiers/quantization/gptq/base.py", line 232, in calibrate_module
    self._hessians[module] = make_empty_hessian(module, device=init_device)
  File "/home/nsantavas/miniforge3/envs/test/lib/python3.10/site-packages/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py", line 30, in make_empty_hessian
    return torch.zeros((num_columns, num_columns), device=device, dtype=GPTQ_PRECISION)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.06 GiB. GPU 0 has a total capacity of 79.21 GiB of which 992.75 MiB is free. Including non-PyTorch memory, this process has 78.23 GiB memory in use. Of the allocated memory 70.31 GiB is allocated by PyTorch, and 7.26 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Interestingly enough, this is not happening with the following setup:

### Environment Information ###
Operating System: `Linux-5.15.0-153-generic-x86_64-with-glibc2.35`
Python Version: `3.10.18 | packaged by conda-forge | (main, Jun  4 2025, 14:45:41) [GCC 13.3.0]`
llm-compressor Version: `0.5.1`
compressed-tensors Version: `0.9.4`
transformers Version: `4.51.1`
torch Version: `2.7.0`
CUDA Devices: `['NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3', 'NVIDIA H100 80GB HBM3']`
AMD Devices: `None`

Any ideas why this is happening? Ideally I'd like to avoid making use of the sequential pipeline.

🛠️ Steps to reproduce

import torch
from datasets import load_dataset
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.cuda.empty_cache()

print("Available GPUs:", torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")


MODEL_ID = "meta-llama/Llama-3.1-70B-Instruct"
# Select calibration dataset.
DATASET_ID = "neuralmagic/calibration"
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 2048


tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, "LLM")
ds = ds.shuffle(seed=42)["train"].select(range(NUM_CALIBRATION_SAMPLES))


def preprocess(example):
    return {
        "text": tokenizer.apply_chat_template(
            example["messages"],
            tokenize=False,
        )
    }


ds = ds.map(preprocess)


def tokenize(sample):
    return tokenizer(
        sample["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )


ds = ds.map(tokenize, remove_columns=ds.column_names)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto"
)


recipe = [
    SmoothQuantModifier(smoothing_strength=0.8),
    GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
]

oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    # pipeline="basic", # enable for versions >0.5.1
)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions