Skip to content

Commit

Permalink
feat: general purpose config (#701)
Browse files Browse the repository at this point in the history
Co-authored-by: Salman Mohammed <[email protected]>
  • Loading branch information
baxen and salman1993 authored Jan 23, 2025
1 parent c53412b commit d230c6e
Show file tree
Hide file tree
Showing 37 changed files with 1,230 additions and 1,104 deletions.
337 changes: 161 additions & 176 deletions crates/goose-cli/src/commands/configure.rs

Large diffs are not rendered by default.

213 changes: 60 additions & 153 deletions crates/goose-cli/src/commands/session.rs
Original file line number Diff line number Diff line change
@@ -1,97 +1,43 @@
use console::style;
use goose::agents::AgentFactory;
use goose::providers::factory;
use rand::{distributions::Alphanumeric, Rng};
use std::path::{Path, PathBuf};
use std::process;

use crate::config::Config;
use crate::prompt::rustyline::RustylinePrompt;
use crate::prompt::Prompt;
use crate::session::{ensure_session_dir, get_most_recent_session, Session};
use goose::agents::extension::ExtensionError;
use mcp_client::transport::Error as McpClientError;

/// Get the provider and model to use, following priority:
/// 1. CLI arguments
/// 2. Environment variables
/// 3. Config file
fn get_provider_and_model(
cli_provider: Option<String>,
cli_model: Option<String>,
config: &Config,
) -> (String, String) {
let provider = cli_provider
.or_else(|| std::env::var("GOOSE_PROVIDER").ok())
.unwrap_or_else(|| config.default_provider.clone());
use goose::agents::AgentFactory;
use goose::config::{Config, ExtensionManager};
use goose::providers::create;

let model = cli_model
.or_else(|| std::env::var("GOOSE_MODEL").ok())
.unwrap_or_else(|| config.default_model.clone());
use mcp_client::transport::Error as McpClientError;

(provider, model)
}
pub async fn build_session(name: Option<String>, resume: bool) -> Session<'static> {
// Load config and get provider/model
let config = Config::global();

pub async fn build_session<'a>(
session: Option<String>,
provider: Option<String>,
model: Option<String>,
agent_version: Option<String>,
resume: bool,
) -> Box<Session<'a>> {
let provider_name: String = config
.get("GOOSE_PROVIDER")
.expect("No provider configured. Run 'goose configure' first");
let session_dir = ensure_session_dir().expect("Failed to create session directory");
let session_file = if resume && session.is_none() {
// When resuming without a specific session name, use the most recent session
get_most_recent_session().expect("Failed to get most recent session")
} else {
session_path(session.clone(), &session_dir, session.is_none() && !resume)
};

// Guard against resuming a non-existent session
if resume && !session_file.exists() {
panic!(
"Cannot resume session: file {} does not exist",
session_file.display()
);
}

// Guard against running a new session with a file that already exists
if !resume && session_file.exists() {
panic!(
"Session file {} already exists. Use --resume to continue an existing session",
session_file.display()
);
let model = config
.get("GOOSE_MODEL")
.expect("No model configured. Run 'goose configure' first");
let model_config = goose::model::ModelConfig::new(model);
let provider = create(&provider_name, model_config).expect("Failed to create provider");

// Create the agent
let agent_version: Option<String> = config.get("GOOSE_AGENT").ok();
let mut agent = match agent_version {
Some(version) => AgentFactory::create(&version, provider),
None => AgentFactory::create(AgentFactory::default_version(), provider),
}
.expect("Failed to create agent");

let config_path = Config::config_path().expect("should identify default config path");

if !config_path.exists() {
println!("No configuration found. Please run 'goose configure' first.");
process::exit(1);
}

let config = Config::load().unwrap_or_else(|_| {
println!("The loaded configuration from {} was invalid", config_path.display());
println!(" please edit the file to make it valid or consider deleting and recreating it via `goose configure`");
process::exit(1);
});

let (provider_name, model_name) = get_provider_and_model(provider, model, &config);
let provider = factory::get_provider(&provider_name).unwrap();

let mut agent = AgentFactory::create(
agent_version
.as_deref()
.unwrap_or(AgentFactory::default_version()),
provider,
)
.unwrap();

// Add configured extensions
for (name, _) in config.extensions.iter() {
if let Some(extension_config) = config.get_extension_config(name) {
// Setup extensions for the agent
for (name, extension) in ExtensionManager::get_all().expect("should load extensions") {
if extension.enabled {
agent
.add_extension(extension_config.clone())
.add_extension(extension.config.clone())
.await
.unwrap_or_else(|e| {
let err = match e {
Expand All @@ -107,82 +53,43 @@ pub async fn build_session<'a>(
}
}

let prompt = match std::env::var("GOOSE_INPUT") {
Ok(val) => match val.as_str() {
"rustyline" => Box::new(RustylinePrompt::new()) as Box<dyn Prompt>,
_ => Box::new(RustylinePrompt::new()) as Box<dyn Prompt>,
},
Err(_) => Box::new(RustylinePrompt::new()),
};

display_session_info(resume, provider_name, model_name, session_file.as_path());
Box::new(Session::new(agent, prompt, session_file))
}

fn session_path(
provided_session_name: Option<String>,
session_dir: &Path,
retry_on_conflict: bool,
) -> PathBuf {
let session_name = provided_session_name.unwrap_or(random_session_name());
let session_file = session_dir.join(format!("{}.jsonl", session_name));

if session_file.exists() && retry_on_conflict {
generate_new_session_path(session_dir)
} else {
session_file
// If resuming, try to find the session
if resume {
if let Some(ref session_name) = name {
// Try to resume specific session
let session_file = session_dir.join(format!("{}.jsonl", session_name));
if session_file.exists() {
let prompt = Box::new(RustylinePrompt::new());
return Session::new(agent, prompt, session_file);
} else {
eprintln!("Session '{}' not found, starting new session", session_name);
}
} else {
// Try to resume most recent session
if let Ok(session_file) = get_most_recent_session() {
let prompt = Box::new(RustylinePrompt::new());
return Session::new(agent, prompt, session_file);
} else {
eprintln!("No previous sessions found, starting new session");
}
}
}
}

fn random_session_name() -> String {
rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(8)
.map(char::from)
.collect::<String>()
.to_lowercase()
}

// For auto-generated names, try up to 5 times to get a unique name
fn generate_new_session_path(session_dir: &Path) -> PathBuf {
let mut attempts = 0;
let max_attempts = 5;

loop {
let generated_name = random_session_name();
let generated_file = session_dir.join(format!("{}.jsonl", generated_name));

if !generated_file.exists() {
break generated_file;
}
// Generate session name if not provided
let name = name.unwrap_or_else(|| {
rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(8)
.map(char::from)
.collect()
});

attempts += 1;
if attempts >= max_attempts {
panic!(
"Failed to generate unique session name after {} attempts",
max_attempts
);
}
let session_file = session_dir.join(format!("{}.jsonl", name));
if session_file.exists() {
eprintln!("Session '{}' already exists", name);
process::exit(1);
}
}

fn display_session_info(resume: bool, provider: String, model: String, session_file: &Path) {
let start_session_msg = if resume {
"resuming session |"
} else {
"starting session |"
};
println!(
"{} {} {} {} {}",
style(start_session_msg).dim(),
style("provider:").dim(),
style(provider).cyan().dim(),
style("model:").dim(),
style(model).cyan().dim(),
);
println!(
" {} {}",
style("logging to").dim(),
style(session_file.display()).dim().cyan(),
);
let prompt = Box::new(RustylinePrompt::new());
Session::new(agent, prompt, session_file)
}
Loading

0 comments on commit d230c6e

Please sign in to comment.