Skip to content

Commit eae7ee8

Browse files
authored
Merge pull request #29 from jideoyelayo1/Eclat
2 parents 4c79d13 + c4897b7 commit eae7ee8

File tree

4 files changed

+273
-0
lines changed

4 files changed

+273
-0
lines changed

CMakeLists.txt

+7
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ add_executable(NeuralNetwork tests/neural_network/NeuralNetworkTest.cpp)
7777
target_compile_definitions(NeuralNetwork PRIVATE TEST_NEURAL_NETWORK)
7878
target_link_libraries(NeuralNetwork cpp_ml_library)
7979

80+
add_executable(Eclat tests/association/EclatTest.cpp)
81+
target_compile_definitions(Eclat PRIVATE TEST_ECLAT)
82+
target_link_libraries(Eclat cpp_ml_library)
83+
8084
# Register individual tests
8185
add_test(NAME LogisticRegressionTest COMMAND LogisticRegressionTest)
8286
add_test(NAME PolynomialRegressionTest COMMAND PolynomialRegressionTest)
@@ -91,6 +95,7 @@ add_test(NAME KNNRegressor COMMAND KNNRegressor)
9195
add_test(NAME HierarchicalClustering COMMAND HierarchicalClustering)
9296
add_test(NAME SupportVectorRegression COMMAND SupportVectorRegression)
9397
add_test(NAME NeuralNetwork COMMAND NeuralNetwork)
98+
add_test(NAME Eclat COMMAND Eclat)
9499

95100

96101
# Add example executables if BUILD_EXAMPLES is ON
@@ -130,6 +135,8 @@ if(BUILD_EXAMPLES)
130135
target_compile_definitions(${EXAMPLE_TARGET} PRIVATE TEST_SUPPORT_VECTOR_REGRESSION)
131136
elseif(EXAMPLE_NAME STREQUAL "NeuralNetworkExample")
132137
target_compile_definitions(${EXAMPLE_TARGET} PRIVATE TEST_NEURAL_NETWORK)
138+
elseif(EXAMPLE_NAME STREQUAL "EclatExample")
139+
target_compile_definitions(${EXAMPLE_TARGET} PRIVATE TEST_ECLAT)
133140
endif()
134141
endforeach()
135142
endif()

examples/EclatExample.cpp

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "../ml_library_include/ml/association/Eclat.hpp"
2+
#include <iostream>
3+
4+
void testEclat() {
5+
// Sample transactions
6+
std::vector<std::vector<int>> transactions = {
7+
{1, 2, 5},
8+
{2, 4},
9+
{2, 3},
10+
{1, 2, 4},
11+
{1, 3},
12+
{2, 3},
13+
{1, 3},
14+
{1, 2, 3, 5},
15+
{1, 2, 3}
16+
};
17+
18+
// Minimum support threshold (e.g., 22% of total transactions)
19+
double min_support = 0.22;
20+
21+
// Create Eclat object
22+
Eclat eclat(min_support);
23+
24+
// Run Eclat algorithm
25+
std::vector<std::vector<int>> frequent_itemsets = eclat.run(transactions);
26+
27+
// Get support counts
28+
auto support_counts = eclat.get_support_counts();
29+
30+
// Display frequent itemsets and their support counts
31+
std::cout << "Frequent Itemsets:\n";
32+
for (const auto& itemset : frequent_itemsets) {
33+
std::cout << "Itemset: { ";
34+
for (int item : itemset) {
35+
std::cout << item << " ";
36+
}
37+
std::cout << "} - Support: " << support_counts.at(itemset) << "\n";
38+
}
39+
}
40+
41+
int main() {
42+
testEclat();
43+
return 0;
44+
}
+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#ifndef ECLAT_HPP
2+
#define ECLAT_HPP
3+
4+
#include <map>
5+
#include <vector>
6+
#include <algorithm>
7+
#include <iostream>
8+
#include <string>
9+
#include <cmath>
10+
#include <stdexcept>
11+
12+
/**
13+
* @file Eclat.hpp
14+
* @brief Optimized Implementation of the Eclat algorithm for frequent itemset mining.
15+
*/
16+
17+
/**
18+
* @class Eclat
19+
* @brief Class to perform frequent itemset mining using the Eclat algorithm.
20+
*/
21+
class Eclat {
22+
public:
23+
/**
24+
* @brief Constructor for the Eclat class.
25+
* @param min_support Minimum support threshold (as a fraction between 0 and 1).
26+
*/
27+
Eclat(double min_support);
28+
29+
/**
30+
* @brief Runs the Eclat algorithm on the provided dataset.
31+
* @param transactions A vector of transactions, each transaction is a vector of items.
32+
* @return A vector of frequent itemsets, where each itemset is represented as a vector of items.
33+
*/
34+
std::vector<std::vector<int>> run(const std::vector<std::vector<int>>& transactions);
35+
36+
/**
37+
* @brief Gets the support counts for all frequent itemsets found.
38+
* @return A map where keys are itemsets (as vectors) and values are support counts.
39+
*/
40+
std::map<std::vector<int>, int> get_support_counts() const;
41+
42+
private:
43+
/**
44+
* @brief Recursively mines frequent itemsets using the Eclat algorithm.
45+
* @param prefix The current itemset prefix.
46+
* @param items A vector of items to consider.
47+
* @param tid_sets A map from items to their transaction ID vectors.
48+
*/
49+
void eclat_recursive(const std::vector<int>& prefix,
50+
const std::vector<int>& items,
51+
const std::map<int, std::vector<int>>& tid_sets);
52+
53+
double min_support; ///< Minimum support threshold.
54+
int min_support_count; ///< Minimum support count (absolute number of transactions).
55+
int total_transactions; ///< Total number of transactions.
56+
std::map<std::vector<int>, int> support_counts; ///< Support counts for itemsets.
57+
};
58+
59+
Eclat::Eclat(double min_support)
60+
: min_support(min_support), min_support_count(0), total_transactions(0) {
61+
if (min_support <= 0.0 || min_support > 1.0) {
62+
throw std::invalid_argument("min_support must be between 0 and 1.");
63+
}
64+
}
65+
66+
std::vector<std::vector<int>> Eclat::run(const std::vector<std::vector<int>>& transactions) {
67+
total_transactions = static_cast<int>(transactions.size());
68+
min_support_count = static_cast<int>(std::ceil(min_support * total_transactions));
69+
70+
// Map each item to its TID vector
71+
std::map<int, std::vector<int>> item_tidsets;
72+
for (int tid = 0; tid < total_transactions; ++tid) {
73+
for (int item : transactions[tid]) {
74+
item_tidsets[item].push_back(tid);
75+
}
76+
}
77+
78+
// Sort TID vectors
79+
for (auto& [item, tids] : item_tidsets) {
80+
std::sort(tids.begin(), tids.end());
81+
}
82+
83+
// Filter items that meet the minimum support
84+
std::vector<int> frequent_items;
85+
for (const auto& [item, tidset] : item_tidsets) {
86+
if (static_cast<int>(tidset.size()) >= min_support_count) {
87+
frequent_items.push_back(item);
88+
}
89+
}
90+
91+
// Sort items for consistent order
92+
std::sort(frequent_items.begin(), frequent_items.end());
93+
94+
// Initialize support counts for single items
95+
for (int item : frequent_items) {
96+
std::vector<int> itemset = {item};
97+
support_counts[itemset] = static_cast<int>(item_tidsets[item].size());
98+
}
99+
100+
// Start recursive mining
101+
eclat_recursive({}, frequent_items, item_tidsets);
102+
103+
// Collect frequent itemsets from support counts
104+
std::vector<std::vector<int>> frequent_itemsets;
105+
for (const auto& [itemset, count] : support_counts) {
106+
if (count >= min_support_count) {
107+
frequent_itemsets.push_back(itemset);
108+
}
109+
}
110+
111+
return frequent_itemsets;
112+
}
113+
114+
void Eclat::eclat_recursive(const std::vector<int>& prefix,
115+
const std::vector<int>& items,
116+
const std::map<int, std::vector<int>>& tid_sets) {
117+
size_t n = items.size();
118+
for (size_t i = 0; i < n; ++i) {
119+
int item = items[i];
120+
std::vector<int> new_prefix = prefix;
121+
new_prefix.push_back(item);
122+
123+
// Update support counts
124+
int support = static_cast<int>(tid_sets.at(item).size());
125+
support_counts[new_prefix] = support;
126+
127+
// Generate new combinations
128+
std::vector<int> remaining_items;
129+
std::map<int, std::vector<int>> new_tid_sets;
130+
131+
for (size_t j = i + 1; j < n; ++j) {
132+
int next_item = items[j];
133+
134+
// Intersect TID sets
135+
std::vector<int> intersect_tid_set;
136+
const auto& tid_set1 = tid_sets.at(item);
137+
const auto& tid_set2 = tid_sets.at(next_item);
138+
std::set_intersection(tid_set1.begin(), tid_set1.end(),
139+
tid_set2.begin(), tid_set2.end(),
140+
std::back_inserter(intersect_tid_set));
141+
142+
if (static_cast<int>(intersect_tid_set.size()) >= min_support_count) {
143+
remaining_items.push_back(next_item);
144+
new_tid_sets[next_item] = std::move(intersect_tid_set);
145+
}
146+
}
147+
148+
// Recursive call
149+
if (!remaining_items.empty()) {
150+
eclat_recursive(new_prefix, remaining_items, new_tid_sets);
151+
}
152+
}
153+
}
154+
155+
std::map<std::vector<int>, int> Eclat::get_support_counts() const {
156+
return support_counts;
157+
}
158+
159+
#endif // ECLAT_HPP

tests/association/EclatTest.cpp

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include "../../ml_library_include/ml/association/Eclat.hpp"
2+
#include <iostream>
3+
#include <vector>
4+
#include <cassert>
5+
#include <string>
6+
7+
int main() {
8+
// Sample dataset with transactions
9+
std::vector<std::vector<int>> transactions = {
10+
{1, 2, 5},
11+
{2, 4},
12+
{2, 3},
13+
{1, 2, 4},
14+
{1, 3},
15+
{2, 3},
16+
{1, 3},
17+
{1, 2, 3, 5},
18+
{1, 2, 3}
19+
};
20+
21+
// Minimum support threshold (e.g., 22% of total transactions)
22+
double min_support = 0.22;
23+
24+
// Create the Eclat model with the minimum support
25+
Eclat eclat(min_support);
26+
27+
// Run Eclat algorithm to obtain frequent itemsets
28+
std::vector<std::vector<int>> frequent_itemsets = eclat.run(transactions);
29+
30+
// Get support counts
31+
auto support_counts = eclat.get_support_counts();
32+
33+
// Expected frequent itemsets for validation (sample expected output)
34+
std::vector<std::vector<int>> expected_frequent_itemsets = {
35+
{1, 2}, {2, 3}, {1, 3}, {1, 2, 3}
36+
// Add other expected itemsets based on expected results for the given min_support
37+
};
38+
39+
// Verify that each expected itemset appears in the results
40+
for (const auto& expected_set : expected_frequent_itemsets) {
41+
assert(std::find(frequent_itemsets.begin(), frequent_itemsets.end(), expected_set) != frequent_itemsets.end() &&
42+
"Expected frequent itemset missing from results.");
43+
}
44+
45+
// Display the results for verification
46+
std::cout << "Frequent Itemsets:\n";
47+
for (const auto& itemset : frequent_itemsets) {
48+
std::cout << "Itemset: { ";
49+
for (int item : itemset) {
50+
std::cout << item << " ";
51+
}
52+
std::cout << "} - Support: " << support_counts.at(itemset) << "\n";
53+
54+
// Verify support is above the minimum support threshold
55+
double support_ratio = static_cast<double>(support_counts.at(itemset)) / transactions.size();
56+
assert(support_ratio >= min_support && "Frequent itemset does not meet minimum support threshold.");
57+
}
58+
59+
// Inform user of successful test
60+
std::cout << "Eclat Association Rule Mining Basic Test passed." << std::endl;
61+
62+
return 0;
63+
}

0 commit comments

Comments
 (0)