Skip to content

Commit

Permalink
Merge pull request #866 from dcSpark/feature/tool-version-fixes
Browse files Browse the repository at this point in the history
version fixes
  • Loading branch information
nicarq authored Feb 12, 2025
2 parents e62ebe7 + 8371779 commit a784ceb
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 13 deletions.
6 changes: 6 additions & 0 deletions shinkai-bin/shinkai-node/src/network/handle_commands_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,12 @@ impl Node {
let _ = Node::v2_api_list_all_shinkai_tools(db_clone, bearer, res).await;
});
}
NodeCommand::V2ApiListAllShinkaiToolsVersions { bearer, res } => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_list_all_shinkai_tools_versions(db_clone, bearer, res).await;
});
}
NodeCommand::V2ApiSetShinkaiTool {
bearer,
tool_key,
Expand Down
113 changes: 108 additions & 5 deletions shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
llm_provider::job_manager::JobManager,
managers::IdentityManager,
managers::{tool_router, IdentityManager},
network::{node_error::NodeError, node_shareable_logic::download_zip_file, Node},
tools::{
tool_definitions::definition_generation::{generate_tool_definitions, get_all_deno_tools},
Expand Down Expand Up @@ -35,7 +35,6 @@ use shinkai_message_primitives::{
shinkai_message::shinkai_message_schemas::JobMessage,
};
use shinkai_sqlite::{errors::SqliteManagerError, SqliteManager};
use shinkai_tools_primitives::tools::tool_types::{OperatingSystem, RunnerType, ToolResult};
use shinkai_tools_primitives::tools::{
deno_tools::DenoTool,
error::ToolError,
Expand All @@ -45,6 +44,10 @@ use shinkai_tools_primitives::tools::{
tool_output_arg::ToolOutputArg,
tool_playground::{ToolPlayground, ToolPlaygroundMetadata},
};
use shinkai_tools_primitives::tools::{
shinkai_tool::ShinkaiToolHeader,
tool_types::{OperatingSystem, RunnerType, ToolResult},
};
use std::path::PathBuf;
use std::{
collections::HashMap,
Expand Down Expand Up @@ -214,8 +217,36 @@ impl Node {
// List all tools
match db.get_all_tool_headers() {
Ok(tools) => {
let response = json!(tools);
let _ = res.send(Ok(response)).await;
// Group tools by their base key (without version)
use std::collections::HashMap;
let mut tool_groups: HashMap<String, Vec<ShinkaiToolHeader>> = HashMap::new();

for tool in tools {
let tool_router_key = tool.tool_router_key.clone();
tool_groups.entry(tool_router_key).or_default().push(tool);
}

// For each group, keep only the tool with the highest version
let mut latest_tools = Vec::new();
for (_, mut group) in tool_groups {
if group.len() == 1 {
latest_tools.push(group.pop().unwrap());
} else {
// Sort by version in descending order
group.sort_by(|a, b| {
let a_version = IndexableVersion::from_string(&a.version.clone())
.unwrap_or(IndexableVersion::from_number(0));
let b_version = IndexableVersion::from_string(&b.version.clone())
.unwrap_or(IndexableVersion::from_number(0));
b_version.cmp(&a_version)
});

// Take the first one (highest version)
latest_tools.push(group.remove(0));
}
}
let t = latest_tools.iter().map(|tool| json!(tool)).collect();
let _ = res.send(Ok(t)).await;
Ok(())
}
Err(err) => {
Expand Down Expand Up @@ -2229,11 +2260,25 @@ impl Node {
// Acquire a write lock on the database
let db_write = db;

let tool_router_key = ToolRouterKey::from_string(&tool_key);
if tool_router_key.is_err() {
let api_error = APIError {
code: StatusCode::BAD_REQUEST.as_u16(),
error: "Bad Request".to_string(),
message: format!("Invalid tool key: {}", tool_router_key.err().unwrap()),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
let tool_router_key = tool_router_key.unwrap();

let version = tool_router_key.version;

// Attempt to remove the playground tool first
let _ = db_write.remove_tool_playground(&tool_key);

// Remove the tool from the database
match db_write.remove_tool(&tool_key, None) {
match db_write.remove_tool(&tool_key, version) {
Ok(_) => {
let response = json!({ "status": "success", "message": "Tool and associated playground (if any) removed successfully" });
let _ = res.send(Ok(response)).await;
Expand Down Expand Up @@ -3416,6 +3461,64 @@ Happy coding!"#,
"files": files_created
}))
}

pub async fn v2_api_list_all_shinkai_tools_versions(
db: Arc<SqliteManager>,
bearer: String,
res: Sender<Result<Value, APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}

// List all tools
match db.get_all_tool_headers() {
Ok(tools) => {
// Group tools by their base key (without version)
use std::collections::HashMap;
let mut tool_groups: HashMap<String, Vec<ShinkaiToolHeader>> = HashMap::new();

for tool in tools {
let tool_router_key = tool.tool_router_key.clone();
tool_groups.entry(tool_router_key).or_default().push(tool);
}

// For each group, sort versions and create the response structure
let mut result = Vec::new();
for (key, mut group) in tool_groups {
// Sort by version in descending order
group.sort_by(|a, b| {
let a_version = IndexableVersion::from_string(&a.version.clone())
.unwrap_or(IndexableVersion::from_number(0));
let b_version = IndexableVersion::from_string(&b.version.clone())
.unwrap_or(IndexableVersion::from_number(0));
b_version.cmp(&a_version)
});

// Extract versions
let versions: Vec<String> = group.iter().map(|tool| tool.version.clone()).collect();

result.push(json!({
"tool_router_key": key,
"versions": versions,
}));
}

let _ = res.send(Ok(json!(result))).await;
Ok(())
}
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to list tools: {}", err),
};
let _ = res.send(Err(api_error)).await;
Ok(())
}
}
}
}

#[cfg(test)]
Expand Down
45 changes: 45 additions & 0 deletions shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ pub fn tool_routes(
.and(warp::body::json())
.and_then(standalone_playground_handler);

let list_all_shinkai_tools_versions_route = warp::path("list_all_shinkai_tools_versions")
.and(warp::get())
.and(with_sender(node_commands_sender.clone()))
.and(warp::header::<String>("authorization"))
.and_then(list_all_shinkai_tools_versions_handler);

tool_execution_route
.or(code_execution_route)
.or(tool_definitions_route)
Expand Down Expand Up @@ -277,6 +283,7 @@ pub fn tool_routes(
.or(disable_all_tools_route)
.or(tool_store_proxy_route)
.or(standalone_playground_route)
.or(list_all_shinkai_tools_versions_route)
}

pub fn safe_folder_name(tool_router_key: &str) -> String {
Expand Down Expand Up @@ -2068,6 +2075,44 @@ pub async fn standalone_playground_handler(
}
}


#[utoipa::path(
get,
path = "/v2/list_all_shinkai_tools_versions",
responses(
(status = 200, description = "Successfully listed all Shinkai tools with versions", body = Value),
(status = 400, description = "Bad request", body = APIError),
(status = 500, description = "Internal server error", body = APIError)
)
)]
pub async fn list_all_shinkai_tools_versions_handler(
sender: Sender<NodeCommand>,
authorization: String,
) -> Result<impl warp::Reply, warp::Rejection> {
let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string();
let (res_sender, res_receiver) = async_channel::bounded(1);
sender
.send(NodeCommand::V2ApiListAllShinkaiToolsVersions {
bearer,
res: res_sender,
})
.await
.map_err(|_| warp::reject::reject())?;
let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?;

match result {
Ok(response) => {
let response = create_success_response(response);
Ok(warp::reply::with_status(warp::reply::json(&response), StatusCode::OK))
}
Err(error) => Ok(warp::reply::with_status(
warp::reply::json(&error),
StatusCode::from_u16(error.code).unwrap(),
)),
}
}


#[derive(OpenApi)]
#[openapi(
paths(
Expand Down
46 changes: 38 additions & 8 deletions shinkai-libs/shinkai-http-api/src/node_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,40 @@ use ed25519_dalek::VerifyingKey;
use serde_json::{Map, Value};
use shinkai_message_primitives::{
schemas::{
coinbase_mpc_config::CoinbaseMPCWalletConfig, crontab::{CronTask, CronTaskAction}, custom_prompt::CustomPrompt, identity::{Identity, StandardIdentity}, job_config::JobConfig, llm_providers::{agent::Agent, serialized_llm_provider::SerializedLLMProvider}, shinkai_name::ShinkaiName, shinkai_subscription::ShinkaiSubscription, shinkai_tool_offering::{ShinkaiToolOffering, UsageTypeInquiry}, shinkai_tools::{CodeLanguage, DynamicToolType}, smart_inbox::{SmartInbox, V2SmartInbox}, tool_router_key::ToolRouterKey, wallet_complementary::{WalletRole, WalletSource}, wallet_mixed::NetworkIdentifier
}, shinkai_message::{
shinkai_message::ShinkaiMessage, shinkai_message_schemas::{
APIAddOllamaModels, APIAvailableSharedItems, APIChangeJobAgentRequest, APIExportSheetPayload, APIImportSheetPayload, APISetSheetUploadedFilesPayload, APIVecFsCopyFolder, APIVecFsCopyItem, APIVecFsCreateFolder, APIVecFsDeleteFolder, APIVecFsDeleteItem, APIVecFsMoveFolder, APIVecFsMoveItem, APIVecFsRetrievePathSimplifiedJson, APIVecFsRetrieveSourceFile, APIVecFsSearchItems, ExportInboxMessagesFormat, IdentityPermissions, JobCreationInfo, JobMessage, RegistrationCodeType, V2ChatMessage
}
}, shinkai_utils::job_scope::MinimalJobScope
coinbase_mpc_config::CoinbaseMPCWalletConfig,
crontab::{CronTask, CronTaskAction},
custom_prompt::CustomPrompt,
identity::{Identity, StandardIdentity},
job_config::JobConfig,
llm_providers::{agent::Agent, serialized_llm_provider::SerializedLLMProvider},
shinkai_name::ShinkaiName,
shinkai_subscription::ShinkaiSubscription,
shinkai_tool_offering::{ShinkaiToolOffering, UsageTypeInquiry},
shinkai_tools::{CodeLanguage, DynamicToolType},
smart_inbox::{SmartInbox, V2SmartInbox},
tool_router_key::ToolRouterKey,
wallet_complementary::{WalletRole, WalletSource},
wallet_mixed::NetworkIdentifier,
},
shinkai_message::{
shinkai_message::ShinkaiMessage,
shinkai_message_schemas::{
APIAddOllamaModels, APIAvailableSharedItems, APIChangeJobAgentRequest, APIExportSheetPayload,
APIImportSheetPayload, APISetSheetUploadedFilesPayload, APIVecFsCopyFolder, APIVecFsCopyItem,
APIVecFsCreateFolder, APIVecFsDeleteFolder, APIVecFsDeleteItem, APIVecFsMoveFolder, APIVecFsMoveItem,
APIVecFsRetrievePathSimplifiedJson, APIVecFsRetrieveSourceFile, APIVecFsSearchItems,
ExportInboxMessagesFormat, IdentityPermissions, JobCreationInfo, JobMessage, RegistrationCodeType,
V2ChatMessage,
},
},
shinkai_utils::job_scope::MinimalJobScope,
};

use shinkai_tools_primitives::tools::{
shinkai_tool::{ShinkaiTool, ShinkaiToolHeader, ShinkaiToolWithAssets}, tool_config::OAuth, tool_playground::ToolPlayground, tool_types::{OperatingSystem, RunnerType}
shinkai_tool::{ShinkaiTool, ShinkaiToolHeader, ShinkaiToolWithAssets},
tool_config::OAuth,
tool_playground::ToolPlayground,
tool_types::{OperatingSystem, RunnerType},
};
// use crate::{
// prompts::custom_prompt::CustomPrompt, tools::shinkai_tool::{ShinkaiTool, ShinkaiToolHeader}, wallet::{
Expand All @@ -27,7 +51,9 @@ use x25519_dalek::PublicKey as EncryptionPublicKey;
use crate::node_api_router::SendResponseBody;

use super::{
api_v1::api_v1_handlers::APIUseRegistrationCodeSuccessResponse, api_v2::api_v2_handlers_general::InitialRegistrationRequest, node_api_router::{APIError, GetPublicKeysResponse, SendResponseBodyData}
api_v1::api_v1_handlers::APIUseRegistrationCodeSuccessResponse,
api_v2::api_v2_handlers_general::InitialRegistrationRequest,
node_api_router::{APIError, GetPublicKeysResponse, SendResponseBodyData},
};

pub enum NodeCommand {
Expand Down Expand Up @@ -673,6 +699,10 @@ pub enum NodeCommand {
bearer: String,
res: Sender<Result<Value, APIError>>,
},
V2ApiListAllShinkaiToolsVersions {
bearer: String,
res: Sender<Result<Value, APIError>>,
},
V2ApiSetShinkaiTool {
bearer: String,
tool_key: String,
Expand Down

0 comments on commit a784ceb

Please sign in to comment.