Skip to content

Commit c54658b

Browse files
feat: Elastic Text-Embedding demo. (#9271)
* feat: Elastic Text-Embedding demo. * feat: Elastic Text-Embedding demo. * feat: Elastic Text-Embedding demo. * feat: Elastic Text-Embedding demo. * feat: Elastic Text-Embedding demo.
1 parent 4613891 commit c54658b

File tree

3 files changed

+169
-49
lines changed

3 files changed

+169
-49
lines changed

aiplatform/src/main/java/aiplatform/PredictTextEmbeddingsSample.java

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,64 +17,72 @@
1717
package aiplatform;
1818

1919
// [START aiplatform_sdk_embedding]
20-
21-
import com.google.cloud.aiplatform.util.ValueConverter;
2220
import com.google.cloud.aiplatform.v1beta1.EndpointName;
21+
import com.google.cloud.aiplatform.v1beta1.PredictRequest;
2322
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
2423
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
2524
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
25+
import com.google.protobuf.Struct;
2626
import com.google.protobuf.Value;
27-
import com.google.protobuf.util.JsonFormat;
2827
import java.io.IOException;
29-
import java.util.ArrayList;
3028
import java.util.List;
29+
import java.util.regex.Matcher;
30+
import java.util.regex.Pattern;
3131

3232
public class PredictTextEmbeddingsSample {
33-
3433
public static void main(String[] args) throws IOException {
3534
// TODO(developer): Replace these variables before running the sample.
3635
// Details about text embedding request structure and supported models are available in:
3736
// https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings
38-
String instance = "{ \"content\": \"What is life?\"}";
37+
String endpoint = "us-central1-aiplatform.googleapis.com:443";
3938
String project = "YOUR_PROJECT_ID";
40-
String location = "us-central1";
41-
String publisher = "google";
42-
String model = "textembedding-gecko@001";
43-
44-
predictTextEmbeddings(instance, project, location, publisher, model);
39+
String model = "textembedding-gecko@003";
40+
predictTextEmbeddings(
41+
endpoint,
42+
project,
43+
model,
44+
List.of("banana bread?", "banana muffins?"),
45+
"RETRIEVAL_DOCUMENT");
4546
}
4647

47-
// Get text embeddings from a supported embedding model
48+
// Gets text embeddings from a pretrained, foundational model.
4849
public static void predictTextEmbeddings(
49-
String instance, String project, String location, String publisher, String model)
50+
String endpoint,
51+
String project,
52+
String model,
53+
List<String> texts,
54+
String task)
5055
throws IOException {
51-
String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
52-
PredictionServiceSettings predictionServiceSettings =
53-
PredictionServiceSettings.newBuilder()
54-
.setEndpoint(endpoint)
55-
.build();
56-
57-
// Initialize client that will be used to send requests. This client only needs to be created
58-
// once, and can be reused for multiple requests.
59-
try (PredictionServiceClient predictionServiceClient =
60-
PredictionServiceClient.create(predictionServiceSettings)) {
61-
EndpointName endpointName =
62-
EndpointName.ofProjectLocationPublisherModelName(project, location, publisher, model);
56+
PredictionServiceSettings settings =
57+
PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
58+
Matcher matcher = Pattern.compile("^(?<Location>\\w+-\\w+)").matcher(endpoint);
59+
String location = matcher.matches() ? matcher.group("Location") : "us-central1";
60+
EndpointName endpointName =
61+
EndpointName.ofProjectLocationPublisherModelName(project, location, "google", model);
6362

64-
// Use Value.Builder to convert instance to a dynamically typed value that can be
65-
// processed by the service.
66-
Value.Builder instanceValue = Value.newBuilder();
67-
JsonFormat.parser().merge(instance, instanceValue);
68-
List<Value> instances = new ArrayList<>();
69-
instances.add(instanceValue.build());
70-
71-
PredictResponse predictResponse =
72-
predictionServiceClient.predict(endpointName, instances, ValueConverter.EMPTY_VALUE);
73-
System.out.println("Predict Response");
74-
for (Value prediction : predictResponse.getPredictionsList()) {
75-
System.out.format("\tPrediction: %s\n", prediction);
63+
// You can use this prediction service client for multiple requests.
64+
try (PredictionServiceClient client = PredictionServiceClient.create(settings)) {
65+
PredictRequest.Builder request =
66+
PredictRequest.newBuilder().setEndpoint(endpointName.toString());
67+
for (int i = 0; i < texts.size(); i++) {
68+
request.addInstances(
69+
Value.newBuilder()
70+
.setStructValue(
71+
Struct.newBuilder()
72+
.putFields("content", valueOf(texts.get(i)))
73+
.putFields("taskType", valueOf(task))
74+
.build()));
75+
}
76+
PredictResponse response = client.predict(request.build());
77+
System.out.println("Got predict response:\n");
78+
for (Value prediction : response.getPredictionsList()) {
79+
System.out.format("Got prediction: %s\n", prediction);
7680
}
7781
}
7882
}
83+
84+
private static Value valueOf(String s) {
85+
return Value.newBuilder().setStringValue(s).build();
86+
}
7987
}
8088
// [END aiplatform_sdk_embedding]
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
// [START generativeaionvertexai_sdk_embedding]
20+
import com.google.cloud.aiplatform.v1beta1.EndpointName;
21+
import com.google.cloud.aiplatform.v1beta1.PredictRequest;
22+
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
23+
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
24+
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
25+
import com.google.protobuf.Struct;
26+
import com.google.protobuf.Value;
27+
import java.io.IOException;
28+
import java.util.List;
29+
import java.util.OptionalInt;
30+
import java.util.regex.Matcher;
31+
import java.util.regex.Pattern;
32+
33+
public class PredictTextEmbeddingsSamplePreview {
34+
public static void main(String[] args) throws IOException {
35+
// TODO(developer): Replace these variables before running the sample.
36+
// Details about text embedding request structure and supported models are
37+
// available in:
38+
// https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings
39+
String endpoint = "us-central1-aiplatform.googleapis.com";
40+
String project = "YOUR_PROJECT_ID";
41+
String model = "text-embedding-preview-0409";
42+
predictTextEmbeddings(
43+
endpoint,
44+
project,
45+
model,
46+
List.of("banana bread?", "banana muffins?"),
47+
"QUESTION_ANSWERING",
48+
OptionalInt.of(256));
49+
}
50+
51+
// Gets text embeddings from a pretrained, foundational model.
52+
public static void predictTextEmbeddings(
53+
String endpoint,
54+
String project,
55+
String model,
56+
List<String> texts,
57+
String task,
58+
OptionalInt outputDimensionality)
59+
throws IOException {
60+
PredictionServiceSettings settings =
61+
PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
62+
Matcher matcher = Pattern.compile("^(?<Location>\\w+-\\w+)").matcher(endpoint);
63+
String location = matcher.matches() ? matcher.group("Location") : "us-central1";
64+
EndpointName endpointName =
65+
EndpointName.ofProjectLocationPublisherModelName(project, location, "google", model);
66+
67+
// You can use this prediction service client for multiple requests.
68+
try (PredictionServiceClient client = PredictionServiceClient.create(settings)) {
69+
PredictRequest.Builder request =
70+
PredictRequest.newBuilder().setEndpoint(endpointName.toString());
71+
if (outputDimensionality.isPresent()) {
72+
request.setParameters(
73+
Value.newBuilder()
74+
.setStructValue(
75+
Struct.newBuilder()
76+
.putFields("outputDimensionality", valueOf(outputDimensionality.getAsInt()))
77+
.build()));
78+
}
79+
for (int i = 0; i < texts.size(); i++) {
80+
request.addInstances(
81+
Value.newBuilder()
82+
.setStructValue(
83+
Struct.newBuilder()
84+
.putFields("content", valueOf(texts.get(i)))
85+
.putFields("taskType", valueOf(task))
86+
.build()));
87+
}
88+
PredictResponse response = client.predict(request.build());
89+
System.out.println("Got predict response:\n");
90+
for (Value prediction : response.getPredictionsList()) {
91+
System.out.format("Got prediction: %s\n", prediction);
92+
}
93+
}
94+
}
95+
96+
private static Value valueOf(String s) {
97+
return Value.newBuilder().setStringValue(s).build();
98+
}
99+
100+
private static Value valueOf(int n) {
101+
return Value.newBuilder().setNumberValue(n).build();
102+
}
103+
}
104+
// [END generativeaionvertexai_sdk_embedding]

aiplatform/src/test/java/aiplatform/PredictTextEmbeddingsSampleTest.java

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,18 @@
2323
import java.io.ByteArrayOutputStream;
2424
import java.io.IOException;
2525
import java.io.PrintStream;
26+
import java.util.List;
27+
import java.util.OptionalInt;
2628
import org.junit.After;
2729
import org.junit.Before;
2830
import org.junit.BeforeClass;
2931
import org.junit.Rule;
3032
import org.junit.Test;
3133

3234
public class PredictTextEmbeddingsSampleTest {
33-
3435
@Rule public final MultipleAttemptsRule multipleAttemptsRule = new MultipleAttemptsRule(3);
35-
36+
private static final String APIS_ENDPOINT = "us-central1-aiplatform.googleapis.com:443";
3637
private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
37-
private static final String LOCATION = "us-central1";
38-
private static final String INSTANCE = "{ \"content\": \"What is life?\"}";
39-
private static final String PUBLISHER = "google";
40-
private static final String MODEL = "textembedding-gecko@001";
41-
4238
private ByteArrayOutputStream bout;
4339
private PrintStream out;
4440
private PrintStream originalPrintStream;
@@ -71,12 +67,24 @@ public void tearDown() {
7167

7268
@Test
7369
public void testPredictTextEmbeddings() throws IOException {
74-
// Act
70+
List<String> texts =
71+
List.of("banana bread?", "banana muffin?", "banana?", "recipe?", "muffin recipe?");
7572
PredictTextEmbeddingsSample.predictTextEmbeddings(
76-
INSTANCE, PROJECT, LOCATION, PUBLISHER, MODEL);
73+
APIS_ENDPOINT, PROJECT, "textembedding-gecko@003", texts, "RETRIEVAL_DOCUMENT");
74+
assertThat(bout.toString()).contains("Got predict response");
75+
}
7776

78-
// Assert
79-
String got = bout.toString();
80-
assertThat(got).contains("Predict Response");
77+
@Test
78+
public void testPredictTextEmbeddingsPreview() throws IOException {
79+
List<String> texts =
80+
List.of("banana bread?", "banana muffin?", "banana?", "recipe?", "muffin recipe?");
81+
PredictTextEmbeddingsSamplePreview.predictTextEmbeddings(
82+
APIS_ENDPOINT,
83+
PROJECT,
84+
"text-embedding-preview-0409",
85+
texts,
86+
"QUESTION_ANSWERING",
87+
OptionalInt.of(256));
88+
assertThat(bout.toString()).contains("Got predict response");
8189
}
8290
}

0 commit comments

Comments
 (0)