Skip to content

Commit 022d103

Browse files
committed
Fixes for review comments, and to ensure that tensors are named as expected.
1 parent ad9172c commit 022d103

File tree

5 files changed

+56
-16
lines changed

5 files changed

+56
-16
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ public Operand<?> call(Scope scope, Operand<?> argument) {
296296

297297
@Override
298298
public Result call(Map<String, Tensor> arguments) {
299-
// FIXME need to manage input/output operand lifetimes
299+
// FIXME need to manage input operand lifetimes
300300
Ops tf = Ops.create();
301301
Map<String, Operand<?>> inputs = new LinkedHashMap<>(arguments.size());
302302

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
package org.tensorflow;
1919

20+
import org.tensorflow.exceptions.TensorFlowException;
2021
import org.tensorflow.proto.framework.RunMetadata;
2122

2223
import java.util.ArrayList;
@@ -27,6 +28,7 @@
2728
import java.util.Map;
2829
import java.util.Optional;
2930
import java.util.Set;
31+
import java.util.logging.Level;
3032
import java.util.logging.Logger;
3133

3234
/**
@@ -35,15 +37,27 @@
3537
* <p>When this is closed it closes all the {@link Tensor}s inside it. If you maintain a
3638
* reference to a value after this object has been closed it will throw an {@link
3739
* IllegalStateException} upon access.
40+
*
41+
* <p>This class is not thread-safe with respect to the close operation. Multiple closers
42+
* or one thread closing a tensor while another is reading may throw exceptions.
43+
*
44+
* <p>Note this class is used to manage the lifetimes of tensors produced by the
45+
* TensorFlow runtime, from sessions and function calls. It is not used as an argument
46+
* to {@code session.run} or function calls as users are in control of the creation
47+
* of input tensors.
3848
*/
3949
public final class Result implements AutoCloseable, Iterable<Map.Entry<String, Tensor>> {
4050
@Override
4151
public void close() {
4252
if (!closed) {
43-
closed = true;
44-
for (Tensor t : map.values()) {
45-
t.close();
53+
for (Tensor t : list) {
54+
try {
55+
t.close();
56+
} catch (TensorFlowException e) {
57+
logger.log(Level.WARNING, "Exception raised when closing tensor inside result.", e);
58+
}
4659
}
60+
closed = true;
4761
} else {
4862
logger.warning("Closing an already closed Result");
4963
}
@@ -111,12 +125,7 @@ public Tensor get(int index) {
111125
*/
112126
public Optional<Tensor> get(String key) {
113127
if (!closed) {
114-
Tensor value = map.get(key);
115-
if (value != null) {
116-
return Optional.of(value);
117-
} else {
118-
return Optional.empty();
119-
}
128+
return Optional.ofNullable(map.get(key));
120129
} else {
121130
throw new IllegalStateException("Result is closed");
122131
}
@@ -153,7 +162,10 @@ public Optional<RunMetadata> getMetadata() {
153162
}
154163

155164
for (int i = 0; i < names.size(); i++) {
156-
this.map.put(names.get(i), values.get(i));
165+
Tensor old = this.map.put(names.get(i), values.get(i));
166+
if (old != null) {
167+
throw new IllegalArgumentException("Name collision in the result set, two outputs are named '" + names.get(i) + "'");
168+
}
157169
}
158170
this.metadata = metadata;
159171
this.closed = false;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,9 @@ public Runner feed(Operand<?> operand, Tensor t) {
308308
* @throws IllegalArgumentException if no output exists with the provided name
309309
*/
310310
public Runner fetch(String operation) {
311-
return fetch(graph.outputOrThrow(operation));
311+
Runner r = fetch(graph.outputOrThrow(operation),false);
312+
outputNames.add(operation);
313+
return r;
312314
}
313315

314316
/**
@@ -338,6 +340,20 @@ public Runner fetch(String operation, int index) {
338340
* @return this session runner
339341
*/
340342
public Runner fetch(Output<?> output) {
343+
return fetch(output, true);
344+
}
345+
346+
/**
347+
* Makes {@link #run()} return the Tensor referred to by {@code output}.
348+
*
349+
* <p>If {@code output} is a resource variable, will fetch the value.
350+
*
351+
* @param output the node to fetch the tensor from
352+
* @param recordName Records the output name. If false the output name must be recorded by the
353+
* calling method as otherwise the result object will throw on construction.
354+
* @return this session runner
355+
*/
356+
private Runner fetch(Output<?> output, boolean recordName) {
341357
if (output.env() != graph) {
342358
throw new IllegalStateException(
343359
"Can't fetch output "
@@ -380,6 +396,9 @@ public Runner fetch(Output<?> output) {
380396
} else {
381397
outputs.add(output);
382398
}
399+
if (recordName) {
400+
outputNames.add(output.name());
401+
}
383402
return this;
384403
}
385404

@@ -523,7 +542,6 @@ private Result runHelper(boolean wantMetadata) {
523542
TF_Operation[] outputOpHandles = new TF_Operation[outputs.size()];
524543
int[] outputOpIndices = new int[outputs.size()];
525544
TF_Operation[] targetOpHandles = new TF_Operation[targets.size()];
526-
List<String> outputNames = new ArrayList<>();
527545

528546
// It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the
529547
// validity of the Graph and graphRef ensures that.
@@ -541,7 +559,6 @@ private Result runHelper(boolean wantMetadata) {
541559
for (Output<?> o : outputs) {
542560
outputOpHandles[idx] = (TF_Operation) o.getUnsafeNativeHandle();
543561
outputOpIndices[idx] = o.index();
544-
outputNames.add(o.name());
545562
idx++;
546563
}
547564
idx = 0;
@@ -603,6 +620,7 @@ public void close() {
603620
private final ArrayList<Output<?>> inputs = new ArrayList<>();
604621
private final ArrayList<Tensor> inputTensors = new ArrayList<>();
605622
private final ArrayList<Output<?>> outputs = new ArrayList<>();
623+
private final ArrayList<String> outputNames = new ArrayList<>();
606624
private final ArrayList<GraphOperation> targets = new ArrayList<>();
607625
private RunOptions runOptions = null;
608626
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SessionFunction.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,16 @@ public Result call(Map<String, Tensor> arguments) {
112112

113113
signature.getOutputs().values().forEach(x -> runner.fetch(x.name));
114114

115-
return runner.run();
115+
Result results = runner.run();
116+
117+
// Unpack the result object and rebuild it with the expected names.
118+
LinkedHashMap<String, Tensor> outputs = new LinkedHashMap<>(results.size());
119+
int i = 0;
120+
for (String outputName : signature.outputNames()) {
121+
outputs.put(outputName, results.get(i));
122+
i++;
123+
}
124+
125+
return new Result(outputs);
116126
}
117127
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ static <T extends TType> T of(Class<T> type, Shape shape, ByteDataBuffer rawData
210210
* <p>When this methods retuns {@code true}, the tensor could be cast to a {@link SparseTensor
211211
* SparseTensor<T>} to access its <i>indices</i>, <i>values</i> and <i>denseShape</i> tensors.
212212
*
213-
* @retrun true if this tensor is a sparse
213+
* @return true if this tensor is a sparse
214214
*/
215215
default boolean isSparse() {
216216
return false;

0 commit comments

Comments
 (0)