Skip to content

Commit

Permalink
feat: provider settings alpha version (#625)
Browse files Browse the repository at this point in the history
  • Loading branch information
lily-de authored and acekyd committed Jan 21, 2025
1 parent 5f047c9 commit e3d82df
Show file tree
Hide file tree
Showing 23 changed files with 1,379 additions and 295 deletions.
1 change: 1 addition & 0 deletions crates/goose-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ http = "1.0"
config = { version = "0.14.1", features = ["toml"] }
thiserror = "1.0"
clap = { version = "4.4", features = ["derive"] }
once_cell = "1.18"

[[bin]]
name = "goosed"
Expand Down
47 changes: 47 additions & 0 deletions crates/goose-server/src/routes/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use axum::{
};
use goose::{agents::AgentFactory, providers::factory};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Serialize)]
struct VersionsResponse {
Expand All @@ -25,6 +26,28 @@ struct CreateAgentResponse {
version: String,
}

#[derive(Deserialize)]
struct ProviderFile {
name: String,
description: String,
models: Vec<String>,
required_keys: Vec<String>,
}

#[derive(Serialize)]
struct ProviderDetails {
name: String,
description: String,
models: Vec<String>,
required_keys: Vec<String>,
}

#[derive(Serialize)]
struct ProviderList {
id: String,
details: ProviderDetails,
}

async fn get_versions() -> Json<VersionsResponse> {
let versions = AgentFactory::available_versions();
let default_version = AgentFactory::default_version().to_string();
Expand Down Expand Up @@ -64,9 +87,33 @@ async fn create_agent(
Ok(Json(CreateAgentResponse { version }))
}

async fn list_providers() -> Json<Vec<ProviderList>> {
let contents = include_str!("providers_and_keys.json");

let providers: HashMap<String, ProviderFile> =
serde_json::from_str(contents).expect("Failed to parse providers_and_keys.json");

let response: Vec<ProviderList> = providers
.into_iter()
.map(|(id, provider)| ProviderList {
id,
details: ProviderDetails {
name: provider.name,
description: provider.description,
models: provider.models,
required_keys: provider.required_keys,
},
})
.collect();

// Return the response as JSON.
Json(response)
}

pub fn routes(state: AppState) -> Router {
Router::new()
.route("/agent/versions", get(get_versions))
.route("/agent/providers", get(list_providers))
.route("/agent", post(create_agent))
.with_state(state)
}
38 changes: 38 additions & 0 deletions crates/goose-server/src/routes/providers_and_keys.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"openai": {
"name": "OpenAI",
"description": "Use GPT-4 and other OpenAI models",
"models": ["gpt-4o", "gpt-4-turbo","o1"],
"required_keys": ["OPENAI_API_KEY"]
},
"anthropic": {
"name": "Anthropic",
"description": "Use Claude and other Anthropic models",
"models": ["claude-3.5-sonnet-2"],
"required_keys": ["ANTHROPIC_API_KEY"]
},
"databricks": {
"name": "Databricks",
"description": "Connect to LLMs via Databricks",
"models": ["claude-3-5-sonnet-2"],
"required_keys": ["DATABRICKS_HOST"]
},
"google": {
"name": "Google",
"description": "Lorem ipsum",
"models": ["gemini-1.5-flash"],
"required_keys": ["GOOGLE_API_KEY"]
},
"grok": {
"name": "Grok",
"description": "Lorem ipsum",
"models": ["llama-3.3-70b-versatile"],
"required_keys": ["GROK_API_KEY"]
},
"ollama": {
"name": "Ollama",
"description": "Lorem ipsum",
"models": ["qwen2.5"],
"required_keys": []
}
}
153 changes: 151 additions & 2 deletions crates/goose-server/src/routes/secrets.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use crate::state::AppState;
use axum::{extract::State, routing::post, Json, Router};
use goose::key_manager::save_to_keyring;
use axum::{extract::State, routing::delete, routing::post, Json, Router};
use goose::key_manager::{
delete_from_keyring, get_keyring_secret, save_to_keyring, KeyRetrievalStrategy,
};
use http::{HeaderMap, StatusCode};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Serialize)]
struct SecretResponse {
Expand Down Expand Up @@ -36,8 +40,153 @@ async fn store_secret(
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ProviderSecretRequest {
pub providers: Vec<String>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct SecretStatus {
pub is_set: bool,
pub location: Option<String>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ProviderResponse {
pub supported: bool,
pub name: Option<String>,
pub description: Option<String>,
pub models: Option<Vec<String>>,
pub secret_status: HashMap<String, SecretStatus>,
}

#[derive(Debug, Serialize, Deserialize)]
struct ProviderConfig {
name: String,
description: String,
models: Vec<String>,
required_keys: Vec<String>,
}

static PROVIDER_ENV_REQUIREMENTS: Lazy<HashMap<String, ProviderConfig>> = Lazy::new(|| {
let contents = include_str!("providers_and_keys.json");
serde_json::from_str(contents).expect("Failed to parse providers_and_keys.json")
});

fn check_key_status(key: &str) -> (bool, Option<String>) {
if let Ok(_value) = std::env::var(key) {
(true, Some("env".to_string()))
} else if let Ok(_) = get_keyring_secret(key, KeyRetrievalStrategy::KeyringOnly) {
(true, Some("keyring".to_string()))
} else {
(false, None)
}
}

async fn check_provider_secrets(
Json(request): Json<ProviderSecretRequest>,
) -> Result<Json<HashMap<String, ProviderResponse>>, StatusCode> {
let mut response = HashMap::new();

for provider_name in request.providers {
if let Some(provider_config) = PROVIDER_ENV_REQUIREMENTS.get(&provider_name) {
let mut secret_status = HashMap::new();

for key in &provider_config.required_keys {
let (key_set, key_location) = check_key_status(key);
secret_status.insert(
key.to_string(),
SecretStatus {
is_set: key_set,
location: key_location,
},
);
}

response.insert(
provider_name,
ProviderResponse {
supported: true,
name: Some(provider_config.name.clone()),
description: Some(provider_config.description.clone()),
models: Some(provider_config.models.clone()),
secret_status,
},
);
} else {
response.insert(
provider_name,
ProviderResponse {
supported: false,
name: None,
description: None,
models: None,
secret_status: HashMap::new(),
},
);
}
}

Ok(Json(response))
}

#[derive(Deserialize)]
struct DeleteSecretRequest {
key: String,
}

async fn delete_secret(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<DeleteSecretRequest>,
) -> Result<StatusCode, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;

if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}

// Attempt to delete the key
match delete_from_keyring(&request.key) {
Ok(_) => Ok(StatusCode::NO_CONTENT),
Err(_) => Err(StatusCode::NOT_FOUND),
}
}

pub fn routes(state: AppState) -> Router {
Router::new()
.route("/secrets/providers", post(check_provider_secrets))
.route("/secrets/store", post(store_secret))
.route("/secrets/delete", delete(delete_secret))
.with_state(state)
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_unsupported_provider() {
// Setup
let request = ProviderSecretRequest {
providers: vec!["unsupported_provider".to_string()],
};

// Execute
let result = check_provider_secrets(Json(request)).await;

// Assert
assert!(result.is_ok());
let Json(response) = result.unwrap();

let provider_status = response
.get("unsupported_provider")
.expect("Provider should exist");
assert!(!provider_status.supported);
assert!(provider_status.secret_status.is_empty());
}
}
26 changes: 26 additions & 0 deletions crates/goose/src/key_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ pub fn save_to_keyring(key_name: &str, key_val: &str) -> std::result::Result<(),
kr.set_password(key_val).map_err(KeyManagerError::from)
}

pub fn delete_from_keyring(key_name: &str) -> std::result::Result<(), KeyManagerError> {
let kr = Entry::new("goose", key_name)?;
kr.delete_credential().map_err(KeyManagerError::from)
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -100,6 +105,27 @@ mod tests {
kr.delete_credential().map_err(KeyManagerError::from)
}

#[test]
fn test_delete_from_keyring() {
let key_name = format!("{}{}", TEST_ENV_PREFIX, "DELETE_KEY");

// Save a value to the keyring
save_to_keyring(&key_name, "test_value").unwrap();

// Verify it exists
let kr = Entry::new("goose", &key_name).unwrap();
assert_eq!(kr.get_password().unwrap(), "test_value");

// Delete the keyring entry
let result = delete_from_keyring(&key_name);
assert!(result.is_ok());

// Verify deletion
let kr = Entry::new("goose", &key_name).unwrap();
let password_result = kr.get_password();
assert!(password_result.is_err());
}

#[test]
fn test_get_key_environment_only() {
let key_name = format!("{}{}", TEST_ENV_PREFIX, "ENV_KEY");
Expand Down
Loading

0 comments on commit e3d82df

Please sign in to comment.