Skip to content

Commit

Permalink
feat(toolkit): support multipart uploads without long-lived tokens (#409
Browse files Browse the repository at this point in the history
)
  • Loading branch information
efiop authored Feb 8, 2025
1 parent 6aebf90 commit e6e602c
Showing 1 changed file with 322 additions and 8 deletions.
330 changes: 322 additions & 8 deletions projects/fal/src/fal/toolkit/file/providers/fal.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,6 @@ def save(
return self._save(file, "gcs")


@dataclass
class FalFileRepositoryV3(FalFileRepositoryBase):
def save(
self, file: FileData, object_lifecycle_preference: dict[str, str] | None = None
) -> str:
return self._save(file, "fal-cdn-v3")


class MultipartUpload:
MULTIPART_THRESHOLD = 100 * 1024 * 1024
MULTIPART_CHUNK_SIZE = 100 * 1024 * 1024
Expand Down Expand Up @@ -366,6 +358,212 @@ def _upload_part(pn: int) -> None:
return multipart.complete()


class MultipartUploadV3:
MULTIPART_THRESHOLD = 100 * 1024 * 1024
MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024
MULTIPART_MAX_CONCURRENCY = 10

def __init__(
self,
file_name: str,
chunk_size: int | None = None,
content_type: str | None = None,
max_concurrency: int | None = None,
) -> None:
self.file_name = file_name
self.chunk_size = chunk_size or self.MULTIPART_CHUNK_SIZE
self.content_type = content_type or "application/octet-stream"
self.max_concurrency = max_concurrency or self.MULTIPART_MAX_CONCURRENCY

self._access_url: str | None = None
self._upload_url: str | None = None

self._parts: list[dict] = []

@property
def access_url(self) -> str:
if not self._access_url:
raise FileUploadException("Upload not initiated")
return self._access_url

@property
def upload_url(self) -> str:
if not self._upload_url:
raise FileUploadException("Upload not initiated")
return self._upload_url

@property
def auth_headers(self) -> dict[str, str]:
fal_key = key_credentials()
if not fal_key:
raise FileUploadException("FAL_KEY must be set")

key_id, key_secret = fal_key
return {
"Authorization": f"Key {key_id}:{key_secret}",
}

def create(self):
grpc_host = os.environ.get("FAL_HOST", "api.alpha.fal.ai")
rest_host = grpc_host.replace("api", "rest", 1)
url = f"https://{rest_host}/storage/upload/initiate-multipart?storage_type=fal-cdn-v3"

try:
req = Request(
url,
method="POST",
headers={
**self.auth_headers,
"Accept": "application/json",
},
data=json.dumps(
{
"file_name": self.file_name,
"content_type": self.content_type,
}
).encode(),
)

with urlopen(req) as response:
result = json.load(response)
self._access_url = result["file_url"]
self._upload_url = result["upload_url"]

except HTTPError as exc:
raise FileUploadException(
f"Error initiating upload. Status {exc.status}: {exc.reason}"
)

@retry(max_retries=5, base_delay=1, backoff_type="exponential", jitter=True)
def upload_part(self, part_number: int, data: bytes) -> None:
parsed = urlparse(self.upload_url)
part_path = parsed.path + f"/{part_number}"
url = urlunparse(parsed._replace(path=part_path))

req = Request(
url,
method="PUT",
headers={
"Content-Type": self.content_type,
},
data=data,
)

try:
with urlopen(req) as resp:
self._parts.append(
{
"partNumber": part_number,
"etag": resp.headers["ETag"],
}
)
except HTTPError as exc:
raise FileUploadException(
f"Error uploading part {part_number} to {url}. "
f"Status {exc.status}: {exc.reason}"
)

def complete(self) -> str:
parsed = urlparse(self.upload_url)
complete_path = parsed.path + "/complete"
url = urlunparse(parsed._replace(path=complete_path))

try:
req = Request(
url,
method="POST",
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
data=json.dumps({"parts": self._parts}).encode(),
)
with urlopen(req):
pass
except HTTPError as e:
raise FileUploadException(
f"Error completing upload {url}. Status {e.status}: {e.reason}"
)

return self.access_url

@classmethod
def save(
cls,
file: FileData,
chunk_size: int | None = None,
max_concurrency: int | None = None,
):
import concurrent.futures

multipart = cls(
file.file_name,
chunk_size=chunk_size,
content_type=file.content_type,
max_concurrency=max_concurrency,
)
multipart.create()

parts = math.ceil(len(file.data) / multipart.chunk_size)
with concurrent.futures.ThreadPoolExecutor(
max_workers=multipart.max_concurrency
) as executor:
futures = []
for part_number in range(1, parts + 1):
start = (part_number - 1) * multipart.chunk_size
data = file.data[start : start + multipart.chunk_size]
futures.append(
executor.submit(multipart.upload_part, part_number, data)
)

for future in concurrent.futures.as_completed(futures):
future.result()

return multipart.complete()

@classmethod
def save_file(
cls,
file_path: str | Path,
chunk_size: int | None = None,
content_type: str | None = None,
max_concurrency: int | None = None,
) -> str:
import concurrent.futures

file_name = os.path.basename(file_path)
size = os.path.getsize(file_path)

multipart = cls(
file_name,
chunk_size=chunk_size,
content_type=content_type,
max_concurrency=max_concurrency,
)
multipart.create()

parts = math.ceil(size / multipart.chunk_size)
with concurrent.futures.ThreadPoolExecutor(
max_workers=multipart.max_concurrency
) as executor:
futures = []
for part_number in range(1, parts + 1):

def _upload_part(pn: int) -> None:
with open(file_path, "rb") as f:
start = (pn - 1) * multipart.chunk_size
f.seek(start)
data = f.read(multipart.chunk_size)
multipart.upload_part(pn, data)

futures.append(executor.submit(_upload_part, part_number))

for future in concurrent.futures.as_completed(futures):
future.result()

return multipart.complete()


class InternalMultipartUploadV3:
MULTIPART_THRESHOLD = 100 * 1024 * 1024
MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024
Expand Down Expand Up @@ -703,6 +901,122 @@ def auth_headers(self) -> dict[str, str]:
}


@dataclass
class FalFileRepositoryV3(FileRepository):
@property
def auth_headers(self) -> dict[str, str]:
fal_key = key_credentials()
if not fal_key:
raise FileUploadException("FAL_KEY must be set")

key_id, key_secret = fal_key
return {
"Authorization": f"Key {key_id}:{key_secret}",
"User-Agent": "fal/0.1.0",
}

@retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
def save(
self,
file: FileData,
multipart: bool | None = None,
multipart_threshold: int | None = None,
multipart_chunk_size: int | None = None,
multipart_max_concurrency: int | None = None,
object_lifecycle_preference: dict[str, str] | None = None,
) -> str:
if multipart is None:
threshold = multipart_threshold or MultipartUploadV3.MULTIPART_THRESHOLD
multipart = len(file.data) > threshold

if multipart:
return MultipartUploadV3.save(
file,
chunk_size=multipart_chunk_size,
max_concurrency=multipart_max_concurrency,
)

headers = {
**self.auth_headers,
"Accept": "application/json",
"Content-Type": "application/json",
}

grpc_host = os.environ.get("FAL_HOST", "api.alpha.fal.ai")
rest_host = grpc_host.replace("api", "rest", 1)
url = f"https://{rest_host}/storage/upload/initiate?storage_type=fal-cdn-v3"

request = Request(
url,
headers=headers,
method="POST",
data=json.dumps(
{
"file_name": file.file_name,
"content_type": file.content_type,
}
).encode(),
)
try:
with urlopen(request) as response:
result = json.load(response)
file_url = result["file_url"]
upload_url = result["upload_url"]
except HTTPError as e:
raise FileUploadException(
f"Error initiating upload. Status {e.status}: {e.reason}"
)

request = Request(
upload_url,
headers={"Content-Type": file.content_type},
method="PUT",
data=file.data,
)
try:
with urlopen(request):
pass
except HTTPError as e:
raise FileUploadException(
f"Error uploading file. Status {e.status}: {e.reason}"
)

return file_url

def save_file(
self,
file_path: str | Path,
content_type: str,
multipart: bool | None = None,
multipart_threshold: int | None = None,
multipart_chunk_size: int | None = None,
multipart_max_concurrency: int | None = None,
object_lifecycle_preference: dict[str, str] | None = None,
) -> tuple[str, FileData | None]:
if multipart is None:
threshold = multipart_threshold or MultipartUploadV3.MULTIPART_THRESHOLD
multipart = os.path.getsize(file_path) > threshold

if multipart:
url = MultipartUploadV3.save_file(
file_path,
chunk_size=multipart_chunk_size,
content_type=content_type,
max_concurrency=multipart_max_concurrency,
)
data = None
else:
with open(file_path, "rb") as f:
data = FileData(
f.read(),
content_type=content_type,
file_name=os.path.basename(file_path),
)
url = self.save(data, object_lifecycle_preference)

return url, data


# This is only available for internal users to have long-lived access tokens
@dataclass
class InternalFalFileRepositoryV3(FileRepository):
Expand Down

0 comments on commit e6e602c

Please sign in to comment.