Skip to content

Commit

Permalink
Add argument library_name when parameters standartization (#2179)
Browse files Browse the repository at this point in the history
* avoid library_name guessing if it is known in parameters standartization

* Update optimum/subpackages.py

---------

Co-authored-by: Ilyas Moutawwakil <[email protected]>
  • Loading branch information
eaidova and IlyasMoutawwakil authored Feb 12, 2025
1 parent 856b252 commit 512d5c6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
11 changes: 8 additions & 3 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2067,7 +2067,11 @@ def infer_library_from_model(
return library_name

@classmethod
def standardize_model_attributes(cls, model: Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]):
def standardize_model_attributes(
cls,
model: Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"],
library_name: Optional[str] = None,
):
"""
Updates the model for export. This function is suitable to make required changes to the models from different
libraries to follow transformers style.
Expand All @@ -2078,7 +2082,8 @@ def standardize_model_attributes(cls, model: Union["PreTrainedModel", "TFPreTrai
"""

library_name = TasksManager.infer_library_from_model(model)
if library_name is None:
library_name = TasksManager.infer_library_from_model(model)

if library_name == "diffusers":
inferred_model_type = None
Expand Down Expand Up @@ -2295,7 +2300,7 @@ def get_model_from_task(
kwargs["from_pt"] = True
model = model_class.from_pretrained(model_name_or_path, **kwargs)

TasksManager.standardize_model_attributes(model)
TasksManager.standardize_model_attributes(model, library_name=library_name)

return model

Expand Down
4 changes: 3 additions & 1 deletion optimum/subpackages.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def load_namespace_modules(namespace: str, module: str):
"""
for dist in importlib_metadata.distributions():
dist_name = dist.metadata["Name"]
if not dist_name.startswith(f"{namespace}-"):
if dist_name is None:
continue
if dist_name == f"{namespace}-benchmark":
continue
if not dist_name.startswith(f"{namespace}-"):
continue
package_import_name = dist_name.replace("-", ".")
module_import_name = f"{package_import_name}.{module}"
if module_import_name in sys.modules:
Expand Down

0 comments on commit 512d5c6

Please sign in to comment.