Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions ms_agent/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import inspect
from typing import Any, Dict, Generator, Iterator, List, Optional, Union

import json5
from ms_agent.llm import LLM
from ms_agent.llm.utils import Message, Tool, ToolCall
from ms_agent.utils import assert_package_exist, retry
Expand Down Expand Up @@ -109,10 +108,20 @@ def _call_llm(self,
if formatted_messages[0]['role'] == 'system':
system = formatted_messages[0]['content']
formatted_messages = formatted_messages[1:]

max_tokens = kwargs.pop('max_tokens', 16000)
extra_body = kwargs.get('extra_body', {})
enable_thinking = extra_body.get('enable_thinking', False)
thinking_budget = extra_body.get('thinking_budget', max_tokens)

params = {
'model': self.model,
'messages': formatted_messages,
'max_tokens': kwargs.pop('max_tokens', 1024),
'max_tokens': max_tokens,
'thinking': {
'type': 'enabled' if enable_thinking else 'disabled',
'budget_tokens': thinking_budget
}
}

if system:
Expand Down Expand Up @@ -163,15 +172,22 @@ def _stream_format_output_message(self,
)
tool_call_id_map = {} # index -> tool_call_id (用于去重 yield)
with stream_manager as stream:
full_content = ''
full_thinking = ''
for event in stream:
event_type = getattr(event, 'type')
if event_type == 'message_start':
msg = event.message
current_message.id = msg.id
tool_call_id_map = {}
yield current_message
elif event_type == 'text':
current_message.content = event.snapshot
elif event_type == 'content_block_delta':
if event.delta.type == 'thinking_delta':
full_thinking += event.delta.thinking
current_message.reasoning_content = full_thinking
elif event.delta.type == 'text_delta':
full_content += event.delta.text
current_message.content = full_content
yield current_message
elif event_type == 'message_stop':
final_msg = getattr(event, 'message')
Expand Down
29 changes: 9 additions & 20 deletions tests/llm/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from modelscope.utils.test_utils import test_level

API_CALL_MAX_TOKEN = 50
API_CALL_MAX_TOKEN = 500


class OpenaiLLM(unittest.TestCase):
Expand Down Expand Up @@ -124,34 +124,23 @@ def test_tool_no_stream(self):
print(res)
assert (len(res.tool_calls))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_agent_multi_round(self):
import asyncio

async def main():
agent = LLMAgent(config=self.conf, mcp_config=self.mcp_config)
if hasattr(agent.config, 'callbacks'):
agent.config.callbacks.remove('input_callback') # noqa
res = await agent.run('访问www.baidu.com')
print(res)
assert ('robots.txt' in res[-1].content)

asyncio.run(main())

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_stream_agent_multi_round(self):
def test_stream_agent_multi_round_with_thinking(self):
import asyncio
from copy import deepcopy

async def main():
conf2 = deepcopy(self.conf)
conf2.generation_config.stream = True
conf2.llm.model = 'Qwen/Qwen3-235B-A22B'
conf2.generation_config.extra_body.enable_thinking = True
agent = LLMAgent(config=conf2, mcp_config=self.mcp_config)
if hasattr(agent.config, 'callbacks'):
agent.config.callbacks.remove('input_callback') # noqa
res = await agent.run('访问www.baidu.com')
print('res:', res)
assert ('robots.txt' in res[-1].content)
res = await agent.run('访问www.baidu.com', stream=True)
async for chunk in res:
print('res: ', chunk)
assert ('robots.txt' in chunk[-1].content)
assert (chunk[-1].reasoning_content)

asyncio.run(main())

Expand Down
Loading