Skip to content

Support responseSchema in VertexAiGeminiChatOptions #3765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.Part;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.Tool.GoogleSearch;
import com.google.cloud.vertexai.generativeai.GenerativeModel;
Expand Down Expand Up @@ -88,6 +87,7 @@
import org.springframework.ai.vertexai.gemini.api.VertexAiGeminiApi;
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiConstants;
import org.springframework.ai.vertexai.gemini.common.VertexAiGeminiSafetySetting;
import org.springframework.ai.vertexai.gemini.schema.VertexAiSchemaConverter;
import org.springframework.ai.vertexai.gemini.schema.VertexToolCallingManager;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.lang.NonNull;
Expand Down Expand Up @@ -376,17 +376,6 @@ else if (rootNode.isArray()) {
}
}

private static Schema jsonToSchema(String json) {
try {
var schemaBuilder = Schema.newBuilder();
JsonFormat.parser().ignoringUnknownFields().merge(json, schemaBuilder);
return schemaBuilder.build();
}
catch (Exception e) {
throw new RuntimeException(e);
}
}

// https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini
@Override
public ChatResponse call(Prompt prompt) {
Expand Down Expand Up @@ -697,7 +686,7 @@ GeminiRequest createGeminiRequest(Prompt prompt) {
.map(toolDefinition -> FunctionDeclaration.newBuilder()
.setName(toolDefinition.name())
.setDescription(toolDefinition.description())
.setParameters(jsonToSchema(toolDefinition.inputSchema()))
.setParameters(VertexAiSchemaConverter.fromOpenApiSchema(toolDefinition.inputSchema()))
.build())
.toList();
tools.add(Tool.newBuilder().addAllFunctionDeclarations(functionDeclarations).build());
Expand Down Expand Up @@ -759,6 +748,10 @@ private GenerationConfig toGenerationConfig(VertexAiGeminiChatOptions options) {
if (options.getResponseMimeType() != null) {
generationConfigBuilder.setResponseMimeType(options.getResponseMimeType());
}
if (options.getResponseSchema() != null) {
generationConfigBuilder
.setResponseSchema(VertexAiSchemaConverter.fromOpenApiSchema(options.getResponseSchema()));
}
if (options.getFrequencyPenalty() != null) {
generationConfigBuilder.setFrequencyPenalty(options.getFrequencyPenalty().floatValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions {
*/
private @JsonProperty("responseMimeType") String responseMimeType;

/**
* Optional. OpenAPI response schema.
*/
private @JsonProperty("responseSchema") String responseSchema;

/**
* Optional. Frequency penalties.
*/
Expand Down Expand Up @@ -170,8 +175,8 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr
options.setModel(fromOptions.getModel());
options.setToolCallbacks(fromOptions.getToolCallbacks());
options.setResponseMimeType(fromOptions.getResponseMimeType());
options.setResponseSchema(fromOptions.getResponseSchema());
options.setToolNames(fromOptions.getToolNames());
options.setResponseMimeType(fromOptions.getResponseMimeType());
options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval());
options.setSafetySettings(fromOptions.getSafetySettings());
options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled());
Expand Down Expand Up @@ -265,6 +270,14 @@ public void setResponseMimeType(String mimeType) {
this.responseMimeType = mimeType;
}

public String getResponseSchema() {
return this.responseSchema;
}

public void setResponseSchema(String responseSchema) {
this.responseSchema = responseSchema;
}

@Override
public List<ToolCallback> getToolCallbacks() {
return this.toolCallbacks;
Expand Down Expand Up @@ -374,6 +387,7 @@ public boolean equals(Object o) {
&& Objects.equals(this.presencePenalty, that.presencePenalty)
&& Objects.equals(this.maxOutputTokens, that.maxOutputTokens) && Objects.equals(this.model, that.model)
&& Objects.equals(this.responseMimeType, that.responseMimeType)
&& Objects.equals(this.responseSchema, that.responseSchema)
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
&& Objects.equals(this.toolNames, that.toolNames)
&& Objects.equals(this.safetySettings, that.safetySettings)
Expand All @@ -386,8 +400,9 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType,
this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings,
this.internalToolExecutionEnabled, this.toolContext, this.logprobs, this.responseLogprobs);
this.responseSchema, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval,
this.safetySettings, this.internalToolExecutionEnabled, this.toolContext, this.logprobs,
this.responseLogprobs);
}

@Override
Expand All @@ -396,10 +411,10 @@ public String toString() {
+ this.temperature + ", topP=" + this.topP + ", topK=" + this.topK + ", frequencyPenalty="
+ this.frequencyPenalty + ", presencePenalty=" + this.presencePenalty + ", candidateCount="
+ this.candidateCount + ", maxOutputTokens=" + this.maxOutputTokens + ", model='" + this.model + '\''
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks=" + this.toolCallbacks
+ ", toolNames=" + this.toolNames + ", googleSearchRetrieval=" + this.googleSearchRetrieval
+ ", safetySettings=" + this.safetySettings + ", logProbs=" + this.logprobs + ", responseLogprobs="
+ this.responseLogprobs + '}';
+ ", responseMimeType='" + this.responseMimeType + '\'' + ", responseSchema='" + this.responseSchema
+ ", toolCallbacks=" + this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval="
+ this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + ", logProbs=" + this.logprobs
+ ", responseLogprobs=" + this.responseLogprobs + '}';
}

@Override
Expand Down Expand Up @@ -473,6 +488,11 @@ public Builder responseMimeType(String mimeType) {
return this;
}

public Builder responseSchema(String responseSchema) {
this.options.setResponseSchema(responseSchema);
return this;
}

public Builder toolCallbacks(List<ToolCallback> toolCallbacks) {
this.options.toolCallbacks = toolCallbacks;
return this;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.vertexai.gemini.schema;

import com.google.cloud.vertexai.api.Schema;
import com.google.protobuf.util.JsonFormat;

/**
* Utility class for converting OpenAPI schemas to Vertex AI Schema objects.
*
* @since 1.1.0
*/
public final class VertexAiSchemaConverter {

private VertexAiSchemaConverter() {
// Prevent instantiation
}

/**
* Converts an OpenAPI schema string to a Vertex AI Schema object.
* @param openApiSchema The OpenAPI schema in JSON format
* @return A Schema object representing the OpenAPI schema
* @throws RuntimeException if the schema cannot be parsed
*/
public static Schema fromOpenApiSchema(String openApiSchema) {
try {
var schemaBuilder = Schema.newBuilder();
JsonFormat.parser().ignoringUnknownFields().merge(openApiSchema, schemaBuilder);
return schemaBuilder.build();
}
catch (Exception e) {
throw new RuntimeException(e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.Part;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Type;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand Down Expand Up @@ -264,6 +266,9 @@ public void createRequestWithGenerationConfigOptions() {
.responseMimeType("application/json")
.responseLogprobs(true)
.logprobs(2)
.responseSchema("""
{"type": "OBJECT"}
""")
.build())
.build();

Expand All @@ -284,6 +289,8 @@ public void createRequestWithGenerationConfigOptions() {
assertThat(request.model().getGenerationConfig().getResponseMimeType()).isEqualTo("application/json");
assertThat(request.model().getGenerationConfig().getLogprobs()).isEqualTo(2);
assertThat(request.model().getGenerationConfig().getResponseLogprobs()).isEqualTo(true);
assertThat(request.model().getGenerationConfig().getResponseSchema())
.isEqualTo(Schema.newBuilder().setType(Type.OBJECT).build());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.vertexai.gemini.schema;

import java.util.List;

import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Type;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class VertexAiSchemaConverterTests {

@Test
public void fromOpenApiSchemaShouldConvertGenericFields() {
String openApiSchema = """
{
"type": "OBJECT",
"format": "date-time",
"title": "Title",
"description": "Description",
"nullable": true,
"example": "Example",
"default": "0"
}""";

Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);

assertEquals(Type.OBJECT, schema.getType());
assertEquals("date-time", schema.getFormat());
assertEquals("Title", schema.getTitle());
assertEquals("Description", schema.getDescription());
assertTrue(schema.getNullable());
assertEquals("Example", schema.getExample().getStringValue());
assertEquals("0", schema.getDefault().getStringValue());
}

@Test
public void fromOpenApiSchemaShouldConvertStringFields() {
String openApiSchema = """
{
"type": "STRING",
"enum": ["a", "b", "c"],
"minLength": 1,
"maxLength": 10,
"pattern": "[0-9.]+"
}""";

Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);

assertEquals(Type.STRING, schema.getType());
assertEquals(List.of("a", "b", "c"), schema.getEnumList());
assertEquals(1, schema.getMinLength());
assertEquals(10, schema.getMaxLength());
assertEquals("[0-9.]+", schema.getPattern());
}

@Test
public void fromOpenApiSchemaShouldConvertIntegerAndNumberFields() {
String openApiSchema = """
{
"anyOf": [{"type": "INTEGER"}, {"type": "NUMBER"}],
"minimum": 0,
"maximum": 100
}""";

Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);

assertEquals(Type.TYPE_UNSPECIFIED, schema.getType());
assertEquals(Type.INTEGER, schema.getAnyOf(0).getType());
assertEquals(Type.NUMBER, schema.getAnyOf(1).getType());
assertEquals(0, schema.getMinimum());
assertEquals(100, schema.getMaximum());
}

@Test
public void fromOpenApiSchemaShouldConvertArrayFields() {
String openApiSchema = """
{
"type": "ARRAY",
"items": {
"type": "BOOLEAN"
},
"minItems": 1,
"maxItems": 5
}""";

Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);

assertEquals(Type.ARRAY, schema.getType());
assertEquals(Type.BOOLEAN, schema.getItems().getType());
assertEquals(1, schema.getMinItems());
assertEquals(5, schema.getMaxItems());
}

@Test
public void fromOpenApiSchemaShouldConvertObjectFields() {
String openApiSchema = """
{
"type": "OBJECT",
"properties": {
"property1": {
"type": "STRING"
},
"property2": {
"type": "INTEGER"
}
},
"minProperties": 1,
"maxProperties": 2,
"required": ["property1"],
"propertyOrdering": ["property1", "property2"]
}""";

Schema schema = VertexAiSchemaConverter.fromOpenApiSchema(openApiSchema);

assertEquals(Type.OBJECT, schema.getType());
assertEquals(2, schema.getPropertiesMap().size());
assertEquals(1, schema.getMinProperties());
assertEquals(2, schema.getMaxProperties());
assertEquals(List.of("property1"), schema.getRequiredList());
assertEquals(List.of("property1", "property2"), schema.getPropertyOrderingList());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ The prefix `spring.ai.vertex.ai.gemini.chat` is the property prefix that lets yo

| spring.ai.vertex.ai.gemini.chat.options.model | Supported https://cloud.google.com/vertex-ai/generative-ai/docs/models#gemini-models[Vertex AI Gemini Chat model] to use include the `gemini-2.0-flash`, `gemini-2.0-flash-lite` and the new `gemini-2.5-pro-preview-03-25`, `gemini-2.5-flash-preview-04-17` models. | gemini-2.0-flash
| spring.ai.vertex.ai.gemini.chat.options.response-mime-type | Output response mimetype of the generated candidate text. | `text/plain`: (default) Text output or `application/json`: JSON response.
| spring.ai.vertex.ai.gemini.chat.options.response-schema | String, containing the output response schema in OpenAPI format, as described in https://ai.google.dev/gemini-api/docs/structured-output#json-schemas. | -
| spring.ai.vertex.ai.gemini.chat.options.google-search-retrieval | Use Google search Grounding feature | `true` or `false`, default `false`.
| spring.ai.vertex.ai.gemini.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising responses from the generative. This value specifies default to be used by the backend while making the call to the generative. | 0.7
| spring.ai.vertex.ai.gemini.chat.options.top-k | The maximum number of tokens to consider when sampling. The generative uses combined Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens. | -
Expand Down