37
37
else :
38
38
import importlib .metadata as importlib_metadata
39
39
try :
40
- _package_map = importlib_metadata .packages_distributions () # load-once to avoid expensive calls
40
+ _package_map = (
41
+ importlib_metadata .packages_distributions ()
42
+ ) # load-once to avoid expensive calls
41
43
except Exception :
42
44
_package_map = None
43
45
53
55
DIFFUSERS_SLOW_IMPORT = os .environ .get ("DIFFUSERS_SLOW_IMPORT" , "FALSE" ).upper ()
54
56
DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES
55
57
56
- STR_OPERATION_TO_FUNC = {">" : op .gt , ">=" : op .ge , "==" : op .eq , "!=" : op .ne , "<=" : op .le , "<" : op .lt }
57
-
58
- _is_google_colab = "google.colab" in sys .modules or any (k .startswith ("COLAB_" ) for k in os .environ )
58
+ STR_OPERATION_TO_FUNC = {
59
+ ">" : op .gt ,
60
+ ">=" : op .ge ,
61
+ "==" : op .eq ,
62
+ "!=" : op .ne ,
63
+ "<=" : op .le ,
64
+ "<" : op .lt ,
65
+ }
66
+
67
+ _is_google_colab = "google.colab" in sys .modules or any (
68
+ k .startswith ("COLAB_" ) for k in os .environ
69
+ )
59
70
60
71
61
- def _is_package_available (pkg_name : str , get_dist_name : bool = False ) -> Tuple [bool , str ]:
72
+ def _is_package_available (
73
+ pkg_name : str , get_dist_name : bool = False
74
+ ) -> Tuple [bool , str ]:
62
75
global _package_map
63
76
pkg_exists = importlib .util .find_spec (pkg_name ) is not None
64
77
pkg_version = "N/A"
@@ -69,11 +82,16 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
69
82
try :
70
83
# Fallback for Python < 3.10
71
84
for dist in importlib_metadata .distributions ():
72
- _top_level_declared = (dist .read_text ("top_level.txt" ) or "" ).split ()
85
+ _top_level_declared = (
86
+ dist .read_text ("top_level.txt" ) or ""
87
+ ).split ()
73
88
_infered_opt_names = {
74
- f .parts [0 ] if len (f .parts ) > 1 else inspect .getmodulename (f ) for f in (dist .files or [])
89
+ f .parts [0 ] if len (f .parts ) > 1 else inspect .getmodulename (f )
90
+ for f in (dist .files or [])
75
91
} - {None }
76
- _top_level_inferred = filter (lambda name : "." not in name , _infered_opt_names )
92
+ _top_level_inferred = filter (
93
+ lambda name : "." not in name , _infered_opt_names
94
+ )
77
95
for pkg in _top_level_declared or _top_level_inferred :
78
96
_package_map [pkg ].append (dist .metadata ["Name" ])
79
97
except Exception as _ :
@@ -99,16 +117,22 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
99
117
else :
100
118
logger .info ("Disabling PyTorch because USE_TORCH is set" )
101
119
_torch_available = False
120
+ _torch_version = "N/A"
102
121
103
122
_jax_version = "N/A"
104
123
_flax_version = "N/A"
105
124
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES :
106
- _flax_available = importlib .util .find_spec ("jax" ) is not None and importlib .util .find_spec ("flax" ) is not None
125
+ _flax_available = (
126
+ importlib .util .find_spec ("jax" ) is not None
127
+ and importlib .util .find_spec ("flax" ) is not None
128
+ )
107
129
if _flax_available :
108
130
try :
109
131
_jax_version = importlib_metadata .version ("jax" )
110
132
_flax_version = importlib_metadata .version ("flax" )
111
- logger .info (f"JAX version { _jax_version } , Flax version { _flax_version } available." )
133
+ logger .info (
134
+ f"JAX version { _jax_version } , Flax version { _flax_version } available."
135
+ )
112
136
except importlib_metadata .PackageNotFoundError :
113
137
_flax_available = False
114
138
else :
@@ -148,7 +172,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
148
172
pass
149
173
_onnx_available = _onnxruntime_version is not None
150
174
if _onnx_available :
151
- logger .debug (f"Successfully imported onnxruntime version { _onnxruntime_version } " )
175
+ logger .debug (
176
+ f"Successfully imported onnxruntime version { _onnxruntime_version } "
177
+ )
152
178
153
179
# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed.
154
180
# _opencv_available = importlib.util.find_spec("opencv-python") is not None
@@ -183,7 +209,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
183
209
_invisible_watermark_available = importlib .util .find_spec ("imwatermark" ) is not None
184
210
try :
185
211
_invisible_watermark_version = importlib_metadata .version ("invisible-watermark" )
186
- logger .debug (f"Successfully imported invisible-watermark version { _invisible_watermark_version } " )
212
+ logger .debug (
213
+ f"Successfully imported invisible-watermark version { _invisible_watermark_version } "
214
+ )
187
215
except importlib_metadata .PackageNotFoundError :
188
216
_invisible_watermark_available = False
189
217
@@ -198,7 +226,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
198
226
_wandb_available , _wandb_version = _is_package_available ("wandb" )
199
227
_tensorboard_available , _tensorboard_version = _is_package_available ("tensorboard" )
200
228
_compel_available , _compel_version = _is_package_available ("compel" )
201
- _sentencepiece_available , _sentencepiece_version = _is_package_available ("sentencepiece" )
229
+ _sentencepiece_available , _sentencepiece_version = _is_package_available (
230
+ "sentencepiece"
231
+ )
202
232
_torchsde_available , _torchsde_version = _is_package_available ("torchsde" )
203
233
_peft_available , _peft_version = _is_package_available ("peft" )
204
234
_torchvision_available , _torchvision_version = _is_package_available ("torchvision" )
@@ -214,11 +244,19 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
214
244
_gguf_available , _gguf_version = _is_package_available ("gguf" )
215
245
_torchao_available , _torchao_version = _is_package_available ("torchao" )
216
246
_bitsandbytes_available , _bitsandbytes_version = _is_package_available ("bitsandbytes" )
217
- _optimum_quanto_available , _optimum_quanto_version = _is_package_available ("optimum" , get_dist_name = True )
218
- _pytorch_retinaface_available , _pytorch_retinaface_version = _is_package_available ("pytorch_retinaface" )
219
- _better_profanity_available , _better_profanity_version = _is_package_available ("better_profanity" )
247
+ _optimum_quanto_available , _optimum_quanto_version = _is_package_available (
248
+ "optimum" , get_dist_name = True
249
+ )
250
+ _pytorch_retinaface_available , _pytorch_retinaface_version = _is_package_available (
251
+ "pytorch_retinaface"
252
+ )
253
+ _better_profanity_available , _better_profanity_version = _is_package_available (
254
+ "better_profanity"
255
+ )
220
256
_nltk_available , _nltk_version = _is_package_available ("nltk" )
221
- _cosmos_guardrail_available , _cosmos_guardrail_version = _is_package_available ("cosmos_guardrail" )
257
+ _cosmos_guardrail_available , _cosmos_guardrail_version = _is_package_available (
258
+ "cosmos_guardrail"
259
+ )
222
260
223
261
224
262
def is_torch_available ():
@@ -374,7 +412,10 @@ def is_cosmos_guardrail_available():
374
412
375
413
376
414
def is_hpu_available ():
377
- return all (importlib .util .find_spec (lib ) for lib in ("habana_frameworks" , "habana_frameworks.torch" ))
415
+ return all (
416
+ importlib .util .find_spec (lib )
417
+ for lib in ("habana_frameworks" , "habana_frameworks.torch" )
418
+ )
378
419
379
420
380
421
# docstyle-ignore
@@ -560,7 +601,10 @@ def is_hpu_available():
560
601
("compel" , (is_compel_available , COMPEL_IMPORT_ERROR )),
561
602
("ftfy" , (is_ftfy_available , FTFY_IMPORT_ERROR )),
562
603
("torchsde" , (is_torchsde_available , TORCHSDE_IMPORT_ERROR )),
563
- ("invisible_watermark" , (is_invisible_watermark_available , INVISIBLE_WATERMARK_IMPORT_ERROR )),
604
+ (
605
+ "invisible_watermark" ,
606
+ (is_invisible_watermark_available , INVISIBLE_WATERMARK_IMPORT_ERROR ),
607
+ ),
564
608
("peft" , (is_peft_available , PEFT_IMPORT_ERROR )),
565
609
("safetensors" , (is_safetensors_available , SAFETENSORS_IMPORT_ERROR )),
566
610
("bitsandbytes" , (is_bitsandbytes_available , BITSANDBYTES_IMPORT_ERROR )),
@@ -569,8 +613,14 @@ def is_hpu_available():
569
613
("gguf" , (is_gguf_available , GGUF_IMPORT_ERROR )),
570
614
("torchao" , (is_torchao_available , TORCHAO_IMPORT_ERROR )),
571
615
("quanto" , (is_optimum_quanto_available , QUANTO_IMPORT_ERROR )),
572
- ("pytorch_retinaface" , (is_pytorch_retinaface_available , PYTORCH_RETINAFACE_IMPORT_ERROR )),
573
- ("better_profanity" , (is_better_profanity_available , BETTER_PROFANITY_IMPORT_ERROR )),
616
+ (
617
+ "pytorch_retinaface" ,
618
+ (is_pytorch_retinaface_available , PYTORCH_RETINAFACE_IMPORT_ERROR ),
619
+ ),
620
+ (
621
+ "better_profanity" ,
622
+ (is_better_profanity_available , BETTER_PROFANITY_IMPORT_ERROR ),
623
+ ),
574
624
("nltk" , (is_nltk_available , NLTK_IMPORT_ERROR )),
575
625
]
576
626
)
@@ -598,9 +648,10 @@ def requires_backends(obj, backends):
598
648
" --upgrade transformers \n ```"
599
649
)
600
650
601
- if name in ["StableDiffusionDepth2ImgPipeline" , "StableDiffusionPix2PixZeroPipeline" ] and is_transformers_version (
602
- "<" , "4.26.0"
603
- ):
651
+ if name in [
652
+ "StableDiffusionDepth2ImgPipeline" ,
653
+ "StableDiffusionPix2PixZeroPipeline" ,
654
+ ] and is_transformers_version ("<" , "4.26.0" ):
604
655
raise ImportError (
605
656
f"You need to install `transformers>=4.26` in order to use { name } : \n ```\n pip install"
606
657
" --upgrade transformers \n ```"
@@ -620,7 +671,9 @@ def __getattr__(cls, key):
620
671
621
672
622
673
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
623
- def compare_versions (library_or_version : Union [str , Version ], operation : str , requirement_version : str ):
674
+ def compare_versions (
675
+ library_or_version : Union [str , Version ], operation : str , requirement_version : str
676
+ ):
624
677
"""
625
678
Compares a library version to some requirement using a given operation.
626
679
@@ -633,7 +686,9 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
633
686
The version to compare the library version against
634
687
"""
635
688
if operation not in STR_OPERATION_TO_FUNC .keys ():
636
- raise ValueError (f"`operation` must be one of { list (STR_OPERATION_TO_FUNC .keys ())} , received { operation } " )
689
+ raise ValueError (
690
+ f"`operation` must be one of { list (STR_OPERATION_TO_FUNC .keys ())} , received { operation } "
691
+ )
637
692
operation = STR_OPERATION_TO_FUNC [operation ]
638
693
if isinstance (library_or_version , str ):
639
694
library_or_version = parse (importlib_metadata .version (library_or_version ))
@@ -837,15 +892,19 @@ class _LazyModule(ModuleType):
837
892
838
893
# Very heavily inspired by optuna.integration._IntegrationModule
839
894
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
840
- def __init__ (self , name , module_file , import_structure , module_spec = None , extra_objects = None ):
895
+ def __init__ (
896
+ self , name , module_file , import_structure , module_spec = None , extra_objects = None
897
+ ):
841
898
super ().__init__ (name )
842
899
self ._modules = set (import_structure .keys ())
843
900
self ._class_to_module = {}
844
901
for key , values in import_structure .items ():
845
902
for value in values :
846
903
self ._class_to_module [value ] = key
847
904
# Needed for autocompletion in an IDE
848
- self .__all__ = list (import_structure .keys ()) + list (chain (* import_structure .values ()))
905
+ self .__all__ = list (import_structure .keys ()) + list (
906
+ chain (* import_structure .values ())
907
+ )
849
908
self .__file__ = module_file
850
909
self .__spec__ = module_spec
851
910
self .__path__ = [os .path .dirname (module_file )]
0 commit comments