|
1 | 1 | /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
|
2 | 2 |
|
3 |
| - Licensed under the Apache License, Version 2.0 (the "License"); |
4 |
| - you may not use this file except in compliance with the License. |
5 |
| - You may obtain a copy of the License at |
6 |
| -
|
7 |
| - http://www.apache.org/licenses/LICENSE-2.0 |
8 |
| -
|
9 |
| - Unless required by applicable law or agreed to in writing, software |
10 |
| - distributed under the License is distributed on an "AS IS" BASIS, |
11 |
| - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 |
| - See the License for the specific language governing permissions and |
13 |
| - limitations under the License. |
14 |
| - ======================================================================= |
15 |
| - */ |
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +======================================================================= |
| 15 | +*/ |
16 | 16 | package org.tensorflow;
|
17 | 17 |
|
18 |
| -import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadSessionFromSavedModel; |
19 | 18 | import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewGraph;
|
20 | 19 | import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;
|
21 | 20 |
|
|
34 | 33 | import java.util.Map.Entry;
|
35 | 34 | import java.util.stream.Collectors;
|
36 | 35 | import org.bytedeco.javacpp.BytePointer;
|
37 |
| -import org.bytedeco.javacpp.PointerPointer; |
38 | 36 | import org.bytedeco.javacpp.PointerScope;
|
39 | 37 | import org.tensorflow.exceptions.TensorFlowException;
|
40 | 38 | import org.tensorflow.internal.c_api.TF_Buffer;
|
@@ -510,21 +508,18 @@ private static SavedModelBundle load(
|
510 | 508 | TF_Graph graph = TF_NewGraph();
|
511 | 509 | TF_Buffer metagraphDef = TF_Buffer.newBuffer();
|
512 | 510 | TF_Session session =
|
513 |
| - TF_LoadSessionFromSavedModel( |
514 |
| - opts, |
515 |
| - runOpts, |
516 |
| - new BytePointer(exportDir), |
517 |
| - new PointerPointer(tags), |
518 |
| - tags.length, |
519 |
| - graph, |
520 |
| - metagraphDef, |
521 |
| - status); |
| 511 | + TF_Session.loadSessionFromSavedModel( |
| 512 | + opts, runOpts, exportDir, tags, graph, metagraphDef, status); |
522 | 513 | status.throwExceptionIfNotOK();
|
523 | 514 |
|
524 | 515 | // handle the result
|
525 | 516 | try {
|
526 | 517 | bundle =
|
527 | 518 | fromHandle(graph, session, MetaGraphDef.parseFrom(metagraphDef.dataAsByteBuffer()));
|
| 519 | + // Only retain the references if the metagraphdef parses correctly, |
| 520 | + // otherwise allow the pointer scope to clean them up |
| 521 | + graph.retainReference(); |
| 522 | + session.retainReference(); |
528 | 523 | } catch (InvalidProtocolBufferException e) {
|
529 | 524 | throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
|
530 | 525 | }
|
|
0 commit comments