Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support mcp setup on CLI #636

Merged
merged 3 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
392 changes: 341 additions & 51 deletions crates/goose-cli/src/commands/configure.rs

Large diffs are not rendered by default.

107 changes: 46 additions & 61 deletions crates/goose-cli/src/commands/session.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,39 @@
use console::style;
use goose::agents::AgentFactory;
use goose::agents::SystemConfig;
use goose::providers::factory;
use rand::{distributions::Alphanumeric, Rng};
use std::path::{Path, PathBuf};
use std::process;

use crate::profile::{load_profiles, set_provider_env_vars, Profile};
use crate::config::Config;
use crate::prompt::rustyline::RustylinePrompt;
use crate::prompt::Prompt;
use crate::session::{ensure_session_dir, get_most_recent_session, Session};

/// Get the provider and model to use, following priority:
/// 1. CLI arguments
/// 2. Environment variables
/// 3. Config file
fn get_provider_and_model(
cli_provider: Option<String>,
cli_model: Option<String>,
config: &Config,
) -> (String, String) {
let provider = cli_provider
.or_else(|| std::env::var("GOOSE_PROVIDER").ok())
.unwrap_or_else(|| config.default_provider.clone());

let model = cli_model
.or_else(|| std::env::var("GOOSE_MODEL").ok())
.unwrap_or_else(|| config.default_model.clone());

(provider, model)
}

pub async fn build_session<'a>(
session: Option<String>,
profile: Option<String>,
provider: Option<String>,
model: Option<String>,
agent_version: Option<String>,
resume: bool,
) -> Box<Session<'a>> {
Expand Down Expand Up @@ -41,34 +61,34 @@ pub async fn build_session<'a>(
);
}

let loaded_profile = load_profile(profile);
let config_path = Config::config_path().expect("should identify default config path");

if !config_path.exists() {
println!("No configuration found. Please run 'goose configure' first.");
process::exit(1);
}

// Set environment variables for provider configuration
set_provider_env_vars(&loaded_profile.provider, &loaded_profile);
let config = Config::load().unwrap_or_else(|_| {
println!("The loaded configuration from {} was invalid", config_path.display());
println!(" please edit the file to make it valid or consider deleting and recreating it via `goose configure`");
process::exit(1);
});

let provider = factory::get_provider(&loaded_profile.provider).unwrap();
let (provider_name, model_name) = get_provider_and_model(provider, model, &config);
let provider = factory::get_provider(&provider_name).unwrap();

let mut agent =
AgentFactory::create(agent_version.as_deref().unwrap_or("default"), provider).unwrap();

// We now add systems to the session based on configuration
// TODO update the profile system tracking
// TODO use systems from the profile
// TODO once the client/server for MCP has stabilized, we should probably add InProcess transport to each
// and avoid spawning here. But it is at least included in the CLI for portability

let system = std::env::var("GOOSE_SYSTEM").unwrap_or("developer2".to_string());
let config = SystemConfig::stdio(
std::env::current_exe()
.expect("should find the current executable")
.to_str()
.expect("should resolve executable to string path"),
)
.with_args(vec!["mcp", &system]);
agent
.add_system(config)
.await
.expect("should start developer server");
// Add configured systems
for (name, _) in config.systems.iter() {
if let Some(system_config) = config.get_system_config(name) {
agent
.add_system(system_config.clone())
.await
.expect(&format!("Failed to start system: {}", name));
}
}

let prompt = match std::env::var("GOOSE_INPUT") {
Ok(val) => match val.as_str() {
Expand All @@ -78,12 +98,7 @@ pub async fn build_session<'a>(
Err(_) => Box::new(RustylinePrompt::new()),
};

display_session_info(
resume,
loaded_profile.provider,
loaded_profile.model,
session_file.as_path(),
);
display_session_info(resume, provider_name, model_name, session_file.as_path());
Box::new(Session::new(agent, prompt, session_file))
}

Expand Down Expand Up @@ -134,36 +149,6 @@ fn generate_new_session_path(session_dir: &Path) -> PathBuf {
}
}

fn load_profile(profile_name: Option<String>) -> Box<Profile> {
let configure_profile_message = "Please create a profile first via goose configure.";
let profiles = load_profiles().unwrap();
let loaded_profile = if profiles.is_empty() {
println!("No profiles found. {}", configure_profile_message);
process::exit(1);
} else {
match profile_name {
Some(name) => match profiles.get(name.as_str()) {
Some(profile) => Box::new(profile.clone()),
None => {
println!(
"Profile '{}' not found. {}",
name, configure_profile_message
);
process::exit(1);
}
},
None => match profiles.get("default") {
Some(profile) => Box::new(profile.clone()),
None => {
println!("No 'default' profile found. {}", configure_profile_message);
process::exit(1);
}
},
}
};
loaded_profile
}

fn display_session_info(resume: bool, provider: String, model: String, session_file: &Path) {
let start_session_msg = if resume {
"resuming session |"
Expand Down
203 changes: 203 additions & 0 deletions crates/goose-cli/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;

use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};

use goose::agents::SystemConfig;

const DEFAULT_SYSTEM: &str = "developer2";

/// Core configuration for Goose CLI
#[derive(Debug, Deserialize, Serialize)]
pub struct Config {
pub default_provider: String,
pub default_model: String,
pub systems: HashMap<String, SystemEntry>,
}

/// A system configuration entry with an enabled flag and configuration
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct SystemEntry {
pub enabled: bool,
#[serde(flatten)]
pub config: SystemConfig,
}

impl Config {
/// Get the configuration file path
pub fn config_path() -> Result<PathBuf> {
let home_dir = dirs::home_dir().context("Could not determine home directory")?;
let config_dir = home_dir.join(".config").join("goose");
if !config_dir.exists() {
fs::create_dir_all(&config_dir)?;
}
Ok(config_dir.join("config.yaml"))
}

/// Load the configuration from disk
pub fn load() -> Result<Self> {
let path = Self::config_path()?;
if !path.exists() {
return Err(anyhow::anyhow!("Config has not yet been created"));
}
let content = fs::read_to_string(path)?;
Ok(serde_yaml::from_str(&content)?)
}

/// Save the configuration to disk
pub fn save(&self) -> Result<()> {
let path = Self::config_path()?;
let content = serde_yaml::to_string(self)?;
fs::write(path, content)?;
Ok(())
}

/// Get the system configuration if enabled
pub fn get_system_config(&self, name: &str) -> Option<SystemConfig> {
let entry = self.systems.get(name)?;
if entry.enabled {
Some(entry.config.clone())
} else {
None
}
}
}

impl Default for Config {
fn default() -> Self {
Self {
default_provider: "".to_string(),
default_model: "".to_string(),
systems: HashMap::from([(
DEFAULT_SYSTEM.to_string(),
SystemEntry {
enabled: true,
config: SystemConfig::Builtin {
name: DEFAULT_SYSTEM.to_string(),
},
},
)]),
}
}
}

#[cfg(test)]
mod tests {
use super::*;

/// This test provides a comprehensive example of all possible configuration options
/// and validates that they are correctly parsed
#[test]
fn test_comprehensive_config() {
let yaml = r#"
# Core settings for the default provider and model
default_provider: openai
default_model: gpt-4

# System configurations showing all possible variants
systems:
# Built-in system that just needs to be enabled
developer:
enabled: true
type: builtin
name: developer

# Built-in system that is disabled
unused:
enabled: false
type: builtin
name: unused

# Full stdio system configuration with all options
python:
enabled: true
type: stdio
cmd: python3
args:
- "-m"
- "goose.systems.python"
envs:
PYTHONPATH: /path/to/python
DEBUG: "true"

# Full SSE system configuration
remote:
enabled: true
type: sse
uri: http://localhost:8000/events
envs:
API_KEY: secret
DEBUG: "true"

# Disabled full system configuration
disabled_system:
enabled: false
type: stdio
cmd: test
args: []
envs: {}
"#;
let config: Config = serde_yaml::from_str(yaml).unwrap();

// Check core settings
assert_eq!(config.default_provider, "openai");
assert_eq!(config.default_model, "gpt-4");

// Check builtin enabled system
match &config.systems.get("developer").unwrap().config {
SystemConfig::Builtin { name } => assert_eq!(name, "developer"),
_ => panic!("Expected builtin system config"),
}
assert!(config.systems.get("developer").unwrap().enabled);

// Check builtin disabled system
match &config.systems.get("unused").unwrap().config {
SystemConfig::Builtin { name } => assert_eq!(name, "unused"),
_ => panic!("Expected builtin system config"),
}
assert!(!config.systems.get("unused").unwrap().enabled);

// Check full stdio system
let python = config.systems.get("python").unwrap();
assert!(python.enabled);
match &python.config {
SystemConfig::Stdio { cmd, args, envs } => {
assert_eq!(cmd, "python3");
assert_eq!(
args,
&vec!["-m".to_string(), "goose.systems.python".to_string()]
);
let env = envs.get_env();
assert_eq!(env.get("PYTHONPATH").unwrap(), "/path/to/python");
assert_eq!(env.get("DEBUG").unwrap(), "true");
}
_ => panic!("Expected stdio system config"),
}

// Check full SSE system
let remote = config.systems.get("remote").unwrap();
assert!(remote.enabled);
match &remote.config {
SystemConfig::Sse { uri, envs } => {
assert_eq!(uri, "http://localhost:8000/events");
let env = envs.get_env();
assert_eq!(env.get("API_KEY").unwrap(), "secret");
assert_eq!(env.get("DEBUG").unwrap(), "true");
}
_ => panic!("Expected sse system config"),
}

// Check disabled full system
assert!(!config.systems.get("disabled_system").unwrap().enabled);

// Test the get_system_config helper
assert!(config.get_system_config("developer").is_some());
assert!(config.get_system_config("unused").is_none());
assert!(config.get_system_config("python").is_some());
assert!(config.get_system_config("remote").is_some());
assert!(config.get_system_config("disabled_system").is_none());
assert!(config.get_system_config("nonexistent").is_none());
}
}
Loading
Loading