Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add openai reasoning support #859

Merged
merged 3 commits into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::llm_provider::error::LLMProviderError;
use crate::llm_provider::execution::chains::inference_chain_trait::InferenceChainResult;
use crate::llm_provider::job_callback_manager::JobCallbackManager;
use crate::llm_provider::job_manager::JobManager;
use crate::llm_provider::llm_stopper::LLMStopper;
Expand Down Expand Up @@ -239,7 +240,7 @@ impl JobManager {
let start = Instant::now();

// Call the inference chain router to choose which chain to use, and call it
let inference_response = JobManager::inference_chain_router(
let (inference_response, inference_response_content) = match JobManager::inference_chain_router(
db.clone(),
llm_provider_found,
full_job,
Expand All @@ -257,8 +258,21 @@ impl JobManager {
// sqlite_logger.clone(),
llm_stopper.clone(),
)
.await?;
let inference_response_content = inference_response.response.clone();
.await
{
Ok(response) => (response.clone(), response.response),
Err(e) => {
let error_message = format!("{}", e);
// Create a minimal inference response with the error message
let error_response = InferenceChainResult {
response: error_message.clone(),
tps: None,
answer_duration: None,
tool_calls: None,
};
(error_response, error_message)
}
};

let duration = start.elapsed();
shinkai_log(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ impl JobManager {
let response = task_response?;
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Debug,
format!("inference_llm_provider_markdown> response: {:?}", response).as_str(),
ShinkaiLogLevel::Info,
format!("inference_with_llm_provider> response: {:?}", response).as_str(),
);

response
Expand Down
61 changes: 51 additions & 10 deletions shinkai-bin/shinkai-node/src/llm_provider/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::shared::openai_api::{openai_prepare_messages, MessageContent, OpenAIR
use super::LLMService;
use crate::llm_provider::execution::chains::inference_chain_trait::{FunctionCall, LLMInferenceResponse};
use crate::llm_provider::llm_stopper::LLMStopper;
use crate::managers::model_capabilities_manager::PromptResultEnum;
use crate::managers::model_capabilities_manager::{ModelCapabilitiesManager, PromptResultEnum};
use async_trait::async_trait;
use futures::StreamExt;
use reqwest::Client;
Expand Down Expand Up @@ -104,20 +104,32 @@ impl LLMService for OpenAI {
Err(e) => eprintln!("Failed to serialize tools_json: {:?}", e),
};

let mut payload = json!({
"model": self.model_type,
"messages": messages_json,
"max_tokens": result.remaining_output_tokens,
"stream": is_stream,
});
// Set up initial payload with appropriate token limit field based on model capabilities
let mut payload = if ModelCapabilitiesManager::has_reasoning_capabilities(&model) {
json!({
"model": self.model_type,
"messages": messages_json,
"max_completion_tokens": result.remaining_output_tokens,
"stream": is_stream,
})
} else {
json!({
"model": self.model_type,
"messages": messages_json,
"max_tokens": result.remaining_output_tokens,
"stream": is_stream,
})
};

// Conditionally add functions to the payload if tools_json is not empty
if !tools_json.is_empty() {
payload["functions"] = serde_json::Value::Array(tools_json.clone());
}

// Add options to payload
add_options_to_payload(&mut payload, config.as_ref());
// Only add options to payload for non-reasoning models
if !ModelCapabilitiesManager::has_reasoning_capabilities(&model) {
add_options_to_payload(&mut payload, config.as_ref());
}

// Print payload as a pretty JSON string
match serde_json::to_string_pretty(&payload) {
Expand Down Expand Up @@ -248,6 +260,35 @@ pub async fn parse_openai_stream_chunk(
inbox_name: Option<InboxName>,
session_id: &str,
) -> Result<Option<String>, LLMProviderError> {
// If the buffer starts with '{', assume we might be receiving a JSON error.
if buffer.trim_start().starts_with('{') {
match serde_json::from_str::<JsonValue>(buffer) {
Ok(json_data) => {
// If it has an "error" field, record that and return immediately.
if let Some(error_obj) = json_data.get("error") {
let code = error_obj
.get("code")
.and_then(|c| c.as_str())
.unwrap_or("Unknown code")
.to_string();
let msg = error_obj
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown error");
// Clear the buffer since we've consumed it
buffer.clear();
return Ok(Some(format!("{}: {}", code, msg)));
}
// Once parsed, clear the buffer since we've consumed it.
buffer.clear();
}
Err(_) => {
// It's not yet valid JSON (partial) - keep the buffer and wait for more data
return Ok(None);
}
}
}

let mut error_message: Option<String> = None;

loop {
Expand Down Expand Up @@ -571,7 +612,7 @@ pub async fn handle_streaming_response(
}

// Handle WebSocket updates for function calls
if let Some(ref manager) = ws_manager_trait {
if let Some(ref _manager) = ws_manager_trait {
if let Some(ref inbox_name) = inbox_name {
if let Some(last_function_call) = function_calls.last() {
send_tool_ws_update(&ws_manager_trait, Some(inbox_name.clone()), last_function_call)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
use serde_json::{self};
use shinkai_message_primitives::schemas::llm_providers::serialized_llm_provider::LLMProviderInterface;
use shinkai_message_primitives::schemas::prompts::Prompt;
use shinkai_message_primitives::schemas::subprompts::{SubPrompt, SubPromptType};

use super::shared_model_logic;

Expand Down Expand Up @@ -106,6 +107,17 @@ pub struct Usage {
}

pub fn openai_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) -> Result<PromptResult, LLMProviderError> {
let mut prompt = prompt.clone();

// If this is a reasoning model, filter out system prompts before any processing
if ModelCapabilitiesManager::has_reasoning_capabilities(model) {
prompt.sub_prompts.retain(|sp| match sp {
SubPrompt::Content(SubPromptType::System, _, _) => false,
SubPrompt::Omni(SubPromptType::System, _, _, _) => false,
_ => true,
});
}

let max_input_tokens = ModelCapabilitiesManager::get_max_input_tokens(model);

// Generate the messages and filter out images
Expand All @@ -125,25 +137,35 @@ pub fn openai_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) ->
.into_iter()
.partition(|message| message.role.is_some());


// Convert both sets of messages to serde Value
let messages_json = serde_json::to_value(messages_with_role)?;

// Convert tools to serde Value with name transformation
let tools_json = serde_json::to_value(
tools.clone().into_iter().map(|mut tool| {
if let Some(functions) = tool.functions.as_mut() {
for function in functions {
// Replace any characters that aren't alphanumeric, underscore, or hyphen
function.name = function.name
.chars()
.map(|c| if c.is_alphanumeric() || c == '_' || c == '-' { c } else { '_' })
.collect::<String>()
.to_lowercase();
tools
.clone()
.into_iter()
.map(|mut tool| {
if let Some(functions) = tool.functions.as_mut() {
for function in functions {
// Replace any characters that aren't alphanumeric, underscore, or hyphen
function.name = function
.name
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '_' || c == '-' {
c
} else {
'_'
}
})
.collect::<String>()
.to_lowercase();
}
}
}
tool
}).collect::<Vec<_>>()
tool
})
.collect::<Vec<_>>(),
)?;

// Convert messages_json and tools_json to Vec<serde_json::Value>
Expand Down Expand Up @@ -211,7 +233,10 @@ pub fn openai_prepare_messages(model: &LLMProviderInterface, prompt: Prompt) ->
})
}

pub fn openai_prepare_messages_gemini(model: &LLMProviderInterface, prompt: Prompt) -> Result<PromptResult, LLMProviderError> {
pub fn openai_prepare_messages_gemini(
model: &LLMProviderInterface,
prompt: Prompt,
) -> Result<PromptResult, LLMProviderError> {
let max_input_tokens = ModelCapabilitiesManager::get_max_input_tokens(model);

// Generate the messages and filter out images
Expand Down Expand Up @@ -525,7 +550,10 @@ mod tests {
let response: OpenAIResponse = serde_json::from_str(response_text).expect("Failed to deserialize");

// Verify basic response fields
assert_eq!(response.id.clone().unwrap(), "chatcmpl-0cae310a-2b36-470a-9261-0f24d77b01bc");
assert_eq!(
response.id.clone().unwrap(),
"chatcmpl-0cae310a-2b36-470a-9261-0f24d77b01bc"
);
assert_eq!(response.object, "chat.completion");
assert_eq!(response.created, 1736736692);
assert_eq!(response.system_fingerprint, Some("fp_9cb648b966".to_string()));
Expand All @@ -540,10 +568,10 @@ mod tests {
let message = &choice.message;
assert_eq!(message.role, "assistant");
assert!(message.content.is_none());

let tool_calls = message.tool_calls.as_ref().expect("Should have tool_calls");
assert_eq!(tool_calls.len(), 1);

let tool_call = &tool_calls[0];
assert_eq!(tool_call.id, "call_sa3n");
assert_eq!(tool_call.call_type, "function");
Expand All @@ -563,4 +591,47 @@ mod tests {
let groq = response.groq.expect("Should have Groq info");
assert_eq!(groq.id, "req_01jhes5nvkedsb8hcw0x912fa6");
}

#[test]
fn test_system_prompt_filtering() {
// Create a prompt with both Content and Omni system prompts
let sub_prompts = vec![
SubPrompt::Content(
SubPromptType::System,
"System prompt that should be filtered".to_string(),
98,
),
SubPrompt::Content(SubPromptType::User, "User message that should remain".to_string(), 100),
SubPrompt::Omni(
SubPromptType::UserLastMessage,
"Last user message that should remain".to_string(),
vec![],
100,
),
];

let mut prompt = Prompt::new();
prompt.add_sub_prompts(sub_prompts);

// Create a mock model with reasoning capabilities
let model = SerializedLLMProvider::mock_provider_with_reasoning().model;

// Process the prompt
let result = openai_prepare_messages(&model, prompt).expect("Failed to prepare messages");

// Extract the messages from the result
let messages = match &result.messages {
PromptResultEnum::Value(value) => value.as_array().unwrap(),
_ => panic!("Expected Value variant"),
};

// Verify that only non-system messages remain
assert_eq!(messages.len(), 2, "Should only have 2 messages after filtering");

// Check that the remaining messages are the user messages
for message in messages {
let role = message["role"].as_str().unwrap();
assert_eq!(role, "user", "All remaining messages should be user messages");
}
}
}
34 changes: 21 additions & 13 deletions shinkai-bin/shinkai-node/src/managers/model_capabilities_manager.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
use crate::llm_provider::{
error::LLMProviderError,
providers::shared::{openai_api::openai_prepare_messages, shared_model_logic::llama_prepare_messages},
error::LLMProviderError, providers::shared::{openai_api::openai_prepare_messages, shared_model_logic::llama_prepare_messages}
};
use shinkai_message_primitives::{
schemas::{
llm_message::LlmMessage,
llm_providers::{
common_agent_llm_provider::ProviderOrAgent,
serialized_llm_provider::{LLMProviderInterface, SerializedLLMProvider},
},
prompts::Prompt,
shinkai_name::ShinkaiName,
},
shinkai_utils::utils::count_tokens_from_message_llama3,
llm_message::LlmMessage, llm_providers::{
common_agent_llm_provider::ProviderOrAgent, serialized_llm_provider::{LLMProviderInterface, SerializedLLMProvider}
}, prompts::Prompt, shinkai_name::ShinkaiName
}, shinkai_utils::utils::count_tokens_from_message_llama3
};
use shinkai_sqlite::SqliteManager;
use std::{
fmt,
sync::{Arc, Weak},
fmt, sync::{Arc, Weak}
};

#[derive(Debug)]
Expand Down Expand Up @@ -726,6 +719,21 @@ impl ModelCapabilitiesManager {
_ => false,
}
}

/// Returns whether the given model has reasoning capabilities
pub fn has_reasoning_capabilities(model: &LLMProviderInterface) -> bool {
match model {
LLMProviderInterface::OpenAI(openai) => {
openai.model_type.starts_with("o1")
|| openai.model_type.starts_with("o2")
|| openai.model_type.starts_with("o3")
|| openai.model_type.starts_with("o4")
|| openai.model_type.starts_with("o5")
}
LLMProviderInterface::Ollama(ollama) => ollama.model_type.starts_with("deepseek-r1"),
_ => false,
}
}
}

// TODO: add a tokenizer library only in the dev env and test that the
Expand Down
Loading