Skip to content

Commit 4eae63a

Browse files
tensorflow.js
1 parent b826685 commit 4eae63a

File tree

19 files changed

+580
-26
lines changed

19 files changed

+580
-26
lines changed

20-tensorflow.js/.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
node_modules
2+
dist
3+
.cache

20-tensorflow.js/05-logistic-regression/script.js

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* @Author: victorsun
33
* @Date: 2019-12-04 20:15:29
44
* @LastEditors: victorsun - csxiaoyao
5-
* @LastEditTime: 2020-03-22 00:05:54
5+
* @LastEditTime: 2020-03-22 16:24:51
66
* @Description: [email protected]
77
*/
88
import * as tf from '@tensorflow/tfjs';
@@ -31,11 +31,11 @@ window.onload = async () => {
3131

3232
// 2. 初始化神经网络模型
3333
const model = tf.sequential();
34-
// 添加层,dense: y=ax+b,设置激活函数sigmoid(防止输入超过100%,对过大过小值收敛,保证数据在 0 - 1 之间)
34+
// 添加层,dense: y=ax+b
3535
model.add(tf.layers.dense({
3636
units: 1, // 输出值为一个概率值,1个神经元即可
3737
inputShape: [2], // 坐标 x,y 两个值,特征数量为2
38-
activation: 'sigmoid' // 设置激活函数 sigmoid 0-1
38+
activation: 'sigmoid' // 设置激活函数 sigmoid 0-1,(防止输入超过100%,对过大过小值收敛,保证数据在 0 - 1 之间)
3939
}));
4040
// 设置损失函数和优化器
4141
model.compile({

20-tensorflow.js/11-mobilenet/script.js

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* @Author: victorsun
33
* @Date: 2019-12-04 20:15:29
44
* @LastEditors: victorsun - csxiaoyao
5-
* @LastEditTime: 2020-03-22 16:10:04
5+
* @LastEditTime: 2020-03-22 16:17:47
66
* @Description: [email protected]
77
*/
88
import * as tf from '@tensorflow/tfjs';
@@ -15,6 +15,7 @@ import { file2img } from './utils';
1515
* 在 tensorflow.js 中调用web格式的模型文件
1616
*
1717
* 【 MobileNet模型 】
18+
* 图像分类模型
1819
* 卷积神经网络模型的一种,轻量、速度快,但是准确性一般
1920
*
2021
* 【 文件说明 】

20-tensorflow.js/12-brand/data.js

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* @Author: victorsun
3+
* @Date: 2019-12-04 20:15:29
4+
* @LastEditors: victorsun - csxiaoyao
5+
* @LastEditTime: 2020-03-22 16:32:52
6+
* @Description: 图片加载
7+
*/
8+
9+
const IMAGE_SIZE = 224;
10+
11+
const loadImg = (src) => {
12+
return new Promise(resolve => {
13+
const img = new Image();
14+
// 允许跨域
15+
img.crossOrigin = "anonymous";
16+
img.src = src;
17+
img.width = IMAGE_SIZE;
18+
img.height = IMAGE_SIZE;
19+
img.onload = () => resolve(img);
20+
});
21+
};
22+
export const getInputs = async () => {
23+
const loadImgs = []; // img - promise
24+
const labels = [];
25+
for (let i = 0; i < 30; i += 1) {
26+
['android', 'apple', 'windows'].forEach(label => {
27+
const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;
28+
const img = loadImg(src);
29+
loadImgs.push(img);
30+
labels.push([
31+
label === 'android' ? 1 : 0,
32+
label === 'apple' ? 1 : 0,
33+
label === 'windows' ? 1 : 0,
34+
]);
35+
});
36+
}
37+
const inputs = await Promise.all(loadImgs);
38+
return {
39+
inputs,
40+
labels,
41+
};
42+
}

20-tensorflow.js/12-brand/index.html

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
<script src="script.js"></script>
2+
<input type="file" onchange="predict(this.files[0])">
3+
<button onclick="download()">下载模型</button>

20-tensorflow.js/12-brand/script.js

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* @Author: victorsun
3+
* @Date: 2019-12-04 20:15:29
4+
* @LastEditors: victorsun - csxiaoyao
5+
* @LastEditTime: 2020-03-22 20:13:25
6+
* @Description: [email protected]
7+
*/
8+
import * as tf from '@tensorflow/tfjs';
9+
import * as tfvis from '@tensorflow/tfjs-vis';
10+
import { getInputs } from './data';
11+
import { img2x, file2img } from './utils';
12+
13+
/**
14+
* 【 迁移学习 】
15+
* 把已训练好的模型参数迁移到新的模型来帮助新模型训练
16+
* 深度学习模型参数多,从头训练成本高
17+
* 删除原始模型的最后一层,基于此截断模型的输出训练一个新的(通常相当浅的)模型
18+
* 本案例,在 mobilenet 基础上,最后输出 ['android', 'apple', 'windows'] 三选一
19+
* 模型的保存
20+
*/
21+
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
22+
const NUM_CLASSES = 3;
23+
const BRAND_CLASSES = ['android', 'apple', 'windows'];
24+
25+
window.onload = async () => {
26+
// 1. 获取输入数据并在 visor 面板中展示
27+
const { inputs, labels } = await getInputs();
28+
const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
29+
inputs.forEach(img => {
30+
surface.drawArea.appendChild(img);
31+
});
32+
33+
// 加载mobilenet 模型并截断 构建双层神经网络 截断模型作为输入,双层神经网络作为输出
34+
// 2. 模型迁移
35+
// 2.1 加载 mobilenet 模型, tfjs_layers_model 格式
36+
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
37+
// 查看模型概况
38+
mobilenet.summary();
39+
40+
// 2.2 获取模型中间层并截断
41+
const layer = mobilenet.getLayer('conv_pw_13_relu'); // 根据层名获取层
42+
// 生成新的截断模型
43+
const truncatedMobilenet = tf.model({
44+
inputs: mobilenet.inputs,
45+
outputs: layer.output
46+
});
47+
48+
// 3. 构建双层神经网络,tensor数据从 mobilenet 模型 flow 到 构建到双层神经网络模型
49+
// 初始化神经网络模型
50+
const model = tf.sequential();
51+
// flatten输入
52+
model.add(tf.layers.flatten({
53+
inputShape: layer.outputShape.slice(1) // [null,7,7,256] => [7,7,256],null表示个数不定,此处删除
54+
}));
55+
// 双层神经网络
56+
model.add(tf.layers.dense({
57+
units: 10,
58+
activation: 'relu'
59+
}));
60+
model.add(tf.layers.dense({
61+
units: NUM_CLASSES, // 输出类别数量
62+
activation: 'softmax'
63+
}));
64+
65+
// 4. 训练
66+
// 4.1 定义损失函数和优化器
67+
model.compile({
68+
loss: 'categoricalCrossentropy', // 交叉熵
69+
optimizer: tf.train.adam()
70+
});
71+
// 4.2 数据预处理: 处理输入为截断模型接受的数据格式,即 mobilenet 接受的格式
72+
const { xs, ys } = tf.tidy(() => {
73+
// img2x: img 转 mobilenet 接受的tensor格式,并合并单个 tensor 为一个大 tensor
74+
const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl))));
75+
const ys = tf.tensor(labels);
76+
return { xs, ys };
77+
});
78+
// 4.3 通过 fit 方法训练
79+
await model.fit(xs, ys, {
80+
epochs: 20,
81+
callbacks: tfvis.show.fitCallbacks(
82+
{ name: '训练效果' },
83+
['loss'],
84+
{ callbacks: ['onEpochEnd'] }
85+
)
86+
});
87+
88+
// 5. 迁移学习下的模型预测
89+
window.predict = async (file) => {
90+
const img = await file2img(file);
91+
document.body.appendChild(img);
92+
const pred = tf.tidy(() => {
93+
// img 转 tensor
94+
const x = img2x(img);
95+
// 截断模型先执行
96+
const input = truncatedMobilenet.predict(x);
97+
// 再用新模型预测出最终结果
98+
return model.predict(input);
99+
});
100+
const index = pred.argMax(1).dataSync()[0];
101+
setTimeout(() => {
102+
alert(`预测结果:${BRAND_CLASSES[index]}`);
103+
}, 0);
104+
};
105+
106+
// 6. 模型的保存 tfjs_layers_model
107+
// json + 权重bin
108+
window.download = async () => {
109+
await model.save('downloads://model');
110+
};
111+
};

20-tensorflow.js/12-brand/utils.js

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import * as tf from '@tensorflow/tfjs';
2+
3+
// img 转 mobilenet 接受的格式
4+
export function img2x(imgEl){
5+
return tf.tidy(() => {
6+
const input = tf.browser.fromPixels(imgEl)
7+
.toFloat()
8+
.sub(255 / 2)
9+
.div(255 / 2)
10+
.reshape([1, 224, 224, 3]);
11+
return input;
12+
});
13+
}
14+
15+
export function file2img(f) {
16+
return new Promise(resolve => {
17+
const reader = new FileReader();
18+
reader.readAsDataURL(f);
19+
reader.onload = (e) => {
20+
const img = document.createElement('img');
21+
img.src = e.target.result;
22+
img.width = 224;
23+
img.height = 224;
24+
img.onload = () => resolve(img);
25+
};
26+
});
27+
}

20-tensorflow.js/13-speech/index.html

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<script src="script.js"></script>
2+
<style>
3+
#result>div {
4+
float: left;
5+
padding: 20px;
6+
}
7+
</style>
8+
<div id="result"></div>

20-tensorflow.js/13-speech/script.js

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* @Author: victorsun
3+
* @Date: 2019-12-04 20:15:29
4+
* @LastEditors: victorsun - csxiaoyao
5+
* @LastEditTime: 2020-03-22 18:03:26
6+
* @Description: [email protected]
7+
*/
8+
import * as speechCommands from '@tensorflow-models/speech-commands';
9+
10+
/**
11+
* 【 使用预训练模型进行语音识别 】
12+
* 语音识别的本质是分类
13+
*
14+
* $ npm i @tensorflow-models/speech-commands
15+
*/
16+
const MODEL_PATH = 'http://127.0.0.1:8080/speech';
17+
18+
window.onload = async () => {
19+
// 创建识别器
20+
const recognizer = speechCommands.create(
21+
'BROWSER_FFT', // 语音识别需要用到傅立叶变换,此处使用浏览器自带的傅立叶
22+
null, // 识别的单词,null为默认单词
23+
MODEL_PATH + '/model.json', // 模型
24+
MODEL_PATH + '/metadata.json' // 自定义源信息
25+
);
26+
// 确保模型加载
27+
await recognizer.ensureModelLoaded();
28+
29+
// 获取模型能够识别的单词
30+
const labels = recognizer.wordLabels().slice(2); // 去掉前两个无意义的单词
31+
console.log(labels);
32+
33+
// 绘制交互界面
34+
const resultEl = document.querySelector('#result');
35+
resultEl.innerHTML = labels.map(l => `
36+
<div>${l}</div>
37+
`).join('');
38+
39+
// 打开设备麦克风监听,可以不用编写 h5 代码
40+
recognizer.listen(result => {
41+
const { scores } = result;
42+
const maxValue = Math.max(...scores);
43+
const index = scores.indexOf(maxValue) - 2; // 去掉前两个无意义的单词
44+
resultEl.innerHTML = labels.map((l, i) => `
45+
<div style="background: ${i === index && 'green'}">${l}</div>
46+
`).join('');
47+
}, {
48+
overlapFactor: 0.3, // 识别频率
49+
probabilityThreshold: 0.9 // 识别阈值,超过指定的准确度即执行上面的回调
50+
});
51+
};
+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
<script src="script.js"></script>
2+
<button onclick="collect(this)">上一张</button>
3+
<button onclick="collect(this)">下一张</button>
4+
<button onclick="collect(this)">背景噪音</button>
5+
<button onclick="save()">保存</button>
6+
<pre id="count"></pre>
7+
<button onclick="train()">训练</button>
8+
<br><br>
9+
监听开关:<input type="checkbox" onchange="toggle(this.checked)">

0 commit comments

Comments
 (0)