-
Notifications
You must be signed in to change notification settings - Fork 151
/
Copy pathagent.py
142 lines (120 loc) · 4.71 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import json
from typing import Callable
from computers import Computer
from utils import (
check_blocklisted_url,
create_response,
pp,
sanitize_message,
show_image,
)
class Agent:
"""
A sample agent class that can be used to interact with a computer.
(See simple_cua_loop.py for a simple example without an agent.)
"""
def __init__(
self,
model="computer-use-preview",
computer: Computer = None,
tools: list[dict] = [],
acknowledge_safety_check_callback: Callable = lambda: False,
):
self.model = model
self.computer = computer
self.tools = tools
self.print_steps = True
self.debug = False
self.show_images = False
self.acknowledge_safety_check_callback = acknowledge_safety_check_callback
if computer:
self.tools += [
{
"type": "computer_use_preview",
"display_width": computer.dimensions[0],
"display_height": computer.dimensions[1],
"environment": computer.environment,
},
]
def debug_print(self, *args):
if self.debug:
pp(*args)
def handle_item(self, item):
"""Handle each item; may cause a computer action + screenshot."""
if item["type"] == "message":
if self.print_steps:
print(item["content"][0]["text"])
if item["type"] == "function_call":
name, args = item["name"], json.loads(item["arguments"])
if self.print_steps:
print(f"{name}({args})")
if hasattr(self.computer, name): # if function exists on computer, call it
method = getattr(self.computer, name)
method(**args)
return [
{
"type": "function_call_output",
"call_id": item["call_id"],
"output": "success", # hard-coded output for demo
}
]
if item["type"] == "computer_call":
action = item["action"]
action_type = action["type"]
action_args = {k: v for k, v in action.items() if k != "type"}
if self.print_steps:
print(f"{action_type}({action_args})")
method = getattr(self.computer, action_type)
method(**action_args)
screenshot_base64 = self.computer.screenshot()
if self.show_images:
show_image(screenshot_base64)
# if user doesn't ack all safety checks exit with error
pending_checks = item.get("pending_safety_checks", [])
for check in pending_checks:
message = check["message"]
if not self.acknowledge_safety_check_callback(message):
raise ValueError(
f"Safety check failed: {message}. Cannot continue with unacknowledged safety checks."
)
call_output = {
"type": "computer_call_output",
"call_id": item["call_id"],
"acknowledged_safety_checks": pending_checks,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_base64}",
},
}
# additional URL safety checks for browser environments
if self.computer.environment == "browser":
current_url = self.computer.get_current_url()
check_blocklisted_url(current_url)
call_output["output"]["current_url"] = current_url
return [call_output]
return []
def run_full_turn(
self, input_items, print_steps=True, debug=False, show_images=False
):
self.print_steps = print_steps
self.debug = debug
self.show_images = show_images
new_items = []
# keep looping until we get a final response
while new_items[-1].get("role") != "assistant" if new_items else True:
self.debug_print([sanitize_message(msg) for msg in input_items + new_items])
response = create_response(
model=self.model,
input=input_items + new_items,
tools=self.tools,
truncation="auto",
)
self.debug_print(response)
if "output" not in response and self.debug:
print(response)
raise ValueError("No output from model")
else:
new_items += response["output"]
for item in response["output"]:
new_items += self.handle_item(item)
return new_items