Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizations to BlobStorageEngine, ExpertLibrary #156

Merged
merged 4 commits into from
Jan 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 87 additions & 59 deletions mttl/models/library/backend_engine.py
Original file line number Diff line number Diff line change
@@ -137,6 +137,7 @@ def __init__(self, token: Optional[str] = None, cache_dir: Optional[str] = None)
self.cache_dir = cache_dir
# Quiet down the azure logging
logging.getLogger("azure").setLevel(logging.WARNING)
self.last_modified_cache = None

@property
def cache_dir(self):
@@ -200,16 +201,27 @@ def _get_container_client(self, repo_id, use_async=False):
storage_uri, container = self._parse_repo_id_to_storage_info(repo_id)
return self._get_blob_client(repo_id, use_async).get_container_client(container)

def _last_modified(self, repo_id: str) -> datetime.datetime:
def _last_modified(
self, repo_id: str, set_cache: bool = False
) -> datetime.datetime:
"""Get the last modified date of a repository."""
try:
return (
self._get_container_client(repo_id)
.get_container_properties()
.last_modified
)
except ResourceNotFoundError as error:
raise ValueError(f"Repository {repo_id} not found") from error

# if cached version exists, return cache. We want to avoid repetitive calls to the API
if self.last_modified_cache:
return self.last_modified_cache

else:
try:
last_modified = (
self._get_container_client(repo_id)
.get_container_properties()
.last_modified
)
if set_cache:
self.last_modified_cache = last_modified
return last_modified
except ResourceNotFoundError as error:
raise ValueError(f"Repository {repo_id} not found") from error

def get_repository_cache_dir(self, repo_id: str) -> Path:
"""Get the cache directory for a repository. If it doesn't exist, create it.
@@ -280,23 +292,17 @@ def delete_repo(self, repo_id, repo_type=None):
except ResourceNotFoundError:
logger.info(f"Container {repo_id} not found.")

def create_commit(self, repo_id, operations, commit_message="", async_mode=False):
def create_commit(self, repo_id, operations, commit_message="", async_mode=True):
asyncio.run(
self.async_create_commit(repo_id, operations, async_mode=async_mode)
)

async def async_create_commit(self, repo_id, operations, async_mode=False):
tasks = []
upload_batch = []
for op in operations:
if isinstance(op, CommitOperationAdd):
tasks.append(
self._async_upload_blob(
repo_id=repo_id,
filename=op.path_in_repo,
buffer=op.path_or_fileobj,
overwrite=True,
)
)
upload_batch.append(op)
elif isinstance(op, CommitOperationCopy):
tasks.append(
self._async_copy_blob(
@@ -314,11 +320,13 @@ async def async_create_commit(self, repo_id, operations, async_mode=False):
filename=op.path_in_repo,
)
)
if async_mode:
await asyncio.gather(*tasks)
else:
for task in tasks:
await task

# upload blobs in batch, using async!
await self.async_upload_blobs(
repo_id,
filenames=[op.path_in_repo for op in upload_batch],
buffers=[op.path_or_fileobj for op in upload_batch],
)

def preupload_lfs_files(self, repo_id, additions):
# for blob storage, these operations are done in create_commit
@@ -390,62 +398,82 @@ async def async_upload_blobs(
else:
if len(buffers) != len(filenames):
raise ValueError("Filenames and buffers must have the same length.")
tasks = [
self._async_upload_blob(repo_id, filename, buffer, overwrite)
for filename, buffer in zip(filenames, buffers)
]
await asyncio.gather(*tasks)
return filenames[0] if is_str else filenames

async def _async_upload_blob(self, repo_id, filename, buffer=None, overwrite=True):
storage_uri, container = self._parse_repo_id_to_storage_info(repo_id)
self._last_modified(repo_id, set_cache=True) # set the cache for last_modified

async with self._get_blob_client(
repo_id, use_async=True
) as blob_service_client:
blob_client = blob_service_client.get_blob_client(
container=container, blob=filename
)
if buffer is not None:
await blob_client.upload_blob(buffer, overwrite=overwrite)
else:
local_cache = self._get_local_filepath(repo_id, filename)
with open(file=local_cache, mode="rb") as blob_file:
await blob_client.upload_blob(blob_file, overwrite=overwrite)
tasks = [
self._async_upload_blob(
blob_service_client, repo_id, filename, buffer, overwrite
)
for filename, buffer in zip(filenames, buffers)
]
await asyncio.gather(*tasks)

self.last_modified_cache = None # reset the cache

return filenames[0] if is_str else filenames

async def _async_upload_blob(
self, blob_service_client, repo_id, filename, buffer=None, overwrite=True
):
storage_uri, container = self._parse_repo_id_to_storage_info(repo_id)

blob_client = blob_service_client.get_blob_client(
container=container, blob=filename
)

if buffer is not None:
await blob_client.upload_blob(buffer, overwrite=overwrite)

else:
local_cache = self._get_local_filepath(repo_id, filename)

with open(file=local_cache, mode="rb") as blob_file:
await blob_client.upload_blob(blob_file, overwrite=overwrite)

async def async_download_blobs(
self, repo_id: str, filesnames: Union[List[str], str]
) -> str:
is_str = isinstance(filesnames, str)
if is_str:
filesnames = [filesnames]
tasks = [
self._async_download_blob(repo_id, filename) for filename in filesnames
]
local_filenames = await asyncio.gather(*tasks)
return local_filenames[0] if is_str else local_filenames

async def _async_download_blob(self, repo_id, filename):
storage_uri, container = self._parse_repo_id_to_storage_info(repo_id)
self._last_modified(repo_id, set_cache=True) # set the cache for last_modified

async with self._get_blob_client(
repo_id, use_async=True
) as blob_service_client:
# already cached!
local_filename = self._get_local_filepath(repo_id, filename)
if local_filename.exists():
return local_filename
tasks = [
self._async_download_blob(blob_service_client, repo_id, filename)
for filename in filesnames
]
local_filesnames = await asyncio.gather(*tasks)

blob_client = blob_service_client.get_blob_client(
container=container, blob=filename
)
self.last_modified_cache = None # reset the cache

return local_filesnames[0] if is_str else local_filesnames

os.makedirs(os.path.dirname(local_filename), exist_ok=True)
with open(file=local_filename, mode="wb") as blob_file:
download_stream = await blob_client.download_blob()
data = await download_stream.readall()
blob_file.write(data)
async def _async_download_blob(self, blob_service_client, repo_id, filename):
# already cached!
local_filename = self._get_local_filepath(repo_id, filename)
if local_filename.exists():
return local_filename

storage_uri, container = self._parse_repo_id_to_storage_info(repo_id)
blob_client = blob_service_client.get_blob_client(
container=container, blob=filename
)

os.makedirs(os.path.dirname(local_filename), exist_ok=True)
async with open(file=local_filename, mode="wb") as blob_file:
download_stream = await blob_client.download_blob()
data = await download_stream.readall()
blob_file.write(data)
return local_filename

async def async_copy_blobs(
self,
source_repo_ids,
87 changes: 61 additions & 26 deletions mttl/models/library/expert_library.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import glob
import io
import os
@@ -141,31 +142,49 @@ def _build_lib(self):
logger.error("Repository not found: %s", self.repo_id)
raise e

# Function to download and process a single .meta file
def download_and_process_meta_file(file):
path_or_bytes = self.hf_hub_download(self.repo_id, file)

metadata_entry = MetadataEntry.fromdict(
torch.load(path_or_bytes, map_location="cpu", weights_only=False)
if isinstance(self, BlobExpertLibrary):
local_filenames = asyncio.run(
self.async_download_blobs(
self.repo_id,
meta_files,
)
)
return metadata_entry

# Use ThreadPoolExecutor for multithreading
with ThreadPoolExecutor() as executor:
# Submit tasks to the executor
future_to_file = {
executor.submit(download_and_process_meta_file, file): file
for file in meta_files
}

# process every meta file in new local directory
metadata = []
for future in as_completed(future_to_file):
file = future_to_file[future]
try:
data = future.result()
metadata.append(data)
except Exception as exc:
logger.error("%r generated an exception: %s" % (file, exc))
for file in local_filenames:
metadata_entry = MetadataEntry.fromdict(
torch.load(file, map_location="cpu", weights_only=False)
)
metadata.append(metadata_entry)

else:

# Function to download and process a single .meta file
def download_and_process_meta_file(file):
path_or_bytes = self.hf_hub_download(self.repo_id, file)

metadata_entry = MetadataEntry.fromdict(
torch.load(path_or_bytes, map_location="cpu", weights_only=False)
)
return metadata_entry

# Use ThreadPoolExecutor for multithreading
with ThreadPoolExecutor() as executor:
# Submit tasks to the executor
future_to_file = {
executor.submit(download_and_process_meta_file, file): file
for file in meta_files
}

metadata = []
for future in as_completed(future_to_file):
file = future_to_file[future]
try:
data = future.result()
metadata.append(data)
except Exception as exc:
logger.error("%r generated an exception: %s" % (file, exc))

for metadatum in metadata:
if self.model_name is not None and metadatum.model != self.model_name:
@@ -282,7 +301,11 @@ def __len__(self):
return len(self.data)

def add_expert(
self, expert_dump: Expert, expert_name: str = None, force: bool = False
self,
expert_dump: Expert,
expert_name: str = None,
force: bool = False,
update_readme: bool = True,
):
if self.sliced:
raise ValueError("Cannot add expert to sliced library.")
@@ -307,7 +330,9 @@ def add_expert(
self._upload_weights(metadata.expert_name, expert_dump)
self._upload_metadata(metadata)
self.data[metadata.expert_name] = metadata
self._update_readme()
# only update readme if requested. This is useful when adding multiple experts in a batch
if update_readme:
self._update_readme()

def list_auxiliary_data(self) -> Dict[str, Tuple[int, str]]:
"""List auxiliary data in the library, returns a dictionary with the data type, the number of records, and a string representation of the config file."""
@@ -770,9 +795,15 @@ def clone(

only_tasks = only_tasks or self.tasks
with new_lib.batched_commit():
update_readme = False
for name, expert in self.items():
if expert.name not in new_lib:
new_lib.add_expert(expert, name, force=force)
new_lib.add_expert(expert, name, force=force, update_readme=False)
update_readme = True

# only update readme if we added new experts
if update_readme:
new_lib._update_readme()

# if the new_lib already exists, delete experts that
# are in this lib but were deleted from the expert_lib
@@ -929,7 +960,11 @@ class LocalExpertLibrary(ExpertLibrary, LocalFSEngine):
"""A local library stored on disk."""

def add_expert(
self, expert_dump: Expert, expert_name: str = None, force: bool = False
self,
expert_dump: Expert,
expert_name: str = None,
force: bool = False,
update_readme: bool = True,
):
expert_name = expert_name or expert_dump.expert_info.expert_name
if "/" in expert_name:
Loading