Skip to content

Commit 2b13495

Browse files
committed
Set _torch_version to N/A if torch is disabled.
1 parent 3a31b29 commit 2b13495

File tree

1 file changed

+87
-28
lines changed

1 file changed

+87
-28
lines changed

src/diffusers/utils/import_utils.py

Lines changed: 87 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
else:
3838
import importlib.metadata as importlib_metadata
3939
try:
40-
_package_map = importlib_metadata.packages_distributions() # load-once to avoid expensive calls
40+
_package_map = (
41+
importlib_metadata.packages_distributions()
42+
) # load-once to avoid expensive calls
4143
except Exception:
4244
_package_map = None
4345

@@ -53,12 +55,23 @@
5355
DIFFUSERS_SLOW_IMPORT = os.environ.get("DIFFUSERS_SLOW_IMPORT", "FALSE").upper()
5456
DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES
5557

56-
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
57-
58-
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
58+
STR_OPERATION_TO_FUNC = {
59+
">": op.gt,
60+
">=": op.ge,
61+
"==": op.eq,
62+
"!=": op.ne,
63+
"<=": op.le,
64+
"<": op.lt,
65+
}
66+
67+
_is_google_colab = "google.colab" in sys.modules or any(
68+
k.startswith("COLAB_") for k in os.environ
69+
)
5970

6071

61-
def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[bool, str]:
72+
def _is_package_available(
73+
pkg_name: str, get_dist_name: bool = False
74+
) -> Tuple[bool, str]:
6275
global _package_map
6376
pkg_exists = importlib.util.find_spec(pkg_name) is not None
6477
pkg_version = "N/A"
@@ -69,11 +82,16 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
6982
try:
7083
# Fallback for Python < 3.10
7184
for dist in importlib_metadata.distributions():
72-
_top_level_declared = (dist.read_text("top_level.txt") or "").split()
85+
_top_level_declared = (
86+
dist.read_text("top_level.txt") or ""
87+
).split()
7388
_infered_opt_names = {
74-
f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in (dist.files or [])
89+
f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f)
90+
for f in (dist.files or [])
7591
} - {None}
76-
_top_level_inferred = filter(lambda name: "." not in name, _infered_opt_names)
92+
_top_level_inferred = filter(
93+
lambda name: "." not in name, _infered_opt_names
94+
)
7795
for pkg in _top_level_declared or _top_level_inferred:
7896
_package_map[pkg].append(dist.metadata["Name"])
7997
except Exception as _:
@@ -99,16 +117,22 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
99117
else:
100118
logger.info("Disabling PyTorch because USE_TORCH is set")
101119
_torch_available = False
120+
_torch_version = "N/A"
102121

103122
_jax_version = "N/A"
104123
_flax_version = "N/A"
105124
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
106-
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
125+
_flax_available = (
126+
importlib.util.find_spec("jax") is not None
127+
and importlib.util.find_spec("flax") is not None
128+
)
107129
if _flax_available:
108130
try:
109131
_jax_version = importlib_metadata.version("jax")
110132
_flax_version = importlib_metadata.version("flax")
111-
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
133+
logger.info(
134+
f"JAX version {_jax_version}, Flax version {_flax_version} available."
135+
)
112136
except importlib_metadata.PackageNotFoundError:
113137
_flax_available = False
114138
else:
@@ -148,7 +172,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
148172
pass
149173
_onnx_available = _onnxruntime_version is not None
150174
if _onnx_available:
151-
logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
175+
logger.debug(
176+
f"Successfully imported onnxruntime version {_onnxruntime_version}"
177+
)
152178

153179
# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed.
154180
# _opencv_available = importlib.util.find_spec("opencv-python") is not None
@@ -183,7 +209,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
183209
_invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None
184210
try:
185211
_invisible_watermark_version = importlib_metadata.version("invisible-watermark")
186-
logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}")
212+
logger.debug(
213+
f"Successfully imported invisible-watermark version {_invisible_watermark_version}"
214+
)
187215
except importlib_metadata.PackageNotFoundError:
188216
_invisible_watermark_available = False
189217

@@ -198,7 +226,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
198226
_wandb_available, _wandb_version = _is_package_available("wandb")
199227
_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard")
200228
_compel_available, _compel_version = _is_package_available("compel")
201-
_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece")
229+
_sentencepiece_available, _sentencepiece_version = _is_package_available(
230+
"sentencepiece"
231+
)
202232
_torchsde_available, _torchsde_version = _is_package_available("torchsde")
203233
_peft_available, _peft_version = _is_package_available("peft")
204234
_torchvision_available, _torchvision_version = _is_package_available("torchvision")
@@ -214,11 +244,19 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
214244
_gguf_available, _gguf_version = _is_package_available("gguf")
215245
_torchao_available, _torchao_version = _is_package_available("torchao")
216246
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
217-
_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True)
218-
_pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface")
219-
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
247+
_optimum_quanto_available, _optimum_quanto_version = _is_package_available(
248+
"optimum", get_dist_name=True
249+
)
250+
_pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available(
251+
"pytorch_retinaface"
252+
)
253+
_better_profanity_available, _better_profanity_version = _is_package_available(
254+
"better_profanity"
255+
)
220256
_nltk_available, _nltk_version = _is_package_available("nltk")
221-
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
257+
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available(
258+
"cosmos_guardrail"
259+
)
222260

223261

224262
def is_torch_available():
@@ -374,7 +412,10 @@ def is_cosmos_guardrail_available():
374412

375413

376414
def is_hpu_available():
377-
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
415+
return all(
416+
importlib.util.find_spec(lib)
417+
for lib in ("habana_frameworks", "habana_frameworks.torch")
418+
)
378419

379420

380421
# docstyle-ignore
@@ -560,7 +601,10 @@ def is_hpu_available():
560601
("compel", (is_compel_available, COMPEL_IMPORT_ERROR)),
561602
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
562603
("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)),
563-
("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)),
604+
(
605+
"invisible_watermark",
606+
(is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR),
607+
),
564608
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
565609
("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)),
566610
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
@@ -569,8 +613,14 @@ def is_hpu_available():
569613
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
570614
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
571615
("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)),
572-
("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)),
573-
("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)),
616+
(
617+
"pytorch_retinaface",
618+
(is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR),
619+
),
620+
(
621+
"better_profanity",
622+
(is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR),
623+
),
574624
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
575625
]
576626
)
@@ -598,9 +648,10 @@ def requires_backends(obj, backends):
598648
" --upgrade transformers \n```"
599649
)
600650

601-
if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version(
602-
"<", "4.26.0"
603-
):
651+
if name in [
652+
"StableDiffusionDepth2ImgPipeline",
653+
"StableDiffusionPix2PixZeroPipeline",
654+
] and is_transformers_version("<", "4.26.0"):
604655
raise ImportError(
605656
f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install"
606657
" --upgrade transformers \n```"
@@ -620,7 +671,9 @@ def __getattr__(cls, key):
620671

621672

622673
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
623-
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
674+
def compare_versions(
675+
library_or_version: Union[str, Version], operation: str, requirement_version: str
676+
):
624677
"""
625678
Compares a library version to some requirement using a given operation.
626679
@@ -633,7 +686,9 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
633686
The version to compare the library version against
634687
"""
635688
if operation not in STR_OPERATION_TO_FUNC.keys():
636-
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
689+
raise ValueError(
690+
f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}"
691+
)
637692
operation = STR_OPERATION_TO_FUNC[operation]
638693
if isinstance(library_or_version, str):
639694
library_or_version = parse(importlib_metadata.version(library_or_version))
@@ -837,15 +892,19 @@ class _LazyModule(ModuleType):
837892

838893
# Very heavily inspired by optuna.integration._IntegrationModule
839894
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
840-
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
895+
def __init__(
896+
self, name, module_file, import_structure, module_spec=None, extra_objects=None
897+
):
841898
super().__init__(name)
842899
self._modules = set(import_structure.keys())
843900
self._class_to_module = {}
844901
for key, values in import_structure.items():
845902
for value in values:
846903
self._class_to_module[value] = key
847904
# Needed for autocompletion in an IDE
848-
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
905+
self.__all__ = list(import_structure.keys()) + list(
906+
chain(*import_structure.values())
907+
)
849908
self.__file__ = module_file
850909
self.__spec__ = module_spec
851910
self.__path__ = [os.path.dirname(module_file)]

0 commit comments

Comments
 (0)