1
1
#!/usr/bin/env python
2
2
# -*- coding: utf-8 -*-
3
3
4
+ import time
4
5
from collections import Counter , defaultdict
5
6
from dataclasses import dataclass , field
6
7
from abc import ABC , abstractmethod
@@ -21,7 +22,7 @@ class ChatMessage:
21
22
class ChatHistory :
22
23
def __init__ (self ):
23
24
self .messages = []
24
- self ._total_tokens = 0
25
+ self ._total_tokens = Counter ()
25
26
26
27
def __len__ (self ):
27
28
return len (self .messages )
@@ -31,11 +32,15 @@ def add(self, role, content):
31
32
32
33
def add_message (self , message : ChatMessage ):
33
34
self .messages .append (message )
34
- self ._total_tokens += message .usage [ "total_tokens" ]
35
+ self ._total_tokens += message .usage
35
36
36
- @property
37
- def total_tokens (self ):
38
- return self ._total_tokens
37
+ def get_usage (self ):
38
+ return iter (row .usage for row in self .messages if row .role == "assistant" )
39
+
40
+ def get_summary (self ):
41
+ summary = dict (self ._total_tokens )
42
+ summary ['rounds' ] = sum (1 for row in self .messages if row .role == "assistant" )
43
+ return summary
39
44
40
45
def get_messages (self ):
41
46
return [{"role" : msg .role , "content" : msg .content } for msg in self .messages ]
@@ -60,6 +65,10 @@ def get_completion(self, messages):
60
65
def add_system_prompt (self , history , system_prompt ):
61
66
history .add ("system" , system_prompt )
62
67
68
+ @abstractmethod
69
+ def parse_usage (self , response ):
70
+ pass
71
+
63
72
@abstractmethod
64
73
def parse_response (self , response ):
65
74
pass
@@ -70,16 +79,23 @@ def __call__(self, history, prompt, system_prompt=None):
70
79
self .add_system_prompt (history , system_prompt )
71
80
history .add ("user" , prompt )
72
81
82
+ start = time .time ()
73
83
response = self .get_completion (history .get_messages ())
84
+ end = time .time ()
74
85
if response :
75
86
msg = self .parse_response (response )
87
+ usage = self .parse_usage (response )
88
+ usage ['time' ] = round (end - start , 3 )
89
+ msg .usage = usage
76
90
history .add_message (msg )
77
91
if msg .reason :
78
92
response = f"{ T ('think' )} :\n ---\n { msg .reason } \n ---\n { msg .content } "
79
93
else :
80
94
response = msg .content
81
95
return response
82
-
96
+
97
+ # https://platform.openai.com/docs/api-reference/chat/create
98
+ # https://api-docs.deepseek.com/api/create-chat-completion
83
99
class OpenAIClient (BaseClient ):
84
100
def __init__ (self , config ):
85
101
super ().__init__ (config )
@@ -88,15 +104,20 @@ def __init__(self, config):
88
104
def add_system_prompt (self , history , system_prompt ):
89
105
history .add ("system" , system_prompt )
90
106
107
+ def parse_usage (self , response ):
108
+ usage = response .usage
109
+ return Counter ({'total_tokens' : usage .total_tokens ,
110
+ 'input_tokens' : usage .prompt_tokens ,
111
+ 'output_tokens' : usage .completion_tokens })
112
+
91
113
def parse_response (self , response ):
92
- usage = response .usage .model_dump ()
93
114
message = response .choices [0 ].message
94
115
reason = getattr (message , "reasoning_content" , None )
95
116
return ChatMessage (
96
117
role = message .role ,
97
118
content = message .content ,
98
- reason = reason ,
99
- usage = Counter ( usage ) )
119
+ reason = reason
120
+ )
100
121
101
122
def get_completion (self , messages ):
102
123
try :
@@ -110,12 +131,18 @@ def get_completion(self, messages):
110
131
self .console .print (f"❌ [bold red]{ self .name } API { T ('call_failed' )} : [yellow]{ str (e )} " )
111
132
response = None
112
133
return response
113
-
134
+
135
+ # https://github.com/ollama/ollama/blob/main/docs/api.md
114
136
class OllamaClient (BaseClient ):
115
137
def __init__ (self , config ):
116
138
super ().__init__ (config )
117
139
self ._session = requests .Session ()
118
140
141
+ def parse_usage (self , response ):
142
+ ret = Counter ({'input_tokens' : response ['prompt_eval_count' ], 'output_tokens' : response ['eval_count' ]})
143
+ ret ['total_tokens' ] = ret ['input_tokens' ] + ret ['output_tokens' ]
144
+ return ret
145
+
119
146
def parse_response (self , response ):
120
147
msg = response ["message" ]
121
148
return ChatMessage (role = msg ['role' ], content = msg ['content' ])
@@ -139,19 +166,22 @@ def get_completion(self, messages):
139
166
response = None
140
167
return response
141
168
169
+ # https://docs.anthropic.com/en/api/messages
142
170
class ClaudeClient (BaseClient ):
143
171
def __init__ (self , config ):
144
172
super ().__init__ (config )
145
173
self ._client = anthropic .Anthropic (api_key = self ._api_key , timeout = self ._timeout )
146
174
175
+ def parse_usage (self , response ):
176
+ usage = response .usage
177
+ ret = Counter ({'input_tokens' : usage .input_tokens , 'output_tokens' : usage .output_tokens })
178
+ ret ['total_tokens' ] = ret ['input_tokens' ] + ret ['output_tokens' ]
179
+ return ret
180
+
147
181
def parse_response (self , response ):
148
- usage = Counter (response .usage )
149
182
content = response .content [0 ].text
150
183
role = response .role
151
- return ChatMessage (
152
- role = role ,
153
- content = content ,
154
- usage = usage )
184
+ return ChatMessage (role = role , content = content )
155
185
156
186
def add_system_prompt (self , history , system_prompt ):
157
187
self ._system_prompt = system_prompt
0 commit comments