Skip to content

Commit 9a373e8

Browse files
committed
Dump
1 parent 157bb9d commit 9a373e8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+2860
-493
lines changed

bad_mine.kra

519 KB
Binary file not shown.

bad_theirs.kra

1.78 MB
Binary file not shown.

data.py

Lines changed: 102 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@
2323
"red",
2424
"rainbow",
2525
]
26+
dont_include = [
27+
("line", "filled", "green"),
28+
("square", "filled", "green"),
29+
("circle", "filled", "green"),
30+
31+
("line", "filled", "green"),
32+
("square", "filled", "green"),
33+
("circle", "filled", "green"),
34+
]
2635

2736
def all_classes():
2837
return itertools.product(shape_types, line_types, colors)
@@ -35,8 +44,8 @@ def rainbow(shape, center_x, center_y, angle):
3544
magnitudes = np.linalg.norm(coords, axis=-1)
3645

3746
h = angles / (2*np.pi) + 0.5
38-
s = np.ones(h.shape)
39-
v = np.ones(h.shape)
47+
s = np.clip(magnitudes*2, 0, 1)
48+
v = np.ones_like(angles)
4049

4150
hsv = np.stack([h, s, v], axis=-1)
4251
hsv = (hsv * 255).astype(np.uint8)
@@ -48,6 +57,14 @@ def rainbow(shape, center_x, center_y, angle):
4857

4958
return rgb
5059

60+
def new_line(d, center_x, center_y, radius, angle, fill, line_width):
61+
point1x = center_x + radius * np.cos(angle)
62+
point1y = center_y + radius * np.sin(angle)
63+
point2x = center_x + radius * np.cos(angle+np.pi)
64+
point2y = center_y + radius * np.sin(angle+np.pi)
65+
66+
d.line([(point1x, point1y), (point2x,point2y)], fill=fill, width=int(line_width))
67+
5168
def square(d, center_x, center_y, radius, angle, fill):
5269
point1x = center_x + radius * np.cos(angle)
5370
point1y = center_y + radius * np.sin(angle)
@@ -74,10 +91,16 @@ def circle(d, center_x, center_y, radius, angle, fill):
7491
d.ellipse([(center_x - radius, center_y - radius), (center_x + radius, center_y + radius)], fill=fill)
7592

7693
# assumes 3 channels
77-
def shapes(images, params, min_radius, max_radius, line_width):
78-
n_images = images.shape[0]
79-
image_width = images.shape[1]
80-
image_height = images.shape[2]
94+
def shapes(params, draw_size, resize_to, min_radius, max_radius, line_width):
95+
n_images = len(params)
96+
image_width = draw_size
97+
image_height = draw_size
98+
99+
100+
# images = np.zeros((len(params), resize_to, resize_to, 3), dtype=np.float)
101+
images = np.random.random([len(params), resize_to, resize_to, 3])
102+
# background = np.random.random([len(params), resize_to, resize_to, 3])
103+
background = np.zeros([len(params), resize_to, resize_to, 3])
81104

82105
white = (255, 255, 255)
83106
red = (255, 0, 0)
@@ -96,33 +119,41 @@ def shapes(images, params, min_radius, max_radius, line_width):
96119
center_y = np.random.uniform(0+radius+2, image_height-radius-2)
97120
angle = np.random.uniform(-np.pi, np.pi)
98121

99-
img = Image.new("RGB", (image_width, image_height))
122+
# img = Image.fromarray(noiseimages[i], "RGB")
123+
img = Image.new("RGB", (draw_size, draw_size))
100124
d = ImageDraw.Draw(img)
101125

102126
fill = white
103-
if color == "red":
104-
fill = red
105-
elif color == "green":
106-
fill = green
107-
elif color == "blue":
108-
fill = blue
127+
128+
if shape == "line":
129+
if line_type == "single":
130+
new_line(d, center_x, center_y, radius, angle, fill, line_width)
131+
elif line_type == "filled":
132+
new_line(d, center_x, center_y, radius, angle, fill, line_width * 4)
133+
elif line_type == "double":
134+
center_offset_x = line_width * np.cos(angle + np.pi/2)
135+
center_offset_y = line_width * np.sin(angle + np.pi/2)
136+
new_line(d, center_x + center_offset_x, center_y + center_offset_y, radius, angle, fill, line_width)
137+
new_line(d, center_x - center_offset_x, center_y - center_offset_y, radius, angle, fill, line_width)
138+
pass
109139

110140
if shape == "tri":
111-
print("tri")
141+
tri_line_width = line_width * 2
112142
tri(d, center_x, center_y, radius, angle, fill)
113143
if line_type != "filled":
114-
tri(d, center_x, center_y, radius-line_width, angle, black)
144+
tri(d, center_x, center_y, radius-tri_line_width, angle, black)
115145
if line_type == "double":
116-
tri(d, center_x, center_y, radius-line_width*2-1, angle, fill)
117-
tri(d, center_x, center_y, radius-line_width*3-1, angle, black)
146+
tri(d, center_x, center_y, radius-tri_line_width*2-1, angle, fill)
147+
tri(d, center_x, center_y, radius-tri_line_width*3-1, angle, black)
118148

119149
elif shape == "square":
150+
sq_line_width = line_width * 1.41
120151
square(d, center_x, center_y, radius, angle, fill)
121152
if line_type != "filled":
122-
square(d, center_x, center_y, radius-line_width, angle, black)
153+
square(d, center_x, center_y, radius-sq_line_width, angle, black)
123154
if line_type == "double":
124-
square(d, center_x, center_y, radius-line_width*2-1, angle, fill)
125-
square(d, center_x, center_y, radius-line_width*3-1, angle, black)
155+
square(d, center_x, center_y, radius-sq_line_width*2-1, angle, fill)
156+
square(d, center_x, center_y, radius-sq_line_width*3-1, angle, black)
126157

127158
elif shape == "circle":
128159
circle(d, center_x, center_y, radius, angle, fill)
@@ -131,24 +162,31 @@ def shapes(images, params, min_radius, max_radius, line_width):
131162
if line_type == "double":
132163
circle(d, center_x, center_y, radius-line_width*2-1, angle, fill)
133164
circle(d, center_x, center_y, radius-line_width*3-1, angle, black)
134-
135-
display(img)
136-
137-
imgdata = np.asarray(img)
165+
166+
img = img.resize((resize_to, resize_to))
167+
mask = np.asarray(img).astype(np.float) / 255.0
138168

139169
if color == "rainbow":
140-
x = imgdata.astype(np.float) / 255.0
141-
x *= rbow
142-
imgdata = (x * 255).astype(np.uint8)
170+
images[i] = rbow * mask + background[i] * (1 - mask)
171+
elif color == "red":
172+
images[i] = np.array([1, 0, 0]) * mask + background[i] * (1 - mask)
173+
elif color == "green":
174+
images[i] = np.array([0, 1, 0]) * mask + background[i] * (1 - mask)
175+
elif color == "blue":
176+
images[i] = np.array([0, 0, 1]) * mask + background[i] * (1 - mask)
177+
elif color == "white":
178+
images[i] = mask + background[i] * (1 - mask)
143179

144-
rgbx = imgdata.astype(np.float) / 255.0
145-
rgb = rgbx[:, :, :3]
146-
images[i] = rgb
180+
return images
147181

148182
def example_shapes():
149183
par = list(all_classes())
150-
images = np.zeros((len(par), 100, 100, 3), dtype=np.float)
151-
shapes(images, par, min_radius=30, max_radius=50, line_width=3)
184+
draw_size = 200
185+
resize_to = 48
186+
line_width = draw_size / 25
187+
min_radius = line_width * 6
188+
max_radius = min_radius * 1.5
189+
images = shapes(par, draw_size, resize_to, min_radius=min_radius, max_radius=max_radius, line_width=line_width)
152190
return images
153191

154192
def line(images, min_length=48, max_length=48):
@@ -276,20 +314,24 @@ def triangle(images, min_size=48, max_size=48):
276314
d.line([tuple(starts[i]), tuple(point2[i])], fill=255, width=6)
277315

278316
# make a convenient structure for our data
279-
def create_dataset_obj(x_all, y_all, z_all):
317+
def create_dataset_obj(x_all, y_all, z_all, n_classes):
280318
x_all = tf.convert_to_tensor(x_all)
281319
y_all = tf.convert_to_tensor(y_all)
282320
z_all = tf.convert_to_tensor(z_all)
283-
x_all = tf.expand_dims(x_all, -1)
284321

285-
inds = np.indices([len(x_all)])
322+
inds = np.random.permutation(len(x_all))
323+
n_all = len(x_all)
324+
n_test = len(x_all) // 10
325+
n_val = len(x_all) // 10
326+
n_train = n_all - n_test - n_val
286327
# 80% train : 10% val : 10% test split
287-
train_indices = inds[inds % 10 >= 2]
288-
val_indices = inds[inds % 10 == 0]
289-
test_indices = inds[inds % 10 == 1]
328+
train_indices = inds[:n_train]
329+
val_indices = inds[n_train:n_train+n_val]
330+
test_indices = inds[n_train+n_val:n_train+n_val+n_test]
290331

291332
return {
292333
"image_size": x_all.shape[1],
334+
"n_classes": n_classes,
293335

294336
"n_all": len(x_all),
295337
"x_all": x_all,
@@ -312,33 +354,31 @@ def create_dataset_obj(x_all, y_all, z_all):
312354
}
313355

314356

315-
def make_image_dataset(n_x_data=10000, n_z_data=2000, image_size=24, latent_dims=6, pixel_dtype=np.uint8):
316-
317-
fns = [
318-
line,
319-
rect,
320-
circleold,
321-
triangle,
322-
]
357+
def make_image_dataset(n_x_data, n_z_data=2000, image_size=24, latent_dims=6, pixel_dtype=np.uint8):
323358

324-
n_classes = len(fns)
359+
360+
n_classes = len(list(all_classes()))
325361
n_per_class = n_x_data // n_classes
326-
images = np.zeros([n_x_data, image_size*3, image_size*3], dtype=pixel_dtype)
327-
328-
for i, fn in enumerate(fns):
329-
start = i*n_per_class
330-
end = (i+1)*n_per_class
331-
fn(images[start:end])
332-
362+
363+
params = [par for par in all_classes()] * n_per_class
364+
333365
class_labels = np.identity(n_classes)
334-
classes = class_labels[np.concatenate([np.repeat(i, n_per_class) for i in range(n_classes)])]
366+
classes = [class_labels[i] for i, par in enumerate(all_classes())] * n_per_class
367+
368+
# class1_labels = np.identity(len(shape_types))
369+
# class2_labels = np.identity(len(line_types))
370+
# class3_labels = np.identity(len(colors))
371+
# tclasses = [
372+
# (class1_labels[shape_types.index(shape_type)], class2_labels[line_types.index(line_type)], class3_labels[colors.index(color)]) for shape_type, line_type, color in all_classes()
373+
# ] * n_per_class
374+
375+
draw_size = 200
376+
resize_to = image_size
377+
line_width = draw_size / 25
378+
min_radius = line_width * 6
379+
max_radius = min_radius * 1.5
380+
images = shapes(params, draw_size, resize_to, min_radius=min_radius, max_radius=max_radius, line_width=line_width)
335381

336382
gaussian_z = tf.random.normal([n_z_data, latent_dims])
337383

338-
resized_images = np.zeros([n_x_data, image_size, image_size], dtype=np.float32)
339-
for i in range(len(images)):
340-
img = Image.frombuffer("L", images[i].shape, images[i])
341-
resized = img.resize((image_size, image_size))
342-
resized_images[i] = np.array(resized).astype(np.float32) / 255.
343-
344-
return create_dataset_obj(resized_images, classes, gaussian_z)
384+
return create_dataset_obj(images, classes, gaussian_z, n_classes)

figures/bad_mine.png

89.8 KB

figures/chan1.png

8.79 KB

figures/chan2.png

9.06 KB

figures/chan3.png

7.13 KB

figures/chan4.png

4.74 KB

figures/chan5.png

8.5 KB

figures/chan6.png

7.29 KB

figures/chan7.png

8.26 KB

figures/chan8.png

6.14 KB

figures/circle.png

8.26 KB

figures/circle2.png

10 KB

figures/images.png

35.7 KB

figures/images_autoencoder.png

44.7 KB

figures/latent_space.png

2.33 MB

figures/line.png

7.54 KB

figures/line2.png

11.4 KB

figures/naive_chan1.png

2.31 KB

figures/naive_chan2.png

1.9 KB

figures/naive_chan3.png

1.68 KB

figures/naive_chan4.png

1.27 KB

figures/naive_chan5.png

2.39 KB

figures/naive_chan6.png

2.14 KB

figures/naive_chan7.png

3.25 KB

figures/naive_chan8.png

1.47 KB

figures/naive_circle.png

3.27 KB

figures/naive_line.png

3.62 KB

figures/naive_square.png

3.1 KB

figures/naive_triangle.png

3.68 KB

figures/search_data_1.png

1.48 KB

figures/search_data_2.png

1.34 KB

figures/search_data_3.png

1.27 KB

figures/search_data_4.png

829 Bytes

figures/search_data_targ.png

1.05 KB

figures/search_gen_1.png

1.55 KB

figures/search_gen_2.png

1.27 KB

figures/search_gen_3.png

1.21 KB

figures/search_gen_4.png

1.33 KB

figures/search_gen_targ.png

1.05 KB

figures/square.png

8.06 KB

figures/square2.png

10.6 KB

figures/train01.png

514 Bytes

figures/train02.png

545 Bytes

figures/train03.png

545 Bytes

figures/train04.png

461 Bytes

figures/train05.png

528 Bytes

figures/train06.png

494 Bytes

figures/train07.png

503 Bytes

figures/train08.png

492 Bytes

figures/train09.png

487 Bytes

figures/train10.png

473 Bytes

figures/train11.png

961 Bytes

figures/train12.png

934 Bytes

figures/train13.png

976 Bytes

figures/train14.png

974 Bytes

figures/triangle.png

7.33 KB

figures/triangle2.png

11.7 KB

figures/ttrain1.png

937 Bytes

figures/ttrain10.png

916 Bytes

figures/ttrain11.png

902 Bytes

figures/ttrain12.png

899 Bytes

figures/ttrain2.png

899 Bytes

figures/ttrain3.png

887 Bytes

figures/ttrain4.png

878 Bytes

figures/ttrain5.png

846 Bytes

figures/ttrain6.png

849 Bytes

figures/ttrain7.png

909 Bytes

figures/ttrain8.png

935 Bytes

figures/ttrain9.png

952 Bytes

0 commit comments

Comments
 (0)