Skip to content

Commit fb65ed0

Browse files
anders-swansonmarkpollack
authored andcommitted
Add OCI GenAI Cohere Chat integration
Adds Oracle Cloud Infrastructure (OCI) Generative AI's Cohere chat model support to expand Spring AI's cloud provider capabilities. This allows developers to use OCI's managed Cohere models through both dedicated and on-demand serving modes. The integration provides auto-configuration for simple setup while allowing full customization of model parameters through OCICohereChatOptions. Teams can now use OCI's Cohere models alongside other providers in Spring AI applications. This change complements the existing OCI embedding support, offering a complete set of GenAI capabilities for Oracle Cloud users. Signed-off-by: Anders Swanson <[email protected]>
1 parent 1cdec7b commit fb65ed0

File tree

17 files changed

+1212
-72
lines changed

17 files changed

+1212
-72
lines changed

models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@
2222
import java.util.concurrent.atomic.AtomicInteger;
2323

2424
import com.oracle.bmc.generativeaiinference.GenerativeAiInference;
25-
import com.oracle.bmc.generativeaiinference.model.DedicatedServingMode;
2625
import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails;
2726
import com.oracle.bmc.generativeaiinference.model.EmbedTextResult;
28-
import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode;
2927
import com.oracle.bmc.generativeaiinference.model.ServingMode;
3028
import com.oracle.bmc.generativeaiinference.requests.EmbedTextRequest;
3129
import io.micrometer.observation.ObservationRegistry;
@@ -128,15 +126,6 @@ private EmbeddingResponse embedAllWithContext(List<EmbedTextRequest> embedTextRe
128126
return embeddingResponse;
129127
}
130128

131-
private ServingMode servingMode(OCIEmbeddingOptions embeddingOptions) {
132-
return switch (embeddingOptions.getServingMode()) {
133-
case "dedicated" -> DedicatedServingMode.builder().endpointId(embeddingOptions.getModel()).build();
134-
case "on-demand" -> OnDemandServingMode.builder().modelId(embeddingOptions.getModel()).build();
135-
default -> throw new IllegalArgumentException(
136-
"unknown serving mode for OCI embedding model: " + embeddingOptions.getServingMode());
137-
};
138-
}
139-
140129
private List<EmbedTextRequest> createRequests(List<String> inputs, OCIEmbeddingOptions embeddingOptions) {
141130
int size = inputs.size();
142131
List<EmbedTextRequest> requests = new ArrayList<>();
@@ -148,8 +137,9 @@ private List<EmbedTextRequest> createRequests(List<String> inputs, OCIEmbeddingO
148137
}
149138

150139
private EmbedTextRequest createRequest(List<String> inputs, OCIEmbeddingOptions embeddingOptions) {
140+
ServingMode servingMode = ServingModeHelper.get(this.options.getServingMode(), this.options.getModel());
151141
EmbedTextDetails embedTextDetails = EmbedTextDetails.builder()
152-
.servingMode(servingMode(embeddingOptions))
142+
.servingMode(servingMode)
153143
.compartmentId(embeddingOptions.getCompartment())
154144
.inputs(inputs)
155145
.truncate(Objects.requireNonNullElse(embeddingOptions.getTruncate(), EmbedTextDetails.Truncate.End))
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.oci;
18+
19+
import com.oracle.bmc.generativeaiinference.model.DedicatedServingMode;
20+
import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode;
21+
import com.oracle.bmc.generativeaiinference.model.ServingMode;
22+
23+
/**
24+
* Helper class to load the OCI Gen AI
25+
* {@link com.oracle.bmc.generativeaiinference.model.ServingMode}
26+
*
27+
* @author Anders Swanson
28+
*/
29+
public final class ServingModeHelper {
30+
31+
private ServingModeHelper() {
32+
}
33+
34+
/**
35+
* Retrieves a specific type of ServingMode based on the provided serving mode string.
36+
* @param servingMode The serving mode as a string. Supported options are 'dedicated'
37+
* and 'on-demand'.
38+
* @param model The model identifier to be used with the serving mode.
39+
* @return A ServingMode instance configured according to the provided parameters.
40+
* @throws IllegalArgumentException If the specified serving mode is not supported.
41+
*/
42+
public static ServingMode get(String servingMode, String model) {
43+
return switch (servingMode) {
44+
case "dedicated" -> DedicatedServingMode.builder().endpointId(model).build();
45+
case "on-demand" -> OnDemandServingMode.builder().modelId(model).build();
46+
default -> throw new IllegalArgumentException(String.format(
47+
"Unknown serving mode for OCI Gen AI: %s. Supported options are 'dedicated' and 'on-demand'",
48+
servingMode));
49+
};
50+
}
51+
52+
}
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.oci.cohere;
18+
19+
import java.util.ArrayList;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.Objects;
23+
24+
import com.oracle.bmc.generativeaiinference.GenerativeAiInference;
25+
import com.oracle.bmc.generativeaiinference.model.BaseChatRequest;
26+
import com.oracle.bmc.generativeaiinference.model.BaseChatResponse;
27+
import com.oracle.bmc.generativeaiinference.model.ChatDetails;
28+
import com.oracle.bmc.generativeaiinference.model.CohereChatBotMessage;
29+
import com.oracle.bmc.generativeaiinference.model.CohereChatRequest;
30+
import com.oracle.bmc.generativeaiinference.model.CohereChatResponse;
31+
import com.oracle.bmc.generativeaiinference.model.CohereMessage;
32+
import com.oracle.bmc.generativeaiinference.model.CohereSystemMessage;
33+
import com.oracle.bmc.generativeaiinference.model.CohereToolCall;
34+
import com.oracle.bmc.generativeaiinference.model.CohereToolMessage;
35+
import com.oracle.bmc.generativeaiinference.model.CohereToolResult;
36+
import com.oracle.bmc.generativeaiinference.model.CohereUserMessage;
37+
import com.oracle.bmc.generativeaiinference.model.ServingMode;
38+
import com.oracle.bmc.generativeaiinference.requests.ChatRequest;
39+
import io.micrometer.observation.ObservationRegistry;
40+
41+
import org.springframework.ai.chat.messages.AssistantMessage;
42+
import org.springframework.ai.chat.messages.Message;
43+
import org.springframework.ai.chat.messages.ToolResponseMessage;
44+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
45+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
46+
import org.springframework.ai.chat.model.ChatModel;
47+
import org.springframework.ai.chat.model.ChatResponse;
48+
import org.springframework.ai.chat.model.Generation;
49+
import org.springframework.ai.chat.observation.ChatModelObservationContext;
50+
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
51+
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
52+
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
53+
import org.springframework.ai.chat.prompt.ChatOptions;
54+
import org.springframework.ai.chat.prompt.Prompt;
55+
import org.springframework.ai.model.ModelOptionsUtils;
56+
import org.springframework.ai.observation.conventions.AiProvider;
57+
import org.springframework.ai.oci.ServingModeHelper;
58+
import org.springframework.util.Assert;
59+
import org.springframework.util.StringUtils;
60+
61+
/**
62+
* {@link ChatModel} implementation that uses the OCI GenAI Chat API.
63+
*
64+
* @author Anders Swanson
65+
* @since 1.0.0
66+
*/
67+
public class OCICohereChatModel implements ChatModel {
68+
69+
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
70+
71+
private static final Double DEFAULT_TEMPERATURE = 0.7;
72+
73+
/**
74+
* The {@link GenerativeAiInference} client used to interact with OCI GenAI service.
75+
*/
76+
private final GenerativeAiInference genAi;
77+
78+
/**
79+
* The configuration information for a chat completions request.
80+
*/
81+
private final OCICohereChatOptions defaultOptions;
82+
83+
private final ObservationRegistry observationRegistry;
84+
85+
/**
86+
* Conventions to use for generating observations.
87+
*/
88+
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
89+
90+
public OCICohereChatModel(GenerativeAiInference genAi, OCICohereChatOptions options) {
91+
this(genAi, options, null);
92+
}
93+
94+
public OCICohereChatModel(GenerativeAiInference genAi, OCICohereChatOptions options,
95+
ObservationRegistry observationRegistry) {
96+
Assert.notNull(genAi, "com.oracle.bmc.generativeaiinference.GenerativeAiInference must not be null");
97+
Assert.notNull(options, "OCIChatOptions must not be null");
98+
99+
this.genAi = genAi;
100+
this.defaultOptions = options;
101+
this.observationRegistry = observationRegistry;
102+
}
103+
104+
@Override
105+
public ChatResponse call(Prompt prompt) {
106+
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
107+
.prompt(prompt)
108+
.provider(AiProvider.OCI_GENAI.value())
109+
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
110+
.build();
111+
112+
return ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
113+
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
114+
this.observationRegistry)
115+
.observe(() -> {
116+
ChatResponse chatResponse = doChatRequest(prompt);
117+
observationContext.setResponse(chatResponse);
118+
return chatResponse;
119+
});
120+
}
121+
122+
@Override
123+
public ChatOptions getDefaultOptions() {
124+
return OCICohereChatOptions.fromOptions(this.defaultOptions);
125+
}
126+
127+
/**
128+
* Use the provided convention for reporting observation data
129+
* @param observationConvention The provided convention
130+
*/
131+
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
132+
Assert.notNull(observationConvention, "observationConvention cannot be null");
133+
this.observationConvention = observationConvention;
134+
}
135+
136+
private ChatResponse doChatRequest(Prompt prompt) {
137+
OCICohereChatOptions options = mergeOptions(prompt.getOptions(), this.defaultOptions);
138+
validateChatOptions(options);
139+
140+
ChatResponseMetadata metadata = ChatResponseMetadata.builder()
141+
.withModel(options.getModel())
142+
.withKeyValue("compartment", options.getCompartment())
143+
.build();
144+
return new ChatResponse(getGenerations(prompt, options), metadata);
145+
146+
}
147+
148+
private OCICohereChatOptions mergeOptions(ChatOptions chatOptions, OCICohereChatOptions defaultOptions) {
149+
if (chatOptions instanceof OCICohereChatOptions override) {
150+
OCICohereChatOptions dynamicOptions = ModelOptionsUtils.merge(override, defaultOptions,
151+
OCICohereChatOptions.class);
152+
153+
if (dynamicOptions != null) {
154+
return dynamicOptions;
155+
}
156+
}
157+
return defaultOptions;
158+
}
159+
160+
private void validateChatOptions(OCICohereChatOptions options) {
161+
if (!StringUtils.hasText(options.getModel())) {
162+
throw new IllegalArgumentException("Model is not set!");
163+
}
164+
if (!StringUtils.hasText(options.getCompartment())) {
165+
throw new IllegalArgumentException("Compartment is not set!");
166+
}
167+
if (!StringUtils.hasText(options.getServingMode())) {
168+
throw new IllegalArgumentException("ServingMode is not set!");
169+
}
170+
}
171+
172+
private List<Generation> getGenerations(Prompt prompt, OCICohereChatOptions options) {
173+
com.oracle.bmc.generativeaiinference.responses.ChatResponse cr = genAi
174+
.chat(toCohereChatRequest(prompt, options));
175+
return toGenerations(cr, options);
176+
177+
}
178+
179+
private List<Generation> toGenerations(com.oracle.bmc.generativeaiinference.responses.ChatResponse ociChatResponse,
180+
OCICohereChatOptions options) {
181+
BaseChatResponse cr = ociChatResponse.getChatResult().getChatResponse();
182+
if (cr instanceof CohereChatResponse resp) {
183+
List<Generation> generations = new ArrayList<>();
184+
ChatGenerationMetadata metadata = ChatGenerationMetadata.from(resp.getFinishReason().getValue(), null);
185+
AssistantMessage message = new AssistantMessage(resp.getText(), Map.of());
186+
generations.add(new Generation(message, metadata));
187+
return generations;
188+
}
189+
throw new IllegalStateException(String.format("Unexpected chat response type: %s", cr.getClass().getName()));
190+
}
191+
192+
private ChatRequest toCohereChatRequest(Prompt prompt, OCICohereChatOptions options) {
193+
List<Message> messages = prompt.getInstructions();
194+
Message message = messages.get(0);
195+
List<CohereMessage> chatHistory = getCohereMessages(messages);
196+
return newChatRequest(options, message, chatHistory);
197+
}
198+
199+
private List<CohereMessage> getCohereMessages(List<Message> messages) {
200+
List<CohereMessage> chatHistory = new ArrayList<>();
201+
for (int i = 1; i < messages.size(); i++) {
202+
Message message = messages.get(i);
203+
switch (message.getMessageType()) {
204+
case USER -> chatHistory.add(CohereUserMessage.builder().message(message.getContent()).build());
205+
case ASSISTANT -> chatHistory.add(CohereChatBotMessage.builder().message(message.getContent()).build());
206+
case SYSTEM -> chatHistory.add(CohereSystemMessage.builder().message(message.getContent()).build());
207+
case TOOL -> {
208+
if (message instanceof ToolResponseMessage tm) {
209+
chatHistory.add(toToolMessage(tm));
210+
}
211+
}
212+
}
213+
}
214+
return chatHistory;
215+
}
216+
217+
private CohereToolMessage toToolMessage(ToolResponseMessage tm) {
218+
List<CohereToolResult> results = tm.getResponses().stream().map(r -> {
219+
CohereToolCall call = CohereToolCall.builder().name(r.name()).build();
220+
return CohereToolResult.builder().call(call).outputs(List.of(r.responseData())).build();
221+
}).toList();
222+
return CohereToolMessage.builder().toolResults(results).build();
223+
}
224+
225+
private ChatRequest newChatRequest(OCICohereChatOptions options, Message message, List<CohereMessage> chatHistory) {
226+
BaseChatRequest baseChatRequest = CohereChatRequest.builder()
227+
.frequencyPenalty(options.getFrequencyPenalty())
228+
.presencePenalty(options.getPresencePenalty())
229+
.maxTokens(options.getMaxTokens())
230+
.topK(options.getTopK())
231+
.topP(options.getTopP())
232+
.temperature(Objects.requireNonNullElse(options.getTemperature(), DEFAULT_TEMPERATURE))
233+
.preambleOverride(options.getPreambleOverride())
234+
.stopSequences(options.getStopSequences())
235+
.documents(options.getDocuments())
236+
.tools(options.getTools())
237+
.chatHistory(chatHistory)
238+
.message(message.getContent())
239+
.build();
240+
ServingMode servingMode = ServingModeHelper.get(options.getServingMode(), options.getModel());
241+
ChatDetails chatDetails = ChatDetails.builder()
242+
.compartmentId(options.getCompartment())
243+
.servingMode(servingMode)
244+
.chatRequest(baseChatRequest)
245+
.build();
246+
return ChatRequest.builder().body$(chatDetails).build();
247+
}
248+
249+
}

0 commit comments

Comments
 (0)