From d9e3000db700e5e0cc7cd3f9a2d7fb9e047b35df Mon Sep 17 00:00:00 2001 From: Jarrod Sibbison <72240382+jsibbison-square@users.noreply.github.com> Date: Tue, 3 Dec 2024 21:30:39 +1100 Subject: [PATCH] fix: Kills recursively processes from spawned agent shell commands (#384) --- crates/goose-cli/src/session.rs | 2 + crates/goose-server/src/routes/reply.rs | 2 + crates/goose/Cargo.toml | 6 +- crates/goose/src/developer.rs | 23 +++- crates/goose/src/lib.rs | 1 + crates/goose/src/process_store.rs | 151 ++++++++++++++++++++++++ 6 files changed, 181 insertions(+), 4 deletions(-) create mode 100644 crates/goose/src/process_store.rs diff --git a/crates/goose-cli/src/session.rs b/crates/goose-cli/src/session.rs index 9a8ec1985..c600cd910 100644 --- a/crates/goose-cli/src/session.rs +++ b/crates/goose-cli/src/session.rs @@ -173,6 +173,8 @@ We've removed the conversation up to the most recent user message } } _ = tokio::signal::ctrl_c() => { + // Kill any running processes when the client disconnects + goose::process_store::kill_processes(); drop(stream); self.handle_interrupted_messages(); break; diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 37c1367ce..58769bb2a 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -305,6 +305,8 @@ async fn handler( } Err(_) => { // Heartbeat, used to detect disconnected clients and then end running tools. if tx.is_closed() { + // Kill any running processes when the client disconnects + goose::process_store::kill_processes(); break; } continue; diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 1d837aec4..e0ccef5c9 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -38,12 +38,15 @@ tower-http = { version = "0.5", features = ["cors"] } webbrowser = "0.8" dotenv = "0.15" xcap = "0.0.14" -libc = "=0.2.164" +libc = "=0.2.167" +lazy_static = "1.5" +kill_tree = "0.2.4" keyring = { version = "3.6.1", features = ["apple-native", "windows-native", "sync-secret-service"] } shellexpand = "3.1.0" [dev-dependencies] +sysinfo = "0.32.1" wiremock = "0.6.0" mockito = "1.2" tempfile = "3.8" @@ -52,4 +55,3 @@ mockall = "0.11" [[example]] name = "databricks_oauth" path = "examples/databricks_oauth.rs" - diff --git a/crates/goose/src/developer.rs b/crates/goose/src/developer.rs index f090790d6..ad69b8fb9 100644 --- a/crates/goose/src/developer.rs +++ b/crates/goose/src/developer.rs @@ -8,6 +8,7 @@ use serde_json::{json, Value}; use std::collections::{HashMap, HashSet}; use std::io::Cursor; use std::path::{Path, PathBuf}; +use std::process::Stdio; use std::sync::Mutex; use tokio::process::Command; use xcap::Monitor; @@ -192,14 +193,32 @@ impl DeveloperSystem { let cmd_with_redirect = format!("{} 2>&1", command); // Execute the command - let output = Command::new("bash") + let child = Command::new("bash") + .stdout(Stdio::piped()) // These two pipes required to capture output later. + .stderr(Stdio::piped()) .kill_on_drop(true) // Critical so that the command is killed when the agent.reply stream is interrupted. .arg("-c") .arg(cmd_with_redirect) - .output() + .spawn() + .map_err(|e| AgentError::ExecutionError(e.to_string()))?; + + // Store the process ID with the command as the key + let pid: Option = child.id(); + if let Some(pid) = pid { + crate::process_store::store_process(pid); + } + + // Wait for the command to complete and get output + let output = child + .wait_with_output() .await .map_err(|e| AgentError::ExecutionError(e.to_string()))?; + // Remove the process ID from the store + if let Some(pid) = pid { + crate::process_store::remove_process(pid); + } + let output_str = format!( "Finished with Status Code: {}\nOutput:\n{}", output.status, diff --git a/crates/goose/src/lib.rs b/crates/goose/src/lib.rs index c77c1fb69..989743cf8 100644 --- a/crates/goose/src/lib.rs +++ b/crates/goose/src/lib.rs @@ -4,6 +4,7 @@ pub mod errors; pub mod key_manager; pub mod memory; pub mod models; +pub mod process_store; pub mod prompt_template; pub mod providers; pub mod systems; diff --git a/crates/goose/src/process_store.rs b/crates/goose/src/process_store.rs new file mode 100644 index 000000000..dd8185f5e --- /dev/null +++ b/crates/goose/src/process_store.rs @@ -0,0 +1,151 @@ +use kill_tree::{blocking::kill_tree_with_config, Config}; +use lazy_static::lazy_static; +use std::sync::Mutex; + +// Singleton that will store process IDs for spawned child processes implementing agent tasks. +lazy_static! { + static ref PROCESS_STORE: Mutex> = Mutex::new(Vec::new()); +} + +pub fn store_process(pid: u32) { + let mut store = PROCESS_STORE.lock().unwrap(); + store.push(pid); +} + +// This removes the record of a process from the store, it does not kill it or check that it is dead. +pub fn remove_process(pid: u32) -> bool { + let mut store = PROCESS_STORE.lock().unwrap(); + if let Some(index) = store.iter().position(|&x| x == pid) { + store.remove(index); + true + } else { + false + } +} + +/// Kill all stored processes +pub fn kill_processes() { + let mut killed_processes = Vec::new(); + { + let store = PROCESS_STORE.lock().unwrap(); + for &pid in store.iter() { + let config = Config { + signal: "SIGKILL".to_string(), + ..Default::default() + }; + let outputs = match kill_tree_with_config(pid, &config) { + Ok(outputs) => outputs, + Err(e) => { + eprintln!("Failed to kill process {}: {}", pid, e); + continue; + } + }; + for output in outputs { + match output { + kill_tree::Output::Killed { process_id, .. } => { + killed_processes.push(process_id); + } + kill_tree::Output::MaybeAlreadyTerminated { process_id, .. } => { + killed_processes.push(process_id); + } + } + } + } + } + // Clean up the store + for pid in killed_processes { + remove_process(pid); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::os::unix::fs::PermissionsExt; + use std::time::Duration; + use std::{fs, thread}; + use sysinfo::{Pid, ProcessesToUpdate, System}; + use tokio::process::Command; + + #[tokio::test] + async fn test_kill_processes_with_children() { + // Create a temporary script that spawns a child process + let temp_dir = std::env::temp_dir(); + let script_path = temp_dir.join("test_script.sh"); + let script_content = r#"#!/bin/bash + # Sleep in the parent process + sleep 300 + "#; + + fs::write(&script_path, script_content).unwrap(); + fs::set_permissions(&script_path, std::fs::Permissions::from_mode(0o755)).unwrap(); + + // Start the parent process which will spawn a child + let parent = Command::new("bash") + .arg("-c") + .arg(script_path.to_str().unwrap()) + .spawn() + .expect("Failed to start parent process"); + + let parent_pid = parent.id().unwrap() as u32; + + // Store the parent process ID + store_process(parent_pid); + + // Give processes time to start + thread::sleep(Duration::from_secs(1)); + + // Get the child process ID using pgrep + let child_pids = Command::new("pgrep") + .arg("-P") + .arg(parent_pid.to_string()) + .output() + .await + .expect("Failed to get child PIDs"); + + let child_pid_str = String::from_utf8_lossy(&child_pids.stdout); + let child_pids: Vec = child_pid_str + .lines() + .filter_map(|line| line.trim().parse::().ok()) + .collect(); + assert!(child_pids.len() == 1); + + // Verify processes are running + assert!(is_process_running(parent_pid).await); + assert!(is_process_running(child_pids[0]).await); + + kill_processes(); + + // Wait until processes are killed + let mut attempts = 0; + while attempts < 10 { + if !is_process_running(parent_pid).await && !is_process_running(child_pids[0]).await { + break; + } + thread::sleep(Duration::from_millis(100)); + attempts += 1; + } + + // Verify processes are dead + assert!(!is_process_running(parent_pid).await); + assert!(!is_process_running(child_pids[0]).await); + + // Clean up the temporary script + fs::remove_file(script_path).unwrap(); + } + + async fn is_process_running(pid: u32) -> bool { + let mut system = System::new_all(); + system.refresh_processes(ProcessesToUpdate::All, true); + + match system.process(Pid::from_u32(pid)) { + Some(process) => !matches!( + process.status(), + sysinfo::ProcessStatus::Stop + | sysinfo::ProcessStatus::Zombie + | sysinfo::ProcessStatus::Dead + ), + None => false, + } + } +}