Skip to content

Commit cfabfe0

Browse files
committed
support responseSchema in VertexAiGeminiChatOptions
Closes #2087 Signed-off-by: Andrei Sumin <[email protected]>
1 parent 9d1e1b5 commit cfabfe0

File tree

6 files changed

+232
-20
lines changed

6 files changed

+232
-20
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import com.google.cloud.vertexai.api.GenerationConfig;
3636
import com.google.cloud.vertexai.api.Part;
3737
import com.google.cloud.vertexai.api.SafetySetting;
38-
import com.google.cloud.vertexai.api.Schema;
3938
import com.google.cloud.vertexai.api.Tool;
4039
import com.google.cloud.vertexai.api.Tool.GoogleSearch;
4140
import com.google.cloud.vertexai.generativeai.GenerativeModel;
@@ -88,6 +87,7 @@
8887
import org.springframework.ai.vertexai.gemini.api.VertexAiGeminiApi;
8988
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiConstants;
9089
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
90+
import org.springframework.ai.vertexai.gemini.schema.VertexAiSchemaConverter;
9191
import org.springframework.ai.vertexai.gemini.schema.VertexToolCallingManager;
9292
import org.springframework.beans.factory.DisposableBean;
9393
import org.springframework.lang.NonNull;
@@ -376,17 +376,6 @@ else if (rootNode.isArray()) {
376376
}
377377
}
378378

379-
private static Schema jsonToSchema(String json) {
380-
try {
381-
var schemaBuilder = Schema.newBuilder();
382-
JsonFormat.parser().ignoringUnknownFields().merge(json, schemaBuilder);
383-
return schemaBuilder.build();
384-
}
385-
catch (Exception e) {
386-
throw new RuntimeException(e);
387-
}
388-
}
389-
390379
// https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
391380
@Override
392381
public ChatResponse call(Prompt prompt) {
@@ -697,7 +686,7 @@ GeminiRequest createGeminiRequest(Prompt prompt) {
697686
.map(toolDefinition -> FunctionDeclaration.newBuilder()
698687
.setName(toolDefinition.name())
699688
.setDescription(toolDefinition.description())
700-
.setParameters(jsonToSchema(toolDefinition.inputSchema()))
689+
.setParameters(VertexAiSchemaConverter.fromOpenApiSchema(toolDefinition.inputSchema()))
701690
.build())
702691
.toList();
703692
tools.add(Tool.newBuilder().addAllFunctionDeclarations(functionDeclarations).build());
@@ -759,6 +748,10 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) {
759748
if (options.getResponseMimeType() != null) {
760749
generationConfigBuilder.setResponseMimeType(options.getResponseMimeType());
761750
}
751+
if (options.getResponseSchema() != null) {
752+
generationConfigBuilder
753+
.setResponseSchema(VertexAiSchemaConverter.fromOpenApiSchema(options.getResponseSchema()));
754+
}
762755
if (options.getFrequencyPenalty() != null) {
763756
generationConfigBuilder.setFrequencyPenalty(options.getFrequencyPenalty().floatValue());
764757
}

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions {
110110
*/
111111
private @JsonProperty("responseMimeType") String responseMimeType;
112112

113+
/**
114+
* Optional. OpenAPI response schema.
115+
*/
116+
private @JsonProperty("responseSchema") String responseSchema;
117+
113118
/**
114119
* Optional. Frequency penalties.
115120
*/
@@ -170,8 +175,8 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr
170175
options.setModel(fromOptions.getModel());
171176
options.setToolCallbacks(fromOptions.getToolCallbacks());
172177
options.setResponseMimeType(fromOptions.getResponseMimeType());
178+
options.setResponseSchema(fromOptions.getResponseSchema());
173179
options.setToolNames(fromOptions.getToolNames());
174-
options.setResponseMimeType(fromOptions.getResponseMimeType());
175180
options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval());
176181
options.setSafetySettings(fromOptions.getSafetySettings());
177182
options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled());
@@ -265,6 +270,14 @@ public void setResponseMimeType(String mimeType) {
265270
this.responseMimeType = mimeType;
266271
}
267272

273+
public String getResponseSchema() {
274+
return this.responseSchema;
275+
}
276+
277+
public void setResponseSchema(String responseSchema) {
278+
this.responseSchema = responseSchema;
279+
}
280+
268281
@Override
269282
public List<ToolCallback> getToolCallbacks() {
270283
return this.toolCallbacks;
@@ -374,6 +387,7 @@ public boolean equals(Object o) {
374387
&& Objects.equals(this.presencePenalty, that.presencePenalty)
375388
&& Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model)
376389
&& Objects.equals(this.responseMimeType, that.responseMimeType)
390+
&& Objects.equals(this.responseSchema, that.responseSchema)
377391
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
378392
&& Objects.equals(this.toolNames, that.toolNames)
379393
&& Objects.equals(this.safetySettings, that.safetySettings)
@@ -386,8 +400,9 @@ public boolean equals(Object o) {
386400
public int hashCode() {
387401
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
388402
this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType,
389-
this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings,
390-
this.internalToolExecutionEnabled, this.toolContext, this.logprobs, this.responseLogprobs);
403+
this.responseSchema, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval,
404+
this.safetySettings, this.internalToolExecutionEnabled, this.toolContext, this.logprobs,
405+
this.responseLogprobs);
391406
}
392407

393408
@Override
@@ -396,10 +411,10 @@ public String toString() {
396411
+ this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", frequencyPenalty="
397412
+ this.frequencyPenalty + ", presencePenalty=" + this.presencePenalty + ", candidateCount="
398413
+ this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\''
399-
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks
400-
+ ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval
401-
+ ", safetySettings=" + this.safetySettings + ", logProbs=" + this.logprobs + ", responseLogprobs="
402-
+ this.responseLogprobs + '}';
414+
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", responseSchema='" + this.responseSchema
415+
+ ", toolCallbacks=" + this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval="
416+
+ this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + ", logProbs=" + this.logprobs
417+
+ ", responseLogprobs=" + this.responseLogprobs + '}';
403418
}
404419

405420
@Override
@@ -473,6 +488,11 @@ public Builder responseMimeType(String mimeType) {
473488
return this;
474489
}
475490

491+
public Builder responseSchema(String responseSchema) {
492+
this.options.setResponseSchema(responseSchema);
493+
return this;
494+
}
495+
476496
public Builder toolCallbacks(List<ToolCallback> toolCallbacks) {
477497
this.options.toolCallbacks = toolCallbacks;
478498
return this;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vertexai.gemini.schema;
18+
19+
import com.google.cloud.vertexai.api.Schema;
20+
import com.google.protobuf.util.JsonFormat;
21+
22+
/**
23+
* Utility class for converting OpenAPI schemas to Vertex AI Schema objects.
24+
*
25+
* @since 1.1.0
26+
*/
27+
public final class VertexAiSchemaConverter {
28+
29+
private VertexAiSchemaConverter() {
30+
// Prevent instantiation
31+
}
32+
33+
/**
34+
* Converts an OpenAPI schema string to a Vertex AI Schema object.
35+
* @param openApiSchema The OpenAPI schema in JSON format
36+
* @return A Schema object representing the OpenAPI schema
37+
* @throws RuntimeException if the schema cannot be parsed
38+
*/
39+
public static Schema fromOpenApiSchema(String openApiSchema) {
40+
try {
41+
var schemaBuilder = Schema.newBuilder();
42+
JsonFormat.parser().ignoringUnknownFields().merge(openApiSchema, schemaBuilder);
43+
return schemaBuilder.build();
44+
}
45+
catch (Exception e) {
46+
throw new RuntimeException(e);
47+
}
48+
}
49+
50+
}

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/CreateGeminiRequestTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import com.google.cloud.vertexai.VertexAI;
2424
import com.google.cloud.vertexai.api.Content;
2525
import com.google.cloud.vertexai.api.Part;
26+
import com.google.cloud.vertexai.api.Schema;
27+
import com.google.cloud.vertexai.api.Type;
2628
import org.junit.jupiter.api.Test;
2729
import org.junit.jupiter.api.extension.ExtendWith;
2830
import org.mockito.Mock;
@@ -264,6 +266,9 @@ public void createRequestWithGenerationConfigOptions() {
264266
.responseMimeType("application/json")
265267
.responseLogprobs(true)
266268
.logprobs(2)
269+
.responseSchema("""
270+
{"type": "OBJECT"}
271+
""")
267272
.build())
268273
.build();
269274

@@ -284,6 +289,8 @@ public void createRequestWithGenerationConfigOptions() {
284289
assertThat(request.model().getGenerationConfig().getResponseMimeType()).isEqualTo("application/json");
285290
assertThat(request.model().getGenerationConfig().getLogprobs()).isEqualTo(2);
286291
assertThat(request.model().getGenerationConfig().getResponseLogprobs()).isEqualTo(true);
292+
assertThat(request.model().getGenerationConfig().getResponseSchema())
293+
.isEqualTo(Schema.newBuilder().setType(Type.OBJECT).build());
287294
}
288295

289296
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vertexai.gemini.schema;
18+
19+
import java.util.List;
20+
21+
import com.google.cloud.vertexai.api.Schema;
22+
import com.google.cloud.vertexai.api.Type;
23+
import org.junit.jupiter.api.Test;
24+
25+
import static org.junit.jupiter.api.Assertions.assertEquals;
26+
import static org.junit.jupiter.api.Assertions.assertTrue;
27+
28+
public class VertexAiSchemaConverterTests {
29+
30+
@Test
31+
public void fromOpenApiSchemaShouldConvertGenericFields() {
32+
String openApiSchema = """
33+
{
34+
"type": "OBJECT",
35+
"format": "date-time",
36+
"title": "Title",
37+
"description": "Description",
38+
"nullable": true,
39+
"example": "Example",
40+
"default": "0"
41+
}""";
42+
43+
Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);
44+
45+
assertEquals(Type.OBJECT, schema.getType());
46+
assertEquals("date-time", schema.getFormat());
47+
assertEquals("Title", schema.getTitle());
48+
assertEquals("Description", schema.getDescription());
49+
assertTrue(schema.getNullable());
50+
assertEquals("Example", schema.getExample().getStringValue());
51+
assertEquals("0", schema.getDefault().getStringValue());
52+
}
53+
54+
@Test
55+
public void fromOpenApiSchemaShouldConvertStringFields() {
56+
String openApiSchema = """
57+
{
58+
"type": "STRING",
59+
"enum": ["a", "b", "c"],
60+
"minLength": 1,
61+
"maxLength": 10,
62+
"pattern": "[0-9.]+"
63+
}""";
64+
65+
Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);
66+
67+
assertEquals(Type.STRING, schema.getType());
68+
assertEquals(List.of("a", "b", "c"), schema.getEnumList());
69+
assertEquals(1, schema.getMinLength());
70+
assertEquals(10, schema.getMaxLength());
71+
assertEquals("[0-9.]+", schema.getPattern());
72+
}
73+
74+
@Test
75+
public void fromOpenApiSchemaShouldConvertIntegerAndNumberFields() {
76+
String openApiSchema = """
77+
{
78+
"anyOf": [{"type": "INTEGER"}, {"type": "NUMBER"}],
79+
"minimum": 0,
80+
"maximum": 100
81+
}""";
82+
83+
Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);
84+
85+
assertEquals(Type.TYPE_UNSPECIFIED, schema.getType());
86+
assertEquals(Type.INTEGER, schema.getAnyOf(0).getType());
87+
assertEquals(Type.NUMBER, schema.getAnyOf(1).getType());
88+
assertEquals(0, schema.getMinimum());
89+
assertEquals(100, schema.getMaximum());
90+
}
91+
92+
@Test
93+
public void fromOpenApiSchemaShouldConvertArrayFields() {
94+
String openApiSchema = """
95+
{
96+
"type": "ARRAY",
97+
"items": {
98+
"type": "BOOLEAN"
99+
},
100+
"minItems": 1,
101+
"maxItems": 5
102+
}""";
103+
104+
Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);
105+
106+
assertEquals(Type.ARRAY, schema.getType());
107+
assertEquals(Type.BOOLEAN, schema.getItems().getType());
108+
assertEquals(1, schema.getMinItems());
109+
assertEquals(5, schema.getMaxItems());
110+
}
111+
112+
@Test
113+
public void fromOpenApiSchemaShouldConvertObjectFields() {
114+
String openApiSchema = """
115+
{
116+
"type": "OBJECT",
117+
"properties": {
118+
"property1": {
119+
"type": "STRING"
120+
},
121+
"property2": {
122+
"type": "INTEGER"
123+
}
124+
},
125+
"minProperties": 1,
126+
"maxProperties": 2,
127+
"required": ["property1"],
128+
"propertyOrdering": ["property1", "property2"]
129+
}""";
130+
131+
Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);
132+
133+
assertEquals(Type.OBJECT, schema.getType());
134+
assertEquals(2, schema.getPropertiesMap().size());
135+
assertEquals(1, schema.getMinProperties());
136+
assertEquals(2, schema.getMaxProperties());
137+
assertEquals(List.of("property1"), schema.getRequiredList());
138+
assertEquals(List.of("property1", "property2"), schema.getPropertyOrderingList());
139+
}
140+
141+
}

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/vertexai-gemini-chat.adoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ The prefix `spring.ai.vertex.ai.gemini.chat` is the property prefix that lets yo
9292

9393
| spring.ai.vertex.ai.gemini.chat.options.model | Supported https://cloud.google.com/vertex-ai/generative-ai/docs/models#gemini-models[Vertex AI Gemini Chat model] to use include the `gemini-2.0-flash`, `gemini-2.0-flash-lite` and the new `gemini-2.5-pro-preview-03-25`, `gemini-2.5-flash-preview-04-17` models. | gemini-2.0-flash
9494
| spring.ai.vertex.ai.gemini.chat.options.response-mime-type | Output response mimetype of the generated candidate text. | `text/plain`: (default) Text output or `application/json`: JSON response.
95+
| spring.ai.vertex.ai.gemini.chat.options.response-schema | String, containing the output response schema in OpenAPI format, as described in https://ai.google.dev/gemini-api/docs/structured-output#json-schemas. | -
9596
| spring.ai.vertex.ai.gemini.chat.options.google-search-retrieval | Use Google search Grounding feature | `true` or `false`, default `false`.
9697
| spring.ai.vertex.ai.gemini.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising responses from the generative. This value specifies default to be used by the backend while making the call to the generative. | 0.7
9798
| spring.ai.vertex.ai.gemini.chat.options.top-k | The maximum number of tokens to consider when sampling. The generative uses combined Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens. | -

0 commit comments

Comments
 (0)