Skip to content

Commit 9a0feaa

Browse files
author
wmz7year
committed
Add Bedrock Cohere Command R model support.
1 parent 7252ba1 commit 9a0feaa

File tree

15 files changed

+1962
-0
lines changed

15 files changed

+1962
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.bedrock.cohere;
17+
18+
import java.util.List;
19+
20+
import org.springframework.ai.bedrock.BedrockUsage;
21+
import org.springframework.ai.bedrock.MessageToPromptConverter;
22+
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi;
23+
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest;
24+
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatResponse;
25+
import org.springframework.ai.chat.ChatClient;
26+
import org.springframework.ai.chat.ChatResponse;
27+
import org.springframework.ai.chat.Generation;
28+
import org.springframework.ai.chat.StreamingChatClient;
29+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
30+
import org.springframework.ai.chat.metadata.Usage;
31+
import org.springframework.ai.chat.prompt.ChatOptions;
32+
import org.springframework.ai.chat.prompt.Prompt;
33+
import org.springframework.ai.model.ModelOptionsUtils;
34+
import org.springframework.util.Assert;
35+
36+
import reactor.core.publisher.Flux;
37+
38+
/**
39+
* @author Wei Jiang
40+
* @since 1.0.0
41+
*/
42+
public class BedrockCohereCommandRChatClient implements ChatClient, StreamingChatClient {
43+
44+
private final CohereCommandRChatBedrockApi chatApi;
45+
46+
private final BedrockCohereCommandRChatOptions defaultOptions;
47+
48+
public BedrockCohereCommandRChatClient(CohereCommandRChatBedrockApi chatApi) {
49+
this(chatApi, BedrockCohereCommandRChatOptions.builder().build());
50+
}
51+
52+
public BedrockCohereCommandRChatClient(CohereCommandRChatBedrockApi chatApi,
53+
BedrockCohereCommandRChatOptions options) {
54+
Assert.notNull(chatApi, "CohereCommandRChatBedrockApi must not be null");
55+
Assert.notNull(options, "BedrockCohereCommandRChatOptions must not be null");
56+
57+
this.chatApi = chatApi;
58+
this.defaultOptions = options;
59+
}
60+
61+
@Override
62+
public ChatResponse call(Prompt prompt) {
63+
CohereCommandRChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt));
64+
65+
Generation generation = new Generation(response.text());
66+
67+
return new ChatResponse(List.of(generation));
68+
}
69+
70+
@Override
71+
public Flux<ChatResponse> stream(Prompt prompt) {
72+
return this.chatApi.chatCompletionStream(this.createRequest(prompt)).map(g -> {
73+
if (g.isFinished()) {
74+
String finishReason = g.finishReason().name();
75+
Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics());
76+
return new ChatResponse(List
77+
.of(new Generation("").withGenerationMetadata(ChatGenerationMetadata.from(finishReason, usage))));
78+
}
79+
return new ChatResponse(List.of(new Generation(g.text())));
80+
});
81+
}
82+
83+
CohereCommandRChatRequest createRequest(Prompt prompt) {
84+
final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions());
85+
86+
var request = CohereCommandRChatRequest.builder(promptValue)
87+
.withSearchQueriesOnly(this.defaultOptions.getSearchQueriesOnly())
88+
.withPreamble(this.defaultOptions.getPreamble())
89+
.withMaxTokens(this.defaultOptions.getMaxTokens())
90+
.withTemperature(this.defaultOptions.getTemperature())
91+
.withTopP(this.defaultOptions.getTopP())
92+
.withTopK(this.defaultOptions.getTopK())
93+
.withPromptTruncation(this.defaultOptions.getPromptTruncation())
94+
.withFrequencyPenalty(this.defaultOptions.getFrequencyPenalty())
95+
.withPresencePenalty(this.defaultOptions.getPresencePenalty())
96+
.withSeed(this.defaultOptions.getSeed())
97+
.withReturnPrompt(this.defaultOptions.getReturnPrompt())
98+
.withStopSequences(this.defaultOptions.getStopSequences())
99+
.withRawPrompting(this.defaultOptions.getRawPrompting())
100+
.build();
101+
102+
if (prompt.getOptions() != null) {
103+
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
104+
BedrockCohereCommandRChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
105+
ChatOptions.class, BedrockCohereCommandRChatOptions.class);
106+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, CohereCommandRChatRequest.class);
107+
}
108+
else {
109+
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
110+
+ prompt.getOptions().getClass().getSimpleName());
111+
}
112+
}
113+
114+
return request;
115+
}
116+
117+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.bedrock.cohere;
17+
18+
import java.util.List;
19+
20+
import com.fasterxml.jackson.annotation.JsonInclude;
21+
import com.fasterxml.jackson.annotation.JsonProperty;
22+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
23+
24+
import org.springframework.ai.bedrock.cohere.api.CohereCommandRChatBedrockApi.CohereCommandRChatRequest.PromptTruncation;
25+
import org.springframework.ai.chat.prompt.ChatOptions;
26+
27+
/**
28+
* @author Wei Jiang
29+
* @since 1.0.0
30+
*/
31+
@JsonInclude(Include.NON_NULL)
32+
public class BedrockCohereCommandRChatOptions implements ChatOptions {
33+
34+
// @formatter:off
35+
/**
36+
* (optional) When enabled, it will only generate potential search queries without performing
37+
* searches or providing a response.
38+
*/
39+
@JsonProperty("search_queries_only") Boolean searchQueriesOnly;
40+
/**
41+
* (optional) Overrides the default preamble for search query generation.
42+
*/
43+
@JsonProperty("preamble") String preamble;
44+
/**
45+
* (optional) Specify the maximum number of tokens to use in the generated response.
46+
*/
47+
@JsonProperty("max_tokens") Integer maxTokens;
48+
/**
49+
* (optional) Use a lower value to decrease randomness in the response.
50+
*/
51+
@JsonProperty("temperature") Float temperature;
52+
/**
53+
* Top P. Use a lower value to ignore less probable options. Set to 0 or 1.0 to disable.
54+
*/
55+
@JsonProperty("p") Float topP;
56+
/**
57+
* Top K. Specify the number of token choices the model uses to generate the next token.
58+
*/
59+
@JsonProperty("k") Integer topK;
60+
/**
61+
* (optional) Dictates how the prompt is constructed.
62+
*/
63+
@JsonProperty("prompt_truncation") PromptTruncation promptTruncation;
64+
/**
65+
* (optional) Used to reduce repetitiveness of generated tokens.
66+
*/
67+
@JsonProperty("frequency_penalty") Float frequencyPenalty;
68+
/**
69+
* (optional) Used to reduce repetitiveness of generated tokens.
70+
*/
71+
@JsonProperty("presence_penalty") Float presencePenalty;
72+
/**
73+
* (optional) Specify the best effort to sample tokens deterministically.
74+
*/
75+
@JsonProperty("seed") Integer seed;
76+
/**
77+
* (optional) Specify true to return the full prompt that was sent to the model.
78+
*/
79+
@JsonProperty("return_prompt") Boolean returnPrompt;
80+
/**
81+
* (optional) A list of stop sequences.
82+
*/
83+
@JsonProperty("stop_sequences") List<String> stopSequences;
84+
/**
85+
* (optional) Specify true, to send the user’s message to the model without any preprocessing.
86+
*/
87+
@JsonProperty("raw_prompting") Boolean rawPrompting;
88+
// @formatter:on
89+
90+
public static Builder builder() {
91+
return new Builder();
92+
}
93+
94+
public static class Builder {
95+
96+
private final BedrockCohereCommandRChatOptions options = new BedrockCohereCommandRChatOptions();
97+
98+
public Builder withSearchQueriesOnly(Boolean searchQueriesOnly) {
99+
options.setSearchQueriesOnly(searchQueriesOnly);
100+
return this;
101+
}
102+
103+
public Builder withPreamble(String preamble) {
104+
options.setPreamble(preamble);
105+
return this;
106+
}
107+
108+
public Builder withMaxTokens(Integer maxTokens) {
109+
options.setMaxTokens(maxTokens);
110+
return this;
111+
}
112+
113+
public Builder withTemperature(Float temperature) {
114+
options.setTemperature(temperature);
115+
return this;
116+
}
117+
118+
public Builder withTopP(Float topP) {
119+
options.setTopP(topP);
120+
return this;
121+
}
122+
123+
public Builder withTopK(Integer topK) {
124+
options.setTopK(topK);
125+
return this;
126+
}
127+
128+
public Builder withPromptTruncation(PromptTruncation promptTruncation) {
129+
options.setPromptTruncation(promptTruncation);
130+
return this;
131+
}
132+
133+
public Builder withFrequencyPenalty(Float frequencyPenalty) {
134+
options.setFrequencyPenalty(frequencyPenalty);
135+
return this;
136+
}
137+
138+
public Builder withPresencePenalty(Float presencePenalty) {
139+
options.setPresencePenalty(presencePenalty);
140+
return this;
141+
}
142+
143+
public Builder withSeed(Integer seed) {
144+
options.setSeed(seed);
145+
return this;
146+
}
147+
148+
public Builder withReturnPrompt(Boolean returnPrompt) {
149+
options.setReturnPrompt(returnPrompt);
150+
return this;
151+
}
152+
153+
public Builder withStopSequences(List<String> stopSequences) {
154+
options.setStopSequences(stopSequences);
155+
return this;
156+
}
157+
158+
public Builder withRawPrompting(Boolean rawPrompting) {
159+
options.setRawPrompting(rawPrompting);
160+
return this;
161+
}
162+
163+
public BedrockCohereCommandRChatOptions build() {
164+
return this.options;
165+
}
166+
167+
}
168+
169+
public Boolean getSearchQueriesOnly() {
170+
return searchQueriesOnly;
171+
}
172+
173+
public void setSearchQueriesOnly(Boolean searchQueriesOnly) {
174+
this.searchQueriesOnly = searchQueriesOnly;
175+
}
176+
177+
public String getPreamble() {
178+
return preamble;
179+
}
180+
181+
public void setPreamble(String preamble) {
182+
this.preamble = preamble;
183+
}
184+
185+
public Integer getMaxTokens() {
186+
return maxTokens;
187+
}
188+
189+
public void setMaxTokens(Integer maxTokens) {
190+
this.maxTokens = maxTokens;
191+
}
192+
193+
@Override
194+
public Float getTemperature() {
195+
return temperature;
196+
}
197+
198+
public void setTemperature(Float temperature) {
199+
this.temperature = temperature;
200+
}
201+
202+
@Override
203+
public Float getTopP() {
204+
return topP;
205+
}
206+
207+
public void setTopP(Float topP) {
208+
this.topP = topP;
209+
}
210+
211+
@Override
212+
public Integer getTopK() {
213+
return topK;
214+
}
215+
216+
public void setTopK(Integer topK) {
217+
this.topK = topK;
218+
}
219+
220+
public PromptTruncation getPromptTruncation() {
221+
return promptTruncation;
222+
}
223+
224+
public void setPromptTruncation(PromptTruncation promptTruncation) {
225+
this.promptTruncation = promptTruncation;
226+
}
227+
228+
public Float getFrequencyPenalty() {
229+
return frequencyPenalty;
230+
}
231+
232+
public void setFrequencyPenalty(Float frequencyPenalty) {
233+
this.frequencyPenalty = frequencyPenalty;
234+
}
235+
236+
public Float getPresencePenalty() {
237+
return presencePenalty;
238+
}
239+
240+
public void setPresencePenalty(Float presencePenalty) {
241+
this.presencePenalty = presencePenalty;
242+
}
243+
244+
public Integer getSeed() {
245+
return seed;
246+
}
247+
248+
public void setSeed(Integer seed) {
249+
this.seed = seed;
250+
}
251+
252+
public Boolean getReturnPrompt() {
253+
return returnPrompt;
254+
}
255+
256+
public void setReturnPrompt(Boolean returnPrompt) {
257+
this.returnPrompt = returnPrompt;
258+
}
259+
260+
public List<String> getStopSequences() {
261+
return stopSequences;
262+
}
263+
264+
public void setStopSequences(List<String> stopSequences) {
265+
this.stopSequences = stopSequences;
266+
}
267+
268+
public Boolean getRawPrompting() {
269+
return rawPrompting;
270+
}
271+
272+
public void setRawPrompting(Boolean rawPrompting) {
273+
this.rawPrompting = rawPrompting;
274+
}
275+
276+
}

0 commit comments

Comments
 (0)