diff --git a/crates/goose/src/providers/errors.rs b/crates/goose/src/providers/errors.rs index 1536107d9..fb8bf4d4a 100644 --- a/crates/goose/src/providers/errors.rs +++ b/crates/goose/src/providers/errors.rs @@ -74,3 +74,50 @@ impl GoogleErrorCode { } } } + +#[derive(serde::Deserialize)] +pub struct OpenAIError { + pub code: Option, + pub message: Option, + #[serde(rename = "type")] + pub error_type: Option, +} + +impl OpenAIError { + pub fn is_context_length_exceeded(&self) -> bool { + if let Some(code) = &self.code { + code == "context_length_exceeded" || code == "string_above_max_length" + } else { + false + } + } +} + +impl std::fmt::Display for OpenAIError { + /// Format the error for display. + /// E.g. {"message": "Invalid API key", "code": "invalid_api_key", "type": "client_error"} + /// would be formatted as "Invalid API key (code: invalid_api_key, type: client_error)" + /// and {"message": "Foo"} as just "Foo", etc. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(message) = &self.message { + write!(f, "{}", message)?; + } + let mut in_parenthesis = false; + if let Some(code) = &self.code { + write!(f, " (code: {}", code)?; + in_parenthesis = true; + } + if let Some(typ) = &self.error_type { + if in_parenthesis { + write!(f, ", type: {}", typ)?; + } else { + write!(f, " (type: {}", typ)?; + in_parenthesis = true; + } + } + if in_parenthesis { + write!(f, ")")?; + } + Ok(()) + } +} diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 65502617f..ef652b635 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -5,13 +5,18 @@ use base64::Engine; use regex::Regex; use reqwest::{Response, StatusCode}; use serde::{Deserialize, Serialize}; -use serde_json::{json, Map, Value}; +use serde_json::{from_value, json, Map, Value}; use std::io::Read; use std::path::Path; -use crate::providers::errors::ProviderError; +use crate::providers::errors::{OpenAIError, ProviderError}; use mcp_core::content::ImageContent; +#[derive(serde::Deserialize)] +struct OpenAIErrorResponse { + error: OpenAIError, +} + #[derive(Debug, Copy, Clone, Serialize, Deserialize)] pub enum ImageFormat { OpenAi, @@ -55,29 +60,18 @@ pub async fn handle_response_openai_compat(response: Response) -> Result { - let mut message = "Unknown error".to_string(); - if let Some(error) = payload.get("error") { - tracing::debug!("Bad Request Error: {error:?}"); - message = error - .get("message") - .and_then(|m| m.as_str()) - .unwrap_or("Unknown error") - .to_string(); - - if let Some(code) = error.get("code").and_then(|c| c.as_str()) { - if code == "context_length_exceeded" || code == "string_above_max_length" { - return Err(ProviderError::ContextLengthExceeded(message)); - } - } - } + StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND => { tracing::debug!( "{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload) ); - Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, message))) - } - StatusCode::NOT_FOUND => { - Err(ProviderError::RequestFailed(format!("{:?}", payload))) + if let Ok(err_resp) = from_value::(payload) { + let err = err_resp.error; + if err.is_context_length_exceeded() { + return Err(ProviderError::ContextLengthExceeded(err.message.unwrap_or("Unknown error".to_string()))); + } + return Err(ProviderError::RequestFailed(format!("{} (status {})", err, status.as_u16()))); + } + Err(ProviderError::RequestFailed(format!("Unknown error (status {})", status))) } StatusCode::TOO_MANY_REQUESTS => { Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))