Skip to content

Commit e7ad077

Browse files
ydshiehBorda
andauthored
byebye torch 2.0 (#37277)
* bump Torch 2.1 with broken compatibility `torch.compile` * dep table * remove usage of is_torch_greater_or_equal_than_2_1 * remove usage of is_torch_greater_or_equal_than_2_1 * remove if is_torch_greater_or_equal("2.1.0") * remove torch >= "2.1.0" * deal with 2.0.0 * PyTorch 2.0+ --> PyTorch 2.1+ * ruff 1 * difficult ruff * address comment * address comment --------- Co-authored-by: Jirka B <[email protected]> Co-authored-by: ydshieh <[email protected]>
1 parent 99f9f10 commit e7ad077

28 files changed

+38
-113
lines changed

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Explore the [Hub](https://huggingface.com/) today to find a model and use Transf
7070

7171
## Installation
7272

73-
Transformers works with Python 3.9+ [PyTorch](https://pytorch.org/get-started/locally/) 2.0+, [TensorFlow](https://www.tensorflow.org/install/pip) 2.6+, and [Flax](https://flax.readthedocs.io/en/latest/) 0.4.1+.
73+
Transformers works with Python 3.9+ [PyTorch](https://pytorch.org/get-started/locally/) 2.1+, [TensorFlow](https://www.tensorflow.org/install/pip) 2.6+, and [Flax](https://flax.readthedocs.io/en/latest/) 0.4.1+.
7474

7575
Create and activate a virtual environment with [venv](https://docs.python.org/3/library/venv.html) or [uv](https://docs.astral.sh/uv/), a fast Rust-based Python package and project manager.
7676

Diff for: docs/source/en/installation.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ rendered properly in your Markdown viewer.
2020

2121
# Installation
2222

23-
Transformers works with [PyTorch](https://pytorch.org/get-started/locally/), [TensorFlow 2.0](https://www.tensorflow.org/install/pip), and [Flax](https://flax.readthedocs.io/en/latest/). It has been tested on Python 3.9+, PyTorch 2.0+, TensorFlow 2.6+, and Flax 0.4.1+.
23+
Transformers works with [PyTorch](https://pytorch.org/get-started/locally/), [TensorFlow 2.0](https://www.tensorflow.org/install/pip), and [Flax](https://flax.readthedocs.io/en/latest/). It has been tested on Python 3.9+, PyTorch 2.1+, TensorFlow 2.6+, and Flax 0.4.1+.
2424

2525
## Virtual environment
2626

Diff for: i18n/README_ar.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ limitations under the License.
245245

246246
### باستخدام pip
247247

248-
تم اختبار هذا المستودع على Python 3.9+، Flax 0.4.1+، PyTorch 2.0+، و TensorFlow 2.6+.
248+
تم اختبار هذا المستودع على Python 3.9+، Flax 0.4.1+، PyTorch 2.1+، و TensorFlow 2.6+.
249249

250250
يجب تثبيت 🤗 Transformers في [بيئة افتراضية](https://docs.python.org/3/library/venv.html). إذا كنت غير معتاد على البيئات الافتراضية Python، فراجع [دليل المستخدم](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
251251

Diff for: i18n/README_de.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ Das Modell selbst ist ein reguläres [PyTorch `nn.Module`](https://pytorch.org/d
246246

247247
### Mit pip
248248

249-
Dieses Repository wurde mit Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ und TensorFlow 2.6+ getestet.
249+
Dieses Repository wurde mit Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ und TensorFlow 2.6+ getestet.
250250

251251
Sie sollten 🤗 Transformers in einer [virtuellen Umgebung](https://docs.python.org/3/library/venv.html) installieren. Wenn Sie mit virtuellen Python-Umgebungen nicht vertraut sind, schauen Sie sich den [Benutzerleitfaden](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) an.
252252

Diff for: i18n/README_es.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ El modelo en si es un [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.h
222222

223223
### Con pip
224224

225-
Este repositorio está probado en Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ y TensorFlow 2.6+.
225+
Este repositorio está probado en Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ y TensorFlow 2.6+.
226226

227227
Deberías instalar 🤗 Transformers en un [entorno virtual](https://docs.python.org/3/library/venv.html). Si no estas familiarizado con los entornos virtuales de Python, consulta la [guía de usuario](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
228228

Diff for: i18n/README_fr.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ Le modèle lui-même est un module [`nn.Module` PyTorch](https://pytorch.org/doc
243243

244244
### Avec pip
245245

246-
Ce référentiel est testé sur Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ et TensorFlow 2.6+.
246+
Ce référentiel est testé sur Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ et TensorFlow 2.6+.
247247

248248
Vous devriez installer 🤗 Transformers dans un [environnement virtuel](https://docs.python.org/3/library/venv.html). Si vous n'êtes pas familier avec les environnements virtuels Python, consultez le [guide utilisateur](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
249249

Diff for: i18n/README_hd.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ checkpoint: जाँच बिंदु
198198

199199
### पिप का उपयोग करना
200200

201-
इस रिपॉजिटरी का परीक्षण Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ और TensorFlow 2.6+ के तहत किया गया है।
201+
इस रिपॉजिटरी का परीक्षण Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ और TensorFlow 2.6+ के तहत किया गया है।
202202

203203
आप [वर्चुअल एनवायरनमेंट](https://docs.python.org/3/library/venv.html) में 🤗 ट्रांसफॉर्मर इंस्टॉल कर सकते हैं। यदि आप अभी तक पायथन के वर्चुअल एनवायरनमेंट से परिचित नहीं हैं, तो कृपया इसे [उपयोगकर्ता निर्देश](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) पढ़ें।
204204

Diff for: i18n/README_ja.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ Hugging Faceチームによって作られた **[トランスフォーマーを
256256

257257
### pipにて
258258

259-
このリポジトリは、Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, TensorFlow 2.6+ でテストされています。
259+
このリポジトリは、Python 3.9+, Flax 0.4.1+, PyTorch 2.1+, TensorFlow 2.6+ でテストされています。
260260

261261
🤗Transformersは[仮想環境](https://docs.python.org/3/library/venv.html)にインストールする必要があります。Pythonの仮想環境に慣れていない場合は、[ユーザーガイド](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)を確認してください。
262262

Diff for: i18n/README_ko.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ Transformers에 달린 100,000개의 별을 축하하기 위해, 우리는 커
242242

243243
### pip로 설치하기
244244

245-
이 저장소는 Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, TensorFlow 2.6+에서 테스트 되었습니다.
245+
이 저장소는 Python 3.9+, Flax 0.4.1+, PyTorch 2.1+, TensorFlow 2.6+에서 테스트 되었습니다.
246246

247247
[가상 환경](https://docs.python.org/3/library/venv.html)에 🤗 Transformers를 설치하세요. Python 가상 환경에 익숙하지 않다면, [사용자 가이드](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)를 확인하세요.
248248

Diff for: i18n/README_pt-br.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ O modelo em si é um [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.ht
253253

254254
### Com pip
255255

256-
Este repositório é testado no Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ e TensorFlow 2.6+.
256+
Este repositório é testado no Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ e TensorFlow 2.6+.
257257

258258
Você deve instalar o 🤗 Transformers em um [ambiente virtual](https://docs.python.org/3/library/venv.html). Se você não está familiarizado com ambientes virtuais em Python, confira o [guia do usuário](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
259259

Diff for: i18n/README_ru.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ Hugging Face Hub. Мы хотим, чтобы Transformers позволил ра
244244

245245
### С помощью pip
246246

247-
Данный репозиторий протестирован на Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ и TensorFlow 2.6+.
247+
Данный репозиторий протестирован на Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ и TensorFlow 2.6+.
248248

249249
Устанавливать 🤗 Transformers следует в [виртуальной среде](https://docs.python.org/3/library/venv.html). Если вы не знакомы с виртуальными средами Python, ознакомьтесь с [руководством пользователя](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
250250

Diff for: i18n/README_te.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ limitations under the License.
246246

247247
### పిప్ తో
248248

249-
ఈ రిపోజిటరీ పైథాన్ 3.9+, ఫ్లాక్స్ 0.4.1+, PyTorch 2.0+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది.
249+
ఈ రిపోజిటరీ పైథాన్ 3.9+, ఫ్లాక్స్ 0.4.1+, PyTorch 2.1+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది.
250250

251251
మీరు [వర్చువల్ వాతావరణం](https://docs.python.org/3/library/venv.html)లో 🤗 ట్రాన్స్‌ఫార్మర్‌లను ఇన్‌స్టాల్ చేయాలి. మీకు పైథాన్ వర్చువల్ పరిసరాల గురించి తెలియకుంటే, [యూజర్ గైడ్](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) చూడండి.
252252

Diff for: i18n/README_ur.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ limitations under the License.
259259

260260
#### &#8207; pip کے ساتھ
261261

262-
یہ ریپوزٹری Python 3.9+، Flax 0.4.1+، PyTorch 2.0+، اور TensorFlow 2.6+ پر ٹیسٹ کی گئی ہے۔
262+
یہ ریپوزٹری Python 3.9+، Flax 0.4.1+، PyTorch 2.1+، اور TensorFlow 2.6+ پر ٹیسٹ کی گئی ہے۔
263263

264264
آپ کو 🤗 Transformers کو ایک [ورچوئل ماحول](https://docs.python.org/3/library/venv.html) میں انسٹال کرنا چاہیے۔ اگر آپ Python ورچوئل ماحول سے واقف نہیں ہیں، تو [یوزر گائیڈ](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) دیکھیں۔
265265

Diff for: i18n/README_vi.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ Chính mô hình là một [Pytorch `nn.Module`](https://pytorch.org/docs/stable
245245

246246
### Sử dụng pip
247247

248-
Thư viện này được kiểm tra trên Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ và TensorFlow 2.6+.
248+
Thư viện này được kiểm tra trên Python 3.9+, Flax 0.4.1+, PyTorch 2.1+ và TensorFlow 2.6+.
249249

250250
Bạn nên cài đặt 🤗 Transformers trong một [môi trường ảo Python](https://docs.python.org/3/library/venv.html). Nếu bạn chưa quen với môi trường ảo Python, hãy xem [hướng dẫn sử dụng](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
251251

Diff for: i18n/README_zh-hans.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ checkpoint: 检查点
198198

199199
### 使用 pip
200200

201-
这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.0+ 和 TensorFlow 2.6+ 下经过测试。
201+
这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.1+ 和 TensorFlow 2.6+ 下经过测试。
202202

203203
你可以在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Transformers。如果你还不熟悉 Python 的虚拟环境,请阅此[用户说明](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)
204204

Diff for: i18n/README_zh-hant.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ Tokenizer 為所有的預訓練模型提供了預處理,並可以直接轉換
210210

211211
### 使用 pip
212212

213-
這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.0+ 和 TensorFlow 2.6+ 下經過測試。
213+
這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.1+ 和 TensorFlow 2.6+ 下經過測試。
214214

215215
你可以在[虛擬環境](https://docs.python.org/3/library/venv.html)中安裝 🤗 Transformers。如果你還不熟悉 Python 的虛擬環境,請閱此[使用者指引](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)
216216

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@
187187
"tiktoken",
188188
"timm<=1.0.11",
189189
"tokenizers>=0.21,<0.22",
190-
"torch>=2.0",
190+
"torch>=2.1",
191191
"torchaudio",
192192
"torchvision",
193193
"pyctcdecode>=0.4.0",

Diff for: src/transformers/dependency_versions_table.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
"tiktoken": "tiktoken",
9393
"timm": "timm<=1.0.11",
9494
"tokenizers": "tokenizers>=0.21,<0.22",
95-
"torch": "torch>=2.0",
95+
"torch": "torch>=2.1",
9696
"torchaudio": "torchaudio",
9797
"torchvision": "torchvision",
9898
"pyctcdecode": "pyctcdecode>=0.4.0",

Diff for: src/transformers/modeling_utils.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -485,20 +485,15 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
485485
"F64": torch.float64,
486486
"I64": torch.int64,
487487
"F8_E4M3": torch.float8_e4m3fn,
488+
"F8_E5M2": torch.float8_e5m2,
488489
}
489490

490-
if is_torch_greater_or_equal("2.1.0"):
491-
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
492491

493492
if is_torch_greater_or_equal("2.3.0"):
494493
str_to_torch_dtype["U16"] = torch.uint16
495494
str_to_torch_dtype["U32"] = torch.uint32
496495
str_to_torch_dtype["U64"] = torch.uint64
497496

498-
if is_torch_greater_or_equal("2.1.0"):
499-
str_to_torch_dtype["F8_E4M3"] = torch.float8_e4m3fn
500-
str_to_torch_dtype["F8_E5M2"] = torch.float8_e5m2
501-
502497

503498
def load_state_dict(
504499
checkpoint_file: Union[str, os.PathLike],
@@ -546,12 +541,7 @@ def load_state_dict(
546541
map_location = "cpu"
547542
extra_args = {}
548543
# mmap can only be used with files serialized with zipfile-based format.
549-
if (
550-
isinstance(checkpoint_file, str)
551-
and map_location != "meta"
552-
and version.parse(torch.__version__) >= version.parse("2.1.0")
553-
and is_zipfile(checkpoint_file)
554-
):
544+
if isinstance(checkpoint_file, str) and map_location != "meta" and is_zipfile(checkpoint_file):
555545
extra_args = {"mmap": True}
556546
return torch.load(
557547
checkpoint_file,

Diff for: src/transformers/models/mask2former/modeling_mask2former.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@
3434
)
3535
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
3636
from ...modeling_utils import PreTrainedModel
37-
from ...pytorch_utils import is_torch_greater_or_equal_than_2_1
3837
from ...utils import is_accelerate_available, logging
3938
from ...utils.backbone_utils import load_backbone
40-
from ...utils.import_utils import is_torchdynamo_compiling
4139
from .configuration_mask2former import Mask2FormerConfig
4240

4341

@@ -2018,18 +2016,8 @@ def forward(
20182016
):
20192017
mask_embeddings = self.mask_embedder(outputs.transpose(0, 1))
20202018

2021-
is_tracing = torch.jit.is_tracing() or isinstance(outputs, torch.fx.Proxy) or is_torchdynamo_compiling()
20222019
# Sum up over the channels
2023-
if is_tracing and not is_torch_greater_or_equal_than_2_1:
2024-
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
2025-
batch_size, num_queries, num_channels = mask_embeddings.shape
2026-
_, _, height, width = pixel_embeddings.shape
2027-
outputs_mask = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device)
2028-
for c in range(num_channels):
2029-
outputs_mask += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
2030-
2031-
else:
2032-
outputs_mask = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
2020+
outputs_mask = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
20332021

20342022
attention_mask = nn.functional.interpolate(
20352023
outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False

Diff for: src/transformers/models/maskformer/modeling_maskformer.py

+2-27
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
2828
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
2929
from ...modeling_utils import PreTrainedModel
30-
from ...pytorch_utils import is_torch_greater_or_equal_than_2_1
3130
from ...utils import (
3231
ModelOutput,
3332
add_start_docstrings,
@@ -39,7 +38,6 @@
3938
requires_backends,
4039
)
4140
from ...utils.backbone_utils import load_backbone
42-
from ...utils.import_utils import is_torchdynamo_compiling
4341
from ..detr import DetrConfig
4442
from .configuration_maskformer import MaskFormerConfig
4543
from .configuration_maskformer_swin import MaskFormerSwinConfig
@@ -1685,26 +1683,14 @@ def get_logits(self, outputs: MaskFormerModelOutput) -> Tuple[Tensor, Tensor, Di
16851683
# get the auxiliary predictions (one for each decoder's layer)
16861684
auxiliary_logits: List[str, Tensor] = []
16871685

1688-
is_tracing = torch.jit.is_tracing() or isinstance(outputs, torch.fx.Proxy) or is_torchdynamo_compiling()
16891686
# This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list
16901687
if self.config.use_auxiliary_loss:
16911688
stacked_transformer_decoder_outputs = torch.stack(outputs.transformer_decoder_hidden_states)
16921689
classes = self.class_predictor(stacked_transformer_decoder_outputs)
16931690
class_queries_logits = classes[-1]
16941691
# get the masks
16951692
mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs)
1696-
1697-
if is_tracing and not is_torch_greater_or_equal_than_2_1:
1698-
# Equivalent to einsum('lbqc, bchw -> lbqhw') but jit friendly
1699-
num_embeddings, batch_size, num_queries, num_channels = mask_embeddings.shape
1700-
_, _, height, width = pixel_embeddings.shape
1701-
binaries_masks = torch.zeros(
1702-
(num_embeddings, batch_size, num_queries, height, width), device=mask_embeddings.device
1703-
)
1704-
for c in range(num_channels):
1705-
binaries_masks += mask_embeddings[..., c][..., None, None] * pixel_embeddings[None, :, None, c]
1706-
else:
1707-
binaries_masks = torch.einsum("lbqc, bchw -> lbqhw", mask_embeddings, pixel_embeddings)
1693+
binaries_masks = torch.einsum("lbqc, bchw -> lbqhw", mask_embeddings, pixel_embeddings)
17081694

17091695
masks_queries_logits = binaries_masks[-1]
17101696
# go til [:-1] because the last one is always used
@@ -1720,18 +1706,7 @@ def get_logits(self, outputs: MaskFormerModelOutput) -> Tuple[Tensor, Tensor, Di
17201706
# get the masks
17211707
mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states)
17221708
# sum up over the channels
1723-
1724-
if is_tracing and not is_torch_greater_or_equal_than_2_1:
1725-
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
1726-
batch_size, num_queries, num_channels = mask_embeddings.shape
1727-
_, _, height, width = pixel_embeddings.shape
1728-
masks_queries_logits = torch.zeros(
1729-
(batch_size, num_queries, height, width), device=mask_embeddings.device
1730-
)
1731-
for c in range(num_channels):
1732-
masks_queries_logits += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
1733-
else:
1734-
masks_queries_logits = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
1709+
masks_queries_logits = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
17351710

17361711
return class_queries_logits, masks_queries_logits, auxiliary_logits
17371712

Diff for: src/transformers/pytorch_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
3333
is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
3434
is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True)
35-
is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)
3635

3736
# For backwards compatibility (e.g. some remote codes on Hub using those variables).
37+
is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)
3838
is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True)
3939
is_torch_greater_or_equal_than_1_13 = is_torch_greater_or_equal("1.13", accept_dev=True)
4040
is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_dev=True)

Diff for: src/transformers/quantizers/quantizer_fbgemm_fp8.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import importlib
1514
from typing import TYPE_CHECKING, Any, Dict, List, Optional
1615

17-
from packaging import version
18-
1916
from .base import HfQuantizer
2017

2118

@@ -48,9 +45,9 @@ def __init__(self, quantization_config, **kwargs):
4845
self.quantization_config = quantization_config
4946

5047
def validate_environment(self, *args, **kwargs):
51-
if not is_torch_available() or version.parse(importlib.metadata.version("torch")) < version.parse("2.1.0"):
48+
if not is_torch_available():
5249
raise ImportError(
53-
"Using fbgemm fp8 quantization requires torch > 2.1.0"
50+
"Using fbgemm fp8 quantization requires torch >= 2.1.0"
5451
"Please install the latest version of torch ( pip install --upgrade torch )"
5552
)
5653
if not is_fbgemm_gpu_available():

0 commit comments

Comments
 (0)