diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java index 81e69b25e5d..1ef7f433ba7 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java @@ -177,6 +177,9 @@ private NDArray processEmbedding(NDList list, NDArray attentionMask) { case "cls": embedding = embedding.get(new NDIndex(":, 0")); break; + case "lasttoken": + embedding = lastTokenPool(embedding, attentionMask); + break; default: throw new AssertionError("Unexpected pooling mode: " + pooling); } @@ -239,6 +242,20 @@ private static NDArray weightedMeanPool(NDArray embeddings, NDArray attentionMas return embeddingSum.div(maskSum); } + private static NDArray lastTokenPool(NDArray embeddings, NDArray attentionMask) { + long sum = attentionMask.get(":, -1").sum().getLong(); + if (sum == attentionMask.getShape().get(0)) { + // left padding + return embeddings.get(":, -1"); + } + + long sequenceLength = attentionMask.sum(new int[] {1}).getLong() - 1; + long batchSize = embeddings.getShape().get(0); + embeddings = embeddings.get(":, " + sequenceLength); + NDArray index = embeddings.getManager().arange(batchSize); + return embeddings.get(index); + } + /** * Creates a builder to build a {@code TextEmbeddingTranslator}. * @@ -313,10 +330,11 @@ public Builder optPoolingMode(String poolingMode) { && !"max".equals(poolingMode) && !"cls".equals(poolingMode) && !"mean_sqrt_len".equals(poolingMode) + && !"lasttoken".equals(poolingMode) && !"weightedmean".equals(poolingMode)) { throw new IllegalArgumentException( "Invalid pooling model, must be one of [mean, max, cls, mean_sqrt_len," - + " weightedmean]."); + + " weightedmean, lasttoken]."); } this.pooling = poolingMode; return this; diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TextEmbeddingTranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TextEmbeddingTranslatorTest.java index 15a670860f7..68a677202df 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TextEmbeddingTranslatorTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/TextEmbeddingTranslatorTest.java @@ -163,6 +163,52 @@ public void testTextEmbeddingTranslator() Assertions.assertAlmostEquals(res[0], 0.05103104); } + // pooling_lasttokens with left padding + criteria = + Criteria.builder() + .setTypes(String.class, float[].class) + .optModelPath(modelDir) + .optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory") + .optArgument("block_shapes", "(1,7,384)") + .optArgument("block_names", "last_hidden_state") + .optEngine("PyTorch") + .optArgument("tokenizer", "intfloat/e5-mistral-7b-instruct") + .optArgument("pooling", "lasttoken") + .optOption("hasParameter", "false") + .optTranslatorFactory(new TextEmbeddingTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + float[] res = predictor.predict(text); + Assert.assertEquals(res.length, 384); + Assertions.assertAlmostEquals(res[0], 0.05103104); + } + + // pooling_lasttokens + criteria = + Criteria.builder() + .setTypes(String.class, float[].class) + .optModelPath(modelDir) + .optArgument("blockFactory", "ai.djl.nn.OnesBlockFactory") + .optArgument("block_shapes", "(1,7,384)") + .optArgument("block_names", "last_hidden_state") + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-uncased") + .optArgument("pooling", "lasttoken") + .optArgument("padding", "max_length") + .optArgument("maxLength", 10) + .optOption("hasParameter", "false") + .optTranslatorFactory(new TextEmbeddingTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + float[] res = predictor.predict(text); + Assert.assertEquals(res.length, 384); + Assertions.assertAlmostEquals(res[0], 0.05103104); + } + // dense and layerNorm criteria = Criteria.builder()