1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import os
4
+ from shutil import copyfile , rmtree
5
+ #数据打包处理,将原始图像数据集转换成Keras要求的格式:
6
+ #每一个子文件夹代表一类,其中有该类所有的图像数据
7
+ '''
8
+ data
9
+ customs folder
10
+ classA
11
+ image1
12
+ image2
13
+ ...
14
+ classB
15
+ image1
16
+ image2
17
+ ...
18
+ classC
19
+ ...
20
+ ...
21
+
22
+ research
23
+ test
24
+ 0
25
+ image1
26
+ image2
27
+ ...
28
+ 1
29
+ image1
30
+ image2
31
+ ...
32
+ 2
33
+ image1
34
+ image2
35
+ ...
36
+ train
37
+ 0
38
+ image1
39
+ image2
40
+ ...
41
+ 1
42
+ image1
43
+ image2
44
+ ...
45
+ 2
46
+ image1
47
+ image2
48
+ ...
49
+ validation
50
+ 0
51
+ image1
52
+ image2
53
+ ...
54
+ 1
55
+ image1
56
+ image2
57
+ ...
58
+ ...
59
+ '''
60
+ #图像原始数据文件夹
61
+ source_data_folder = "F://ai_data/camelyon17/train_data"
62
+ #新的文件夹
63
+ research_data_folder = "F://ai_data/camelyon17/research_data"
64
+ #类名文本
65
+ label_text_file = source_data_folder + "//labels.txt"
66
+
67
+ train_num = 210000 #用于训练的图像数目
68
+ val_num = 719 #用于训练测试的图像数目
69
+ test_num = 4000 #用于最终测试的图像数目
70
+
71
+ def convert_class_data ():
72
+ np .random .seed (0 ) #使用统一的Seed,保证每次随机的结果都相同
73
+ #打开已经生成的标签文件
74
+ label_file = open (label_text_file )
75
+ #按行读取标签文件中的文本信息
76
+ labels = label_file .readlines ()
77
+ #随机打乱标签文本信息的顺序
78
+ np .random .shuffle (labels )
79
+ current_i = 0
80
+
81
+ current_i = save_images (current_i = current_i , phase = "train" , d_size = train_num , labels = labels )
82
+ current_i = save_images (current_i = current_i , phase = "test" , d_size = test_num , labels = labels )
83
+ current_i = save_images (current_i = current_i , phase = "validation" , d_size = val_num , labels = labels )
84
+
85
+
86
+ def save_images (current_i , phase , d_size , labels ):
87
+ if phase == "train" : #选择存储训练集数据
88
+ dst_folder = research_data_folder + "\\ train\\ "
89
+ elif phase == "test" : #选择存储测试集数据
90
+ dst_folder = research_data_folder + "\\ test\\ "
91
+ elif phase == "validation" : #选择存储训练测试集数据
92
+ dst_folder = research_data_folder + "\\ validation\\ "
93
+ else :
94
+ print ("phase error : {0}" .format (phase ))
95
+ exit ()
96
+ #打开新的标签文本文件,准备录入不同数据集的标签信息,以作备用
97
+ label_file = open (research_data_folder + "\\ " + phase + "_label.txt" , mode = "w" )
98
+ for i in range (current_i , current_i + d_size ):
99
+ #获取被打乱顺序的标签
100
+ item = labels [i ]
101
+ #根据空格分割文件名称和类别名称
102
+ r = item .split (" " )
103
+ #获取文件名称
104
+ img_source_path = r [0 ]
105
+ #获取类别名称,注意需要把最后的换行符去掉
106
+ img_class_name = r [1 ].split ("\n " )[0 ]
107
+ #创建新的路径,以拷贝图像文件
108
+ img_dst_path = dst_folder + img_class_name + "\\ " + os .path .basename (img_source_path )
109
+ #如果新的路径不存在,则新建文件夹
110
+ if not os .path .exists (os .path .dirname (img_dst_path )):
111
+ os .makedirs (os .path .dirname (img_dst_path ))
112
+ #将文件拷贝到新的路径中
113
+ copyfile (img_source_path , img_dst_path )
114
+ print ("{0} copied" .format (img_dst_path ))
115
+ #顺手完成标签文本文件,以作备用
116
+ label_text = img_dst_path + " " + img_class_name + "\n "
117
+ #标签写入新的文本文件
118
+ label_file .write (label_text )
119
+ current_i = i
120
+ label_file .close ()
121
+ return current_i
122
+
123
+ def image_labeling ():
124
+ #数据目录
125
+ directories = []
126
+ #类别名称
127
+ class_names = []
128
+ #图像文件列表
129
+ image_filenames = []
130
+
131
+ #在数据根目录下寻找文件夹
132
+ for filename in os .listdir (source_data_folder ):
133
+ #定位当前文件夹
134
+ path = os .path .join (source_data_folder , filename )
135
+ #如果路径为path的是文件夹
136
+ if os .path .isdir (path ):
137
+ directories .append (path ) #录入数据目录
138
+
139
+ #循环数据目录文件夹
140
+ for i , directory in enumerate (directories ):
141
+ #在数据目录文件夹中遍历图像文件
142
+ for filename in os .listdir (directory ):
143
+ path = os .path .join (directory , filename )
144
+ #加入所有图像文件名
145
+ image_filenames .append (path )
146
+ #加入图像所对应的标签编号
147
+ class_names .append (str (i ))
148
+
149
+ #打开标签文本文件,准备录入标签数据
150
+ label_file = open (label_text_file , mode = "w" )
151
+ for idx , item in enumerate (image_filenames ):
152
+ text = item + " " + class_names [idx ] + "\n "
153
+ print (text )
154
+ label_file .write (text )
155
+ label_file .close ()
156
+
157
+ def main ():
158
+ print ("Start to convert data" )
159
+ image_labeling ()
160
+ convert_class_data ()
161
+
162
+
163
+ if __name__ == '__main__' :
164
+ main ()
165
+ if __name__ == '__main__' :
166
+ main ()
0 commit comments