Skip to content

Commit 6cb7439

Browse files
author
wmz7year
committed
Amazon Bedrock Chat adds tool support.
1 parent 49b3326 commit 6cb7439

File tree

7 files changed

+785
-70
lines changed

7 files changed

+785
-70
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.bedrock;
17+
18+
import java.util.List;
19+
20+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
21+
22+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock.Type;
23+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
24+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
25+
import software.amazon.awssdk.services.bedrockruntime.model.Message;
26+
import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
27+
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
28+
29+
/**
30+
* Amazon Bedrock Chat model converse interface generation metadata, encapsulating
31+
* information on the completion.
32+
*
33+
* @author Wei Jiang
34+
* @since 1.0.0
35+
*/
36+
public class BedrockConverseChatGenerationMetadata implements ChatGenerationMetadata {
37+
38+
private String stopReason;
39+
40+
private Message message;
41+
42+
private ConverseStreamOutput event;
43+
44+
public BedrockConverseChatGenerationMetadata(String stopReason, ConverseStreamOutput event) {
45+
super();
46+
47+
this.stopReason = stopReason;
48+
this.event = event;
49+
}
50+
51+
public BedrockConverseChatGenerationMetadata(String stopReason, Message message) {
52+
super();
53+
54+
this.stopReason = stopReason;
55+
this.message = message;
56+
}
57+
58+
public static BedrockConverseChatGenerationMetadata from(ConverseResponse response, Message message) {
59+
return new BedrockConverseChatGenerationMetadata(response.stopReasonAsString(), message);
60+
}
61+
62+
public static BedrockConverseChatGenerationMetadata from(ConverseStreamOutput event) {
63+
String stopReason = null;
64+
65+
if (event instanceof MessageStopEvent messageStopEvent) {
66+
stopReason = messageStopEvent.stopReasonAsString();
67+
}
68+
69+
return new BedrockConverseChatGenerationMetadata(stopReason, event);
70+
}
71+
72+
@Override
73+
public <T> T getContentFilterMetadata() {
74+
return null;
75+
}
76+
77+
@Override
78+
public String getFinishReason() {
79+
return stopReason;
80+
}
81+
82+
public Message getMessage() {
83+
if (message != null) {
84+
return message;
85+
}
86+
else {
87+
System.out.println(event);
88+
return message;
89+
}
90+
}
91+
92+
public ConverseStreamOutput getEvent() {
93+
return event;
94+
}
95+
96+
public List<ToolUseBlock> getToolToUseList() {
97+
return message.content()
98+
.stream()
99+
.filter(content -> content.type() == Type.TOOL_USE)
100+
.map(content -> content.toolUse())
101+
.toList();
102+
}
103+
104+
}

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,20 @@
1515
*/
1616
package org.springframework.ai.bedrock.anthropic3;
1717

18+
import com.fasterxml.jackson.annotation.JsonIgnore;
1819
import com.fasterxml.jackson.annotation.JsonInclude;
1920
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2021
import com.fasterxml.jackson.annotation.JsonProperty;
2122
import org.springframework.ai.chat.prompt.ChatOptions;
23+
import org.springframework.ai.model.function.FunctionCallback;
24+
import org.springframework.ai.model.function.FunctionCallingOptions;
25+
import org.springframework.boot.context.properties.NestedConfigurationProperty;
26+
import org.springframework.util.Assert;
2227

28+
import java.util.ArrayList;
29+
import java.util.HashSet;
2330
import java.util.List;
31+
import java.util.Set;
2432

2533
/**
2634
* Java {@link ChatOptions} for the Bedrock Anthropic chat generative model chat options.
@@ -31,7 +39,7 @@
3139
* @since 1.0.0
3240
*/
3341
@JsonInclude(Include.NON_NULL)
34-
public class Anthropic3ChatOptions implements ChatOptions {
42+
public class Anthropic3ChatOptions implements ChatOptions, FunctionCallingOptions {
3543

3644
// @formatter:off
3745
/**
@@ -66,6 +74,31 @@ public class Anthropic3ChatOptions implements ChatOptions {
6674
*/
6775
private @JsonProperty("stop_sequences") List<String> stopSequences;
6876

77+
/**
78+
* Tool Function Callbacks to register with the ChatModel. For Prompt
79+
* Options the functionCallbacks are automatically enabled for the duration of the
80+
* prompt execution. For Default Options the functionCallbacks are registered but
81+
* disabled by default. Use the enableFunctions to set the functions from the registry
82+
* to be used by the ChatModel chat completion requests.
83+
*/
84+
@NestedConfigurationProperty
85+
@JsonIgnore
86+
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
87+
88+
/**
89+
* List of functions, identified by their names, to configure for function calling in
90+
* the chat completion requests. Functions with those names must exist in the
91+
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
92+
* are automatically enabled for the duration of the prompt execution.
93+
*
94+
* Note that function enabled with the default options are enabled for all chat
95+
* completion requests. This could impact the token count and the billing. If the
96+
* functions is set in a prompt options, then the enabled functions are only active
97+
* for the duration of this prompt execution.
98+
*/
99+
@NestedConfigurationProperty
100+
@JsonIgnore
101+
private Set<String> functions = new HashSet<>();
69102
// @formatter:on
70103

71104
public static Builder builder() {
@@ -101,6 +134,23 @@ public Builder withStopSequences(List<String> stopSequences) {
101134
return this;
102135
}
103136

137+
public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
138+
this.options.functionCallbacks = functionCallbacks;
139+
return this;
140+
}
141+
142+
public Builder withFunctions(Set<String> functionNames) {
143+
Assert.notNull(functionNames, "Function names must not be null");
144+
this.options.functions = functionNames;
145+
return this;
146+
}
147+
148+
public Builder withFunction(String functionName) {
149+
Assert.hasText(functionName, "Function name must not be empty");
150+
this.options.functions.add(functionName);
151+
return this;
152+
}
153+
104154
public Anthropic3ChatOptions build() {
105155
return this.options;
106156
}
@@ -150,12 +200,36 @@ public void setStopSequences(List<String> stopSequences) {
150200
this.stopSequences = stopSequences;
151201
}
152202

203+
@Override
204+
public List<FunctionCallback> getFunctionCallbacks() {
205+
return this.functionCallbacks;
206+
}
207+
208+
@Override
209+
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
210+
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
211+
this.functionCallbacks = functionCallbacks;
212+
}
213+
214+
@Override
215+
public Set<String> getFunctions() {
216+
return this.functions;
217+
}
218+
219+
@Override
220+
public void setFunctions(Set<String> functions) {
221+
Assert.notNull(functions, "Function must not be null");
222+
this.functions = functions;
223+
}
224+
153225
public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) {
154226
return builder().withTemperature(fromOptions.getTemperature())
155227
.withMaxTokens(fromOptions.getMaxTokens())
156228
.withTopK(fromOptions.getTopK())
157229
.withTopP(fromOptions.getTopP())
158230
.withStopSequences(fromOptions.getStopSequences())
231+
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
232+
.withFunctions(fromOptions.getFunctions())
159233
.build();
160234
}
161235

0 commit comments

Comments
 (0)