1
+ /*
2
+ * @Author : victorsun
3
+ * @Date : 2019-12-04 20:15:29
4
+ * @LastEditors : victorsun - csxiaoyao
5
+ * @LastEditTime : 2020-03-22 20:13:25
6
+
7
+ */
8
+ import * as tf from '@tensorflow/tfjs' ;
9
+ import * as tfvis from '@tensorflow/tfjs-vis' ;
10
+ import { getInputs } from './data' ;
11
+ import { img2x , file2img } from './utils' ;
12
+
13
+ /**
14
+ * 【 迁移学习 】
15
+ * 把已训练好的模型参数迁移到新的模型来帮助新模型训练
16
+ * 深度学习模型参数多,从头训练成本高
17
+ * 删除原始模型的最后一层,基于此截断模型的输出训练一个新的(通常相当浅的)模型
18
+ * 本案例,在 mobilenet 基础上,最后输出 ['android', 'apple', 'windows'] 三选一
19
+ * 模型的保存
20
+ */
21
+ const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json' ;
22
+ const NUM_CLASSES = 3 ;
23
+ const BRAND_CLASSES = [ 'android' , 'apple' , 'windows' ] ;
24
+
25
+ window . onload = async ( ) => {
26
+ // 1. 获取输入数据并在 visor 面板中展示
27
+ const { inputs, labels } = await getInputs ( ) ;
28
+ const surface = tfvis . visor ( ) . surface ( { name : '输入示例' , styles : { height : 250 } } ) ;
29
+ inputs . forEach ( img => {
30
+ surface . drawArea . appendChild ( img ) ;
31
+ } ) ;
32
+
33
+ // 加载mobilenet 模型并截断 构建双层神经网络 截断模型作为输入,双层神经网络作为输出
34
+ // 2. 模型迁移
35
+ // 2.1 加载 mobilenet 模型, tfjs_layers_model 格式
36
+ const mobilenet = await tf . loadLayersModel ( MOBILENET_MODEL_PATH ) ;
37
+ // 查看模型概况
38
+ mobilenet . summary ( ) ;
39
+
40
+ // 2.2 获取模型中间层并截断
41
+ const layer = mobilenet . getLayer ( 'conv_pw_13_relu' ) ; // 根据层名获取层
42
+ // 生成新的截断模型
43
+ const truncatedMobilenet = tf . model ( {
44
+ inputs : mobilenet . inputs ,
45
+ outputs : layer . output
46
+ } ) ;
47
+
48
+ // 3. 构建双层神经网络,tensor数据从 mobilenet 模型 flow 到 构建到双层神经网络模型
49
+ // 初始化神经网络模型
50
+ const model = tf . sequential ( ) ;
51
+ // flatten输入
52
+ model . add ( tf . layers . flatten ( {
53
+ inputShape : layer . outputShape . slice ( 1 ) // [null,7,7,256] => [7,7,256],null表示个数不定,此处删除
54
+ } ) ) ;
55
+ // 双层神经网络
56
+ model . add ( tf . layers . dense ( {
57
+ units : 10 ,
58
+ activation : 'relu'
59
+ } ) ) ;
60
+ model . add ( tf . layers . dense ( {
61
+ units : NUM_CLASSES , // 输出类别数量
62
+ activation : 'softmax'
63
+ } ) ) ;
64
+
65
+ // 4. 训练
66
+ // 4.1 定义损失函数和优化器
67
+ model . compile ( {
68
+ loss : 'categoricalCrossentropy' , // 交叉熵
69
+ optimizer : tf . train . adam ( )
70
+ } ) ;
71
+ // 4.2 数据预处理: 处理输入为截断模型接受的数据格式,即 mobilenet 接受的格式
72
+ const { xs, ys } = tf . tidy ( ( ) => {
73
+ // img2x: img 转 mobilenet 接受的tensor格式,并合并单个 tensor 为一个大 tensor
74
+ const xs = tf . concat ( inputs . map ( imgEl => truncatedMobilenet . predict ( img2x ( imgEl ) ) ) ) ;
75
+ const ys = tf . tensor ( labels ) ;
76
+ return { xs, ys } ;
77
+ } ) ;
78
+ // 4.3 通过 fit 方法训练
79
+ await model . fit ( xs , ys , {
80
+ epochs : 20 ,
81
+ callbacks : tfvis . show . fitCallbacks (
82
+ { name : '训练效果' } ,
83
+ [ 'loss' ] ,
84
+ { callbacks : [ 'onEpochEnd' ] }
85
+ )
86
+ } ) ;
87
+
88
+ // 5. 迁移学习下的模型预测
89
+ window . predict = async ( file ) => {
90
+ const img = await file2img ( file ) ;
91
+ document . body . appendChild ( img ) ;
92
+ const pred = tf . tidy ( ( ) => {
93
+ // img 转 tensor
94
+ const x = img2x ( img ) ;
95
+ // 截断模型先执行
96
+ const input = truncatedMobilenet . predict ( x ) ;
97
+ // 再用新模型预测出最终结果
98
+ return model . predict ( input ) ;
99
+ } ) ;
100
+ const index = pred . argMax ( 1 ) . dataSync ( ) [ 0 ] ;
101
+ setTimeout ( ( ) => {
102
+ alert ( `预测结果:${ BRAND_CLASSES [ index ] } ` ) ;
103
+ } , 0 ) ;
104
+ } ;
105
+
106
+ // 6. 模型的保存 tfjs_layers_model
107
+ // json + 权重bin
108
+ window . download = async ( ) => {
109
+ await model . save ( 'downloads://model' ) ;
110
+ } ;
111
+ } ;
0 commit comments