-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathtest_grammar_response_format_llama.py
128 lines (109 loc) · 3.93 KB
/
test_grammar_response_format_llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import pytest
import requests
from pydantic import BaseModel
from typing import List
@pytest.fixture(scope="module")
def llama_grammar_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
num_shard=1,
disable_grammar_support=False,
use_flash_attention=False,
max_batch_prefill_tokens=3000,
) as handle:
yield handle
@pytest.fixture(scope="module")
async def llama_grammar(llama_grammar_handle):
await llama_grammar_handle.health(300)
return llama_grammar_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
class Weather(BaseModel):
unit: str
temperature: List[int]
json_payload={
"model": "tgi",
"messages": [
{
"role": "system",
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
"seed": 42,
"max_tokens": 500,
"response_format": {"type": "json_object", "value": Weather.schema()},
}
# send the request
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json=json_payload,
)
chat_completion = response.json()
called = chat_completion["choices"][0]["message"]["content"]
assert response.status_code == 200
assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
assert chat_completion == response_snapshot
json_payload["response_format"]["type"] = "json"
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json=json_payload,
)
chat_completion = response.json()
called = chat_completion["choices"][0]["message"]["content"]
assert response.status_code == 200
assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
assert chat_completion == response_snapshot
json_payload["response_format"]["type"] = "json_schema"
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json=json_payload,
)
chat_completion = response.json()
called = chat_completion["choices"][0]["message"]["content"]
assert response.status_code == 200
assert called == '{ "unit": "fahrenheit", "temperature": [ 72, 79, 88 ] }'
assert chat_completion == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_grammar_response_format_llama_error_if_tools_not_installed(
llama_grammar,
):
class Weather(BaseModel):
unit: str
temperature: List[int]
# send the request
response = requests.post(
f"{llama_grammar.base_url}/v1/chat/completions",
headers=llama_grammar.headers,
json={
"model": "tgi",
"messages": [
{
"role": "system",
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
"seed": 42,
"max_tokens": 500,
"tools": [],
"response_format": {"type": "json_object", "value": Weather.schema()},
},
)
# 422 means the server was unable to process the request because it contains invalid data.
assert response.status_code == 422
assert response.json() == {
"error": "Tool error: Grammar and tools are mutually exclusive",
"error_type": "tool_error",
}