Skip to content

Commit

Permalink
feat(toolkit): fal compressed file (#60)
Browse files Browse the repository at this point in the history
* feat(toolkit): fal compressed file

* fix: type check
  • Loading branch information
badayvedat authored Jan 26, 2024
1 parent dead59e commit f064e50
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 4 deletions.
2 changes: 1 addition & 1 deletion projects/fal/src/fal/toolkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from fal.toolkit.file.file import File
from fal.toolkit.file import CompressedFile, File
from fal.toolkit.image.image import Image, ImageSizeInput, get_image_size
from fal.toolkit.mainify import mainify
from fal.toolkit.utils import (
Expand Down
39 changes: 38 additions & 1 deletion projects/fal/src/fal/toolkit/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from fal.toolkit.utils.download_utils import download_file
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.typing import Optional

from tempfile import NamedTemporaryFile, TemporaryDirectory
from zipfile import ZipFile

FileRepositoryFactory = Callable[[], FileRepository]

Expand Down Expand Up @@ -149,3 +150,39 @@ def save(self, path: str | Path, overwrite: bool = False) -> Path:
downloaded_path.rename(file_path)

return file_path


@mainify
class CompressedFile(File):
_extract_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)

def __init__(self, **kwargs):
super().__init__(**kwargs)
self._extract_dir = None

def __iter__(self):
if not self._extract_dir:
self._extract_files()

files = Path(self._extract_dir.name).iterdir() # type: ignore
return iter(files)

def _extract_files(self):
self._extract_dir = TemporaryDirectory()

with NamedTemporaryFile() as temp_file:
file_path = temp_file.name
self.save(file_path, overwrite=True)

with ZipFile(file_path) as zip_file:
zip_file.extractall(self._extract_dir.name)

def glob(self, pattern: str):
if not self._extract_dir:
self._extract_files()

return Path(self._extract_dir.name).glob(pattern) # type: ignore

def __del__(self):
if self._extract_dir:
self._extract_dir.cleanup()
22 changes: 20 additions & 2 deletions projects/fal/tests/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from uuid import uuid4

import fal
from fal.toolkit.file.file import CompressedFile
from pydantic import BaseModel, Field
import pytest
from fal import FalServerlessHost, FalServerlessKeyCredentials, local, sync_dir
from fal.api import FalServerlessError
Expand Down Expand Up @@ -506,8 +508,6 @@ def fal_file_to_local_file(content: str):
],
)
def test_fal_file_input(isolated_client, file_url: str, expected_content: str):
from pydantic import BaseModel, Field

class TestInput(BaseModel):
file: File = Field()

Expand All @@ -526,3 +526,21 @@ def init_file_on_fal(input: TestInput) -> File:
# Expect value error if we try to access the file content for input file
with pytest.raises(ValueError):
fal_file_content_matches(file, expected_content)


def test_fal_compressed_file(isolated_client):
class TestInput(BaseModel):
files: CompressedFile

@isolated_client(requirements=["pydantic==1.10.12"])
def init_compressed_file_on_fal(input: TestInput) -> int:
extracted_file_paths = [file for file in input.files]
return extracted_file_paths

archive_url = "https://storage.googleapis.com/falserverless/sdk_tests/compressed_file_test.zip"
test_input = TestInput(files=archive_url)

extracted_file_paths = init_compressed_file_on_fal(test_input)

assert all(isinstance(file, Path) for file in extracted_file_paths)
assert len(extracted_file_paths) == 3

0 comments on commit f064e50

Please sign in to comment.