Skip to content

Commit 0325321

Browse files
author
Siarnold
committed
Li Cheng Hand Segmentation
0 parents  commit 0325321

File tree

1,708 files changed

+523198
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,708 files changed

+523198
-0
lines changed

Classifier.cpp

+373
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
#include "Classifier.h"
2+
3+
4+
LcValidator::LcValidator( float _tp, float _fp, float _fn , float _tn)
5+
{
6+
tp = _tp; fp = _fp; fn = _fn; tn = _tn;
7+
}
8+
9+
LcValidator LcValidator::operator +(const LcValidator & a)
10+
{
11+
return LcValidator( a.tp + tp , a.fp + fp , a.fn + fn , a.tn + tn);
12+
}
13+
void LcValidator::display()
14+
{
15+
//cout << "tp:" << tp << " fp:" << fp << " tn:"<< tn << " fn:" << fn << endl;
16+
17+
cout << "Precision:" << getPrecision(1) << " " << getPrecision(0) << "(back) " << endl;
18+
cout << " Recall :" << getRecall(1) << " " << getRecall(0) << "(back) " << endl;
19+
20+
cout << "F:" << getF1() << " 0-1:" << getZeroOne() << endl;
21+
}
22+
23+
float LcValidator::getZeroOne()
24+
{
25+
return (tp+tn)/(tp+tn+fp+fn);
26+
}
27+
28+
float LcValidator::getPrecision(int i)
29+
{
30+
if(i){return tp/(1e-5f+tp+fp);}
31+
else {return tn/(1e-5f+tn+fn);}
32+
}
33+
34+
35+
float LcValidator::getRecall(int i)
36+
{
37+
if(i){return tp/(1e-5f+tp+fn);}
38+
else {return tn/(1e-5f+fp+tn);}
39+
}
40+
41+
float LcValidator::getF1(int i)
42+
{
43+
float p = getPrecision(i);
44+
float r = getRecall(i);
45+
return 2*p*r/(1e-5f+p+r);
46+
}
47+
48+
LcValidator::LcValidator( Mat & res, Mat & lab)
49+
{
50+
count( res, lab, 0.5, tp, fp, tn, fn);
51+
}
52+
53+
void LcValidator::count( Mat & res, Mat & lab, float th, float & tp, float & fp, float & tn, float & fn)
54+
{
55+
if( res.rows != lab.rows){ cout << " size unmatch while predicting " << endl; return;}
56+
57+
tp = fp = tn = fn = 0.0f;
58+
59+
float * p_res = (float*) res.data;
60+
float * p_lab = (float*) lab.data;
61+
62+
for(int sz = res.rows * res.cols ; sz>0; sz--, p_res++, p_lab++)
63+
{
64+
if( *p_res > th)
65+
{
66+
if( *p_lab >th) tp += 1.0f;
67+
else fp += 1.0f;
68+
}
69+
else
70+
{
71+
if( *p_lab >th) fn += 1.0f;
72+
else tn += 1.0f;
73+
}
74+
}
75+
76+
{
77+
float n = float( res.rows*res.cols);
78+
fp/=n; tp/=n; tn/=n; fn/=n;
79+
}
80+
}
81+
82+
//==============================
83+
84+
void LcRandomTreesR::train(Mat & feature, Mat & label)
85+
{
86+
//_params.max_depth = 10;
87+
//_params.regression_accuracy = 0.1f;
88+
//_params.use_1se_rule = true;
89+
//_params.use_surrogates = true;
90+
//_params.truncate_pruned_tree = false;
91+
//_params.min_sample_count = 10;
92+
93+
_params.max_depth = 10;
94+
_params.regression_accuracy = 0.00f;
95+
_params.min_sample_count = 10;
96+
97+
98+
double t = double(getTickCount());
99+
100+
cout << " Training random forest regression model ...";
101+
102+
Mat varType = Mat::ones(feature.cols+1,1,CV_8UC1) * CV_VAR_NUMERICAL;
103+
104+
_random_tree.train(feature,CV_ROW_SAMPLE,label,Mat(),Mat(),varType,Mat(), _params);
105+
106+
t = (getTickCount()-t)/getTickFrequency();
107+
cout << " time to train:" << t << " secs." << endl;
108+
109+
}
110+
111+
LcValidator LcRandomTreesR::predict( Mat & feature, Mat & res, Mat & label)
112+
{
113+
int n = feature.rows;
114+
res = Mat::zeros( n, 1, 5);
115+
for(int i = 0; i< n ; i++)
116+
{
117+
res.at<float>(i,0) = _random_tree.predict( feature.row(i) );
118+
//res.at<float>(i,0) = _random_tree.predict_prob( feature.row(i) );
119+
}
120+
121+
if( label.rows == feature.rows ) return LcValidator( res, label);
122+
else return LcValidator();
123+
}
124+
125+
LcValidator LcRandomTreesR::predict( Mat & feature, Mat & res)
126+
{
127+
Mat label;
128+
return predict(feature,res,label);
129+
}
130+
131+
132+
void LcRandomTreesR::save( string filename_prefix ){
133+
string filename = filename_prefix + "_rdtr.xml";
134+
cout << " Classifier: Saving " << filename << endl;
135+
_random_tree.save( filename.c_str());
136+
}
137+
138+
void LcRandomTreesR::load( string filename_prefix ){
139+
string filename = filename_prefix + "_rdtr.xml";
140+
cout << " Classifier: Loading " << filename << endl;
141+
_random_tree.load( filename.c_str());
142+
}
143+
144+
void LcRandomTreesR::load_full( string full_filename ){
145+
cout << " Classifier: Loading " << full_filename << endl;
146+
_random_tree.load( full_filename.c_str());
147+
}
148+
149+
//==============================
150+
151+
void LcRandomTreesC::train(Mat & feature, Mat & label) // Multi-class Classifier
152+
{
153+
//_params.max_depth = 10;
154+
//_params.regression_accuracy = 0.1f;
155+
//_params.use_1se_rule = true;
156+
//_params.use_surrogates = true;
157+
//_params.truncate_pruned_tree = false;
158+
//_params.min_sample_count = 10;
159+
160+
_params.max_depth = 100;
161+
_params.min_sample_count = 40;
162+
//_params.use_1se_rule = true;
163+
//_params.use_surrogates = true;
164+
165+
166+
double t = double(getTickCount());
167+
168+
if( veb ) cout << "Train Random Tree Multi-Class Classifier model ...";
169+
170+
Mat varType = Mat::ones(feature.cols+1,1,CV_8UC1) * CV_VAR_NUMERICAL; // all floats
171+
varType.at<uchar>(feature.cols,0) = CV_VAR_CATEGORICAL;
172+
173+
_random_tree.train(feature,CV_ROW_SAMPLE,label,Mat(),Mat(),varType,Mat(), _params);
174+
175+
t = (getTickCount()-t)/getTickFrequency();
176+
if( veb ) cout << " time:" << t << " secs." << endl;
177+
178+
}
179+
180+
LcValidator LcRandomTreesC::predict( Mat & feature, Mat & res, Mat & label)
181+
{
182+
int n = feature.rows;
183+
res = Mat::zeros( n, 1, 5);
184+
for(int i = 0; i< n ; i++)
185+
{
186+
res.at<float>(i,0) = _random_tree.predict( feature.row(i) );
187+
//res.at<float>(i,0) = _random_tree.predict_prob( feature.row(i) );
188+
}
189+
190+
if( label.rows == feature.rows ) return LcValidator( res, label);
191+
else return LcValidator();
192+
}
193+
194+
LcValidator LcRandomTreesC::predict( Mat & feature, Mat & res)
195+
{
196+
Mat label;
197+
return predict(feature,res,label);
198+
}
199+
200+
201+
void LcRandomTreesC::save( string filename_prefix ){
202+
string filename = filename_prefix + "_rdtc.xml";
203+
_random_tree.save( filename.c_str());
204+
}
205+
206+
void LcRandomTreesC::load( string filename_prefix ){
207+
string filename = filename_prefix + "_rdtc.xml";
208+
_random_tree.load( filename.c_str());
209+
}
210+
211+
212+
//==============================
213+
214+
void LcDecisionTree::train(Mat & feature, Mat & label)
215+
{
216+
int TREE_DEPTH = 10;
217+
218+
_params = CvDTreeParams(TREE_DEPTH,10,0.0,true,TREE_DEPTH,4,true,true,0);
219+
220+
double t = double(getTickCount());
221+
222+
if( veb ) cout << "Train decision tree model ...";
223+
224+
Mat varType = Mat::ones(feature.cols+1,1,CV_8UC1) * CV_VAR_NUMERICAL; // all floats
225+
varType.at<uchar>(feature.cols,0) = CV_VAR_CATEGORICAL;
226+
227+
_tree.train(feature,CV_ROW_SAMPLE,label,Mat(),Mat(),varType,Mat(),_params);
228+
229+
t = (getTickCount()-t)/getTickFrequency();
230+
if( veb ) cout << " time:" << t << " secs." << endl;
231+
232+
}
233+
234+
LcValidator LcDecisionTree::predict( Mat & feature, Mat & res, Mat & label)
235+
{
236+
237+
int n = feature.rows;
238+
res = Mat::zeros( n, 1, 5);
239+
for(int i = 0; i< n ; i++)
240+
{
241+
CvDTreeNode *node;
242+
node = _tree.predict( feature.row(i) ,Mat(),false);
243+
res.at<float>(i,0) = float(node->value);
244+
}
245+
246+
if( label.rows == feature.rows ) return LcValidator( res, label);
247+
else return LcValidator();
248+
}
249+
250+
void LcDecisionTree::save( string filename_prefix ){
251+
string filename = filename_prefix + "_dt.xml";
252+
_tree.save( filename.c_str() );
253+
}
254+
255+
void LcDecisionTree::load( string filename_prefix ){
256+
string filename = filename_prefix + "_dt.xml";
257+
_tree.load( filename.c_str() );
258+
}
259+
260+
//==============================
261+
262+
void LcAdaBoosting::train(Mat & feature, Mat & label)
263+
{
264+
265+
int boost_type = CvBoost::GENTLE; //CvBoost::REAL; //CvBoost::GENTLE;
266+
int weak_count = 100;
267+
double weight_trim_rate = 0.95;
268+
int max_depth = 1;
269+
bool use_surrogates = false;
270+
const float* priors = NULL;
271+
_params = CvBoostParams(boost_type, weak_count,weight_trim_rate,max_depth,use_surrogates,priors);
272+
273+
Mat varType = Mat::ones(feature.cols+1,1,CV_8UC1) * CV_VAR_NUMERICAL; // all floats
274+
varType.at<uchar>(feature.cols,0) = CV_VAR_CATEGORICAL;
275+
276+
//lab = lab*2-1;
277+
//cout << lab << endl;
278+
//lab.convertTo(lab,CV_8UC1);
279+
280+
double t = (double)getTickCount();
281+
if(veb) cout << "Train (Gentle) AdaBoost model ...";
282+
_boost.train(feature,CV_ROW_SAMPLE,label,Mat(),Mat(),varType,Mat(),_params,false);
283+
t = (getTickCount()-t)/getTickFrequency();
284+
if(veb) cout << " time:" << t << " secs." << endl;
285+
}
286+
287+
LcValidator LcAdaBoosting::predict( Mat & feature, Mat & res, Mat & label)
288+
{
289+
int n = feature.rows;
290+
res = Mat::zeros( n, 1, 5);
291+
292+
for(int i = 0; i< n ; i++)
293+
{
294+
res.at<float>(i,0) = _boost.predict( feature.row(i) );
295+
}
296+
297+
if( label.rows == feature.rows ) return LcValidator( res, label);
298+
else return LcValidator();
299+
}
300+
301+
void LcAdaBoosting::save( string filename_prefix ){
302+
string filename = filename_prefix + "_ada.xml";
303+
_boost.save( filename.c_str() );
304+
}
305+
306+
void LcAdaBoosting::load( string filename_prefix ){
307+
string filename = filename_prefix + "_ada.xml";
308+
_boost.load( filename.c_str() );
309+
}
310+
311+
//==============================
312+
313+
LcKNN::LcKNN()
314+
{
315+
rotation_kernel = Mat();
316+
}
317+
318+
LcValidator LcKNN::predict(Mat & feature, Mat & res, Mat & label)
319+
{
320+
321+
cv::flann::Index _flann(_feat, cv::flann::KDTreeIndexParams(4));
322+
323+
int n = feature.rows;
324+
res = Mat::zeros( n, 1, 5);
325+
326+
Mat inds; Mat dists;
327+
328+
_flann.knnSearch(feature, inds, dists,knn,cv::flann::SearchParams(64));
329+
330+
for(int i = 0; i< n ; i++)
331+
{
332+
float sum_weight = 0.0f;
333+
334+
float sum_ans = 0.0f;
335+
336+
for(int k = 0;k< knn ;k++)
337+
{
338+
float m_weight = 1;//exp(- dists[k]/scale);
339+
int & id = inds.at<int>(i,k);
340+
sum_weight += m_weight;
341+
sum_ans += m_weight * _lab.at<float>(id,0);
342+
}
343+
344+
res.at<float>( i,0) = float( sum_ans/sum_weight);
345+
}
346+
347+
if( label.rows == feature.rows ) return LcValidator( res, label);
348+
else return LcValidator();
349+
}
350+
351+
void LcKNN::train(Mat & feature, Mat & label)
352+
{
353+
354+
knn = 5;
355+
356+
feature.copyTo(_feat);
357+
label.copyTo(_lab);
358+
359+
}
360+
361+
void LcKNN::save( string filename_prefix ){
362+
string feature_name = filename_prefix + "_knn_feat.bin";
363+
lc::LcMat2Bin( feature_name.c_str(), _feat);
364+
string label_name = filename_prefix + "_knn_lab.bin";
365+
lc::LcMat2Bin( label_name.c_str(), _lab);
366+
}
367+
368+
void LcKNN::load( string filename_prefix ){
369+
string feature_name = filename_prefix + "_knn_feat.bin";
370+
lc::LcBin2Mat( feature_name.c_str(), _feat);
371+
string label_name = filename_prefix + "_knn_lab.bin";
372+
lc::LcBin2Mat( label_name.c_str(), _lab);
373+
}

0 commit comments

Comments
 (0)