Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add endpoint enable disable tool #875

Merged
merged 3 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions shinkai-bin/shinkai-node/src/network/handle_commands_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3008,7 +3008,15 @@ impl Node {
let encryption_public_key_clone = self.encryption_public_key.clone();
let identity_public_key_clone = self.identity_public_key.clone();
tokio::spawn(async move {
let _ = Node::v2_api_compute_quests_status(db_clone, node_name_clone, encryption_public_key_clone, identity_public_key_clone, bearer, res).await;
let _ = Node::v2_api_compute_quests_status(
db_clone,
node_name_clone,
encryption_public_key_clone,
identity_public_key_clone,
bearer,
res,
)
.await;
});
}
NodeCommand::V2ApiComputeAndSendQuestsStatus { bearer, res } => {
Expand All @@ -3017,7 +3025,26 @@ impl Node {
let encryption_public_key_clone = self.encryption_public_key.clone();
let identity_public_key_clone = self.identity_public_key.clone();
tokio::spawn(async move {
let _ = Node::v2_api_compute_and_send_quests_status(db_clone, node_name_clone, encryption_public_key_clone, identity_public_key_clone, bearer, res).await;
let _ = Node::v2_api_compute_and_send_quests_status(
db_clone,
node_name_clone,
encryption_public_key_clone,
identity_public_key_clone,
bearer,
res,
)
.await;
});
}
NodeCommand::V2ApiSetToolEnabled {
bearer,
tool_router_key,
enabled,
res,
} => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_set_tool_enabled(db_clone, bearer, tool_router_key, enabled, res).await;
});
}
_ => (),
Expand Down
107 changes: 69 additions & 38 deletions shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
use crate::{
llm_provider::job_manager::JobManager,
managers::{tool_router, IdentityManager},
network::{node_error::NodeError, node_shareable_logic::download_zip_file, Node},
tools::{
tool_definitions::definition_generation::{generate_tool_definitions, get_all_deno_tools},
tool_execution::execution_coordinator::{execute_code, execute_tool_cmd},
tool_generation::v2_create_and_send_job_message,
tool_prompts::{generate_code_prompt, tool_metadata_implementation_prompt},
},
utils::environment::NodeEnvironment,
llm_provider::job_manager::JobManager, managers::{tool_router, IdentityManager}, network::{node_error::NodeError, node_shareable_logic::download_zip_file, Node}, tools::{
tool_definitions::definition_generation::{generate_tool_definitions, get_all_deno_tools}, tool_execution::execution_coordinator::{execute_code, execute_tool_cmd}, tool_generation::v2_create_and_send_job_message, tool_prompts::{generate_code_prompt, tool_metadata_implementation_prompt}
}, utils::environment::NodeEnvironment
};
use async_channel::Sender;
use chrono::Utc;
Expand All @@ -18,45 +11,26 @@ use serde_json::{json, Map, Value};
use shinkai_http_api::node_api_router::{APIError, SendResponseBodyData};
use shinkai_message_primitives::{
schemas::{
inbox_name::InboxName, indexable_version::IndexableVersion, job::JobLike, job_config::JobConfig,
shinkai_name::ShinkaiSubidentityType, tool_router_key::ToolRouterKey,
},
shinkai_message::shinkai_message_schemas::{CallbackAction, JobCreationInfo, MessageSchemaType},
shinkai_utils::{
job_scope::MinimalJobScope, shinkai_message_builder::ShinkaiMessageBuilder,
signatures::clone_signature_secret_key,
},
inbox_name::InboxName, indexable_version::IndexableVersion, job::JobLike, job_config::JobConfig, shinkai_name::ShinkaiSubidentityType, tool_router_key::ToolRouterKey
}, shinkai_message::shinkai_message_schemas::{CallbackAction, JobCreationInfo, MessageSchemaType}, shinkai_utils::{
job_scope::MinimalJobScope, shinkai_message_builder::ShinkaiMessageBuilder, signatures::clone_signature_secret_key
}
};
use shinkai_message_primitives::{
schemas::{
shinkai_name::ShinkaiName,
shinkai_tools::{CodeLanguage, DynamicToolType},
},
shinkai_message::shinkai_message_schemas::JobMessage,
shinkai_name::ShinkaiName, shinkai_tools::{CodeLanguage, DynamicToolType}
}, shinkai_message::shinkai_message_schemas::JobMessage
};
use shinkai_sqlite::{errors::SqliteManagerError, SqliteManager};
use shinkai_tools_primitives::tools::{
deno_tools::DenoTool,
error::ToolError,
python_tools::PythonTool,
shinkai_tool::{ShinkaiTool, ShinkaiToolWithAssets},
tool_config::{OAuth, ToolConfig},
tool_output_arg::ToolOutputArg,
tool_playground::{ToolPlayground, ToolPlaygroundMetadata},
deno_tools::DenoTool, error::ToolError, python_tools::PythonTool, shinkai_tool::{ShinkaiTool, ShinkaiToolWithAssets}, tool_config::{OAuth, ToolConfig}, tool_output_arg::ToolOutputArg, tool_playground::{ToolPlayground, ToolPlaygroundMetadata}
};
use shinkai_tools_primitives::tools::{
shinkai_tool::ShinkaiToolHeader,
tool_types::{OperatingSystem, RunnerType, ToolResult},
shinkai_tool::ShinkaiToolHeader, tool_types::{OperatingSystem, RunnerType, ToolResult}
};
use std::path::PathBuf;
use std::{
collections::HashMap,
env,
fs::File,
io::{Read, Write},
result,
sync::Arc,
time::Instant,
collections::HashMap, env, fs::File, io::{Read, Write}, result, sync::Arc, time::Instant
};
use tokio::fs;
use tokio::{process::Command, sync::Mutex};
Expand Down Expand Up @@ -3519,6 +3493,63 @@ Happy coding!"#,
}
}
}

pub async fn v2_api_set_tool_enabled(
db: Arc<SqliteManager>,
bearer: String,
tool_router_key: String,
enabled: bool,
res: Sender<Result<Value, APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}

// Get the tool first to verify it exists
match db.get_tool_by_key(&tool_router_key) {
Ok(mut tool) => {
// Update the enabled status using the appropriate method
if enabled {
tool.enable();
} else {
tool.disable();
}

// Save the updated tool
match db.update_tool(tool).await {
Ok(_) => {
let response = json!({
"tool_router_key": tool_router_key,
"enabled": enabled,
"success": true
});
let _ = res.send(Ok(response)).await;
}
Err(e) => {
let _ = res
.send(Err(APIError {
code: 500,
error: "Failed to update tool".to_string(),
message: format!("Failed to update tool: {}", e),
}))
.await;
}
}
}
Err(_) => {
let _ = res
.send(Err(APIError {
code: 404,
error: "Tool not found".to_string(),
message: format!("Tool with key '{}' not found", tool_router_key),
}))
.await;
}
}

Ok(())
}
}

#[cfg(test)]
Expand Down
57 changes: 57 additions & 0 deletions shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ pub fn tool_routes(
.and(warp::header::<String>("authorization"))
.and_then(list_all_shinkai_tools_versions_handler);

let set_tool_enabled_route = warp::path("set_tool_enabled")
.and(warp::post())
.and(with_sender(node_commands_sender.clone()))
.and(warp::header::<String>("authorization"))
.and(warp::body::json())
.and_then(set_tool_enabled_handler);

tool_execution_route
.or(code_execution_route)
.or(tool_definitions_route)
Expand Down Expand Up @@ -284,6 +291,7 @@ pub fn tool_routes(
.or(tool_store_proxy_route)
.or(standalone_playground_route)
.or(list_all_shinkai_tools_versions_route)
.or(set_tool_enabled_route)
}

pub fn safe_folder_name(tool_router_key: &str) -> String {
Expand Down Expand Up @@ -2112,6 +2120,53 @@ pub async fn list_all_shinkai_tools_versions_handler(
}
}

#[derive(Deserialize, ToSchema)]
pub struct SetToolEnabledRequest {
pub tool_router_key: String,
pub enabled: bool,
}

#[utoipa::path(
post,
path = "/v2/set_tool_enabled",
request_body = SetToolEnabledRequest,
responses(
(status = 200, description = "Successfully enabled/disabled tool", body = Value),
(status = 400, description = "Bad request", body = APIError),
(status = 500, description = "Internal server error", body = APIError)
)
)]
pub async fn set_tool_enabled_handler(
sender: Sender<NodeCommand>,
authorization: String,
payload: SetToolEnabledRequest,
) -> Result<impl warp::Reply, warp::Rejection> {
let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string();
let (res_sender, res_receiver) = async_channel::bounded(1);

sender
.send(NodeCommand::V2ApiSetToolEnabled {
bearer,
tool_router_key: payload.tool_router_key,
enabled: payload.enabled,
res: res_sender,
})
.await
.map_err(|_| warp::reject::reject())?;

let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?;

match result {
Ok(response) => {
let response = create_success_response(response);
Ok(warp::reply::with_status(warp::reply::json(&response), StatusCode::OK))
}
Err(error) => Ok(warp::reply::with_status(
warp::reply::json(&error),
StatusCode::from_u16(error.code).unwrap(),
)),
}
}

#[derive(OpenApi)]
#[openapi(
Expand Down Expand Up @@ -2144,11 +2199,13 @@ pub async fn list_all_shinkai_tools_versions_handler(
disable_all_tools_handler,
tool_store_proxy_handler,
standalone_playground_handler,
set_tool_enabled_handler,
),
components(
schemas(
APIError,
ToolExecutionRequest,
SetToolEnabledRequest,
)
),
tags(
Expand Down
48 changes: 14 additions & 34 deletions shinkai-libs/shinkai-http-api/src/node_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,16 @@ use ed25519_dalek::VerifyingKey;
use serde_json::{Map, Value};
use shinkai_message_primitives::{
schemas::{
coinbase_mpc_config::CoinbaseMPCWalletConfig,
crontab::{CronTask, CronTaskAction},
custom_prompt::CustomPrompt,
identity::{Identity, StandardIdentity},
job_config::JobConfig,
llm_providers::{agent::Agent, serialized_llm_provider::SerializedLLMProvider},
shinkai_name::ShinkaiName,
shinkai_subscription::ShinkaiSubscription,
shinkai_tool_offering::{ShinkaiToolOffering, UsageTypeInquiry},
shinkai_tools::{CodeLanguage, DynamicToolType},
smart_inbox::{SmartInbox, V2SmartInbox},
tool_router_key::ToolRouterKey,
wallet_complementary::{WalletRole, WalletSource},
wallet_mixed::NetworkIdentifier,
},
shinkai_message::{
shinkai_message::ShinkaiMessage,
shinkai_message_schemas::{
APIAddOllamaModels, APIAvailableSharedItems, APIChangeJobAgentRequest, APIExportSheetPayload,
APIImportSheetPayload, APISetSheetUploadedFilesPayload, APIVecFsCopyFolder, APIVecFsCopyItem,
APIVecFsCreateFolder, APIVecFsDeleteFolder, APIVecFsDeleteItem, APIVecFsMoveFolder, APIVecFsMoveItem,
APIVecFsRetrievePathSimplifiedJson, APIVecFsRetrieveSourceFile, APIVecFsSearchItems,
ExportInboxMessagesFormat, IdentityPermissions, JobCreationInfo, JobMessage, RegistrationCodeType,
V2ChatMessage,
},
},
shinkai_utils::job_scope::MinimalJobScope,
coinbase_mpc_config::CoinbaseMPCWalletConfig, crontab::{CronTask, CronTaskAction}, custom_prompt::CustomPrompt, identity::{Identity, StandardIdentity}, job_config::JobConfig, llm_providers::{agent::Agent, serialized_llm_provider::SerializedLLMProvider}, shinkai_name::ShinkaiName, shinkai_subscription::ShinkaiSubscription, shinkai_tool_offering::{ShinkaiToolOffering, UsageTypeInquiry}, shinkai_tools::{CodeLanguage, DynamicToolType}, smart_inbox::{SmartInbox, V2SmartInbox}, tool_router_key::ToolRouterKey, wallet_complementary::{WalletRole, WalletSource}, wallet_mixed::NetworkIdentifier
}, shinkai_message::{
shinkai_message::ShinkaiMessage, shinkai_message_schemas::{
APIAddOllamaModels, APIAvailableSharedItems, APIChangeJobAgentRequest, APIExportSheetPayload, APIImportSheetPayload, APISetSheetUploadedFilesPayload, APIVecFsCopyFolder, APIVecFsCopyItem, APIVecFsCreateFolder, APIVecFsDeleteFolder, APIVecFsDeleteItem, APIVecFsMoveFolder, APIVecFsMoveItem, APIVecFsRetrievePathSimplifiedJson, APIVecFsRetrieveSourceFile, APIVecFsSearchItems, ExportInboxMessagesFormat, IdentityPermissions, JobCreationInfo, JobMessage, RegistrationCodeType, V2ChatMessage
}
}, shinkai_utils::job_scope::MinimalJobScope
};

use shinkai_tools_primitives::tools::{
shinkai_tool::{ShinkaiTool, ShinkaiToolHeader, ShinkaiToolWithAssets},
tool_config::OAuth,
tool_playground::ToolPlayground,
tool_types::{OperatingSystem, RunnerType},
shinkai_tool::{ShinkaiTool, ShinkaiToolHeader, ShinkaiToolWithAssets}, tool_config::OAuth, tool_playground::ToolPlayground, tool_types::{OperatingSystem, RunnerType}
};
// use crate::{
// prompts::custom_prompt::CustomPrompt, tools::shinkai_tool::{ShinkaiTool, ShinkaiToolHeader}, wallet::{
Expand All @@ -51,9 +27,7 @@ use x25519_dalek::PublicKey as EncryptionPublicKey;
use crate::node_api_router::SendResponseBody;

use super::{
api_v1::api_v1_handlers::APIUseRegistrationCodeSuccessResponse,
api_v2::api_v2_handlers_general::InitialRegistrationRequest,
node_api_router::{APIError, GetPublicKeysResponse, SendResponseBodyData},
api_v1::api_v1_handlers::APIUseRegistrationCodeSuccessResponse, api_v2::api_v2_handlers_general::InitialRegistrationRequest, node_api_router::{APIError, GetPublicKeysResponse, SendResponseBodyData}
};

pub enum NodeCommand {
Expand Down Expand Up @@ -1267,4 +1241,10 @@ pub enum NodeCommand {
bearer: String,
res: Sender<Result<Value, APIError>>,
},
V2ApiSetToolEnabled {
bearer: String,
tool_router_key: String,
enabled: bool,
res: Sender<Result<Value, APIError>>,
},
}