8
8
from caicloud .clever .tensorflow import dist_base
9
9
from caicloud .clever .tensorflow import model_exporter
10
10
11
+ tf .app .flags .DEFINE_string ("export_dir" ,
12
+ "/tmp/mnist/saved_model" ,
13
+ "model export directory path." )
14
+ tf .app .flags .DEFINE_string ("checkpoint_dir" ,
15
+ "" ,
16
+ "model checkpoint directory path." )
17
+ FLAGS = tf .app .flags .FLAGS
18
+
19
+ _train_op = None
20
+
11
21
def model_fn (sync , num_replicas ):
12
22
"""TensorFlow 模型定义函数。
13
23
@@ -18,13 +28,33 @@ def model_fn(sync, num_replicas):
18
28
- `sync`:当前是否采用参数同步更新模式。
19
29
- `num_replicas`:分布式 TensorFlow 的计算节点(worker)个数。
20
30
"""
31
+ global _train_op
21
32
22
33
# TODO:添加业务模型定义操作。
23
-
34
+ # global_step = ...
35
+ # _train_op = ...
36
+
37
+ # 添加模型评估配置:
38
+ # accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
39
+ # def accuracy_evalute_fn(session):
40
+ # return session.run(accuracy, ...)
41
+ # model_metric_ops = {
42
+ # "accuracy": accuracy_evalute_fn
43
+ # }
44
+
45
+ # 定义模型导出配置
46
+ # model_export_spec = model_exporter.ModelExportSpec(
47
+ # export_dir=FLAGS.export_dir,
48
+ # input_tensors={"image": _input_images},
49
+ # output_tensors={"logits": logits})
50
+
24
51
# model_fn 函数需要返回 ModelFnHandler 对象告知 TaaS 平台所构建的模型的一些信息,
25
52
# 例如 global_step、优化器 Optimizer、模型评估指标以及模型导出的相关配置等等。
26
53
# 详细信息请参考 docs.caicloud.io。
27
- return dist_base .ModelFnHandler ()
54
+ return dist_base .ModelFnHandler (
55
+ global_step = global_step ,
56
+ model_metric_ops = model_metric_ops ,
57
+ model_export_spec = model_export_spec )
28
58
29
59
def train_fn (session , num_global_step ):
30
60
"""模型训练的每一轮操作。
@@ -38,21 +68,41 @@ def train_fn(session, num_global_step):
38
68
39
69
# TODO:添加业务模型训练操作。
40
70
41
- # train_fn 函数返回一个 bool 值,用于告知 TaaS 平台是否要提前终止模型训练。返回 True,
42
- # 表示终止训练;否则,TaaS 将继续下一轮训练。
43
- # 例如,为了防止训练模型过拟合,在训练过程中定时使用验证数据评测模型效果。如果发现模型
44
- # 在训练数据集上的效果有优化,而在验证数据集上的效果却开始劣化,则说明模型可能出现了过
45
- # 拟合,此时我们就可以通过返回 True 来告知 TaaS 平台提前终止模型训练。
71
+ # train_fn 函数返回一个 bool 值,用于告知 TaaS 平台是否要提前终止模型训练。
72
+ # 返回 True,表示终止训练;否则,TaaS 将继续下一轮训练。
73
+ # 例如,为了防止训练模型过拟合,在训练过程中定时使用验证数据评测模型效果。当模型效果
74
+ # 达到预期效果,便可以通过返回 True 来结束模型训练。
46
75
return False
47
76
48
77
def gen_init_fn ():
49
78
"""获取自定义初始化函数。
50
79
51
80
有些情况下,我需要从某个事先训练好的 checkpoint 文件中加载模型的参数。此时,我们需要自
52
81
己实现使用 tf.Saver() 从该 checkpoint 中加载模型参数进行自定义初始化的函数。
82
+
83
+ 注:如果不需要自定义初始化,可以不提供 gen_init_fn 实现,或者 gen_init_fn 返回 None。
53
84
"""
54
- return None
55
85
86
+ # TODO: 添加自己的处理逻辑
87
+
88
+ # 定义 tf.train.Saver 会修改 TensorFlow 的 Graph 结构,
89
+ # 而当 Base 框架调用自定义初始化函数 init_from_checkpoint 的时候,
90
+ # TensorFlow 模型的 Graph 结构已经变成 finalized,不再允许修改 Graph 结构。
91
+ # 所以,这个定义必须放在 init_from_checkpoint 函数外面。
92
+ saver = tf .train .Saver (tf .trainable_variables ())
93
+
94
+ def init_from_checkpoint (scaffold , sess ):
95
+ """执行自定义初始化的函数。
96
+
97
+ TaaS 平台会优先从设置的日志保存路径中获取最新的 checkpoint 来 restore 模型参数,
98
+ 如果日志保存路径中找不到 checkpoint 文件,才会调用本函数来进行模型初始化。
99
+
100
+ 本函数必须接收两个参数:
101
+ - scafford: tf.train.Scaffold 对象;
102
+ - sess: tf.Session 对象。
103
+ """
104
+ saver .restore (sess , checkpoint_path )
105
+ return init_from_checkpoint
56
106
57
107
def after_train_hook (session ):
58
108
"""模型训练操作。
0 commit comments