Skip to content

Commit ff4a3a7

Browse files
committed
Add integration tests for the copilot provider
Since the copilot provider is a proxy, we add a "requester" module that depending on the provider makes a request either using raw python requests like earlier or by setting a proxy and using a CA cert file. To be able to add more tests, we also add more kinds of checks, in addition to the existing one which makes sure the reply is like the expected one using cosine distance, we also add checks that make sure the LLM reply contains or doesn't contain a string. We use those to add a test that ensures that the copilot provider chat works and that the copilot chat refuses to generate code snippet with a malicious package. To be able to run a subset of tests, we also add the ability to select a subset of tests based on a provider (`codegate_providers`) or the test name (`codegate_test_names`) These serve as the base for further integration tests. To run them, call: ``` CODEGATE_PROVIDERS=copilot \ CA_CERT_FILE=/Users/you/devel/codegate/codegate_volume/certs/ca.crt \ ENV_COPILOT_KEY=your-openapi-key \ python tests/integration/integration_tests.py ``` Related: #402
1 parent 19ffa83 commit ff4a3a7

File tree

4 files changed

+313
-52
lines changed

4 files changed

+313
-52
lines changed

tests/integration/checks.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List
3+
4+
import structlog
5+
from sklearn.metrics.pairwise import cosine_similarity
6+
7+
from codegate.inference.inference_engine import LlamaCppInferenceEngine
8+
9+
logger = structlog.get_logger("codegate")
10+
11+
12+
class BaseCheck(ABC):
13+
def __init__(self, test_name: str):
14+
self.test_name = test_name
15+
16+
@abstractmethod
17+
async def run_check(self, parsed_response: str, test_data: dict) -> bool:
18+
pass
19+
20+
21+
class CheckLoader:
22+
@staticmethod
23+
def load(test_data: dict) -> List[BaseCheck]:
24+
test_name = test_data.get("name")
25+
checks = []
26+
if test_data.get(DistanceCheck.KEY):
27+
checks.append(DistanceCheck(test_name))
28+
if test_data.get(ContainsCheck.KEY):
29+
checks.append(ContainsCheck(test_name))
30+
if test_data.get(DoesNotContainCheck.KEY):
31+
checks.append(DoesNotContainCheck(test_name))
32+
33+
return checks
34+
35+
36+
class DistanceCheck(BaseCheck):
37+
KEY = "likes"
38+
39+
def __init__(self, test_name: str):
40+
super().__init__(test_name)
41+
self.inference_engine = LlamaCppInferenceEngine()
42+
self.embedding_model = "codegate_volume/models/all-minilm-L6-v2-q5_k_m.gguf"
43+
44+
async def _calculate_string_similarity(self, str1, str2):
45+
vector1 = await self.inference_engine.embed(self.embedding_model, [str1])
46+
vector2 = await self.inference_engine.embed(self.embedding_model, [str2])
47+
similarity = cosine_similarity(vector1, vector2)
48+
return similarity[0]
49+
50+
async def run_check(self, parsed_response: str, test_data: dict) -> bool:
51+
similarity = await self._calculate_string_similarity(
52+
parsed_response, test_data[DistanceCheck.KEY]
53+
)
54+
if similarity < 0.8:
55+
logger.error(f"Test {self.test_name} failed")
56+
logger.error(f"Similarity: {similarity}")
57+
logger.error(f"Response: {parsed_response}")
58+
logger.error(f"Expected Response: {test_data[DistanceCheck.KEY]}")
59+
return False
60+
return True
61+
62+
63+
class ContainsCheck(BaseCheck):
64+
KEY = "contains"
65+
66+
async def run_check(self, parsed_response: str, test_data: dict) -> bool:
67+
if test_data[ContainsCheck.KEY].strip() not in parsed_response:
68+
logger.error(f"Test {self.test_name} failed")
69+
logger.error(f"Response: {parsed_response}")
70+
logger.error(f"Expected Response to contain: '{test_data[ContainsCheck.KEY]}'")
71+
return False
72+
return True
73+
74+
75+
class DoesNotContainCheck(BaseCheck):
76+
KEY = "does_not_contain"
77+
78+
async def run_check(self, parsed_response: str, test_data: dict) -> bool:
79+
if test_data[DoesNotContainCheck.KEY].strip() in parsed_response:
80+
logger.error(f"Test {self.test_name} failed")
81+
logger.error(f"Response: {parsed_response}")
82+
logger.error(
83+
f"Expected Response to not contain: '{test_data[DoesNotContainCheck.KEY]}'"
84+
)
85+
return False
86+
return True

tests/integration/integration_tests.py

+122-42
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,50 @@
22
import json
33
import os
44
import re
5+
from typing import Optional
56

67
import requests
78
import structlog
89
import yaml
10+
from checks import CheckLoader
911
from dotenv import find_dotenv, load_dotenv
10-
from sklearn.metrics.pairwise import cosine_similarity
11-
12-
from codegate.inference.inference_engine import LlamaCppInferenceEngine
12+
from requesters import RequesterFactory
1313

1414
logger = structlog.get_logger("codegate")
1515

1616

1717
class CodegateTestRunner:
1818
def __init__(self):
19-
self.inference_engine = LlamaCppInferenceEngine()
20-
self.embedding_model = "codegate_volume/models/all-minilm-L6-v2-q5_k_m.gguf"
19+
self.requester_factory = RequesterFactory()
20+
21+
def call_codegate(
22+
self, url: str, headers: dict, data: dict, provider: str
23+
) -> Optional[requests.Response]:
24+
logger.debug(f"Creating requester for provider: {provider}")
25+
requester = self.requester_factory.create_requester(provider)
26+
logger.debug(f"Using requester type: {requester.__class__.__name__}")
27+
28+
logger.debug(f"Making request to URL: {url}")
29+
logger.debug(f"Headers: {headers}")
30+
logger.debug(f"Data: {data}")
31+
32+
response = requester.make_request(url, headers, data)
33+
34+
# Enhanced response logging
35+
if response is not None:
36+
37+
if response.status_code != 200:
38+
logger.debug(f"Response error status: {response.status_code}")
39+
logger.debug(f"Response error headers: {dict(response.headers)}")
40+
try:
41+
error_content = response.json()
42+
logger.error(f"Request error as JSON: {error_content}")
43+
except ValueError:
44+
# If not JSON, try to get raw text
45+
logger.error(f"Raw request error: {response.text}")
46+
else:
47+
logger.error("No response received")
2148

22-
@staticmethod
23-
def call_codegate(url, headers, data):
24-
response = None
25-
try:
26-
response = requests.post(url, headers=headers, json=data)
27-
except Exception as e:
28-
logger.exception("An error occurred: %s", e)
2949
return response
3050

3151
@staticmethod
@@ -50,6 +70,8 @@ def parse_response_message(response, streaming=True):
5070

5171
message_content = None
5272
if "choices" in json_line:
73+
if "finish_reason" in json_line["choices"][0]:
74+
break
5375
if "delta" in json_line["choices"][0]:
5476
message_content = json_line["choices"][0]["delta"].get("content", "")
5577
elif "text" in json_line["choices"][0]:
@@ -75,12 +97,6 @@ def parse_response_message(response, streaming=True):
7597

7698
return response_message
7799

78-
async def calculate_string_similarity(self, str1, str2):
79-
vector1 = await self.inference_engine.embed(self.embedding_model, [str1])
80-
vector2 = await self.inference_engine.embed(self.embedding_model, [str2])
81-
similarity = cosine_similarity(vector1, vector2)
82-
return similarity[0]
83-
84100
@staticmethod
85101
def replace_env_variables(input_string, env):
86102
"""
@@ -103,51 +119,115 @@ def replacement(match):
103119
pattern = r"ENV\w*"
104120
return re.sub(pattern, replacement, input_string)
105121

106-
async def run_test(self, test, test_headers):
122+
async def run_test(self, test: dict, test_headers: dict) -> None:
107123
test_name = test["name"]
108124
url = test["url"]
109125
data = json.loads(test["data"])
110126
streaming = data.get("stream", False)
111-
response = CodegateTestRunner.call_codegate(url, test_headers, data)
112-
expected_response = test["expected"]
127+
provider = test["provider"]
128+
129+
response = self.call_codegate(url, test_headers, data, provider)
130+
if not response:
131+
logger.error(f"Test {test_name} failed: No response received")
132+
return
133+
134+
# Debug response info
135+
logger.debug(f"Response status: {response.status_code}")
136+
logger.debug(f"Response headers: {dict(response.headers)}")
137+
113138
try:
114-
parsed_response = CodegateTestRunner.parse_response_message(
115-
response, streaming=streaming
116-
)
117-
similarity = await self.calculate_string_similarity(parsed_response, expected_response)
118-
if similarity < 0.8:
119-
logger.error(f"Test {test_name} failed")
120-
logger.error(f"Similarity: {similarity}")
121-
logger.error(f"Response: {parsed_response}")
122-
logger.error(f"Expected Response: {expected_response}")
123-
else:
124-
logger.info(f"Test {test['name']} passed")
139+
parsed_response = self.parse_response_message(response, streaming=streaming)
140+
141+
# Load appropriate checks for this test
142+
checks = CheckLoader.load(test)
143+
144+
# Run all checks
145+
passed = True
146+
for check in checks:
147+
passed_check = await check.run_check(parsed_response, test)
148+
if not passed_check:
149+
passed = False
150+
logger.info(f"Test {test_name} passed" if passed else f"Test {test_name} failed")
151+
125152
except Exception as e:
126153
logger.exception("Could not parse response: %s", e)
127154

128-
async def run_tests(self, testcases_file):
155+
async def run_tests(
156+
self,
157+
testcases_file: str,
158+
providers: Optional[list[str]] = None,
159+
test_names: Optional[list[str]] = None,
160+
) -> None:
129161
with open(testcases_file, "r") as f:
130162
tests = yaml.safe_load(f)
131163

132164
headers = tests["headers"]
133-
for _, header_val in headers.items():
134-
if header_val is None:
135-
continue
136-
for key, val in header_val.items():
137-
header_val[key] = CodegateTestRunner.replace_env_variables(val, os.environ)
165+
testcases = tests["testcases"]
138166

139-
test_count = len(tests["testcases"])
167+
if providers or test_names:
168+
filtered_testcases = {}
140169

141-
logger.info(f"Running {test_count} tests")
142-
for _, test_data in tests["testcases"].items():
170+
for test_id, test_data in testcases.items():
171+
if providers:
172+
if test_data.get("provider", "").lower() not in [p.lower() for p in providers]:
173+
continue
174+
175+
if test_names:
176+
if test_data.get("name", "").lower() not in [t.lower() for t in test_names]:
177+
continue
178+
179+
filtered_testcases[test_id] = test_data
180+
181+
testcases = filtered_testcases
182+
183+
if not testcases:
184+
filter_msg = []
185+
if providers:
186+
filter_msg.append(f"providers: {', '.join(providers)}")
187+
if test_names:
188+
filter_msg.append(f"test names: {', '.join(test_names)}")
189+
logger.warning(f"No tests found for {' and '.join(filter_msg)}")
190+
return
191+
192+
test_count = len(testcases)
193+
filter_msg = []
194+
if providers:
195+
filter_msg.append(f"providers: {', '.join(providers)}")
196+
if test_names:
197+
filter_msg.append(f"test names: {', '.join(test_names)}")
198+
199+
logger.info(
200+
f"Running {test_count} tests"
201+
+ (f" for {' and '.join(filter_msg)}" if filter_msg else "")
202+
)
203+
204+
for test_id, test_data in testcases.items():
143205
test_headers = headers.get(test_data["provider"], {})
206+
test_headers = {
207+
k: self.replace_env_variables(v, os.environ) for k, v in test_headers.items()
208+
}
144209
await self.run_test(test_data, test_headers)
145210

146211

147212
async def main():
148213
load_dotenv(find_dotenv())
149214
test_runner = CodegateTestRunner()
150-
await test_runner.run_tests("./tests/integration/testcases.yaml")
215+
216+
# Get providers and test names from environment variables
217+
providers_env = os.environ.get("CODEGATE_PROVIDERS")
218+
test_names_env = os.environ.get("CODEGATE_TEST_NAMES")
219+
220+
providers = None
221+
if providers_env:
222+
providers = [p.strip() for p in providers_env.split(",") if p.strip()]
223+
224+
test_names = None
225+
if test_names_env:
226+
test_names = [t.strip() for t in test_names_env.split(",") if t.strip()]
227+
228+
await test_runner.run_tests(
229+
"./tests/integration/testcases.yaml", providers=providers, test_names=test_names
230+
)
151231

152232

153233
if __name__ == "__main__":

tests/integration/requesters.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import json
2+
import os
3+
from abc import ABC, abstractmethod
4+
from typing import Optional
5+
6+
import requests
7+
import structlog
8+
9+
logger = structlog.get_logger("codegate")
10+
11+
12+
class BaseRequester(ABC):
13+
@abstractmethod
14+
def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]:
15+
pass
16+
17+
18+
class StandardRequester(BaseRequester):
19+
def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]:
20+
# Ensure Content-Type is always set correctly
21+
headers["Content-Type"] = "application/json"
22+
23+
# Explicitly serialize to JSON string
24+
json_data = json.dumps(data)
25+
26+
return requests.post(
27+
url, headers=headers, data=json_data # Use data instead of json parameter
28+
)
29+
30+
31+
class CopilotRequester(BaseRequester):
32+
def make_request(self, url: str, headers: dict, data: dict) -> Optional[requests.Response]:
33+
# Ensure Content-Type is always set correctly
34+
headers["Content-Type"] = "application/json"
35+
36+
# Explicitly serialize to JSON string
37+
json_data = json.dumps(data)
38+
39+
return requests.post(
40+
url,
41+
data=json_data, # Use data instead of json parameter
42+
headers=headers,
43+
proxies={"https": "https://localhost:8990", "http": "http://localhost:8990"},
44+
verify=os.environ.get("CA_CERT_FILE"),
45+
stream=True,
46+
)
47+
48+
49+
class RequesterFactory:
50+
@staticmethod
51+
def create_requester(provider: str) -> BaseRequester:
52+
if provider.lower() == "copilot":
53+
return CopilotRequester()
54+
return StandardRequester()

0 commit comments

Comments
 (0)