Skip to content

byebye torch 2.0 #37277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 7, 2025
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Explore the [Hub](https://huggingface.com/) today to find a model and use Transf

## Installation

Transformers works with Python 3.9+ [PyTorch](https://pytorch.org/get-started/locally/) 2.0+, [TensorFlow](https://www.tensorflow.org/install/pip) 2.6+, and [Flax](https://flax.readthedocs.io/en/latest/) 0.4.1+.
Transformers works with Python 3.9+ [PyTorch](https://pytorch.org/get-started/locally/) 2.1+, [TensorFlow](https://www.tensorflow.org/install/pip) 2.6+, and [Flax](https://flax.readthedocs.io/en/latest/) 0.4.1+.

Create and activate a virtual environment with [venv](https://docs.python.org/3/library/venv.html) or [uv](https://docs.astral.sh/uv/), a fast Rust-based Python package and project manager.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ rendered properly in your Markdown viewer.

# Installation

Transformers works with [PyTorch](https://pytorch.org/get-started/locally/), [TensorFlow 2.0](https://www.tensorflow.org/install/pip), and [Flax](https://flax.readthedocs.io/en/latest/). It has been tested on Python 3.9+, PyTorch 2.0+, TensorFlow 2.6+, and Flax 0.4.1+.
Transformers works with [PyTorch](https://pytorch.org/get-started/locally/), [TensorFlow 2.0](https://www.tensorflow.org/install/pip), and [Flax](https://flax.readthedocs.io/en/latest/). It has been tested on Python 3.9+, PyTorch 2.1+, TensorFlow 2.6+, and Flax 0.4.1+.

## Virtual environment

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_ar.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ limitations under the License.

### باستخدام pip

تم اختبار هذا المستودع على Python 3.9+، Flax 0.4.1+، PyTorch 2.0+، و TensorFlow 2.6+.
تم اختبار هذا المستودع على Python 3.9+، Flax 0.4.1+، PyTorch 2.1+، و TensorFlow 2.6+.

يجب تثبيت 🤗 Transformers في [بيئة افتراضية](https://docs.python.org/3/library/venv.html). إذا كنت غير معتاد على البيئات الافتراضية Python، فراجع [دليل المستخدم](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_de.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ Das Modell selbst ist ein reguläres [PyTorch `nn.Module`](https://pytorch.org/d

### Mit pip

Dieses Repository wurde mit Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ und TensorFlow 2.6+ getestet.
Dieses Repository wurde mit Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ und TensorFlow 2.6+ getestet.

Sie sollten 🤗 Transformers in einer [virtuellen Umgebung](https://docs.python.org/3/library/venv.html) installieren. Wenn Sie mit virtuellen Python-Umgebungen nicht vertraut sind, schauen Sie sich den [Benutzerleitfaden](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) an.

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_es.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ El modelo en si es un [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.h

### Con pip

Este repositorio está probado en Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ y TensorFlow 2.6+.
Este repositorio está probado en Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ y TensorFlow 2.6+.

Deberías instalar 🤗 Transformers en un [entorno virtual](https://docs.python.org/3/library/venv.html). Si no estas familiarizado con los entornos virtuales de Python, consulta la [guía de usuario](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_fr.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ Le modèle lui-même est un module [`nn.Module` PyTorch](https://pytorch.org/doc

### Avec pip

Ce référentiel est testé sur Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ et TensorFlow 2.6+.
Ce référentiel est testé sur Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ et TensorFlow 2.6+.

Vous devriez installer 🤗 Transformers dans un [environnement virtuel](https://docs.python.org/3/library/venv.html). Si vous n'êtes pas familier avec les environnements virtuels Python, consultez le [guide utilisateur](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_hd.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ checkpoint: जाँच बिंदु

### पिप का उपयोग करना

इस रिपॉजिटरी का परीक्षण Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ और TensorFlow 2.6+ के तहत किया गया है।
इस रिपॉजिटरी का परीक्षण Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ और TensorFlow 2.6+ के तहत किया गया है।

आप [वर्चुअल एनवायरनमेंट](https://docs.python.org/3/library/venv.html) में 🤗 ट्रांसफॉर्मर इंस्टॉल कर सकते हैं। यदि आप अभी तक पायथन के वर्चुअल एनवायरनमेंट से परिचित नहीं हैं, तो कृपया इसे [उपयोगकर्ता निर्देश](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) पढ़ें।

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ Hugging Faceチームによって作られた **[トランスフォーマーを

### pipにて

このリポジトリは、Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, TensorFlow 2.6+ でテストされています。
このリポジトリは、Python 3.9+, Flax 0.4.1+, PyTorch 2.1+, TensorFlow 2.6+ でテストされています。

🤗Transformersは[仮想環境](https://docs.python.org/3/library/venv.html)にインストールする必要があります。Pythonの仮想環境に慣れていない場合は、[ユーザーガイド](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)を確認してください。

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_ko.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ Transformers에 달린 100,000개의 별을 축하하기 위해, 우리는 커

### pip로 설치하기

이 저장소는 Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, TensorFlow 2.6+에서 테스트 되었습니다.
이 저장소는 Python 3.9+, Flax 0.4.1+, PyTorch 2.1+, TensorFlow 2.6+에서 테스트 되었습니다.

[가상 환경](https://docs.python.org/3/library/venv.html)에 🤗 Transformers를 설치하세요. Python 가상 환경에 익숙하지 않다면, [사용자 가이드](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)를 확인하세요.

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_pt-br.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ O modelo em si é um [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.ht

### Com pip

Este repositório é testado no Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ e TensorFlow 2.6+.
Este repositório é testado no Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ e TensorFlow 2.6+.

Você deve instalar o 🤗 Transformers em um [ambiente virtual](https://docs.python.org/3/library/venv.html). Se você não está familiarizado com ambientes virtuais em Python, confira o [guia do usuário](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_ru.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ Hugging Face Hub. Мы хотим, чтобы Transformers позволил ра

### С помощью pip

Данный репозиторий протестирован на Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ и TensorFlow 2.6+.
Данный репозиторий протестирован на Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ и TensorFlow 2.6+.

Устанавливать 🤗 Transformers следует в [виртуальной среде](https://docs.python.org/3/library/venv.html). Если вы не знакомы с виртуальными средами Python, ознакомьтесь с [руководством пользователя](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_te.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ limitations under the License.

### పిప్ తో

ఈ రిపోజిటరీ పైథాన్ 3.9+, ఫ్లాక్స్ 0.4.1+, PyTorch 2.0+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది.
ఈ రిపోజిటరీ పైథాన్ 3.9+, ఫ్లాక్స్ 0.4.1+, PyTorch 2.1+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది.

మీరు [వర్చువల్ వాతావరణం](https://docs.python.org/3/library/venv.html)లో 🤗 ట్రాన్స్‌ఫార్మర్‌లను ఇన్‌స్టాల్ చేయాలి. మీకు పైథాన్ వర్చువల్ పరిసరాల గురించి తెలియకుంటే, [యూజర్ గైడ్](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) చూడండి.

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_ur.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ limitations under the License.

#### ‏ pip کے ساتھ

یہ ریپوزٹری Python 3.9+، Flax 0.4.1+، PyTorch 2.0+، اور TensorFlow 2.6+ پر ٹیسٹ کی گئی ہے۔
یہ ریپوزٹری Python 3.9+، Flax 0.4.1+، PyTorch 2.1+، اور TensorFlow 2.6+ پر ٹیسٹ کی گئی ہے۔

آپ کو 🤗 Transformers کو ایک [ورچوئل ماحول](https://docs.python.org/3/library/venv.html) میں انسٹال کرنا چاہیے۔ اگر آپ Python ورچوئل ماحول سے واقف نہیں ہیں، تو [یوزر گائیڈ](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) دیکھیں۔

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_vi.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ Chính mô hình là một [Pytorch `nn.Module`](https://pytorch.org/docs/stable

### Sử dụng pip

Thư viện này được kiểm tra trên Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ và TensorFlow 2.6+.
Thư viện này được kiểm tra trên Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ và TensorFlow 2.6+.

Bạn nên cài đặt 🤗 Transformers trong một [môi trường ảo Python](https://docs.python.org/3/library/venv.html). Nếu bạn chưa quen với môi trường ảo Python, hãy xem [hướng dẫn sử dụng](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_zh-hans.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ checkpoint: 检查点

### 使用 pip

这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.0+ 和 TensorFlow 2.6+ 下经过测试。
这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.1+ 和 TensorFlow 2.6+ 下经过测试。

你可以在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Transformers。如果你还不熟悉 Python 的虚拟环境,请阅此[用户说明](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_zh-hant.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Tokenizer 為所有的預訓練模型提供了預處理,並可以直接轉換

### 使用 pip

這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.0+ 和 TensorFlow 2.6+ 下經過測試。
這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.1+ 和 TensorFlow 2.6+ 下經過測試。

你可以在[虛擬環境](https://docs.python.org/3/library/venv.html)中安裝 🤗 Transformers。如果你還不熟悉 Python 的虛擬環境,請閱此[使用者指引](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"tiktoken",
"timm<=1.0.11",
"tokenizers>=0.21,<0.22",
"torch>=2.0",
"torch>=2.1",
"torchaudio",
"torchvision",
"pyctcdecode>=0.4.0",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
"tiktoken": "tiktoken",
"timm": "timm<=1.0.11",
"tokenizers": "tokenizers>=0.21,<0.22",
"torch": "torch>=2.0",
"torch": "torch>=2.1",
Copy link
Contributor

@cyyever cyyever Apr 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we specify torch>=2.1.2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the reason? Usually we are pin major/minor version, but not patch version. We can do that if there is good reason of course.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because is_torch_sdpa_available returns True on pt>=2.1.2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Given that most people are likely using newer torch versions, such specific pin (to the patch version) isn't really necessary IMO.

IIRC, when we need to pin such patch version(s), the reasons were earlier patch version(s) have serious issues and make several models not useable. One example I remembered was audio models failed at some point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will merge as it is. If we find such pin is necesssary, we can update in the future.

"torchaudio": "torchaudio",
"torchvision": "torchvision",
"pyctcdecode": "pyctcdecode>=0.4.0",
Expand Down
14 changes: 2 additions & 12 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,20 +485,15 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}

if is_torch_greater_or_equal("2.1.0"):
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn

if is_torch_greater_or_equal("2.3.0"):
str_to_torch_dtype["U16"] = torch.uint16
str_to_torch_dtype["U32"] = torch.uint32
str_to_torch_dtype["U64"] = torch.uint64

if is_torch_greater_or_equal("2.1.0"):
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
str_to_torch_dtype["F8_E5M2"] = torch.float8_e5m2


def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
Expand Down Expand Up @@ -546,12 +541,7 @@ def load_state_dict(
map_location = "cpu"
extra_args = {}
# mmap can only be used with files serialized with zipfile-based format.
if (
isinstance(checkpoint_file, str)
and map_location != "meta"
and version.parse(torch.__version__) >= version.parse("2.1.0")
and is_zipfile(checkpoint_file)
):
if isinstance(checkpoint_file, str) and map_location != "meta" and is_zipfile(checkpoint_file):
extra_args = {"mmap": True}
return torch.load(
checkpoint_file,
Expand Down
14 changes: 1 addition & 13 deletions src/transformers/models/mask2former/modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@
)
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_1
from ...utils import is_accelerate_available, logging
from ...utils.backbone_utils import load_backbone
from ...utils.import_utils import is_torchdynamo_compiling
from .configuration_mask2former import Mask2FormerConfig


Expand Down Expand Up @@ -2018,18 +2016,8 @@ def forward(
):
mask_embeddings = self.mask_embedder(outputs.transpose(0, 1))

is_tracing = torch.jit.is_tracing() or isinstance(outputs, torch.fx.Proxy) or is_torchdynamo_compiling()
# Sum up over the channels
if is_tracing and not is_torch_greater_or_equal_than_2_1:
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape
outputs_mask = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device)
for c in range(num_channels):
outputs_mask += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]

else:
outputs_mask = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
outputs_mask = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)

attention_mask = nn.functional.interpolate(
outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False
Expand Down
29 changes: 2 additions & 27 deletions src/transformers/models/maskformer/modeling_maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_1
from ...utils import (
ModelOutput,
add_start_docstrings,
Expand All @@ -39,7 +38,6 @@
requires_backends,
)
from ...utils.backbone_utils import load_backbone
from ...utils.import_utils import is_torchdynamo_compiling
from ..detr import DetrConfig
from .configuration_maskformer import MaskFormerConfig
from .configuration_maskformer_swin import MaskFormerSwinConfig
Expand Down Expand Up @@ -1685,26 +1683,14 @@ def get_logits(self, outputs: MaskFormerModelOutput) -> Tuple[Tensor, Tensor, Di
# get the auxiliary predictions (one for each decoder's layer)
auxiliary_logits: List[str, Tensor] = []

is_tracing = torch.jit.is_tracing() or isinstance(outputs, torch.fx.Proxy) or is_torchdynamo_compiling()
# This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list
if self.config.use_auxiliary_loss:
stacked_transformer_decoder_outputs = torch.stack(outputs.transformer_decoder_hidden_states)
classes = self.class_predictor(stacked_transformer_decoder_outputs)
class_queries_logits = classes[-1]
# get the masks
mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs)

if is_tracing and not is_torch_greater_or_equal_than_2_1:
# Equivalent to einsum('lbqc, bchw -> lbqhw') but jit friendly
num_embeddings, batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape
binaries_masks = torch.zeros(
(num_embeddings, batch_size, num_queries, height, width), device=mask_embeddings.device
)
for c in range(num_channels):
binaries_masks += mask_embeddings[..., c][..., None, None] * pixel_embeddings[None, :, None, c]
else:
binaries_masks = torch.einsum("lbqc, bchw -> lbqhw", mask_embeddings, pixel_embeddings)
binaries_masks = torch.einsum("lbqc, bchw -> lbqhw", mask_embeddings, pixel_embeddings)

masks_queries_logits = binaries_masks[-1]
# go til [:-1] because the last one is always used
Expand All @@ -1720,18 +1706,7 @@ def get_logits(self, outputs: MaskFormerModelOutput) -> Tuple[Tensor, Tensor, Di
# get the masks
mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states)
# sum up over the channels

if is_tracing and not is_torch_greater_or_equal_than_2_1:
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape
masks_queries_logits = torch.zeros(
(batch_size, num_queries, height, width), device=mask_embeddings.device
)
for c in range(num_channels):
masks_queries_logits += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
else:
masks_queries_logits = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
masks_queries_logits = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)

return class_queries_logits, masks_queries_logits, auxiliary_logits

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True)
is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)

# For backwards compatibility (e.g. some remote codes on Hub using those variables).
is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)
is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True)
is_torch_greater_or_equal_than_1_13 = is_torch_greater_or_equal("1.13", accept_dev=True)
is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_dev=True)
Expand Down
7 changes: 2 additions & 5 deletions src/transformers/quantizers/quantizer_fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from packaging import version

from .base import HfQuantizer


Expand Down Expand Up @@ -48,9 +45,9 @@ def __init__(self, quantization_config, **kwargs):
self.quantization_config = quantization_config

def validate_environment(self, *args, **kwargs):
if not is_torch_available() or version.parse(importlib.metadata.version("torch")) < version.parse("2.1.0"):
if not is_torch_available():
raise ImportError(
"Using fbgemm fp8 quantization requires torch > 2.1.0"
"Using fbgemm fp8 quantization requires torch >= 2.1.0"
"Please install the latest version of torch ( pip install --upgrade torch )"
)
if not is_fbgemm_gpu_available():
Expand Down
Loading