Skip to content

Commit

Permalink
fix: Kills recursively processes from spawned agent shell commands (#384
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jsibbison-square authored Dec 3, 2024
1 parent 798a7cf commit d9e3000
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 4 deletions.
2 changes: 2 additions & 0 deletions crates/goose-cli/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ We've removed the conversation up to the most recent user message
}
}
_ = tokio::signal::ctrl_c() => {
// Kill any running processes when the client disconnects
goose::process_store::kill_processes();
drop(stream);
self.handle_interrupted_messages();
break;
Expand Down
2 changes: 2 additions & 0 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ async fn handler(
}
Err(_) => { // Heartbeat, used to detect disconnected clients and then end running tools.
if tx.is_closed() {
// Kill any running processes when the client disconnects
goose::process_store::kill_processes();
break;
}
continue;
Expand Down
6 changes: 4 additions & 2 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,15 @@ tower-http = { version = "0.5", features = ["cors"] }
webbrowser = "0.8"
dotenv = "0.15"
xcap = "0.0.14"
libc = "=0.2.164"
libc = "=0.2.167"
lazy_static = "1.5"
kill_tree = "0.2.4"

keyring = { version = "3.6.1", features = ["apple-native", "windows-native", "sync-secret-service"] }
shellexpand = "3.1.0"

[dev-dependencies]
sysinfo = "0.32.1"
wiremock = "0.6.0"
mockito = "1.2"
tempfile = "3.8"
Expand All @@ -52,4 +55,3 @@ mockall = "0.11"
[[example]]
name = "databricks_oauth"
path = "examples/databricks_oauth.rs"

23 changes: 21 additions & 2 deletions crates/goose/src/developer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use serde_json::{json, Value};
use std::collections::{HashMap, HashSet};
use std::io::Cursor;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::Mutex;
use tokio::process::Command;
use xcap::Monitor;
Expand Down Expand Up @@ -192,14 +193,32 @@ impl DeveloperSystem {
let cmd_with_redirect = format!("{} 2>&1", command);

// Execute the command
let output = Command::new("bash")
let child = Command::new("bash")
.stdout(Stdio::piped()) // These two pipes required to capture output later.
.stderr(Stdio::piped())
.kill_on_drop(true) // Critical so that the command is killed when the agent.reply stream is interrupted.
.arg("-c")
.arg(cmd_with_redirect)
.output()
.spawn()
.map_err(|e| AgentError::ExecutionError(e.to_string()))?;

// Store the process ID with the command as the key
let pid: Option<u32> = child.id();
if let Some(pid) = pid {
crate::process_store::store_process(pid);
}

// Wait for the command to complete and get output
let output = child
.wait_with_output()
.await
.map_err(|e| AgentError::ExecutionError(e.to_string()))?;

// Remove the process ID from the store
if let Some(pid) = pid {
crate::process_store::remove_process(pid);
}

let output_str = format!(
"Finished with Status Code: {}\nOutput:\n{}",
output.status,
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod errors;
pub mod key_manager;
pub mod memory;
pub mod models;
pub mod process_store;
pub mod prompt_template;
pub mod providers;
pub mod systems;
Expand Down
151 changes: 151 additions & 0 deletions crates/goose/src/process_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use kill_tree::{blocking::kill_tree_with_config, Config};
use lazy_static::lazy_static;
use std::sync::Mutex;

// Singleton that will store process IDs for spawned child processes implementing agent tasks.
lazy_static! {
static ref PROCESS_STORE: Mutex<Vec<u32>> = Mutex::new(Vec::new());
}

pub fn store_process(pid: u32) {
let mut store = PROCESS_STORE.lock().unwrap();
store.push(pid);
}

// This removes the record of a process from the store, it does not kill it or check that it is dead.
pub fn remove_process(pid: u32) -> bool {
let mut store = PROCESS_STORE.lock().unwrap();
if let Some(index) = store.iter().position(|&x| x == pid) {
store.remove(index);
true
} else {
false
}
}

/// Kill all stored processes
pub fn kill_processes() {
let mut killed_processes = Vec::new();
{
let store = PROCESS_STORE.lock().unwrap();
for &pid in store.iter() {
let config = Config {
signal: "SIGKILL".to_string(),
..Default::default()
};
let outputs = match kill_tree_with_config(pid, &config) {
Ok(outputs) => outputs,
Err(e) => {
eprintln!("Failed to kill process {}: {}", pid, e);
continue;
}
};
for output in outputs {
match output {
kill_tree::Output::Killed { process_id, .. } => {
killed_processes.push(process_id);
}
kill_tree::Output::MaybeAlreadyTerminated { process_id, .. } => {
killed_processes.push(process_id);
}
}
}
}
}
// Clean up the store
for pid in killed_processes {
remove_process(pid);
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::os::unix::fs::PermissionsExt;
use std::time::Duration;
use std::{fs, thread};
use sysinfo::{Pid, ProcessesToUpdate, System};
use tokio::process::Command;

#[tokio::test]
async fn test_kill_processes_with_children() {
// Create a temporary script that spawns a child process
let temp_dir = std::env::temp_dir();
let script_path = temp_dir.join("test_script.sh");
let script_content = r#"#!/bin/bash
# Sleep in the parent process
sleep 300
"#;

fs::write(&script_path, script_content).unwrap();
fs::set_permissions(&script_path, std::fs::Permissions::from_mode(0o755)).unwrap();

// Start the parent process which will spawn a child
let parent = Command::new("bash")
.arg("-c")
.arg(script_path.to_str().unwrap())
.spawn()
.expect("Failed to start parent process");

let parent_pid = parent.id().unwrap() as u32;

// Store the parent process ID
store_process(parent_pid);

// Give processes time to start
thread::sleep(Duration::from_secs(1));

// Get the child process ID using pgrep
let child_pids = Command::new("pgrep")
.arg("-P")
.arg(parent_pid.to_string())
.output()
.await
.expect("Failed to get child PIDs");

let child_pid_str = String::from_utf8_lossy(&child_pids.stdout);
let child_pids: Vec<u32> = child_pid_str
.lines()
.filter_map(|line| line.trim().parse::<u32>().ok())
.collect();
assert!(child_pids.len() == 1);

// Verify processes are running
assert!(is_process_running(parent_pid).await);
assert!(is_process_running(child_pids[0]).await);

kill_processes();

// Wait until processes are killed
let mut attempts = 0;
while attempts < 10 {
if !is_process_running(parent_pid).await && !is_process_running(child_pids[0]).await {
break;
}
thread::sleep(Duration::from_millis(100));
attempts += 1;
}

// Verify processes are dead
assert!(!is_process_running(parent_pid).await);
assert!(!is_process_running(child_pids[0]).await);

// Clean up the temporary script
fs::remove_file(script_path).unwrap();
}

async fn is_process_running(pid: u32) -> bool {
let mut system = System::new_all();
system.refresh_processes(ProcessesToUpdate::All, true);

match system.process(Pid::from_u32(pid)) {
Some(process) => !matches!(
process.status(),
sysinfo::ProcessStatus::Stop
| sysinfo::ProcessStatus::Zombie
| sysinfo::ProcessStatus::Dead
),
None => false,
}
}
}

0 comments on commit d9e3000

Please sign in to comment.