Skip to content

Commit e4ee01a

Browse files
author
wmz7year
committed
Add Amazon Bedrock Converse API support.
1 parent fa53e2a commit e4ee01a

File tree

7 files changed

+454
-1
lines changed

7 files changed

+454
-1
lines changed

models/spring-ai-bedrock/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
<version>${project.parent.version}</version>
3030
</dependency>
3131

32+
<dependency>
33+
<groupId>org.springframework.ai</groupId>
34+
<artifactId>spring-ai-retry</artifactId>
35+
<version>${project.parent.version}</version>
36+
</dependency>
37+
3238
<dependency>
3339
<groupId>org.springframework</groupId>
3440
<artifactId>spring-web</artifactId>
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
// @formatter:off
17+
package org.springframework.ai.bedrock.api;
18+
19+
import java.time.Duration;
20+
21+
import org.slf4j.Logger;
22+
import org.slf4j.LoggerFactory;
23+
import org.springframework.ai.retry.RetryUtils;
24+
import org.springframework.retry.support.RetryTemplate;
25+
import org.springframework.util.Assert;
26+
27+
import reactor.core.publisher.Flux;
28+
import reactor.core.publisher.Sinks;
29+
import reactor.core.publisher.Sinks.EmitFailureHandler;
30+
import reactor.core.publisher.Sinks.EmitResult;
31+
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
32+
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
33+
import software.amazon.awssdk.regions.Region;
34+
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
35+
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
36+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
37+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
38+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
39+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
40+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
41+
42+
/**
43+
* Amazon Bedrock Converse API, It provides the basic functionality to invoke the Bedrock
44+
* AI model and receive the response for streaming and non-streaming requests.
45+
* The Converse API doesn't support any embedding models (such as Titan Embeddings G1 - Text)
46+
* or image generation models (such as Stability AI).
47+
*
48+
* <p>
49+
* https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
50+
* <p>
51+
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
52+
* <p>
53+
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html
54+
* <p>
55+
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
56+
* <p>
57+
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
58+
*
59+
* @author Wei Jiang
60+
* @since 1.0.0
61+
*/
62+
public class BedrockConverseApi {
63+
64+
private static final Logger logger = LoggerFactory.getLogger(BedrockConverseApi.class);
65+
66+
private final Region region;
67+
68+
private final BedrockRuntimeClient client;
69+
70+
private final BedrockRuntimeAsyncClient clientStreaming;
71+
72+
private final RetryTemplate retryTemplate;
73+
74+
/**
75+
* Create a new BedrockConverseApi instance using default credentials provider.
76+
*
77+
* @param region The AWS region to use.
78+
*/
79+
public BedrockConverseApi(String region) {
80+
this(ProfileCredentialsProvider.builder().build(), region, Duration.ofMinutes(5));
81+
}
82+
83+
/**
84+
* Create a new BedrockConverseApi instance using default credentials provider.
85+
*
86+
* @param region The AWS region to use.
87+
* @param timeout The timeout to use.
88+
*/
89+
public BedrockConverseApi(String region, Duration timeout) {
90+
this(ProfileCredentialsProvider.builder().build(), region, timeout);
91+
}
92+
93+
/**
94+
* Create a new BedrockConverseApi instance using the provided credentials provider,
95+
* region.
96+
*
97+
* @param credentialsProvider The credentials provider to connect to AWS.
98+
* @param region The AWS region to use.
99+
*/
100+
public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, String region) {
101+
this(credentialsProvider, region, Duration.ofMinutes(5));
102+
}
103+
104+
/**
105+
* Create a new BedrockConverseApi instance using the provided credentials provider,
106+
* region.
107+
*
108+
* @param credentialsProvider The credentials provider to connect to AWS.
109+
* @param region The AWS region to use.
110+
* @param timeout Configure the amount of time to allow the client to complete the
111+
* execution of an API call. This timeout covers the entire client execution except
112+
* for marshalling. This includes request handler execution, all HTTP requests
113+
* including retries, unmarshalling, etc. This value should always be positive, if
114+
* present.
115+
*/
116+
public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, String region, Duration timeout) {
117+
this(credentialsProvider, Region.of(region), timeout);
118+
}
119+
120+
/**
121+
* Create a new BedrockConverseApi instance using the provided credentials provider,
122+
* region.
123+
*
124+
* @param credentialsProvider The credentials provider to connect to AWS.
125+
* @param region The AWS region to use.
126+
* @param timeout Configure the amount of time to allow the client to complete the
127+
* execution of an API call. This timeout covers the entire client execution except
128+
* for marshalling. This includes request handler execution, all HTTP requests
129+
* including retries, unmarshalling, etc. This value should always be positive, if
130+
* present.
131+
*/
132+
public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, Region region, Duration timeout) {
133+
this(credentialsProvider, region, timeout, RetryUtils.DEFAULT_RETRY_TEMPLATE);
134+
}
135+
136+
/**
137+
* Create a new BedrockConverseApi instance using the provided credentials provider,
138+
* region
139+
*
140+
* @param credentialsProvider The credentials provider to connect to AWS.
141+
* @param region The AWS region to use.
142+
* @param timeout Configure the amount of time to allow the client to complete the
143+
* execution of an API call. This timeout covers the entire client execution except
144+
* for marshalling. This includes request handler execution, all HTTP requests
145+
* including retries, unmarshalling, etc. This value should always be positive, if
146+
* present.
147+
* @param retryTemplate The retry template used to retry the Amazon Bedrock Converse
148+
* API calls.
149+
*/
150+
public BedrockConverseApi(AwsCredentialsProvider credentialsProvider, Region region, Duration timeout,
151+
RetryTemplate retryTemplate) {
152+
Assert.notNull(credentialsProvider, "Credentials provider must not be null");
153+
Assert.notNull(region, "Region must not be empty");
154+
Assert.notNull(timeout, "Timeout must not be null");
155+
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
156+
157+
this.region = region;
158+
this.retryTemplate = retryTemplate;
159+
160+
this.client = BedrockRuntimeClient.builder()
161+
.region(this.region)
162+
.credentialsProvider(credentialsProvider)
163+
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
164+
.build();
165+
166+
this.clientStreaming = BedrockRuntimeAsyncClient.builder()
167+
.region(this.region)
168+
.credentialsProvider(credentialsProvider)
169+
.overrideConfiguration(c -> c.apiCallTimeout(timeout))
170+
.build();
171+
}
172+
173+
/**
174+
* @return The AWS region.
175+
*/
176+
public Region getRegion() {
177+
return this.region;
178+
}
179+
180+
/**
181+
* Invoke the model and return the response.
182+
*
183+
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
184+
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
185+
* https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient.html#converse
186+
* @param converseRequest Model invocation request.
187+
* @return The model invocation response.
188+
*/
189+
public ConverseResponse converse(ConverseRequest converseRequest) {
190+
Assert.notNull(converseRequest, "'converseRequest' must not be null");
191+
192+
return this.retryTemplate.execute(ctx -> {
193+
return client.converse(converseRequest);
194+
});
195+
}
196+
197+
/**
198+
* Invoke the model and return the response stream.
199+
*
200+
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
201+
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
202+
* https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient.html#converseStream
203+
* @param converseStreamRequest Model invocation request.
204+
* @return The model invocation response stream.
205+
*/
206+
public Flux<ConverseStreamOutput> converseStream(ConverseStreamRequest converseStreamRequest) {
207+
Assert.notNull(converseStreamRequest, "'converseStreamRequest' must not be null");
208+
209+
return this.retryTemplate.execute(ctx -> {
210+
Sinks.Many<ConverseStreamOutput> eventSink = Sinks.many().multicast().onBackpressureBuffer();
211+
212+
ConverseStreamResponseHandler.Visitor visitor = ConverseStreamResponseHandler.Visitor.builder()
213+
.onDefault((output) -> {
214+
logger.debug("Received converse stream output:{}", output);
215+
eventSink.tryEmitNext(output);
216+
})
217+
.build();
218+
219+
ConverseStreamResponseHandler responseHandler = ConverseStreamResponseHandler.builder()
220+
.onEventStream(stream -> stream.subscribe((e) -> e.accept(visitor)))
221+
.onComplete(() -> {
222+
EmitResult emitResult = eventSink.tryEmitComplete();
223+
224+
while (!emitResult.isSuccess()) {
225+
logger.debug("Emitting complete:{}", emitResult);
226+
emitResult = eventSink.tryEmitComplete();
227+
}
228+
229+
eventSink.emitComplete(EmitFailureHandler.busyLooping(Duration.ofSeconds(3)));
230+
logger.debug("Completed streaming response.");
231+
})
232+
.onError((error) -> {
233+
logger.error("Error handling Bedrock converse stream response", error);
234+
eventSink.tryEmitError(error);
235+
})
236+
.build();
237+
238+
clientStreaming.converseStream(converseStreamRequest, responseHandler);
239+
240+
return eventSink.asFlux();
241+
});
242+
}
243+
244+
}
245+
//@formatter:on
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.bedrock.api;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
20+
import java.util.List;
21+
22+
import org.junit.jupiter.api.Test;
23+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
24+
25+
import reactor.core.publisher.Flux;
26+
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
27+
import software.amazon.awssdk.regions.Region;
28+
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
29+
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
30+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
31+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
32+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
33+
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
34+
import software.amazon.awssdk.services.bedrockruntime.model.Message;
35+
36+
/**
37+
* @author Wei Jiang
38+
*/
39+
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
40+
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
41+
public class BedrockConverseApiIT {
42+
43+
private BedrockConverseApi converseApi = new BedrockConverseApi(EnvironmentVariableCredentialsProvider.create(),
44+
Region.US_EAST_1.id());
45+
46+
@Test
47+
public void testConverse() {
48+
ContentBlock contentBlock = ContentBlock.builder().text("Give me the names of 3 famous pirates?").build();
49+
50+
Message message = Message.builder().content(contentBlock).role(ConversationRole.USER).build();
51+
52+
ConverseRequest request = ConverseRequest.builder()
53+
.modelId("anthropic.claude-3-sonnet-20240229-v1:0")
54+
.messages(List.of(message))
55+
.build();
56+
57+
ConverseResponse response = converseApi.converse(request);
58+
59+
assertThat(response).isNotNull();
60+
assertThat(response.output()).isNotNull();
61+
assertThat(response.output().message()).isNotNull();
62+
assertThat(response.output().message().content()).isNotEmpty();
63+
assertThat(response.output().message().content().get(0).text()).contains("Blackbeard");
64+
assertThat(response.stopReason()).isNotNull();
65+
assertThat(response.usage()).isNotNull();
66+
assertThat(response.usage().inputTokens()).isGreaterThan(10);
67+
assertThat(response.usage().outputTokens()).isGreaterThan(30);
68+
}
69+
70+
@Test
71+
public void testConverseStream() {
72+
ContentBlock contentBlock = ContentBlock.builder().text("Give me the names of 3 famous pirates?").build();
73+
74+
Message message = Message.builder().content(contentBlock).role(ConversationRole.USER).build();
75+
76+
ConverseStreamRequest request = ConverseStreamRequest.builder()
77+
.modelId("anthropic.claude-3-sonnet-20240229-v1:0")
78+
.messages(List.of(message))
79+
.build();
80+
81+
Flux<ConverseStreamOutput> responseStream = converseApi.converseStream(request);
82+
83+
List<ConverseStreamOutput> responseOutputs = responseStream.collectList().block();
84+
85+
assertThat(responseOutputs).isNotNull();
86+
assertThat(responseOutputs).hasSizeGreaterThan(10);
87+
}
88+
89+
}

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@
136136
<azure-open-ai-client.version>1.0.0-beta.8</azure-open-ai-client.version>
137137
<jtokkit.version>1.0.0</jtokkit.version>
138138
<victools.version>4.31.1</victools.version>
139-
<bedrockruntime.version>2.25.3</bedrockruntime.version>
139+
<bedrockruntime.version>2.25.64</bedrockruntime.version>
140140
<jackson.version>2.16.1</jackson.version>
141141
<djl.version>0.26.0</djl.version>
142142
<onnxruntime.version>1.17.0</onnxruntime.version>

0 commit comments

Comments
 (0)