Skip to content

Commit

Permalink
fix(toolkit): fix repository signatures (#422)
Browse files Browse the repository at this point in the history
  • Loading branch information
efiop authored Feb 18, 2025
1 parent 445cc21 commit 25ea9a9
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 6 deletions.
4 changes: 3 additions & 1 deletion projects/fal/src/fal/toolkit/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def from_bytes(
fallback_repo = get_builtin_repository(fallback_repository)

url = fallback_repo.save(
fdata, object_lifecycle_preference, **fallback_save_kwargs
fdata,
object_lifecycle_preference=object_lifecycle_preference,
**fallback_save_kwargs,
)

return cls(
Expand Down
31 changes: 27 additions & 4 deletions projects/fal/src/fal/toolkit/file/providers/fal.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,13 @@ def _save(self, file: FileData, storage_type: str) -> str:
@dataclass
class FalFileRepository(FalFileRepositoryBase):
def save(
self, file: FileData, object_lifecycle_preference: dict[str, str] | None = None
self,
file: FileData,
multipart: bool | None = None,
multipart_threshold: int | None = None,
multipart_chunk_size: int | None = None,
multipart_max_concurrency: int | None = None,
object_lifecycle_preference: dict[str, str] | None = None,
) -> str:
return self._save(file, "gcs")

Expand Down Expand Up @@ -834,7 +840,10 @@ def save_file(
content_type=content_type,
file_name=os.path.basename(file_path),
)
url = self.save(data, object_lifecycle_preference)
url = self.save(
data,
object_lifecycle_preference=object_lifecycle_preference,
)

return url, data

Expand All @@ -844,6 +853,10 @@ class InMemoryRepository(FileRepository):
def save(
self,
file: FileData,
multipart: bool | None = None,
multipart_threshold: int | None = None,
multipart_chunk_size: int | None = None,
multipart_max_concurrency: int | None = None,
object_lifecycle_preference: dict[str, str] | None = None,
) -> str:
return f'data:{file.content_type};base64,{b64encode(file.data).decode("utf-8")}'
Expand All @@ -865,6 +878,10 @@ def _object_lifecycle_headers(
def save(
self,
file: FileData,
multipart: bool | None = None,
multipart_threshold: int | None = None,
multipart_chunk_size: int | None = None,
multipart_max_concurrency: int | None = None,
object_lifecycle_preference: dict[str, str] | None = None,
) -> str:
headers = {
Expand Down Expand Up @@ -1013,7 +1030,10 @@ def save_file(
content_type=content_type,
file_name=os.path.basename(file_path),
)
url = self.save(data, object_lifecycle_preference)
url = self.save(
data,
object_lifecycle_preference=object_lifecycle_preference,
)

return url, data

Expand Down Expand Up @@ -1119,6 +1139,9 @@ def save_file(
content_type=content_type,
file_name=os.path.basename(file_path),
)
url = self.save(data, object_lifecycle_preference)
url = self.save(
data,
object_lifecycle_preference=object_lifecycle_preference,
)

return url, data
4 changes: 4 additions & 0 deletions projects/fal/src/fal/toolkit/file/providers/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def bucket(self):
def save(
self,
data: FileData,
multipart: bool | None = None,
multipart_threshold: int | None = None,
multipart_chunk_size: int | None = None,
multipart_max_concurrency: int | None = None,
object_lifecycle_preference: Optional[dict[str, str]] = None,
) -> str:
destination_path = posixpath.join(
Expand Down
4 changes: 4 additions & 0 deletions projects/fal/src/fal/toolkit/file/providers/r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def bucket(self):
def save(
self,
data: FileData,
multipart: bool | None = None,
multipart_threshold: int | None = None,
multipart_chunk_size: int | None = None,
multipart_max_concurrency: int | None = None,
object_lifecycle_preference: Optional[dict[str, str]] = None,
) -> str:
destination_path = posixpath.join(
Expand Down
4 changes: 4 additions & 0 deletions projects/fal/src/fal/toolkit/file/providers/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def storage_client(self):
def save(
self,
data: FileData,
multipart: bool | None = None,
multipart_threshold: int | None = None,
multipart_chunk_size: int | None = None,
multipart_max_concurrency: int | None = None,
object_lifecycle_preference: Optional[dict[str, str]] = None,
key: Optional[str] = None,
) -> str:
Expand Down
13 changes: 12 additions & 1 deletion projects/fal/src/fal/toolkit/file/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class FileRepository:
def save(
self,
data: FileData,
multipart: bool | None = None,
multipart_threshold: int | None = None,
multipart_chunk_size: int | None = None,
multipart_max_concurrency: int | None = None,
object_lifecycle_preference: Optional[dict[str, str]] = None,
) -> str:
raise NotImplementedError()
Expand All @@ -59,4 +63,11 @@ def save_file(
with open(file_path, "rb") as fobj:
data = FileData(fobj.read(), content_type, Path(file_path).name)

return self.save(data, object_lifecycle_preference), data
return self.save(
data,
multipart=multipart,
multipart_threshold=multipart_threshold,
multipart_chunk_size=multipart_chunk_size,
multipart_max_concurrency=multipart_max_concurrency,
object_lifecycle_preference=object_lifecycle_preference,
), data
7 changes: 7 additions & 0 deletions projects/fal/tests/test_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,13 @@ def hello_file():
return File.from_bytes(
b"Hello fal storage from isolated",
repository=repo_type,
save_kwargs={
"multipart": False,
"multipart_threshold": 1024 * 1024,
"multipart_chunk_size": 1024 * 1024,
"multipart_max_concurrency": 10,
"object_lifecycle_preference": {},
},
fallback_repository=None,
)

Expand Down

0 comments on commit 25ea9a9

Please sign in to comment.