Skip to content

Commit b1e8dcf

Browse files
committed
enable to use tools
1 parent d23e4f9 commit b1e8dcf

File tree

4 files changed

+463
-102
lines changed

4 files changed

+463
-102
lines changed

bindings/ceylon/ceylon/llm/agent.py

+167-16
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,19 @@
55
import asyncio
66
from dataclasses import dataclass, field
77
from datetime import datetime
8-
from typing import Dict, Any
8+
from typing import Dict, Any, Optional, List, Sequence
99

1010
from pydantic import BaseModel
1111

1212
from ceylon.llm.models import Model, ModelSettings, ModelMessage
13-
from ceylon.llm.models.support.messages import MessageRole, TextPart
13+
from ceylon.llm.models.support.messages import (
14+
MessageRole,
15+
TextPart,
16+
ToolCallPart,
17+
ToolReturnPart,
18+
ModelMessagePart
19+
)
20+
from ceylon.llm.models.support.tools import ToolDefinition
1421
from ceylon.processor.agent import ProcessWorker
1522
from ceylon.processor.data import ProcessRequest
1623

@@ -30,6 +37,8 @@ class LLMConfig(BaseModel):
3037
retry_attempts: int = 3
3138
retry_delay: float = 1.0
3239
timeout: float = 30.0
40+
tools: Optional[Sequence[ToolDefinition]] = None
41+
parallel_tool_calls: Optional[int] = None
3342

3443
class Config:
3544
arbitrary_types_allowed = True
@@ -58,33 +67,175 @@ def __init__(
5867
self.response_cache: Dict[str, LLMResponse] = {}
5968
self.processing_lock = asyncio.Lock()
6069

61-
# Initialize model context with settings
70+
# Initialize model context with settings and tools
6271
self.model_context = self.llm_model.create_context(
6372
settings=ModelSettings(
6473
temperature=config.temperature,
65-
max_tokens=config.max_tokens
66-
)
74+
max_tokens=config.max_tokens,
75+
parallel_tool_calls=config.parallel_tool_calls
76+
),
77+
tools=config.tools or []
6778
)
6879

69-
async def _processor(self, request: ProcessRequest, time: int):
80+
async def _process_tool_calls(
81+
self,
82+
message_parts: List[ModelMessagePart]
83+
) -> List[ModelMessagePart]:
84+
"""Process any tool calls in the message parts and return updated parts."""
85+
processed_parts = []
86+
87+
for part in message_parts:
88+
if isinstance(part, ToolCallPart):
89+
try:
90+
# Find the corresponding tool
91+
tool = next(
92+
(t for t in self.config.tools or []
93+
if t.name == part.tool_name),
94+
None
95+
)
96+
97+
if tool:
98+
# Execute the tool
99+
result = await tool.function(**part.args)
100+
101+
# Add the tool return
102+
processed_parts.append(
103+
ToolReturnPart(
104+
tool_name=part.tool_name,
105+
content=result
106+
)
107+
)
108+
else:
109+
# Tool not found - add error message
110+
processed_parts.append(
111+
TextPart(
112+
text=f"Error: Tool '{part.tool_name}' not found"
113+
)
114+
)
115+
except Exception as e:
116+
# Handle tool execution error
117+
processed_parts.append(
118+
TextPart(
119+
text=f"Error executing tool '{part.tool_name}': {str(e)}"
120+
)
121+
)
122+
else:
123+
processed_parts.append(part)
124+
125+
return processed_parts
126+
127+
async def _process_conversation(
128+
self,
129+
messages: List[ModelMessage]
130+
) -> List[ModelMessage]:
131+
"""Process a conversation, handling tool calls as needed."""
132+
processed_messages = []
133+
134+
for message in messages:
135+
if message.role == MessageRole.ASSISTANT:
136+
# Process any tool calls in assistant messages
137+
processed_parts = await self._process_tool_calls(message.parts)
138+
processed_messages.append(
139+
ModelMessage(
140+
role=message.role,
141+
parts=processed_parts
142+
)
143+
)
144+
else:
145+
processed_messages.append(message)
146+
147+
return processed_messages
148+
149+
def _parse_request_data(self, data: Any) -> str:
150+
"""Parse the request data into a string format."""
151+
if isinstance(data, str):
152+
return data
153+
elif isinstance(data, dict):
154+
return data.get("request", str(data))
155+
else:
156+
return str(data)
157+
158+
async def _processor(self, request: ProcessRequest, time: int) -> tuple[str, Dict[str, Any]]:
159+
"""Process a request using the LLM model."""
160+
# Initialize conversation with system prompt
70161
message_list = [
71162
ModelMessage(
72163
role=MessageRole.SYSTEM,
73-
parts=[
74-
TextPart(text=self.config.system_prompt)
75-
]
76-
),
164+
parts=[TextPart(text=self.config.system_prompt)]
165+
)
166+
]
167+
168+
# Add user message
169+
user_text = self._parse_request_data(request.data)
170+
message_list.append(
77171
ModelMessage(
78172
role=MessageRole.USER,
79-
parts=[
80-
TextPart(text=request.data)
81-
]
173+
parts=[TextPart(text=user_text)]
82174
)
83-
]
175+
)
176+
177+
# Track the complete conversation
178+
complete_conversation = message_list.copy()
179+
final_response = None
180+
metadata = {}
181+
182+
for attempt in range(self.config.retry_attempts):
183+
try:
184+
# Get model response
185+
response, usage = await self.llm_model.request(
186+
message_list,
187+
self.model_context
188+
)
189+
190+
# Add model response to conversation
191+
assistant_message = ModelMessage(
192+
role=MessageRole.ASSISTANT,
193+
parts=response.parts
194+
)
195+
complete_conversation.append(assistant_message)
196+
197+
# Process any tool calls
198+
complete_conversation = await self._process_conversation(
199+
complete_conversation
200+
)
201+
202+
# Extract final text response
203+
final_text_parts = [
204+
part.text for part in response.parts
205+
if isinstance(part, TextPart)
206+
]
207+
final_response = " ".join(final_text_parts)
208+
209+
# Update metadata
210+
metadata.update({
211+
"usage": {
212+
"requests": usage.requests,
213+
"request_tokens": usage.request_tokens,
214+
"response_tokens": usage.response_tokens,
215+
"total_tokens": usage.total_tokens
216+
},
217+
"attempt": attempt + 1,
218+
"tools_used": [
219+
part.tool_name for part in response.parts
220+
if isinstance(part, ToolCallPart)
221+
]
222+
})
223+
224+
# If we got a response, break the retry loop
225+
if final_response:
226+
break
227+
228+
except Exception as e:
229+
if attempt == self.config.retry_attempts - 1:
230+
raise
231+
await asyncio.sleep(self.config.retry_delay)
232+
233+
if not final_response:
234+
raise ValueError("No valid response generated")
84235

85-
return await self.llm_model.request(message_list, self.model_context)
236+
return final_response, metadata
86237

87238
async def stop(self) -> None:
88239
if self.llm_model:
89240
await self.llm_model.close()
90-
await super().stop()
241+
await super().stop()

bindings/ceylon/ceylon/llm/models/__init__.py

+129-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,27 @@
1616
from ceylon.llm.models.support.settings import ModelSettings
1717
from ceylon.llm.models.support.tools import ToolDefinition
1818
from ceylon.llm.models.support.usage import Usage, UsageLimits
19+
from abc import ABC, abstractmethod
20+
from dataclasses import dataclass
21+
from types import TracebackType
22+
from typing import AsyncIterator, Optional, Sequence, Type, Any
23+
import re
24+
import json
25+
26+
from ceylon.llm.models.support.http import AsyncHTTPClient, cached_async_http_client
27+
from ceylon.llm.models.support.messages import (
28+
ModelMessage,
29+
ModelResponse,
30+
StreamedResponse,
31+
MessageRole,
32+
TextPart,
33+
ToolCallPart,
34+
ToolReturnPart,
35+
ModelMessagePart
36+
)
37+
from ceylon.llm.models.support.settings import ModelSettings
38+
from ceylon.llm.models.support.tools import ToolDefinition
39+
from ceylon.llm.models.support.usage import Usage, UsageLimits
1940

2041

2142
@dataclass
@@ -27,7 +48,13 @@ class ModelContext:
2748

2849

2950
class Model(ABC):
30-
"""Base class for all language model implementations"""
51+
"""Base class for all language model implementations with tool support"""
52+
53+
# Regex pattern for extracting tool calls - can be overridden by subclasses
54+
TOOL_CALL_PATTERN = re.compile(
55+
r'<tool_call>(?P<tool_json>.*?)</tool_call>',
56+
re.DOTALL
57+
)
3158

3259
def __init__(
3360
self,
@@ -147,11 +174,108 @@ def _check_usage_limits(self, usage: Usage, limits: UsageLimits) -> None:
147174
raise UsageLimitExceeded(
148175
f"Request limit {limits.request_limit} exceeded"
149176
)
150-
if limits.total_tokens and usage.total_tokens >= limits.total_tokens:
177+
if limits.request_tokens_limit and usage.request_tokens >= limits.request_tokens_limit:
178+
raise UsageLimitExceeded(
179+
f"Request tokens limit {limits.request_tokens_limit} exceeded"
180+
)
181+
if limits.response_tokens_limit and usage.response_tokens >= limits.response_tokens_limit:
151182
raise UsageLimitExceeded(
152-
f"Total token limit {limits.total_tokens} exceeded"
183+
f"Response tokens limit {limits.response_tokens_limit} exceeded"
153184
)
154185

186+
def _format_tool_descriptions(self, tools: Sequence[ToolDefinition]) -> str:
187+
"""Format tool descriptions for system message.
188+
189+
Args:
190+
tools: Sequence of tool definitions
191+
192+
Returns:
193+
Formatted tool descriptions string
194+
"""
195+
if not tools:
196+
return ""
197+
198+
tool_descriptions = []
199+
for tool in tools:
200+
desc = f"- {tool.name}: {tool.description}\n"
201+
desc += f" Parameters: {json.dumps(tool.parameters_json_schema)}"
202+
tool_descriptions.append(desc)
203+
204+
return (
205+
"You have access to the following tools:\n\n"
206+
f"{chr(10).join(tool_descriptions)}\n\n"
207+
"To use a tool, respond with XML tags like this:\n"
208+
"<tool_call>{\"tool_name\": \"tool_name\", \"args\": {\"arg1\": \"value1\"}}</tool_call>\n"
209+
"Wait for the tool result before continuing."
210+
)
211+
212+
def _parse_tool_call(self, match: re.Match) -> Optional[ToolCallPart]:
213+
"""Parse a tool call match into a ToolCallPart.
214+
215+
Args:
216+
match: Regex match object containing tool call JSON
217+
218+
Returns:
219+
ToolCallPart if valid, None if invalid
220+
"""
221+
try:
222+
tool_data = json.loads(match.group('tool_json'))
223+
if isinstance(tool_data, dict) and 'tool_name' in tool_data and 'args' in tool_data:
224+
return ToolCallPart(
225+
tool_name=tool_data['tool_name'],
226+
args=tool_data['args']
227+
)
228+
except (json.JSONDecodeError, KeyError):
229+
pass
230+
return None
231+
232+
def _parse_response(self, text: str) -> list[ModelMessagePart]:
233+
"""Parse response text into message parts.
234+
235+
Args:
236+
text: Raw response text from model
237+
238+
Returns:
239+
List of ModelMessagePart objects
240+
"""
241+
parts = []
242+
current_text = []
243+
last_end = 0
244+
245+
# Find all tool calls in the response
246+
for match in self.TOOL_CALL_PATTERN.finditer(text):
247+
# Add any text before the tool call
248+
if match.start() > last_end:
249+
prefix_text = text[last_end:match.start()].strip()
250+
if prefix_text:
251+
current_text.append(prefix_text)
252+
253+
# Parse and add the tool call
254+
tool_call = self._parse_tool_call(match)
255+
if tool_call:
256+
# If we have accumulated text, add it first
257+
if current_text:
258+
parts.append(TextPart(text=' '.join(current_text)))
259+
current_text = []
260+
parts.append(tool_call)
261+
else:
262+
# If tool call parsing fails, treat it as regular text
263+
current_text.append(match.group(0))
264+
265+
last_end = match.end()
266+
267+
# Add any remaining text after the last tool call
268+
if last_end < len(text):
269+
remaining = text[last_end:].strip()
270+
if remaining:
271+
current_text.append(remaining)
272+
273+
# Add any accumulated text as final part
274+
if current_text:
275+
parts.append(TextPart(text=' '.join(current_text)))
276+
277+
return parts
278+
155279

156280
class UsageLimitExceeded(Exception):
157281
"""Raised when usage limits are exceeded"""
@@ -178,10 +302,12 @@ def cached_async_http_client(timeout: int = 600, connect: int = 5,
178302
The default timeouts match those of OpenAI,
179303
see <https://github.com/openai/openai-python/blob/v1.54.4/src/openai/_constants.py#L9>.
180304
"""
305+
181306
def factory() -> httpx.AsyncClient:
182307
return httpx.AsyncClient(
183308
headers={"User-Agent": get_user_agent()},
184309
timeout=timeout,
185310
base_url=base_url
186311
)
312+
187313
return factory

0 commit comments

Comments
 (0)