Skip to content

Commit

Permalink
feat: add alias set and alias delete cli commands (#8)
Browse files Browse the repository at this point in the history
* feat: add alias set and alias delete cli commands

* alias update and delete tests

* add alias set test too
  • Loading branch information
chamini2 authored Dec 13, 2023
1 parent 3448116 commit c6de55c
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 26 deletions.
44 changes: 42 additions & 2 deletions projects/fal/src/fal/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def key_revoke(client: sdk.FalServerlessClient, key_id: str):


##### Function group #####
ALIAS_AUTH_OPTIONS = ["public", "private", "shared"]
ALIAS_AUTH_TYPE = Literal["public", "private", "shared"]


@click.group
@click.option("--host", default=DEFAULT_HOST, envvar=HOST_ENVVAR)
@click.option("--port", default=DEFAULT_PORT, envvar=PORT_ENVVAR, hidden=True)
Expand All @@ -245,7 +249,7 @@ def function_cli(ctx, host: str, port: str):
@click.option(
"--auth",
"auth_mode",
type=click.Choice(["public", "private", "shared"]),
type=click.Choice(ALIAS_AUTH_OPTIONS),
default="private",
)
@click.argument("file_path", required=True)
Expand All @@ -256,7 +260,7 @@ def register_application(
file_path: str,
function_name: str,
alias: str | None,
auth_mode: Literal["public", "private", "shared"],
auth_mode: ALIAS_AUTH_TYPE,
):
import runpy

Expand Down Expand Up @@ -362,6 +366,36 @@ def _alias_table(aliases: list[AliasInfo]):
return table


@alias_cli.command("set")
@click.argument("alias", required=True)
@click.argument("revision", required=True)
@click.option(
"--auth",
"auth_mode",
type=click.Choice(ALIAS_AUTH_OPTIONS),
default="private",
)
@click.pass_obj
def alias_set(
client: api.FalServerlessClient,
alias: str,
revision: str,
auth_mode: ALIAS_AUTH_TYPE,
):
with client.connect() as connection:
connection.create_alias(alias, revision, auth_mode)


@alias_cli.command("delete")
@click.argument("alias", required=True)
@click.pass_obj
def alias_delete(client: api.FalServerlessClient, alias: str):
with client.connect() as connection:
application_id = connection.delete_alias(alias)

console.print(f"Deleted alias '{alias}' for application '{application_id}'.")


@alias_cli.command("list")
@click.pass_obj
def alias_list(client: api.FalServerlessClient):
Expand All @@ -377,6 +411,12 @@ def alias_list(client: api.FalServerlessClient):
@click.option("--keep-alive", "-k", type=int)
@click.option("--max-multiplexing", "-m", type=int)
@click.option("--max-concurrency", "-c", type=int)
# TODO: add auth_mode
# @click.option(
# "--auth",
# "auth_mode",
# type=click.Choice(ALIAS_AUTH_OPTIONS),
# )
@click.pass_obj
def alias_update(
client: api.FalServerlessClient,
Expand Down
25 changes: 25 additions & 0 deletions projects/fal/src/fal/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,31 @@ def run(
for partial_result in self.stub.Run(request):
yield from_grpc(partial_result)

def create_alias(
self,
alias: str,
revision: str,
auth_mode: Literal["public", "private", "shared"],
):
if auth_mode == "public":
auth = isolate_proto.ApplicationAuthMode.PUBLIC
elif auth_mode == "shared":
auth = isolate_proto.ApplicationAuthMode.SHARED
else:
auth = isolate_proto.ApplicationAuthMode.PRIVATE

request = isolate_proto.SetAliasRequest(
alias=alias,
revision=revision,
auth_mode=auth,
)
self.stub.SetAlias(request)

def delete_alias(self, alias: str) -> str:
request = isolate_proto.DeleteAliasRequest(alias=alias)
res: isolate_proto.DeleteAliasResult = self.stub.DeleteAlias(request)
return res.revision

def list_aliases(self) -> list[AliasInfo]:
request = isolate_proto.ListAliasesRequest()
response: isolate_proto.ListAliasesResult = self.stub.ListAliases(request)
Expand Down
155 changes: 131 additions & 24 deletions projects/fal/tests/test_apps.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,70 @@
from typing import Generator

import fal
import fal.api as api
import pytest
from fal import apps
from fal.rest_client import REST_CLIENT
import time
from pydantic import BaseModel

from openapi_fal_rest.api.applications import app_metadata


@pytest.fixture(scope="module")
def test_app() -> Generator[str, None, None]:
# Create a temporary app, register it, and return the ID of it.
class Input(BaseModel):
lhs: int
rhs: int
wait_time: int = 0

import time

from fal.cli import _get_user_id
class Output(BaseModel):
result: int


@fal.function(
keep_alive=60,
machine_type="S",
serve=True,
max_concurrency=1,
)
def addition_app(input: Input) -> Output:
print("starting...")
for _ in range(input.wait_time):
print("sleeping...")
time.sleep(1)

class Input(BaseModel):
lhs: int
rhs: int
wait_time: int = 0
return Output(result=input.lhs + input.rhs)


@pytest.fixture(scope="module")
def aliased_app() -> Generator[tuple[str, str], None, None]:
# Create a temporary app, register it, and return the ID of it.

class Output(BaseModel):
result: int
import uuid

@fal.function(
keep_alive=60,
machine_type="S",
serve=True,
max_concurrency=1,
app_alias = str(uuid.uuid4()).replace("-", "")[:10]
app_revision = addition_app.host.register(
func=addition_app.func,
options=addition_app.options,
# random enough
application_name=app_alias,
application_auth_mode="private",
)
def addition_app(input: Input) -> Output:
print("starting...")
for _ in range(input.wait_time):
print("sleeping...")
time.sleep(1)
yield app_revision, app_alias # type: ignore


return Output(result=input.lhs + input.rhs)
@pytest.fixture(scope="module")
def test_app():
# Create a temporary app, register it, and return the ID of it.

app_alias = addition_app.host.register(
from fal.cli import _get_user_id

app_revision = addition_app.host.register(
func=addition_app.func,
options=addition_app.options,
)
user_id = _get_user_id()
yield f"{user_id}-{app_alias}"
yield f"{user_id}-{app_revision}"


def test_app_client(test_app: str):
Expand Down Expand Up @@ -99,3 +119,90 @@ def test_app_openapi_spec_metadata(test_app: str):
openapi_spec: dict = metadata["openapi"]
for key in ["openapi", "info", "paths", "components"]:
assert key in openapi_spec, f"{key} key missing from openapi {openapi_spec}"


def test_app_update_app(aliased_app: tuple[str, str]):
app_revision, app_alias = aliased_app

host: api.FalServerlessHost = addition_app.host # type: ignore
with host._connection as client:
# Get the registered values
res = client.list_aliases()
found = next(filter(lambda alias: alias.alias == app_alias, res), None)
assert found, f"Could not find app {app_alias} in {res}"
assert found.revision == app_revision

with host._connection as client:
new_keep_alive = found.keep_alive + 1
new_max_concurrency = found.max_concurrency + 1
new_max_multiplexing = found.max_multiplexing + 1

res = client.update_application(
application_name=app_alias,
keep_alive=new_keep_alive,
max_concurrency=new_max_concurrency,
max_multiplexing=new_max_multiplexing,
)
assert res.alias == app_alias
assert res.keep_alive == new_keep_alive
assert res.max_concurrency == new_max_concurrency
assert res.max_multiplexing == new_max_multiplexing

with host._connection as client:
new_keep_alive = new_keep_alive + 1
res = client.update_application(
application_name=app_alias,
keep_alive=new_keep_alive,
)
assert res.alias == app_alias
assert res.keep_alive == new_keep_alive
assert res.max_concurrency == new_max_concurrency
assert res.max_multiplexing == new_max_multiplexing

with host._connection as client:
new_max_concurrency = new_max_concurrency + 1
res = client.update_application(
application_name=app_alias,
max_concurrency=new_max_concurrency,
)
assert res.alias == app_alias
assert res.keep_alive == new_keep_alive
assert res.max_concurrency == new_max_concurrency
assert res.max_multiplexing == new_max_multiplexing


def test_app_set_delete_alias(aliased_app: tuple[str, str]):
app_revision, app_alias = aliased_app

host: api.FalServerlessHost = addition_app.host # type: ignore

with host._connection as client:
# Get the registered values
res = client.list_aliases()
found = next(filter(lambda alias: alias.alias == app_alias, res), None)
assert found, f"Could not find app {app_alias} in {res}"
assert found.revision == app_revision
assert found.auth_mode == "private"

new_app_alias = f"{app_alias}-new"
with host._connection as client:
# Get the registered values
res = client.create_alias(new_app_alias, app_revision, "public")

with host._connection as client:
# Get the registered values
res = client.list_aliases()
found = next(filter(lambda alias: alias.alias == new_app_alias, res), None)
assert found, f"Could not find app {app_alias} in {res}"
assert found.revision == app_revision
assert found.auth_mode == "public"

with host._connection as client:
res = client.delete_alias(alias=app_alias)
assert res == app_revision

with host._connection as client:
# Get the registered values
res = client.list_aliases()
found = next(filter(lambda alias: alias.alias == app_alias, res), None)
assert not found, f"Found app {app_alias} in {res} after deletion"

0 comments on commit c6de55c

Please sign in to comment.