Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: background uploads #71

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion projects/fal/src/fal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from fal import apps

# TODO: DEPRECATED - use function instead
from fal.api import FalServerlessHost, LocalHost, cached
from fal.api import FalServerlessHost, LocalHost
from fal.api import function
from fal.api import function as isolated
from fal.app import App, endpoint, realtime, wrap_app
from fal.sdk import FalServerlessKeyCredentials
from fal.sync import sync_dir
from fal.toolkit.utils import cached

local = LocalHost()
serverless = FalServerlessHost()
Expand Down
35 changes: 0 additions & 35 deletions projects/fal/src/fal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,41 +135,6 @@ def run(
raise NotImplementedError


def cached(func: Callable[ArgsT, ReturnT]) -> Callable[ArgsT, ReturnT]:
"""Cache the result of the given function in-memory."""
import hashlib

try:
source_code = inspect.getsource(func).encode("utf-8")
except OSError:
# TODO: explain the reason for this (e.g. we don't know how to
# check if you sent us the same function twice).
print(f"[warning] Function {func.__name__} can not be cached...")
return func

cache_key = hashlib.sha256(source_code).hexdigest()

@wraps(func)
def wrapper(
*args: ArgsT.args,
**kwargs: ArgsT.kwargs,
) -> ReturnT:
from functools import lru_cache

# HACK: Using the isolate module as a global cache.
import isolate

if not hasattr(isolate, "__cached_functions__"):
isolate.__cached_functions__ = {}

if cache_key not in isolate.__cached_functions__:
isolate.__cached_functions__[cache_key] = lru_cache(maxsize=None)(func)

return isolate.__cached_functions__[cache_key](*args, **kwargs)

return wrapper


def _prepare_partial_func(
func: Callable[ArgsT, ReturnT],
*args: ArgsT.args,
Expand Down
1 change: 1 addition & 0 deletions projects/fal/src/fal/toolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
clone_repository,
download_file,
download_model_weights,
cached,
)
3 changes: 3 additions & 0 deletions projects/fal/src/fal/toolkit/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from fal.toolkit.file.providers.r2 import R2Repository
from fal.toolkit.file.types import FileData, FileRepository, RepositoryId
from fal.toolkit.mainify import mainify
from fal.toolkit.utils.cache import cached
from fal.toolkit.utils.download_utils import download_file
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.typing import Optional
from tempfile import NamedTemporaryFile, TemporaryDirectory
from zipfile import ZipFile


FileRepositoryFactory = Callable[[], FileRepository]

BUILT_IN_REPOSITORIES: dict[RepositoryId, FileRepositoryFactory] = {
Expand All @@ -25,6 +27,7 @@
}


@cached
def get_builtin_repository(id: RepositoryId) -> FileRepository:
if id not in BUILT_IN_REPOSITORIES.keys():
raise ValueError(f'"{id}" is not a valid built-in file repository')
Expand Down
42 changes: 40 additions & 2 deletions projects/fal/src/fal/toolkit/file/providers/fal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,31 @@
import json
import os
from base64 import b64encode
from dataclasses import dataclass
from dataclasses import dataclass, field
from concurrent.futures import ThreadPoolExecutor, Future
from urllib.error import HTTPError
from urllib.request import Request, urlopen

from fal.toolkit.exceptions import FileUploadException
from fal.toolkit.file.types import FileData, FileRepository
from fal.toolkit.mainify import mainify

# Don't allow more than 24 uploads to be in progress at once, if we are stuck
# then execute the next upload synchronously.
MAX_BACKGROUND_UPLOADS = 24


@mainify
@dataclass
class FalFileRepository(FileRepository):
thread_pool: ThreadPoolExecutor = field(default_factory=ThreadPoolExecutor)
uploads: set[Future] = field(default_factory=set)

def __post_init__(self):
self.allow_background_uploads = os.environ.get(
"FAL_ALLOW_BACKGROUND_UPLOADS", False
)

def save(self, file: FileData) -> str:
key_id = os.environ.get("FAL_KEY_ID")
key_secret = os.environ.get("FAL_KEY_SECRET")
Expand All @@ -29,6 +42,7 @@ def save(self, file: FileData) -> str:
rest_host = grpc_host.replace("api", "rest", 1)
storage_url = f"https://{rest_host}/storage/upload/initiate"

self.gc_futures()
try:
req = Request(
storage_url,
Expand All @@ -45,7 +59,14 @@ def save(self, file: FileData) -> str:
result = json.load(response)

upload_url = result["upload_url"]
self._upload_file(upload_url, file)
if (
not self.allow_background_uploads
or len(self.uploads) >= MAX_BACKGROUND_UPLOADS
):
self._upload_file(upload_url, file)
else:
future = self.thread_pool.submit(self._upload_file, upload_url, file)
self.uploads.add(future)

return result["file_url"]
except HTTPError as e:
Expand All @@ -64,6 +85,23 @@ def _upload_file(self, upload_url: str, file: FileData):
with urlopen(req):
return

def gc_futures(self):
import traceback

for future in self.uploads.copy():
if not future.done():
continue

if future in self.uploads:
self.uploads.remove(future)
Comment on lines +95 to +96
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ATL: why the if? if it's for race conditions, shouldn't we do a better locking mechanism or handle KeyError exception?


exception = future.exception()
if exception is not None:
print("[Warning] Failed to upload file")
traceback.print_exception(
type(exception), exception, exception.__traceback__
)


@mainify
@dataclass
Expand Down
1 change: 1 addition & 0 deletions projects/fal/src/fal/toolkit/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations

from fal.toolkit.utils.download_utils import *
from fal.toolkit.utils.cache import *
48 changes: 48 additions & 0 deletions projects/fal/src/fal/toolkit/utils/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

import inspect
from functools import wraps
from typing import (
Callable,
TypeVar,
)

from typing_extensions import ParamSpec

ArgsT = ParamSpec("ArgsT")
ReturnT = TypeVar("ReturnT", covariant=True)


def cached(func: Callable[ArgsT, ReturnT]) -> Callable[ArgsT, ReturnT]:
"""Cache the result of the given function in-memory."""
import hashlib

try:
source_code = inspect.getsource(func).encode("utf-8")
except OSError:
# TODO: explain the reason for this (e.g. we don't know how to
# check if you sent us the same function twice).
print(f"[warning] Function {func.__name__} can not be cached...")
return func

cache_key = hashlib.sha256(source_code).hexdigest()

@wraps(func)
def wrapper(
*args: ArgsT.args,
**kwargs: ArgsT.kwargs,
) -> ReturnT:
from functools import lru_cache

# HACK: Using the isolate module as a global cache.
import isolate

if not hasattr(isolate, "__cached_functions__"):
isolate.__cached_functions__ = {}

if cache_key not in isolate.__cached_functions__:
isolate.__cached_functions__[cache_key] = lru_cache(maxsize=None)(func)

return isolate.__cached_functions__[cache_key](*args, **kwargs)

return wrapper
Loading