Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: providers filter goose messages #637

Closed
wants to merge 8 commits into from
1 change: 1 addition & 0 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ async fn stream_message(
}
}
}
Role::Goose => (),
}
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/agents/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub enum SystemError {
Initialization(SystemConfig),
#[error("Failed a client call to an MCP server: {0}")]
Client(#[from] ClientError),
#[error("Messages exceeded context-limit and could not be truncated to fit.")]
#[error("User Message exceeded context-limit. History could not be truncated to accomodate.")]
ContextLimit,
#[error("Transport error: {0}")]
Transport(#[from] mcp_client::transport::Error),
Expand Down
51 changes: 23 additions & 28 deletions crates/goose/src/agents/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@ use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
use crate::register_agent;
use crate::token_counter::TokenCounter;
use mcp_core::Tool;
use mcp_core::{Role, Tool};
use serde_json::Value;

/// Agent impl. that truncates oldest messages when payload over LLM ctx-limit
pub struct TruncateAgent {
capabilities: Mutex<Capabilities>,
_token_counter: TokenCounter,
token_counter: TokenCounter,
}

impl TruncateAgent {
pub fn new(provider: Box<dyn Provider>) -> Self {
Self {
capabilities: Mutex::new(Capabilities::new(provider)),
_token_counter: TokenCounter::new(),
token_counter: TokenCounter::new(),
}
}

async fn prepare_inference(
async fn enforce_ctx_limit_pre_flight(
&self,
system_prompt: &str,
tools: &[Tool],
Expand All @@ -46,25 +46,22 @@ impl TruncateAgent {

let model = Some(model_name);
let approx_count =
self._token_counter
self.token_counter
.count_everything(system_prompt, messages, tools, &resources, model);

let mut new_messages = messages.to_vec();
if approx_count > target_limit {
let user_msg_size = self.text_content_size(new_messages.last(), model);
if user_msg_size > target_limit {
debug!(
"[WARNING] User message {} exceeds token budget {}.",
user_msg_size,
user_msg_size - target_limit
);
return Err(SystemError::ContextLimit);
}

new_messages = self.chop_front_messages(messages, approx_count, target_limit, model);

if new_messages.is_empty() {
return Err(SystemError::ContextLimit);
}

// add goose message
let alert_val = "Some of the oldest messages in the conversation history \
have been truncated to keep history within the LLM context-limit.";
let alert_msg = Message::goose().with_text(alert_val);
new_messages.push(alert_msg);
}

Ok(new_messages)
Expand All @@ -76,7 +73,7 @@ impl TruncateAgent {
.and_then(|c| c.as_text());

if let Some(txt) = text {
let count = self._token_counter.count_tokens(txt, model);
let count = self.token_counter.count_tokens(txt, model);
return count;
}

Expand Down Expand Up @@ -130,7 +127,7 @@ impl Agent for TruncateAgent {
let mut capabilities = self.capabilities.lock().await;
let tools = capabilities.get_prefixed_tools().await?;
let system_prompt = capabilities.get_system_prompt().await;
let _estimated_limit = capabilities
let estimated_limit = capabilities
.provider()
.get_model_config()
.get_estimated_limit();
Expand All @@ -146,11 +143,11 @@ impl Agent for TruncateAgent {

// Update conversation history for the start of the reply
let mut messages = self
.prepare_inference(
.enforce_ctx_limit_pre_flight(
&system_prompt,
&tools,
messages,
_estimated_limit,
estimated_limit,
&capabilities.provider().get_model_config().model_name,
&mut capabilities.get_resources().await?,
)
Expand All @@ -170,6 +167,13 @@ impl Agent for TruncateAgent {
// Yield the assistant's response
yield response.clone();

// if ctx limit added goose message yield it
if let Some(msg) = messages.last() {
if msg.role == Role::Goose {
yield messages.last().unwrap().clone();
}
}

tokio::task::yield_now().await;

// First collect any tool requests
Expand Down Expand Up @@ -204,15 +208,6 @@ impl Agent for TruncateAgent {

yield message_tool_response.clone();

messages = self.prepare_inference(
&system_prompt,
&tools,
&messages,
_estimated_limit,
&capabilities.provider().get_model_config().model_name,
&mut capabilities.get_resources().await?
).await?;

messages.push(response);
messages.push(message_tool_response);
}
Expand Down
9 changes: 9 additions & 0 deletions crates/goose/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ impl Message {
}
}

/// Create a new user message with the current timestamp
pub fn goose() -> Self {
Message {
role: Role::Goose,
created: Utc::now().timestamp(),
content: Vec::new(),
}
}

/// Add any MessageContent to the message
pub fn with_content(mut self, content: MessageContent) -> Self {
self.content.push(content);
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ impl AnthropicProvider {
let role = match message.role {
Role::User => "user",
Role::Assistant => "assistant",
Role::Goose => continue,
};

let mut content = Vec::new();
Expand Down
2 changes: 2 additions & 0 deletions crates/goose/src/providers/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ impl GoogleProvider {
fn messages_to_google_spec(&self, messages: &[Message]) -> Vec<Value> {
messages
.iter()
.filter(|message| message.role != Role::Goose)
.map(|message| {
let role = if message.role == Role::User {
"user"
Expand Down Expand Up @@ -457,6 +458,7 @@ mod tests {
let messages = vec![
set_up_text_message("Hello", Role::User),
set_up_text_message("World", Role::Assistant),
set_up_text_message("Please don't notice me.", Role::Goose),
];
let payload = provider.messages_to_google_spec(&messages);
assert_eq!(payload.len(), 2);
Expand Down
11 changes: 9 additions & 2 deletions crates/goose/src/providers/openai_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ pub fn messages_to_openai_spec(
) -> Vec<Value> {
let mut messages_spec = Vec::new();
for message in messages {
if message.role == Role::Goose {
continue;
}

let mut converted = json!({
"role": message.role
});
Expand Down Expand Up @@ -378,8 +382,11 @@ mod tests {

#[test]
fn test_messages_to_openai_spec() -> anyhow::Result<()> {
let message = Message::user().with_text("Hello");
let spec = messages_to_openai_spec(&[message], &ImageFormat::OpenAi, false);
let messages = vec![
Message::user().with_text("Hello"),
Message::goose().with_text("Please don't notice me."),
];
let spec = messages_to_openai_spec(&messages, &ImageFormat::OpenAi, false);

assert_eq!(spec.len(), 1);
assert_eq!(spec[0]["role"], "user");
Expand Down
1 change: 1 addition & 0 deletions crates/mcp-core/src/role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ use serde::{Deserialize, Serialize};
pub enum Role {
User,
Assistant,
Goose,
}
Loading