Skip to content

Commit

Permalink
cancellable bash commands during cli or goosed
Browse files Browse the repository at this point in the history
  • Loading branch information
jsibbison-square committed Nov 27, 2024
1 parent a306117 commit 0cc8fc8
Show file tree
Hide file tree
Showing 10 changed files with 551 additions and 326 deletions.
7 changes: 4 additions & 3 deletions crates/goose-cli/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use goose::{agent::Agent as GooseAgent, models::message::Message, systems::System};
use tokio::sync::watch;

#[async_trait]
pub trait Agent {
fn add_system(&mut self, system: Box<dyn System>);
async fn reply(&self, messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>>;
async fn reply(&self, messages: &[Message], cancel_rx: watch::Receiver<bool>) -> Result<BoxStream<'_, Result<Message>>>;
}

#[async_trait]
Expand All @@ -15,7 +16,7 @@ impl Agent for GooseAgent {
self.add_system(system);
}

async fn reply(&self, messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>> {
self.reply(messages).await
async fn reply(&self, messages: &[Message], cancel_rx: watch::Receiver<bool>) -> Result<BoxStream<'_, Result<Message>>> {
self.reply(messages, cancel_rx).await
}
}
3 changes: 2 additions & 1 deletion crates/goose-cli/src/agents/mock_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use goose::{models::message::Message, systems::System};
use tokio::sync::watch;

use crate::agents::agent::Agent;

Expand All @@ -13,7 +14,7 @@ impl Agent for MockAgent {
()
}

async fn reply(&self, _messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>> {
async fn reply(&self, _messages: &[Message], _cancel_rx: watch::Receiver<bool>) -> Result<BoxStream<'_, Result<Message>>> {
Ok(Box::pin(futures::stream::empty()))
}
}
4 changes: 3 additions & 1 deletion crates/goose-cli/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ impl<'a> Session<'a> {
}

async fn agent_process_messages(&mut self) {
let mut stream = match self.agent.reply(&self.messages).await {
let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(false);
let mut stream = match self.agent.reply(&self.messages, cancel_rx).await {
Ok(stream) => stream,
Err(e) => {
eprintln!("Error starting reply stream: {}", e);
Expand Down Expand Up @@ -166,6 +167,7 @@ impl<'a> Session<'a> {
}
_ = tokio::signal::ctrl_c() => {
drop(stream);
cancel_tx.send(true).unwrap();
self.rewind_messages();
self.prompt.render(raw_message(" Interrupt: Resetting conversation to before the last sent message...\n"));
break;
Expand Down
18 changes: 15 additions & 3 deletions crates/goose-cli/src/systems/goose_hints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ use anyhow::Result as AnyhowResult;
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::fs;

use goose::errors::{AgentError, AgentResult};
use goose::models::content::Content;
use goose::models::tool::Tool;
use goose::models::tool::ToolCall;
use goose::systems::System;
use goose::systems::{System, CancellableOperation, CancelFn};

#[derive(Clone)]
pub struct GooseHintsSystem {
Expand Down Expand Up @@ -74,8 +75,19 @@ impl System for GooseHintsSystem {
Ok(HashMap::new())
}

async fn call(&self, tool_call: ToolCall) -> AgentResult<Vec<Content>> {
Err(AgentError::ToolNotFound(tool_call.name))
async fn call(&self, tool_call: ToolCall) -> CancellableOperation {
// No-op cancel function since this system doesn't create long-running processes
let cancel_fn: CancelFn = Arc::new(|| {});

// Create the future that will execute the tool call
let future = Box::pin(async move {
Err(AgentError::ToolNotFound(tool_call.name))
});

CancellableOperation {
cancel: cancel_fn,
future,
}
}
}

Expand Down
50 changes: 39 additions & 11 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ use serde_json::{json, Value};
use std::{
convert::Infallible,
pin::Pin,
task::{Context, Poll},
task::{Context, Poll}, time::Duration,
};
use tokio::time::{sleep, timeout};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;

Expand Down Expand Up @@ -280,9 +281,10 @@ async fn handler(
// Convert incoming messages
let messages = convert_messages(request.messages);

let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(false);
// Spawn task to handle streaming
tokio::spawn(async move {
let mut stream = match agent.reply(&messages).await {
let mut stream = match agent.reply(&messages, cancel_rx).await {
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {}", e);
Expand All @@ -292,16 +294,41 @@ async fn handler(
}
};

while let Some(response) = stream.next().await {
match response {
Ok(message) => {
if let Err(e) = stream_message(message, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
break;
// Create a once-off timer
let timer = sleep(Duration::from_secs(5));
tokio::pin!(timer); // Pin the timer so it can be used in `tokio::select!`
loop {
tokio::select! {
response = timeout(Duration::from_millis(500), stream.next()) => {
match response {
Ok(Some(Ok(message))) => {
let tx_clone = tx.clone();
tokio::spawn(async move {
if let Err(e) = stream_message(message, &tx_clone).await {
tracing::error!("Error sending message through channel: {}", e);
}
});
tracing::info!("Message sent.");
}
Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e);
break;
}
Ok(None) => {
tracing::info!("Stream ended.");
break;
},
Err(_) => {
tracing::warn!("stream check timeout");
},
}
}
Err(e) => {
tracing::error!("Error processing message: {}", e);
_ = &mut timer => {
println!("S Timeout!!!");
cancel_tx.send(true).unwrap();
drop(stream);
tracing::warn!("Timeout reached while waiting for the next message.");
println!("E Timeout!!!.");
break;
}
}
Expand Down Expand Up @@ -337,9 +364,10 @@ async fn ask_handler(
// Create a single message for the prompt
let messages = vec![Message::user().with_text(request.prompt)];

let (_cancel_tx, cancel_rx) = tokio::sync::watch::channel(false);
// Get response from agent
let mut response_text = String::new();
let mut stream = match agent.reply(&messages).await {
let mut stream = match agent.reply(&messages, cancel_rx).await {
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {}", e);
Expand Down
Loading

0 comments on commit 0cc8fc8

Please sign in to comment.