|
17 | 17 | package aiplatform;
|
18 | 18 |
|
19 | 19 | // [START aiplatform_sdk_embedding]
|
20 |
| - |
21 |
| -import com.google.cloud.aiplatform.util.ValueConverter; |
22 | 20 | import com.google.cloud.aiplatform.v1beta1.EndpointName;
|
| 21 | +import com.google.cloud.aiplatform.v1beta1.PredictRequest; |
23 | 22 | import com.google.cloud.aiplatform.v1beta1.PredictResponse;
|
24 | 23 | import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
|
25 | 24 | import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
|
| 25 | +import com.google.protobuf.Struct; |
26 | 26 | import com.google.protobuf.Value;
|
27 |
| -import com.google.protobuf.util.JsonFormat; |
28 | 27 | import java.io.IOException;
|
29 |
| -import java.util.ArrayList; |
30 | 28 | import java.util.List;
|
| 29 | +import java.util.regex.Matcher; |
| 30 | +import java.util.regex.Pattern; |
31 | 31 |
|
32 | 32 | public class PredictTextEmbeddingsSample {
|
33 |
| - |
34 | 33 | public static void main(String[] args) throws IOException {
|
35 | 34 | // TODO(developer): Replace these variables before running the sample.
|
36 | 35 | // Details about text embedding request structure and supported models are available in:
|
37 | 36 | // https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings
|
38 |
| - String instance = "{ \"content\": \"What is life?\"}"; |
| 37 | + String endpoint = "us-central1-aiplatform.googleapis.com:443"; |
39 | 38 | String project = "YOUR_PROJECT_ID";
|
40 |
| - String location = "us-central1"; |
41 |
| - String publisher = "google"; |
42 |
| - String model = "textembedding-gecko@001"; |
43 |
| - |
44 |
| - predictTextEmbeddings(instance, project, location, publisher, model); |
| 39 | + String model = "textembedding-gecko@003"; |
| 40 | + predictTextEmbeddings( |
| 41 | + endpoint, |
| 42 | + project, |
| 43 | + model, |
| 44 | + List.of("banana bread?", "banana muffins?"), |
| 45 | + "RETRIEVAL_DOCUMENT"); |
45 | 46 | }
|
46 | 47 |
|
47 |
| - // Get text embeddings from a supported embedding model |
| 48 | + // Gets text embeddings from a pretrained, foundational model. |
48 | 49 | public static void predictTextEmbeddings(
|
49 |
| - String instance, String project, String location, String publisher, String model) |
| 50 | + String endpoint, |
| 51 | + String project, |
| 52 | + String model, |
| 53 | + List<String> texts, |
| 54 | + String task) |
50 | 55 | throws IOException {
|
51 |
| - String endpoint = String.format("%s-aiplatform.googleapis.com:443", location); |
52 |
| - PredictionServiceSettings predictionServiceSettings = |
53 |
| - PredictionServiceSettings.newBuilder() |
54 |
| - .setEndpoint(endpoint) |
55 |
| - .build(); |
56 |
| - |
57 |
| - // Initialize client that will be used to send requests. This client only needs to be created |
58 |
| - // once, and can be reused for multiple requests. |
59 |
| - try (PredictionServiceClient predictionServiceClient = |
60 |
| - PredictionServiceClient.create(predictionServiceSettings)) { |
61 |
| - EndpointName endpointName = |
62 |
| - EndpointName.ofProjectLocationPublisherModelName(project, location, publisher, model); |
| 56 | + PredictionServiceSettings settings = |
| 57 | + PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build(); |
| 58 | + Matcher matcher = Pattern.compile("^(?<Location>\\w+-\\w+)").matcher(endpoint); |
| 59 | + String location = matcher.matches() ? matcher.group("Location") : "us-central1"; |
| 60 | + EndpointName endpointName = |
| 61 | + EndpointName.ofProjectLocationPublisherModelName(project, location, "google", model); |
63 | 62 |
|
64 |
| - // Use Value.Builder to convert instance to a dynamically typed value that can be |
65 |
| - // processed by the service. |
66 |
| - Value.Builder instanceValue = Value.newBuilder(); |
67 |
| - JsonFormat.parser().merge(instance, instanceValue); |
68 |
| - List<Value> instances = new ArrayList<>(); |
69 |
| - instances.add(instanceValue.build()); |
70 |
| - |
71 |
| - PredictResponse predictResponse = |
72 |
| - predictionServiceClient.predict(endpointName, instances, ValueConverter.EMPTY_VALUE); |
73 |
| - System.out.println("Predict Response"); |
74 |
| - for (Value prediction : predictResponse.getPredictionsList()) { |
75 |
| - System.out.format("\tPrediction: %s\n", prediction); |
| 63 | + // You can use this prediction service client for multiple requests. |
| 64 | + try (PredictionServiceClient client = PredictionServiceClient.create(settings)) { |
| 65 | + PredictRequest.Builder request = |
| 66 | + PredictRequest.newBuilder().setEndpoint(endpointName.toString()); |
| 67 | + for (int i = 0; i < texts.size(); i++) { |
| 68 | + request.addInstances( |
| 69 | + Value.newBuilder() |
| 70 | + .setStructValue( |
| 71 | + Struct.newBuilder() |
| 72 | + .putFields("content", valueOf(texts.get(i))) |
| 73 | + .putFields("taskType", valueOf(task)) |
| 74 | + .build())); |
| 75 | + } |
| 76 | + PredictResponse response = client.predict(request.build()); |
| 77 | + System.out.println("Got predict response:\n"); |
| 78 | + for (Value prediction : response.getPredictionsList()) { |
| 79 | + System.out.format("Got prediction: %s\n", prediction); |
76 | 80 | }
|
77 | 81 | }
|
78 | 82 | }
|
| 83 | + |
| 84 | + private static Value valueOf(String s) { |
| 85 | + return Value.newBuilder().setStringValue(s).build(); |
| 86 | + } |
79 | 87 | }
|
80 | 88 | // [END aiplatform_sdk_embedding]
|
0 commit comments