Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/main/java/ee/carlrobert/llm/client/azure/AzureClient.java
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create a new overloaded method instead? This is a breaking change and would require other apps to rewrite their code as well.

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import ee.carlrobert.llm.PropertiesLoader;
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.openai.completion.ChatCompletionResponseData;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
Expand Down Expand Up @@ -46,7 +47,7 @@ private AzureClient(Builder builder, OkHttpClient.Builder httpClientBuilder) {

public EventSource getChatCompletionAsync(
OpenAIChatCompletionRequest request,
CompletionEventListener<String> completionEventListener) {
CompletionEventListener<ChatCompletionResponseData> completionEventListener) {
return EventSources.createFactory(httpClient)
.newEventSource(buildChatRequest(request), getEventSourceListener(completionEventListener));
}
Expand Down Expand Up @@ -139,7 +140,7 @@ private String getImageGenerationPath(OpenAIImageGenerationRequest request) {
}

private OpenAIChatCompletionEventSourceListener getEventSourceListener(
CompletionEventListener<String> listener) {
CompletionEventListener<ChatCompletionResponseData> listener) {
return new OpenAIChatCompletionEventSourceListener(listener) {
@Override
protected ErrorDetails getErrorDetails(String data) throws JsonProcessingException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import ee.carlrobert.llm.client.codegpt.response.AutoApplyResponse;
import ee.carlrobert.llm.client.codegpt.response.CodeGPTException;
import ee.carlrobert.llm.client.codegpt.response.PredictionResponse;
import ee.carlrobert.llm.client.openai.completion.ChatCompletionResponseData;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener;
Expand Down Expand Up @@ -59,7 +60,7 @@ public CodeGPTUserDetails getUserDetails(String apiKey) {

public EventSource getChatCompletionAsync(
ChatCompletionRequest request,
CompletionEventListener<String> eventListener) {
CompletionEventListener<ChatCompletionResponseData> eventListener) {
return createNewEventSource(
buildChatCompletionRequest(request),
getChatCompletionEventSourceListener(eventListener));
Expand Down Expand Up @@ -201,7 +202,7 @@ private Map<String, String> getRequiredHeaders() {
}

private OpenAIChatCompletionEventSourceListener getChatCompletionEventSourceListener(
CompletionEventListener<String> listener) {
CompletionEventListener<ChatCompletionResponseData> listener) {
return new OpenAIChatCompletionEventSourceListener(listener) {
@Override
protected ErrorDetails getErrorDetails(String data) throws JsonProcessingException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import ee.carlrobert.llm.client.ollama.completion.response.OllamaModelInfoResponse;
import ee.carlrobert.llm.client.ollama.completion.response.OllamaPullResponse;
import ee.carlrobert.llm.client.ollama.completion.response.OllamaTagsResponse;
import ee.carlrobert.llm.client.openai.completion.ChatCompletionResponseData;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
Expand Down Expand Up @@ -67,7 +68,7 @@ public EventSource getCompletionAsync(

public EventSource getChatCompletionAsync(
OpenAIChatCompletionRequest request,
CompletionEventListener<String> eventListener) {
CompletionEventListener<ChatCompletionResponseData> eventListener) {
return EventSources.createFactory(httpClient)
.newEventSource(
buildPostRequest(request, "/v1/chat/completions", true),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ee.carlrobert.llm.client.DeserializationUtil;
import ee.carlrobert.llm.client.codegpt.response.CodeGPTException;
import ee.carlrobert.llm.client.openai.completion.ApiResponseError;
import ee.carlrobert.llm.client.openai.completion.ChatCompletionResponseData;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionEventSourceListener;
import ee.carlrobert.llm.client.openai.completion.OpenAITextCompletionEventSourceListener;
Expand Down Expand Up @@ -69,7 +70,7 @@ public EventSource getCompletionAsync(

public EventSource getChatCompletionAsync(
OpenAIChatCompletionRequest request,
CompletionEventListener<String> eventListener) {
CompletionEventListener<ChatCompletionResponseData> eventListener) {
return getChatCompletionAsync(
request,
new OpenAIChatCompletionEventSourceListener(eventListener));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package ee.carlrobert.llm.client.openai.completion;

import org.jetbrains.annotations.Nullable;

public class ChatCompletionResponseData {
@Nullable private final String content;
@Nullable private final String reasoningContent;

public ChatCompletionResponseData(@Nullable String content, @Nullable String reasoningContent) {
this.content = content;
this.reasoningContent = reasoningContent;
}

@Nullable
public String getContent() {
return content;
}

@Nullable
public String getReasoningContent() {
return reasoningContent;
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
if (reasoningContent != null && !reasoningContent.isEmpty()) {
sb.append("<think>").append(reasoningContent).append("</think>\n\n");
}
if (content != null) {
sb.append(content);
}
return sb.toString();
}
}
Original file line number Diff line number Diff line change
@@ -1,45 +1,39 @@
package ee.carlrobert.llm.client.openai.completion;

import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;

import com.fasterxml.jackson.core.JsonProcessingException;
import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponse;
import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponseChoice;
import ee.carlrobert.llm.client.openai.completion.response.OpenAIChatCompletionResponseChoiceDelta;
import ee.carlrobert.llm.completion.CompletionEventListener;
import ee.carlrobert.llm.completion.CompletionEventSourceListener;
import java.util.Objects;
import java.util.stream.Stream;

public class OpenAIChatCompletionEventSourceListener extends CompletionEventSourceListener<String> {
import static ee.carlrobert.llm.client.DeserializationUtil.OBJECT_MAPPER;

public class OpenAIChatCompletionEventSourceListener extends CompletionEventSourceListener<ChatCompletionResponseData> {

public OpenAIChatCompletionEventSourceListener(CompletionEventListener<String> listener) {

public OpenAIChatCompletionEventSourceListener(CompletionEventListener<ChatCompletionResponseData> listener) {
super(listener);
}

/**
* Content of the first choice.
* <ul>
* <li>Search all choices which are not null</li>
* <li>Search all deltas which are not null</li>
* <li>Use first content which is not null or blank (whitespace)</li>
* <li>Otherwise use "" (empty string) if no match can be found</li>
* </ul>
* Returns the first valid message content extracted from the OpenAI Chat Completion response.
*
* @return First non-blank content which can be found, otherwise {@code ""}
* @param data the JSON string received from the OpenAI Chat Completion API.
* @return ChatCompletionResponseData object containing content and reasoningContent
* @throws JsonProcessingException if an error occurs during JSON processing.
*/
protected String getMessage(String data) throws JsonProcessingException {
var choices = OBJECT_MAPPER
.readValue(data, OpenAIChatCompletionResponse.class)
.getChoices();
return (choices == null ? Stream.<OpenAIChatCompletionResponseChoice>empty() : choices.stream())
protected ChatCompletionResponseData getMessage(String data) throws JsonProcessingException {
var response = OBJECT_MAPPER.readValue(data, OpenAIChatCompletionResponse.class);
var choices = response.getChoices();

return choices == null ? new ChatCompletionResponseData(null, null) : choices.stream()
.filter(Objects::nonNull)
.map(OpenAIChatCompletionResponseChoice::getDelta)
.filter(Objects::nonNull)
.map(OpenAIChatCompletionResponseChoiceDelta::getContent)
.filter(Objects::nonNull)
.map(delta -> new ChatCompletionResponseData(delta.getContent(), delta.getReasoningContent()))
.findFirst()
.orElse("");
.orElse(new ChatCompletionResponseData(null, null));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,19 @@ public class OpenAIChatCompletionResponseChoiceDelta {

private final String role;
private final String content;
private final String reasoningContent;
private final List<ToolCall> toolCalls;

@JsonCreator(mode = JsonCreator.Mode.PROPERTIES)
public OpenAIChatCompletionResponseChoiceDelta(
@JsonProperty("role") String role,
@JsonProperty("content") String content,
@JsonProperty("reasoning_content") String reasoningContent,
@JsonProperty("tool_calls") List<ToolCall> toolCalls) {
this.role = role;
this.content = content;
this.toolCalls = toolCalls;
this.reasoningContent = reasoningContent;
}

public String getRole() {
Expand All @@ -33,4 +36,8 @@ public String getContent() {
public List<ToolCall> getToolCalls() {
return toolCalls;
}

public String getReasoningContent() {
return reasoningContent;
}
}
13 changes: 7 additions & 6 deletions src/test/java/ee/carlrobert/llm/client/AzureClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ee.carlrobert.llm.client.http.ResponseEntity;
import ee.carlrobert.llm.client.http.exchange.BasicHttpExchange;
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange;
import ee.carlrobert.llm.client.openai.completion.ChatCompletionResponseData;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage;
Expand Down Expand Up @@ -75,9 +76,9 @@ void shouldStreamAzureChatCompletion() {
.setPresencePenalty(0.1)
.setFrequencyPenalty(0.1)
.build(),
new CompletionEventListener<String>() {
new CompletionEventListener<ChatCompletionResponseData>() {
@Override
public void onMessage(String message, EventSource eventSource) {
public void onMessage(ChatCompletionResponseData message, EventSource eventSource) {
resultMessageBuilder.append(message);
}

Expand Down Expand Up @@ -245,9 +246,9 @@ void shouldStreamAzureChatCompletionWithCustomURL() {
.setFrequencyPenalty(0.1)
.setOverriddenPath("/v1/deployments/%s/completions?api_version=%s")
.build(),
new CompletionEventListener<String>() {
new CompletionEventListener<ChatCompletionResponseData>() {
@Override
public void onMessage(String message, EventSource eventSource) {
public void onMessage(ChatCompletionResponseData message, EventSource eventSource) {
resultMessageBuilder.append(message);
}

Expand Down Expand Up @@ -282,7 +283,7 @@ void shouldListenForInvalidTokenErrorResponse() {
new OpenAIChatCompletionRequest.Builder(
List.of(new OpenAIChatCompletionStandardMessage("user", "TEST_PROMPT")))
.build(),
new CompletionEventListener<String>() {
new CompletionEventListener<ChatCompletionResponseData>() {
@Override
public void onError(ErrorDetails error, Throwable t) {
errorMessageBuilder.append(error.getMessage());
Expand Down Expand Up @@ -314,7 +315,7 @@ void shouldListenForInvalidResourceErrorResponse() {
new OpenAIChatCompletionRequest.Builder(
List.of(new OpenAIChatCompletionStandardMessage("user", "TEST_PROMPT")))
.build(),
new CompletionEventListener<String>() {
new CompletionEventListener<ChatCompletionResponseData>() {
@Override
public void onError(ErrorDetails error, Throwable t) {
errorMessageBuilder.append(error.getMessage());
Expand Down
5 changes: 3 additions & 2 deletions src/test/java/ee/carlrobert/llm/client/CodeGPTClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ee.carlrobert.llm.client.http.ResponseEntity;
import ee.carlrobert.llm.client.http.exchange.BasicHttpExchange;
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange;
import ee.carlrobert.llm.client.openai.completion.ChatCompletionResponseData;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionStandardMessage;
import ee.carlrobert.llm.completion.CompletionEventListener;
import java.util.List;
Expand Down Expand Up @@ -85,9 +86,9 @@ void shouldStreamChatCompletion() {
new ContextFile("TEST_FILE_NAME", "TEST_FILE_CONTENT"))))
.setDocumentationDetails(new DocumentationDetails("TEST_DOC_NAME", "TEST_DOC_URL"))
.build(),
new CompletionEventListener<String>() {
new CompletionEventListener<ChatCompletionResponseData>() {
@Override
public void onMessage(String message, EventSource eventSource) {
public void onMessage(ChatCompletionResponseData message, EventSource eventSource) {
resultMessageBuilder.append(message);
}
});
Expand Down
17 changes: 9 additions & 8 deletions src/test/java/ee/carlrobert/llm/client/OpenAIClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import ee.carlrobert.llm.client.http.exchange.BasicHttpExchange;
import ee.carlrobert.llm.client.http.exchange.StreamHttpExchange;
import ee.carlrobert.llm.client.openai.OpenAIClient;
import ee.carlrobert.llm.client.openai.completion.ChatCompletionResponseData;
import ee.carlrobert.llm.client.openai.completion.ErrorDetails;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel;
import ee.carlrobert.llm.client.openai.completion.request.OpenAIChatCompletionRequest;
Expand Down Expand Up @@ -90,10 +91,10 @@ void shouldStreamChatCompletion() {
.setFrequencyPenalty(0.1)
.setResponseFormat(responseFormat)
.build(),
new CompletionEventListener<String>() {
new CompletionEventListener<ChatCompletionResponseData>() {
@Override
public void onMessage(String message, EventSource eventSource) {
resultMessageBuilder.append(message);
public void onMessage(ChatCompletionResponseData message, EventSource eventSource) {
resultMessageBuilder.append(message.getContent());
}
});

Expand Down Expand Up @@ -219,10 +220,10 @@ void shouldStreamChatCompletionWithCustomURL() {
.setFrequencyPenalty(0.1)
.setOverriddenPath("/v1/test/segment")
.build(),
new CompletionEventListener<String>() {
new CompletionEventListener<ChatCompletionResponseData>() {
@Override
public void onMessage(String message, EventSource eventSource) {
resultMessageBuilder.append(message);
public void onMessage(ChatCompletionResponseData message, EventSource eventSource) {
resultMessageBuilder.append(message.getContent());
}
});

Expand Down Expand Up @@ -407,7 +408,7 @@ void shouldHandleInvalidApiKeyError() {
List.of(new OpenAIChatCompletionStandardMessage("user", "TEST_PROMPT")))
.setModel(OpenAIChatCompletionModel.GPT_3_5)
.build(),
new CompletionEventListener<String>() {
new CompletionEventListener<ChatCompletionResponseData>() {
@Override
public void onError(ErrorDetails error, Throwable t) {
assertThat(error.getCode()).isEqualTo("invalid_api_key");
Expand Down Expand Up @@ -436,7 +437,7 @@ void shouldHandleUnknownApiError() {
List.of(new OpenAIChatCompletionStandardMessage("user", "TEST_PROMPT")))
.setModel(OpenAIChatCompletionModel.GPT_3_5)
.build(),
new CompletionEventListener<String>() {
new CompletionEventListener<ChatCompletionResponseData>() {
@Override
public void onError(ErrorDetails error, Throwable t) {
errorMessageBuilder.append(error.getMessage());
Expand Down
Loading