@@ -1247,6 +1247,143 @@ def test_run_inference_with_agent_engine_with_response_column_raises_error(
12471247 "'intermediate_events' or 'response' columns"
12481248 ) in str (excinfo .value )
12491249
1250+ @mock .patch .object (_evals_utils , "EvalDatasetLoader" )
1251+ @mock .patch (
1252+ "vertexai._genai._evals_common.InMemorySessionService"
1253+ )
1254+ @mock .patch ("vertexai._genai._evals_common.Runner" )
1255+ @mock .patch ("vertexai._genai._evals_common.LlmAgent" )
1256+ def test_run_inference_with_local_agent (
1257+ self ,
1258+ mock_llm_agent ,
1259+ mock_runner ,
1260+ mock_session_service ,
1261+ mock_eval_dataset_loader ,
1262+ ):
1263+ mock_df = pd .DataFrame (
1264+ {
1265+ "prompt" : ["agent prompt" , "agent prompt 2" ],
1266+ "session_inputs" : [
1267+ {
1268+ "user_id" : "123" ,
1269+ "state" : {"a" : "1" },
1270+ },
1271+ {
1272+ "user_id" : "456" ,
1273+ "state" : {"b" : "2" },
1274+ },
1275+ ],
1276+ }
1277+ )
1278+ mock_eval_dataset_loader .return_value .load .return_value = mock_df .to_dict (
1279+ orient = "records"
1280+ )
1281+
1282+ mock_agent_instance = mock .Mock ()
1283+ mock_llm_agent .return_value = mock_agent_instance
1284+ mock_session_service .return_value .create_session = mock .AsyncMock ()
1285+ mock_runner_instance = mock_runner .return_value
1286+ stream_run_return_value_1 = [
1287+ mock .Mock (
1288+ model_dump = lambda : {
1289+ "id" : "1" ,
1290+ "content" : {"parts" : [{"text" : "intermediate1" }]},
1291+ "timestamp" : 123 ,
1292+ "author" : "model" ,
1293+ }
1294+ ),
1295+ mock .Mock (
1296+ model_dump = lambda : {
1297+ "id" : "2" ,
1298+ "content" : {"parts" : [{"text" : "agent response" }]},
1299+ "timestamp" : 124 ,
1300+ "author" : "model" ,
1301+ }
1302+ ),
1303+ ]
1304+ stream_run_return_value_2 = [
1305+ mock .Mock (
1306+ model_dump = lambda : {
1307+ "id" : "3" ,
1308+ "content" : {"parts" : [{"text" : "intermediate2" }]},
1309+ "timestamp" : 125 ,
1310+ "author" : "model" ,
1311+ }
1312+ ),
1313+ mock .Mock (
1314+ model_dump = lambda : {
1315+ "id" : "4" ,
1316+ "content" : {"parts" : [{"text" : "agent response 2" }]},
1317+ "timestamp" : 126 ,
1318+ "author" : "model" ,
1319+ }
1320+ ),
1321+ ]
1322+
1323+ async def async_iterator (items ):
1324+ for item in items :
1325+ yield item
1326+
1327+ mock_runner_instance .run_async .side_effect = [
1328+ async_iterator (stream_run_return_value_1 ),
1329+ async_iterator (stream_run_return_value_2 ),
1330+ ]
1331+
1332+ inference_result = self .client .evals .run_inference (
1333+ agent = mock_agent_instance ,
1334+ src = mock_df ,
1335+ )
1336+
1337+ mock_eval_dataset_loader .return_value .load .assert_called_once_with (mock_df )
1338+ assert mock_session_service .call_count == 2
1339+ mock_runner .assert_called_with (
1340+ agent = mock_agent_instance ,
1341+ app_name = "local agent run" ,
1342+ session_service = mock_session_service .return_value ,
1343+ )
1344+ assert mock_runner .call_count == 2
1345+ assert mock_runner_instance .run_async .call_count == 2
1346+
1347+ pd .testing .assert_frame_equal (
1348+ inference_result .eval_dataset_df ,
1349+ pd .DataFrame (
1350+ {
1351+ "prompt" : ["agent prompt" , "agent prompt 2" ],
1352+ "session_inputs" : [
1353+ {
1354+ "user_id" : "123" ,
1355+ "state" : {"a" : "1" },
1356+ },
1357+ {
1358+ "user_id" : "456" ,
1359+ "state" : {"b" : "2" },
1360+ },
1361+ ],
1362+ "intermediate_events" : [
1363+ [
1364+ {
1365+ "event_id" : "1" ,
1366+ "content" : {"parts" : [{"text" : "intermediate1" }]},
1367+ "creation_timestamp" : 123 ,
1368+ "author" : "model" ,
1369+ }
1370+ ],
1371+ [
1372+ {
1373+ "event_id" : "3" ,
1374+ "content" : {"parts" : [{"text" : "intermediate2" }]},
1375+ "creation_timestamp" : 125 ,
1376+ "author" : "model" ,
1377+ }
1378+ ],
1379+ ],
1380+ "response" : ["agent response" , "agent response 2" ],
1381+ }
1382+ ),
1383+ )
1384+ assert inference_result .candidate_name is None
1385+ assert inference_result .gcs_source is None
1386+
12501387 def test_run_inference_with_litellm_string_prompt_format (
12511388 self ,
12521389 mock_api_client_fixture ,
@@ -1599,6 +1736,7 @@ def test_run_agent_internal_success(self, mock_run_agent):
15991736 result_df = _evals_common ._run_agent_internal (
16001737 api_client = mock_api_client ,
16011738 agent_engine = mock_agent_engine ,
1739+ agent = None ,
16021740 prompt_dataset = prompt_dataset ,
16031741 )
16041742
@@ -1629,6 +1767,7 @@ def test_run_agent_internal_error_response(self, mock_run_agent):
16291767 result_df = _evals_common ._run_agent_internal (
16301768 api_client = mock_api_client ,
16311769 agent_engine = mock_agent_engine ,
1770+ agent = None ,
16321771 prompt_dataset = prompt_dataset ,
16331772 )
16341773
@@ -1655,6 +1794,7 @@ def test_run_agent_internal_malformed_event(self, mock_run_agent):
16551794 result_df = _evals_common ._run_agent_internal (
16561795 api_client = mock_api_client ,
16571796 agent_engine = mock_agent_engine ,
1797+ agent = None ,
16581798 prompt_dataset = prompt_dataset ,
16591799 )
16601800 assert "response" in result_df .columns
0 commit comments