2
2
import json
3
3
import os
4
4
import re
5
+ from typing import Optional
5
6
6
7
import requests
7
8
import structlog
8
9
import yaml
10
+ from checks import CheckLoader
9
11
from dotenv import find_dotenv , load_dotenv
10
- from sklearn .metrics .pairwise import cosine_similarity
11
-
12
- from codegate .inference .inference_engine import LlamaCppInferenceEngine
12
+ from requesters import RequesterFactory
13
13
14
14
logger = structlog .get_logger ("codegate" )
15
15
16
16
17
17
class CodegateTestRunner :
18
18
def __init__ (self ):
19
- self .inference_engine = LlamaCppInferenceEngine ()
20
- self .embedding_model = "codegate_volume/models/all-minilm-L6-v2-q5_k_m.gguf"
19
+ self .requester_factory = RequesterFactory ()
20
+
21
+ def call_codegate (
22
+ self , url : str , headers : dict , data : dict , provider : str
23
+ ) -> Optional [requests .Response ]:
24
+ logger .debug (f"Creating requester for provider: { provider } " )
25
+ requester = self .requester_factory .create_requester (provider )
26
+ logger .debug (f"Using requester type: { requester .__class__ .__name__ } " )
27
+
28
+ logger .debug (f"Making request to URL: { url } " )
29
+ logger .debug (f"Headers: { headers } " )
30
+ logger .debug (f"Data: { data } " )
31
+
32
+ response = requester .make_request (url , headers , data )
33
+
34
+ # Enhanced response logging
35
+ if response is not None :
36
+
37
+ if response .status_code != 200 :
38
+ logger .debug (f"Response error status: { response .status_code } " )
39
+ logger .debug (f"Response error headers: { dict (response .headers )} " )
40
+ try :
41
+ error_content = response .json ()
42
+ logger .error (f"Request error as JSON: { error_content } " )
43
+ except ValueError :
44
+ # If not JSON, try to get raw text
45
+ logger .error (f"Raw request error: { response .text } " )
46
+ else :
47
+ logger .error ("No response received" )
21
48
22
- @staticmethod
23
- def call_codegate (url , headers , data ):
24
- response = None
25
- try :
26
- response = requests .post (url , headers = headers , json = data )
27
- except Exception as e :
28
- logger .exception ("An error occurred: %s" , e )
29
49
return response
30
50
31
51
@staticmethod
@@ -50,6 +70,8 @@ def parse_response_message(response, streaming=True):
50
70
51
71
message_content = None
52
72
if "choices" in json_line :
73
+ if "finish_reason" in json_line ["choices" ][0 ]:
74
+ break
53
75
if "delta" in json_line ["choices" ][0 ]:
54
76
message_content = json_line ["choices" ][0 ]["delta" ].get ("content" , "" )
55
77
elif "text" in json_line ["choices" ][0 ]:
@@ -75,12 +97,6 @@ def parse_response_message(response, streaming=True):
75
97
76
98
return response_message
77
99
78
- async def calculate_string_similarity (self , str1 , str2 ):
79
- vector1 = await self .inference_engine .embed (self .embedding_model , [str1 ])
80
- vector2 = await self .inference_engine .embed (self .embedding_model , [str2 ])
81
- similarity = cosine_similarity (vector1 , vector2 )
82
- return similarity [0 ]
83
-
84
100
@staticmethod
85
101
def replace_env_variables (input_string , env ):
86
102
"""
@@ -103,51 +119,115 @@ def replacement(match):
103
119
pattern = r"ENV\w*"
104
120
return re .sub (pattern , replacement , input_string )
105
121
106
- async def run_test (self , test , test_headers ) :
122
+ async def run_test (self , test : dict , test_headers : dict ) -> None :
107
123
test_name = test ["name" ]
108
124
url = test ["url" ]
109
125
data = json .loads (test ["data" ])
110
126
streaming = data .get ("stream" , False )
111
- response = CodegateTestRunner .call_codegate (url , test_headers , data )
112
- expected_response = test ["expected" ]
127
+ provider = test ["provider" ]
128
+
129
+ response = self .call_codegate (url , test_headers , data , provider )
130
+ if not response :
131
+ logger .error (f"Test { test_name } failed: No response received" )
132
+ return
133
+
134
+ # Debug response info
135
+ logger .debug (f"Response status: { response .status_code } " )
136
+ logger .debug (f"Response headers: { dict (response .headers )} " )
137
+
113
138
try :
114
- parsed_response = CodegateTestRunner .parse_response_message (
115
- response , streaming = streaming
116
- )
117
- similarity = await self .calculate_string_similarity (parsed_response , expected_response )
118
- if similarity < 0.8 :
119
- logger .error (f"Test { test_name } failed" )
120
- logger .error (f"Similarity: { similarity } " )
121
- logger .error (f"Response: { parsed_response } " )
122
- logger .error (f"Expected Response: { expected_response } " )
123
- else :
124
- logger .info (f"Test { test ['name' ]} passed" )
139
+ parsed_response = self .parse_response_message (response , streaming = streaming )
140
+
141
+ # Load appropriate checks for this test
142
+ checks = CheckLoader .load (test )
143
+
144
+ # Run all checks
145
+ passed = True
146
+ for check in checks :
147
+ passed_check = await check .run_check (parsed_response , test )
148
+ if not passed_check :
149
+ passed = False
150
+ logger .info (f"Test { test_name } passed" if passed else f"Test { test_name } failed" )
151
+
125
152
except Exception as e :
126
153
logger .exception ("Could not parse response: %s" , e )
127
154
128
- async def run_tests (self , testcases_file ):
155
+ async def run_tests (
156
+ self ,
157
+ testcases_file : str ,
158
+ providers : Optional [list [str ]] = None ,
159
+ test_names : Optional [list [str ]] = None ,
160
+ ) -> None :
129
161
with open (testcases_file , "r" ) as f :
130
162
tests = yaml .safe_load (f )
131
163
132
164
headers = tests ["headers" ]
133
- for _ , header_val in headers .items ():
134
- if header_val is None :
135
- continue
136
- for key , val in header_val .items ():
137
- header_val [key ] = CodegateTestRunner .replace_env_variables (val , os .environ )
165
+ testcases = tests ["testcases" ]
138
166
139
- test_count = len (tests ["testcases" ])
167
+ if providers or test_names :
168
+ filtered_testcases = {}
140
169
141
- logger .info (f"Running { test_count } tests" )
142
- for _ , test_data in tests ["testcases" ].items ():
170
+ for test_id , test_data in testcases .items ():
171
+ if providers :
172
+ if test_data .get ("provider" , "" ).lower () not in [p .lower () for p in providers ]:
173
+ continue
174
+
175
+ if test_names :
176
+ if test_data .get ("name" , "" ).lower () not in [t .lower () for t in test_names ]:
177
+ continue
178
+
179
+ filtered_testcases [test_id ] = test_data
180
+
181
+ testcases = filtered_testcases
182
+
183
+ if not testcases :
184
+ filter_msg = []
185
+ if providers :
186
+ filter_msg .append (f"providers: { ', ' .join (providers )} " )
187
+ if test_names :
188
+ filter_msg .append (f"test names: { ', ' .join (test_names )} " )
189
+ logger .warning (f"No tests found for { ' and ' .join (filter_msg )} " )
190
+ return
191
+
192
+ test_count = len (testcases )
193
+ filter_msg = []
194
+ if providers :
195
+ filter_msg .append (f"providers: { ', ' .join (providers )} " )
196
+ if test_names :
197
+ filter_msg .append (f"test names: { ', ' .join (test_names )} " )
198
+
199
+ logger .info (
200
+ f"Running { test_count } tests"
201
+ + (f" for { ' and ' .join (filter_msg )} " if filter_msg else "" )
202
+ )
203
+
204
+ for test_id , test_data in testcases .items ():
143
205
test_headers = headers .get (test_data ["provider" ], {})
206
+ test_headers = {
207
+ k : self .replace_env_variables (v , os .environ ) for k , v in test_headers .items ()
208
+ }
144
209
await self .run_test (test_data , test_headers )
145
210
146
211
147
212
async def main ():
148
213
load_dotenv (find_dotenv ())
149
214
test_runner = CodegateTestRunner ()
150
- await test_runner .run_tests ("./tests/integration/testcases.yaml" )
215
+
216
+ # Get providers and test names from environment variables
217
+ providers_env = os .environ .get ("CODEGATE_PROVIDERS" )
218
+ test_names_env = os .environ .get ("CODEGATE_TEST_NAMES" )
219
+
220
+ providers = None
221
+ if providers_env :
222
+ providers = [p .strip () for p in providers_env .split ("," ) if p .strip ()]
223
+
224
+ test_names = None
225
+ if test_names_env :
226
+ test_names = [t .strip () for t in test_names_env .split ("," ) if t .strip ()]
227
+
228
+ await test_runner .run_tests (
229
+ "./tests/integration/testcases.yaml" , providers = providers , test_names = test_names
230
+ )
151
231
152
232
153
233
if __name__ == "__main__" :
0 commit comments