Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 151 additions & 43 deletions src/pipelines/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,33 +78,73 @@ pub async fn chat_completions(
) -> Result<impl IntoResponse, StatusCode> {
let mut tracer = OtelTracer::start("chat", &payload);

for model_key in model_keys {
let model = model_registry.get(&model_key).unwrap();

if payload.model == model.model_type {
let response = model
.chat_completions(payload.clone())
.await
.inspect_err(|e| {
eprintln!("Chat completion error for model {}: {:?}", model_key, e);
})?;

if let ChatCompletionResponse::NonStream(completion) = response {
tracer.log_success(&completion);
return Ok(Json(completion).into_response());
let matching_models: Vec<_> = model_keys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The filtering of matching models (lines 81-90) is repeated in three functions. Consider extracting this logic into a helper function to reduce code duplication.

.iter()
.filter_map(|key| {
let model = model_registry.get(key)?;
if payload.model == model.model_type {
Some((key.clone(), model))
} else {
None
}
})
.collect();

if let ChatCompletionResponse::Stream(stream) = response {
return Ok(Sse::new(trace_and_stream(tracer, stream))
.keep_alive(KeepAlive::default())
.into_response());
if matching_models.is_empty() {
tracer.log_error("No matching model found".to_string());
eprintln!("No matching model found for: {}", payload.model);
return Err(StatusCode::NOT_FOUND);
}

let mut last_error = None;

for (model_key, model) in matching_models {
match model.chat_completions(payload.clone()).await {
Ok(response) => match response {
ChatCompletionResponse::NonStream(completion) => {
tracer.log_success(&completion);
return Ok(Json(completion).into_response());
}
ChatCompletionResponse::Stream(stream) => {
return Ok(Sse::new(trace_and_stream(tracer, stream))
.keep_alive(KeepAlive::default())
.into_response());
}
},
Err(status_code) => {
eprintln!(
"Chat completion error for model {}: {:?}",
model_key, status_code
);

if is_transient_error(status_code) {
eprintln!(
"Transient error for model {}, trying next model...",
model_key
);
last_error = Some(status_code);
continue;
} else {
return Err(status_code);
}
}
}
}

tracer.log_error("No matching model found".to_string());
eprintln!("No matching model found for: {}", payload.model);
Err(StatusCode::NOT_FOUND)
let error = last_error.unwrap();
tracer.log_error(format!("All models failed with error: {}", error));
Err(error)
}

fn is_transient_error(status_code: StatusCode) -> bool {
matches!(
status_code,
StatusCode::TOO_MANY_REQUESTS | // 429
StatusCode::REQUEST_TIMEOUT | // 408
StatusCode::SERVICE_UNAVAILABLE | // 503
StatusCode::BAD_GATEWAY | // 502
StatusCode::GATEWAY_TIMEOUT // 504
)
}

pub async fn completions(
Expand All @@ -114,21 +154,55 @@ pub async fn completions(
) -> impl IntoResponse {
let mut tracer = OtelTracer::start("completion", &payload);

for model_key in model_keys {
let model = model_registry.get(&model_key).unwrap();
let matching_models: Vec<_> = model_keys
.iter()
.filter_map(|key| {
let model = model_registry.get(key)?;
if payload.model == model.model_type {
Some((key.clone(), model))
} else {
None
}
})
.collect();

if matching_models.is_empty() {
tracer.log_error("No matching model found".to_string());
eprintln!("No matching model found for: {}", payload.model);
return Err(StatusCode::NOT_FOUND);
}

let mut last_error = None;

for (model_key, model) in matching_models {
match model.completions(payload.clone()).await {
Ok(response) => {
tracer.log_success(&response);
return Ok(Json(response));
}
Err(status_code) => {
eprintln!(
"Completion error for model {}: {:?}",
model_key, status_code
);

if payload.model == model.model_type {
let response = model.completions(payload.clone()).await.inspect_err(|e| {
eprintln!("Completion error for model {}: {:?}", model_key, e);
})?;
tracer.log_success(&response);
return Ok(Json(response));
if is_transient_error(status_code) {
eprintln!(
"Transient error for model {}, trying next model...",
model_key
);
last_error = Some(status_code);
continue;
} else {
return Err(status_code);
}
}
}
}

tracer.log_error("No matching model found".to_string());
eprintln!("No matching model found for: {}", payload.model);
Err(StatusCode::NOT_FOUND)
let error = last_error.unwrap();
tracer.log_error(format!("All models failed with error: {}", error));
Err(error)
}

pub async fn embeddings(
Expand All @@ -138,19 +212,53 @@ pub async fn embeddings(
) -> impl IntoResponse {
let mut tracer = OtelTracer::start("embeddings", &payload);

for model_key in model_keys {
let model = model_registry.get(&model_key).unwrap();
let matching_models: Vec<_> = model_keys
.iter()
.filter_map(|key| {
let model = model_registry.get(key)?;
if payload.model == model.model_type {
Some((key.clone(), model))
} else {
None
}
})
.collect();

if payload.model == model.model_type {
let response = model.embeddings(payload.clone()).await.inspect_err(|e| {
eprintln!("Embeddings error for model {}: {:?}", model_key, e);
})?;
tracer.log_success(&response);
return Ok(Json(response));
if matching_models.is_empty() {
tracer.log_error("No matching model found".to_string());
eprintln!("No matching model found for: {}", payload.model);
return Err(StatusCode::NOT_FOUND);
}

let mut last_error = None;

for (model_key, model) in matching_models {
match model.embeddings(payload.clone()).await {
Ok(response) => {
tracer.log_success(&response);
return Ok(Json(response));
}
Err(status_code) => {
eprintln!(
"Embeddings error for model {}: {:?}",
model_key, status_code
);

if is_transient_error(status_code) {
eprintln!(
"Transient error for model {}, trying next model...",
model_key
);
last_error = Some(status_code);
continue;
} else {
return Err(status_code);
}
}
}
}

tracer.log_error("No matching model found".to_string());
eprintln!("No matching model found for: {}", payload.model);
Err(StatusCode::NOT_FOUND)
let error = last_error.unwrap();
tracer.log_error(format!("All models failed with error: {}", error));
Err(error)
}