Skip to content

Commit

Permalink
fix: clean up providers (#650)
Browse files Browse the repository at this point in the history
  • Loading branch information
baxen authored Jan 20, 2025
1 parent 67c9c75 commit f8a577c
Show file tree
Hide file tree
Showing 24 changed files with 1,412 additions and 2,873 deletions.
31 changes: 4 additions & 27 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,6 @@ jobs:
sudo apt update -y
sudo apt install -y libdbus-1-dev gnome-keyring libxcb1-dev
- name: Start gnome-keyring
# run gnome-keyring with 'foobar' as password for the login keyring
# this will create a new login keyring and unlock it
# the login password doesn't matter, but the keyring must be unlocked for the tests to work
run: |
gnome-keyring-daemon --components=secrets --daemonize --unlock <<< 'foobar'
- name: Setup Rust
uses: dtolnay/rust-toolchain@stable
with:
Expand Down Expand Up @@ -76,27 +69,11 @@ jobs:
restore-keys: |
${{ runner.os }}-cargo-build-
- name: Install Ollama
run: curl -fsSL https://ollama.com/install.sh | sh

- name: Start Ollama
- name: Build and Test
run: |
# Run the background, in a way that survives to the next step
nohup ollama serve > ollama.log 2>&1 &
# Block using the ready endpoint
time curl --retry 5 --retry-connrefused --retry-delay 1 -sf http://localhost:11434
- name: Test Ollama Model
run: ollama run qwen2.5 hello || cat ollama.log

- name: Build Rust Project
run: cargo build

- name: Run Tests
run: cargo test --verbose
env:
OLLAMA_MODEL: "qwen2.5"

gnome-keyring-daemon --components=secrets --daemonize --unlock <<< 'foobar'
cargo test
## TODO: Need to decide if we wanna error out on clippy warnings. It was not being used before.
# - name: Run Cargo Clippy (Lint)
# run: cargo clippy -- -D warnings
Expand Down
3 changes: 0 additions & 3 deletions crates/goose-cli/src/log_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ pub fn log_usage(session_file: String, usage: Vec<ProviderUsage>) {
#[cfg(test)]
mod tests {
use goose::providers::base::{ProviderUsage, Usage};
use rust_decimal_macros::dec;

use crate::{
log_usage::{log_usage, SessionLog},
Expand All @@ -72,7 +71,6 @@ mod tests {
vec![ProviderUsage::new(
"model".to_string(),
Usage::new(Some(10), Some(20), Some(30)),
Some(dec!(0.5)),
)],
);

Expand All @@ -87,7 +85,6 @@ mod tests {
assert_eq!(log.usage[0].usage.output_tokens, Some(20));
assert_eq!(log.usage[0].usage.total_tokens, Some(30));
assert_eq!(log.usage[0].model, "model");
assert_eq!(log.usage[0].cost, Some(dec!(0.5)));

// Remove the log file after test
std::fs::remove_file(&log_file).ok();
Expand Down
80 changes: 14 additions & 66 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use axum::{
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use goose::message::{Message, MessageContent};
use goose::providers::base::ModerationError;

use mcp_core::{content::Content, role::Role};
use serde::Deserialize;
use serde_json::{json, Value};
Expand Down Expand Up @@ -166,18 +166,6 @@ impl ProtocolFormatter {
format!("3:{}\n", encoded_error)
}

fn format_moderation_error(error: &ModerationError) -> String {
let error_part = match error {
ModerationError::ContentFlagged { categories, .. } => {
format!(
"Content was flagged by moderation in the following categories: {}",
categories
)
}
};
format!("3:\"{}\"\n", error_part)
}

fn format_finish(reason: &str) -> String {
// Finish messages start with "d:"
let finish = json!({
Expand Down Expand Up @@ -324,19 +312,9 @@ async fn handler(
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {}", e);
// Check if it's a moderation error
if let Some(moderation_error) = e.downcast_ref::<ModerationError>() {
let _ = tx
.send(ProtocolFormatter::format_moderation_error(moderation_error))
.await;
// Kill the stream since we encountered a moderation error
} else {
// Send a generic error message
let _ = tx
.send(ProtocolFormatter::format_error(&e.to_string()))
.await;
}
// Send a finish message with error as the reason
let _ = tx
.send(ProtocolFormatter::format_error(&e.to_string()))
.await;
let _ = tx.send(ProtocolFormatter::format_finish("error")).await;
return;
}
Expand All @@ -355,12 +333,7 @@ async fn handler(
}
Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e);
// Check if it's a moderation error
if let Some(moderation_error) = e.downcast_ref::<ModerationError>() {
let _ = tx.send(ProtocolFormatter::format_moderation_error(moderation_error)).await;
} else {
let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await;
}
let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await;
break;
}
Ok(None) => {
Expand Down Expand Up @@ -467,7 +440,7 @@ mod tests {
use super::*;
use goose::{
agents::DefaultAgent as Agent,
providers::base::{Moderation, ModerationResult, Provider, ProviderUsage, Usage},
providers::base::{Provider, ProviderUsage, Usage},
providers::configs::ModelConfig,
};
use mcp_core::tool::Tool;
Expand All @@ -480,37 +453,27 @@ mod tests {

#[async_trait::async_trait]
impl Provider for MockProvider {
async fn complete_internal(
fn get_model_config(&self) -> &ModelConfig {
&self.model_config
}

async fn complete(
&self,
_system_prompt: &str,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), anyhow::Error> {
) -> anyhow::Result<(Message, ProviderUsage)> {
Ok((
Message::assistant().with_text("Mock response"),
ProviderUsage::new("mock".to_string(), Usage::default(), None),
ProviderUsage::new("mock".to_string(), Usage::default()),
))
}

fn get_model_config(&self) -> &ModelConfig {
&self.model_config
}

fn get_usage(&self, _data: &Value) -> anyhow::Result<Usage> {
Ok(Usage::new(None, None, None))
}
}

#[async_trait::async_trait]
impl Moderation for MockProvider {
async fn moderate_content(
&self,
_content: &str,
) -> Result<ModerationResult, anyhow::Error> {
Ok(ModerationResult::new(false, None, None))
}
}

#[test]
fn test_convert_messages_user_only() {
let incoming = vec![IncomingMessage {
Expand Down Expand Up @@ -584,21 +547,6 @@ mod tests {
assert!(formatted.starts_with("3:"));
assert!(formatted.contains("Test error"));

// Test moderation error formatting
let moderation_error = ModerationError::ContentFlagged {
categories: "hate, violence".to_string(),
category_scores: Some(json!({
"hate": 0.9,
"violence": 0.8
})),
};
let formatted = ProtocolFormatter::format_moderation_error(&moderation_error);
println!("{}", formatted);
assert!(formatted.starts_with("3:"));
assert!(
formatted.contains("Content was flagged by moderation in the following categories:")
);

// Test finish formatting
let formatted = ProtocolFormatter::format_finish("stop");
assert!(formatted.starts_with("d:"));
Expand Down
9 changes: 0 additions & 9 deletions crates/goose/.env.example

This file was deleted.

2 changes: 1 addition & 1 deletion crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,4 @@ path = "examples/databricks_oauth.rs"

[[bench]]
name = "tokenization_benchmark"
harness = false
harness = false
6 changes: 0 additions & 6 deletions crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use chrono::{DateTime, TimeZone, Utc};
use mcp_client::McpService;
use rust_decimal_macros::dec;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::LazyLock;
Expand Down Expand Up @@ -193,11 +192,6 @@ impl Capabilities {
e.usage.total_tokens = Some(
e.usage.total_tokens.unwrap_or(0) + usage.usage.total_tokens.unwrap_or(0),
);
if e.cost.is_none() || usage.cost.is_none() {
e.cost = None; // Pricing is not available for all models
} else {
e.cost = Some(e.cost.unwrap_or(dec!(0)) + usage.cost.unwrap_or(dec!(0)));
}
})
.or_insert_with(|| usage.clone());
});
Expand Down
Loading

0 comments on commit f8a577c

Please sign in to comment.