-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added AlexNet implementation for Extended MNIST #19
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,294 @@ | ||
/* | ||
* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ======================================================================= | ||
*/ | ||
package org.tensorflow.model.examples.cnn.alexnet; | ||
|
||
import org.tensorflow.Graph; | ||
import org.tensorflow.Operand; | ||
import org.tensorflow.Session; | ||
import org.tensorflow.Tensor; | ||
import org.tensorflow.framework.optimizers.Adam; | ||
import org.tensorflow.framework.optimizers.Optimizer; | ||
import org.tensorflow.model.examples.datasets.ImageBatch; | ||
import org.tensorflow.model.examples.datasets.mnist.MnistDataset; | ||
import org.tensorflow.ndarray.ByteNdArray; | ||
import org.tensorflow.ndarray.FloatNdArray; | ||
import org.tensorflow.ndarray.Shape; | ||
import org.tensorflow.ndarray.index.Indices; | ||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.op.core.*; | ||
import org.tensorflow.op.math.Add; | ||
import org.tensorflow.op.math.Mean; | ||
import org.tensorflow.op.nn.Conv2d; | ||
import org.tensorflow.op.nn.MaxPool; | ||
import org.tensorflow.op.nn.Relu; | ||
import org.tensorflow.op.nn.LocalResponseNormalization; | ||
import org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits; | ||
import org.tensorflow.op.random.TruncatedNormal; | ||
import org.tensorflow.types.TFloat32; | ||
import org.tensorflow.types.TUint8; | ||
|
||
import java.util.Arrays; | ||
import java.util.logging.Level; | ||
import java.util.logging.Logger; | ||
|
||
/** | ||
* Describes the AlexNet Model. | ||
*/ | ||
public class AlexNetModel implements AutoCloseable { | ||
private static final int PIXEL_DEPTH = 255; | ||
private static final int NUM_CHANNELS = 1; | ||
private static final int IMAGE_SIZE = 28; | ||
private static final int NUM_LABELS = 26; | ||
private static final long SEED = 123456789L; | ||
|
||
private static final String PADDING_TYPE = "SAME"; | ||
public static final String INPUT_NAME = "input"; | ||
public static final String OUTPUT_NAME = "output"; | ||
public static final String TARGET = "target"; | ||
public static final String TRAIN = "train"; | ||
public static final String TRAINING_LOSS = "training_loss"; | ||
public static final String INIT = "init"; | ||
|
||
private static final Logger logger = Logger.getLogger(AlexNetModel.class.getName()); | ||
|
||
private final Graph graph; | ||
|
||
private final Session session; | ||
|
||
public AlexNetModel() { | ||
graph = compile(); | ||
session = new Session(graph); | ||
} | ||
|
||
public static Graph compile() { | ||
Graph graph = new Graph(); | ||
|
||
Ops tf = Ops.create(graph); | ||
|
||
// Inputs | ||
Placeholder<TUint8> input = tf.withName(INPUT_NAME).placeholder(TUint8.DTYPE, | ||
Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE))); | ||
Reshape<TUint8> input_reshaped = tf | ||
.reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)); | ||
Placeholder<TUint8> labels = tf.withName(TARGET).placeholder(TUint8.DTYPE); | ||
|
||
// Scaling the features | ||
Constant<TFloat32> centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); | ||
Constant<TFloat32> scalingFactor = tf.constant((float) PIXEL_DEPTH); | ||
Operand<TFloat32> scaledInput = tf.math | ||
.div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.DTYPE), centeringFactor), | ||
scalingFactor); | ||
|
||
//Layer 1 | ||
Relu<TFloat32> relu1 = alexNetConv2DLayer("1", tf, scaledInput, new int[]{11, 11, NUM_CHANNELS, 96}, 96); | ||
MaxPool<TFloat32> pool1 = alexNetMaxPool(tf, relu1); | ||
LocalResponseNormalization<TFloat32> norm1 = alexNetModelLRN(tf, pool1); | ||
|
||
//Layer 2 | ||
Relu<TFloat32> relu2 = alexNetConv2DLayer("2", tf, norm1, new int[]{5, 5, 96, 256}, 256); | ||
MaxPool<TFloat32> pool2 = alexNetMaxPool(tf, relu2); | ||
LocalResponseNormalization<TFloat32> norm2 = alexNetModelLRN(tf, pool2); | ||
|
||
//Layer 3 | ||
Relu<TFloat32> relu3 = alexNetConv2DLayer("3", tf, norm2, new int[]{3, 3, 256, 384}, 384); | ||
LocalResponseNormalization<TFloat32> norm3 = alexNetModelLRN(tf, relu3); | ||
|
||
//Layer 4 | ||
Relu<TFloat32> relu4 = alexNetConv2DLayer("4", tf, norm3, new int[]{3, 3, 384, 384}, 384); | ||
|
||
//Layer 5 | ||
Relu<TFloat32> relu5 = alexNetConv2DLayer("2", tf, relu4, new int[]{3, 3, 384, 256}, 256); | ||
MaxPool<TFloat32> pool5 = alexNetMaxPool(tf, relu5); | ||
LocalResponseNormalization<TFloat32> norm5 = alexNetModelLRN(tf, pool5); | ||
|
||
Reshape<TFloat32> flatten = alexNetFlatten(tf, pool5); | ||
|
||
Add<TFloat32> loss = buildFCLayersAndRegularization(tf, labels, flatten); | ||
|
||
Optimizer optimizer = new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f); | ||
|
||
optimizer.minimize(loss, TRAIN); | ||
|
||
tf.init(); | ||
|
||
return graph; | ||
} | ||
|
||
public static Add<TFloat32> buildFCLayersAndRegularization(Ops tf, Placeholder<TUint8> labels, Reshape<TFloat32> flatten) { | ||
int fcBiasShape = 500; | ||
int[] fcWeightShape = {4096, fcBiasShape}; | ||
|
||
Variable<TFloat32> fc1Weights = tf.variable(tf.math.mul(tf.random | ||
.truncatedNormal(tf.array(fcWeightShape), TFloat32.DTYPE, | ||
TruncatedNormal.seed(SEED)), tf.constant(0.1f))); | ||
Variable<TFloat32> fc1Biases = tf | ||
.variable(tf.fill(tf.array(new int[]{fcBiasShape}), tf.constant(0.1f))); | ||
Relu<TFloat32> fcRelu = tf.nn | ||
.relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); | ||
|
||
// Softmax layer | ||
Variable<TFloat32> fc2Weights = tf.variable(tf.math.mul(tf.random | ||
.truncatedNormal(tf.array(fcBiasShape, NUM_LABELS), TFloat32.DTYPE, | ||
TruncatedNormal.seed(SEED)), tf.constant(0.1f))); | ||
Variable<TFloat32> fc2Biases = tf | ||
.variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f))); | ||
|
||
Add<TFloat32> logits = tf.math.add(tf.linalg.matMul(fcRelu, fc2Weights), fc2Biases); | ||
|
||
// Predicted outputs | ||
tf.withName(OUTPUT_NAME).nn.softmax(logits); | ||
|
||
// Loss function & regularization | ||
OneHot<TFloat32> oneHot = tf | ||
.oneHot(labels, tf.constant(NUM_LABELS), tf.constant(1.0f), tf.constant(0.0f)); | ||
SoftmaxCrossEntropyWithLogits<TFloat32> batchLoss = tf.nn.raw | ||
.softmaxCrossEntropyWithLogits(logits, oneHot); | ||
Mean<TFloat32> labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); | ||
Add<TFloat32> regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's think, do we need regularization here, maybe better to remove it and refactor dense layers in separate procedures |
||
.add(tf.nn.l2Loss(fc1Biases), | ||
tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); | ||
return tf.withName(TRAINING_LOSS).math | ||
.add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); | ||
} | ||
|
||
public static Reshape<TFloat32> alexNetFlatten(Ops tf, MaxPool<TFloat32> pool5) { | ||
return tf.reshape(pool5, tf.concat(Arrays | ||
.asList(tf.slice(tf.shape(pool5), tf.array(new int[]{0}), tf.array(new int[]{1})), | ||
tf.array(new int[]{-1})), tf.constant(0))); | ||
} | ||
|
||
public static MaxPool<TFloat32> alexNetMaxPool(Ops tf, Relu<TFloat32> relu) { | ||
return tf.nn | ||
.maxPool(relu, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1), | ||
PADDING_TYPE); | ||
} | ||
|
||
private static LocalResponseNormalization<TFloat32> alexNetModelLRN(Ops tf, MaxPool<TFloat32> pool) { | ||
return tf.nn.localResponseNormalization(pool); | ||
} | ||
|
||
private static LocalResponseNormalization<TFloat32> alexNetModelLRN(Ops tf, Relu<TFloat32> relu) { | ||
return tf.nn.localResponseNormalization(relu); | ||
} | ||
|
||
public static Relu<TFloat32> alexNetConv2DLayer(String layerName, Ops tf, Operand<TFloat32> scaledInput, int[] convWeightsL1Shape, int convBiasL1Shape) { | ||
Variable<TFloat32> conv1Weights = tf.withName("conv2d_" + layerName).variable(tf.math.mul(tf.random | ||
.truncatedNormal(tf.array(convWeightsL1Shape), TFloat32.DTYPE, | ||
TruncatedNormal.seed(SEED)), tf.constant(0.1f))); | ||
Conv2d<TFloat32> conv = tf.nn | ||
.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); | ||
Variable<TFloat32> convBias = tf | ||
.withName("bias2d_" + layerName).variable(tf.fill(tf.array(new int[]{convBiasL1Shape}), tf.constant(0.0f))); | ||
return tf.nn.relu(tf.withName("biasAdd_" + layerName).nn.biasAdd(conv, convBias)); | ||
} | ||
|
||
public void train(MnistDataset dataset, int epochs, int minibatchSize) { | ||
// Initialises the parameters. | ||
session.runner().addTarget(INIT).run(); | ||
logger.info("Initialised the model parameters"); | ||
|
||
int interval = 0; | ||
// Train the model | ||
for (int i = 0; i < epochs; i++) { | ||
for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) { | ||
try (Tensor<TUint8> batchImages = TUint8.tensorOf(trainingBatch.images()); | ||
Tensor<TUint8> batchLabels = TUint8.tensorOf(trainingBatch.labels()); | ||
Tensor<TFloat32> loss = session.runner() | ||
.feed(TARGET, batchLabels) | ||
.feed(INPUT_NAME, batchImages) | ||
.addTarget(TRAIN) | ||
.fetch(TRAINING_LOSS) | ||
.run().get(0).expect(TFloat32.DTYPE)) { | ||
|
||
logger.log(Level.INFO, | ||
"Iteration = " + interval + ", training loss = " + loss.data().getFloat()); | ||
} | ||
interval++; | ||
} | ||
} | ||
} | ||
|
||
public void test(MnistDataset dataset, int minibatchSize) { | ||
int correctCount = 0; | ||
int[][] confusionMatrix = new int[NUM_LABELS + 1][NUM_LABELS + 1]; | ||
|
||
for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) { | ||
try (Tensor<TUint8> transformedInput = TUint8.tensorOf(trainingBatch.images()); | ||
Tensor<TFloat32> outputTensor = session.runner() | ||
.feed(INPUT_NAME, transformedInput) | ||
.fetch(OUTPUT_NAME).run().get(0).expect(TFloat32.DTYPE)) { | ||
|
||
ByteNdArray labelBatch = trainingBatch.labels(); | ||
for (int k = 0; k < labelBatch.shape().size(0); k++) { | ||
byte trueLabel = labelBatch.getByte(k); | ||
int predLabel; | ||
|
||
predLabel = argmax(outputTensor.data().slice(Indices.at(k), Indices.all())); | ||
if (predLabel == trueLabel) { | ||
correctCount++; | ||
} | ||
|
||
confusionMatrix[trueLabel][predLabel]++; | ||
} | ||
} | ||
} | ||
|
||
logger.info("Final accuracy = " + ((float) correctCount) / dataset.numTestingExamples()); | ||
|
||
StringBuilder sb = new StringBuilder(); | ||
sb.append("Label"); | ||
for (int i = 0; i < confusionMatrix.length; i++) { | ||
sb.append(String.format("%1$5s", "" + i)); | ||
} | ||
sb.append("\n"); | ||
|
||
for (int i = 0; i < confusionMatrix.length; i++) { | ||
sb.append(String.format("%1$5s", "" + i)); | ||
for (int j = 0; j < confusionMatrix[i].length; j++) { | ||
sb.append(String.format("%1$5s", "" + confusionMatrix[i][j])); | ||
} | ||
sb.append("\n"); | ||
} | ||
|
||
System.out.println(sb.toString()); | ||
} | ||
|
||
/** | ||
* Find the maximum probability and return it's index. | ||
* | ||
* @param probabilities The probabilites. | ||
* @return The index of the max. | ||
*/ | ||
public static int argmax(FloatNdArray probabilities) { | ||
float maxVal = Float.NEGATIVE_INFINITY; | ||
int idx = 0; | ||
for (int i = 0; i < probabilities.shape().size(0); i++) { | ||
float curVal = probabilities.getFloat(i); | ||
if (curVal > maxVal) { | ||
maxVal = curVal; | ||
idx = i; | ||
} | ||
} | ||
return idx; | ||
} | ||
|
||
@Override | ||
public void close() { | ||
session.close(); | ||
graph.close(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/* | ||
* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ======================================================================= | ||
*/ | ||
package org.tensorflow.model.examples.cnn.alexnet; | ||
|
||
import org.tensorflow.model.examples.datasets.mnist.MnistDataset; | ||
|
||
import java.util.logging.Logger; | ||
|
||
/** | ||
* Trains and evaluates AlexNet model on Extended-MNIST dataset. | ||
*/ | ||
public class AlexNetOnEMNIST { | ||
// Hyper-parameters | ||
public static final int EPOCHS = 1; | ||
public static final int BATCH_SIZE = 500; | ||
|
||
// Fashion MNIST dataset paths | ||
public static final String TRAINING_IMAGES_ARCHIVE = "emnist/emnist-letters-train-images-idx3-ubyte.gz"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have no strong position on new dataset addition. But I suppose the best cases here - to add a link on dataset and its creators (for example on paper https://arxiv.org/abs/1702.05373) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we keep training on Mnist dataset? |
||
public static final String TRAINING_LABELS_ARCHIVE = "emnist/emnist-letters-train-labels-idx1-ubyte.gz"; | ||
public static final String TEST_IMAGES_ARCHIVE = "emnist/emnist-letters-test-images-idx3-ubyte.gz"; | ||
public static final String TEST_LABELS_ARCHIVE = "emnist/emnist-letters-test-labels-idx1-ubyte.gz"; | ||
|
||
private static final Logger logger = Logger.getLogger(AlexNetOnEMNIST.class.getName()); | ||
|
||
public static void main(String[] args) { | ||
logger.info("Data loading."); | ||
MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE, TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE); | ||
|
||
try (AlexNetModel alexNetModel = new AlexNetModel()) { | ||
logger.info("Model training."); | ||
alexNetModel.train(dataset, EPOCHS, BATCH_SIZE); | ||
|
||
logger.info("Model evaluation."); | ||
alexNetModel.test(dataset, BATCH_SIZE); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dear @akshaybahadur21 could you please explain, why are you using LRN here (please share the link on AlexNet references)?