23
23
"red" ,
24
24
"rainbow" ,
25
25
]
26
+ dont_include = [
27
+ ("line" , "filled" , "green" ),
28
+ ("square" , "filled" , "green" ),
29
+ ("circle" , "filled" , "green" ),
30
+
31
+ ("line" , "filled" , "green" ),
32
+ ("square" , "filled" , "green" ),
33
+ ("circle" , "filled" , "green" ),
34
+ ]
26
35
27
36
def all_classes ():
28
37
return itertools .product (shape_types , line_types , colors )
@@ -35,8 +44,8 @@ def rainbow(shape, center_x, center_y, angle):
35
44
magnitudes = np .linalg .norm (coords , axis = - 1 )
36
45
37
46
h = angles / (2 * np .pi ) + 0.5
38
- s = np .ones ( h . shape )
39
- v = np .ones ( h . shape )
47
+ s = np .clip ( magnitudes * 2 , 0 , 1 )
48
+ v = np .ones_like ( angles )
40
49
41
50
hsv = np .stack ([h , s , v ], axis = - 1 )
42
51
hsv = (hsv * 255 ).astype (np .uint8 )
@@ -48,6 +57,14 @@ def rainbow(shape, center_x, center_y, angle):
48
57
49
58
return rgb
50
59
60
+ def new_line (d , center_x , center_y , radius , angle , fill , line_width ):
61
+ point1x = center_x + radius * np .cos (angle )
62
+ point1y = center_y + radius * np .sin (angle )
63
+ point2x = center_x + radius * np .cos (angle + np .pi )
64
+ point2y = center_y + radius * np .sin (angle + np .pi )
65
+
66
+ d .line ([(point1x , point1y ), (point2x ,point2y )], fill = fill , width = int (line_width ))
67
+
51
68
def square (d , center_x , center_y , radius , angle , fill ):
52
69
point1x = center_x + radius * np .cos (angle )
53
70
point1y = center_y + radius * np .sin (angle )
@@ -74,10 +91,16 @@ def circle(d, center_x, center_y, radius, angle, fill):
74
91
d .ellipse ([(center_x - radius , center_y - radius ), (center_x + radius , center_y + radius )], fill = fill )
75
92
76
93
# assumes 3 channels
77
- def shapes (images , params , min_radius , max_radius , line_width ):
78
- n_images = images .shape [0 ]
79
- image_width = images .shape [1 ]
80
- image_height = images .shape [2 ]
94
+ def shapes (params , draw_size , resize_to , min_radius , max_radius , line_width ):
95
+ n_images = len (params )
96
+ image_width = draw_size
97
+ image_height = draw_size
98
+
99
+
100
+ # images = np.zeros((len(params), resize_to, resize_to, 3), dtype=np.float)
101
+ images = np .random .random ([len (params ), resize_to , resize_to , 3 ])
102
+ # background = np.random.random([len(params), resize_to, resize_to, 3])
103
+ background = np .zeros ([len (params ), resize_to , resize_to , 3 ])
81
104
82
105
white = (255 , 255 , 255 )
83
106
red = (255 , 0 , 0 )
@@ -96,33 +119,41 @@ def shapes(images, params, min_radius, max_radius, line_width):
96
119
center_y = np .random .uniform (0 + radius + 2 , image_height - radius - 2 )
97
120
angle = np .random .uniform (- np .pi , np .pi )
98
121
99
- img = Image .new ("RGB" , (image_width , image_height ))
122
+ # img = Image.fromarray(noiseimages[i], "RGB")
123
+ img = Image .new ("RGB" , (draw_size , draw_size ))
100
124
d = ImageDraw .Draw (img )
101
125
102
126
fill = white
103
- if color == "red" :
104
- fill = red
105
- elif color == "green" :
106
- fill = green
107
- elif color == "blue" :
108
- fill = blue
127
+
128
+ if shape == "line" :
129
+ if line_type == "single" :
130
+ new_line (d , center_x , center_y , radius , angle , fill , line_width )
131
+ elif line_type == "filled" :
132
+ new_line (d , center_x , center_y , radius , angle , fill , line_width * 4 )
133
+ elif line_type == "double" :
134
+ center_offset_x = line_width * np .cos (angle + np .pi / 2 )
135
+ center_offset_y = line_width * np .sin (angle + np .pi / 2 )
136
+ new_line (d , center_x + center_offset_x , center_y + center_offset_y , radius , angle , fill , line_width )
137
+ new_line (d , center_x - center_offset_x , center_y - center_offset_y , radius , angle , fill , line_width )
138
+ pass
109
139
110
140
if shape == "tri" :
111
- print ( "tri" )
141
+ tri_line_width = line_width * 2
112
142
tri (d , center_x , center_y , radius , angle , fill )
113
143
if line_type != "filled" :
114
- tri (d , center_x , center_y , radius - line_width , angle , black )
144
+ tri (d , center_x , center_y , radius - tri_line_width , angle , black )
115
145
if line_type == "double" :
116
- tri (d , center_x , center_y , radius - line_width * 2 - 1 , angle , fill )
117
- tri (d , center_x , center_y , radius - line_width * 3 - 1 , angle , black )
146
+ tri (d , center_x , center_y , radius - tri_line_width * 2 - 1 , angle , fill )
147
+ tri (d , center_x , center_y , radius - tri_line_width * 3 - 1 , angle , black )
118
148
119
149
elif shape == "square" :
150
+ sq_line_width = line_width * 1.41
120
151
square (d , center_x , center_y , radius , angle , fill )
121
152
if line_type != "filled" :
122
- square (d , center_x , center_y , radius - line_width , angle , black )
153
+ square (d , center_x , center_y , radius - sq_line_width , angle , black )
123
154
if line_type == "double" :
124
- square (d , center_x , center_y , radius - line_width * 2 - 1 , angle , fill )
125
- square (d , center_x , center_y , radius - line_width * 3 - 1 , angle , black )
155
+ square (d , center_x , center_y , radius - sq_line_width * 2 - 1 , angle , fill )
156
+ square (d , center_x , center_y , radius - sq_line_width * 3 - 1 , angle , black )
126
157
127
158
elif shape == "circle" :
128
159
circle (d , center_x , center_y , radius , angle , fill )
@@ -131,24 +162,31 @@ def shapes(images, params, min_radius, max_radius, line_width):
131
162
if line_type == "double" :
132
163
circle (d , center_x , center_y , radius - line_width * 2 - 1 , angle , fill )
133
164
circle (d , center_x , center_y , radius - line_width * 3 - 1 , angle , black )
134
-
135
- display (img )
136
-
137
- imgdata = np .asarray (img )
165
+
166
+ img = img .resize ((resize_to , resize_to ))
167
+ mask = np .asarray (img ).astype (np .float ) / 255.0
138
168
139
169
if color == "rainbow" :
140
- x = imgdata .astype (np .float ) / 255.0
141
- x *= rbow
142
- imgdata = (x * 255 ).astype (np .uint8 )
170
+ images [i ] = rbow * mask + background [i ] * (1 - mask )
171
+ elif color == "red" :
172
+ images [i ] = np .array ([1 , 0 , 0 ]) * mask + background [i ] * (1 - mask )
173
+ elif color == "green" :
174
+ images [i ] = np .array ([0 , 1 , 0 ]) * mask + background [i ] * (1 - mask )
175
+ elif color == "blue" :
176
+ images [i ] = np .array ([0 , 0 , 1 ]) * mask + background [i ] * (1 - mask )
177
+ elif color == "white" :
178
+ images [i ] = mask + background [i ] * (1 - mask )
143
179
144
- rgbx = imgdata .astype (np .float ) / 255.0
145
- rgb = rgbx [:, :, :3 ]
146
- images [i ] = rgb
180
+ return images
147
181
148
182
def example_shapes ():
149
183
par = list (all_classes ())
150
- images = np .zeros ((len (par ), 100 , 100 , 3 ), dtype = np .float )
151
- shapes (images , par , min_radius = 30 , max_radius = 50 , line_width = 3 )
184
+ draw_size = 200
185
+ resize_to = 48
186
+ line_width = draw_size / 25
187
+ min_radius = line_width * 6
188
+ max_radius = min_radius * 1.5
189
+ images = shapes (par , draw_size , resize_to , min_radius = min_radius , max_radius = max_radius , line_width = line_width )
152
190
return images
153
191
154
192
def line (images , min_length = 48 , max_length = 48 ):
@@ -276,20 +314,24 @@ def triangle(images, min_size=48, max_size=48):
276
314
d .line ([tuple (starts [i ]), tuple (point2 [i ])], fill = 255 , width = 6 )
277
315
278
316
# make a convenient structure for our data
279
- def create_dataset_obj (x_all , y_all , z_all ):
317
+ def create_dataset_obj (x_all , y_all , z_all , n_classes ):
280
318
x_all = tf .convert_to_tensor (x_all )
281
319
y_all = tf .convert_to_tensor (y_all )
282
320
z_all = tf .convert_to_tensor (z_all )
283
- x_all = tf .expand_dims (x_all , - 1 )
284
321
285
- inds = np .indices ([len (x_all )])
322
+ inds = np .random .permutation (len (x_all ))
323
+ n_all = len (x_all )
324
+ n_test = len (x_all ) // 10
325
+ n_val = len (x_all ) // 10
326
+ n_train = n_all - n_test - n_val
286
327
# 80% train : 10% val : 10% test split
287
- train_indices = inds [inds % 10 >= 2 ]
288
- val_indices = inds [inds % 10 == 0 ]
289
- test_indices = inds [inds % 10 == 1 ]
328
+ train_indices = inds [: n_train ]
329
+ val_indices = inds [n_train : n_train + n_val ]
330
+ test_indices = inds [n_train + n_val : n_train + n_val + n_test ]
290
331
291
332
return {
292
333
"image_size" : x_all .shape [1 ],
334
+ "n_classes" : n_classes ,
293
335
294
336
"n_all" : len (x_all ),
295
337
"x_all" : x_all ,
@@ -312,33 +354,31 @@ def create_dataset_obj(x_all, y_all, z_all):
312
354
}
313
355
314
356
315
- def make_image_dataset (n_x_data = 10000 , n_z_data = 2000 , image_size = 24 , latent_dims = 6 , pixel_dtype = np .uint8 ):
316
-
317
- fns = [
318
- line ,
319
- rect ,
320
- circleold ,
321
- triangle ,
322
- ]
357
+ def make_image_dataset (n_x_data , n_z_data = 2000 , image_size = 24 , latent_dims = 6 , pixel_dtype = np .uint8 ):
323
358
324
- n_classes = len (fns )
359
+
360
+ n_classes = len (list (all_classes ()))
325
361
n_per_class = n_x_data // n_classes
326
- images = np .zeros ([n_x_data , image_size * 3 , image_size * 3 ], dtype = pixel_dtype )
327
-
328
- for i , fn in enumerate (fns ):
329
- start = i * n_per_class
330
- end = (i + 1 )* n_per_class
331
- fn (images [start :end ])
332
-
362
+
363
+ params = [par for par in all_classes ()] * n_per_class
364
+
333
365
class_labels = np .identity (n_classes )
334
- classes = class_labels [np .concatenate ([np .repeat (i , n_per_class ) for i in range (n_classes )])]
366
+ classes = [class_labels [i ] for i , par in enumerate (all_classes ())] * n_per_class
367
+
368
+ # class1_labels = np.identity(len(shape_types))
369
+ # class2_labels = np.identity(len(line_types))
370
+ # class3_labels = np.identity(len(colors))
371
+ # tclasses = [
372
+ # (class1_labels[shape_types.index(shape_type)], class2_labels[line_types.index(line_type)], class3_labels[colors.index(color)]) for shape_type, line_type, color in all_classes()
373
+ # ] * n_per_class
374
+
375
+ draw_size = 200
376
+ resize_to = image_size
377
+ line_width = draw_size / 25
378
+ min_radius = line_width * 6
379
+ max_radius = min_radius * 1.5
380
+ images = shapes (params , draw_size , resize_to , min_radius = min_radius , max_radius = max_radius , line_width = line_width )
335
381
336
382
gaussian_z = tf .random .normal ([n_z_data , latent_dims ])
337
383
338
- resized_images = np .zeros ([n_x_data , image_size , image_size ], dtype = np .float32 )
339
- for i in range (len (images )):
340
- img = Image .frombuffer ("L" , images [i ].shape , images [i ])
341
- resized = img .resize ((image_size , image_size ))
342
- resized_images [i ] = np .array (resized ).astype (np .float32 ) / 255.
343
-
344
- return create_dataset_obj (resized_images , classes , gaussian_z )
384
+ return create_dataset_obj (images , classes , gaussian_z , n_classes )
0 commit comments