Skip to content

Commit 5727d71

Browse files
authored
Allow output targets in Session runner (#388)
+ return `SessionFunction` typed instances from `SavedModelBundle`
1 parent 2d449a6 commit 5727d71

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,7 @@ gradleBuild
5656
**/target
5757
.tf_configure.bazelrc
5858
.clwb/
59+
60+
# Deployment Files
61+
settings.xml
62+
pom.xml.asc

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ public List<Signature> signatures() {
430430
* @return object that can be used to make calls to a function
431431
* @throws IllegalArgumentException if {@code signatureKey} is not found in this saved model.
432432
*/
433-
public TensorFunction function(String signatureKey) {
433+
public SessionFunction function(String signatureKey) {
434434
SessionFunction function = functions.get(signatureKey);
435435
if (function == null) {
436436
throw new IllegalArgumentException(
@@ -444,7 +444,7 @@ public TensorFunction function(String signatureKey) {
444444
*
445445
* <p><b>All functions use the bundle's underlying session.</b>
446446
*/
447-
public List<TensorFunction> functions() {
447+
public List<SessionFunction> functions() {
448448
return new ArrayList<>(functions.values());
449449
}
450450

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,12 +397,13 @@ public Runner fetch(Operand<?> operand) {
397397
* Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor
398398
* Tensors}.
399399
*
400-
* @param operation the string name of the operation to execute
400+
* @param operation Is either the string name of the operation or it is a string of the form
401+
* <tt>operation_name:output_index</tt>, where <tt>output_index</tt> will simply be ignored.
401402
* @return this session runner
402403
* @throws IllegalArgumentException if no operation exists with the provided name
403404
*/
404405
public Runner addTarget(String operation) {
405-
return addTarget(graph.operationOrThrow(operation));
406+
return addTarget(graph.outputOrThrow(operation));
406407
}
407408

408409
/**

0 commit comments

Comments
 (0)