Skip to content

Commit b8fab92

Browse files
committed
code after part 2
1 parent 81abcb9 commit b8fab92

File tree

6 files changed

+1617
-50
lines changed

6 files changed

+1617
-50
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
data/
22
*.png
3+
*.jpg
34
node_modules/
45
.env
56
.DS_Store
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
int counter = 0;
2+
void setup() {
3+
size(280, 280);
4+
}
5+
6+
void draw() {
7+
background(255);
8+
float r = random(100, 200);
9+
strokeWeight(16);
10+
rectMode(CENTER);
11+
square(width/2, height/2, r);
12+
PImage img = get();
13+
img.resize(28, 28);
14+
img.save("data/square" + nf(counter,3) + ".png");
15+
counter++;
16+
if (counter == 550) {
17+
exit();
18+
}
19+
// noLoop();
20+
}

Diff for: index.js

+107-46
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,119 @@
11
console.log("Hello Autoencoder 🚂");
22

3-
import * as tf from '@tensorflow/tfjs-node'
4-
5-
const autoencoder = tf.sequential();
6-
7-
const encoder = tf.layers.dense({
8-
units: 32,
9-
inputShape: [784],
10-
activation: 'relu'
11-
});
12-
const decoder = tf.layers.dense({
13-
units: 784,
14-
activation: 'sigmoid'
15-
// inputShape: [32]
16-
});
17-
18-
autoencoder.add(encoder);
19-
autoencoder.add(decoder);
20-
21-
autoencoder.compile({
22-
optimizer: 'adam',
23-
loss: 'binaryCrossentropy',
24-
metrics: ['accuracy'],
25-
});
26-
27-
28-
function generateImage() {
29-
const img = [];
30-
for (let i = 0; i < 784; i++) {
31-
img[i] = Math.random();
32-
}
33-
return img;
3+
import * as tf from "@tensorflow/tfjs-node";
4+
// import canvas from "canvas";
5+
// const { loadImage } = canvas;
6+
import Jimp from "jimp";
7+
import numeral from "numeral";
8+
9+
main();
10+
11+
async function main() {
12+
// Build the model
13+
const autoencoder = buildModel();
14+
// load all image data
15+
const images = await loadImages(550);
16+
17+
// train the model
18+
const x_train = tf.tensor2d(images.slice(500));
19+
await trainModel(autoencoder, x_train, 250);
20+
21+
// test the model
22+
const x_test = tf.tensor2d(images.slice(500, 550));
23+
await generateTests(autoencoder, x_test);
3424
}
3525

36-
const x_inputs = [];
37-
for (let i = 0; i < 1000; i++) {
38-
x_inputs[i] = generateImage();
26+
async function generateTests(autoencoder, x_test) {
27+
const output = autoencoder.predict(x_test);
28+
// output.print();
29+
30+
const newImages = await output.array();
31+
for (let i = 0; i < newImages.length; i++) {
32+
const img = newImages[i];
33+
const buffer = [];
34+
for (let n = 0; n < img.length; n++) {
35+
const val = Math.floor(img[n] * 255);
36+
buffer[n * 4 + 0] = val;
37+
buffer[n * 4 + 1] = val;
38+
buffer[n * 4 + 2] = val;
39+
buffer[n * 4 + 3] = 255;
40+
}
41+
const image = new Jimp(
42+
{
43+
data: Buffer.from(buffer),
44+
width: 28,
45+
height: 28,
46+
},
47+
(err, image) => {
48+
const num = numeral(i).format("000");
49+
image.write(`output/square${num}.png`);
50+
}
51+
);
52+
}
3953
}
4054

41-
const x_train = tf.tensor2d(x_inputs);
42-
x_train.print();
43-
44-
trainModel();
55+
function buildModel() {
56+
const autoencoder = tf.sequential();
57+
// Build the model
58+
autoencoder.add(
59+
tf.layers.dense({
60+
units: 256,
61+
inputShape: [784],
62+
activation: "relu",
63+
})
64+
);
65+
autoencoder.add(
66+
tf.layers.dense({
67+
units: 128,
68+
activation: "relu",
69+
})
70+
);
71+
72+
autoencoder.add(
73+
tf.layers.dense({
74+
units: 256,
75+
activation: "sigmoid",
76+
})
77+
);
78+
79+
autoencoder.add(
80+
tf.layers.dense({
81+
units: 784,
82+
activation: "sigmoid",
83+
})
84+
);
85+
autoencoder.compile({
86+
optimizer: "adam",
87+
loss: "binaryCrossentropy",
88+
metrics: ["accuracy"],
89+
});
90+
return autoencoder;
91+
}
4592

46-
async function trainModel() {
93+
async function trainModel(autoencoder, x_train, epochs) {
4794
await autoencoder.fit(x_train, x_train, {
48-
epochs: 50,
49-
batch_size: 256,
95+
epochs: epochs,
96+
batch_size: 32,
5097
shuffle: true,
51-
verbose: true
98+
verbose: true,
5299
});
53100
}
54101

55-
56-
57-
58-
102+
async function loadImages(total) {
103+
const allImages = [];
104+
for (let i = 0; i < total; i++) {
105+
const num = numeral(i).format("000");
106+
const img = await Jimp.read(`data/square${num}.png`);
107+
108+
let rawData = [];
109+
for (let n = 0; n < 28 * 28; n++) {
110+
let index = n * 4;
111+
let r = img.bitmap.data[index + 0];
112+
// let g = img.bitmap.data[n + 1];
113+
// let b = img.bitmap.data[n + 2];
114+
rawData[n] = r / 255.0;
115+
}
116+
allImages[i] = rawData;
117+
}
118+
return allImages;
119+
}

0 commit comments

Comments
 (0)