Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ldelalande/model settings #651

Merged
merged 19 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions crates/goose-server/src/routes/agent.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::state::AppState;
use std::env;
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
Expand All @@ -19,6 +20,7 @@ struct VersionsResponse {
struct CreateAgentRequest {
version: Option<String>,
provider: String,
model: Option<String>,
}

#[derive(Serialize)]
Expand Down Expand Up @@ -73,6 +75,13 @@ async fn create_agent(
return Err(StatusCode::UNAUTHORIZED);
}

// Set the environment variable for the model if provided
if let Some(model) = &payload.model {
let env_var_key = format!("{}_MODEL", payload.provider.to_uppercase());
Copy link
Collaborator

@alexhancock alexhancock Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be payload.model.to_uppercase()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's provider pretty sure ... like OPENAI_MODEL, ANTHROPIC_MODEL, DATABRICKS_MODEL etc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to work based on log statements!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I guess I was just confused by the surface level param mapping

We have payload.model and payload.provider but we set {}_MODEL to payload.provider

Not a big deal though!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah i see how that is confusing.. it's just to match what we are looking for in the backend when we set up an agent with a specific provider -- eg here's where we set up anthropic with a model or default to sonnet

env::set_var(env_var_key.clone(), model);
println!("Set environment variable: {}={}", env_var_key, model);
}

let provider = factory::get_provider(&payload.provider).expect("Failed to create provider");

let version = payload
Expand Down
5,450 changes: 2,437 additions & 3,013 deletions ui/desktop/package-lock.json

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions ui/desktop/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,14 @@
"express": "^4.21.1",
"framer-motion": "^11.11.11",
"lucide-react": "^0.454.0",
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-icons": "^5.3.0",
"react-markdown": "^9.0.1",
"react-router-dom": "^6.28.0",
"react-select": "^5.9.0",
"react-syntax-highlighter": "^15.6.1",
"react-toastify": "^8.0.0",
"tailwind-merge": "^2.5.4",
"tailwindcss-animate": "^1.0.7",
"unist-util-visit": "^5.0.0"
Expand Down
21 changes: 19 additions & 2 deletions ui/desktop/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import LauncherWindow from './LauncherWindow';
import ChatWindow from './ChatWindow';
import ErrorScreen from './components/ErrorScreen';
import 'react-toastify/dist/ReactToastify.css';
import { ToastContainer } from 'react-toastify';
import { ModelProvider} from "./components/settings/models/ModelContext";
import { ActiveKeysProvider } from "./components/settings/api_keys/ActiveKeysContext";

export default function App() {
const [fatalError, setFatalError] = useState<string | null>(null);
Expand All @@ -9,7 +13,7 @@
const isLauncher = searchParams.get('window') === 'launcher';

useEffect(() => {
const handleFatalError = (_: any, errorMessage: string) => {

Check warning on line 16 in ui/desktop/src/App.tsx

View workflow job for this annotation

GitHub Actions / Lint Electron Desktop App

Unexpected any. Specify a different type
setFatalError(errorMessage);
};

Expand All @@ -24,6 +28,19 @@
if (fatalError) {
return <ErrorScreen error={fatalError} onReload={() => window.electron.reloadApp()} />;
}

return isLauncher ? <LauncherWindow /> : <ChatWindow />;

return (
<ModelProvider>
<ActiveKeysProvider>
{isLauncher ? <LauncherWindow /> : <ChatWindow />}
<ToastContainer
aria-label="Toast notifications"
position="top-right"
autoClose={3000}
closeOnClick
pauseOnHover
/>
</ActiveKeysProvider>
</ModelProvider>
);
}
92 changes: 64 additions & 28 deletions ui/desktop/src/ChatWindow.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import React, { useEffect, useRef, useState } from "react";
import { Message, useChat } from "./ai-sdk-fork/useChat";
import { Route, Routes, Navigate } from "react-router-dom";

Check warning on line 3 in ui/desktop/src/ChatWindow.tsx

View workflow job for this annotation

GitHub Actions / Lint Electron Desktop App

'Route' is defined but never used. Allowed unused vars must match /^_/u

Check warning on line 3 in ui/desktop/src/ChatWindow.tsx

View workflow job for this annotation

GitHub Actions / Lint Electron Desktop App

'Routes' is defined but never used. Allowed unused vars must match /^_/u

Check warning on line 3 in ui/desktop/src/ChatWindow.tsx

View workflow job for this annotation

GitHub Actions / Lint Electron Desktop App

'Navigate' is defined but never used. Allowed unused vars must match /^_/u
import { getApiUrl, getSecretKey, extendGoosed, extendGoosedFromUrl } from "./config";

Check warning on line 4 in ui/desktop/src/ChatWindow.tsx

View workflow job for this annotation

GitHub Actions / Lint Electron Desktop App

'extendGoosed' is defined but never used. Allowed unused vars must match /^_/u
import BottomMenu from "./components/BottomMenu";
import FlappyGoose from "./components/FlappyGoose";
import GooseMessage from "./components/GooseMessage";
Expand All @@ -15,12 +15,17 @@
import WingToWing, { Working } from "./components/WingToWing";
import { askAi } from "./utils/askAI";
import {
getStoredModel,
Provider,
} from "./utils/providerUtils";
import { ChatLayout } from "./components/chat_window/ChatLayout"
import { ChatRoutes } from "./components/chat_window/ChatRoutes"
import { WelcomeModal } from "./components/welcome_screen/WelcomeModal"
import { getStoredProvider, initializeSystem } from './utils/providerUtils'
import {useModel} from "./components/settings/models/ModelContext";
import {useRecentModels} from "./components/settings/models/RecentModels";
import {createSelectedModel} from "./components/settings/models/utils";
import {getDefaultModel} from "./components/settings/models/hardcoded_stuff";

declare global {
interface Window {
Expand All @@ -33,10 +38,10 @@
logInfo: (message: string) => void;
showNotification: (opts: { title: string; body: string }) => void;
getBinaryPath: (binary: string) => Promise<string>;
app: any;

Check warning on line 41 in ui/desktop/src/ChatWindow.tsx

View workflow job for this annotation

GitHub Actions / Lint Electron Desktop App

Unexpected any. Specify a different type
};
appConfig: {
get: (key: string) => any;

Check warning on line 44 in ui/desktop/src/ChatWindow.tsx

View workflow job for this annotation

GitHub Actions / Lint Electron Desktop App

Unexpected any. Specify a different type
};
}
}
Expand All @@ -57,7 +62,7 @@
chats,
setChats,
selectedChatId,
setSelectedChatId,

Check warning on line 65 in ui/desktop/src/ChatWindow.tsx

View workflow job for this annotation

GitHub Actions / Lint Electron Desktop App

'setSelectedChatId' is defined but never used. Allowed unused args must match /^_/u
initialQuery,
setProgressMessage,
setWorking,
Expand Down Expand Up @@ -138,7 +143,7 @@
c.id === selectedChatId ? { ...c, messages } : c
);
setChats(updatedChats);
}, [messages, selectedChatId]);

Check warning on line 146 in ui/desktop/src/ChatWindow.tsx

View workflow job for this annotation

GitHub Actions / Lint Electron Desktop App

React Hook useEffect has missing dependencies: 'chats' and 'setChats'. Either include them or remove the dependency array. If 'setChats' changes too often, find the parent component that defines it and wrap that definition in useCallback

const initialQueryAppended = useRef(false);
useEffect(() => {
Expand All @@ -146,7 +151,7 @@
append({ role: "user", content: initialQuery });
initialQueryAppended.current = true;
}
}, [initialQuery]);

Check warning on line 154 in ui/desktop/src/ChatWindow.tsx

View workflow job for this annotation

GitHub Actions / Lint Electron Desktop App

React Hook useEffect has a missing dependency: 'append'. Either include it or remove the dependency array

useEffect(() => {
if (messages.length > 0) {
Expand Down Expand Up @@ -332,6 +337,8 @@
const openNewChatWindow = () => {
window.electron.createChatWindow();
};
const { switchModel, currentModel} = useModel(); // Access switchModel via useModel
const { addRecentModel } = useRecentModels(); // Access addRecentModel from useRecentModels

// Add keyboard shortcut handler
useEffect(() => {
Expand Down Expand Up @@ -365,8 +372,8 @@
const initialQuery = searchParams.get("initialQuery");
const historyParam = searchParams.get("history");
const initialHistory = historyParam
? JSON.parse(decodeURIComponent(historyParam))
: [];
? JSON.parse(decodeURIComponent(historyParam))
: [];

const [chats, setChats] = useState<Chat[]>(() => {
const firstChat = {
Expand All @@ -379,7 +386,7 @@

const [selectedChatId, setSelectedChatId] = useState(1);
const [mode, setMode] = useState<"expanded" | "compact">(
initialQuery ? "compact" : "expanded"
initialQuery ? "compact" : "expanded"
);
const [working, setWorking] = useState<Working>(Working.Idle);
const [progressMessage, setProgressMessage] = useState<string>("");
Expand Down Expand Up @@ -415,7 +422,7 @@
"Content-Type": "application/json",
"X-Secret-Key": getSecretKey(),
},
body: JSON.stringify({ key, value }),
body: JSON.stringify({key, value}),
});

if (!response.ok) {
Expand All @@ -438,10 +445,24 @@
await storeSecret(secretKey, trimmedKey);

// Initialize the system with the selected provider
await initializeSystem(selectedProvider.id);
await initializeSystem(selectedProvider.id, null);

// get the default model
const modelName = getDefaultModel(selectedProvider.id)

// create model object
const model = createSelectedModel(selectedProvider.id, modelName)

// Call the context's switchModel to track the set model state in the front end
switchModel(model);

// Keep track of the recently used models
addRecentModel(model);


// Save provider selection and close modal
localStorage.setItem("GOOSE_PROVIDER", selectedProvider.id);
console.log("set up provider with default model", selectedProvider.id, modelName)
setShowWelcomeModal(false);
} catch (error) {
console.error("Failed to setup provider:", error);
Expand All @@ -454,9 +475,15 @@
const setupStoredProvider = async () => {
const config = window.electron.getConfig();
const storedProvider = getStoredProvider(config);
const storedModel = getStoredModel()
if (storedProvider) {
try {
await initializeSystem(storedProvider);
// Call the context's switchModel to update the model
switchModel(model);

Check failure on line 482 in ui/desktop/src/ChatWindow.tsx

View workflow job for this annotation

GitHub Actions / Lint Electron Desktop App

'model' is not defined

// Keep track of the recently used models
addRecentModel(model);

Check failure on line 485 in ui/desktop/src/ChatWindow.tsx

View workflow job for this annotation

GitHub Actions / Lint Electron Desktop App

'model' is not defined
await initializeSystem(storedProvider, storedModel);
} catch (error) {
console.error("Failed to initialize with stored provider:", error);
}
Expand All @@ -467,27 +494,36 @@
}, []);

return (
<ChatLayout mode={mode}>
<ChatRoutes
chats={chats}
setChats={setChats}
selectedChatId={selectedChatId}
setSelectedChatId={setSelectedChatId}
setProgressMessage={setProgressMessage}
setWorking={setWorking}
/>
<WingToWing
onExpand={toggleMode}
progressMessage={progressMessage}
working={working}
/>
{showWelcomeModal && (
<WelcomeModal
selectedProvider={selectedProvider}
setSelectedProvider={setSelectedProvider}
onSubmit={handleModalSubmit}
/>
)}
</ChatLayout>
<div>
<ChatLayout mode={mode}>
<ChatRoutes
chats={chats}
setChats={setChats}
selectedChatId={selectedChatId}
setSelectedChatId={setSelectedChatId}
setProgressMessage={setProgressMessage}
setWorking={setWorking}
/>
<WingToWing
onExpand={toggleMode}
progressMessage={progressMessage}
working={working}
/>
{showWelcomeModal && (
<WelcomeModal
selectedProvider={selectedProvider}
setSelectedProvider={setSelectedProvider}
onSubmit={handleModalSubmit}
/>
)}
</ChatLayout>
{/*<ToastContainer*/}
{/* aria-label="Notification container"*/}
{/* position="top-right"*/}
{/* autoClose={3000}*/}
{/* closeOnClick*/}
{/* pauseOnHover*/}
{/*/>*/}
</div>
);
}
3 changes: 2 additions & 1 deletion ui/desktop/src/components/chat_window/ChatRoutes.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { Routes, Route, Navigate } from "react-router-dom";
import { ChatContent } from "../../ChatWindow"
import Settings from "../settings/Settings"
import Keys from "../settings/Keys"
import MoreModelsSettings from "../settings/models/MoreModels";

export const ChatRoutes = ({
chats,
Expand All @@ -17,7 +18,6 @@ export const ChatRoutes = ({
path="/chat/:id"
element={
<ChatContent
key={selectedChatId}
chats={chats}
setChats={setChats}
selectedChatId={selectedChatId}
Expand All @@ -29,6 +29,7 @@ export const ChatRoutes = ({
}
/>
<Route path="/settings" element={<Settings />} />
<Route path="/settings/more-models" element={<MoreModelsSettings />} />
<Route path="/keys" element={<Keys />} />
<Route path="*" element={<Navigate to="/chat/1" replace />} />
</Routes>
Expand Down
47 changes: 21 additions & 26 deletions ui/desktop/src/components/settings/Settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import { Modal, ModalContent, ModalHeader, ModalTitle } from "../ui/modal";
import { Button } from "../ui/button";
import { RevealKeysDialog } from "./modals/RevealKeysDialog";
import { showToast } from "../ui/toast";
import { Back } from "../icons";
import BackButton from "../ui/BackButton";
import {RecentModelsRadio, useRecentModels} from "./models/RecentModels";
import { useHandleModelSelection} from "./models/utils";


const EXTENSIONS_DESCRIPTION =
"The Model Context Protocol (MCP) is a system that allows AI models to securely connect with local or remote resources using standard server setups. It works like a client-server setup and expands AI capabilities using three main components: Prompts, Resources, and Tools.";
Expand Down Expand Up @@ -70,6 +73,8 @@ const DEFAULT_SETTINGS: SettingsType = {

export default function Settings() {
const navigate = useNavigate();
const { recentModels } = useRecentModels(); // Access recent models
const handleModelSelection = useHandleModelSelection();

const [settings, setSettings] = React.useState<SettingsType>(() => {
const saved = localStorage.getItem("user_settings");
Expand All @@ -81,14 +86,12 @@ export default function Settings() {
localStorage.setItem("user_settings", JSON.stringify(settings));
}, [settings]);

const handleModelToggle = (modelId: string) => {
setSettings((prev) => ({
...prev,
models: prev.models.map((model) => ({
...model,
enabled: model.id === modelId,
})),
}));
const handleModelToggle = async (model: Model) => {
try {
await handleModelSelection(model, "Settings"); // Use the provided model selection logic
} catch (error) {
console.error("Failed to switch model:", error);
}
};

const handleExtensionToggle = (extensionId: string) => {
Expand Down Expand Up @@ -183,14 +186,12 @@ export default function Settings() {
{/* Left Navigation */}
<div className="w-48 border-r border-gray-100 dark:border-gray-700 px-2 pt-2">
<div className="sticky top-8">
<button
onClick={handleExit}
className="flex items-center gap-2 text-gray-600 hover:text-gray-800
dark:text-gray-400 dark:hover:text-gray-200 mb-16 mt-4"
>
<Back className="w-4 h-4" />
<span>Back</span>
</button>
<BackButton
onClick={() => {
handleExit();
}}
className="mb-4"
/>
<div className="space-y-2">
{["Models", "Extensions", "Keys"].map((section) => (
<button
Expand All @@ -214,19 +215,13 @@ export default function Settings() {
<div className="flex justify-between items-center mb-4">
<h2 className="text-2xl font-semibold">Models</h2>
<button
onClick={() => setAddModelOpen(true)}
onClick={() => navigate("/settings/more-models")}
className="text-indigo-500 hover:text-indigo-600 font-medium"
>
Add Models
More Models
</button>
</div>
{settings.models.map((model) => (
<ToggleableItem
key={model.id}
{...model}
onToggle={handleModelToggle}
/>
))}
<RecentModelsRadio/>
</section>

{/* Extensions Section */}
Expand Down
46 changes: 46 additions & 0 deletions ui/desktop/src/components/settings/api_keys/ActiveKeysContext.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import React, {createContext, useContext, useState, ReactNode, useEffect} from "react";
import {getActiveProviders} from "./utils";

// Create a context for active keys
const ActiveKeysContext = createContext<{
activeKeys: string[];
setActiveKeys: (keys: string[]) => void;
} | undefined>(undefined);

export const ActiveKeysProvider = ({ children }: { children: ReactNode }) => {
const [activeKeys, setActiveKeys] = useState<string[]>([]); // Start with an empty list
const [isLoading, setIsLoading] = useState(true); // Track loading state

// Fetch active keys from the backend
useEffect(() => {
const fetchActiveProviders = async () => {
try {
const providers = await getActiveProviders(); // Fetch the active providers
console.log("Fetched providers:", providers);
setActiveKeys(providers); // Update state with fetched providers
} catch (error) {
console.error("Error fetching active providers:", error);
} finally {
setIsLoading(false); // Ensure loading is marked as complete
}
};

fetchActiveProviders(); // Call the async function
}, []);

// Provide active keys and ability to update them
return (
<ActiveKeysContext.Provider value={{ activeKeys, setActiveKeys }}>
{!isLoading ? children : <div>Loading...</div>} {/* Conditional rendering */}
</ActiveKeysContext.Provider>
);
};

// Custom hook to access active keys
export const useActiveKeys = () => {
const context = useContext(ActiveKeysContext);
if (!context) {
throw new Error("useActiveKeys must be used within an ActiveKeysProvider");
}
return context;
};
Loading
Loading