Skip to content

Commit

Permalink
Merge pull request #827 from dcSpark/feature/duplicate-tool
Browse files Browse the repository at this point in the history
Duplicate Tool
  • Loading branch information
nicarq authored Feb 8, 2025
2 parents 629a88a + 4280c03 commit c6905ca
Show file tree
Hide file tree
Showing 19 changed files with 1,497 additions and 241 deletions.
39 changes: 32 additions & 7 deletions shinkai-bin/shinkai-node/src/managers/tool_router.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::env;
use std::sync::Arc;
use std::time::Instant;
use std::{env, fs};

use crate::llm_provider::error::LLMProviderError;
use crate::llm_provider::execution::chains::generic_chain::generic_inference_chain::GenericInferenceChain;
Expand All @@ -20,8 +20,9 @@ use shinkai_message_primitives::schemas::invoices::{Invoice, InvoiceStatusEnum};
use shinkai_message_primitives::schemas::job::JobLike;
use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent;
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
use shinkai_message_primitives::schemas::shinkai_preferences::ShinkaiInternalComms;
use shinkai_message_primitives::schemas::shinkai_tool_offering::{
AssetPayment, ToolPrice, UsageType, UsageTypeInquiry,
AssetPayment, ToolPrice, UsageType, UsageTypeInquiry
};
use shinkai_message_primitives::schemas::shinkai_tools::CodeLanguage;
use shinkai_message_primitives::schemas::wallet_mixed::{Asset, NetworkIdentifier};
Expand Down Expand Up @@ -95,16 +96,16 @@ impl ToolRouter {
.map_err(|e| ToolError::DatabaseError(e.to_string()))?;
}

if let Err(e) = self.add_rust_tools().await {
eprintln!("Error adding rust tools: {}", e);
}

if let Err(e) =
Self::import_tools_from_directory(self.sqlite_manager.clone(), self.signing_secret_key.clone()).await
{
eprintln!("Error importing tools from directory: {}", e);
}

if let Err(e) = self.add_rust_tools().await {
eprintln!("Error adding rust tools: {}", e);
}

if is_empty {
if let Err(e) = self.add_static_prompts(&generator).await {
eprintln!("Error adding static prompts: {}", e);
Expand Down Expand Up @@ -162,6 +163,18 @@ impl ToolRouter {
return Ok(());
}

// Set the sync status to false at the start
let internal_comms = ShinkaiInternalComms {
internal_has_sync_default_tools: false,
};
if let Err(e) = db.set_preference(
"internal_comms",
&internal_comms,
Some("Internal communication preferences"),
) {
eprintln!("Error setting internal_comms preference: {}", e);
}

let start_time = Instant::now();
let node_env = fetch_node_environment();

Expand Down Expand Up @@ -324,6 +337,18 @@ impl ToolRouter {
let duration = start_time.elapsed();
println!("Total time taken to import/upgrade tools: {:?}", duration);

// Set the sync status to true after successful completion
let internal_comms = ShinkaiInternalComms {
internal_has_sync_default_tools: true,
};
if let Err(e) = db.set_preference(
"internal_comms",
&internal_comms,
Some("Internal communication preferences"),
) {
eprintln!("Error setting internal_comms preference: {}", e);
}

Ok(())
}

Expand Down Expand Up @@ -1209,7 +1234,7 @@ impl ToolRouter {
}

// Check if the top vector search result has a score under 0.2
if let Some((tool, score)) = vector_tools.first() {
if let Some((tool, _score)) = vector_tools.first() {
if seen_ids.insert(tool.tool_router_key.clone()) {
combined_tools.push(tool.clone());
}
Expand Down
35 changes: 35 additions & 0 deletions shinkai-bin/shinkai-node/src/network/handle_commands_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2889,6 +2889,35 @@ impl Node {
let _ = Node::v2_api_disable_all_tools(db_clone, bearer, res).await;
});
}
NodeCommand::V2ApiDuplicateTool {
bearer,
tool_key_path,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_name_clone = self.node_name.clone();
let identity_manager = self.identity_manager.clone();
let job_manager = self.job_manager.clone();
let encryption_secret_key = self.encryption_secret_key.clone();
let encryption_public_key = self.encryption_public_key.clone();
let signing_secret_key = self.identity_secret_key.clone();

tokio::spawn(async move {
let _ = Node::v2_api_duplicate_tool(
db_clone,
bearer,
tool_key_path,
node_name_clone,
identity_manager,
job_manager,
encryption_secret_key,
encryption_public_key,
signing_secret_key,
res,
)
.await;
});
}
NodeCommand::V2ApiAddRegexPattern {
bearer,
provider_name,
Expand Down Expand Up @@ -2961,6 +2990,12 @@ impl Node {
.await;
});
}
NodeCommand::V2ApiCheckDefaultToolsSync { bearer, res } => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_check_default_tools_sync(db_clone, bearer, res).await;
});
}
NodeCommand::V2ApiComputeQuestsStatus { bearer, res } => {
let db_clone = Arc::clone(&self.db);
let node_name_clone = self.node_name.clone();
Expand Down
27 changes: 27 additions & 0 deletions shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use crate::{
llm_provider::{job_manager::JobManager, llm_stopper::LLMStopper}, managers::{identity_manager::IdentityManagerTrait, IdentityManager}, network::{node_error::NodeError, node_shareable_logic::download_zip_file, Node}, tools::tool_generation, utils::update_global_identity::update_global_identity_name
};

use shinkai_message_primitives::schemas::shinkai_preferences::ShinkaiInternalComms;
use std::time::Instant;
use tokio::time::Duration;
use x25519_dalek::StaticSecret as EncryptionStaticKey;
Expand Down Expand Up @@ -1712,6 +1713,32 @@ impl Node {
Ok(())
}

pub async fn v2_api_check_default_tools_sync(
db: Arc<SqliteManager>,
bearer: String,
res: Sender<Result<bool, APIError>>,
) -> Result<(), NodeError> {
// Validate bearer token
if let Err(_) = Self::validate_bearer_token(&bearer, db.clone(), &res).await {
return Ok(());
}

// Get the internal_comms preference from the database
match db.get_preference::<ShinkaiInternalComms>("internal_comms") {
Ok(Some(internal_comms)) => {
let _ = res.send(Ok(internal_comms.internal_has_sync_default_tools)).await;
}
Ok(None) => {
let _ = res.send(Ok(false)).await;
}
Err(e) => {
eprintln!("Error getting internal_comms preference: {}", e);
let _ = res.send(Ok(false)).await;
}
}
Ok(())
}

pub async fn v2_api_compute_quests_status(
db: Arc<SqliteManager>,
node_name: ShinkaiName,
Expand Down
Loading

0 comments on commit c6905ca

Please sign in to comment.