From ceb80ca189dce987771d8b29156b3005cdff7e79 Mon Sep 17 00:00:00 2001 From: Jarrod Sibbison <72240382+jsibbison-square@users.noreply.github.com> Date: Thu, 28 Nov 2024 13:07:37 +1100 Subject: [PATCH] Interrupt running shell tool commands (#365) --- crates/goose-server/src/routes/reply.rs | 35 ++++++++++++++++++------- crates/goose/src/developer.rs | 4 ++- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index d9b3aa3e6..4d50faf6d 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -22,8 +22,10 @@ use std::{ convert::Infallible, pin::Pin, task::{Context, Poll}, + time::Duration, }; use tokio::sync::mpsc; +use tokio::time::timeout; use tokio_stream::wrappers::ReceiverStream; // Types matching the incoming JSON structure @@ -293,18 +295,31 @@ async fn handler( } }; - while let Some(response) = stream.next().await { - match response { - Ok(message) => { - if let Err(e) = stream_message(message, &tx).await { - tracing::error!("Error sending message through channel: {}", e); - break; + loop { + tokio::select! { + response = timeout(Duration::from_millis(500), stream.next()) => { + match response { + Ok(Some(Ok(message))) => { + if let Err(e) = stream_message(message, &tx).await { + tracing::error!("Error sending message through channel: {}", e); + break; + } + } + Ok(Some(Err(e))) => { + tracing::error!("Error processing message: {}", e); + break; + } + Ok(None) => { + break; + } + Err(_) => { // Heartbeat, used to detect disconnected clients and then end running tools. + if tx.is_closed() { + break; + } + continue; + } } } - Err(e) => { - tracing::error!("Error processing message: {}", e); - break; - } } } diff --git a/crates/goose/src/developer.rs b/crates/goose/src/developer.rs index 1fbc91ed7..5e5e5d7c7 100644 --- a/crates/goose/src/developer.rs +++ b/crates/goose/src/developer.rs @@ -8,8 +8,8 @@ use serde_json::{json, Value}; use std::collections::{HashMap, HashSet}; use std::io::Cursor; use std::path::{Path, PathBuf}; -use std::process::Command; use std::sync::Mutex; +use tokio::process::Command; use xcap::Monitor; use crate::errors::{AgentError, AgentResult}; @@ -192,9 +192,11 @@ impl DeveloperSystem { // Execute the command let output = Command::new("bash") + .kill_on_drop(true) // Critical so that the command is killed when the agent.reply stream is interrupted. .arg("-c") .arg(cmd_with_redirect) .output() + .await .map_err(|e| AgentError::ExecutionError(e.to_string()))?; let output_str = format!(