Skip to content

Commit 9eb2196

Browse files
authored
Add Cosine Similarity Algorithm for Strings (#459)
1 parent b0838cb commit 9eb2196

File tree

3 files changed

+221
-0
lines changed

3 files changed

+221
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
using System;
2+
using Algorithms.Strings.Similarity;
3+
using NUnit.Framework;
4+
5+
namespace Algorithms.Tests.Strings.Similarity;
6+
7+
[TestFixture]
8+
public class CosineSimilarityTests
9+
{
10+
[Test]
11+
public void Calculate_IdenticalStrings_ReturnsOne()
12+
{
13+
var str1 = "test";
14+
var str2 = "test";
15+
var result = CosineSimilarity.Calculate(str1, str2);
16+
Assert.That(result, Is.EqualTo(1.0).Within(1e-6), "Identical strings should have a cosine similarity of 1.");
17+
}
18+
19+
[Test]
20+
public void Calculate_CompletelyDifferentStrings_ReturnsZero()
21+
{
22+
var str1 = "abc";
23+
var str2 = "xyz";
24+
var result = CosineSimilarity.Calculate(str1, str2);
25+
Assert.That(result, Is.EqualTo(0.0).Within(1e-6), "Completely different strings should have a cosine similarity of 0.");
26+
}
27+
28+
[Test]
29+
public void Calculate_EmptyStrings_ReturnsZero()
30+
{
31+
var str1 = "";
32+
var str2 = "";
33+
var result = CosineSimilarity.Calculate(str1, str2);
34+
Assert.That(result, Is.EqualTo(0.0).Within(1e-6), "Empty strings should have a cosine similarity of 0.");
35+
}
36+
37+
[Test]
38+
public void Calculate_OneEmptyString_ReturnsZero()
39+
{
40+
var str1 = "test";
41+
var str2 = "";
42+
var result = CosineSimilarity.Calculate(str1, str2);
43+
Assert.That(result, Is.EqualTo(0.0).Within(1e-6), "Empty string should have a cosine similarity of 0.");
44+
}
45+
46+
[Test]
47+
public void Calculate_SameCharactersDifferentCases_ReturnsOne()
48+
{
49+
var str1 = "Test";
50+
var str2 = "test";
51+
var result = CosineSimilarity.Calculate(str1, str2);
52+
Assert.That(result, Is.EqualTo(1.0).Within(1e-6), "The method should be case-insensitive.");
53+
}
54+
55+
[Test]
56+
public void Calculate_SpecialCharacters_ReturnsCorrectValue()
57+
{
58+
var str1 = "hello!";
59+
var str2 = "hello!";
60+
var result = CosineSimilarity.Calculate(str1, str2);
61+
Assert.That(result, Is.EqualTo(1.0).Within(1e-6), "Strings with special characters should have a cosine similarity of 1.");
62+
}
63+
64+
[Test]
65+
public void Calculate_DifferentLengthWithCommonCharacters_ReturnsCorrectValue()
66+
{
67+
var str1 = "hello";
68+
var str2 = "hello world";
69+
var result = CosineSimilarity.Calculate(str1, str2);
70+
var expected = 10 / (Math.Sqrt(7) * Math.Sqrt(19)); // calculated manually
71+
Assert.That(result, Is.EqualTo(expected).Within(1e-6), "Strings with different lengths but some common characters should have the correct cosine similarity.");
72+
}
73+
74+
[Test]
75+
public void Calculate_PartiallyMatchingStrings_ReturnsCorrectValue()
76+
{
77+
var str1 = "night";
78+
var str2 = "nacht";
79+
var result = CosineSimilarity.Calculate(str1, str2);
80+
// Assuming the correct calculation gives an expected value
81+
var expected = 3.0 / 5.0;
82+
Assert.That(result, Is.EqualTo(expected).Within(1e-6), "Partially matching strings should have the correct cosine similarity.");
83+
}
84+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
namespace Algorithms.Strings.Similarity;
5+
6+
public static class CosineSimilarity
7+
{
8+
/// <summary>
9+
/// Calculates the Cosine Similarity between two strings.
10+
/// Cosine Similarity is a measure of similarity between two non-zero vectors of an inner product space.
11+
/// It measures the cosine of the angle between the two vectors.
12+
/// </summary>
13+
/// <param name="left">The first string.</param>
14+
/// <param name="right">The second string.</param>
15+
/// <returns>
16+
/// A double value between 0 and 1 that represents the similarity
17+
/// of the two strings.
18+
/// </returns>
19+
public static double Calculate(string left, string right)
20+
{
21+
// Step 1: Get the vectors for the two strings
22+
// Each vector represents the frequency of each character in the string.
23+
var vectors = GetVectors(left.ToLowerInvariant(), right.ToLowerInvariant());
24+
var leftVector = vectors.leftVector;
25+
var rightVector = vectors.rightVector;
26+
27+
// Step 2: Calculate the intersection of the two vectors
28+
// The intersection is the set of characters that appear in both strings.
29+
var intersection = GetIntersection(leftVector, rightVector);
30+
31+
// Step 3: Calculate the dot product of the two vectors
32+
// The dot product is the sum of the products of the corresponding values of the characters in the intersection.
33+
var dotProduct = DotProduct(leftVector, rightVector, intersection);
34+
35+
// Step 4: Calculate the square magnitude of each vector
36+
// The magnitude is the square root of the sum of the squares of the values in the vector.
37+
var mLeft = 0.0;
38+
foreach (var value in leftVector.Values)
39+
{
40+
mLeft += value * value;
41+
}
42+
43+
var mRight = 0.0;
44+
foreach (var value in rightVector.Values)
45+
{
46+
mRight += value * value;
47+
}
48+
49+
// Step 5: Check if either vector is zero
50+
// If either vector is zero (i.e., all characters are unique), the Cosine Similarity is 0.
51+
if (mLeft <= 0 || mRight <= 0)
52+
{
53+
return 0.0;
54+
}
55+
56+
// Step 6: Calculate and return the Cosine Similarity
57+
// The Cosine Similarity is the dot product divided by the product of the magnitudes.
58+
return dotProduct / (Math.Sqrt(mLeft) * Math.Sqrt(mRight));
59+
}
60+
61+
/// <summary>
62+
/// Calculates the vectors for the given strings.
63+
/// </summary>
64+
/// <param name="left">The first string.</param>
65+
/// <param name="right">The second string.</param>
66+
/// <returns>A tuple containing the vectors for the two strings.</returns>
67+
private static (Dictionary<char, int> leftVector, Dictionary<char, int> rightVector) GetVectors(string left, string right)
68+
{
69+
var leftVector = new Dictionary<char, int>();
70+
var rightVector = new Dictionary<char, int>();
71+
72+
// Calculate the frequency of each character in the left string
73+
foreach (var character in left)
74+
{
75+
leftVector.TryGetValue(character, out var frequency);
76+
leftVector[character] = ++frequency;
77+
}
78+
79+
// Calculate the frequency of each character in the right string
80+
foreach (var character in right)
81+
{
82+
rightVector.TryGetValue(character, out var frequency);
83+
rightVector[character] = ++frequency;
84+
}
85+
86+
return (leftVector, rightVector);
87+
}
88+
89+
/// <summary>
90+
/// Calculates the dot product between two vectors represented as dictionaries of character frequencies.
91+
/// The dot product is the sum of the products of the corresponding values of the characters in the intersection of the two vectors.
92+
/// </summary>
93+
/// <param name="leftVector">The vector of the left string.</param>
94+
/// <param name="rightVector">The vector of the right string.</param>
95+
/// <param name="intersection">The intersection of the two vectors, represented as a set of characters.</param>
96+
/// <returns>The dot product of the two vectors.</returns>
97+
private static double DotProduct(Dictionary<char, int> leftVector, Dictionary<char, int> rightVector, HashSet<char> intersection)
98+
{
99+
// Initialize the dot product to 0
100+
double dotProduct = 0;
101+
102+
// Iterate over each character in the intersection of the two vectors
103+
foreach (var character in intersection)
104+
{
105+
// Calculate the product of the corresponding values of the characters in the left and right vectors
106+
dotProduct += leftVector[character] * rightVector[character];
107+
}
108+
109+
// Return the dot product
110+
return dotProduct;
111+
}
112+
113+
/// <summary>
114+
/// Calculates the intersection of two vectors, represented as dictionaries of character frequencies.
115+
/// </summary>
116+
/// <param name="leftVector">The vector of the left string.</param>
117+
/// <param name="rightVector">The vector of the right string.</param>
118+
/// <returns>A HashSet containing the characters that appear in both vectors.</returns>
119+
private static HashSet<char> GetIntersection(Dictionary<char, int> leftVector, Dictionary<char, int> rightVector)
120+
{
121+
// Initialize a HashSet to store the intersection of the two vectors.
122+
var intersection = new HashSet<char>();
123+
124+
// Iterate over each key-value pair in the left vector.
125+
foreach (var kvp in leftVector)
126+
{
127+
// If the right vector contains the same key, add it to the intersection.
128+
if (rightVector.ContainsKey(kvp.Key))
129+
{
130+
intersection.Add(kvp.Key);
131+
}
132+
}
133+
134+
return intersection;
135+
}
136+
}

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ find more than one implementation for the same objective but using different alg
178178
* [A181391 Van Eck's](./Algorithms/Sequences/VanEcksSequence.cs)
179179
* [String](./Algorithms/Strings)
180180
* [Similarity](./Algorithms/Strings/Similarity/)
181+
* [Cosine Similarity](./Algorithms/Strings/Similarity/CosineSimilarity.cs)
181182
* [Hamming Distance](./Algorithms/Strings/Similarity/HammingDistance.cs)
182183
* [Jaro Similarity](./Algorithms/Strings/Similarity/JaroSimilarity.cs)
183184
* [Jaro-Winkler Distance](./Algorithms/Strings/Similarity/JaroWinklerDistance.cs)

0 commit comments

Comments
 (0)