Skip to content

Commit a8bdae8

Browse files
authored
Merge pull request #40 from lienhua34/dist-base
调整支持原生 TensorFlow API 的 TaaS 分布式模型训练任务代码模板
2 parents 24fed83 + 87e1fba commit a8bdae8

File tree

1 file changed

+58
-8
lines changed

1 file changed

+58
-8
lines changed

caicloud.tensorflow/caicloud/clever/examples/dist-tf-template.py renamed to caicloud.tensorflow/caicloud/clever/templates/template-raw.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88
from caicloud.clever.tensorflow import dist_base
99
from caicloud.clever.tensorflow import model_exporter
1010

11+
tf.app.flags.DEFINE_string("export_dir",
12+
"/tmp/mnist/saved_model",
13+
"model export directory path.")
14+
tf.app.flags.DEFINE_string("checkpoint_dir",
15+
"",
16+
"model checkpoint directory path.")
17+
FLAGS = tf.app.flags.FLAGS
18+
19+
_train_op = None
20+
1121
def model_fn(sync, num_replicas):
1222
"""TensorFlow 模型定义函数。
1323
@@ -18,13 +28,33 @@ def model_fn(sync, num_replicas):
1828
- `sync`:当前是否采用参数同步更新模式。
1929
- `num_replicas`:分布式 TensorFlow 的计算节点(worker)个数。
2030
"""
31+
global _train_op
2132

2233
# TODO:添加业务模型定义操作。
23-
34+
# global_step = ...
35+
# _train_op = ...
36+
37+
# 添加模型评估配置:
38+
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
39+
# def accuracy_evalute_fn(session):
40+
# return session.run(accuracy, ...)
41+
# model_metric_ops = {
42+
# "accuracy": accuracy_evalute_fn
43+
# }
44+
45+
# 定义模型导出配置
46+
# model_export_spec = model_exporter.ModelExportSpec(
47+
# export_dir=FLAGS.export_dir,
48+
# input_tensors={"image": _input_images},
49+
# output_tensors={"logits": logits})
50+
2451
# model_fn 函数需要返回 ModelFnHandler 对象告知 TaaS 平台所构建的模型的一些信息,
2552
# 例如 global_step、优化器 Optimizer、模型评估指标以及模型导出的相关配置等等。
2653
# 详细信息请参考 docs.caicloud.io。
27-
return dist_base.ModelFnHandler()
54+
return dist_base.ModelFnHandler(
55+
global_step = global_step,
56+
model_metric_ops = model_metric_ops,
57+
model_export_spec = model_export_spec)
2858

2959
def train_fn(session, num_global_step):
3060
"""模型训练的每一轮操作。
@@ -38,21 +68,41 @@ def train_fn(session, num_global_step):
3868

3969
# TODO:添加业务模型训练操作。
4070

41-
# train_fn 函数返回一个 bool 值,用于告知 TaaS 平台是否要提前终止模型训练。返回 True,
42-
# 表示终止训练;否则,TaaS 将继续下一轮训练。
43-
# 例如,为了防止训练模型过拟合,在训练过程中定时使用验证数据评测模型效果。如果发现模型
44-
# 在训练数据集上的效果有优化,而在验证数据集上的效果却开始劣化,则说明模型可能出现了过
45-
# 拟合,此时我们就可以通过返回 True 来告知 TaaS 平台提前终止模型训练。
71+
# train_fn 函数返回一个 bool 值,用于告知 TaaS 平台是否要提前终止模型训练。
72+
# 返回 True,表示终止训练;否则,TaaS 将继续下一轮训练。
73+
# 例如,为了防止训练模型过拟合,在训练过程中定时使用验证数据评测模型效果。当模型效果
74+
# 达到预期效果,便可以通过返回 True 来结束模型训练。
4675
return False
4776

4877
def gen_init_fn():
4978
"""获取自定义初始化函数。
5079
5180
有些情况下,我需要从某个事先训练好的 checkpoint 文件中加载模型的参数。此时,我们需要自
5281
己实现使用 tf.Saver() 从该 checkpoint 中加载模型参数进行自定义初始化的函数。
82+
83+
注:如果不需要自定义初始化,可以不提供 gen_init_fn 实现,或者 gen_init_fn 返回 None。
5384
"""
54-
return None
5585

86+
# TODO: 添加自己的处理逻辑
87+
88+
# 定义 tf.train.Saver 会修改 TensorFlow 的 Graph 结构,
89+
# 而当 Base 框架调用自定义初始化函数 init_from_checkpoint 的时候,
90+
# TensorFlow 模型的 Graph 结构已经变成 finalized,不再允许修改 Graph 结构。
91+
# 所以,这个定义必须放在 init_from_checkpoint 函数外面。
92+
saver = tf.train.Saver(tf.trainable_variables())
93+
94+
def init_from_checkpoint(scaffold, sess):
95+
"""执行自定义初始化的函数。
96+
97+
TaaS 平台会优先从设置的日志保存路径中获取最新的 checkpoint 来 restore 模型参数,
98+
如果日志保存路径中找不到 checkpoint 文件,才会调用本函数来进行模型初始化。
99+
100+
本函数必须接收两个参数:
101+
- scafford: tf.train.Scaffold 对象;
102+
- sess: tf.Session 对象。
103+
"""
104+
saver.restore(sess, checkpoint_path)
105+
return init_from_checkpoint
56106

57107
def after_train_hook(session):
58108
"""模型训练操作。

0 commit comments

Comments
 (0)