15
15
import requests
16
16
import openai
17
17
from anthropic import Anthropic
18
+ import tiktoken
18
19
19
20
20
21
class OllamaAPI :
21
22
def __init__ (self , model ):
22
23
self .model = model
23
24
self .base_url = "http://localhost:11434/api"
25
+ self .token_count = 0
24
26
27
+ def count_tokens (self , text ):
28
+ # Ollama doesn't provide a built-in token counter, so we'll use tiktoken as an approximation
29
+ return len (tiktoken .encoding_for_model ("gpt-4o" ).encode (text ))
30
+
25
31
def generate (self , prompt ):
26
32
url = f"{ self .base_url } /generate"
27
33
data = {"model" : self .model , "prompt" : prompt , "stream" : True }
@@ -39,6 +45,18 @@ def generate(self, prompt):
39
45
except json .JSONDecodeError :
40
46
print (f"Error decoding JSON: { decoded_line } " )
41
47
print () # Print a newline at the end
48
+
49
+ # Extract content between markers
50
+ start_marker = "^^^start^^^"
51
+ end_marker = "^^^end^^^"
52
+ start_index = full_response .find (start_marker )
53
+ end_index = full_response .find (end_marker )
54
+ if start_index != - 1 and end_index != - 1 :
55
+ full_response = full_response [start_index + len (start_marker ):end_index ].strip ()
56
+
57
+ self .token_count = self .count_tokens (full_response )
58
+ print (f"Token count: { self .token_count } " )
59
+
42
60
return full_response
43
61
else :
44
62
raise Exception (f"Ollama API error: { response .text } " )
@@ -53,6 +71,10 @@ def __init__(self, model):
53
71
if not self .api_key :
54
72
raise ValueError ("OPENAI_API_KEY environment variable is not set" )
55
73
openai .api_key = self .api_key
74
+ self .token_count = 0
75
+
76
+ def count_tokens (self , text ):
77
+ return len (tiktoken .encoding_for_model (self .model ).encode (text ))
56
78
57
79
def generate (self , prompt ):
58
80
try :
@@ -68,11 +90,24 @@ def generate(self, prompt):
68
90
full_response += chunk_text
69
91
print (chunk_text , end = "" , flush = True )
70
92
print () # Print a newline at the end
93
+
94
+ # Extract content between markers
95
+ start_marker = "^^^start^^^"
96
+ end_marker = "^^^end^^^"
97
+ start_index = full_response .find (start_marker )
98
+ end_index = full_response .find (end_marker )
99
+ if start_index != - 1 and end_index != - 1 :
100
+ full_response = full_response [start_index + len (start_marker ):end_index ].strip ()
101
+
102
+ self .token_count = self .count_tokens (full_response )
103
+ print (f"Token count: { self .token_count } " )
104
+
71
105
return full_response
72
106
except Exception as e :
73
107
raise Exception (f"OpenAI API error: { str (e )} " )
74
108
75
109
110
+
76
111
class ClaudeAPI :
77
112
def __init__ (self , model ):
78
113
if model == "mistral-nemo" :
@@ -82,22 +117,53 @@ def __init__(self, model):
82
117
if not self .api_key :
83
118
raise ValueError ("ANTHROPIC_API_KEY environment variable is not set" )
84
119
self .client = Anthropic (api_key = self .api_key )
120
+ self .token_count = 0
121
+
122
+ def count_tokens (self , text ):
123
+ # Ollama doesn't provide a built-in token counter, so we'll use tiktoken as an approximation
124
+ return len (tiktoken .encoding_for_model ("gpt-4o" ).encode (text ))
85
125
86
126
def generate (self , prompt ):
87
127
try :
88
- response = self .client .messages .create (
89
- model = self .model ,
90
- messages = [{"role" : "user" , "content" : prompt }],
91
- stream = True ,
92
- max_tokens = 1000 ,
93
- )
94
128
full_response = ""
95
- for completion in response :
96
- if completion .type == "content_block_delta" :
97
- chunk_text = completion .delta .text
98
- full_response += chunk_text
99
- print (chunk_text , end = "" , flush = True )
129
+ max_iterations = 5 # Adjust this value as needed
130
+ continuation_prompt = prompt
131
+
132
+ for iteration in range (max_iterations ):
133
+ response = self .client .messages .create (
134
+ model = self .model ,
135
+ messages = [{"role" : "user" , "content" : continuation_prompt }],
136
+ stream = True ,
137
+ max_tokens = 1000 ,
138
+ )
139
+
140
+ chunk_response = ""
141
+ for completion in response :
142
+ if completion .type == "content_block_delta" :
143
+ chunk_text = completion .delta .text
144
+ chunk_response += chunk_text
145
+ print (chunk_text , end = "" , flush = True )
146
+
147
+ full_response += chunk_response
148
+
149
+ if "^^^end^^^" in chunk_response :
150
+ break
151
+
152
+ continuation_prompt = f"Continue from where you left off. Previous response: { chunk_response } "
153
+
100
154
print () # Print a newline at the end
155
+
156
+ # Extract content between markers
157
+ start_marker = "^^^start^^^"
158
+ end_marker = "^^^end^^^"
159
+ start_index = full_response .find (start_marker )
160
+ end_index = full_response .find (end_marker )
161
+ if start_index != - 1 and end_index != - 1 :
162
+ full_response = full_response [start_index + len (start_marker ):end_index ].strip ()
163
+
164
+ self .token_count = self .count_tokens (full_response )
165
+ print (f"Token count: { self .token_count } " )
166
+
101
167
return full_response
102
168
except Exception as e :
103
169
raise Exception (f"Claude API error: { str (e )} " )
@@ -119,6 +185,11 @@ def __init__(
119
185
self .pwd = os .getcwd () + "/" + self .project_name
120
186
self .llm = self .setup_llm ()
121
187
self .previous_suggestions = set ()
188
+ self .token_counts = {}
189
+
190
+ def count_tokens (self , text ):
191
+ encoding = tiktoken .encoding_for_model ("gpt-3.5-turbo" )
192
+ return len (encoding .encode (text ))
122
193
123
194
def setup_llm (self ):
124
195
if self .provider == "ollama" :
@@ -169,6 +240,9 @@ def run_task(self):
169
240
break
170
241
test_check_attempts += 1
171
242
243
+ total_tokens = sum (self .token_counts .values ())
244
+ print (f"\n Total tokens used: { total_tokens } " )
245
+
172
246
print (
173
247
"Task completed. Please review the output and make any necessary manual adjustments."
174
248
)
@@ -303,7 +377,7 @@ def implement_solution(self, max_attempts=3):
303
377
prompt = f"""
304
378
Create a comprehensive implementation for the task: { self .task } .
305
379
You must follow these rules strictly:
306
- 1, IMPORTANT: Never use pass statements in your code or tests. Always provide a meaningful implementation.
380
+ 1. IMPORTANT: Never use pass statements in your code or tests. Always provide a meaningful implementation.
307
381
2. CRITICAL: Use the following code block format for specifying file content:
308
382
For code files, use:
309
383
<<<main.py>>>
@@ -332,6 +406,7 @@ def implement_solution(self, max_attempts=3):
332
406
13. IMPORTANT: Always pytest parameterize tests for different cases.
333
407
14. CRITICAL: Always use `import main` to import the main.py file in the test file.
334
408
15. IMPORTANT: Only mock external services or APIs in tests.
409
+ 16. IMPORTANT: Enclose your entire response between ^^^start^^^ and ^^^end^^^ markers.
335
410
Working directory: { self .pwd }
336
411
"""
337
412
@@ -340,6 +415,14 @@ def implement_solution(self, max_attempts=3):
340
415
solution = self .get_response (prompt )
341
416
self .logger .info (f"Received solution:\n { solution } " )
342
417
418
+ # Extract content between markers
419
+ start_marker = "^^^start^^^"
420
+ end_marker = "^^^end^^^"
421
+ start_index = solution .find (start_marker )
422
+ end_index = solution .find (end_marker )
423
+ if start_index != - 1 and end_index != - 1 :
424
+ solution = solution [start_index + len (start_marker ):end_index ].strip ()
425
+
343
426
# Parse and execute any uv add commands
344
427
uv_commands = [
345
428
line .strip ()
@@ -392,11 +475,14 @@ def extract_file_contents_direct(self, solution):
392
475
393
476
def get_response (self , prompt ):
394
477
try :
395
- return self .llm .generate (prompt )
478
+ response = self .llm .generate (prompt )
479
+ prompt_key = prompt [:50 ] # Use first 50 characters as a key
480
+ self .token_counts [prompt_key ] = self .llm .token_count
481
+ return response
396
482
except Exception as e :
397
483
self .logger .error (f"Error getting response from { self .provider } : { str (e )} " )
398
484
return ""
399
-
485
+
400
486
def code_check (self , file_path ):
401
487
try :
402
488
# Run autopep8 to automatically fix style issues
@@ -498,6 +584,7 @@ def improve_test_file(self, test_output):
498
584
8. IMPORTANT: Always pytest parameterize tests for different cases.
499
585
9. CRITICAL: Always use `import main` to import the main.py file in the test file.
500
586
10. IMPORTANT: Only mock external services or APIs in tests.
587
+ 11. IMPORTANT: Enclose your entire response between ^^^start^^^ and ^^^end^^^ markers.
501
588
Working directory: { self .pwd }
502
589
"""
503
590
proposed_improvements = self .get_response (prompt )
@@ -522,6 +609,7 @@ def validate_implementation(self, proposed_improvements):
522
609
If the implementation is correct or mostly correct, respond with 'VALID'.
523
610
If the implementation is completely unrelated or fundamentally flawed, respond with 'INVALID'.
524
611
Do not provide any additional information or explanations.
612
+ IMPORTANT: Enclose your entire response between ^^^start^^^ and ^^^end^^^ markers.
525
613
"""
526
614
response = self .get_response (prompt )
527
615
@@ -567,6 +655,7 @@ def improve_code(
567
655
# File content here
568
656
<<<end>>>
569
657
7. CRITICAL: Do not explain the task only implement the required functionality in the code blocks.
658
+ 8. IMPORTANT: Enclose your entire response between ^^^start^^^ and ^^^end^^^ markers.
570
659
Working directory: { self .pwd }
571
660
"""
572
661
proposed_improvements = self .get_response (prompt )
0 commit comments