Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 136 additions & 5 deletions codex-rs/core/src/chat_completions.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::sync::Arc;
use std::time::Duration;

use crate::ModelProviderInfo;
Expand All @@ -12,6 +13,7 @@ use crate::error::Result;
use crate::error::RetryLimitReachedError;
use crate::error::UnexpectedResponseError;
use crate::model_family::ModelFamily;
use crate::protocol::TokenUsage;
use crate::tools::spec::create_tools_json_for_chat_completions_api;
use crate::util::backoff;
use bytes::Bytes;
Expand All @@ -20,6 +22,7 @@ use codex_protocol::models::ContentItem;
use codex_protocol::models::FunctionCallOutputContentItem;
use codex_protocol::models::ReasoningItemContent;
use codex_protocol::models::ResponseItem;
use codex_utils_tokenizer::Tokenizer;
use eventsource_stream::Eventsource;
use futures::Stream;
use futures::StreamExt;
Expand All @@ -34,6 +37,102 @@ use tokio::time::timeout;
use tracing::debug;
use tracing::trace;

struct ChatUsageHeuristic {
tokenizer: Arc<Tokenizer>,
input_tokens: i64,
output_tokens: i64,
reasoning_tokens: i64,
}

impl ChatUsageHeuristic {
fn new(model: &str, messages: &[serde_json::Value]) -> Option<Self> {
let tokenizer = match Tokenizer::for_model(model) {
Ok(tok) => tok,
Err(err) => {
debug!(
"failed to build tokenizer for model {model}; falling back to default: {err:?}"
);
match Tokenizer::try_default() {
Ok(tok) => tok,
Err(fallback_err) => {
debug!(
"failed to fall back to default tokenizer for model {model}: {fallback_err:?}"
);
return None;
}
}
}
};

let tokenizer = Arc::new(tokenizer);
let mut input_tokens =
4_i64.saturating_mul(i64::try_from(messages.len()).unwrap_or(i64::MAX));

for message in messages {
input_tokens =
input_tokens.saturating_add(Self::count_value_tokens(tokenizer.as_ref(), message));

if let Some(tool_calls) = message.get("tool_calls").and_then(|v| v.as_array()) {
input_tokens = input_tokens.saturating_add(
8_i64.saturating_mul(i64::try_from(tool_calls.len()).unwrap_or(i64::MAX)),
);
}
}

Some(Self {
tokenizer,
input_tokens,
output_tokens: 0,
reasoning_tokens: 0,
})
}

fn record_output(&mut self, text: &str) {
if text.is_empty() {
return;
}
self.output_tokens = self
.output_tokens
.saturating_add(self.tokenizer.count(text));
}

fn record_reasoning(&mut self, text: &str) {
if text.is_empty() {
return;
}
self.reasoning_tokens = self
.reasoning_tokens
.saturating_add(self.tokenizer.count(text));
}

fn to_usage(&self) -> TokenUsage {
let total = self
.input_tokens
.saturating_add(self.output_tokens)
.saturating_add(self.reasoning_tokens);
TokenUsage {
input_tokens: self.input_tokens,
cached_input_tokens: 0,
output_tokens: self.output_tokens,
reasoning_output_tokens: self.reasoning_tokens,
total_tokens: total,
}
}

fn count_value_tokens(tokenizer: &Tokenizer, value: &serde_json::Value) -> i64 {
match value {
serde_json::Value::String(s) => tokenizer.count(s),
serde_json::Value::Array(items) => items.iter().fold(0_i64, |acc, item| {
acc.saturating_add(Self::count_value_tokens(tokenizer, item))
}),
serde_json::Value::Object(map) => map.values().fold(0_i64, |acc, item| {
acc.saturating_add(Self::count_value_tokens(tokenizer, item))
}),
_ => 0,
}
}
}

/// Implementation for the classic Chat Completions API.
pub(crate) async fn stream_chat_completions(
prompt: &Prompt,
Expand Down Expand Up @@ -325,6 +424,8 @@ pub(crate) async fn stream_chat_completions(
}

let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
let usage_heuristic = ChatUsageHeuristic::new(model_family.slug.as_str(), &messages);

let payload = json!({
"model": model_family.slug,
"messages": messages,
Expand Down Expand Up @@ -368,6 +469,7 @@ pub(crate) async fn stream_chat_completions(
tx_event,
provider.stream_idle_timeout(),
otel_event_manager.clone(),
usage_heuristic,
));
return Ok(ResponseStream { rx_event });
}
Expand Down Expand Up @@ -421,6 +523,7 @@ async fn process_chat_sse<S>(
tx_event: mpsc::Sender<Result<ResponseEvent>>,
idle_timeout: Duration,
otel_event_manager: OtelEventManager,
mut usage_heuristic: Option<ChatUsageHeuristic>,
) where
S: Stream<Item = Result<Bytes>> + Unpin,
{
Expand Down Expand Up @@ -459,10 +562,11 @@ async fn process_chat_sse<S>(
}
Ok(None) => {
// Stream closed gracefully – emit Completed with dummy id.
let token_usage = usage_heuristic.as_ref().map(ChatUsageHeuristic::to_usage);
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
token_usage,
}))
.await;
return;
Expand Down Expand Up @@ -505,10 +609,11 @@ async fn process_chat_sse<S>(
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}

let token_usage = usage_heuristic.as_ref().map(ChatUsageHeuristic::to_usage);
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
token_usage,
}))
.await;
return;
Expand All @@ -532,6 +637,9 @@ async fn process_chat_sse<S>(
&& !content.is_empty()
{
assistant_text.push_str(content);
if let Some(usage) = usage_heuristic.as_mut() {
usage.record_output(content);
}
let _ = tx_event
.send(Ok(ResponseEvent::OutputTextDelta(content.to_string())))
.await;
Expand Down Expand Up @@ -565,6 +673,9 @@ async fn process_chat_sse<S>(
if let Some(reasoning) = maybe_text {
// Accumulate so we can emit a terminal Reasoning item at the end.
reasoning_text.push_str(&reasoning);
if let Some(usage) = usage_heuristic.as_mut() {
usage.record_reasoning(&reasoning);
}
let _ = tx_event
.send(Ok(ResponseEvent::ReasoningContentDelta(reasoning)))
.await;
Expand All @@ -578,6 +689,9 @@ async fn process_chat_sse<S>(
if let Some(s) = message_reasoning.as_str() {
if !s.is_empty() {
reasoning_text.push_str(s);
if let Some(usage) = usage_heuristic.as_mut() {
usage.record_reasoning(s);
}
let _ = tx_event
.send(Ok(ResponseEvent::ReasoningContentDelta(s.to_string())))
.await;
Expand All @@ -590,6 +704,9 @@ async fn process_chat_sse<S>(
&& !s.is_empty()
{
reasoning_text.push_str(s);
if let Some(usage) = usage_heuristic.as_mut() {
usage.record_reasoning(s);
}
let _ = tx_event
.send(Ok(ResponseEvent::ReasoningContentDelta(s.to_string())))
.await;
Expand All @@ -608,18 +725,31 @@ async fn process_chat_sse<S>(

// Extract call_id if present.
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
if fn_call_state.call_id.is_none() {
if let Some(usage) = usage_heuristic.as_mut() {
usage.record_output(id);
}
fn_call_state.call_id = Some(id.to_string());
}
}

// Extract function details if present.
if let Some(function) = tool_call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
fn_call_state.name.get_or_insert_with(|| name.to_string());
if fn_call_state.name.is_none() {
if let Some(usage) = usage_heuristic.as_mut() {
usage.record_output(name);
}
fn_call_state.name = Some(name.to_string());
}
}

if let Some(args_fragment) = function.get("arguments").and_then(|a| a.as_str())
{
fn_call_state.arguments.push_str(args_fragment);
if let Some(usage) = usage_heuristic.as_mut() {
usage.record_output(args_fragment);
}
}
}
}
Expand Down Expand Up @@ -682,10 +812,11 @@ async fn process_chat_sse<S>(
}

// Emit Completed regardless of reason so the agent can advance.
let token_usage = usage_heuristic.as_ref().map(ChatUsageHeuristic::to_usage);
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
token_usage,
}))
.await;

Expand Down
6 changes: 4 additions & 2 deletions codex-rs/core/src/openai_model_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ impl ModelInfo {
}

pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
let slug = model_family.slug.as_str();
match slug {
let raw_slug = model_family.slug.as_str();
let slug = raw_slug.strip_prefix("openai/").unwrap_or(raw_slug);
let normalized_slug = slug.replace(':', "-");
match normalized_slug.as_str() {
// OSS models have a 128k shared token pool.
// Arbitrarily splitting it: 3/4 input context, 1/4 output.
// https://openai.com/index/gpt-oss-model-card/
Expand Down
43 changes: 43 additions & 0 deletions codex-rs/core/tests/chat_completions_sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,49 @@ async fn streams_text_without_reasoning() {
assert_matches!(events[2], ResponseEvent::Completed { .. });
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn completed_event_includes_usage_estimate() {
if network_disabled() {
println!(
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
);
return;
}

let sse = concat!(
"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n",
"data: {\"choices\":[{\"delta\":{}}]}\n\n",
"data: [DONE]\n\n",
);

let events = run_stream(sse).await;
assert_eq!(events.len(), 3, "unexpected events: {events:?}");

let usage = events
.iter()
.find_map(|event| match event {
ResponseEvent::Completed {
token_usage: Some(usage),
..
} => Some(usage.clone()),
_ => None,
})
.expect("missing usage estimate on Completed event");

assert!(
usage.input_tokens > 0,
"expected input tokens > 0, got {usage:?}"
);
assert!(
usage.output_tokens > 0,
"expected output tokens > 0, got {usage:?}"
);
assert!(
usage.total_tokens >= usage.input_tokens + usage.output_tokens,
"expected total tokens to cover input + output, got {usage:?}"
);
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn streams_reasoning_from_string_delta() {
if network_disabled() {
Expand Down
Loading