Skip to content

Commit

Permalink
Add documentation on loading TF extension libraries for running certa…
Browse files Browse the repository at this point in the history
…in TF models in DJL (#1776)
  • Loading branch information
siddvenk authored Jul 6, 2022
1 parent 7cecdfd commit 1d9f56b
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions docs/tensorflow/how_to_import_tensorflow_models_in_DJL.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,44 @@ Please refer to these two examples:
1. [Object Detection with TensorFlow](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ObjectDetection.java) for loading from TensorFlow Hub url.
2. [BERT Classification](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/BertClassification.java) for loading from local downloaded model.

## How to import TensorFlow models that use [TensorFlow Extensions](https://www.tensorflow.org/resources/libraries-extensions)

You can use DJL to run TensorFlow models that require ops available in TensorFlow extensions (not in core), such as models that use ops in [TensorFlow Text](https://www.tensorflow.org/text).

To use such models in DJL you must know:
* Which extension contains the additional required ops
* The TensorFlow extension library version that is compatible with the version of Tensorflow currently being used in DJL

Here we show an example that uses DJL to run the [Multilingual Universal Sentence Encoder](https://tfhub.dev/google/universal-sentence-encoder-multilingual/3).
This is a text encoder from tensorflow hub that uses ops from the TensorFlow Text extension.
* You need to download the library files for the extension so that we can load them for use. [Here](https://github.com/tensorflow/text#install-using-pip) are instructions for downloading TensorFlow text. We'll use version `2.7.0` in this example.
* After the library files are downloaded, you are ready to load them for use in DJL. You do this by using
TensorFlow Java `loadLibrary(...)` API before loading the model in DJL.
* We can update the existing [UniversalSentenceEncoder](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/UniversalSentenceEncoder.java) example to use the Multilingual model:

```java
// Build Criteria for Multilingual Universal Sentence Encoder from TensorFlow Hub
String modelUrl = "https://storage.googleapis.com/tfhub-modules/google/universal-sentence-encoder-multilingual/3.tar.gz"
Criteria<String[], float[][]> criteria =
Criteria.builder()
.optApplication(Application.NLP.TEXT_EMBEDDING)
.setTypes(String[].class, float[][].class)
.optModelUrls(modelUrl)
.optTranslator(new MyTranslator())
.optEngine("TensorFlow")
.optProgress(new ProgressBar())
.build();

// Load TensorFlow Text libraries for necessary ops
// Example Path of installation on Linux may look like: /home/<user>/.local/python-3.8.3/lib/python3.8/site-packages/tensorflow_text/
TensorFlow.loadLibrary("<path_to_library_installation>/python/ops/_sentencepiece_tokenizer.so");

// Run the model
try (ZooModel<String[], float[][]> model = criteria.loadModel();
Predictor<String[], float[][]> predictor = model.newPredictor()) {
return predictor.predict(inputs.toArray(new String[0]));
}
```
## How to load TensorFlow Checkpoints

To load an TensorFlow Estimator checkpoint, you need to convert it to SavedModel format in using Python.
Expand Down

0 comments on commit 1d9f56b

Please sign in to comment.