Skip to content

Commit abb2a63

Browse files
committed
merge with master + tests + fixes and scripts
1 parent c2d4e74 commit abb2a63

File tree

6 files changed

+180
-4
lines changed

6 files changed

+180
-4
lines changed

ai_engine_sdk/client.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,11 @@ async def delete(self):
273273

274274
async def execute_function(self, function_ids: list[str], objective: str, context: str|None = None):
275275
await self._submit_message(
276-
payload=ApiUserMessageExecuteFunctions.parse_obj({
276+
payload=ApiUserMessageExecuteFunctions.model_validate({
277277
"functions": function_ids,
278278
"objective": objective,
279-
"context": context or ""
279+
"context": context or "",
280+
'session_id': self.session_id,
280281
})
281282
)
282283

@@ -285,6 +286,7 @@ def __init__(self, api_key: str, options: Optional[dict] = None):
285286
self._api_base_url = options.get('api_base_url') if options and 'api_base_url' in options else default_api_base_url
286287
self._api_key = api_key
287288

289+
288290
####
289291
# Function groups
290292
####
@@ -391,7 +393,7 @@ async def get_functions_by_function_group(self, function_group_id: str) -> list[
391393
if "functions" in raw_response:
392394
list(
393395
map(
394-
lambda function_name: FunctionGroupFunctions.parse_obj({"name": function_name}),
396+
lambda function_name: FunctionGroupFunctions.model_validate({"name": function_name}),
395397
raw_response["functions"]
396398
)
397399
)

examples/execute_function.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import argparse
2+
import asyncio
3+
import os
4+
from pprint import pprint
5+
6+
from faker.utils.decorators import lowercase
7+
8+
from ai_engine_sdk import AiEngine, FunctionGroup, ApiBaseMessage
9+
from ai_engine_sdk.client import Session
10+
from tests.conftest import function_groups
11+
12+
13+
async def main(
14+
target_environment: str,
15+
agentverse_api_key: str,
16+
function_uuid: str,
17+
function_group_uuid: str
18+
):
19+
# Request from cli args.
20+
options = {}
21+
if target_environment:
22+
options = {"api_base_url": target_environment}
23+
24+
ai_engine = AiEngine(api_key=agentverse_api_key, options=options)
25+
26+
session: Session = await ai_engine.create_session(function_group=function_group_uuid)
27+
await session.execute_function(function_ids=[function_uuid], objective="", context="")
28+
29+
try:
30+
empty_count = 0
31+
session_ended = False
32+
33+
print("Waiting for execution:")
34+
while empty_count < 100:
35+
messages: list[ApiBaseMessage] = await session.get_messages()
36+
if messages:
37+
pprint(messages)
38+
if any((msg.type.lower() == "stop" for msg in messages)):
39+
print("DONE")
40+
break
41+
if len(messages) % 10 == 0:
42+
print("Wait...")
43+
if len(messages) == 0:
44+
empty_count += 1
45+
else:
46+
empty_count = 0
47+
48+
49+
except Exception as ex:
50+
pprint(ex)
51+
raise
52+
53+
if __name__ == '__main__':
54+
from dotenv import load_dotenv
55+
load_dotenv()
56+
api_key = os.getenv("AV_API_KEY", "")
57+
58+
parser = argparse.ArgumentParser()
59+
parser.add_argument(
60+
"-e",
61+
"--target_environment",
62+
type=str,
63+
required=False,
64+
help="The target environment: staging, localhost, production... You need to explicitly add the domain. By default it will be production."
65+
)
66+
parser.add_argument(
67+
"-fg",
68+
"--function_group_uuid",
69+
type=str,
70+
required=True,
71+
)
72+
parser.add_argument(
73+
"-f",
74+
"--function_uuid",
75+
type=str,
76+
required=True,
77+
)
78+
args = parser.parse_args()
79+
80+
result = asyncio.run(
81+
main(
82+
agentverse_api_key=api_key,
83+
target_environment=args.target_environment,
84+
function_group_uuid=args.function_group_uuid,
85+
function_uuid=args.function_uuid
86+
)
87+
)
88+
pprint(result)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import argparse
2+
import asyncio
3+
import os
4+
from pprint import pprint
5+
6+
from ai_engine_sdk import FunctionGroup, AiEngine
7+
from tests.integration.test_ai_engine_client import api_key
8+
9+
10+
async def main(
11+
function_group_name: str,
12+
agentverse_api_key: str,
13+
target_environment: str | None = None,
14+
):
15+
# Request from cli args.
16+
options = {}
17+
if target_environment:
18+
options = {"api_base_url": target_environment}
19+
20+
ai_engine: AiEngine = AiEngine(api_key=agentverse_api_key, options=options)
21+
function_groups: list[FunctionGroup] = await ai_engine.get_function_groups()
22+
23+
target_function_group = next((g for g in function_groups if g.name == function_group_name), None)
24+
if target_function_group is None:
25+
raise Exception(f'Could not find "{target_function_group}" function group.')
26+
27+
return await ai_engine.get_functions_by_function_group(function_group_id=target_function_group.uuid)
28+
29+
30+
31+
if __name__ == "__main__":
32+
from dotenv import load_dotenv
33+
load_dotenv()
34+
api_key = os.getenv("AV_API_KEY", "")
35+
36+
# Parse CLI arguments
37+
parser = argparse.ArgumentParser()
38+
39+
parser.add_argument(
40+
"-e",
41+
"--target_environment",
42+
type=str,
43+
required=False,
44+
help="The target environment: staging, localhost, production... You need to explicitly add the domain. By default it will be production."
45+
)
46+
parser.add_argument(
47+
"-fgn",
48+
"--fg_name",
49+
type=str,
50+
required=True,
51+
)
52+
args = parser.parse_args()
53+
54+
target_environment = args.target_environment
55+
56+
res = asyncio.run(
57+
main(
58+
agentverse_api_key=api_key,
59+
function_group_name=args.fg_name,
60+
target_environment=args.target_environment
61+
)
62+
)
63+
pprint(res)

tests/conftest.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,17 @@ async def function_groups(ai_engine_client) -> list[FunctionGroup]:
3434
# session: Session = await ai_engine_client.create_session(
3535
# function_group=function_groups, opts={"model": "next-gen"}
3636
# )
37-
# return session
37+
# return session
38+
39+
40+
@pytest.fixture(scope="session")
41+
def valid_public_function_uuid() -> str:
42+
# TODO: Do it programmatically (when test fails bc of it will be good moment)
43+
# 'Cornerstone Software' from Public fg and staging
44+
return "312712ae-eb70-42f7-bb5a-ad21ce6d73c3"
45+
46+
47+
@pytest.fixture(scope="session")
48+
def public_function_group() -> FunctionGroup:
49+
# TODO: Do it programmatically (when test fails bc of it will be good moment)
50+
return FunctionGroup(uuid="e504eabb-4bc7-458d-aa8c-7c3748f8952c", name="Public", isPrivate=False)

tests/integration/test_ai_engine_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ async def test_create_session(self, ai_engine_client: AiEngine):
6060
# await ai_engine_client.delete_function_group()
6161

6262

63+
@pytest.mark.asyncio
64+
async def test_execute_function(self, ai_engine_client: AiEngine, public_function_group: FunctionGroup, valid_public_function_uuid: str):
65+
session: Session = await ai_engine_client.create_session(function_group=public_function_group.uuid)
66+
result = await session.execute_function(
67+
function_ids=[valid_public_function_uuid],
68+
objective="Test software",
69+
context=""
70+
)
71+
72+
6373
@pytest.mark.asyncio
6474
async def test_create_function_group_and_list_them(self, ai_engine_client: AiEngine):
6575
name = fake.company()

0 commit comments

Comments
 (0)