diff --git a/projects/fal/src/fal/toolkit/__init__.py b/projects/fal/src/fal/toolkit/__init__.py index deb0a4dc..863794e1 100644 --- a/projects/fal/src/fal/toolkit/__init__.py +++ b/projects/fal/src/fal/toolkit/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from fal.toolkit.file.file import File +from fal.toolkit.file import CompressedFile, File from fal.toolkit.image.image import Image, ImageSizeInput, get_image_size from fal.toolkit.mainify import mainify from fal.toolkit.utils import ( diff --git a/projects/fal/src/fal/toolkit/file/file.py b/projects/fal/src/fal/toolkit/file/file.py index 70be1fef..2b673578 100644 --- a/projects/fal/src/fal/toolkit/file/file.py +++ b/projects/fal/src/fal/toolkit/file/file.py @@ -12,7 +12,8 @@ from fal.toolkit.utils.download_utils import download_file from pydantic import BaseModel, Field, PrivateAttr from pydantic.typing import Optional - +from tempfile import NamedTemporaryFile, TemporaryDirectory +from zipfile import ZipFile FileRepositoryFactory = Callable[[], FileRepository] @@ -149,3 +150,39 @@ def save(self, path: str | Path, overwrite: bool = False) -> Path: downloaded_path.rename(file_path) return file_path + + +@mainify +class CompressedFile(File): + _extract_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._extract_dir = None + + def __iter__(self): + if not self._extract_dir: + self._extract_files() + + files = Path(self._extract_dir.name).iterdir() # type: ignore + return iter(files) + + def _extract_files(self): + self._extract_dir = TemporaryDirectory() + + with NamedTemporaryFile() as temp_file: + file_path = temp_file.name + self.save(file_path, overwrite=True) + + with ZipFile(file_path) as zip_file: + zip_file.extractall(self._extract_dir.name) + + def glob(self, pattern: str): + if not self._extract_dir: + self._extract_files() + + return Path(self._extract_dir.name).glob(pattern) # type: ignore + + def __del__(self): + if self._extract_dir: + self._extract_dir.cleanup() diff --git a/projects/fal/tests/integration_test.py b/projects/fal/tests/integration_test.py index 05307094..0e547c70 100644 --- a/projects/fal/tests/integration_test.py +++ b/projects/fal/tests/integration_test.py @@ -4,6 +4,8 @@ from uuid import uuid4 import fal +from fal.toolkit.file.file import CompressedFile +from pydantic import BaseModel, Field import pytest from fal import FalServerlessHost, FalServerlessKeyCredentials, local, sync_dir from fal.api import FalServerlessError @@ -506,8 +508,6 @@ def fal_file_to_local_file(content: str): ], ) def test_fal_file_input(isolated_client, file_url: str, expected_content: str): - from pydantic import BaseModel, Field - class TestInput(BaseModel): file: File = Field() @@ -526,3 +526,21 @@ def init_file_on_fal(input: TestInput) -> File: # Expect value error if we try to access the file content for input file with pytest.raises(ValueError): fal_file_content_matches(file, expected_content) + + +def test_fal_compressed_file(isolated_client): + class TestInput(BaseModel): + files: CompressedFile + + @isolated_client(requirements=["pydantic==1.10.12"]) + def init_compressed_file_on_fal(input: TestInput) -> int: + extracted_file_paths = [file for file in input.files] + return extracted_file_paths + + archive_url = "https://storage.googleapis.com/falserverless/sdk_tests/compressed_file_test.zip" + test_input = TestInput(files=archive_url) + + extracted_file_paths = init_compressed_file_on_fal(test_input) + + assert all(isinstance(file, Path) for file in extracted_file_paths) + assert len(extracted_file_paths) == 3