Skip to content

Commit

Permalink
feat: support OpenAI reasoning effort configuration for O1/O3 models (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
salman1993 authored Feb 17, 2025
1 parent 855adfb commit aa32c4a
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 4 deletions.
123 changes: 121 additions & 2 deletions crates/goose/src/providers/formats/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,28 @@ pub fn create_request(
let is_o1 = model_config.model_name.starts_with("o1");
let is_o3 = model_config.model_name.starts_with("o3");

// Only extract reasoning effort for O1/O3 models
let (model_name, reasoning_effort) = if is_o1 || is_o3 {
let parts: Vec<&str> = model_config.model_name.split('-').collect();
let last_part = parts.last().unwrap();

match *last_part {
"low" | "medium" | "high" => {
let base_name = parts[..parts.len() - 1].join("-");
(base_name, Some(last_part.to_string()))
}
_ => (
model_config.model_name.to_string(),
Some("medium".to_string()),
),
}
} else {
// For non-O family models, use the model name as is and no reasoning effort
(model_config.model_name.to_string(), None)
};

let system_message = json!({
"role": if is_o1 { "developer" } else { "system" },
"role": if is_o1 || is_o3 { "developer" } else { "system" },
"content": system
});

Expand All @@ -349,10 +369,17 @@ pub fn create_request(
messages_array.extend(messages_spec);

let mut payload = json!({
"model": model_config.model_name,
"model": model_name,
"messages": messages_array
});

if let Some(effort) = reasoning_effort {
payload
.as_object_mut()
.unwrap()
.insert("reasoning_effort".to_string(), json!(effort));
}

if !tools_spec.is_empty() {
payload
.as_object_mut()
Expand Down Expand Up @@ -778,4 +805,96 @@ mod tests {

Ok(())
}

#[test]
fn test_create_request_gpt_4o() -> anyhow::Result<()> {
// Test default medium reasoning effort for O3 model
let model_config = ModelConfig {
model_name: "gpt-4o".to_string(),
tokenizer_name: "gpt-4o".to_string(),
context_limit: Some(4096),
temperature: None,
max_tokens: Some(1024),
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
let obj = request.as_object().unwrap();
let expected = json!({
"model": "gpt-4o",
"messages": [
{
"role": "system",
"content": "system"
}
],
"max_tokens": 1024
});

for (key, value) in expected.as_object().unwrap() {
assert_eq!(obj.get(key).unwrap(), value);
}

Ok(())
}

#[test]
fn test_create_request_o1_default() -> anyhow::Result<()> {
// Test default medium reasoning effort for O1 model
let model_config = ModelConfig {
model_name: "o1".to_string(),
tokenizer_name: "o1".to_string(),
context_limit: Some(4096),
temperature: None,
max_tokens: Some(1024),
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
let obj = request.as_object().unwrap();
let expected = json!({
"model": "o1",
"messages": [
{
"role": "developer",
"content": "system"
}
],
"reasoning_effort": "medium",
"max_completion_tokens": 1024
});

for (key, value) in expected.as_object().unwrap() {
assert_eq!(obj.get(key).unwrap(), value);
}

Ok(())
}

#[test]
fn test_create_request_o3_custom_reasoning_effort() -> anyhow::Result<()> {
// Test custom reasoning effort for O3 model
let model_config = ModelConfig {
model_name: "o3-mini-high".to_string(),
tokenizer_name: "o3-mini".to_string(),
context_limit: Some(4096),
temperature: None,
max_tokens: Some(1024),
};
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
let obj = request.as_object().unwrap();
let expected = json!({
"model": "o3-mini",
"messages": [
{
"role": "developer",
"content": "system"
}
],
"reasoning_effort": "high",
"max_completion_tokens": 1024
});

for (key, value) in expected.as_object().unwrap() {
assert_eq!(obj.get(key).unwrap(), value);
}

Ok(())
}
}
4 changes: 2 additions & 2 deletions crates/goose/tests/truncate_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ mod tests {
async fn test_truncate_agent_with_openai() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::OpenAi,
model: "gpt-4o-mini",
context_window: 128_000,
model: "o3-mini-low",
context_window: 200_000,
})
.await
}
Expand Down

0 comments on commit aa32c4a

Please sign in to comment.