-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapriori.py
166 lines (131 loc) · 5.26 KB
/
apriori.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import sys
from itertools import combinations
transaction_list = []
min_support_count = 0
def initialize_item_set():
global transaction_list
item_set_1 = {}
for transaction in transaction_list:
for item in transaction:
if item in item_set_1:
item_set_1[item] += 1
else:
item_set_1[item] = 1
for item in list(item_set_1.keys()):
if item_set_1[item] < min_support_count:
del item_set_1[item]
return item_set_1
def make_candidate_set_list(prev_keys_list, item_length):
candidate_set_list = []
if item_length == 2:
tmp_candidates = list(combinations(prev_keys_list, item_length))
for candidate in tmp_candidates:
candidate_set_list.append(set(candidate))
else:
single_key_list = []
for prev_keys in prev_keys_list:
for prev_key in prev_keys:
if prev_key not in single_key_list:
single_key_list.append(prev_key)
tmp_candidates = list(combinations(single_key_list, item_length))
for candidate in tmp_candidates:
candidate_set_list.append(set(candidate))
return candidate_set_list
def make_frequent_set(prev_keys_list, item_length, candidate_set_list):
global transaction_list
global min_support_count
frequent_set_tmp = {}
if item_length == 2:
tmp_list = []
for prev_keys in prev_keys_list:
tmp_list.append(set([prev_keys]))
prev_keys_list = tmp_list
else:
tmp_list = []
for prev_keys in prev_keys_list:
tmp_list.append(set(prev_keys))
prev_keys_list = tmp_list
for candidate_set in candidate_set_list:
count = 0
for key in list(combinations(candidate_set, item_length - 1)):
key = set(key)
if key not in prev_keys_list:
break
count = count + 1
if count == item_length:
frequent_set_tmp[tuple(candidate_set)] = 0
for key in frequent_set_tmp.keys():
for transaction in transaction_list:
if set(key).issubset(set(transaction)):
frequent_set_tmp[key] += 1
frequent_set = {key: frequent_set_tmp[key] for key in frequent_set_tmp.keys() if
frequent_set_tmp[key] >= min_support_count}
return frequent_set
def apriori():
global transaction_list
item_set_total_list = []
item_set_1 = initialize_item_set()
item_set_total_list.append(item_set_1)
item_length = 1
while True:
prev_keys_list = list(item_set_total_list[item_length - 1].keys())
item_length += 1
candidate_set_list = make_candidate_set_list(prev_keys_list, item_length)
if not candidate_set_list:
break
frequent_set = make_frequent_set(prev_keys_list, item_length, candidate_set_list)
if not frequent_set:
break
else:
item_set_total_list.append(frequent_set)
return item_set_total_list
def from_set_to_form(target_set):
target_list = list(target_set)
context = ','.join(str(target) for target in target_list)
text = "{" + context + "}"
return text
def make_output_text_by_association(item_set_total_list):
global transaction_list
num_of_transaction = len(transaction_list)
output_file_text = []
for list_index in range(1, len(item_set_total_list)):
for item_set, item_set_count in item_set_total_list[list_index].items():
sub_itemset_length = list_index
while sub_itemset_length > 0:
sub_itemset = list(combinations(item_set, sub_itemset_length))
for item in sub_itemset:
associative_item = set(item_set).difference(item)
support = (int(item_set_count) / num_of_transaction) * 100
item_count = 0
for transaction in transaction_list:
if set(item).issubset(transaction):
item_count += 1
confidence = (int(item_set_count) / item_count) * 100
output_file_text.append(
'{}\t{}\t{:.2f}\t{:.2f}\n'.format(from_set_to_form(item), from_set_to_form(associative_item),
support, confidence))
sub_itemset_length -= 1
return output_file_text
def main():
global transaction_list
global min_support_count
min_support = int(sys.argv[1])
input_file_name = sys.argv[2]
output_file_name = sys.argv[3]
input_file = open("./" + input_file_name, 'r')
transaction_list_raw = input_file.readlines()
input_file.close()
transaction_count = 0
for transaction in transaction_list_raw:
refined_transaction = transaction.replace("\n", "").split('\t')
transaction_list.append(refined_transaction)
transaction_count += 1
min_support_count = int(transaction_count * min_support / 100)
item_set_total_list = apriori()
output_file_context = make_output_text_by_association(item_set_total_list)
output_file = open("./" + output_file_name, 'w')
for context in output_file_context:
output_file.write(context)
output_file.close()
if __name__ == "__main__":
main()