Skip to content

Commit

Permalink
feat(toolkit): input types (#380)
Browse files Browse the repository at this point in the history
* feat(toolkit): input types

* feat: support pydantic v2

* feat: add loading

* refactor: use read_image_from_url to load images

* refactor: naming

---------

Co-authored-by: Ruslan Kuprieiev <[email protected]>
  • Loading branch information
badayvedat and efiop authored Jan 21, 2025
1 parent e9212ba commit 72a7376
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 0 deletions.
140 changes: 140 additions & 0 deletions projects/fal/src/fal/toolkit/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import re
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Generator, Union

import pydantic
from pydantic.utils import update_not_none

from fal.toolkit.image import read_image_from_url
from fal.toolkit.utils.download_utils import download_file

# https://github.com/pydantic/pydantic/pull/2573
if not hasattr(pydantic, "__version__") or pydantic.__version__.startswith("1."):
IS_PYDANTIC_V2 = False
else:
IS_PYDANTIC_V2 = True

MAX_DATA_URI_LENGTH = 10 * 1024 * 1024
MAX_HTTPS_URL_LENGTH = 2048

HTTP_URL_REGEX = (
r"^https:\/\/(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(?::\d{1,5})?(?:\/[^\s]*)?$"
)


class DownloadFileMixin:
@contextmanager
def as_temp_file(self) -> Generator[Path, None, None]:
with tempfile.TemporaryDirectory() as temp_dir:
yield download_file(str(self), temp_dir)


class DownloadImageMixin:
def to_pil(self):
return read_image_from_url(str(self))


class DataUri(DownloadFileMixin, str):
if IS_PYDANTIC_V2:

@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler) -> Any:
return {
"type": "str",
"pattern": "^data:",
"max_length": MAX_DATA_URI_LENGTH,
"strip_whitespace": True,
}

def __get_pydantic_json_schema__(cls, core_schema, handler) -> Dict[str, Any]:
json_schema = handler(core_schema)
json_schema.update(format="data-uri")
return json_schema
else:

@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, value: Any) -> "DataUri":
from pydantic.validators import str_validator

value = str_validator(value)
value = value.strip()

if not value.startswith("data:"):
raise ValueError("Data URI must start with 'data:'")

if len(value) > MAX_DATA_URI_LENGTH:
raise ValueError(
f"Data URI is too long. Max length is {MAX_DATA_URI_LENGTH} bytes."
)

return cls(value)

@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
update_not_none(field_schema, format="data-uri")


class HttpsUrl(DownloadFileMixin, str):
if IS_PYDANTIC_V2:

@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler) -> Any:
return {
"type": "str",
"pattern": HTTP_URL_REGEX,
"max_length": MAX_HTTPS_URL_LENGTH,
"strip_whitespace": True,
}

def __get_pydantic_json_schema__(cls, core_schema, handler) -> Dict[str, Any]:
json_schema = handler(core_schema)
json_schema.update(format="https-url")
return json_schema

else:

@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, value: Any) -> "HttpsUrl":
from pydantic.validators import str_validator

value = str_validator(value)
value = value.strip()

if not re.match(HTTP_URL_REGEX, value):
raise ValueError(
"URL must start with 'https://' and follow the correct format."
)

if len(value) > MAX_HTTPS_URL_LENGTH:
raise ValueError(
f"HTTPS URL is too long. Max length is "
f"{MAX_HTTPS_URL_LENGTH} characters."
)

return cls(value)

@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
update_not_none(field_schema, format="https-url")


class ImageHttpsUrl(DownloadImageMixin, HttpsUrl):
pass


class ImageDataUri(DownloadImageMixin, DataUri):
pass


FileInput = Union[HttpsUrl, DataUri]
ImageInput = Union[ImageHttpsUrl, ImageDataUri]
99 changes: 99 additions & 0 deletions projects/fal/tests/toolkit/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
from pydantic import BaseModel, ValidationError

from fal.toolkit.types import MAX_DATA_URI_LENGTH, MAX_HTTPS_URL_LENGTH, FileInput


class DummyModel(BaseModel):
url: FileInput


class TestFileInput:
def test_valid_https_urls(self):
# Test basic HTTPS URL
model = DummyModel(url="https://example.com")
assert model.url == "https://example.com"

# Test HTTPS URL with path
model = DummyModel(url="https://example.com/path/to/resource")
assert model.url == "https://example.com/path/to/resource"

# Test HTTPS URL with query parameters
model = DummyModel(url="https://example.com/search?q=test&page=1")
assert model.url == "https://example.com/search?q=test&page=1"

# Test HTTPS URL with subdomain
model = DummyModel(url="https://sub.example.com")
assert model.url == "https://sub.example.com"

# Test HTTPS URL with port
model = DummyModel(url="https://example.com:8443")
assert model.url == "https://example.com:8443"

# Test HTTPS URL with whitespace
model = DummyModel(url=" https://example.com ")
assert model.url == "https://example.com"

# TODO: should we even allow this?
# Test HTTPS URL with port
model = DummyModel(url="https://example.com:8443")
assert model.url == "https://example.com:8443"

def test_valid_data_uris(self):
# Test basic data URI
model = DummyModel(url="data:text/plain;base64,SGVsbG8gV29ybGQ=")
assert model.url == "data:text/plain;base64,SGVsbG8gV29ybGQ="

# Test data URI with image
image_uri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" # noqa: E501
model = DummyModel(url=image_uri)
assert model.url == image_uri

# Test data URI with whitespace
model = DummyModel(url=" data:text/plain,Hello World ")
assert model.url == "data:text/plain,Hello World"

def test_invalid_inputs(self):
# Test HTTP URL (non-HTTPS)
with pytest.raises(ValueError):
DummyModel(url="http://example.com")

# Test malformed URL
with pytest.raises(ValueError):
DummyModel(url="not-a-url")

# Test invalid data URI
with pytest.raises(ValueError):
DummyModel(url="invalid-data-uri")

# Test empty string
with pytest.raises(ValueError):
DummyModel(url="")

# Test None value
with pytest.raises(ValueError):
DummyModel(url=None)

def test_length_limits(self):
# Test HTTPS URL at max length
domain = "example.com"
path_length = MAX_HTTPS_URL_LENGTH - len(f"https://{domain}/")
long_url = f"https://{domain}/{'a' * path_length}"
model = DummyModel(url=long_url)
assert model.url == long_url

# Test HTTPS URL exceeding max length
too_long_url = f"https://example.com/{'a' * MAX_HTTPS_URL_LENGTH}"
with pytest.raises(ValidationError):
DummyModel(url=too_long_url)

# Test data URI at max length
uri_prefix = "data:text/plain,"
long_uri = f"{uri_prefix}{'a' * (MAX_DATA_URI_LENGTH - len(uri_prefix))}"
model = DummyModel(url=long_uri)
assert model.url == long_uri

# Test data URI exceeding max length
too_long_uri = f"data:text/plain,{'a' * MAX_DATA_URI_LENGTH}"
with pytest.raises(ValueError):
DummyModel(url=too_long_uri)

0 comments on commit 72a7376

Please sign in to comment.