Skip to content

Adapt to TFJava 0.5.0 #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorflow-examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
<!-- The maven compiler plugin defaults to a lower version -->
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<tensorflow.version>0.4.0</tensorflow.version>
<tensorflow.version>0.5.0</tensorflow.version>
</properties>

<dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,17 @@ The given SavedModel SignatureDef contains the following output(s):
but again the actual tensor is DT_FLOAT according to saved_model_cli.
*/


import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;

import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Result;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
Expand Down Expand Up @@ -228,16 +229,16 @@ public class FasterRcnnInception {
};

public static void main(String[] params) {

if (params.length != 2) {
throw new IllegalArgumentException("Exactly 2 parameters required !");
}

//my output image
String outputImagePath = params[1];
//my test image
String imagePath = params[0];
// get path to model folder
String modelPath = "models/faster_rcnn_inception_resnet_v2_1024x1024";
String modelPath = "models/faster_rcnn_inception_resnet_v2_1024x1024_1";
// load saved model
SavedModelBundle model = SavedModelBundle.load(modelPath, "serve");
//create a map of the COCO 2017 labels
Expand Down Expand Up @@ -268,17 +269,14 @@ public static void main(String[] params) {
Map<String, Tensor> feedDict = new HashMap<>();
//The given SavedModel SignatureDef input
feedDict.put("input_tensor", reshapeTensor);
//The given SavedModel MetaGraphDef key
Map<String, Tensor> outputTensorMap = model.function("serving_default").call(feedDict);
//detection_classes, detectionBoxes etc. are model output names
try (TFloat32 detectionClasses = (TFloat32) outputTensorMap.get("detection_classes");
TFloat32 detectionBoxes = (TFloat32) outputTensorMap.get("detection_boxes");
TFloat32 rawDetectionBoxes = (TFloat32) outputTensorMap.get("raw_detection_boxes");
TFloat32 numDetections = (TFloat32) outputTensorMap.get("num_detections");
TFloat32 detectionScores = (TFloat32) outputTensorMap.get("detection_scores");
TFloat32 rawDetectionScores = (TFloat32) outputTensorMap.get("raw_detection_scores");
TFloat32 detectionAnchorIndices = (TFloat32) outputTensorMap.get("detection_anchor_indices");
TFloat32 detectionMulticlassScores = (TFloat32) outputTensorMap.get("detection_multiclass_scores")) {
//detection_classes, detectionBoxes, num_detections. are model output names
try (Result result = model.function("serving_default").call(feedDict);
TFloat32 detectionBoxes = (TFloat32) result.get("detection_boxes")
.orElseThrow(() -> new RuntimeException("model output exception detection_boxes key is null"));
TFloat32 numDetections = (TFloat32) result.get("num_detections")
.orElseThrow(() -> new RuntimeException("model output exception num_detections key is null"));
TFloat32 detectionScores = (TFloat32) result.get("detection_scores")
.orElseThrow(() -> new RuntimeException("model output exception detection_scores key is null"))) {
int numDetects = (int) numDetections.getFloat(0);
if (numDetects > 0) {
ArrayList<FloatNdArray> boxArray = new ArrayList<>();
Expand Down Expand Up @@ -320,9 +318,9 @@ public static void main(String[] params) {
tf.dtypes.cast(tf.reshape(
tf.math.mul(
tf.image.drawBoundingBoxes(tf.math.div(
tf.dtypes.cast(tf.constant(reshapeTensor),
TFloat32.class),
tf.constant(255.0f)
tf.dtypes.cast(tf.constant(reshapeTensor),
TFloat32.class),
tf.constant(255.0f)
),
boxesPlaceHolder, colors),
tf.constant(255.0f)
Expand All @@ -344,4 +342,4 @@ public static void main(String[] params) {
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
*/
package org.tensorflow.model.examples.regression.linear;

import java.util.List;
import java.util.Random;
import org.tensorflow.Graph;
import org.tensorflow.Result;
import org.tensorflow.Session;
import org.tensorflow.framework.optimizers.GradientDescent;
import org.tensorflow.framework.optimizers.Optimizer;
Expand Down Expand Up @@ -108,13 +108,13 @@ public static void main(String[] args) {
}

// Extract linear regression model weight and bias values
List<?> tensorList = session.runner()
Result tensorList = session.runner()
.fetch(WEIGHT_VARIABLE_NAME)
.fetch(BIAS_VARIABLE_NAME)
.run();

try (TFloat32 weightValue = (TFloat32)tensorList.get(0);
TFloat32 biasValue = (TFloat32)tensorList.get(1)) {
try (TFloat32 weightValue = (TFloat32) tensorList.get(0);
TFloat32 biasValue = (TFloat32) tensorList.get(1)) {

System.out.println("Weight is " + weightValue.getFloat());
System.out.println("Bias is " + biasValue.getFloat());
Expand All @@ -126,7 +126,7 @@ public static void main(String[] args) {

try (TFloat32 xTensor = TFloat32.scalarOf(x);
TFloat32 yTensor = TFloat32.scalarOf(predictedY);
TFloat32 yPredictedTensor = (TFloat32)session.runner()
TFloat32 yPredictedTensor = (TFloat32) session.runner()
.feed(xData.asOutput(), xTensor)
.feed(yData.asOutput(), yTensor)
.fetch(yPredicted)
Expand Down