Skip to content

Commit 17d56b0

Browse files
author
wmz7year
committed
Add Bedrock Cohere Command R model support.
1 parent 76e7e0b commit 17d56b0

File tree

15 files changed

+2017
-0
lines changed

15 files changed

+2017
-0
lines changed

models/spring-ai-bedrock/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
<version>${project.parent.version}</version>
3030
</dependency>
3131

32+
<dependency>
33+
<groupId>org.springframework.ai</groupId>
34+
<artifactId>spring-ai-retry</artifactId>
35+
<version>${project.parent.version}</version>
36+
</dependency>
37+
3238
<dependency>
3339
<groupId>org.springframework</groupId>
3440
<artifactId>spring-web</artifactId>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.bedrock.cohere;
17+
18+
import java.util.List;
19+
20+
import org.springframework.ai.bedrock.BedrockUsage;
21+
import org.springframework.ai.bedrock.MessageToPromptConverter;
22+
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi;
23+
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest;
24+
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatResponse;
25+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
26+
import org.springframework.ai.chat.metadata.Usage;
27+
import org.springframework.ai.chat.model.ChatModel;
28+
import org.springframework.ai.chat.model.ChatResponse;
29+
import org.springframework.ai.chat.model.Generation;
30+
import org.springframework.ai.chat.model.StreamingChatModel;
31+
import org.springframework.ai.chat.prompt.ChatOptions;
32+
import org.springframework.ai.chat.prompt.Prompt;
33+
import org.springframework.ai.model.ModelOptionsUtils;
34+
import org.springframework.ai.retry.RetryUtils;
35+
import org.springframework.retry.support.RetryTemplate;
36+
import org.springframework.util.Assert;
37+
38+
import reactor.core.publisher.Flux;
39+
40+
/**
41+
* @author Wei Jiang
42+
* @since 1.0.0
43+
*/
44+
public class BedrockCohereCommandRChatModel implements ChatModel, StreamingChatModel {
45+
46+
private final CohereCommandRChatBedrockApi chatApi;
47+
48+
private final BedrockCohereCommandRChatOptions defaultOptions;
49+
50+
/**
51+
* The retry template used to retry the Bedrock API calls.
52+
*/
53+
private final RetryTemplate retryTemplate;
54+
55+
public BedrockCohereCommandRChatModel(CohereCommandRChatBedrockApi chatApi) {
56+
this(chatApi, BedrockCohereCommandRChatOptions.builder().build());
57+
}
58+
59+
public BedrockCohereCommandRChatModel(CohereCommandRChatBedrockApi chatApi,
60+
BedrockCohereCommandRChatOptions options) {
61+
this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
62+
}
63+
64+
public BedrockCohereCommandRChatModel(CohereCommandRChatBedrockApi chatApi,
65+
BedrockCohereCommandRChatOptions options, RetryTemplate retryTemplate) {
66+
Assert.notNull(chatApi, "CohereCommandRChatBedrockApi must not be null");
67+
Assert.notNull(options, "BedrockCohereCommandRChatOptions must not be null");
68+
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
69+
70+
this.chatApi = chatApi;
71+
this.defaultOptions = options;
72+
this.retryTemplate = retryTemplate;
73+
}
74+
75+
@Override
76+
public ChatResponse call(Prompt prompt) {
77+
CohereCommandRChatRequest request = this.createRequest(prompt);
78+
79+
return this.retryTemplate.execute(ctx -> {
80+
CohereCommandRChatResponse response = this.chatApi.chatCompletion(request);
81+
82+
Generation generation = new Generation(response.text());
83+
84+
return new ChatResponse(List.of(generation));
85+
});
86+
}
87+
88+
@Override
89+
public Flux<ChatResponse> stream(Prompt prompt) {
90+
CohereCommandRChatRequest request = this.createRequest(prompt);
91+
92+
return this.retryTemplate.execute(ctx -> {
93+
return this.chatApi.chatCompletionStream(request).map(g -> {
94+
if (g.isFinished()) {
95+
String finishReason = g.finishReason().name();
96+
Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics());
97+
return new ChatResponse(List.of(new Generation("")
98+
.withGenerationMetadata(ChatGenerationMetadata.from(finishReason, usage))));
99+
}
100+
return new ChatResponse(List.of(new Generation(g.text())));
101+
});
102+
});
103+
}
104+
105+
CohereCommandRChatRequest createRequest(Prompt prompt) {
106+
final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());
107+
108+
var request = CohereCommandRChatRequest.builder(promptValue)
109+
.withSearchQueriesOnly(this.defaultOptions.getSearchQueriesOnly())
110+
.withPreamble(this.defaultOptions.getPreamble())
111+
.withMaxTokens(this.defaultOptions.getMaxTokens())
112+
.withTemperature(this.defaultOptions.getTemperature())
113+
.withTopP(this.defaultOptions.getTopP())
114+
.withTopK(this.defaultOptions.getTopK())
115+
.withPromptTruncation(this.defaultOptions.getPromptTruncation())
116+
.withFrequencyPenalty(this.defaultOptions.getFrequencyPenalty())
117+
.withPresencePenalty(this.defaultOptions.getPresencePenalty())
118+
.withSeed(this.defaultOptions.getSeed())
119+
.withReturnPrompt(this.defaultOptions.getReturnPrompt())
120+
.withStopSequences(this.defaultOptions.getStopSequences())
121+
.withRawPrompting(this.defaultOptions.getRawPrompting())
122+
.build();
123+
124+
if (prompt.getOptions() != null) {
125+
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
126+
BedrockCohereCommandRChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
127+
ChatOptions.class, BedrockCohereCommandRChatOptions.class);
128+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, CohereCommandRChatRequest.class);
129+
}
130+
else {
131+
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
132+
+ prompt.getOptions().getClass().getSimpleName());
133+
}
134+
}
135+
136+
return request;
137+
}
138+
139+
@Override
140+
public ChatOptions getDefaultOptions() {
141+
return BedrockCohereCommandRChatOptions.fromOptions(defaultOptions);
142+
}
143+
144+
}

0 commit comments

Comments
 (0)