Skip to content

Commit

Permalink
Merge pull request #872 from dcSpark/nico/allow_retry_on_first_message
Browse files Browse the repository at this point in the history
allow retry on first message
  • Loading branch information
nicarq authored Feb 15, 2025
2 parents 1341202 + 9d51343 commit 8361ff2
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 27 deletions.
2 changes: 1 addition & 1 deletion scripts/run_agent_provider.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export INITIAL_AGENT_URLS="https://api.openai.com"
export INITIAL_AGENT_MODELS="openai:gpt-4o-mini"

export CONTRACT_ADDRESS="0x425fb20ba3874e887336aaa7f3fab32d08135ba9"
export ADD_TESTING_NETWORK_ECHO="true"
export ADD_TESTING_NETWORK_ECHO="false"

# Add these lines to enable all log options
export LOG_ALL=1
Expand Down
2 changes: 1 addition & 1 deletion shinkai-bin/shinkai-node/src/managers/tool_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provide
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
use shinkai_message_primitives::schemas::shinkai_preferences::ShinkaiInternalComms;
use shinkai_message_primitives::schemas::shinkai_tool_offering::{
AssetPayment, ToolPrice, UsageType, UsageTypeInquiry,
AssetPayment, ToolPrice, UsageType, UsageTypeInquiry
};
use shinkai_message_primitives::schemas::shinkai_tools::CodeLanguage;
use shinkai_message_primitives::schemas::wallet_mixed::{Asset, NetworkIdentifier};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,27 @@ impl Node {
None
};

// Check if the parent message is the first message in the job
let is_parent_first_message_in_job =
if let Ok(Some(first_message)) = db.get_first_message_from_inbox(inbox_name.clone()) {
first_message.calculate_message_hash_for_pagination() == parent_message_hash
} else {
false
};

// If the parent message is the first message in the job, clear all messages before adding the new one
if is_parent_first_message_in_job {
if let Err(err) = db.clear_inbox_messages(&inbox_name) {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to clear inbox messages: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
}

job_message.parent = parent_parent_key;

let shinkai_message = match Self::api_v2_create_shinkai_message(
Expand Down
160 changes: 135 additions & 25 deletions shinkai-libs/shinkai-sqlite/src/inbox_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,10 @@ use rusqlite::params;
use serde_json::Value;
use shinkai_message_primitives::{
schemas::{
identity::StandardIdentity,
inbox_name::InboxName,
inbox_permission::InboxPermission,
job_config::JobConfig,
shinkai_name::ShinkaiName,
smart_inbox::{LLMProviderSubset, ProviderType, SmartInbox},
ws_types::{WSMessageType, WSUpdateHandler},
},
shinkai_message::{
shinkai_message::{NodeApiData, ShinkaiMessage},
shinkai_message_schemas::WSTopic,
},
shinkai_utils::shinkai_time::ShinkaiStringTime,
identity::StandardIdentity, inbox_name::InboxName, inbox_permission::InboxPermission, job_config::JobConfig, shinkai_name::ShinkaiName, smart_inbox::{LLMProviderSubset, ProviderType, SmartInbox}, ws_types::{WSMessageType, WSUpdateHandler}
}, shinkai_message::{
shinkai_message::{NodeApiData, ShinkaiMessage}, shinkai_message_schemas::WSTopic
}, shinkai_utils::shinkai_time::ShinkaiStringTime
};
use tokio::sync::Mutex;

Expand Down Expand Up @@ -752,6 +743,51 @@ impl SqliteManager {
)?;
Ok(())
}

pub fn get_first_message_from_inbox(
&self,
inbox_name: String,
) -> Result<Option<ShinkaiMessage>, SqliteManagerError> {
let conn = self.get_connection()?;
let mut stmt = conn.prepare(
"SELECT shinkai_message FROM inbox_messages
WHERE inbox_name = ?1
ORDER BY time_key ASC
LIMIT 1",
)?;

let mut rows = stmt.query(params![inbox_name])?;

if let Some(row) = rows.next()? {
let encoded_message: Vec<u8> = row.get(0)?;
let message = ShinkaiMessage::decode_message_result(encoded_message)
.map_err(|e| SqliteManagerError::SomeError(e.to_string()))?;
Ok(Some(message))
} else {
Ok(None)
}
}

pub fn clear_inbox_messages(&self, inbox_name: &str) -> Result<(), SqliteManagerError> {
if !self.does_inbox_exist(inbox_name)? {
return Err(SqliteManagerError::InboxNotFound(inbox_name.to_string()));
}

let mut conn = self.get_connection()?;
let tx = conn.transaction()?;

// Delete all messages from the inbox
tx.execute("DELETE FROM inbox_messages WHERE inbox_name = ?1", params![inbox_name])?;

// Reset the read_up_to_message_hash to null since there are no messages
tx.execute(
"UPDATE inboxes SET read_up_to_message_hash = NULL WHERE inbox_name = ?1",
params![inbox_name],
)?;

tx.commit()?;
Ok(())
}
}

#[cfg(test)]
Expand All @@ -760,18 +796,11 @@ mod tests {
use ed25519_dalek::SigningKey;
use shinkai_embedding::model_type::{EmbeddingModelType, OllamaTextEmbeddingsInference};
use shinkai_message_primitives::{
schemas::identity::StandardIdentityType,
shinkai_message::{
shinkai_message::MessageBody,
shinkai_message_schemas::{IdentityPermissions, MessageSchemaType},
},
shinkai_utils::{
encryption::{unsafe_deterministic_encryption_keypair, EncryptionMethod},
job_scope::MinimalJobScope,
search_mode::VectorSearchMode,
shinkai_message_builder::ShinkaiMessageBuilder,
signatures::{clone_signature_secret_key, unsafe_deterministic_signature_keypair},
},
schemas::identity::StandardIdentityType, shinkai_message::{
shinkai_message::MessageBody, shinkai_message_schemas::{IdentityPermissions, MessageSchemaType}
}, shinkai_utils::{
encryption::{unsafe_deterministic_encryption_keypair, EncryptionMethod}, job_scope::MinimalJobScope, search_mode::VectorSearchMode, shinkai_message_builder::ShinkaiMessageBuilder, signatures::{clone_signature_secret_key, unsafe_deterministic_signature_keypair}
}
};
use std::path::PathBuf;
use tempfile::NamedTempFile;
Expand Down Expand Up @@ -1414,6 +1443,87 @@ mod tests {
assert_eq!(last_messages_inbox.len(), 1);
}

#[tokio::test]
async fn test_get_first_message_from_inbox() {
let db = setup_test_db();

let node_identity_name = "@@node.shinkai";
let subidentity_name = "main";
let (node_identity_sk, _) = unsafe_deterministic_signature_keypair(0);
let (node_encryption_sk, _node_encryption_pk) = unsafe_deterministic_encryption_keypair(0);

let (_, node_subencryption_pk) = unsafe_deterministic_encryption_keypair(100);

// Create and insert multiple messages with different timestamps
let messages = vec![
("First Message", "2023-07-02T20:53:34.812Z"),
("Second Message", "2023-07-02T20:54:34.812Z"),
("Third Message", "2023-07-02T20:55:34.812Z"),
];

let mut inbox_name = String::new();

for (content, timestamp) in messages {
let message = generate_message_with_text(
content.to_string(),
node_encryption_sk.clone(),
clone_signature_secret_key(&node_identity_sk),
node_subencryption_pk,
subidentity_name.to_string(),
node_identity_name.to_string(),
timestamp.to_string(),
);

if inbox_name.is_empty() {
inbox_name = InboxName::from_message(&message).unwrap().to_string();
}

db.unsafe_insert_inbox_message(&message, None, None).await.unwrap();
}

// Test getting the first message
let first_message = db.get_first_message_from_inbox(inbox_name.clone()).unwrap();
assert!(first_message.is_some());
assert_eq!(first_message.unwrap().get_message_content().unwrap(), "First Message");

// Test with non-existent inbox
let non_existent = db
.get_first_message_from_inbox("non_existent_inbox".to_string())
.unwrap();
assert!(non_existent.is_none());

// Test clearing messages from the inbox
db.clear_inbox_messages(&inbox_name).unwrap();

// Verify the inbox is empty after clearing
let first_message_after_clear = db.get_first_message_from_inbox(inbox_name.clone()).unwrap();
assert!(
first_message_after_clear.is_none(),
"Inbox should be empty after clearing messages"
);

// Verify we can still add new messages after clearing
let new_message = generate_message_with_text(
"New Message After Clear".to_string(),
node_encryption_sk.clone(),
clone_signature_secret_key(&node_identity_sk),
node_subencryption_pk,
subidentity_name.to_string(),
node_identity_name.to_string(),
"2023-07-02T21:00:00.000Z".to_string(),
);

db.unsafe_insert_inbox_message(&new_message, None, None).await.unwrap();

// Verify the new message is now the first message
let first_message_after_new = db.get_first_message_from_inbox(inbox_name.clone()).unwrap();
assert!(first_message_after_new.is_some());
assert_eq!(
first_message_after_new.unwrap().get_message_content().unwrap(),
"New Message After Clear"
);
}

// For benchmarking purposes
// #[tokio::test]
async fn benchmark_get_all_smart_inboxes_for_profile() {
Expand Down

0 comments on commit 8361ff2

Please sign in to comment.