Skip to content

Commit 6dba636

Browse files
author
王小辉
committed
Updated XGBoost training script
1 parent 49b089c commit 6dba636

File tree

2 files changed

+280
-160
lines changed

2 files changed

+280
-160
lines changed

Dataset/README.md

+136-133
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55

66
可以从下面任何一个地址下载数据集,解压缩到 `dataset` 目录下:
77

8-
- [Google Drive](https://drive.google.com/open?id=0B2hKiPsUlgibMmQ0TWJHTjBmQXc)
9-
- [Baidu 网盘](https://pan.baidu.com/s/1qYntFR2) 密码:xnae
8+
- [Google Drive](https://drive.google.com/open?id=0B2hKiPsUlgibNElYNmFEWmFpbjA)
9+
- [Baidu 网盘](https://pan.baidu.com/s/1eSb3X0u) 密码:dgyf
1010

1111
### Data Structure
1212

13-
数据集中包含下面五种行为的传感器数据
13+
数据集中包含下面 6 种行为的传感器数据
1414

1515
- Walking
1616
- Running
1717
- Bus
1818
- Subway
1919
- Car
20+
- Stationary
2021

2122
由于数据采集是有成本的,为了保证后面测试的灵活性,在测试过程中,使用 100Hz 的频率进行传感器数据采集,在数据处理的时候可以进行降频重采样,测试在不同的较低频率下识别模型的性能。使用到的传感器类型
2223

@@ -27,23 +28,23 @@
2728
通过加速度、陀螺仪以及磁场传感器的数据,可以计算出设备从机身坐标系到真实世界坐标系的旋转矩阵,然后可以通过该旋转矩阵,将机身坐标系的加速度转换到真实世界坐标系,而且消除不同行为下,不同的设备姿态对传感器数据的影响:
2829

2930
```java
30-
public static void calculateWorldAcce(SensorData sd){
31-
float[] Rotate = new float[16];
32-
float[] I = new float[16];
33-
float[] currOrientation = new float[3];
34-
SensorManager.getRotationMatrix(Rotate, I, sd.gravity, sd.magnetic);
35-
SensorManager.getOrientation(Rotate, currOrientation);
36-
System.arraycopy(currOrientation, 0, sd.orientation, 0, 3);
31+
public static void calculateWorldAcce(SensorData sd){
32+
float[] Rotate = new float[16];
33+
float[] I = new float[16];
34+
float[] currOrientation = new float[3];
35+
SensorManager.getRotationMatrix(Rotate, I, sd.gravity, sd.magnetic);
36+
SensorManager.getOrientation(Rotate, currOrientation);
37+
System.arraycopy(currOrientation, 0, sd.orientation, 0, 3);
3738

38-
float[] relativeAcc = new float[4];
39-
float[] earthAcc = new float[4];
40-
float[] inv = new float[16];
41-
System.arraycopy(sd.accelerate, 0, relativeAcc, 0, 3);
42-
relativeAcc[3] = 0;
43-
android.opengl.Matrix.invertM(inv, 0, Rotate, 0);
44-
android.opengl.Matrix.multiplyMV(earthAcc, 0, inv, 0, relativeAcc, 0);
45-
System.arraycopy(earthAcc, 0, sd.world_accelerometer, 0, 3);
46-
}
39+
float[] relativeAcc = new float[4];
40+
float[] earthAcc = new float[4];
41+
float[] inv = new float[16];
42+
System.arraycopy(sd.accelerate, 0, relativeAcc, 0, 3);
43+
relativeAcc[3] = 0;
44+
android.opengl.Matrix.invertM(inv, 0, Rotate, 0);
45+
android.opengl.Matrix.multiplyMV(earthAcc, 0, inv, 0, relativeAcc, 0);
46+
System.arraycopy(earthAcc, 0, sd.world_accelerometer, 0, 3);
47+
}
4748
```
4849

4950

@@ -58,50 +59,50 @@
5859
在降频重采样的过程中,可以对重复值进行过滤:
5960

6061
```python
61-
def get_resample_dataset(file_path):
62-
re_sampled = []
63-
with open(file_path, "r") as lines:
64-
index = 0
65-
last_value = ""
66-
for line in lines:
67-
index += 1
68-
if index == 5:
69-
values = line.split(",")
70-
if len(values) == 6:
71-
current_value = "{},{},{}".format(values[3], values[4], values[5])
72-
if current_value != last_value:
73-
re_sampled.append(current_value)
74-
last_value = current_value
75-
index = 0
76-
else:
77-
index -= 1
78-
else:
79-
index -= 1
80-
print("\tAfter re-sampling, the count of the lines are: {}".format(len(re_sampled)))
81-
return re_sampled
62+
def get_resample_dataset(file_path):
63+
re_sampled = []
64+
with open(file_path, "r") as lines:
65+
index = 0
66+
last_value = ""
67+
for line in lines:
68+
index += 1
69+
if index == 5:
70+
values = line.split(",")
71+
if len(values) == 6:
72+
current_value = "{},{},{}".format(values[3], values[4], values[5])
73+
if current_value != last_value:
74+
re_sampled.append(current_value)
75+
last_value = current_value
76+
index = 0
77+
else:
78+
index -= 1
79+
else:
80+
index -= 1
81+
print("\tAfter re-sampling, the count of the lines are: {}".format(len(re_sampled)))
82+
return re_sampled
8283
```
8384

8485
另外,行为识别时,采集到的传感器数据是连续的时间序列数据,为了提高识别的及时性,我们可以通过半重叠的方式对采集到的数据进行处理,使得下一组数据的前半部分和前一组数据的后半部分一样,可以在时间窗口的一半时间给出识别结果:
8586

8687
```python
87-
def get_half_overlap_dataset(dataset):
88-
overlapped = []
89-
for i in range(0, len(dataset) - batch_size, batch_size / 2):
90-
overlapped.append(dataset[i: i + batch_size])
91-
print("\tThe number of the groups after half-overlapping is: {}".format(len(overlapped)))
92-
return overlapped
88+
def get_half_overlap_dataset(dataset):
89+
overlapped = []
90+
for i in range(0, len(dataset) - batch_size, batch_size / 2):
91+
overlapped.append(dataset[i: i + batch_size])
92+
print("\tThe number of the groups after half-overlapping is: {}".format(len(overlapped)))
93+
return overlapped
9394
```
9495

9596
接下来对数据进行随机分组,70% 的数据用于训练,30% 的数据用于模型测试。另一个做交叉验证的方法是:将数据集(每种行为的数据量相当,也就是数据集基本平衡)针对每个行为分为 10 份,每次有放回地随机抽取 7 份训练,剩余的 3 份用于测试,多次重复后对模型的 precision 和 recall 取平均。
9697

9798
```python
98-
def split_train_test(dataset):
99-
total_dataset = np.array(dataset)
100-
train_test_split = np.random.rand(len(total_dataset)) < 0.70
101-
train_dataset = total_dataset[train_test_split]
102-
test_dataset = total_dataset[~train_test_split]
103-
print("\t\tCount of train dataset: {}\n\t\tCount of test dataset: {}".format(len(train_dataset), len(test_dataset)))
104-
return train_dataset.tolist(), test_dataset.tolist()
99+
def split_train_test(dataset):
100+
total_dataset = np.array(dataset)
101+
train_test_split = np.random.rand(len(total_dataset)) < 0.70
102+
train_dataset = total_dataset[train_test_split]
103+
test_dataset = total_dataset[~train_test_split]
104+
print("\t\tCount of train dataset: {}\n\t\tCount of test dataset: {}".format(len(train_dataset), len(test_dataset)))
105+
return train_dataset.tolist(), test_dataset.tolist()
105106
```
106107

107108
### 特征抽取
@@ -126,51 +127,53 @@
126127
我们使用在 Kaggle 比赛中非常受欢迎的 XGBoost 算法进行模型训练:
127128

128129
```python
129-
def xgTestSelfDataset(train_X, train_Y, test_X, test_Y):
130-
import xgboost as xgb
131-
import time
130+
def xgTestSelfDataset(train_X, train_Y, test_X, test_Y):
131+
import xgboost as xgb
132+
import time
132133

133-
# label need to be 0 to num_class -1
134-
xg_train = xgb.DMatrix(train_X, label=train_Y)
135-
xg_test = xgb.DMatrix(test_X, label=test_Y)
136-
# setup parameters for xgboost
137-
param = {'objective': 'multi:softprob',
138-
'eta': 0.15,
139-
'max_depth': 6,
140-
'silent': 1,
141-
'num_class': 5,
142-
"n_estimators": 1000,
143-
"subsample": 0.7,
144-
"scale_pos_weight": 0.5,
145-
"seed": 32}
134+
# label need to be 0 to num_class -1
135+
xg_train = xgb.DMatrix(train_X, label=train_Y)
136+
xg_test = xgb.DMatrix(test_X, label=test_Y)
137+
# setup parameters for xgboost
138+
param = {'objective': 'multi:softprob',
139+
'eta': 0.15,
140+
'max_depth': 6,
141+
'silent': 1,
142+
'num_class': 5,
143+
"n_estimators": 1000,
144+
"subsample": 0.7,
145+
"scale_pos_weight": 0.5,
146+
"seed": 32}
146147

147-
watchlist = [(xg_train, 'train'), (xg_test, 'test')]
148-
num_round = 50
148+
watchlist = [(xg_train, 'train'), (xg_test, 'test')]
149+
num_round = 50
149150

150-
start = time.time()
151-
bst = xgb.train(param, xg_train, num_round, watchlist)
152-
trainDuration = time.time() - start
153-
start = time.time()
154-
yprob = bst.predict(xg_test).reshape(test_Y.shape[0], 5)
155-
testDuration = time.time() - start
156-
ylabel = np.argmax(yprob, axis=1)
151+
start = time.time()
152+
bst = xgb.train(param, xg_train, num_round, watchlist)
153+
trainDuration = time.time() - start
154+
start = time.time()
155+
yprob = bst.predict(xg_test).reshape(test_Y.shape[0], 5)
156+
testDuration = time.time() - start
157+
ylabel = np.argmax(yprob, axis=1)
157158

158-
if os.path.exists("rhar.model"):
159-
os.remove("rhar.model")
160-
bst.save_model("rhar.model")
159+
if os.path.exists("rhar.model"):
160+
os.remove("rhar.model")
161+
bst.save_model("rhar.model")
161162
```
162163

163164
在训练完成后,我们将得到的模型保存下来,之后在 Android 代码中加载使用。针对这里的数据集,我们得到了下面的 metrics:
164165

165-
Precision 0.959531251114
166-
Recall 0.959706959707
167-
f1_score 0.959480543169
166+
Precision 0.909969288145
167+
Recall 0.908256880734
168+
f1_score 0.90816711949
168169
confusion_matrix
169-
[[59 1 0 0 0]
170-
[ 0 70 0 0 0]
171-
[ 0 0 42 0 3]
172-
[ 0 0 0 52 1]
173-
[ 0 0 4 2 39]]
170+
[[ 93 0 1 0 0 0]
171+
[ 0 115 1 1 1 0]
172+
[ 0 0 102 3 5 0]
173+
[ 3 0 9 84 6 2]
174+
[ 2 0 7 11 85 1]
175+
[ 0 0 4 3 0 115]]
176+
predicting, classification error=0.091743
174177

175178
### 在 Android 程序中进行 inference
176179

@@ -181,56 +184,56 @@ XGBoost 的官方 Java 实现需要通过 jni 调用 native 模块,这里我
181184
我们的 XGBoost 分类器实现为:
182185

183186
```java
184-
public class XGBoostClassifier implements ClassifierInterface {
187+
public class XGBoostClassifier implements ClassifierInterface {
185188

186-
private Predictor predictor;
187-
private double[] features;
189+
private Predictor predictor;
190+
private double[] features;
188191

189-
public static final String TYPE = "xgboost";
192+
public static final String TYPE = "xgboost";
190193

191-
public XGBoostClassifier(Context ctx) {
192-
try {
193-
InputStream is = ctx.getAssets().open("rhar.model");
194-
predictor = new Predictor(is);
195-
is.close();
196-
} catch (Throwable t) {
197-
t.printStackTrace();
198-
}
199-
}
194+
public XGBoostClassifier(Context ctx) {
195+
try {
196+
InputStream is = ctx.getAssets().open("rhar.model");
197+
predictor = new Predictor(is);
198+
is.close();
199+
} catch (Throwable t) {
200+
t.printStackTrace();
201+
}
202+
}
200203

201-
/**
202-
* Extract and select features from the raw sensor data points.
203-
* These data points are collected with certain sampling frequency and windows.
204-
* @param sensorData Raw sensor data points.
205-
* @return Extracted features.
206-
*/
207-
private double[] prepareFeatures(SensorData[] sensorData, final int sampleFreq, final int sampleCount) {
204+
/**
205+
* Extract and select features from the raw sensor data points.
206+
* These data points are collected with certain sampling frequency and windows.
207+
* @param sensorData Raw sensor data points.
208+
* @return Extracted features.
209+
*/
210+
private double[] prepareFeatures(SensorData[] sensorData, final int sampleFreq, final int sampleCount) {
208211

209-
double[] matrix = new double[SensorFeature.FEATURE_COUNT];
210-
Feature aFeature = new Feature();
211-
aFeature.extractFeatures(sensorData, sampleFreq, sampleCount);
212-
System.arraycopy(aFeature.getFeaturesAsArray(), 0, matrix, 0, SensorFeature.FEATURE_COUNT);
213-
return matrix;
214-
}
212+
double[] matrix = new double[SensorFeature.FEATURE_COUNT];
213+
Feature aFeature = new Feature();
214+
aFeature.extractFeatures(sensorData, sampleFreq, sampleCount);
215+
System.arraycopy(aFeature.getFeaturesAsArray(), 0, matrix, 0, SensorFeature.FEATURE_COUNT);
216+
return matrix;
217+
}
215218

216-
/**
217-
* Recognize current human activity based on pre-defined rules.
218-
* @param sensorData Raw sensor data points.
219-
*/
220-
@Override
221-
public double[] recognize(SensorData[] sensorData, final int sampleFreq, final int sampleCount) {
222-
features = prepareFeatures(sensorData, sampleFreq, sampleCount);
223-
return predict();
224-
}
219+
/**
220+
* Recognize current human activity based on pre-defined rules.
221+
* @param sensorData Raw sensor data points.
222+
*/
223+
@Override
224+
public double[] recognize(SensorData[] sensorData, final int sampleFreq, final int sampleCount) {
225+
features = prepareFeatures(sensorData, sampleFreq, sampleCount);
226+
return predict();
227+
}
225228

226-
@Override
227-
public double[] getCurrentFeatures(){
228-
return features;
229-
}
229+
@Override
230+
public double[] getCurrentFeatures(){
231+
return features;
232+
}
230233

231-
private double[] predict() {
232-
FVec vector = FVec.Transformer.fromArray(features, true);
233-
return predictor.predict(vector);
234-
}
235-
}
234+
private double[] predict() {
235+
FVec vector = FVec.Transformer.fromArray(features, true);
236+
return predictor.predict(vector);
237+
}
238+
}
236239
```

0 commit comments

Comments
 (0)