From 3723c64cc557891458cc7c4280d0300582926bbd Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Mon, 24 Feb 2025 11:54:41 -0800 Subject: [PATCH] feat: permission before tool call (#1313) --- crates/goose-cli/src/session/mod.rs | 51 ++++++-- crates/goose-cli/src/session/output.rs | 16 ++- crates/goose-server/src/routes/reply.rs | 13 +- crates/goose/src/agents/agent.rs | 3 + crates/goose/src/agents/reference.rs | 4 + crates/goose/src/agents/truncate.rs | 119 +++++++++++++++--- crates/goose/src/message.rs | 44 +++++++ .../goose/src/providers/formats/anthropic.rs | 3 + crates/goose/src/providers/formats/bedrock.rs | 3 + crates/goose/src/providers/formats/openai.rs | 3 + 10 files changed, 226 insertions(+), 33 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index d359a297a..dc2c797b3 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -14,6 +14,7 @@ use goose::agents::Agent; use goose::message::{Message, MessageContent}; use mcp_core::handler::ToolError; use rand::{distributions::Alphanumeric, Rng}; +use rustyline::Editor; use std::path::PathBuf; use tokio; @@ -104,7 +105,7 @@ impl Session { } pub async fn start(&mut self) -> Result<()> { - let mut editor = rustyline::Editor::<(), rustyline::history::DefaultHistory>::new()?; + let mut editor = Editor::<(), rustyline::history::DefaultHistory>::new()?; // Load history from messages for msg in self @@ -120,7 +121,6 @@ impl Session { } } } - output::display_greeting(); loop { match input::get_input(&mut editor)? { @@ -129,7 +129,7 @@ impl Session { storage::persist_messages(&self.session_file, &self.messages)?; output::show_thinking(); - self.process_agent_response().await?; + self.process_agent_response(&mut editor).await?; output::hide_thinking(); } input::InputResult::Exit => break, @@ -188,11 +188,15 @@ impl Session { self.messages .push(Message::user().with_text(&initial_message)); storage::persist_messages(&self.session_file, &self.messages)?; - self.process_agent_response().await?; + let mut editor = Editor::<(), rustyline::history::DefaultHistory>::new()?; + self.process_agent_response(&mut editor).await?; Ok(()) } - async fn process_agent_response(&mut self) -> Result<()> { + async fn process_agent_response( + &mut self, + editor: &mut Editor<(), rustyline::history::DefaultHistory>, + ) -> Result<()> { let mut stream = self.agent.reply(&self.messages).await?; use futures::StreamExt; @@ -200,8 +204,41 @@ impl Session { tokio::select! { result = stream.next() => { match result { - Some(Ok(message)) => { - self.messages.push(message.clone()); + Some(Ok(mut message)) => { + + // Handle tool confirmation requests before rendering + if let Some(MessageContent::ToolConfirmationRequest(confirmation)) = message.content.first() { + output::hide_thinking(); + + // Format the confirmation prompt + let prompt = "Goose would like to call the above tool. Allow? (y/n):".to_string(); + + let confirmation_request = Message::user().with_tool_confirmation_request( + confirmation.id.clone(), + confirmation.tool_name.clone(), + confirmation.arguments.clone(), + Some(prompt) + ); + output::render_message(&confirmation_request); + + // Get confirmation from user + let confirmed = match input::get_input(editor)? { + input::InputResult::Message(content) => { + content.trim().to_lowercase().starts_with('y') + } + _ => false, + }; + + self.agent.handle_confirmation(confirmation.id.clone(), confirmed).await; + + message = confirmation_request; + } + + // Only push the message if it's not a tool confirmation request + if !message.content.iter().any(|content| matches!(content, MessageContent::ToolConfirmationRequest(_))) { + self.messages.push(message.clone()); + } + storage::persist_messages(&self.session_file, &self.messages)?; output::hide_thinking(); output::render_message(&message); diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index f6ccbdfcb..e190d30cc 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -1,7 +1,7 @@ use bat::WrappingMode; use console::style; use goose::config::Config; -use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; +use goose::message::{Message, MessageContent, ToolConfirmationRequest, ToolRequest, ToolResponse}; use mcp_core::tool::ToolCall; use serde_json::Value; use std::cell::RefCell; @@ -94,6 +94,9 @@ pub fn render_message(message: &Message) { MessageContent::Text(text) => print_markdown(&text.text, theme), MessageContent::ToolRequest(req) => render_tool_request(req, theme), MessageContent::ToolResponse(resp) => render_tool_response(resp, theme), + MessageContent::ToolConfirmationRequest(req) => { + render_tool_confirmation_request(req, theme) + } MessageContent::Image(image) => { println!("Image: [data: {}, type: {}]", image.data, image.mime_type); } @@ -147,6 +150,17 @@ fn render_tool_response(resp: &ToolResponse, theme: Theme) { } } +fn render_tool_confirmation_request(req: &ToolConfirmationRequest, theme: Theme) { + match &req.prompt { + Some(prompt) => { + let colored_prompt = + prompt.replace("Allow? (y/n)", &format!("{}", style("Allow? (y/n)").cyan())); + println!("{}", colored_prompt); + } + None => print_markdown("No prompt provided", theme), + } +} + pub fn render_error(message: &str) { println!("\n {} {}\n", style("error:").red().bold(), message); } diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index b5a716ce3..03166c173 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -247,13 +247,14 @@ async fn stream_message( .await?; } } + MessageContent::ToolConfirmationRequest(_) => { + // skip tool confirmation requests + } MessageContent::Image(_) => { - // TODO - continue; + // skip images } MessageContent::ToolResponse(_) => { - // Tool responses should only come from the user - continue; + // skip tool responses } } } @@ -311,7 +312,7 @@ async fn handler( let mut stream = match agent.reply(&messages).await { Ok(stream) => stream, Err(e) => { - tracing::error!("Failed to start reply stream: {}", e); + tracing::error!("Failed to start reply stream: {:?}", e); let _ = tx .send(ProtocolFormatter::format_error(&e.to_string())) .await; @@ -398,7 +399,7 @@ async fn ask_handler( let mut stream = match agent.reply(&messages).await { Ok(stream) => stream, Err(e) => { - tracing::error!("Failed to start reply stream: {}", e); + tracing::error!("Failed to start reply stream: {:?}", e); return Err(StatusCode::INTERNAL_SERVER_ERROR); } }; diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 4500f95d1..469418b20 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -32,6 +32,9 @@ pub trait Agent: Send + Sync { /// Add custom text to be included in the system prompt async fn extend_system_prompt(&mut self, extension: String); + /// Handle a confirmation response for a tool request + async fn handle_confirmation(&self, request_id: String, confirmed: bool); + /// Override the system prompt with custom text async fn override_system_prompt(&mut self, template: String); } diff --git a/crates/goose/src/agents/reference.rs b/crates/goose/src/agents/reference.rs index 6c30435d9..bda3acce2 100644 --- a/crates/goose/src/agents/reference.rs +++ b/crates/goose/src/agents/reference.rs @@ -61,6 +61,10 @@ impl Agent for ReferenceAgent { Ok(Value::Null) } + async fn handle_confirmation(&self, _request_id: String, _confirmed: bool) { + // TODO implement + } + #[instrument(skip(self, messages), fields(user_message))] async fn reply( &self, diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index 685524d4b..86663d2df 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -2,12 +2,14 @@ /// It makes no attempt to handle context limits, and cannot read resources use async_trait::async_trait; use futures::stream::BoxStream; +use tokio::sync::mpsc; use tokio::sync::Mutex; use tracing::{debug, error, instrument, warn}; use super::Agent; use crate::agents::capabilities::Capabilities; use crate::agents::extension::{ExtensionConfig, ExtensionResult}; +use crate::config::Config; use crate::message::{Message, ToolRequest}; use crate::providers::base::Provider; use crate::providers::base::ProviderUsage; @@ -16,7 +18,7 @@ use crate::register_agent; use crate::token_counter::TokenCounter; use crate::truncate::{truncate_messages, OldestFirstTruncation}; use indoc::indoc; -use mcp_core::tool::Tool; +use mcp_core::{tool::Tool, Content}; use serde_json::{json, Value}; const MAX_TRUNCATION_ATTEMPTS: usize = 3; @@ -26,14 +28,21 @@ const ESTIMATE_FACTOR_DECAY: f32 = 0.9; pub struct TruncateAgent { capabilities: Mutex, token_counter: TokenCounter, + confirmation_tx: mpsc::Sender<(String, bool)>, // (request_id, confirmed) + confirmation_rx: Mutex>, } impl TruncateAgent { pub fn new(provider: Box) -> Self { let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name()); + // Create channel with buffer size 32 (adjust if needed) + let (tx, rx) = mpsc::channel(32); + Self { capabilities: Mutex::new(Capabilities::new(provider)), token_counter, + confirmation_tx: tx, + confirmation_rx: Mutex::new(rx), } } @@ -121,6 +130,13 @@ impl Agent for TruncateAgent { Ok(Value::Null) } + /// Handle a confirmation response for a tool request + async fn handle_confirmation(&self, request_id: String, confirmed: bool) { + if let Err(e) = self.confirmation_tx.send((request_id, confirmed)).await { + error!("Failed to send confirmation: {}", e); + } + } + #[instrument(skip(self, messages), fields(user_message))] async fn reply( &self, @@ -132,6 +148,10 @@ impl Agent for TruncateAgent { let mut tools = capabilities.get_prefixed_tools().await?; let mut truncation_attempt: usize = 0; + // Load settings from config + let config = Config::global(); + let goose_mode = config.get("GOOSE_MODE").unwrap_or("auto".to_string()); + // we add in the 2 resource tools if any extensions support resources // TODO: make sure there is no collision with another extension's tool name let read_resource_tool = Tool::new( @@ -191,7 +211,6 @@ impl Agent for TruncateAgent { Ok(Box::pin(async_stream::try_stream! { let _reply_guard = reply_span.enter(); loop { - // Attempt to get completion from provider match capabilities.provider().complete( &system_prompt, &messages, @@ -218,24 +237,86 @@ impl Agent for TruncateAgent { break; } - // Then dispatch each in parallel - let futures: Vec<_> = tool_requests - .iter() - .filter_map(|request| request.tool_call.clone().ok()) - .map(|tool_call| capabilities.dispatch_tool_call(tool_call)) - .collect(); - - // Process all the futures in parallel but wait until all are finished - let outputs = futures::future::join_all(futures).await; - - // Create a message with the responses + // Process tool requests depending on goose_mode let mut message_tool_response = Message::user(); - // Now combine these into MessageContent::ToolResponse using the original ID - for (request, output) in tool_requests.iter().zip(outputs.into_iter()) { - message_tool_response = message_tool_response.with_tool_response( - request.id.clone(), - output, - ); + // Clone goose_mode once before the match to avoid move issues + let mode = goose_mode.clone(); + match mode.as_str() { + "approve" => { + // Process each tool request sequentially with confirmation + for request in &tool_requests { + if let Ok(tool_call) = request.tool_call.clone() { + let confirmation = Message::user().with_tool_confirmation_request( + request.id.clone(), + tool_call.name.clone(), + tool_call.arguments.clone(), + Some("Goose would like to call the tool: {}\nAllow? (y/n): ".to_string()), + ); + yield confirmation; + + // Wait for confirmation response through the channel + let mut rx = self.confirmation_rx.lock().await; + if let Some((req_id, confirmed)) = rx.recv().await { + if req_id == request.id { + if confirmed { + // User approved - dispatch the tool call + let output = capabilities.dispatch_tool_call(tool_call).await; + message_tool_response = message_tool_response.with_tool_response( + request.id.clone(), + output, + ); + } else { + // User declined - add declined response + message_tool_response = message_tool_response.with_tool_response( + request.id.clone(), + Ok(vec![Content::text("User declined to run this tool.")]), + ); + } + } + } + } + } + }, + "chat" => { + // Skip all tool calls in chat mode + for request in &tool_requests { + message_tool_response = message_tool_response.with_tool_response( + request.id.clone(), + Ok(vec![Content::text( + "The following tool call was skipped in Goose chat mode. \ + In chat mode, you cannot run tool calls, instead, you can \ + only provide a detailed plan to the user. Provide an \ + explanation of the proposed tool call as if it were a plan. \ + Only if the user asks, provide a short explanation to the \ + user that they could consider running the tool above on \ + their own or with a different goose mode." + )]), + ); + } + }, + _ => { + if mode != "auto" { + warn!("Unknown GOOSE_MODE: {mode:?}. Defaulting to 'auto' mode."); + } + // Process tool requests in parallel + let mut tool_futures = Vec::new(); + for request in &tool_requests { + if let Ok(tool_call) = request.tool_call.clone() { + tool_futures.push(async { + let output = capabilities.dispatch_tool_call(tool_call).await; + (request.id.clone(), output) + }); + } + } + // Wait for all tool calls to complete + let results = futures::future::join_all(tool_futures).await; + for (request_id, output) in results { + message_tool_response = message_tool_response.with_tool_response( + request_id, + output, + ); + } + } } yield message_tool_response.clone(); diff --git a/crates/goose/src/message.rs b/crates/goose/src/message.rs index 30de253ff..3e10e9d69 100644 --- a/crates/goose/src/message.rs +++ b/crates/goose/src/message.rs @@ -12,6 +12,7 @@ use mcp_core::content::{Content, ImageContent, TextContent}; use mcp_core::handler::ToolResult; use mcp_core::role::Role; use mcp_core::tool::ToolCall; +use serde_json::Value; #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct ToolRequest { @@ -25,6 +26,14 @@ pub struct ToolResponse { pub tool_result: ToolResult>, } +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +pub struct ToolConfirmationRequest { + pub id: String, + pub tool_name: String, + pub arguments: Value, + pub prompt: Option, +} + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] /// Content passed inside a message, which can be both simple content and tool content pub enum MessageContent { @@ -32,6 +41,7 @@ pub enum MessageContent { Image(ImageContent), ToolRequest(ToolRequest), ToolResponse(ToolResponse), + ToolConfirmationRequest(ToolConfirmationRequest), } impl MessageContent { @@ -64,6 +74,19 @@ impl MessageContent { }) } + pub fn tool_confirmation_request>( + id: S, + tool_name: String, + arguments: Value, + prompt: Option, + ) -> Self { + MessageContent::ToolConfirmationRequest(ToolConfirmationRequest { + id: id.into(), + tool_name, + arguments, + prompt, + }) + } pub fn as_tool_request(&self) -> Option<&ToolRequest> { if let MessageContent::ToolRequest(ref tool_request) = self { Some(tool_request) @@ -80,6 +103,14 @@ impl MessageContent { } } + pub fn as_tool_confirmation_request(&self) -> Option<&ToolConfirmationRequest> { + if let MessageContent::ToolConfirmationRequest(ref tool_confirmation_request) = self { + Some(tool_confirmation_request) + } else { + None + } + } + pub fn as_tool_response_text(&self) -> Option { if let Some(tool_response) = self.as_tool_response() { if let Ok(contents) = &tool_response.tool_result { @@ -178,6 +209,19 @@ impl Message { self.with_content(MessageContent::tool_response(id, result)) } + /// Add a tool confirmation request to the message + pub fn with_tool_confirmation_request>( + self, + id: S, + tool_name: String, + arguments: Value, + prompt: Option, + ) -> Self { + self.with_content(MessageContent::tool_confirmation_request( + id, tool_name, arguments, prompt, + )) + } + /// Get the concatenated text content of the message, separated by newlines pub fn as_concat_text(&self) -> String { self.content diff --git a/crates/goose/src/providers/formats/anthropic.rs b/crates/goose/src/providers/formats/anthropic.rs index 224630b03..4eadf4bcd 100644 --- a/crates/goose/src/providers/formats/anthropic.rs +++ b/crates/goose/src/providers/formats/anthropic.rs @@ -57,6 +57,9 @@ pub fn format_messages(messages: &[Message]) -> Vec { })); } } + MessageContent::ToolConfirmationRequest(_tool_confirmation_request) => { + // Skip tool confirmation requests + } MessageContent::Image(_) => continue, // Anthropic doesn't support image content yet } } diff --git a/crates/goose/src/providers/formats/bedrock.rs b/crates/goose/src/providers/formats/bedrock.rs index 812fda263..dcedf31ba 100644 --- a/crates/goose/src/providers/formats/bedrock.rs +++ b/crates/goose/src/providers/formats/bedrock.rs @@ -28,6 +28,9 @@ pub fn to_bedrock_message(message: &Message) -> Result { pub fn to_bedrock_message_content(content: &MessageContent) -> Result { Ok(match content { MessageContent::Text(text) => bedrock::ContentBlock::Text(text.text.to_string()), + MessageContent::ToolConfirmationRequest(_tool_confirmation_request) => { + bedrock::ContentBlock::Text("".to_string()) + } MessageContent::Image(_) => { bail!("Image content is not supported by Bedrock provider yet") } diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index 32186c796..f78304163 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -136,6 +136,9 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< } } } + MessageContent::ToolConfirmationRequest(_) => { + // Skip tool confirmation requests + } MessageContent::Image(image) => { // Handle direct image content converted["content"] = json!([convert_image(image, image_format)]);