Skip to content

Commit f52058c

Browse files
authored
Add files via upload
1 parent 71d0dc4 commit f52058c

10 files changed

+1647
-0
lines changed

Diff for: Resnet/README.md

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# tensorflow_models_learning
2+
> 老铁要是觉得不错,给个“star”
3+
## 1.生成record训练数据
4+
dataset已经包含了训练和测试的图片,请直接运行create_tf_record.py</br>
5+
> 对于InceptionNet V1:设置resize_height和resize_width = 224 </br>
6+
> 对于InceptionNet V3:设置resize_height和resize_width = 299 </br>
7+
> 其他模型,请根据输入需要设置resize_height和resize_width的大小</br>
8+
9+
```
10+
if __name__ == '__main__':
11+
# 参数设置
12+
resize_height = 224 # 指定存储图片高度
13+
resize_width = 224 # 指定存储图片宽度
14+
shuffle=True
15+
log=5
16+
# 产生train.record文件
17+
image_dir='dataset/train'
18+
train_labels = 'dataset/train.txt' # 图片路径
19+
train_record_output = 'dataset/record/train{}.tfrecords'.format(resize_height)
20+
create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)
21+
train_nums=get_example_nums(train_record_output)
22+
print("save train example nums={}".format(train_nums))
23+
24+
# 产生val.record文件
25+
image_dir='dataset/val'
26+
val_labels = 'dataset/val.txt' # 图片路径
27+
val_record_output = 'dataset/record/val{}.tfrecords'.format(resize_height)
28+
create_records(image_dir,val_labels, val_record_output, resize_height, resize_width,shuffle,log)
29+
val_nums=get_example_nums(val_record_output)
30+
print("save val example nums={}".format(val_nums))
31+
32+
# 测试显示函数
33+
# disp_records(train_record_output,resize_height, resize_width)
34+
batch_test(train_record_output,resize_height, resize_width)
35+
36+
```
37+
## 2.训练过程
38+
目前提供VGG、inception_v1、inception_v3、mobilenet_v以及resnet_v1的训练文件,只需要生成tfrecord数据,即可开始训练
39+
> 训练VGG请直接运行:vgg_train_val.py </br>
40+
> 训练inception_v1请直接运行:inception_v1_train_val.py </br>
41+
> 训练inception_v3请直接运行:inception_v3_train_val.py </br>
42+
> 训练mobilenet_v1请直接运行:mobilenet_train_val.py </br>
43+
> 其他模型,请参考训练文件进行修改</br>
44+
45+
## 3.资源下载
46+
- 本项目详细说明,请参考鄙人博客资料:
47+
> 《使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型》: https://panjinquan.blog.csdn.net/article/details/81560537 </br>
48+
> 《tensorflow实现将ckpt转pb文件》: https://panjinquan.blog.csdn.net/article/details/82218092 </br>
49+
> 《使用自己的数据集训练MobileNet、ResNet实现图像分类(TensorFlow)》https://panjinquan.blog.csdn.net/article/details/88252699
50+
> 预训练模型下载地址: https://download.csdn.net/download/guyuealian/10610847 </br>
51+
- 老铁要是觉得不错,给个“star”
52+
- tensorflow-gpu==1.4.0

Diff for: Resnet/convert_pb.py

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# -*-coding: utf-8 -*-
2+
"""
3+
@Project: tensorflow_models_nets
4+
@File : convert_pb.py
5+
@Author : panjq
6+
7+
@Date : 2018-08-29 17:46:50
8+
@info :
9+
-通过传入 CKPT 模型的路径得到模型的图和变量数据
10+
-通过 import_meta_graph 导入模型中的图
11+
-通过 saver.restore 从模型中恢复图中各个变量的数据
12+
-通过 graph_util.convert_variables_to_constants 将模型持久化
13+
"""
14+
15+
import tensorflow as tf
16+
from create_tf_record import *
17+
from tensorflow.python.framework import graph_util
18+
19+
resize_height = 299 # 指定图片高度
20+
resize_width = 299 # 指定图片宽度
21+
depths = 3
22+
23+
def freeze_graph_test(pb_path, image_path):
24+
'''
25+
:param pb_path:pb文件的路径
26+
:param image_path:测试图片的路径
27+
:return:
28+
'''
29+
with tf.Graph().as_default():
30+
output_graph_def = tf.GraphDef()
31+
with open(pb_path, "rb") as f:
32+
output_graph_def.ParseFromString(f.read())
33+
tf.import_graph_def(output_graph_def, name="")
34+
with tf.Session() as sess:
35+
sess.run(tf.global_variables_initializer())
36+
37+
# 定义输入的张量名称,对应网络结构的输入张量
38+
# input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
39+
input_image_tensor = sess.graph.get_tensor_by_name("input:0")
40+
input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
41+
input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
42+
43+
# 定义输出的张量名称
44+
output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
45+
46+
# 读取测试图片
47+
im=read_image(image_path,resize_height,resize_width,normalization=True)
48+
im=im[np.newaxis,:]
49+
# 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
50+
# out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
51+
out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
52+
input_keep_prob_tensor:1.0,
53+
input_is_training_tensor:False})
54+
print("out:{}".format(out))
55+
score = tf.nn.softmax(out, name='pre')
56+
class_id = tf.argmax(score, 1)
57+
print("pre class_id:{}".format(sess.run(class_id)))
58+
59+
60+
def freeze_graph(input_checkpoint,output_graph):
61+
'''
62+
63+
:param input_checkpoint:
64+
:param output_graph: PB模型保存路径
65+
:return:
66+
'''
67+
# checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
68+
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
69+
70+
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
71+
output_node_names = "InceptionV3/Logits/SpatialSqueeze"
72+
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
73+
74+
with tf.Session() as sess:
75+
saver.restore(sess, input_checkpoint) #恢复图并得到数据
76+
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
77+
sess=sess,
78+
input_graph_def=sess.graph_def,# 等于:sess.graph_def
79+
output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
80+
81+
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
82+
f.write(output_graph_def.SerializeToString()) #序列化输出
83+
print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
84+
85+
# for op in sess.graph.get_operations():
86+
# print(op.name, op.values())
87+
88+
def freeze_graph2(input_checkpoint,output_graph):
89+
'''
90+
91+
:param input_checkpoint:
92+
:param output_graph: PB模型保存路径
93+
:return:
94+
'''
95+
# checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
96+
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
97+
98+
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
99+
output_node_names = "InceptionV3/Logits/SpatialSqueeze"
100+
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
101+
graph = tf.get_default_graph() # 获得默认的图
102+
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
103+
104+
with tf.Session() as sess:
105+
saver.restore(sess, input_checkpoint) #恢复图并得到数据
106+
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
107+
sess=sess,
108+
input_graph_def=input_graph_def,# 等于:sess.graph_def
109+
output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
110+
111+
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
112+
f.write(output_graph_def.SerializeToString()) #序列化输出
113+
print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
114+
115+
# for op in graph.get_operations():
116+
# print(op.name, op.values())
117+
118+
119+
if __name__ == '__main__':
120+
# 输入ckpt模型路径
121+
input_checkpoint='models/model.ckpt-10000'
122+
# 输出pb模型的路径
123+
out_pb_path="models/pb/frozen_model.pb"
124+
# 调用freeze_graph将ckpt转为pb
125+
freeze_graph(input_checkpoint,out_pb_path)
126+
127+
# 测试pb模型
128+
image_path = 'test_image/animal.jpg'
129+
freeze_graph_test(pb_path=out_pb_path, image_path=image_path)

Diff for: Resnet/create_labels_files.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# -*-coding:utf-8-*-
2+
"""
3+
@Project: googlenet_classification
4+
@File : create_labels_files.py
5+
@Author : panjq
6+
7+
@Date : 2018-08-11 10:15:28
8+
"""
9+
10+
import os
11+
import os.path
12+
13+
14+
def write_txt(content, filename, mode='w'):
15+
"""保存txt数据
16+
:param content:需要保存的数据,type->list
17+
:param filename:文件名
18+
:param mode:读写模式:'w' or 'a'
19+
:return: void
20+
"""
21+
with open(filename, mode) as f:
22+
for line in content:
23+
str_line = ""
24+
for col, data in enumerate(line):
25+
if not col == len(line) - 1:
26+
# 以空格作为分隔符
27+
str_line = str_line + str(data) + " "
28+
else:
29+
# 每行最后一个数据用换行符“\n”
30+
str_line = str_line + str(data) + "\n"
31+
f.write(str_line)
32+
33+
34+
def get_files_list(dir):
35+
'''
36+
实现遍历dir目录下,所有文件(包含子文件夹的文件)
37+
:param dir:指定文件夹目录
38+
:return:包含所有文件的列表->list
39+
'''
40+
# parent:父目录, filenames:该目录下所有文件夹,filenames:该目录下的文件名
41+
files_list = []
42+
for parent, dirnames, filenames in os.walk(dir):
43+
for filename in filenames:
44+
# print("parent is: " + parent)
45+
# print("filename is: " + filename)
46+
# print(os.path.join(parent, filename)) # 输出rootdir路径下所有文件(包含子文件)信息
47+
curr_file = parent.split(os.sep)[-1]
48+
if curr_file == 'flower':
49+
labels = 0
50+
elif curr_file == 'guitar':
51+
labels = 1
52+
elif curr_file == 'animal':
53+
labels = 2
54+
elif curr_file == 'houses':
55+
labels = 3
56+
elif curr_file == 'plane':
57+
labels = 4
58+
files_list.append([os.path.join(curr_file, filename), labels])
59+
return files_list
60+
61+
62+
if __name__ == '__main__':
63+
train_dir = 'dataset/train'
64+
train_txt = 'dataset/train.txt'
65+
train_data = get_files_list(train_dir)
66+
write_txt(train_data, train_txt, mode='w')
67+
68+
val_dir = 'dataset/val'
69+
val_txt = 'dataset/val.txt'
70+
val_data = get_files_list(val_dir)
71+
write_txt(val_data, val_txt, mode='w')
72+

0 commit comments

Comments
 (0)