Skip to content

Commit d230bd5

Browse files
committed
fix: capture Usage, ChatResponseCached, and ToolResults
Additionally add tests to ensure these are captured properly. Signed-off-by: Donnie Adams <[email protected]>
1 parent 136573a commit d230bd5

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

gptscript/frame.py

+4
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def __init__(self,
140140
output: list[Output] = None,
141141
error: str = "",
142142
usage: Usage = None,
143+
chatResponseCached: bool = False,
144+
toolResults: int = 0,
143145
llmRequest: Any = None,
144146
llmResponse: Any = None,
145147
):
@@ -179,6 +181,8 @@ def __init__(self,
179181
self.usage = usage
180182
if isinstance(self.usage, dict):
181183
self.usage = Usage(**self.usage)
184+
self.chatResponseCached = chatResponseCached
185+
self.toolResults = toolResults
182186
self.llmRequest = llmRequest
183187
self.llmResponse = llmResponse
184188

tests/test_gptscript.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,18 @@ async def test_restart_failed_run(gptscript):
192192

193193
@pytest.mark.asyncio
194194
async def test_eval_simple_tool(gptscript, simple_tool):
195-
run = gptscript.evaluate(simple_tool)
195+
run = gptscript.evaluate(simple_tool, Options(disableCache=True))
196196
out = await run.text()
197+
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
198+
for c in run.calls().values():
199+
prompt_tokens += c.usage.promptTokens
200+
completion_tokens += c.usage.completionTokens
201+
total_tokens += c.usage.totalTokens
202+
197203
assert "Washington" in out, "Unexpected response for tool run"
204+
assert prompt_tokens > 0, "Unexpected promptTokens for tool run"
205+
assert completion_tokens > 0, "Unexpected completionTokens for tool run"
206+
assert total_tokens > 0, "Unexpected totalTokens for tool run"
198207

199208

200209
@pytest.mark.asyncio
@@ -210,6 +219,13 @@ async def test_eval_tool_list(gptscript, tool_list):
210219
out = await run.text()
211220
assert out.strip() == "hello there", "Unexpected output from eval using a list of tools"
212221

222+
# In this case, we expect the total number of toolResults to be 1
223+
total_tool_results = 0
224+
for c in run.calls().values():
225+
total_tool_results += c.toolResults
226+
227+
assert total_tool_results == 1, "Unexpected number of toolResults"
228+
213229

214230
@pytest.mark.asyncio
215231
async def test_eval_tool_list_with_sub_tool(gptscript, tool_list):
@@ -234,6 +250,23 @@ async def collect_events(run: Run, e: CallFrame | RunFrame | PromptFrame):
234250
assert '"artists":' in stream_output, "Expected stream_output to have output"
235251

236252

253+
@pytest.mark.asyncio
254+
async def test_simple_run_file(gptscript):
255+
cwd = os.getcwd().removesuffix("/tests")
256+
run = gptscript.run(cwd + "/tests/fixtures/test.gpt")
257+
out = await run.text()
258+
assert "Ronald Reagan" in out, "Expect run file to have correct output"
259+
260+
# Run again and make sure the output is the same, and the cache is used
261+
run = gptscript.run(cwd + "/tests/fixtures/test.gpt")
262+
second_out = await run.text()
263+
assert second_out == out, "Expect run file to have same output as previous run"
264+
265+
# In this case, we expect one cached call frame
266+
for c in run.calls().values():
267+
assert c.chatResponseCached, "Expect chatResponseCached to be true"
268+
269+
237270
@pytest.mark.asyncio
238271
async def test_stream_run_file(gptscript):
239272
stream_output = ""
@@ -687,11 +720,13 @@ async def test_parse_with_metadata_then_run(gptscript):
687720
run = gptscript.evaluate(tools[0])
688721
assert "200" == await run.text(), "Expect file to have correct output"
689722

723+
690724
@pytest.mark.asyncio
691725
async def test_credentials(gptscript):
692726
name = "test-" + str(os.urandom(4).hex())
693727
now = datetime.now()
694-
res = await gptscript.create_credential(Credential(toolName=name, env={"TEST": "test"}, expiresAt=now + timedelta(seconds=5)))
728+
res = await gptscript.create_credential(
729+
Credential(toolName=name, env={"TEST": "test"}, expiresAt=now + timedelta(seconds=5)))
695730
assert not res.startswith("an error occurred"), "Unexpected error creating credential: " + res
696731

697732
sleep(5)

0 commit comments

Comments
 (0)