diff --git a/projects/fal/src/fal/cli.py b/projects/fal/src/fal/cli.py index 38709dd9..77292f8a 100644 --- a/projects/fal/src/fal/cli.py +++ b/projects/fal/src/fal/cli.py @@ -232,6 +232,10 @@ def key_revoke(client: sdk.FalServerlessClient, key_id: str): ##### Function group ##### +ALIAS_AUTH_OPTIONS = ["public", "private", "shared"] +ALIAS_AUTH_TYPE = Literal["public", "private", "shared"] + + @click.group @click.option("--host", default=DEFAULT_HOST, envvar=HOST_ENVVAR) @click.option("--port", default=DEFAULT_PORT, envvar=PORT_ENVVAR, hidden=True) @@ -245,7 +249,7 @@ def function_cli(ctx, host: str, port: str): @click.option( "--auth", "auth_mode", - type=click.Choice(["public", "private", "shared"]), + type=click.Choice(ALIAS_AUTH_OPTIONS), default="private", ) @click.argument("file_path", required=True) @@ -256,7 +260,7 @@ def register_application( file_path: str, function_name: str, alias: str | None, - auth_mode: Literal["public", "private", "shared"], + auth_mode: ALIAS_AUTH_TYPE, ): import runpy @@ -362,6 +366,36 @@ def _alias_table(aliases: list[AliasInfo]): return table +@alias_cli.command("set") +@click.argument("alias", required=True) +@click.argument("revision", required=True) +@click.option( + "--auth", + "auth_mode", + type=click.Choice(ALIAS_AUTH_OPTIONS), + default="private", +) +@click.pass_obj +def alias_set( + client: api.FalServerlessClient, + alias: str, + revision: str, + auth_mode: ALIAS_AUTH_TYPE, +): + with client.connect() as connection: + connection.create_alias(alias, revision, auth_mode) + + +@alias_cli.command("delete") +@click.argument("alias", required=True) +@click.pass_obj +def alias_delete(client: api.FalServerlessClient, alias: str): + with client.connect() as connection: + application_id = connection.delete_alias(alias) + + console.print(f"Deleted alias '{alias}' for application '{application_id}'.") + + @alias_cli.command("list") @click.pass_obj def alias_list(client: api.FalServerlessClient): @@ -377,6 +411,12 @@ def alias_list(client: api.FalServerlessClient): @click.option("--keep-alive", "-k", type=int) @click.option("--max-multiplexing", "-m", type=int) @click.option("--max-concurrency", "-c", type=int) +# TODO: add auth_mode +# @click.option( +# "--auth", +# "auth_mode", +# type=click.Choice(ALIAS_AUTH_OPTIONS), +# ) @click.pass_obj def alias_update( client: api.FalServerlessClient, diff --git a/projects/fal/src/fal/sdk.py b/projects/fal/src/fal/sdk.py index 7382e7e8..70a221db 100644 --- a/projects/fal/src/fal/sdk.py +++ b/projects/fal/src/fal/sdk.py @@ -494,6 +494,31 @@ def run( for partial_result in self.stub.Run(request): yield from_grpc(partial_result) + def create_alias( + self, + alias: str, + revision: str, + auth_mode: Literal["public", "private", "shared"], + ): + if auth_mode == "public": + auth = isolate_proto.ApplicationAuthMode.PUBLIC + elif auth_mode == "shared": + auth = isolate_proto.ApplicationAuthMode.SHARED + else: + auth = isolate_proto.ApplicationAuthMode.PRIVATE + + request = isolate_proto.SetAliasRequest( + alias=alias, + revision=revision, + auth_mode=auth, + ) + self.stub.SetAlias(request) + + def delete_alias(self, alias: str) -> str: + request = isolate_proto.DeleteAliasRequest(alias=alias) + res: isolate_proto.DeleteAliasResult = self.stub.DeleteAlias(request) + return res.revision + def list_aliases(self) -> list[AliasInfo]: request = isolate_proto.ListAliasesRequest() response: isolate_proto.ListAliasesResult = self.stub.ListAliases(request) diff --git a/projects/fal/tests/test_apps.py b/projects/fal/tests/test_apps.py index dc5783f7..22f683c6 100644 --- a/projects/fal/tests/test_apps.py +++ b/projects/fal/tests/test_apps.py @@ -1,50 +1,70 @@ from typing import Generator import fal +import fal.api as api import pytest from fal import apps from fal.rest_client import REST_CLIENT +import time from pydantic import BaseModel from openapi_fal_rest.api.applications import app_metadata -@pytest.fixture(scope="module") -def test_app() -> Generator[str, None, None]: - # Create a temporary app, register it, and return the ID of it. +class Input(BaseModel): + lhs: int + rhs: int + wait_time: int = 0 - import time - from fal.cli import _get_user_id +class Output(BaseModel): + result: int + + +@fal.function( + keep_alive=60, + machine_type="S", + serve=True, + max_concurrency=1, +) +def addition_app(input: Input) -> Output: + print("starting...") + for _ in range(input.wait_time): + print("sleeping...") + time.sleep(1) - class Input(BaseModel): - lhs: int - rhs: int - wait_time: int = 0 + return Output(result=input.lhs + input.rhs) + + +@pytest.fixture(scope="module") +def aliased_app() -> Generator[tuple[str, str], None, None]: + # Create a temporary app, register it, and return the ID of it. - class Output(BaseModel): - result: int + import uuid - @fal.function( - keep_alive=60, - machine_type="S", - serve=True, - max_concurrency=1, + app_alias = str(uuid.uuid4()).replace("-", "")[:10] + app_revision = addition_app.host.register( + func=addition_app.func, + options=addition_app.options, + # random enough + application_name=app_alias, + application_auth_mode="private", ) - def addition_app(input: Input) -> Output: - print("starting...") - for _ in range(input.wait_time): - print("sleeping...") - time.sleep(1) + yield app_revision, app_alias # type: ignore + - return Output(result=input.lhs + input.rhs) +@pytest.fixture(scope="module") +def test_app(): + # Create a temporary app, register it, and return the ID of it. - app_alias = addition_app.host.register( + from fal.cli import _get_user_id + + app_revision = addition_app.host.register( func=addition_app.func, options=addition_app.options, ) user_id = _get_user_id() - yield f"{user_id}-{app_alias}" + yield f"{user_id}-{app_revision}" def test_app_client(test_app: str): @@ -99,3 +119,90 @@ def test_app_openapi_spec_metadata(test_app: str): openapi_spec: dict = metadata["openapi"] for key in ["openapi", "info", "paths", "components"]: assert key in openapi_spec, f"{key} key missing from openapi {openapi_spec}" + + +def test_app_update_app(aliased_app: tuple[str, str]): + app_revision, app_alias = aliased_app + + host: api.FalServerlessHost = addition_app.host # type: ignore + with host._connection as client: + # Get the registered values + res = client.list_aliases() + found = next(filter(lambda alias: alias.alias == app_alias, res), None) + assert found, f"Could not find app {app_alias} in {res}" + assert found.revision == app_revision + + with host._connection as client: + new_keep_alive = found.keep_alive + 1 + new_max_concurrency = found.max_concurrency + 1 + new_max_multiplexing = found.max_multiplexing + 1 + + res = client.update_application( + application_name=app_alias, + keep_alive=new_keep_alive, + max_concurrency=new_max_concurrency, + max_multiplexing=new_max_multiplexing, + ) + assert res.alias == app_alias + assert res.keep_alive == new_keep_alive + assert res.max_concurrency == new_max_concurrency + assert res.max_multiplexing == new_max_multiplexing + + with host._connection as client: + new_keep_alive = new_keep_alive + 1 + res = client.update_application( + application_name=app_alias, + keep_alive=new_keep_alive, + ) + assert res.alias == app_alias + assert res.keep_alive == new_keep_alive + assert res.max_concurrency == new_max_concurrency + assert res.max_multiplexing == new_max_multiplexing + + with host._connection as client: + new_max_concurrency = new_max_concurrency + 1 + res = client.update_application( + application_name=app_alias, + max_concurrency=new_max_concurrency, + ) + assert res.alias == app_alias + assert res.keep_alive == new_keep_alive + assert res.max_concurrency == new_max_concurrency + assert res.max_multiplexing == new_max_multiplexing + + +def test_app_set_delete_alias(aliased_app: tuple[str, str]): + app_revision, app_alias = aliased_app + + host: api.FalServerlessHost = addition_app.host # type: ignore + + with host._connection as client: + # Get the registered values + res = client.list_aliases() + found = next(filter(lambda alias: alias.alias == app_alias, res), None) + assert found, f"Could not find app {app_alias} in {res}" + assert found.revision == app_revision + assert found.auth_mode == "private" + + new_app_alias = f"{app_alias}-new" + with host._connection as client: + # Get the registered values + res = client.create_alias(new_app_alias, app_revision, "public") + + with host._connection as client: + # Get the registered values + res = client.list_aliases() + found = next(filter(lambda alias: alias.alias == new_app_alias, res), None) + assert found, f"Could not find app {app_alias} in {res}" + assert found.revision == app_revision + assert found.auth_mode == "public" + + with host._connection as client: + res = client.delete_alias(alias=app_alias) + assert res == app_revision + + with host._connection as client: + # Get the registered values + res = client.list_aliases() + found = next(filter(lambda alias: alias.alias == app_alias, res), None) + assert not found, f"Found app {app_alias} in {res} after deletion"