Skip to content

Commit

Permalink
feat: fal compressed file
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Rochetti <[email protected]>
  • Loading branch information
badayvedat and drochetti committed Dec 26, 2023
1 parent 41e507a commit 6347826
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 3 deletions.
2 changes: 1 addition & 1 deletion projects/fal/src/fal/toolkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from fal.toolkit.file import File
from fal.toolkit.file import CompressedFile, File
from fal.toolkit.image.image import Image, ImageSizeInput, get_image_size
from fal.toolkit.mainify import mainify
from fal.toolkit.utils import (
Expand Down
34 changes: 34 additions & 0 deletions projects/fal/src/fal/toolkit/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Any, Callable
from urllib.parse import urlparse
from zipfile import ZipFile

from fal.toolkit.file.providers.fal import FalFileRepository, InMemoryRepository
from fal.toolkit.file.providers.gcp import GoogleStorageRepository
Expand All @@ -13,6 +14,7 @@
from fal.toolkit.mainify import mainify
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.typing import Optional
from tempfile import TemporaryDirectory

FileRepositoryFactory = Callable[[], FileRepository]

Expand Down Expand Up @@ -152,3 +154,35 @@ def as_bytes(self) -> bytes:
def save(self, path: str | Path):
file_path = Path(path)
file_path.write_bytes(self.as_bytes())


@mainify
class CompressedFile(File):
_extract_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._extract_dir = None

def __iter__(self):
if not self._extract_dir:
self._extract_files(self.as_bytes())

files = Path(self._extract_dir.name).iterdir() # type: ignore
return iter(files)

def _extract_files(self, file_bytes: bytes):
self._extract_dir = TemporaryDirectory()

with ZipFile(BytesIO(file_bytes)) as zip_file:
zip_file.extractall(self._extract_dir.name)

def glob(self, pattern: str):
if not self._extract_dir:
self._extract_files(self.as_bytes())

return Path(self._extract_dir.name).glob(pattern) # type: ignore

def __del__(self):
if self._extract_dir:
self._extract_dir.cleanup()
28 changes: 26 additions & 2 deletions projects/fal/tests/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
from pathlib import Path
import tempfile
from uuid import uuid4
from fal.toolkit import File
import fal
import pytest
from fal import FalServerlessHost, FalServerlessKeyCredentials, local, sync_dir
from fal.api import FalServerlessError
from fal.toolkit import clone_repository, download_file, download_model_weights
from fal.toolkit import (
clone_repository,
download_file,
download_model_weights,
CompressedFile,
File,
)
from fal.toolkit.utils.download_utils import _get_git_revision_hash, _hash_url
from pydantic import BaseModel


def test_isolated(isolated_client):
Expand Down Expand Up @@ -523,3 +529,21 @@ def init_file_on_fal(input: TestInput) -> File:
# File will be downloaded when content is accessed
assert fal_file_content_matches(file, expected_content)
assert fal_file_downloaded(file)


def test_fal_compressed_file(isolated_client):
class TestInput(BaseModel):
files: CompressedFile

@isolated_client(requirements=["pydantic==1.10.12"])
def init_compressed_file_on_fal(input: TestInput) -> int:
extracted_file_paths = [file for file in input.files]
return extracted_file_paths

archive_url = "https://storage.googleapis.com/falserverless/sdk_tests/compressed_file_test.zip"
test_input = TestInput(files=archive_url)

extracted_file_paths = init_compressed_file_on_fal(test_input)

assert all(isinstance(file, Path) for file in extracted_file_paths)
assert len(extracted_file_paths) == 3
17 changes: 17 additions & 0 deletions projects/fal/tests/toolkit/file_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import os
from base64 import b64encode
from pathlib import Path

import pytest
from fal.toolkit.file.file import File, GoogleStorageRepository
from fal.toolkit import CompressedFile
from pydantic import BaseModel


def test_binary_content_matches():
Expand Down Expand Up @@ -62,3 +65,17 @@ def test_gcp_storage_if_available():
assert file.url.startswith(
"https://storage.googleapis.com/fal_registry_image_results/"
)


def test_compressed_file():
class TestInput(BaseModel):
files: CompressedFile

archive_url = "https://storage.googleapis.com/falserverless/sdk_tests/compressed_file_test.zip"

test_input = TestInput(files=archive_url)

extracted_file_paths = [file for file in test_input.files]

assert all(isinstance(file, Path) for file in extracted_file_paths)
assert len(extracted_file_paths) == 3

0 comments on commit 6347826

Please sign in to comment.