-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmultiLabelMetrics.py
148 lines (98 loc) · 3.83 KB
/
multiLabelMetrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import random
import re
from operator import add
import csv
import pyspark
import pyspark.sql.functions as f
from pyspark.ml.feature import StringIndexer
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.types import Row
from pyspark.sql.types import StringType
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType
from pyspark.mllib.evaluation import MultilabelMetrics
from pyspark.sql.window import Window
from pyspark.storagelevel import StorageLevel
import cv2
import mysql.connector
import requests
from vcgImageAI.comm.sparkBase import *
from vcgImageAI.comm.vcgUtils import *
sparkBase = SparkBase()
spark = sparkBase.createYarnSparkEnv()
"""
从mongo数据表:evaluations计算recall, Precision 以及F-measure, 还有per label accuracy
一个globalIdentity 包含多个modelName(表示模型训练的不同steps),
一个globalIdentity 表示一组唯一的模型以及参数
"""
def computeMultiLabelMetrics(pickedNum=5,topNum=None,globalIdentity=None,batchNum=None,spark=None, excludeLabels=[]):
pipeline = "[{'$match':{'$and':[{'batchNum':" + str(
batchNum) + "},{'globalIdentity':'" + globalIdentity + "'}]}},{'$project': {'scores': 1,'labelId':1}}]"
print(pipeline)
df = spark.read.format("com.mongodb.spark.sql.DefaultSource") \
.option('uri', 'mongodb://zhaoyufei:[email protected]/') \
.option('database', 'vcg') \
.option('collection', 'evaluations') \
.option("pipeline", pipeline).load()
# .persist(storageLevel=StorageLevel.DISK_ONLY)
df.printSchema()
df.show(100, False)
#.persist(storageLevel=StorageLevel.DISK_ONLY)
df.printSchema()
# print(df.count())
excludeLabels_braodcast = spark.sparkContext.broadcast(excludeLabels)
"""
计算top 20,50, 100,200 的recall以及precision,
predictions: labelIndex1,....,
labels: labelIndex1,......
imageId
"""
def flatMapToLabelId(row):
labelId = row.labelId
scores = row.scores
labelIds = labelId
# imageId=row.imageId
labelIds_=[]
for labelIdV in labelIds:
if str(labelIdV) in excludeLabels_braodcast.value:
continue
labelIds_.append(labelIdV)
sorted_by_value = sorted(scores.items(), key=lambda kv: -kv[1])
sorted_by_value_ = []
for tuple in sorted_by_value:
if tuple[0] in excludeLabels_braodcast.value:
continue
else:
sorted_by_value_.append(tuple)
# if topNum is None:
# topNum = len(labelIds_) * pickedNum
# else:
# topNum=100
predictions=[]
for tupe in sorted_by_value_[0:topNum]:
predictions.append(float(tupe[0]))
labelIds=list(map(lambda x: float(x), labelIds_))
if len(predictions)==0 or len(labelIds) == 0:
return None
return (predictions,labelIds)
rdd = df.rdd.map(lambda row: flatMapToLabelId(row)).filter(lambda tuple:tuple !=None)
rdd2 = rdd.map(lambda tuples: Row(topNum=len(tuples[0]))).toDF()
rdd2.show(100, False)
# print(rdd2.groupBy().sum().collect())
# Instantiate metrics object
metrics = MultilabelMetrics(rdd)
# Summary stats
print("batchNum = %s" % batchNum)
print("Recall = %s" % metrics.recall())
print("Precision = %s" % metrics.precision())
print("F1 measure = %s" % metrics.f1Measure())
print("Accuracy = %s" % metrics.accuracy)
if __name__ == '__main__':
computeMultiLabelMetrics(topNum=100, globalIdentity='pt_gettyml_labelCountAbove300', batchNum=320000,
spark=spark)