Skip to content

Commit

Permalink
fix: handle OpenAI API errors better (#1291)
Browse files Browse the repository at this point in the history
  • Loading branch information
akx authored Feb 23, 2025
1 parent 9693b40 commit 45b3812
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 22 deletions.
47 changes: 47 additions & 0 deletions crates/goose/src/providers/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,50 @@ impl GoogleErrorCode {
}
}
}

#[derive(serde::Deserialize)]
pub struct OpenAIError {
pub code: Option<String>,
pub message: Option<String>,
#[serde(rename = "type")]
pub error_type: Option<String>,
}

impl OpenAIError {
pub fn is_context_length_exceeded(&self) -> bool {
if let Some(code) = &self.code {
code == "context_length_exceeded" || code == "string_above_max_length"
} else {
false
}
}
}

impl std::fmt::Display for OpenAIError {
/// Format the error for display.
/// E.g. {"message": "Invalid API key", "code": "invalid_api_key", "type": "client_error"}
/// would be formatted as "Invalid API key (code: invalid_api_key, type: client_error)"
/// and {"message": "Foo"} as just "Foo", etc.
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(message) = &self.message {
write!(f, "{}", message)?;
}
let mut in_parenthesis = false;
if let Some(code) = &self.code {
write!(f, " (code: {}", code)?;
in_parenthesis = true;
}
if let Some(typ) = &self.error_type {
if in_parenthesis {
write!(f, ", type: {}", typ)?;
} else {
write!(f, " (type: {}", typ)?;
in_parenthesis = true;
}
}
if in_parenthesis {
write!(f, ")")?;
}
Ok(())
}
}
38 changes: 16 additions & 22 deletions crates/goose/src/providers/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@ use base64::Engine;
use regex::Regex;
use reqwest::{Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Map, Value};
use serde_json::{from_value, json, Map, Value};
use std::io::Read;
use std::path::Path;

use crate::providers::errors::ProviderError;
use crate::providers::errors::{OpenAIError, ProviderError};
use mcp_core::content::ImageContent;

#[derive(serde::Deserialize)]
struct OpenAIErrorResponse {
error: OpenAIError,
}

#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
pub enum ImageFormat {
OpenAi,
Expand Down Expand Up @@ -55,29 +60,18 @@ pub async fn handle_response_openai_compat(response: Response) -> Result<Value,
Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}. Response: {:?}", status, payload)))
}
StatusCode::BAD_REQUEST => {
let mut message = "Unknown error".to_string();
if let Some(error) = payload.get("error") {
tracing::debug!("Bad Request Error: {error:?}");
message = error
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown error")
.to_string();

if let Some(code) = error.get("code").and_then(|c| c.as_str()) {
if code == "context_length_exceeded" || code == "string_above_max_length" {
return Err(ProviderError::ContextLengthExceeded(message));
}
}
}
StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND => {
tracing::debug!(
"{}", format!("Provider request failed with status: {}. Payload: {:?}", status, payload)
);
Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", status, message)))
}
StatusCode::NOT_FOUND => {
Err(ProviderError::RequestFailed(format!("{:?}", payload)))
if let Ok(err_resp) = from_value::<OpenAIErrorResponse>(payload) {
let err = err_resp.error;
if err.is_context_length_exceeded() {
return Err(ProviderError::ContextLengthExceeded(err.message.unwrap_or("Unknown error".to_string())));
}
return Err(ProviderError::RequestFailed(format!("{} (status {})", err, status.as_u16())));
}
Err(ProviderError::RequestFailed(format!("Unknown error (status {})", status)))
}
StatusCode::TOO_MANY_REQUESTS => {
Err(ProviderError::RateLimitExceeded(format!("{:?}", payload)))
Expand Down

0 comments on commit 45b3812

Please sign in to comment.