Skip to content

Commit 80f6367

Browse files
Add helper function to clean up artifacts
1 parent 5a21d66 commit 80f6367

File tree

2 files changed

+55
-9
lines changed

2 files changed

+55
-9
lines changed

ads/aqua/common/utils.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
"""AQUA utils and constants."""
55

@@ -11,6 +11,7 @@
1111
import random
1212
import re
1313
import shlex
14+
import shutil
1415
import subprocess
1516
from datetime import datetime, timedelta
1617
from functools import wraps
@@ -21,6 +22,8 @@
2122
import fsspec
2223
import oci
2324
from cachetools import TTLCache, cached
25+
from huggingface_hub.constants import HF_HUB_CACHE
26+
from huggingface_hub.file_download import repo_folder_name
2427
from huggingface_hub.hf_api import HfApi, ModelInfo
2528
from huggingface_hub.utils import (
2629
GatedRepoError,
@@ -788,7 +791,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
788791
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
789792

790793

791-
def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
794+
def upload_folder(
795+
os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None
796+
) -> str:
792797
"""Upload the local folder to the object storage
793798
794799
Args:
@@ -818,6 +823,38 @@ def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern
818823
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
819824

820825

826+
def cleanup_local_hf_model_artifact(
827+
model_name: str,
828+
local_dir: str = None,
829+
):
830+
"""
831+
Helper function that deletes local artifacts downloaded from Hugging Face to free up disk space.
832+
Parameters
833+
----------
834+
model_name (str): Name of the huggingface model
835+
local_dir (str): Local directory where the object is downloaded
836+
837+
"""
838+
if local_dir and os.path.exists(local_dir):
839+
model_dir = os.path.join(local_dir, model_name)
840+
if os.path.exists(model_dir):
841+
shutil.rmtree(model_dir)
842+
logger.debug(f"Deleted local model artifact directory: {model_dir}")
843+
844+
if not os.listdir(local_dir):
845+
shutil.rmtree(local_dir)
846+
logger.debug(f"Deleted local directory {model_dir} as it is empty.")
847+
848+
hf_local_path = os.path.join(
849+
HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model")
850+
)
851+
if os.path.exists(hf_local_path):
852+
shutil.rmtree(hf_local_path)
853+
logger.debug(
854+
f"Deleted local Hugging Face cache directory {hf_local_path} for the model {model_name} "
855+
)
856+
857+
821858
def is_service_managed_container(container):
822859
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
823860

ads/aqua/model/model.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
import os
55
import pathlib
@@ -23,6 +23,7 @@
2323
from ads.aqua.common.utils import (
2424
LifecycleStatus,
2525
_build_resource_identifier,
26+
cleanup_local_hf_model_artifact,
2627
copy_model_config,
2728
create_word_icon,
2829
generate_tei_cmd_var,
@@ -1322,20 +1323,18 @@ def _download_model_from_hf(
13221323
Returns
13231324
-------
13241325
model_artifact_path (str): Location where the model artifacts are downloaded.
1325-
13261326
"""
13271327
# Download the model from hub
1328-
if not local_dir:
1329-
local_dir = os.path.join(os.path.expanduser("~"), "cached-model")
1330-
local_dir = os.path.join(local_dir, model_name)
1331-
os.makedirs(local_dir, exist_ok=True)
1328+
if local_dir:
1329+
local_dir = os.path.join(local_dir, model_name)
1330+
os.makedirs(local_dir, exist_ok=True)
13321331
snapshot_download(
13331332
repo_id=model_name,
13341333
local_dir=local_dir,
13351334
allow_patterns=allow_patterns,
13361335
ignore_patterns=ignore_patterns,
13371336
)
1338-
# Upload to object storage and skip .cache/huggingface/ folder
1337+
# Upload to object storage
13391338
model_artifact_path = upload_folder(
13401339
os_path=os_path,
13411340
local_dir=local_dir,
@@ -1365,6 +1364,8 @@ def register(
13651364
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
13661365
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
13671366
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
1367+
delete_from_local (bool): Deletes downloaded files from local machine after model is successfully
1368+
registered. Set to True by default.
13681369
13691370
Returns:
13701371
AquaModel:
@@ -1474,6 +1475,14 @@ def register(
14741475
detail=validation_result.telemetry_model_name,
14751476
)
14761477

1478+
if (
1479+
import_model_details.download_from_hf
1480+
and import_model_details.delete_from_local
1481+
):
1482+
cleanup_local_hf_model_artifact(
1483+
model_name=model_name, local_dir=import_model_details.local_dir
1484+
)
1485+
14771486
return AquaModel(**aqua_model_attributes)
14781487

14791488
def _if_show(self, model: DataScienceModel) -> bool:

0 commit comments

Comments
 (0)