Skip to content

Google Vertex AI toolcalling token usage #2677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.messages.AssistantMessage;
Expand All @@ -60,7 +59,9 @@
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
Expand All @@ -71,12 +72,11 @@
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.content.Media;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.LegacyToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
Expand Down Expand Up @@ -136,12 +136,13 @@
* @author Soby Chacko
* @author Jihoon Kim
* @author Alexandros Pappas
* @author Ilayaperumal Gopinathan
* @since 0.8.1
* @see VertexAiGeminiChatOptions
* @see ToolCallingManager
* @see ChatModel
*/
public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements ChatModel, DisposableBean {
public class VertexAiGeminiChatModel implements ChatModel, DisposableBean {

private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();

Expand Down Expand Up @@ -277,8 +278,6 @@ public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions defa
ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry,
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {

super(null, VertexAiGeminiChatOptions.builder().build(), List.of());

Assert.notNull(vertexAI, "VertexAI must not be null");
Assert.notNull(defaultOptions, "VertexAiGeminiChatOptions must not be null");
Assert.notNull(defaultOptions.getModel(), "VertexAiGeminiChatOptions.modelName must not be null");
Expand Down Expand Up @@ -425,10 +424,10 @@ private static Schema jsonToSchema(String json) {
@Override
public ChatResponse call(Prompt prompt) {
var requestPrompt = this.buildRequestPrompt(prompt);
return this.internalCall(requestPrompt);
return this.internalCall(requestPrompt, null);
}

private ChatResponse internalCall(Prompt prompt) {
private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
Expand All @@ -451,8 +450,12 @@ private ChatResponse internalCall(Prompt prompt) {
.flatMap(List::stream)
.toList();

ChatResponse chatResponse = new ChatResponse(generations,
toChatResponseMetadata(generateContentResponse));
GenerateContentResponse.UsageMetadata usage = generateContentResponse.getUsageMetadata();
Usage currentUsage = (usage != null)
? new DefaultUsage(usage.getPromptTokenCount(), usage.getCandidatesTokenCount())
: new EmptyUsage();
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(cumulativeUsage));

observationContext.setResponse(chatResponse);
return chatResponse;
Expand All @@ -469,7 +472,8 @@ private ChatResponse internalCall(Prompt prompt) {
}
else {
// Send the tool execution result back to the model.
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()));
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
}
}

Expand All @@ -485,10 +489,6 @@ Prompt buildRequestPrompt(Prompt prompt) {
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
VertexAiGeminiChatOptions.class);
}
else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
VertexAiGeminiChatOptions.class);
}
else {
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
VertexAiGeminiChatOptions.class);
Expand Down Expand Up @@ -535,10 +535,10 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOp
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
var requestPrompt = this.buildRequestPrompt(prompt);
return this.internalStream(requestPrompt);
return this.internalStream(requestPrompt, null);
}

public Flux<ChatResponse> internalStream(Prompt prompt) {
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return Flux.deferContextual(contextView -> {

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
Expand All @@ -559,21 +559,22 @@ public Flux<ChatResponse> internalStream(Prompt prompt) {
ResponseStream<GenerateContentResponse> responseStream = request.model
.generateContentStream(request.contents);

Flux<ChatResponse> chatResponse1 = Flux.fromStream(responseStream.stream())
.switchMap(response2 -> Mono.just(response2).map(response -> {

List<Generation> generations = response.getCandidatesList()
.stream()
.map(this::responseCandidateToGeneration)
.flatMap(List::stream)
.toList();

return new ChatResponse(generations, toChatResponseMetadata(response));
Flux<ChatResponse> chatResponseFlux = Flux.fromStream(responseStream.stream()).switchMap(response -> {
List<Generation> generations = response.getCandidatesList()
.stream()
.map(this::responseCandidateToGeneration)
.flatMap(List::stream)
.toList();

}));
GenerateContentResponse.UsageMetadata usage = response.getUsageMetadata();
Usage currentUsage = (usage != null) ? getDefaultUsage(usage) : new EmptyUsage();
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(cumulativeUsage));
return Flux.just(chatResponse);
});

// @formatter:off
Flux<ChatResponse> chatResponseFlux = chatResponse1.flatMap(response -> {
Flux<ChatResponse> flux = chatResponseFlux.flatMap(response -> {
if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
Expand All @@ -586,7 +587,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt) {
.build());
} else {
// Send the tool execution result back to the model.
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()));
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response);
}
}).subscribeOn(Schedulers.boundedElastic());
}
Expand All @@ -599,7 +600,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt) {
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
// @formatter:on;

return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
return new MessageAggregator().aggregate(flux, observationContext::setResponse);

}
catch (Exception e) {
Expand Down Expand Up @@ -653,8 +654,8 @@ protected List<Generation> responseCandidateToGeneration(Candidate candidate) {
}
}

private ChatResponseMetadata toChatResponseMetadata(GenerateContentResponse response) {
return ChatResponseMetadata.builder().usage(getDefaultUsage(response.getUsageMetadata())).build();
private ChatResponseMetadata toChatResponseMetadata(Usage usage) {
return ChatResponseMetadata.builder().usage(usage).build();
}

private DefaultUsage getDefaultUsage(GenerateContentResponse.UsageMetadata usageMetadata) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,15 @@ public void functionCallTestInferredOpenApiSchema() {
.build()))
.build();

ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions));
ChatResponse chatResponse = this.chatModel.call(new Prompt(messages, promptOptions));

logger.info("Response: {}", response);
assertThat(chatResponse).isNotNull();
logger.info("Response: {}", chatResponse);
assertThat(chatResponse.getResult().getOutput().getText()).contains("30", "10", "15");

assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15");
assertThat(chatResponse.getMetadata()).isNotNull();
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(310);

ChatResponse response2 = this.chatModel
.call(new Prompt("What is the payment status for transaction 696?", promptOptions));
Expand Down Expand Up @@ -166,6 +170,41 @@ public void functionCallTestInferredOpenApiSchemaStream() {

}

@Test
public void functionCallUsageTestInferredOpenApiSchemaStream() {

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Paris and in Tokyo? Return the temperature in Celsius.");

List<Message> messages = new ArrayList<>(List.of(userMessage));

var promptOptions = VertexAiGeminiChatOptions.builder()
.model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH)
.toolCallbacks(List.of(
FunctionToolCallback.builder("get_current_weather", new MockWeatherService())
.description("Get the current weather in a given location.")
.inputType(MockWeatherService.Request.class)
.build(),
FunctionToolCallback.builder("get_payment_status", new PaymentStatus())
.description(
"Retrieves the payment status for transaction. For example what is the payment status for transaction 700?")
.inputType(PaymentInfoRequest.class)
.build()))
.build();

Flux<ChatResponse> response = this.chatModel.stream(new Prompt(messages, promptOptions));

ChatResponse chatResponse = response.blockLast();

logger.info("Response: {}", chatResponse);

assertThat(chatResponse).isNotNull();
assertThat(chatResponse.getMetadata()).isNotNull();
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(310);

}

public record PaymentInfoRequest(String id) {

}
Expand Down