Skip to content

Commit

Permalink
feat: permission before tool call (#1313)
Browse files Browse the repository at this point in the history
  • Loading branch information
wendytang authored Feb 24, 2025
1 parent bef7551 commit 3723c64
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 33 deletions.
51 changes: 44 additions & 7 deletions crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use goose::agents::Agent;
use goose::message::{Message, MessageContent};
use mcp_core::handler::ToolError;
use rand::{distributions::Alphanumeric, Rng};
use rustyline::Editor;
use std::path::PathBuf;
use tokio;

Expand Down Expand Up @@ -104,7 +105,7 @@ impl Session {
}

pub async fn start(&mut self) -> Result<()> {
let mut editor = rustyline::Editor::<(), rustyline::history::DefaultHistory>::new()?;
let mut editor = Editor::<(), rustyline::history::DefaultHistory>::new()?;

// Load history from messages
for msg in self
Expand All @@ -120,7 +121,6 @@ impl Session {
}
}
}

output::display_greeting();
loop {
match input::get_input(&mut editor)? {
Expand All @@ -129,7 +129,7 @@ impl Session {
storage::persist_messages(&self.session_file, &self.messages)?;

output::show_thinking();
self.process_agent_response().await?;
self.process_agent_response(&mut editor).await?;
output::hide_thinking();
}
input::InputResult::Exit => break,
Expand Down Expand Up @@ -188,20 +188,57 @@ impl Session {
self.messages
.push(Message::user().with_text(&initial_message));
storage::persist_messages(&self.session_file, &self.messages)?;
self.process_agent_response().await?;
let mut editor = Editor::<(), rustyline::history::DefaultHistory>::new()?;
self.process_agent_response(&mut editor).await?;
Ok(())
}

async fn process_agent_response(&mut self) -> Result<()> {
async fn process_agent_response(
&mut self,
editor: &mut Editor<(), rustyline::history::DefaultHistory>,
) -> Result<()> {
let mut stream = self.agent.reply(&self.messages).await?;

use futures::StreamExt;
loop {
tokio::select! {
result = stream.next() => {
match result {
Some(Ok(message)) => {
self.messages.push(message.clone());
Some(Ok(mut message)) => {

// Handle tool confirmation requests before rendering
if let Some(MessageContent::ToolConfirmationRequest(confirmation)) = message.content.first() {
output::hide_thinking();

// Format the confirmation prompt
let prompt = "Goose would like to call the above tool. Allow? (y/n):".to_string();

let confirmation_request = Message::user().with_tool_confirmation_request(
confirmation.id.clone(),
confirmation.tool_name.clone(),
confirmation.arguments.clone(),
Some(prompt)
);
output::render_message(&confirmation_request);

// Get confirmation from user
let confirmed = match input::get_input(editor)? {
input::InputResult::Message(content) => {
content.trim().to_lowercase().starts_with('y')
}
_ => false,
};

self.agent.handle_confirmation(confirmation.id.clone(), confirmed).await;

message = confirmation_request;
}

// Only push the message if it's not a tool confirmation request
if !message.content.iter().any(|content| matches!(content, MessageContent::ToolConfirmationRequest(_))) {
self.messages.push(message.clone());
}

storage::persist_messages(&self.session_file, &self.messages)?;
output::hide_thinking();
output::render_message(&message);
Expand Down
16 changes: 15 additions & 1 deletion crates/goose-cli/src/session/output.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use bat::WrappingMode;
use console::style;
use goose::config::Config;
use goose::message::{Message, MessageContent, ToolRequest, ToolResponse};
use goose::message::{Message, MessageContent, ToolConfirmationRequest, ToolRequest, ToolResponse};
use mcp_core::tool::ToolCall;
use serde_json::Value;
use std::cell::RefCell;
Expand Down Expand Up @@ -94,6 +94,9 @@ pub fn render_message(message: &Message) {
MessageContent::Text(text) => print_markdown(&text.text, theme),
MessageContent::ToolRequest(req) => render_tool_request(req, theme),
MessageContent::ToolResponse(resp) => render_tool_response(resp, theme),
MessageContent::ToolConfirmationRequest(req) => {
render_tool_confirmation_request(req, theme)
}
MessageContent::Image(image) => {
println!("Image: [data: {}, type: {}]", image.data, image.mime_type);
}
Expand Down Expand Up @@ -147,6 +150,17 @@ fn render_tool_response(resp: &ToolResponse, theme: Theme) {
}
}

fn render_tool_confirmation_request(req: &ToolConfirmationRequest, theme: Theme) {
match &req.prompt {
Some(prompt) => {
let colored_prompt =
prompt.replace("Allow? (y/n)", &format!("{}", style("Allow? (y/n)").cyan()));
println!("{}", colored_prompt);
}
None => print_markdown("No prompt provided", theme),
}
}

pub fn render_error(message: &str) {
println!("\n {} {}\n", style("error:").red().bold(), message);
}
Expand Down
13 changes: 7 additions & 6 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,14 @@ async fn stream_message(
.await?;
}
}
MessageContent::ToolConfirmationRequest(_) => {
// skip tool confirmation requests
}
MessageContent::Image(_) => {
// TODO
continue;
// skip images
}
MessageContent::ToolResponse(_) => {
// Tool responses should only come from the user
continue;
// skip tool responses
}
}
}
Expand Down Expand Up @@ -311,7 +312,7 @@ async fn handler(
let mut stream = match agent.reply(&messages).await {
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {}", e);
tracing::error!("Failed to start reply stream: {:?}", e);
let _ = tx
.send(ProtocolFormatter::format_error(&e.to_string()))
.await;
Expand Down Expand Up @@ -398,7 +399,7 @@ async fn ask_handler(
let mut stream = match agent.reply(&messages).await {
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {}", e);
tracing::error!("Failed to start reply stream: {:?}", e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
Expand Down
3 changes: 3 additions & 0 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ pub trait Agent: Send + Sync {
/// Add custom text to be included in the system prompt
async fn extend_system_prompt(&mut self, extension: String);

/// Handle a confirmation response for a tool request
async fn handle_confirmation(&self, request_id: String, confirmed: bool);

/// Override the system prompt with custom text
async fn override_system_prompt(&mut self, template: String);
}
4 changes: 4 additions & 0 deletions crates/goose/src/agents/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ impl Agent for ReferenceAgent {
Ok(Value::Null)
}

async fn handle_confirmation(&self, _request_id: String, _confirmed: bool) {
// TODO implement
}

#[instrument(skip(self, messages), fields(user_message))]
async fn reply(
&self,
Expand Down
119 changes: 100 additions & 19 deletions crates/goose/src/agents/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
/// It makes no attempt to handle context limits, and cannot read resources
use async_trait::async_trait;
use futures::stream::BoxStream;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tracing::{debug, error, instrument, warn};

use super::Agent;
use crate::agents::capabilities::Capabilities;
use crate::agents::extension::{ExtensionConfig, ExtensionResult};
use crate::config::Config;
use crate::message::{Message, ToolRequest};
use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
Expand All @@ -16,7 +18,7 @@ use crate::register_agent;
use crate::token_counter::TokenCounter;
use crate::truncate::{truncate_messages, OldestFirstTruncation};
use indoc::indoc;
use mcp_core::tool::Tool;
use mcp_core::{tool::Tool, Content};
use serde_json::{json, Value};

const MAX_TRUNCATION_ATTEMPTS: usize = 3;
Expand All @@ -26,14 +28,21 @@ const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
pub struct TruncateAgent {
capabilities: Mutex<Capabilities>,
token_counter: TokenCounter,
confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed)
confirmation_rx: Mutex<mpsc::Receiver<(String, bool)>>,
}

impl TruncateAgent {
pub fn new(provider: Box<dyn Provider>) -> Self {
let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name());
// Create channel with buffer size 32 (adjust if needed)
let (tx, rx) = mpsc::channel(32);

Self {
capabilities: Mutex::new(Capabilities::new(provider)),
token_counter,
confirmation_tx: tx,
confirmation_rx: Mutex::new(rx),
}
}

Expand Down Expand Up @@ -121,6 +130,13 @@ impl Agent for TruncateAgent {
Ok(Value::Null)
}

/// Handle a confirmation response for a tool request
async fn handle_confirmation(&self, request_id: String, confirmed: bool) {
if let Err(e) = self.confirmation_tx.send((request_id, confirmed)).await {
error!("Failed to send confirmation: {}", e);
}
}

#[instrument(skip(self, messages), fields(user_message))]
async fn reply(
&self,
Expand All @@ -132,6 +148,10 @@ impl Agent for TruncateAgent {
let mut tools = capabilities.get_prefixed_tools().await?;
let mut truncation_attempt: usize = 0;

// Load settings from config
let config = Config::global();
let goose_mode = config.get("GOOSE_MODE").unwrap_or("auto".to_string());

// we add in the 2 resource tools if any extensions support resources
// TODO: make sure there is no collision with another extension's tool name
let read_resource_tool = Tool::new(
Expand Down Expand Up @@ -191,7 +211,6 @@ impl Agent for TruncateAgent {
Ok(Box::pin(async_stream::try_stream! {
let _reply_guard = reply_span.enter();
loop {
// Attempt to get completion from provider
match capabilities.provider().complete(
&system_prompt,
&messages,
Expand All @@ -218,24 +237,86 @@ impl Agent for TruncateAgent {
break;
}

// Then dispatch each in parallel
let futures: Vec<_> = tool_requests
.iter()
.filter_map(|request| request.tool_call.clone().ok())
.map(|tool_call| capabilities.dispatch_tool_call(tool_call))
.collect();

// Process all the futures in parallel but wait until all are finished
let outputs = futures::future::join_all(futures).await;

// Create a message with the responses
// Process tool requests depending on goose_mode
let mut message_tool_response = Message::user();
// Now combine these into MessageContent::ToolResponse using the original ID
for (request, output) in tool_requests.iter().zip(outputs.into_iter()) {
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
// Clone goose_mode once before the match to avoid move issues
let mode = goose_mode.clone();
match mode.as_str() {
"approve" => {
// Process each tool request sequentially with confirmation
for request in &tool_requests {
if let Ok(tool_call) = request.tool_call.clone() {
let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
Some("Goose would like to call the tool: {}\nAllow? (y/n): ".to_string()),
);
yield confirmation;

// Wait for confirmation response through the channel
let mut rx = self.confirmation_rx.lock().await;
if let Some((req_id, confirmed)) = rx.recv().await {
if req_id == request.id {
if confirmed {
// User approved - dispatch the tool call
let output = capabilities.dispatch_tool_call(tool_call).await;
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
output,
);
} else {
// User declined - add declined response
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text("User declined to run this tool.")]),
);
}
}
}
}
}
},
"chat" => {
// Skip all tool calls in chat mode
for request in &tool_requests {
message_tool_response = message_tool_response.with_tool_response(
request.id.clone(),
Ok(vec![Content::text(
"The following tool call was skipped in Goose chat mode. \
In chat mode, you cannot run tool calls, instead, you can \
only provide a detailed plan to the user. Provide an \
explanation of the proposed tool call as if it were a plan. \
Only if the user asks, provide a short explanation to the \
user that they could consider running the tool above on \
their own or with a different goose mode."
)]),
);
}
},
_ => {
if mode != "auto" {
warn!("Unknown GOOSE_MODE: {mode:?}. Defaulting to 'auto' mode.");
}
// Process tool requests in parallel
let mut tool_futures = Vec::new();
for request in &tool_requests {
if let Ok(tool_call) = request.tool_call.clone() {
tool_futures.push(async {
let output = capabilities.dispatch_tool_call(tool_call).await;
(request.id.clone(), output)
});
}
}
// Wait for all tool calls to complete
let results = futures::future::join_all(tool_futures).await;
for (request_id, output) in results {
message_tool_response = message_tool_response.with_tool_response(
request_id,
output,
);
}
}
}

yield message_tool_response.clone();
Expand Down
Loading

0 comments on commit 3723c64

Please sign in to comment.