Skip to content

Commit 66ae92d

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
chore: Make runnable_name required.
PiperOrigin-RevId: 731445881
1 parent 632730c commit 66ae92d

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

tests/unit/vertex_ag2/test_reasoning_engine_templates_ag2.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_TEST_LOCATION = "us-central1"
3232
_TEST_PROJECT = "test-project"
3333
_TEST_MODEL = "gemini-1.0-pro"
34+
_TEST_RUNNABLE_NAME = "test-runnable"
3435
_TEST_SYSTEM_INSTRUCTION = "You are a helpful bot."
3536

3637

@@ -127,8 +128,11 @@ def teardown_method(self):
127128
initializer.global_pool.shutdown(wait=True)
128129

129130
def test_initialization(self):
130-
agent = reasoning_engines.AG2Agent(model=_TEST_MODEL)
131+
agent = reasoning_engines.AG2Agent(
132+
model=_TEST_MODEL, runnable_name=_TEST_RUNNABLE_NAME
133+
)
131134
assert agent._model_name == _TEST_MODEL
135+
assert agent._runnable_name == _TEST_RUNNABLE_NAME
132136
assert agent._project == _TEST_PROJECT
133137
assert agent._location == _TEST_LOCATION
134138
assert agent._runnable is None
@@ -140,6 +144,7 @@ def test_initialization_with_tools(self, autogen_tools_mock):
140144
]
141145
agent = reasoning_engines.AG2Agent(
142146
model=_TEST_MODEL,
147+
runnable_name=_TEST_RUNNABLE_NAME,
143148
system_instruction=_TEST_SYSTEM_INSTRUCTION,
144149
tools=tools,
145150
runnable_builder=lambda **kwargs: kwargs,
@@ -154,6 +159,7 @@ def test_initialization_with_tools(self, autogen_tools_mock):
154159
def test_set_up(self):
155160
agent = reasoning_engines.AG2Agent(
156161
model=_TEST_MODEL,
162+
runnable_name=_TEST_RUNNABLE_NAME,
157163
runnable_builder=lambda **kwargs: kwargs,
158164
)
159165
assert agent._runnable is None
@@ -163,6 +169,7 @@ def test_set_up(self):
163169
def test_clone(self):
164170
agent = reasoning_engines.AG2Agent(
165171
model=_TEST_MODEL,
172+
runnable_name=_TEST_RUNNABLE_NAME,
166173
runnable_builder=lambda **kwargs: kwargs,
167174
)
168175
agent.set_up()
@@ -176,6 +183,7 @@ def test_clone(self):
176183
def test_query(self, dataclasses_asdict_mock):
177184
agent = reasoning_engines.AG2Agent(
178185
model=_TEST_MODEL,
186+
runnable_name=_TEST_RUNNABLE_NAME,
179187
)
180188
agent._runnable = mock.Mock()
181189
mocks = mock.Mock()
@@ -202,6 +210,7 @@ def test_enable_tracing(
202210
):
203211
agent = reasoning_engines.AG2Agent(
204212
model=_TEST_MODEL,
213+
runnable_name=_TEST_RUNNABLE_NAME,
205214
enable_tracing=True,
206215
)
207216
assert agent._enable_tracing is True
@@ -220,5 +229,6 @@ def test_raise_untyped_input_args(self, vertexai_init_mock):
220229
with pytest.raises(TypeError, match=r"has untyped input_arg"):
221230
reasoning_engines.AG2Agent(
222231
model=_TEST_MODEL,
232+
runnable_name=_TEST_RUNNABLE_NAME,
223233
tools=[_return_input_no_typing],
224234
)

vertexai/preview/reasoning_engines/templates/ag2.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _prepare_runnable_kwargs(
7171
if "system_message" not in runnable_kwargs and system_instruction:
7272
runnable_kwargs["system_message"] = system_instruction
7373

74-
if "name" not in runnable_kwargs and runnable_name:
74+
if "name" not in runnable_kwargs:
7575
runnable_kwargs["name"] = runnable_name
7676

7777
if "llm_config" not in runnable_kwargs:
@@ -146,11 +146,11 @@ class AG2Agent:
146146
def __init__(
147147
self,
148148
model: str,
149+
runnable_name: str,
149150
*,
150151
api_type: Optional[str] = None,
151152
llm_config: Optional[Mapping[str, Any]] = None,
152153
system_instruction: Optional[str] = None,
153-
runnable_name: Optional[str] = None,
154154
runnable_kwargs: Optional[Mapping[str, Any]] = None,
155155
runnable_builder: Optional[Callable[..., "ConversableAgent"]] = None,
156156
tools: Optional[Sequence[Callable[..., Any]]] = None,
@@ -201,6 +201,11 @@ def __init__(
201201
Required. The name of the model (e.g. "gemini-1.0-pro").
202202
Used to create a default `llm_config` if one is not provided.
203203
This parameter is ignored if `llm_config` is provided.
204+
runnable_name (str):
205+
Required. The name of the runnable.
206+
This name is used as the default `runnable_kwargs["name"]`
207+
unless `runnable_kwargs` already contains a "name", in which
208+
case the provided `runnable_kwargs["name"]` will be used.
204209
api_type (str):
205210
Optional. The API type to use for the language model.
206211
Used to create a default `llm_config` if one is not provided.
@@ -219,11 +224,6 @@ def __init__(
219224
`runnable_kwargs["system_message"]` unless `runnable_kwargs`
220225
already contains a "system_message", in which case the provided
221226
`runnable_kwargs["system_message"]` will be used.
222-
runnable_name (str):
223-
Optional. The name of the runnable.
224-
This name is used as the default `runnable_kwargs["name"]`
225-
unless `runnable_kwargs` already contains a "name", in which
226-
case the provided `runnable_kwargs["name"]` will be used.
227227
runnable_kwargs (Mapping[str, Any]):
228228
Optional. Additional keyword arguments for the constructor of
229229
the runnable. Details of the kwargs can be found in

0 commit comments

Comments
 (0)