Skip to content

Commit

Permalink
Merge pull request #18 from dcSpark/nico/add_leiden
Browse files Browse the repository at this point in the history
new tool (Leiden Algo)
  • Loading branch information
nicarq authored Jul 13, 2024
2 parents e1bd4cb + 686ee12 commit 4b7455e
Show file tree
Hide file tree
Showing 11 changed files with 374 additions and 3 deletions.
4 changes: 4 additions & 0 deletions apps/shinkai-tool-leiden/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"name": "@shinkai_protocol/shinkai-tool-leiden",
"type": "commonjs"
}
35 changes: 35 additions & 0 deletions apps/shinkai-tool-leiden/project.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"name": "@shinkai_protocol/shinkai-tool-leiden",
"$schema": "../../node_modules/nx/schemas/project-schema.json",
"sourceRoot": "apps/shinkai-tool-leiden/src",
"projectType": "library",
"tags": ["tool"],
"targets": {
"build": {
"executor": "@nx/webpack:webpack",
"outputs": ["{options.outputPath}"],
"defaultConfiguration": "production",
"options": {
"compiler": "tsc",
"outputPath": "dist/apps/shinkai-tool-leiden",
"main": "apps/shinkai-tool-leiden/src/index.ts",
"tsConfig": "apps/shinkai-tool-leiden/tsconfig.app.json",
"webpackConfig": "apps/shinkai-tool-leiden/webpack.config.ts"
},
"configurations": {
"development": {},
"production": {}
}
},
"lint": {
"executor": "@nx/linter:eslint",
"outputs": ["{options.outputFile}"],
"options": {
"lintFilePatterns": [
"apps/shinkai-tool-leiden/**/*.ts",
"apps/shinkai-tool-leiden/package.json"
]
}
}
}
}
237 changes: 237 additions & 0 deletions apps/shinkai-tool-leiden/src/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import { BaseTool, RunResult } from '@shinkai_protocol/shinkai-tools-builder';
import { ToolDefinition } from 'libs/shinkai-tools-builder/src/tool-definition';
import {
Clustering,
GradientDescentVOSLayoutAlgorithm,
Layout,
LeidenAlgorithm,
Network,
} from 'networkanalysis-ts';

type Config = {};
type Params = {
edges: [number, number, number][];
resolution?: number; // Make resolution optional
nIterations?: number | null; // Make nIterations optional
nRandomStarts?: number;
convergenceThreshold?: number; // Add optional convergence threshold
};
type Result = {
bestClustering: { nNodes: number; clusters: number[]; nClusters: number };
bestLayout: { nNodes: number; coordinates: number[][] };
};

export class Tool extends BaseTool<Config, Params, Result> {
definition: ToolDefinition<Config, Params, Result> = {
id: 'shinkai-tool-leiden',
name: 'Shinkai: Leiden Algorithm',
description: 'Runs the Leiden algorithm on the input edges',
author: 'Shinkai',
keywords: ['leiden', 'clustering', 'network analysis'],
configurations: {
type: 'object',
properties: {},
required: [],
},
parameters: {
type: 'object',
properties: {
edges: {
type: 'array',
items: {
type: 'array',
items: [{ type: 'number' }, { type: 'number' }, { type: 'number' }],
minItems: 3,
maxItems: 3,
},
},
resolution: { type: 'number', nullable: true },
nIterations: { type: 'number', nullable: true },
nRandomStarts: { type: 'number', nullable: true },
convergenceThreshold: { type: 'number', nullable: true },
},
required: ['edges'],
},
result: {
type: 'object',
properties: {
bestClustering: {
type: 'object',
properties: {
nNodes: { type: 'number' },
clusters: { type: 'array', items: { type: 'number' } },
nClusters: { type: 'number' },
},
required: ['nNodes', 'clusters', 'nClusters'],
},
bestLayout: {
type: 'object',
properties: {
nNodes: { type: 'number' },
coordinates: {
type: 'array',
items: { type: 'array', items: { type: 'number' } },
},
},
required: ['nNodes', 'coordinates'],
},
},
required: ['bestClustering', 'bestLayout'],
},
};

async run(params: Params): Promise<RunResult<Result>> {
const {
edges,
resolution = 1.0, // Set default resolution to 1.0
nIterations = null, // Set default nIterations to null
nRandomStarts = 10, // Set default nRandomStarts to 10
convergenceThreshold = 0.0001, // Set default convergenceThreshold to 0.0001
} = params;

function runLeidenAlgorithm(
edges: [number, number, number][],
resolution: number,
nIterations: number | null,
nRandomStarts: number,
convergenceThreshold?: number, // Add optional convergence threshold
) {
const adjustedEdges = edges.map((edge) => [
edge[0] - 1,
edge[1] - 1,
edge[2],
]);
const nNodes =
Math.max(...adjustedEdges.flatMap((edge) => [edge[0], edge[1]])) + 1;
const network = new Network({
nNodes: nNodes,
setNodeWeightsToTotalEdgeWeights: true,
edges: [
adjustedEdges.map((edge) => edge[0]),
adjustedEdges.map((edge) => edge[1]),
],
edgeWeights: adjustedEdges.map((edge) => edge[2]),
sortedEdges: false,
checkIntegrity: true,
});

const normalizedNetwork =
network.createNormalizedNetworkUsingAssociationStrength();

let bestClustering: Clustering = new Clustering({ nNodes: 0 });
let maxQuality = Number.NEGATIVE_INFINITY;
const clusteringAlgorithm = new LeidenAlgorithm();
clusteringAlgorithm.setResolution(resolution);
let previousQuality = Number.NEGATIVE_INFINITY;
let iteration = 0;

if (nIterations !== null) {
clusteringAlgorithm.setNIterations(nIterations);
for (let i = 0; i < nRandomStarts; i++) {
const clustering = new Clustering({
nNodes: normalizedNetwork.getNNodes(),
});
clusteringAlgorithm.improveClustering(normalizedNetwork, clustering);
const quality = clusteringAlgorithm.calcQuality(
normalizedNetwork,
clustering,
);
if (quality > maxQuality) {
bestClustering = clustering;
maxQuality = quality;
}
}
} else {
while (true) {
const clustering = new Clustering({
nNodes: normalizedNetwork.getNNodes(),
});
clusteringAlgorithm.improveClustering(normalizedNetwork, clustering);
const quality = clusteringAlgorithm.calcQuality(
normalizedNetwork,
clustering,
);
if (quality > maxQuality) {
bestClustering = clustering;
maxQuality = quality;
}
if (
convergenceThreshold &&
Math.abs(quality - previousQuality) < convergenceThreshold
) {
break;
}
previousQuality = quality;
iteration++;
}
}
bestClustering.orderClustersByNNodes();

let bestLayout: Layout = new Layout({ nNodes: 0 });
let minQuality = Number.POSITIVE_INFINITY;
const layoutAlgorithm = new GradientDescentVOSLayoutAlgorithm();
layoutAlgorithm.setAttraction(2);
layoutAlgorithm.setRepulsion(1);
previousQuality = Number.POSITIVE_INFINITY;
iteration = 0;

if (nIterations !== null) {
for (let i = 0; i < nRandomStarts; i++) {
const layout = new Layout({ nNodes: normalizedNetwork.getNNodes() });
layoutAlgorithm.improveLayout(normalizedNetwork, layout);
const quality = layoutAlgorithm.calcQuality(
normalizedNetwork,
layout,
);
if (quality < minQuality) {
bestLayout = layout;
minQuality = quality;
}
}
} else {
while (true) {
const layout = new Layout({ nNodes: normalizedNetwork.getNNodes() });
layoutAlgorithm.improveLayout(normalizedNetwork, layout);
const quality = layoutAlgorithm.calcQuality(
normalizedNetwork,
layout,
);
if (quality < minQuality) {
bestLayout = layout;
minQuality = quality;
}
if (
convergenceThreshold &&
Math.abs(quality - previousQuality) < convergenceThreshold
) {
break;
}
previousQuality = quality;
iteration++;
}
}
bestLayout.standardize(true);

return {
bestClustering: {
nNodes: bestClustering.getNNodes(),
clusters: bestClustering.getClusters(),
nClusters: bestClustering.getNClusters(),
},
bestLayout: {
nNodes: bestLayout.getNNodes(),
coordinates: bestLayout.getCoordinates(),
},
};
}

const results = runLeidenAlgorithm(
edges,
resolution,
nIterations,
nRandomStarts,
convergenceThreshold, // Pass convergence threshold
);
return Promise.resolve({ data: results });
}
}
4 changes: 4 additions & 0 deletions apps/shinkai-tool-leiden/tsconfig.app.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"extends": "./tsconfig.json",
"include": ["./src/**/*.ts"]
}
10 changes: 10 additions & 0 deletions apps/shinkai-tool-leiden/tsconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"extends": "../../tsconfig.base.json",
"compilerOptions": {

},
"include": [
"./src/**/*.ts",
"webpack.config.ts"
],
}
17 changes: 17 additions & 0 deletions apps/shinkai-tool-leiden/webpack.config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import * as path from 'path';

import { composePlugins, withNx } from '@nx/webpack';
import { merge } from 'lodash';

import { withToolWebpackConfig } from '@shinkai_protocol/shinkai-tools-bundler';

module.exports = composePlugins(withNx(), (config, { options, context }) => {
return merge(
config,
withToolWebpackConfig({
entry: path.join(__dirname, 'src/index.ts'),
outputPath: path.join(__dirname, '../../dist/apps/shinkai-tool-leiden'),
tsConfigFile: path.join(__dirname, 'tsconfig.app.json'),
}),
);
});
10 changes: 10 additions & 0 deletions libs/shinkai-tools-runner/src/built_in_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ lazy_static! {
.unwrap(),
)),
);
m.insert(
"shinkai-tool-leiden",
&*Box::leak(Box::new(
serde_json::from_str::<ToolDefinition>(include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tools/shinkai-tool-leiden/definition.json"
)))
.unwrap(),
)),
);
// ntim: New tools will be inserted here, don't remove this comment
m
};
Expand Down
37 changes: 35 additions & 2 deletions libs/shinkai-tools-runner/src/lib.test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,9 @@ async fn max_execution_time() {
setTimeout(() => {
resolve();
}, 500);
});
});
}
return { data: true };
}
}
Expand Down Expand Up @@ -417,3 +417,36 @@ async fn shinkai_tool_download_page_stack_overflow() {
.unwrap();
assert!(run_result.is_ok());
}

#[tokio::test]
async fn shinkai_tool_leiden() {
let tool_definition = get_tool("shinkai-tool-leiden").unwrap();
let mut tool = Tool::new();
let _ = tool
.load_from_code(&tool_definition.code.clone().unwrap(), "")
.await;
let edges = vec![
(2, 1, 1), (3, 1, 1), (3, 2, 1), (4, 1, 1), (4, 2, 1), (4, 3, 1),
(5, 1, 1), (6, 1, 1), (7, 1, 1), (7, 5, 1), (7, 6, 1), (8, 1, 1),
(8, 2, 1), (8, 3, 1), (8, 4, 1), (9, 1, 1), (9, 3, 1), (10, 3, 1),
(11, 1, 1), (11, 5, 1), (11, 6, 1), (12, 1, 1), (13, 1, 1), (13, 4, 1),
(14, 1, 1), (14, 2, 1), (14, 3, 1), (14, 4, 1), (17, 6, 1), (17, 7, 1),
(18, 1, 1), (18, 2, 1), (20, 1, 1), (20, 2, 1), (22, 1, 1), (22, 2, 1),
(26, 24, 1), (26, 25, 1), (28, 3, 1), (28, 24, 1), (28, 25, 1), (29, 3, 1),
(30, 24, 1), (30, 27, 1), (31, 2, 1), (31, 9, 1), (32, 1, 1), (32, 25, 1),
(32, 26, 1), (32, 29, 1), (33, 3, 1), (33, 9, 1), (33, 15, 1), (33, 16, 1),
(33, 19, 1), (33, 21, 1), (33, 23, 1), (33, 24, 1), (33, 30, 1), (33, 31, 1),
(33, 32, 1), (34, 9, 1), (34, 10, 1), (34, 14, 1), (34, 15, 1), (34, 16, 1),
(34, 19, 1), (34, 20, 1), (34, 21, 1), (34, 23, 1), (34, 24, 1), (34, 27, 1),
(34, 28, 1), (34, 29, 1), (34, 30, 1), (34, 31, 1), (34, 32, 1), (34, 33, 1)
];
let params = serde_json::json!({
"edges": edges
});
let start_time = std::time::Instant::now(); // Start measuring time
let run_result = tool.run(&params.to_string(), None).await.unwrap();
let elapsed_time = start_time.elapsed(); // Measure elapsed time

println!("Execution time: {:?}", elapsed_time); // Print the elapsed time
assert!(run_result.data["bestClustering"]["nClusters"].as_u64().unwrap() > 0);
}
Loading

0 comments on commit 4b7455e

Please sign in to comment.