|
1 | 1 | console.log("Hello Autoencoder 🚂");
|
2 | 2 |
|
3 |
| -import * as tf from '@tensorflow/tfjs-node' |
4 |
| - |
5 |
| -const autoencoder = tf.sequential(); |
6 |
| - |
7 |
| -const encoder = tf.layers.dense({ |
8 |
| - units: 32, |
9 |
| - inputShape: [784], |
10 |
| - activation: 'relu' |
11 |
| -}); |
12 |
| -const decoder = tf.layers.dense({ |
13 |
| - units: 784, |
14 |
| - activation: 'sigmoid' |
15 |
| - // inputShape: [32] |
16 |
| -}); |
17 |
| - |
18 |
| -autoencoder.add(encoder); |
19 |
| -autoencoder.add(decoder); |
20 |
| - |
21 |
| -autoencoder.compile({ |
22 |
| - optimizer: 'adam', |
23 |
| - loss: 'binaryCrossentropy', |
24 |
| - metrics: ['accuracy'], |
25 |
| -}); |
26 |
| - |
27 |
| - |
28 |
| -function generateImage() { |
29 |
| - const img = []; |
30 |
| - for (let i = 0; i < 784; i++) { |
31 |
| - img[i] = Math.random(); |
32 |
| - } |
33 |
| - return img; |
| 3 | +import * as tf from "@tensorflow/tfjs-node"; |
| 4 | +// import canvas from "canvas"; |
| 5 | +// const { loadImage } = canvas; |
| 6 | +import Jimp from "jimp"; |
| 7 | +import numeral from "numeral"; |
| 8 | + |
| 9 | +main(); |
| 10 | + |
| 11 | +async function main() { |
| 12 | + // Build the model |
| 13 | + const autoencoder = buildModel(); |
| 14 | + // load all image data |
| 15 | + const images = await loadImages(550); |
| 16 | + |
| 17 | + // train the model |
| 18 | + const x_train = tf.tensor2d(images.slice(500)); |
| 19 | + await trainModel(autoencoder, x_train, 250); |
| 20 | + |
| 21 | + // test the model |
| 22 | + const x_test = tf.tensor2d(images.slice(500, 550)); |
| 23 | + await generateTests(autoencoder, x_test); |
34 | 24 | }
|
35 | 25 |
|
36 |
| -const x_inputs = []; |
37 |
| -for (let i = 0; i < 1000; i++) { |
38 |
| - x_inputs[i] = generateImage(); |
| 26 | +async function generateTests(autoencoder, x_test) { |
| 27 | + const output = autoencoder.predict(x_test); |
| 28 | + // output.print(); |
| 29 | + |
| 30 | + const newImages = await output.array(); |
| 31 | + for (let i = 0; i < newImages.length; i++) { |
| 32 | + const img = newImages[i]; |
| 33 | + const buffer = []; |
| 34 | + for (let n = 0; n < img.length; n++) { |
| 35 | + const val = Math.floor(img[n] * 255); |
| 36 | + buffer[n * 4 + 0] = val; |
| 37 | + buffer[n * 4 + 1] = val; |
| 38 | + buffer[n * 4 + 2] = val; |
| 39 | + buffer[n * 4 + 3] = 255; |
| 40 | + } |
| 41 | + const image = new Jimp( |
| 42 | + { |
| 43 | + data: Buffer.from(buffer), |
| 44 | + width: 28, |
| 45 | + height: 28, |
| 46 | + }, |
| 47 | + (err, image) => { |
| 48 | + const num = numeral(i).format("000"); |
| 49 | + image.write(`output/square${num}.png`); |
| 50 | + } |
| 51 | + ); |
| 52 | + } |
39 | 53 | }
|
40 | 54 |
|
41 |
| -const x_train = tf.tensor2d(x_inputs); |
42 |
| -x_train.print(); |
43 |
| - |
44 |
| -trainModel(); |
| 55 | +function buildModel() { |
| 56 | + const autoencoder = tf.sequential(); |
| 57 | + // Build the model |
| 58 | + autoencoder.add( |
| 59 | + tf.layers.dense({ |
| 60 | + units: 256, |
| 61 | + inputShape: [784], |
| 62 | + activation: "relu", |
| 63 | + }) |
| 64 | + ); |
| 65 | + autoencoder.add( |
| 66 | + tf.layers.dense({ |
| 67 | + units: 128, |
| 68 | + activation: "relu", |
| 69 | + }) |
| 70 | + ); |
| 71 | + |
| 72 | + autoencoder.add( |
| 73 | + tf.layers.dense({ |
| 74 | + units: 256, |
| 75 | + activation: "sigmoid", |
| 76 | + }) |
| 77 | + ); |
| 78 | + |
| 79 | + autoencoder.add( |
| 80 | + tf.layers.dense({ |
| 81 | + units: 784, |
| 82 | + activation: "sigmoid", |
| 83 | + }) |
| 84 | + ); |
| 85 | + autoencoder.compile({ |
| 86 | + optimizer: "adam", |
| 87 | + loss: "binaryCrossentropy", |
| 88 | + metrics: ["accuracy"], |
| 89 | + }); |
| 90 | + return autoencoder; |
| 91 | +} |
45 | 92 |
|
46 |
| -async function trainModel() { |
| 93 | +async function trainModel(autoencoder, x_train, epochs) { |
47 | 94 | await autoencoder.fit(x_train, x_train, {
|
48 |
| - epochs: 50, |
49 |
| - batch_size: 256, |
| 95 | + epochs: epochs, |
| 96 | + batch_size: 32, |
50 | 97 | shuffle: true,
|
51 |
| - verbose: true |
| 98 | + verbose: true, |
52 | 99 | });
|
53 | 100 | }
|
54 | 101 |
|
55 |
| - |
56 |
| - |
57 |
| - |
58 |
| - |
| 102 | +async function loadImages(total) { |
| 103 | + const allImages = []; |
| 104 | + for (let i = 0; i < total; i++) { |
| 105 | + const num = numeral(i).format("000"); |
| 106 | + const img = await Jimp.read(`data/square${num}.png`); |
| 107 | + |
| 108 | + let rawData = []; |
| 109 | + for (let n = 0; n < 28 * 28; n++) { |
| 110 | + let index = n * 4; |
| 111 | + let r = img.bitmap.data[index + 0]; |
| 112 | + // let g = img.bitmap.data[n + 1]; |
| 113 | + // let b = img.bitmap.data[n + 2]; |
| 114 | + rawData[n] = r / 255.0; |
| 115 | + } |
| 116 | + allImages[i] = rawData; |
| 117 | + } |
| 118 | + return allImages; |
| 119 | +} |
0 commit comments