Skip to content

Commit

Permalink
Merge pull request #859 from dcSpark/nico/fix_errors_2
Browse files Browse the repository at this point in the history
add openai reasoning support
  • Loading branch information
nicarq authored Feb 9, 2025
2 parents c6905ca + e0a86e9 commit 8173a1c
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 64 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::llm_provider::error::LLMProviderError;
use crate::llm_provider::execution::chains::inference_chain_trait::InferenceChainResult;
use crate::llm_provider::job_callback_manager::JobCallbackManager;
use crate::llm_provider::job_manager::JobManager;
use crate::llm_provider::llm_stopper::LLMStopper;
Expand Down Expand Up @@ -239,7 +240,7 @@ impl JobManager {
let start = Instant::now();

// Call the inference chain router to choose which chain to use, and call it
let inference_response = JobManager::inference_chain_router(
let (inference_response, inference_response_content) = match JobManager::inference_chain_router(
db.clone(),
llm_provider_found,
full_job,
Expand All @@ -257,8 +258,21 @@ impl JobManager {
// sqlite_logger.clone(),
llm_stopper.clone(),
)
.await?;
let inference_response_content = inference_response.response.clone();
.await
{
Ok(response) => (response.clone(), response.response),
Err(e) => {
let error_message = format!("{}", e);
// Create a minimal inference response with the error message
let error_response = InferenceChainResult {
response: error_message.clone(),
tps: None,
answer_duration: None,
tool_calls: None,
};
(error_response, error_message)
}
};

let duration = start.elapsed();
shinkai_log(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ impl JobManager {
let response = task_response?;
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Debug,
format!("inference_llm_provider_markdown> response: {:?}", response).as_str(),
ShinkaiLogLevel::Info,
format!("inference_with_llm_provider> response: {:?}", response).as_str(),
);

response
Expand Down
61 changes: 51 additions & 10 deletions shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::shared::openai_api::{openai_prepare_messages, MessageContent, OpenAIR
use super::LLMService;
use crate::llm_provider::execution::chains::inference_chain_trait::{FunctionCall, LLMInferenceResponse};
use crate::llm_provider::llm_stopper::LLMStopper;
use crate::managers::model_capabilities_manager::PromptResultEnum;
use crate::managers::model_capabilities_manager::{ModelCapabilitiesManager, PromptResultEnum};
use async_trait::async_trait;
use futures::StreamExt;
use reqwest::Client;
Expand Down Expand Up @@ -104,20 +104,32 @@ impl LLMService for OpenAI {
Err(e) => eprintln!("Failed to serialize tools_json: {:?}", e),
};

let mut payload = json!({
"model": self.model_type,
"messages": messages_json,
"max_tokens": result.remaining_output_tokens,
"stream": is_stream,
});
// Set up initial payload with appropriate token limit field based on model capabilities
let mut payload = if ModelCapabilitiesManager::has_reasoning_capabilities(&model) {
json!({
"model": self.model_type,
"messages": messages_json,
"max_completion_tokens": result.remaining_output_tokens,
"stream": is_stream,
})
} else {
json!({
"model": self.model_type,
"messages": messages_json,
"max_tokens": result.remaining_output_tokens,
"stream": is_stream,
})
};

// Conditionally add functions to the payload if tools_json is not empty
if !tools_json.is_empty() {
payload["functions"] = serde_json::Value::Array(tools_json.clone());
}

// Add options to payload
add_options_to_payload(&mut payload, config.as_ref());
// Only add options to payload for non-reasoning models
if !ModelCapabilitiesManager::has_reasoning_capabilities(&model) {
add_options_to_payload(&mut payload, config.as_ref());
}

// Print payload as a pretty JSON string
match serde_json::to_string_pretty(&payload) {
Expand Down Expand Up @@ -248,6 +260,35 @@ pub async fn parse_openai_stream_chunk(
inbox_name: Option<InboxName>,
session_id: &str,
) -> Result<Option<String>, LLMProviderError> {
// If the buffer starts with '{', assume we might be receiving a JSON error.
if buffer.trim_start().starts_with('{') {
match serde_json::from_str::<JsonValue>(buffer) {
Ok(json_data) => {
// If it has an "error" field, record that and return immediately.
if let Some(error_obj) = json_data.get("error") {
let code = error_obj
.get("code")
.and_then(|c| c.as_str())
.unwrap_or("Unknown code")
.to_string();
let msg = error_obj
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown error");
// Clear the buffer since we've consumed it
buffer.clear();
return Ok(Some(format!("{}: {}", code, msg)));
}
// Once parsed, clear the buffer since we've consumed it.
buffer.clear();
}
Err(_) => {
// It's not yet valid JSON (partial) - keep the buffer and wait for more data
return Ok(None);
}
}
}

let mut error_message: Option<String> = None;

loop {
Expand Down Expand Up @@ -571,7 +612,7 @@ pub async fn handle_streaming_response(
}

// Handle WebSocket updates for function calls
if let Some(ref manager) = ws_manager_trait {
if let Some(ref _manager) = ws_manager_trait {
if let Some(ref inbox_name) = inbox_name {
if let Some(last_function_call) = function_calls.last() {
send_tool_ws_update(&ws_manager_trait, Some(inbox_name.clone()), last_function_call)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
use serde_json::{self};
use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::LLMProviderInterface;
use shinkai_message_primitives::schemas::prompts::Prompt;
use shinkai_message_primitives::schemas::subprompts::{SubPrompt, SubPromptType};

use super::shared_model_logic;

Expand Down Expand Up @@ -106,6 +107,17 @@ pub struct Usage {
}

pub fn openai_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) -> Result<PromptResult, LLMProviderError> {
let mut prompt = prompt.clone();

// If this is a reasoning model, filter out system prompts before any processing
if ModelCapabilitiesManager::has_reasoning_capabilities(model) {
prompt.sub_prompts.retain(|sp| match sp {
SubPrompt::Content(SubPromptType::System, _, _) => false,
SubPrompt::Omni(SubPromptType::System, _, _, _) => false,
_ => true,
});
}

let max_input_tokens = ModelCapabilitiesManager::get_max_input_tokens(model);

// Generate the messages and filter out images
Expand All @@ -125,25 +137,35 @@ pub fn openai_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) ->
.into_iter()
.partition(|message| message.role.is_some());


// Convert both sets of messages to serde Value
let messages_json = serde_json::to_value(messages_with_role)?;

// Convert tools to serde Value with name transformation
let tools_json = serde_json::to_value(
tools.clone().into_iter().map(|mut tool| {
if let Some(functions) = tool.functions.as_mut() {
for function in functions {
// Replace any characters that aren't alphanumeric, underscore, or hyphen
function.name = function.name
.chars()
.map(|c| if c.is_alphanumeric() || c == '_' || c == '-' { c } else { '_' })
.collect::<String>()
.to_lowercase();
tools
.clone()
.into_iter()
.map(|mut tool| {
if let Some(functions) = tool.functions.as_mut() {
for function in functions {
// Replace any characters that aren't alphanumeric, underscore, or hyphen
function.name = function
.name
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '_' || c == '-' {
c
} else {
'_'
}
})
.collect::<String>()
.to_lowercase();
}
}
}
tool
}).collect::<Vec<_>>()
tool
})
.collect::<Vec<_>>(),
)?;

// Convert messages_json and tools_json to Vec<serde_json::Value>
Expand Down Expand Up @@ -211,7 +233,10 @@ pub fn openai_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) ->
})
}

pub fn openai_prepare_messages_gemini(model: &LLMProviderInterface, prompt: Prompt) -> Result<PromptResult, LLMProviderError> {
pub fn openai_prepare_messages_gemini(
model: &LLMProviderInterface,
prompt: Prompt,
) -> Result<PromptResult, LLMProviderError> {
let max_input_tokens = ModelCapabilitiesManager::get_max_input_tokens(model);

// Generate the messages and filter out images
Expand Down Expand Up @@ -525,7 +550,10 @@ mod tests {
let response: OpenAIResponse = serde_json::from_str(response_text).expect("Failed to deserialize");

// Verify basic response fields
assert_eq!(response.id.clone().unwrap(), "chatcmpl-0cae310a-2b36-470a-9261-0f24d77b01bc");
assert_eq!(
response.id.clone().unwrap(),
"chatcmpl-0cae310a-2b36-470a-9261-0f24d77b01bc"
);
assert_eq!(response.object, "chat.completion");
assert_eq!(response.created, 1736736692);
assert_eq!(response.system_fingerprint, Some("fp_9cb648b966".to_string()));
Expand All @@ -540,10 +568,10 @@ mod tests {
let message = &choice.message;
assert_eq!(message.role, "assistant");
assert!(message.content.is_none());

let tool_calls = message.tool_calls.as_ref().expect("Should have tool_calls");
assert_eq!(tool_calls.len(), 1);

let tool_call = &tool_calls[0];
assert_eq!(tool_call.id, "call_sa3n");
assert_eq!(tool_call.call_type, "function");
Expand All @@ -563,4 +591,47 @@ mod tests {
let groq = response.groq.expect("Should have Groq info");
assert_eq!(groq.id, "req_01jhes5nvkedsb8hcw0x912fa6");
}

#[test]
fn test_system_prompt_filtering() {
// Create a prompt with both Content and Omni system prompts
let sub_prompts = vec![
SubPrompt::Content(
SubPromptType::System,
"System prompt that should be filtered".to_string(),
98,
),
SubPrompt::Content(SubPromptType::User, "User message that should remain".to_string(), 100),
SubPrompt::Omni(
SubPromptType::UserLastMessage,
"Last user message that should remain".to_string(),
vec![],
100,
),
];

let mut prompt = Prompt::new();
prompt.add_sub_prompts(sub_prompts);

// Create a mock model with reasoning capabilities
let model = SerializedLLMProvider::mock_provider_with_reasoning().model;

// Process the prompt
let result = openai_prepare_messages(&model, prompt).expect("Failed to prepare messages");

// Extract the messages from the result
let messages = match &result.messages {
PromptResultEnum::Value(value) => value.as_array().unwrap(),
_ => panic!("Expected Value variant"),
};

// Verify that only non-system messages remain
assert_eq!(messages.len(), 2, "Should only have 2 messages after filtering");

// Check that the remaining messages are the user messages
for message in messages {
let role = message["role"].as_str().unwrap();
assert_eq!(role, "user", "All remaining messages should be user messages");
}
}
}
34 changes: 21 additions & 13 deletions shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
use crate::llm_provider::{
error::LLMProviderError,
providers::shared::{openai_api::openai_prepare_messages, shared_model_logic::llama_prepare_messages},
error::LLMProviderError, providers::shared::{openai_api::openai_prepare_messages, shared_model_logic::llama_prepare_messages}
};
use shinkai_message_primitives::{
schemas::{
llm_message::LlmMessage,
llm_providers::{
common_agent_llm_provider::ProviderOrAgent,
serialized_llm_provider::{LLMProviderInterface, SerializedLLMProvider},
},
prompts::Prompt,
shinkai_name::ShinkaiName,
},
shinkai_utils::utils::count_tokens_from_message_llama3,
llm_message::LlmMessage, llm_providers::{
common_agent_llm_provider::ProviderOrAgent, serialized_llm_provider::{LLMProviderInterface, SerializedLLMProvider}
}, prompts::Prompt, shinkai_name::ShinkaiName
}, shinkai_utils::utils::count_tokens_from_message_llama3
};
use shinkai_sqlite::SqliteManager;
use std::{
fmt,
sync::{Arc, Weak},
fmt, sync::{Arc, Weak}
};

#[derive(Debug)]
Expand Down Expand Up @@ -726,6 +719,21 @@ impl ModelCapabilitiesManager {
_ => false,
}
}

/// Returns whether the given model has reasoning capabilities
pub fn has_reasoning_capabilities(model: &LLMProviderInterface) -> bool {
match model {
LLMProviderInterface::OpenAI(openai) => {
openai.model_type.starts_with("o1")
|| openai.model_type.starts_with("o2")
|| openai.model_type.starts_with("o3")
|| openai.model_type.starts_with("o4")
|| openai.model_type.starts_with("o5")
}
LLMProviderInterface::Ollama(ollama) => ollama.model_type.starts_with("deepseek-r1"),
_ => false,
}
}
}

// TODO: add a tokenizer library only in the dev env and test that the
Expand Down
Loading

0 comments on commit 8173a1c

Please sign in to comment.