Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/config/entities/providers-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"bedrock",
"cohere",
"deepseek",
"fireworks-ai",
"gemini",
"groq",
"xai",
Expand Down Expand Up @@ -54,6 +55,7 @@
"anthropic",
"cohere",
"deepseek",
"fireworks-ai",
"gemini",
"groq",
"xai",
Expand Down
8 changes: 8 additions & 0 deletions src/config/entities/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ pub enum ProviderConfig {
Cohere(configs::CohereProviderConfig),
#[serde(rename = "deepseek")]
DeepSeek(configs::DeepSeekProviderConfig),
#[serde(rename = "fireworks-ai")]
FireworksAi(configs::FireworksAiProviderConfig),
#[serde(rename = "gemini")]
Gemini(configs::GeminiProviderConfig),
#[serde(rename = "groq")]
Expand All @@ -54,6 +56,7 @@ impl ProviderConfig {
Self::Bedrock(_) => identifiers::BEDROCK,
Self::Cohere(_) => identifiers::COHERE,
Self::DeepSeek(_) => identifiers::DEEPSEEK,
Self::FireworksAi(_) => identifiers::FIREWORKS_AI,
Self::Gemini(_) => identifiers::GEMINI,
Self::Groq(_) => identifiers::GROQ,
Self::Xai(_) => identifiers::XAI,
Expand Down Expand Up @@ -157,6 +160,11 @@ mod tests {
"type": "cohere",
"config": { "api_key": "test_key" }
}), true, None)]
#[case::fireworks_ai_ok(json!({
"name": "fireworks-primary",
"type": "fireworks-ai",
"config": { "api_key": "test_key" }
}), true, None)]
#[case::openrouter_ok(json!({
"name": "openrouter-primary",
"type": "openrouter",
Expand Down
171 changes: 171 additions & 0 deletions src/gateway/providers/fireworks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
use http::{HeaderMap, HeaderValue, header::AUTHORIZATION};
use serde::{Deserialize, Serialize};
use serde_json::Value;

use crate::gateway::{
error::{GatewayError, Result},
provider_instance::ProviderAuth,
traits::{ChatTransform, EmbedTransform, ProviderCapabilities, ProviderMeta},
types::{
embed::{EmbedRequestBody, EmbeddingRequest},
openai::ChatCompletionRequest,
},
};

/// Fireworks AI currently uses its OpenAI-compatible inference API.
/// Docs: https://docs.fireworks.ai/tools-sdks/openai-compatibility
pub const IDENTIFIER: &str = "fireworks-ai";

#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct FireworksAiProviderConfig {
pub api_key: String,

#[serde(skip_serializing_if = "Option::is_none")]
pub api_base: Option<String>,
}

pub struct FireworksAi;

impl ProviderMeta for FireworksAi {
fn name(&self) -> &'static str {
IDENTIFIER
}

fn default_base_url(&self) -> &'static str {
"https://api.fireworks.ai/inference/v1"
}

fn build_auth_headers(&self, auth: &ProviderAuth) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
let value = HeaderValue::from_str(&format!("Bearer {}", auth.api_key_for(self.name())?))
.map_err(|error| GatewayError::Validation(error.to_string()))?;
headers.insert(AUTHORIZATION, value);
Ok(headers)
}
}

impl ChatTransform for FireworksAi {
fn transform_request(&self, request: &ChatCompletionRequest) -> Result<Value> {
let mut body = serde_json::to_value(request)
.map_err(|error| GatewayError::Transform(error.to_string()))?;

if let Value::Object(map) = &mut body {
// Fireworks defaults to truncating max_tokens on context overflow.
// Set the documented override so requests keep OpenAI-style error semantics.
map.entry("context_length_exceeded_behavior")
.or_insert_with(|| Value::String("error".into()));
}

Ok(body)
}
}

impl EmbedTransform for FireworksAi {
fn transform_embeddings_request(&self, request: &EmbeddingRequest) -> Result<EmbedRequestBody> {
let mut body = serde_json::to_value(request)
.map_err(|error| GatewayError::Transform(error.to_string()))?;

if let Value::Object(map) = &mut body {
// Fireworks documents dimensions, return_logits, normalize, and prompt_template,
// but not OpenAI's encoding_format or user fields on /embeddings.
map.remove("encoding_format");
map.remove("user");
}

Ok(EmbedRequestBody::Json(body))
}
}

impl ProviderCapabilities for FireworksAi {
fn as_embed_transform(&self) -> Option<&dyn EmbedTransform> {
Some(self)
}
}

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use serde_json::json;

use super::FireworksAi;
use crate::gateway::{
provider_instance::ProviderAuth,
traits::{ChatTransform, EmbedTransform, ProviderCapabilities, ProviderMeta},
types::{
embed::{EmbedRequestBody, EmbeddingRequest},
openai::ChatCompletionRequest,
},
};

#[test]
fn provider_metadata_and_urls_are_correct() {
let provider = FireworksAi;
let headers = provider
.build_auth_headers(&ProviderAuth::ApiKey("fw-key".into()))
.unwrap();

assert_eq!(provider.name(), "fireworks-ai");
assert_eq!(
provider.default_base_url(),
"https://api.fireworks.ai/inference/v1"
);
assert_eq!(headers["authorization"], "Bearer fw-key");
assert_eq!(
provider.build_url(provider.default_base_url(), "ignored"),
"https://api.fireworks.ai/inference/v1/chat/completions"
);
assert!(provider.as_embed_transform().is_some());
}

#[test]
fn transform_request_defaults_to_openai_context_length_behavior() {
let provider = FireworksAi;
let request: ChatCompletionRequest = serde_json::from_value(json!({
"model": "accounts/fireworks/models/kimi-k2-instruct-0905",
"messages": [{"role": "user", "content": "hello"}]
}))
.unwrap();

let transformed = provider.transform_request(&request).unwrap();

assert_eq!(transformed["context_length_exceeded_behavior"], "error");
}

#[test]
fn transform_request_preserves_explicit_context_length_behavior() {
let provider = FireworksAi;
let request: ChatCompletionRequest = serde_json::from_value(json!({
"model": "accounts/fireworks/models/kimi-k2-instruct-0905",
"messages": [{"role": "user", "content": "hello"}],
"context_length_exceeded_behavior": "truncate"
}))
.unwrap();

let transformed = provider.transform_request(&request).unwrap();

assert_eq!(transformed["context_length_exceeded_behavior"], "truncate");
}

#[test]
fn transform_embeddings_request_strips_unsupported_fields() {
let provider = FireworksAi;
let request: EmbeddingRequest = serde_json::from_value(json!({
"model": "fireworks/qwen3-embedding-8b",
"input": ["hello"],
"dimensions": 128,
"encoding_format": "float",
"user": "user-123"
}))
.unwrap();

let body = provider.transform_embeddings_request(&request).unwrap();

match body {
EmbedRequestBody::Json(value) => {
assert_eq!(value["dimensions"], 128);
assert_eq!(value.get("encoding_format"), None);
assert_eq!(value.get("user"), None);
}
}
}
}
14 changes: 10 additions & 4 deletions src/gateway/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod azure;
pub mod bedrock;
pub mod cohere;
pub mod deepseek;
pub mod fireworks;
pub mod gemini;
pub mod groq;
pub mod macros;
Expand All @@ -16,6 +17,7 @@ pub use azure::AzureDef;
pub use bedrock::BedrockDef;
pub use cohere::Cohere;
pub use deepseek::DeepSeek;
pub use fireworks::FireworksAi;
pub use gemini::GoogleDef;
pub use groq::Groq;
pub use mistral::Mistral;
Expand All @@ -25,14 +27,16 @@ pub use xai::Xai;

pub mod identifiers {
use super::{
anthropic, azure, bedrock, cohere, deepseek, gemini, groq, mistral, openai, openrouter, xai,
anthropic, azure, bedrock, cohere, deepseek, fireworks, gemini, groq, mistral, openai,
openrouter, xai,
};

pub const ANTHROPIC: &str = anthropic::IDENTIFIER;
pub const AZURE: &str = azure::IDENTIFIER;
pub const BEDROCK: &str = bedrock::IDENTIFIER;
pub const COHERE: &str = cohere::IDENTIFIER;
pub const DEEPSEEK: &str = deepseek::IDENTIFIER;
pub const FIREWORKS_AI: &str = fireworks::IDENTIFIER;
pub const GEMINI: &str = gemini::IDENTIFIER;
pub const GROQ: &str = groq::IDENTIFIER;
pub const MISTRAL: &str = mistral::IDENTIFIER;
Expand All @@ -45,9 +49,9 @@ pub mod configs {
pub use super::{
anthropic::AnthropicProviderConfig, azure::AzureProviderConfig,
bedrock::BedrockProviderConfig, cohere::CohereProviderConfig,
deepseek::DeepSeekProviderConfig, gemini::GeminiProviderConfig, groq::GroqProviderConfig,
mistral::MistralProviderConfig, openai::OpenAIProviderConfig,
openrouter::OpenRouterProviderConfig, xai::XaiProviderConfig,
deepseek::DeepSeekProviderConfig, fireworks::FireworksAiProviderConfig,
gemini::GeminiProviderConfig, groq::GroqProviderConfig, mistral::MistralProviderConfig,
openai::OpenAIProviderConfig, openrouter::OpenRouterProviderConfig, xai::XaiProviderConfig,
};
}

Expand All @@ -60,6 +64,7 @@ pub fn default_provider_registry() -> Result<ProviderRegistry> {
.register(BedrockDef)?
.register(Cohere)?
.register(DeepSeek)?
.register(FireworksAi)?
.register(GoogleDef)?
.register(Groq)?
.register(Mistral)?
Expand All @@ -84,6 +89,7 @@ mod tests {
assert_eq!(registry.get("anthropic").unwrap().name(), "anthropic");
assert_eq!(registry.get("bedrock").unwrap().name(), "bedrock");
assert_eq!(registry.get("cohere").unwrap().name(), "cohere");
assert_eq!(registry.get("fireworks-ai").unwrap().name(), "fireworks-ai");
assert_eq!(registry.get("gemini").unwrap().name(), "gemini");
assert_eq!(registry.get("groq").unwrap().name(), "groq");
assert_eq!(registry.get("mistral").unwrap().name(), "mistral");
Expand Down
25 changes: 23 additions & 2 deletions src/proxy/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ fn provider_auth_and_base_url(config: &ProviderConfig) -> Result<(ProviderAuth,
ProviderAuth::ApiKey(config.api_key.clone()),
parse_base_url(config.api_base.as_deref())?,
),
ProviderConfig::FireworksAi(config) => (
ProviderAuth::ApiKey(config.api_key.clone()),
parse_base_url(config.api_base.as_deref())?,
),
ProviderConfig::Gemini(config) => (
ProviderAuth::ApiKey(config.api_key.clone()),
parse_base_url(config.api_base.as_deref())?,
Expand Down Expand Up @@ -161,8 +165,9 @@ mod tests {
use crate::{
config::entities::providers::ProviderConfig,
gateway::providers::configs::{
AzureProviderConfig, BedrockProviderConfig, CohereProviderConfig, GroqProviderConfig,
MistralProviderConfig, OpenRouterProviderConfig, XaiProviderConfig,
AzureProviderConfig, BedrockProviderConfig, CohereProviderConfig,
FireworksAiProviderConfig, GroqProviderConfig, MistralProviderConfig,
OpenRouterProviderConfig, XaiProviderConfig,
},
};

Expand Down Expand Up @@ -237,6 +242,22 @@ mod tests {
);
}

#[test]
fn provider_auth_and_base_url_returns_fireworks_api_key_and_optional_base_url() {
let config = ProviderConfig::FireworksAi(FireworksAiProviderConfig {
api_key: "fireworks-key".into(),
api_base: Some("https://api.fireworks.ai/inference/v1".into()),
});

let (auth, base_url_override) = provider_auth_and_base_url(&config).unwrap();

assert_eq!(auth.api_key_for("fireworks-ai").unwrap(), "fireworks-key");
assert_eq!(
base_url_override.as_ref().map(Url::as_str),
Some("https://api.fireworks.ai/inference/v1")
);
}

#[test]
fn provider_auth_and_base_url_returns_groq_api_key_and_optional_base_url() {
let config = ProviderConfig::Groq(GroqProviderConfig {
Expand Down
2 changes: 2 additions & 0 deletions ui/src/i18n/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
"openai": "OpenAI",
"openrouter": "OpenRouter",
"cohere": "Cohere",
"fireworks-ai": "Fireworks AI",
"groq": "Groq",
"xai": "xAI",
"mistral": "Mistral",
Expand Down Expand Up @@ -223,6 +224,7 @@
"openai": "OpenAI",
"openrouter": "OpenRouter",
"cohere": "Cohere",
"fireworks-ai": "Fireworks AI",
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"groq": "Groq",
"xai": "xAI",
"azure": "Azure OpenAI",
Expand Down
2 changes: 2 additions & 0 deletions ui/src/i18n/locales/zh-CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
"openai": "OpenAI",
"openrouter": "OpenRouter",
"cohere": "Cohere",
"fireworks-ai": "Fireworks AI",
"groq": "Groq",
"xai": "xAI",
"mistral": "Mistral",
Expand Down Expand Up @@ -223,6 +224,7 @@
"openai": "OpenAI",
"openrouter": "OpenRouter",
"cohere": "Cohere",
"fireworks-ai": "Fireworks AI",
"groq": "Groq",
"xai": "xAI",
"azure": "Azure OpenAI",
Expand Down
6 changes: 6 additions & 0 deletions ui/src/lib/api/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export const PROVIDER_TYPE_VARIANTS = [
'openai',
'openrouter',
'cohere',
'fireworks-ai',
'groq',
'xai',
'mistral',
Expand Down Expand Up @@ -108,6 +109,11 @@ export type Provider =
type: 'cohere';
config: ApiBaseProviderConfig;
}
| {
name: string;
type: 'fireworks-ai';
config: ApiBaseProviderConfig;
}
| {
name: string;
type: 'groq';
Expand Down