Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 58 additions & 15 deletions codex-rs/codex-api/src/endpoint/responses_websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,7 @@ impl ResponsesWebsocketConnection {
let models_etag = self.models_etag.clone();
let server_model = self.server_model.clone();
let telemetry = self.telemetry.clone();
let request_body = serde_json::to_value(&request).map_err(|err| {
ApiError::Stream(format!("failed to encode websocket request: {err}"))
})?;
let request_text = serialize_websocket_request(&request)?;

let current_span = Span::current();
tokio::spawn(
Expand Down Expand Up @@ -261,7 +259,7 @@ impl ResponsesWebsocketConnection {
run_websocket_response_stream(
ws_stream,
tx_event.clone(),
request_body,
request_text,
idle_timeout,
telemetry,
connection_reused,
Expand Down Expand Up @@ -629,7 +627,7 @@ fn json_header_value(value: Value) -> Option<HeaderValue> {
async fn run_websocket_response_stream(
ws_stream: &mut WsStream,
tx_event: mpsc::Sender<std::result::Result<ResponseEvent, ApiError>>,
request_body: Value,
request_text: String,
idle_timeout: Duration,
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
connection_reused: bool,
Expand All @@ -638,7 +636,7 @@ async fn run_websocket_response_stream(
let mut last_server_model: Option<String> = None;
send_websocket_request(
ws_stream,
request_body,
request_text,
idle_timeout,
telemetry.as_ref(),
connection_reused,
Expand Down Expand Up @@ -758,19 +756,11 @@ async fn run_websocket_response_stream(

async fn send_websocket_request(
ws_stream: &WsStream,
request_body: Value,
request_text: String,
idle_timeout: Duration,
telemetry: Option<&Arc<dyn WebsocketTelemetry>>,
connection_reused: bool,
) -> Result<(), ApiError> {
let request_text = match serde_json::to_string(&request_body) {
Ok(text) => text,
Err(err) => {
return Err(ApiError::Stream(format!(
"failed to encode websocket request: {err}"
)));
}
};
trace!("websocket request: {request_text}");

let request_start = Instant::now();
Expand All @@ -797,11 +787,64 @@ async fn send_websocket_request(
Ok(())
}

fn serialize_websocket_request(request: &ResponsesWsRequest) -> Result<String, ApiError> {
Comment thread
jif-oai marked this conversation as resolved.
serde_json::to_string(request)
.map_err(|err| ApiError::Stream(format!("failed to encode websocket request: {err}")))
}

#[cfg(test)]
mod tests {
use super::*;
use crate::common::ResponseCreateWsRequest;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::collections::HashMap;

#[test]
fn direct_serialization_preserves_websocket_request_payload() {
let request = ResponsesWsRequest::ResponseCreate(ResponseCreateWsRequest {
model: "gpt-test".to_string(),
instructions: "Use the available tools.".to_string(),
previous_response_id: Some("resp-1".to_string()),
input: vec![ResponseItem::Message {
id: Some("msg-1".to_string()),
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "hello".to_string(),
}],
phase: None,
}],
tools: vec![json!({
"type": "function",
"name": "lookup",
"parameters": {"type": "object"}
})],
tool_choice: "auto".to_string(),
parallel_tool_calls: true,
reasoning: None,
store: false,
stream: true,
include: vec!["reasoning.encrypted_content".to_string()],
service_tier: Some("priority".to_string()),
prompt_cache_key: Some("cache-key".to_string()),
text: None,
generate: Some(false),
client_metadata: Some(HashMap::from([(
"traceparent".to_string(),
"00-0123456789abcdef0123456789abcdef-0123456789abcdef-01".to_string(),
)])),
});

let previous_payload = serde_json::to_value(&request).expect("serialize previous payload");
let request_text =
serialize_websocket_request(&request).expect("serialize websocket request");
let wire_payload =
serde_json::from_str::<Value>(&request_text).expect("parse websocket request");

assert_eq!(wire_payload, previous_payload);
}

#[test]
fn websocket_config_enables_permessage_deflate() {
Expand Down
Loading