Skip to content

Commit 031a0c1

Browse files
authored
SavedModelBundle leak fix (#335)
1 parent b997f12 commit 031a0c1

File tree

2 files changed

+62
-56
lines changed

2 files changed

+62
-56
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
/* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
22
3-
Licensed under the Apache License, Version 2.0 (the "License");
4-
you may not use this file except in compliance with the License.
5-
You may obtain a copy of the License at
6-
7-
http://www.apache.org/licenses/LICENSE-2.0
8-
9-
Unless required by applicable law or agreed to in writing, software
10-
distributed under the License is distributed on an "AS IS" BASIS,
11-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
See the License for the specific language governing permissions and
13-
limitations under the License.
14-
=======================================================================
15-
*/
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================
15+
*/
1616
package org.tensorflow;
1717

18-
import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadSessionFromSavedModel;
1918
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewGraph;
2019
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;
2120

@@ -34,7 +33,6 @@
3433
import java.util.Map.Entry;
3534
import java.util.stream.Collectors;
3635
import org.bytedeco.javacpp.BytePointer;
37-
import org.bytedeco.javacpp.PointerPointer;
3836
import org.bytedeco.javacpp.PointerScope;
3937
import org.tensorflow.exceptions.TensorFlowException;
4038
import org.tensorflow.internal.c_api.TF_Buffer;
@@ -510,21 +508,18 @@ private static SavedModelBundle load(
510508
TF_Graph graph = TF_NewGraph();
511509
TF_Buffer metagraphDef = TF_Buffer.newBuffer();
512510
TF_Session session =
513-
TF_LoadSessionFromSavedModel(
514-
opts,
515-
runOpts,
516-
new BytePointer(exportDir),
517-
new PointerPointer(tags),
518-
tags.length,
519-
graph,
520-
metagraphDef,
521-
status);
511+
TF_Session.loadSessionFromSavedModel(
512+
opts, runOpts, exportDir, tags, graph, metagraphDef, status);
522513
status.throwExceptionIfNotOK();
523514

524515
// handle the result
525516
try {
526517
bundle =
527518
fromHandle(graph, session, MetaGraphDef.parseFrom(metagraphDef.dataAsByteBuffer()));
519+
// Only retain the references if the metagraphdef parses correctly,
520+
// otherwise allow the pointer scope to clean them up
521+
graph.retainReference();
522+
session.retainReference();
528523
} catch (InvalidProtocolBufferException e) {
529524
throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
530525
}
Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
/*
2-
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
33
4-
Licensed under the Apache License, Version 2.0 (the "License");
5-
you may not use this file except in compliance with the License.
6-
You may obtain a copy of the License at
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
77
8-
http://www.apache.org/licenses/LICENSE-2.0
8+
http://www.apache.org/licenses/LICENSE-2.0
99
10-
Unless required by applicable law or agreed to in writing, software
11-
distributed under the License is distributed on an "AS IS" BASIS,
12-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
See the License for the specific language governing permissions and
14-
limitations under the License.
15-
=======================================================================
16-
*/
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
=======================================================================
16+
*/
1717

1818
package org.tensorflow.internal.c_api;
1919

@@ -25,29 +25,40 @@
2525

2626
@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
2727
public abstract class AbstractTF_Graph extends Pointer {
28-
protected static class DeleteDeallocator extends TF_Graph implements Pointer.Deallocator {
29-
DeleteDeallocator(TF_Graph s) { super(s); }
30-
@Override public void deallocate() { if (!isNull()) TF_DeleteGraph(this); setNull(); }
28+
protected static class DeleteDeallocator extends TF_Graph implements Pointer.Deallocator {
29+
DeleteDeallocator(TF_Graph s) {
30+
super(s);
3131
}
3232

33-
public AbstractTF_Graph(Pointer p) { super(p); }
34-
35-
/**
36-
* Calls TF_NewGraph(), and registers a deallocator.
37-
* @return TF_Graph created. Do not call TF_DeleteGraph() on it.
38-
*/
39-
public static TF_Graph newGraph() {
40-
TF_Graph g = TF_NewGraph();
41-
if (g != null) {
42-
g.deallocator(new DeleteDeallocator(g));
43-
}
44-
return g;
33+
@Override
34+
public void deallocate() {
35+
if (!isNull()) TF_DeleteGraph(this);
36+
setNull();
4537
}
38+
}
4639

47-
/**
48-
* Calls the deallocator, if registered, otherwise has no effect.
49-
*/
50-
public void delete() {
51-
deallocate();
40+
public AbstractTF_Graph(Pointer p) {
41+
super(p);
42+
}
43+
44+
/**
45+
* Calls TF_NewGraph(), and registers a deallocator.
46+
*
47+
* <p>Note {@link org.tensorflow.Graph} will call TF_DeleteGraph on close, so do not use this
48+
* method when constructing a reference for use inside a {@code Graph} object.
49+
*
50+
* @return TF_Graph created. Do not call TF_DeleteGraph() on it.
51+
*/
52+
public static TF_Graph newGraph() {
53+
TF_Graph g = TF_NewGraph();
54+
if (g != null) {
55+
g.deallocator(new DeleteDeallocator(g));
5256
}
57+
return g;
58+
}
59+
60+
/** Calls the deallocator, if registered, otherwise has no effect. */
61+
public void delete() {
62+
deallocate();
63+
}
5364
}

0 commit comments

Comments
 (0)