5
5
import asyncio
6
6
from dataclasses import dataclass , field
7
7
from datetime import datetime
8
- from typing import Dict , Any
8
+ from typing import Dict , Any , Optional , List , Sequence
9
9
10
10
from pydantic import BaseModel
11
11
12
12
from ceylon .llm .models import Model , ModelSettings , ModelMessage
13
- from ceylon .llm .models .support .messages import MessageRole , TextPart
13
+ from ceylon .llm .models .support .messages import (
14
+ MessageRole ,
15
+ TextPart ,
16
+ ToolCallPart ,
17
+ ToolReturnPart ,
18
+ ModelMessagePart
19
+ )
20
+ from ceylon .llm .models .support .tools import ToolDefinition
14
21
from ceylon .processor .agent import ProcessWorker
15
22
from ceylon .processor .data import ProcessRequest
16
23
@@ -30,6 +37,8 @@ class LLMConfig(BaseModel):
30
37
retry_attempts : int = 3
31
38
retry_delay : float = 1.0
32
39
timeout : float = 30.0
40
+ tools : Optional [Sequence [ToolDefinition ]] = None
41
+ parallel_tool_calls : Optional [int ] = None
33
42
34
43
class Config :
35
44
arbitrary_types_allowed = True
@@ -58,33 +67,175 @@ def __init__(
58
67
self .response_cache : Dict [str , LLMResponse ] = {}
59
68
self .processing_lock = asyncio .Lock ()
60
69
61
- # Initialize model context with settings
70
+ # Initialize model context with settings and tools
62
71
self .model_context = self .llm_model .create_context (
63
72
settings = ModelSettings (
64
73
temperature = config .temperature ,
65
- max_tokens = config .max_tokens
66
- )
74
+ max_tokens = config .max_tokens ,
75
+ parallel_tool_calls = config .parallel_tool_calls
76
+ ),
77
+ tools = config .tools or []
67
78
)
68
79
69
- async def _processor (self , request : ProcessRequest , time : int ):
80
+ async def _process_tool_calls (
81
+ self ,
82
+ message_parts : List [ModelMessagePart ]
83
+ ) -> List [ModelMessagePart ]:
84
+ """Process any tool calls in the message parts and return updated parts."""
85
+ processed_parts = []
86
+
87
+ for part in message_parts :
88
+ if isinstance (part , ToolCallPart ):
89
+ try :
90
+ # Find the corresponding tool
91
+ tool = next (
92
+ (t for t in self .config .tools or []
93
+ if t .name == part .tool_name ),
94
+ None
95
+ )
96
+
97
+ if tool :
98
+ # Execute the tool
99
+ result = await tool .function (** part .args )
100
+
101
+ # Add the tool return
102
+ processed_parts .append (
103
+ ToolReturnPart (
104
+ tool_name = part .tool_name ,
105
+ content = result
106
+ )
107
+ )
108
+ else :
109
+ # Tool not found - add error message
110
+ processed_parts .append (
111
+ TextPart (
112
+ text = f"Error: Tool '{ part .tool_name } ' not found"
113
+ )
114
+ )
115
+ except Exception as e :
116
+ # Handle tool execution error
117
+ processed_parts .append (
118
+ TextPart (
119
+ text = f"Error executing tool '{ part .tool_name } ': { str (e )} "
120
+ )
121
+ )
122
+ else :
123
+ processed_parts .append (part )
124
+
125
+ return processed_parts
126
+
127
+ async def _process_conversation (
128
+ self ,
129
+ messages : List [ModelMessage ]
130
+ ) -> List [ModelMessage ]:
131
+ """Process a conversation, handling tool calls as needed."""
132
+ processed_messages = []
133
+
134
+ for message in messages :
135
+ if message .role == MessageRole .ASSISTANT :
136
+ # Process any tool calls in assistant messages
137
+ processed_parts = await self ._process_tool_calls (message .parts )
138
+ processed_messages .append (
139
+ ModelMessage (
140
+ role = message .role ,
141
+ parts = processed_parts
142
+ )
143
+ )
144
+ else :
145
+ processed_messages .append (message )
146
+
147
+ return processed_messages
148
+
149
+ def _parse_request_data (self , data : Any ) -> str :
150
+ """Parse the request data into a string format."""
151
+ if isinstance (data , str ):
152
+ return data
153
+ elif isinstance (data , dict ):
154
+ return data .get ("request" , str (data ))
155
+ else :
156
+ return str (data )
157
+
158
+ async def _processor (self , request : ProcessRequest , time : int ) -> tuple [str , Dict [str , Any ]]:
159
+ """Process a request using the LLM model."""
160
+ # Initialize conversation with system prompt
70
161
message_list = [
71
162
ModelMessage (
72
163
role = MessageRole .SYSTEM ,
73
- parts = [
74
- TextPart (text = self .config .system_prompt )
75
- ]
76
- ),
164
+ parts = [TextPart (text = self .config .system_prompt )]
165
+ )
166
+ ]
167
+
168
+ # Add user message
169
+ user_text = self ._parse_request_data (request .data )
170
+ message_list .append (
77
171
ModelMessage (
78
172
role = MessageRole .USER ,
79
- parts = [
80
- TextPart (text = request .data )
81
- ]
173
+ parts = [TextPart (text = user_text )]
82
174
)
83
- ]
175
+ )
176
+
177
+ # Track the complete conversation
178
+ complete_conversation = message_list .copy ()
179
+ final_response = None
180
+ metadata = {}
181
+
182
+ for attempt in range (self .config .retry_attempts ):
183
+ try :
184
+ # Get model response
185
+ response , usage = await self .llm_model .request (
186
+ message_list ,
187
+ self .model_context
188
+ )
189
+
190
+ # Add model response to conversation
191
+ assistant_message = ModelMessage (
192
+ role = MessageRole .ASSISTANT ,
193
+ parts = response .parts
194
+ )
195
+ complete_conversation .append (assistant_message )
196
+
197
+ # Process any tool calls
198
+ complete_conversation = await self ._process_conversation (
199
+ complete_conversation
200
+ )
201
+
202
+ # Extract final text response
203
+ final_text_parts = [
204
+ part .text for part in response .parts
205
+ if isinstance (part , TextPart )
206
+ ]
207
+ final_response = " " .join (final_text_parts )
208
+
209
+ # Update metadata
210
+ metadata .update ({
211
+ "usage" : {
212
+ "requests" : usage .requests ,
213
+ "request_tokens" : usage .request_tokens ,
214
+ "response_tokens" : usage .response_tokens ,
215
+ "total_tokens" : usage .total_tokens
216
+ },
217
+ "attempt" : attempt + 1 ,
218
+ "tools_used" : [
219
+ part .tool_name for part in response .parts
220
+ if isinstance (part , ToolCallPart )
221
+ ]
222
+ })
223
+
224
+ # If we got a response, break the retry loop
225
+ if final_response :
226
+ break
227
+
228
+ except Exception as e :
229
+ if attempt == self .config .retry_attempts - 1 :
230
+ raise
231
+ await asyncio .sleep (self .config .retry_delay )
232
+
233
+ if not final_response :
234
+ raise ValueError ("No valid response generated" )
84
235
85
- return await self . llm_model . request ( message_list , self . model_context )
236
+ return final_response , metadata
86
237
87
238
async def stop (self ) -> None :
88
239
if self .llm_model :
89
240
await self .llm_model .close ()
90
- await super ().stop ()
241
+ await super ().stop ()
0 commit comments