Skip to content

Commit caa1095

Browse files
perhapszzyperhapszzy
perhapszzy
authored and
perhapszzy
committed
add more examples
1 parent ab35000 commit caa1095

9 files changed

+1586
-30
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ Deep_Learning_with_TensorFlow/datasets/flower_processed_data.npy
77
Deep_Learning_with_TensorFlow/1.4.0/Chapter05/5. MNIST\346\234\200\344\275\263\345\256\236\350\267\265/MNIST_model/*
88
Deep_Learning_with_TensorFlow/1.4.0/Chapter07/output.tfrecords
99
Deep_Learning_with_TensorFlow/1.4.0/Chapter07/data.tfrecords*
10+
Deep_Learning_with_TensorFlow/1.4.0/Chapter07/output_test.tfrecords
11+
Deep_Learning_with_TensorFlow/1.4.0/Chapter07/test1.txt
12+
Deep_Learning_with_TensorFlow/1.4.0/Chapter07/test2.txt
13+
Deep_Learning_with_TensorFlow/1.4.0/Chapter08/sin.png
1014
Deep_Learning_with_TensorFlow/1.4.0/Chapter10/log/*
1115
Deep_Learning_with_TensorFlow/1.4.0/Chapter11/log/*
1216
.DS_Store

Deep_Learning_with_TensorFlow/1.4.0/Chapter07/.ipynb_checkpoints/1. TFRecord样例程序-checkpoint.ipynb

+32-15
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
"Extracting ../../datasets/MNIST_data/train-labels-idx1-ubyte.gz\n",
3434
"Extracting ../../datasets/MNIST_data/t10k-images-idx3-ubyte.gz\n",
3535
"Extracting ../../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz\n",
36-
"TFRecord文件已保存。\n"
36+
"TFRecord训练文件已保存。\n",
37+
"TFRecord测试文件已保存。\n"
3738
]
3839
}
3940
],
@@ -45,27 +46,43 @@
4546
"def _bytes_feature(value):\n",
4647
" return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
4748
"\n",
48-
"# 读取mnist数据。\n",
49+
"# 将数据转化为tf.train.Example格式。\n",
50+
"def _make_example(pixels, label, image):\n",
51+
" image_raw = image.tostring()\n",
52+
" example = tf.train.Example(features=tf.train.Features(feature={\n",
53+
" 'pixels': _int64_feature(pixels),\n",
54+
" 'label': _int64_feature(np.argmax(label)),\n",
55+
" 'image_raw': _bytes_feature(image_raw)\n",
56+
" }))\n",
57+
" return example\n",
58+
"\n",
59+
"# 读取mnist训练数据。\n",
4960
"mnist = input_data.read_data_sets(\"../../datasets/MNIST_data\",dtype=tf.uint8, one_hot=True)\n",
5061
"images = mnist.train.images\n",
5162
"labels = mnist.train.labels\n",
5263
"pixels = images.shape[1]\n",
5364
"num_examples = mnist.train.num_examples\n",
5465
"\n",
55-
"# 输出TFRecord文件的地址。\n",
56-
"filename = \"output.tfrecords\"\n",
57-
"writer = tf.python_io.TFRecordWriter(filename)\n",
58-
"for index in range(num_examples):\n",
59-
" image_raw = images[index].tostring()\n",
66+
"# 输出包含训练数据的TFRecord文件。\n",
67+
"with tf.python_io.TFRecordWriter(\"output.tfrecords\") as writer:\n",
68+
" for index in range(num_examples):\n",
69+
" example = _make_example(pixels, labels[index], images[index])\n",
70+
" writer.write(example.SerializeToString())\n",
71+
"print(\"TFRecord训练文件已保存。\")\n",
6072
"\n",
61-
" example = tf.train.Example(features=tf.train.Features(feature={\n",
62-
" 'pixels': _int64_feature(pixels),\n",
63-
" 'label': _int64_feature(np.argmax(labels[index])),\n",
64-
" 'image_raw': _bytes_feature(image_raw)\n",
65-
" }))\n",
66-
" writer.write(example.SerializeToString())\n",
67-
"writer.close()\n",
68-
"print \"TFRecord文件已保存。\""
73+
"# 读取mnist测试数据。\n",
74+
"images_test = mnist.test.images\n",
75+
"labels_test = mnist.test.labels\n",
76+
"pixels_test = images_test.shape[1]\n",
77+
"num_examples_test = mnist.test.num_examples\n",
78+
"\n",
79+
"# 输出包含测试数据的TFRecord文件。\n",
80+
"with tf.python_io.TFRecordWriter(\"output_test.tfrecords\") as writer:\n",
81+
" for index in range(num_examples_test):\n",
82+
" example = _make_example(\n",
83+
" pixels_test, labels_test[index], images_test[index])\n",
84+
" writer.write(example.SerializeToString())\n",
85+
"print(\"TFRecord测试文件已保存。\")"
6986
]
7087
},
7188
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {
7+
"collapsed": true
8+
},
9+
"outputs": [],
10+
"source": [
11+
"import tempfile\n",
12+
"import tensorflow as tf"
13+
]
14+
},
15+
{
16+
"cell_type": "markdown",
17+
"metadata": {},
18+
"source": [
19+
"#### 1. 从数组创建数据集。"
20+
]
21+
},
22+
{
23+
"cell_type": "code",
24+
"execution_count": 2,
25+
"metadata": {},
26+
"outputs": [
27+
{
28+
"name": "stdout",
29+
"output_type": "stream",
30+
"text": [
31+
"1\n",
32+
"4\n",
33+
"9\n",
34+
"25\n",
35+
"64\n"
36+
]
37+
}
38+
],
39+
"source": [
40+
"input_data = [1, 2, 3, 5, 8]\n",
41+
"dataset = tf.data.Dataset.from_tensor_slices(input_data)\n",
42+
"\n",
43+
"# 定义迭代器。\n",
44+
"iterator = dataset.make_one_shot_iterator()\n",
45+
"\n",
46+
"# get_next() 返回代表一个输入数据的张量。\n",
47+
"x = iterator.get_next()\n",
48+
"y = x * x\n",
49+
"\n",
50+
"with tf.Session() as sess:\n",
51+
" for i in range(len(input_data)):\n",
52+
" print(sess.run(y))\n"
53+
]
54+
},
55+
{
56+
"cell_type": "markdown",
57+
"metadata": {},
58+
"source": [
59+
"#### 2. 读取文本文件里的数据。"
60+
]
61+
},
62+
{
63+
"cell_type": "code",
64+
"execution_count": 3,
65+
"metadata": {},
66+
"outputs": [
67+
{
68+
"name": "stdout",
69+
"output_type": "stream",
70+
"text": [
71+
"File1, line1.\n",
72+
"File1, line2.\n",
73+
"File2, line1.\n",
74+
"File2, line2.\n"
75+
]
76+
}
77+
],
78+
"source": [
79+
"# 创建文本文件作为本例的输入。\n",
80+
"with open(\"./test1.txt\", \"w\") as file:\n",
81+
" file.write(\"File1, line1.\\n\") \n",
82+
" file.write(\"File1, line2.\\n\")\n",
83+
"with open(\"./test2.txt\", \"w\") as file:\n",
84+
" file.write(\"File2, line1.\\n\") \n",
85+
" file.write(\"File2, line2.\\n\")\n",
86+
"\n",
87+
"# 从文本文件创建数据集。这里可以提供多个文件。\n",
88+
"input_files = [\"./test1.txt\", \"./test2.txt\"]\n",
89+
"dataset = tf.data.TextLineDataset(input_files)\n",
90+
"\n",
91+
"# 定义迭代器。\n",
92+
"iterator = dataset.make_one_shot_iterator()\n",
93+
"\n",
94+
"# 这里get_next()返回一个字符串类型的张量,代表文件中的一行。\n",
95+
"x = iterator.get_next() \n",
96+
"with tf.Session() as sess:\n",
97+
" for i in range(4):\n",
98+
" print(sess.run(x))\n"
99+
]
100+
},
101+
{
102+
"cell_type": "markdown",
103+
"metadata": {},
104+
"source": [
105+
"#### 3. 解析TFRecord文件里的数据。读取文件为本章第一节创建的文件。"
106+
]
107+
},
108+
{
109+
"cell_type": "code",
110+
"execution_count": 4,
111+
"metadata": {},
112+
"outputs": [
113+
{
114+
"name": "stdout",
115+
"output_type": "stream",
116+
"text": [
117+
"7\n",
118+
"3\n",
119+
"4\n",
120+
"6\n",
121+
"1\n",
122+
"8\n",
123+
"1\n",
124+
"0\n",
125+
"9\n",
126+
"8\n"
127+
]
128+
}
129+
],
130+
"source": [
131+
"# 解析一个TFRecord的方法。\n",
132+
"def parser(record):\n",
133+
" features = tf.parse_single_example(\n",
134+
" record,\n",
135+
" features={\n",
136+
" 'image_raw':tf.FixedLenFeature([],tf.string),\n",
137+
" 'pixels':tf.FixedLenFeature([],tf.int64),\n",
138+
" 'label':tf.FixedLenFeature([],tf.int64)\n",
139+
" })\n",
140+
" decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)\n",
141+
" retyped_images = tf.cast(decoded_images, tf.float32)\n",
142+
" images = tf.reshape(retyped_images, [784])\n",
143+
" labels = tf.cast(features['label'],tf.int32)\n",
144+
" #pixels = tf.cast(features['pixels'],tf.int32)\n",
145+
" return images, labels\n",
146+
"\n",
147+
"# 从TFRecord文件创建数据集。这里可以提供多个文件。\n",
148+
"input_files = [\"output.tfrecords\"]\n",
149+
"dataset = tf.data.TFRecordDataset(input_files)\n",
150+
"\n",
151+
"# map()函数表示对数据集中的每一条数据进行调用解析方法。\n",
152+
"dataset = dataset.map(parser)\n",
153+
"\n",
154+
"# 定义遍历数据集的迭代器。\n",
155+
"iterator = dataset.make_one_shot_iterator()\n",
156+
"\n",
157+
"# 读取数据,可用于进一步计算\n",
158+
"image, label = iterator.get_next()\n",
159+
"\n",
160+
"with tf.Session() as sess:\n",
161+
" for i in range(10):\n",
162+
" x, y = sess.run([image, label]) \n",
163+
" print(y)\n"
164+
]
165+
},
166+
{
167+
"cell_type": "markdown",
168+
"metadata": {},
169+
"source": [
170+
"#### 4. 使用initializable_iterator来动态初始化数据集。"
171+
]
172+
},
173+
{
174+
"cell_type": "code",
175+
"execution_count": 5,
176+
"metadata": {
177+
"collapsed": true
178+
},
179+
"outputs": [],
180+
"source": [
181+
"# 从TFRecord文件创建数据集,具体文件路径是一个placeholder,稍后再提供具体路径。\n",
182+
"input_files = tf.placeholder(tf.string)\n",
183+
"dataset = tf.data.TFRecordDataset(input_files)\n",
184+
"dataset = dataset.map(parser)\n",
185+
"\n",
186+
"# 定义遍历dataset的initializable_iterator。\n",
187+
"iterator = dataset.make_initializable_iterator()\n",
188+
"image, label = iterator.get_next()\n",
189+
"\n",
190+
"with tf.Session() as sess:\n",
191+
" # 首先初始化iterator,并给出input_files的值。\n",
192+
" sess.run(iterator.initializer,\n",
193+
" feed_dict={input_files: [\"output.tfrecords\"]})\n",
194+
" # 遍历所有数据一个epoch。当遍历结束时,程序会抛出OutOfRangeError。\n",
195+
" while True:\n",
196+
" try:\n",
197+
" x, y = sess.run([image, label])\n",
198+
" except tf.errors.OutOfRangeError:\n",
199+
" break \n"
200+
]
201+
}
202+
],
203+
"metadata": {
204+
"kernelspec": {
205+
"display_name": "Python 2",
206+
"language": "python",
207+
"name": "python2"
208+
},
209+
"language_info": {
210+
"codemirror_mode": {
211+
"name": "ipython",
212+
"version": 2
213+
},
214+
"file_extension": ".py",
215+
"mimetype": "text/x-python",
216+
"name": "python",
217+
"nbconvert_exporter": "python",
218+
"pygments_lexer": "ipython2",
219+
"version": "2.7.13"
220+
}
221+
},
222+
"nbformat": 4,
223+
"nbformat_minor": 1
224+
}

0 commit comments

Comments
 (0)