Skip to content

Commit 2202f61

Browse files
benshukbenshukJosephasafg
authored
fix: 🔧 introduce MaestroMessage (#298)
* refactor: reorganize imports by moving maestro models to their own module * refactor: update input type in maestro run methods to use MaestroMessage instead of ChatMessage * test: add parameterized tests for Maestro input formats in maestro.py * feat: add async example for multi-message processing * feat: update async example to handle multi-message input for business proposal * feat: include requirements result in README example for Maestro runs * fix: update content formatting in async_run_multi_messages.py for improved readability --------- Co-authored-by: benshuk <[email protected]> Co-authored-by: Asaf Joseph Gardin <[email protected]>
1 parent fc872d5 commit 2202f61

File tree

8 files changed

+133
-51
lines changed

8 files changed

+133
-51
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ run_result = client.beta.maestro.runs.create_and_poll(
269269
"description": "The poem should rhyme",
270270
},
271271
],
272+
include=["requirements_result"]
272273
)
273274
```
274275

ai21/clients/common/maestro/run.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
from abc import ABC, abstractmethod
44
from typing import List
55

6-
from ai21.models.chat import ChatMessage
76
from ai21.models.maestro.run import (
8-
Tool,
9-
ToolResources,
10-
RunResponse,
7+
Budget,
118
DEFAULT_RUN_POLL_INTERVAL,
129
DEFAULT_RUN_POLL_TIMEOUT,
13-
Requirement,
14-
Budget,
10+
MaestroMessage,
1511
OutputOptions,
12+
Requirement,
13+
RunResponse,
14+
Tool,
15+
ToolResources,
1616
)
1717
from ai21.types import NOT_GIVEN, NotGiven
1818
from ai21.utils.typing import remove_not_given
@@ -24,7 +24,7 @@ class BaseMaestroRun(ABC):
2424
def _create_body(
2525
self,
2626
*,
27-
input: str | List[ChatMessage],
27+
input: str | List[MaestroMessage],
2828
models: List[str] | NotGiven,
2929
tools: List[Tool] | NotGiven,
3030
tool_resources: ToolResources | NotGiven,
@@ -33,14 +33,9 @@ def _create_body(
3333
include: List[OutputOptions] | NotGiven,
3434
**kwargs,
3535
) -> dict:
36-
if isinstance(input, list):
37-
_input = [{"role": message.role, "content": message.content} for message in input]
38-
else:
39-
_input = input
40-
4136
return remove_not_given(
4237
{
43-
"input": _input,
38+
"input": input,
4439
"models": models,
4540
"tools": tools,
4641
"tool_resources": tool_resources,
@@ -55,7 +50,7 @@ def _create_body(
5550
def create(
5651
self,
5752
*,
58-
input: str | List[ChatMessage],
53+
input: str | List[MaestroMessage],
5954
models: List[str] | NotGiven = NOT_GIVEN,
6055
tools: List[Tool] | NotGiven = NOT_GIVEN,
6156
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
@@ -78,7 +73,7 @@ def poll_for_status(self, *, run_id: str, poll_interval_sec: float, poll_timeout
7873
def create_and_poll(
7974
self,
8075
*,
81-
input: str | List[ChatMessage],
76+
input: str | List[MaestroMessage],
8277
models: List[str] | NotGiven = NOT_GIVEN,
8378
tools: List[Tool] | NotGiven = NOT_GIVEN,
8479
tool_resources: ToolResources | NotGiven = NOT_GIVEN,

ai21/clients/studio/resources/maestro/run.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@
66

77
from ai21.clients.common.maestro.run import BaseMaestroRun
88
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
9-
from ai21.models.chat import ChatMessage
109
from ai21.models.maestro.run import (
11-
Tool,
12-
ToolResources,
13-
RunResponse,
14-
TERMINATED_RUN_STATUSES,
10+
Budget,
1511
DEFAULT_RUN_POLL_INTERVAL,
1612
DEFAULT_RUN_POLL_TIMEOUT,
17-
Requirement,
18-
Budget,
13+
MaestroMessage,
1914
OutputOptions,
15+
Requirement,
16+
RunResponse,
17+
TERMINATED_RUN_STATUSES,
18+
Tool,
19+
ToolResources,
2020
)
2121
from ai21.types import NotGiven, NOT_GIVEN
2222

@@ -25,7 +25,7 @@ class MaestroRun(StudioResource, BaseMaestroRun):
2525
def create(
2626
self,
2727
*,
28-
input: str | List[ChatMessage],
28+
input: str | List[MaestroMessage],
2929
models: List[str] | NotGiven = NOT_GIVEN,
3030
tools: List[Tool] | NotGiven = NOT_GIVEN,
3131
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
@@ -70,7 +70,7 @@ def poll_for_status(self, *, run_id: str, poll_interval_sec: float, poll_timeout
7070
def create_and_poll(
7171
self,
7272
*,
73-
input: str | List[ChatMessage],
73+
input: str | List[MaestroMessage],
7474
models: List[str] | NotGiven = NOT_GIVEN,
7575
tools: List[Tool] | NotGiven = NOT_GIVEN,
7676
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
@@ -101,7 +101,7 @@ class AsyncMaestroRun(AsyncStudioResource, BaseMaestroRun):
101101
async def create(
102102
self,
103103
*,
104-
input: str | List[ChatMessage],
104+
input: str | List[MaestroMessage],
105105
models: List[str] | NotGiven = NOT_GIVEN,
106106
tools: List[Tool] | NotGiven = NOT_GIVEN,
107107
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
@@ -146,7 +146,7 @@ async def poll_for_status(self, *, run_id: str, poll_interval_sec: float, poll_t
146146
async def create_and_poll(
147147
self,
148148
*,
149-
input: str | List[ChatMessage],
149+
input: str | List[MaestroMessage],
150150
models: List[str] | NotGiven = NOT_GIVEN,
151151
tools: List[Tool] | NotGiven = NOT_GIVEN,
152152
tool_resources: ToolResources | NotGiven = NOT_GIVEN,

ai21/models/__init__.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,6 @@
88
ConversationalRagSource,
99
)
1010
from ai21.models.responses.file_response import FileResponse
11-
from ai21.models.maestro.run import (
12-
Requirement,
13-
Budget,
14-
Tool,
15-
ToolResources,
16-
DataSources,
17-
FileSearchResult,
18-
WebSearchResult,
19-
OutputOptions,
20-
)
2111

2212
__all__ = [
2313
"ChatMessage",
@@ -30,12 +20,4 @@
3020
"FileResponse",
3121
"ConversationalRagResponse",
3222
"ConversationalRagSource",
33-
"Requirement",
34-
"Budget",
35-
"Tool",
36-
"ToolResources",
37-
"DataSources",
38-
"FileSearchResult",
39-
"WebSearchResult",
40-
"OutputOptions",
4123
]

ai21/models/maestro/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from ai21.models.maestro.run import (
2+
Budget,
3+
DataSources,
4+
FileSearchResult,
5+
MaestroMessage,
6+
OutputOptions,
7+
Requirement,
8+
Tool,
9+
ToolResources,
10+
WebSearchResult,
11+
)
12+
13+
__all__ = [
14+
"Budget",
15+
"DataSources",
16+
"FileSearchResult",
17+
"MaestroMessage",
18+
"OutputOptions",
19+
"Requirement",
20+
"Tool",
21+
"ToolResources",
22+
"WebSearchResult",
23+
]

ai21/models/maestro/run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
TERMINATED_RUN_STATUSES: Set[RunStatus] = {"completed", "failed", "requires_action"}
1919

2020

21+
class MaestroMessage(TypedDict):
22+
role: Role
23+
content: str
24+
25+
2126
class Tool(TypedDict):
2227
type: ToolType
2328

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import asyncio
2+
3+
from ai21 import AsyncAI21Client
4+
5+
client = AsyncAI21Client()
6+
7+
8+
async def main():
9+
try:
10+
run_result = await client.beta.maestro.runs.create_and_poll(
11+
input=[
12+
{
13+
"role": "user",
14+
"content": "Write me a memo to my boss suggesting we should pivot to selling parrots. "
15+
"We now have a window cleaning business for high rises",
16+
},
17+
{
18+
"role": "assistant",
19+
"content": "**Subject: Proposal to Pivot Business Focus**\n"
20+
"**Section 1: Current Market Challenges**\n"
21+
"High-rise window cleaning faces increasing competition.\n"
22+
"Profit margins are shrinking due to rising operational costs.\n"
23+
"Weather conditions often disrupt our service schedule.\n"
24+
"**Section 2: New Opportunity in Parrot Sales**\n"
25+
"Parrot sales are booming with pet ownership on the rise.\n"
26+
"Parrots require minimal space and resources compared to current operations.\n"
27+
"Our existing logistics can support this new venture.\n"
28+
"**Section 3: Strategic Advantages**\n"
29+
"Parrots offer higher profit margins than window cleaning.\n"
30+
"Selling parrots expands our customer base.\n"
31+
"This pivot leverages our customer service expertise",
32+
},
33+
{"role": "user", "content": "great, now beef up each section with a few more sentences"},
34+
]
35+
)
36+
37+
print(run_result)
38+
except TimeoutError:
39+
print("The run timed out")
40+
41+
42+
if __name__ == "__main__":
43+
asyncio.run(main())
Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,49 @@
11
import pytest
22

33
from ai21 import AsyncAI21Client
4+
from ai21.models.maestro import MaestroMessage
45

56

67
@pytest.mark.asyncio
78
async def test_maestro__when_upload__should_return_data_sources(): # file_in_library: str):
89
client = AsyncAI21Client()
9-
result = await client.beta.maestro.runs.create_and_poll(
10-
input="When did Einstein receive a Nobel Prize?", tools=[{"type": "file_search"}], include=["data_sources"]
10+
run = await client.beta.maestro.runs.create_and_poll(
11+
input="When did Einstein receive a Nobel Prize?",
12+
tools=[{"type": "file_search"}],
13+
include=["data_sources"],
14+
poll_timeout_sec=200,
1115
)
12-
assert result.status == "completed", "Expected 'completed' status"
13-
assert result.result, "Expected a non-empty answer"
14-
assert result.data_sources, "Expected data sources"
15-
assert len(result.data_sources["file_search"]) > 0, "Expected at least one file search data source"
16-
assert result.data_sources.get("web_search") is None, "Expected no web search data sources"
16+
assert run.status == "completed", f"[RUN {run.id}] Expected 'completed' status"
17+
assert run.result, f"[RUN {run.id}] Expected a non-empty answer"
18+
assert run.data_sources, f"[RUN {run.id}] Expected data sources"
19+
assert len(run.data_sources["file_search"]) > 0, f"[RUN {run.id}] Expected at least one file search data source"
20+
assert run.data_sources.get("web_search") is None, f"[RUN {run.id}] Expected no web search data sources"
21+
22+
23+
@pytest.mark.parametrize(
24+
"input_data,test_description",
25+
[
26+
("What is the capital of France?", "string input"),
27+
(
28+
[MaestroMessage(role="user", content="What is the capital of France?")],
29+
"Use MaestroMessage format in a list",
30+
),
31+
(
32+
[
33+
{"role": "user", "content": "I need help with geography."},
34+
{"role": "assistant", "content": "I'd be happy to help with geography questions."},
35+
{"role": "user", "content": "What is the capital of France?"},
36+
],
37+
"multi-message conversation input",
38+
),
39+
],
40+
)
41+
@pytest.mark.asyncio
42+
async def test_maestro__input_formats__should_accept_string_and_list(input_data, test_description):
43+
"""Test that input can be passed as both string and list of dictionaries."""
44+
client = AsyncAI21Client()
45+
46+
run = await client.beta.maestro.runs.create_and_poll(input=input_data, poll_timeout_sec=200)
47+
48+
assert run.status == "completed", f"[RUN {run.id}] Expected 'completed' status for {test_description}"
49+
assert run.result, f"[RUN {run.id}] Expected a non-empty answer for {test_description}"

0 commit comments

Comments
 (0)