Skip to content

Commit c4897b7

Browse files
committed
eclat changes
1 parent d3c7b0b commit c4897b7

File tree

3 files changed

+45
-80
lines changed

3 files changed

+45
-80
lines changed

examples/EclatExample.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,23 @@ void testEclat() {
2222
Eclat eclat(min_support);
2323

2424
// Run Eclat algorithm
25-
std::vector<std::set<int>> frequent_itemsets = eclat.run(transactions);
25+
std::vector<std::vector<int>> frequent_itemsets = eclat.run(transactions);
2626

2727
// Get support counts
2828
auto support_counts = eclat.get_support_counts();
2929

3030
// Display frequent itemsets and their support counts
3131
std::cout << "Frequent Itemsets:\n";
3232
for (const auto& itemset : frequent_itemsets) {
33-
std::string itemset_str;
33+
std::cout << "Itemset: { ";
3434
for (int item : itemset) {
35-
itemset_str += std::to_string(item) + " ";
35+
std::cout << item << " ";
3636
}
37-
std::string key = eclat.itemset_to_string(itemset);
38-
int support = support_counts[key];
39-
std::cout << "Itemset: {" << itemset_str << "} - Support: " << support << "\n";
37+
std::cout << "} - Support: " << support_counts.at(itemset) << "\n";
4038
}
41-
4239
}
4340

44-
int main(){
41+
int main() {
4542
testEclat();
4643
return 0;
47-
}
44+
}
Lines changed: 33 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
#ifndef ECLAT_HPP
22
#define ECLAT_HPP
33

4-
#include <unordered_map>
5-
#include <unordered_set>
4+
#include <map>
65
#include <vector>
7-
#include <set>
86
#include <algorithm>
9-
#include <functional>
107
#include <iostream>
118
#include <string>
129
#include <cmath>
10+
#include <stdexcept>
1311

1412
/**
1513
* @file Eclat.hpp
16-
* @brief Implementation of the Eclat algorithm for frequent itemset mining.
14+
* @brief Optimized Implementation of the Eclat algorithm for frequent itemset mining.
1715
*/
1816

1917
/**
@@ -31,38 +29,31 @@ class Eclat {
3129
/**
3230
* @brief Runs the Eclat algorithm on the provided dataset.
3331
* @param transactions A vector of transactions, each transaction is a vector of items.
34-
* @return A vector of frequent itemsets, where each itemset is represented as a set of items.
32+
* @return A vector of frequent itemsets, where each itemset is represented as a vector of items.
3533
*/
36-
std::vector<std::set<int>> run(const std::vector<std::vector<int>>& transactions);
34+
std::vector<std::vector<int>> run(const std::vector<std::vector<int>>& transactions);
3735

3836
/**
3937
* @brief Gets the support counts for all frequent itemsets found.
40-
* @return An unordered_map where keys are itemsets (as strings) and values are support counts.
38+
* @return A map where keys are itemsets (as vectors) and values are support counts.
4139
*/
42-
std::unordered_map<std::string, int> get_support_counts() const;
43-
/**
44-
* @brief Converts an itemset to a string representation for use as a key.
45-
* @param itemset The itemset to convert.
46-
* @return A string representation of the itemset.
47-
*/
48-
std::string itemset_to_string(const std::set<int>& itemset) const;
40+
std::map<std::vector<int>, int> get_support_counts() const;
4941

5042
private:
5143
/**
5244
* @brief Recursively mines frequent itemsets using the Eclat algorithm.
5345
* @param prefix The current itemset prefix.
5446
* @param items A vector of items to consider.
55-
* @param tid_sets A map from items to their transaction ID sets.
47+
* @param tid_sets A map from items to their transaction ID vectors.
5648
*/
57-
void eclat_recursive(const std::set<int>& prefix,
49+
void eclat_recursive(const std::vector<int>& prefix,
5850
const std::vector<int>& items,
59-
const std::unordered_map<int, std::unordered_set<int>>& tid_sets);
60-
51+
const std::map<int, std::vector<int>>& tid_sets);
6152

6253
double min_support; ///< Minimum support threshold.
6354
int min_support_count; ///< Minimum support count (absolute number of transactions).
6455
int total_transactions; ///< Total number of transactions.
65-
std::unordered_map<std::string, int> support_counts; ///< Support counts for itemsets.
56+
std::map<std::vector<int>, int> support_counts; ///< Support counts for itemsets.
6657
};
6758

6859
Eclat::Eclat(double min_support)
@@ -72,18 +63,23 @@ Eclat::Eclat(double min_support)
7263
}
7364
}
7465

75-
std::vector<std::set<int>> Eclat::run(const std::vector<std::vector<int>>& transactions) {
66+
std::vector<std::vector<int>> Eclat::run(const std::vector<std::vector<int>>& transactions) {
7667
total_transactions = static_cast<int>(transactions.size());
7768
min_support_count = static_cast<int>(std::ceil(min_support * total_transactions));
7869

79-
// Map each item to its TID set
80-
std::unordered_map<int, std::unordered_set<int>> item_tidsets;
70+
// Map each item to its TID vector
71+
std::map<int, std::vector<int>> item_tidsets;
8172
for (int tid = 0; tid < total_transactions; ++tid) {
8273
for (int item : transactions[tid]) {
83-
item_tidsets[item].insert(tid);
74+
item_tidsets[item].push_back(tid);
8475
}
8576
}
8677

78+
// Sort TID vectors
79+
for (auto& [item, tids] : item_tidsets) {
80+
std::sort(tids.begin(), tids.end());
81+
}
82+
8783
// Filter items that meet the minimum support
8884
std::vector<int> frequent_items;
8985
for (const auto& [item, tidset] : item_tidsets) {
@@ -97,64 +93,51 @@ std::vector<std::set<int>> Eclat::run(const std::vector<std::vector<int>>& trans
9793

9894
// Initialize support counts for single items
9995
for (int item : frequent_items) {
100-
std::set<int> itemset = {item};
101-
std::string itemset_str = itemset_to_string(itemset);
102-
support_counts[itemset_str] = static_cast<int>(item_tidsets[item].size());
96+
std::vector<int> itemset = {item};
97+
support_counts[itemset] = static_cast<int>(item_tidsets[item].size());
10398
}
10499

105100
// Start recursive mining
106101
eclat_recursive({}, frequent_items, item_tidsets);
107102

108103
// Collect frequent itemsets from support counts
109-
std::vector<std::set<int>> frequent_itemsets;
110-
for (const auto& [itemset_str, count] : support_counts) {
104+
std::vector<std::vector<int>> frequent_itemsets;
105+
for (const auto& [itemset, count] : support_counts) {
111106
if (count >= min_support_count) {
112-
// Convert string back to itemset
113-
std::set<int> itemset;
114-
size_t pos = 0;
115-
std::string token;
116-
std::string s = itemset_str;
117-
while ((pos = s.find(',')) != std::string::npos) {
118-
token = s.substr(0, pos);
119-
itemset.insert(std::stoi(token));
120-
s.erase(0, pos + 1);
121-
}
122-
itemset.insert(std::stoi(s));
123107
frequent_itemsets.push_back(itemset);
124108
}
125109
}
126110

127111
return frequent_itemsets;
128112
}
129113

130-
void Eclat::eclat_recursive(const std::set<int>& prefix,
114+
void Eclat::eclat_recursive(const std::vector<int>& prefix,
131115
const std::vector<int>& items,
132-
const std::unordered_map<int, std::unordered_set<int>>& tid_sets) {
116+
const std::map<int, std::vector<int>>& tid_sets) {
133117
size_t n = items.size();
134118
for (size_t i = 0; i < n; ++i) {
135119
int item = items[i];
136-
std::set<int> new_prefix = prefix;
137-
new_prefix.insert(item);
138-
std::string itemset_str = itemset_to_string(new_prefix);
120+
std::vector<int> new_prefix = prefix;
121+
new_prefix.push_back(item);
139122

140123
// Update support counts
141124
int support = static_cast<int>(tid_sets.at(item).size());
142-
support_counts[itemset_str] = support;
125+
support_counts[new_prefix] = support;
143126

144127
// Generate new combinations
145128
std::vector<int> remaining_items;
146-
std::unordered_map<int, std::unordered_set<int>> new_tid_sets;
129+
std::map<int, std::vector<int>> new_tid_sets;
147130

148131
for (size_t j = i + 1; j < n; ++j) {
149132
int next_item = items[j];
150133

151134
// Intersect TID sets
152-
std::unordered_set<int> intersect_tid_set;
135+
std::vector<int> intersect_tid_set;
153136
const auto& tid_set1 = tid_sets.at(item);
154137
const auto& tid_set2 = tid_sets.at(next_item);
155138
std::set_intersection(tid_set1.begin(), tid_set1.end(),
156139
tid_set2.begin(), tid_set2.end(),
157-
std::inserter(intersect_tid_set, intersect_tid_set.begin()));
140+
std::back_inserter(intersect_tid_set));
158141

159142
if (static_cast<int>(intersect_tid_set.size()) >= min_support_count) {
160143
remaining_items.push_back(next_item);
@@ -169,19 +152,8 @@ void Eclat::eclat_recursive(const std::set<int>& prefix,
169152
}
170153
}
171154

172-
std::unordered_map<std::string, int> Eclat::get_support_counts() const {
155+
std::map<std::vector<int>, int> Eclat::get_support_counts() const {
173156
return support_counts;
174157
}
175158

176-
std::string Eclat::itemset_to_string(const std::set<int>& itemset) const {
177-
std::string s;
178-
for (auto it = itemset.begin(); it != itemset.end(); ++it) {
179-
s += std::to_string(*it);
180-
if (std::next(it) != itemset.end()) {
181-
s += ",";
182-
}
183-
}
184-
return s;
185-
}
186-
187159
#endif // ECLAT_HPP

tests/association/EclatTest.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
#include "../../ml_library_include/ml/association/Eclat.hpp"
22
#include <iostream>
33
#include <vector>
4-
#include <set>
54
#include <cassert>
65
#include <string>
7-
#include "../TestUtils.hpp"
86

97
int main() {
108
// Sample dataset with transactions
@@ -27,13 +25,13 @@ int main() {
2725
Eclat eclat(min_support);
2826

2927
// Run Eclat algorithm to obtain frequent itemsets
30-
std::vector<std::set<int>> frequent_itemsets = eclat.run(transactions);
28+
std::vector<std::vector<int>> frequent_itemsets = eclat.run(transactions);
3129

3230
// Get support counts
3331
auto support_counts = eclat.get_support_counts();
3432

3533
// Expected frequent itemsets for validation (sample expected output)
36-
std::vector<std::set<int>> expected_frequent_itemsets = {
34+
std::vector<std::vector<int>> expected_frequent_itemsets = {
3735
{1, 2}, {2, 3}, {1, 3}, {1, 2, 3}
3836
// Add other expected itemsets based on expected results for the given min_support
3937
};
@@ -47,16 +45,14 @@ int main() {
4745
// Display the results for verification
4846
std::cout << "Frequent Itemsets:\n";
4947
for (const auto& itemset : frequent_itemsets) {
50-
std::string itemset_str;
48+
std::cout << "Itemset: { ";
5149
for (int item : itemset) {
52-
itemset_str += std::to_string(item) + " ";
50+
std::cout << item << " ";
5351
}
54-
std::string key = eclat.itemset_to_string(itemset);
55-
int support = support_counts[key];
56-
std::cout << "Itemset: {" << itemset_str << "} - Support: " << support << "\n";
52+
std::cout << "} - Support: " << support_counts.at(itemset) << "\n";
5753

5854
// Verify support is above the minimum support threshold
59-
double support_ratio = static_cast<double>(support) / transactions.size();
55+
double support_ratio = static_cast<double>(support_counts.at(itemset)) / transactions.size();
6056
assert(support_ratio >= min_support && "Frequent itemset does not meet minimum support threshold.");
6157
}
6258

0 commit comments

Comments
 (0)