diff --git a/pom.xml b/pom.xml index de8699e6..f401c0c9 100644 --- a/pom.xml +++ b/pom.xml @@ -52,7 +52,7 @@ 11 [2.0.9,3.0.0) - 0.13.0 + 1.1.0 [1.18.30,2.0.0) [2.15.2,3.0.0) [4.31.1,5.0.0) diff --git a/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java index 97d3825a..55a01d35 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java @@ -1,9 +1,11 @@ package io.github.sashirestela.openai.demo; +import io.github.sashirestela.cleverclient.http.HttpRequestData; +import io.github.sashirestela.openai.SimpleOpenAI; import java.util.ArrayList; import java.util.List; - -import io.github.sashirestela.openai.SimpleOpenAI; +import java.util.function.UnaryOperator; +import lombok.NonNull; public abstract class AbstractDemo { @@ -23,6 +25,16 @@ protected AbstractDemo() { .build(); } + protected AbstractDemo(@NonNull String baseUrl, + @NonNull String apiKey, + @NonNull UnaryOperator requestInterceptor) { + openAI = SimpleOpenAI.builder() + .apiKey(apiKey) + .baseUrl(baseUrl) + .requestInterceptor(requestInterceptor) + .build(); + } + public void addTitleAction(String title, Action action) { titleActions.add(new TitleAction(title, action)); } diff --git a/src/demo/java/io/github/sashirestela/openai/demo/AzureOpenAIChatServiceDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/AzureOpenAIChatServiceDemo.java new file mode 100644 index 00000000..23c384c7 --- /dev/null +++ b/src/demo/java/io/github/sashirestela/openai/demo/AzureOpenAIChatServiceDemo.java @@ -0,0 +1,81 @@ +package io.github.sashirestela.openai.demo; + + +import io.github.sashirestela.cleverclient.support.ContentType; +import io.github.sashirestela.openai.domain.chat.ChatRequest; +import io.github.sashirestela.openai.domain.chat.message.ChatMsgSystem; +import io.github.sashirestela.openai.domain.chat.message.ChatMsgUser; +import java.util.Map; +import java.util.Optional; + +public class AzureOpenAIChatServiceDemo extends AbstractDemo { + private static final String AZURE_OPENAI_API_KEY_HEADER = "api-key"; + private final ChatRequest chatRequest; + + @SuppressWarnings("unchecked") + public AzureOpenAIChatServiceDemo(String baseUrl, String apiKey, String model) { + super(baseUrl, apiKey, request -> { + var url = request.getUrl(); + var contentType = request.getContentType(); + var body = request.getBody(); + + // add a header to the request + var headers = request.getHeaders(); + headers.put(AZURE_OPENAI_API_KEY_HEADER, apiKey); + request.setHeaders(headers); + + // add a query parameter to url + url += (url.contains("?") ? "&" : "?") + "api-version=2023-05-15"; + // remove '/vN' or '/vN.M' from url + url = url.replaceFirst("(\\/v\\d+\\.*\\d*)", ""); + request.setUrl(url); + + if (contentType != null) { + if (contentType.equals(ContentType.APPLICATION_JSON)) { + var bodyJson = (String) request.getBody(); + // remove a field from body (as Json) + bodyJson = bodyJson.replaceFirst(",?\"model\":\"[^\"]*\",?", ""); + bodyJson = bodyJson.replaceFirst("\"\"", "\",\""); + body = bodyJson; + } + if (contentType.equals(ContentType.MULTIPART_FORMDATA)) { + Map bodyMap = (Map) request.getBody(); + // remove a field from body (as Map) + bodyMap.remove("model"); + body = bodyMap; + } + request.setBody(body); + } + + return request; + }); + + chatRequest = ChatRequest.builder() + .model(model) + .message(new ChatMsgSystem("You are an expert in AI.")) + .message( + new ChatMsgUser("Write a technical article about ChatGPT, no more than 100 words.")) + .temperature(0.0) + .maxTokens(300) + .build(); + } + + public void demoCallChatBlocking() { + var futureChat = openAI.chatCompletions().create(chatRequest); + var chatResponse = futureChat.join(); + System.out.println(chatResponse.firstContent()); + } + + public static void main(String[] args) { + var baseUrl = System.getenv("CUSTOM_OPENAI_BASE_URL"); + var apiKey = System.getenv("CUSTOM_OPENAI_API_KEY"); + // Services like Azure OpenAI don't require a model (endpoints have built-in model) + var model = Optional.ofNullable(System.getenv("CUSTOM_OPENAI_MODEL")) + .orElse("N/A"); + var demo = new AzureOpenAIChatServiceDemo(baseUrl, apiKey, model); + + demo.addTitleAction("Call Completion (Blocking Approach)", demo::demoCallChatBlocking); + + demo.run(); + } +} diff --git a/src/main/java/io/github/sashirestela/openai/OpenAI.java b/src/main/java/io/github/sashirestela/openai/OpenAI.java index 3408c75c..7c27b3d6 100644 --- a/src/main/java/io/github/sashirestela/openai/OpenAI.java +++ b/src/main/java/io/github/sashirestela/openai/OpenAI.java @@ -218,7 +218,7 @@ default CompletableFuture> createStream(@Body ChatRequest c /** * Given a prompt, the model will return one or more predicted completions. It - * is recommend most users to use the Chat Completion. + * is recommended for most users to use the Chat Completion. * * @see OpenAI diff --git a/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java b/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java index 928f1f74..e7f74c77 100644 --- a/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java +++ b/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java @@ -1,10 +1,11 @@ package io.github.sashirestela.openai; +import io.github.sashirestela.cleverclient.CleverClient; +import io.github.sashirestela.cleverclient.http.HttpRequestData; import java.net.http.HttpClient; -import java.util.ArrayList; +import java.util.HashMap; import java.util.Optional; - -import io.github.sashirestela.cleverclient.CleverClient; +import java.util.function.UnaryOperator; import lombok.AccessLevel; import lombok.Builder; import lombok.Getter; @@ -30,6 +31,7 @@ public class SimpleOpenAI { private final String baseUrl; @Deprecated private final String urlBase = null; + private HttpClient httpClient; private CleverClient cleverClient; @@ -86,7 +88,8 @@ public SimpleOpenAI( String organizationId, String baseUrl, String urlBase, - HttpClient httpClient) { + HttpClient httpClient, + UnaryOperator requestInterceptor) { this.apiKey = apiKey; this.organizationId = organizationId; this.baseUrl = Optional.ofNullable(baseUrl) @@ -94,18 +97,17 @@ public SimpleOpenAI( this.httpClient = Optional.ofNullable(httpClient).orElse(HttpClient.newHttpClient()); - var headers = new ArrayList(); - headers.add(AUTHORIZATION_HEADER); - headers.add(BEARER_AUTHORIZATION + apiKey); + var headers = new HashMap(); + headers.put(AUTHORIZATION_HEADER, BEARER_AUTHORIZATION + apiKey); if (organizationId != null) { - headers.add(ORGANIZATION_HEADER); - headers.add(organizationId); + headers.put(ORGANIZATION_HEADER, organizationId); } this.cleverClient = CleverClient.builder() .httpClient(this.httpClient) .baseUrl(this.baseUrl) .headers(headers) .endOfStream(END_OF_STREAM) + .requestInterceptor(requestInterceptor) .build(); } diff --git a/src/main/java/io/github/sashirestela/openai/support/JsonSchemaUtil.java b/src/main/java/io/github/sashirestela/openai/support/JsonSchemaUtil.java index da53db69..e47999f1 100644 --- a/src/main/java/io/github/sashirestela/openai/support/JsonSchemaUtil.java +++ b/src/main/java/io/github/sashirestela/openai/support/JsonSchemaUtil.java @@ -10,11 +10,11 @@ import com.github.victools.jsonschema.generator.SchemaVersion; import com.github.victools.jsonschema.module.jackson.JacksonModule; import com.github.victools.jsonschema.module.jackson.JacksonOption; - -import io.github.sashirestela.cleverclient.util.Constant; import io.github.sashirestela.openai.SimpleUncheckedException; public class JsonSchemaUtil { + + public static final String JSON_EMPTY_CLASS = "{\"type\":\"object\",\"properties\":{}}"; private static ObjectMapper objectMapper = new ObjectMapper(); private JsonSchemaUtil() { @@ -39,7 +39,7 @@ public static JsonNode classToJsonSchema(Class clazz) { } } else { try { - jsonSchema = objectMapper.readTree(Constant.JSON_EMPTY_CLASS); + jsonSchema = objectMapper.readTree(JSON_EMPTY_CLASS); } catch (JsonProcessingException e) { throw new SimpleUncheckedException("Cannot generate the Json Schema for the class {0}.", clazz.getName(), e); diff --git a/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java b/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java index 12ef68db..65b4c9f5 100644 --- a/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java +++ b/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java @@ -99,7 +99,7 @@ void shouldNotAddOrganizationToHeadersWhenBuilderIsCalledWithoutOrganizationId() var openAI = SimpleOpenAI.builder() .apiKey("apiKey") .build(); - assertFalse(openAI.getCleverClient().getHeaders().contains(openAI.getOrganizationId())); + assertFalse(openAI.getCleverClient().getHeaders().containsValue(openAI.getOrganizationId())); } @Test @@ -108,7 +108,7 @@ void shouldAddOrganizationToHeadersWhenBuilderIsCalledWithOrganizationId() { .apiKey("apiKey") .organizationId("orgId") .build(); - assertTrue(openAI.getCleverClient().getHeaders().contains(openAI.getOrganizationId())); + assertTrue(openAI.getCleverClient().getHeaders().containsValue(openAI.getOrganizationId())); } @Test @@ -165,7 +165,7 @@ void shouldInstanceAudioServiceOnlyOnceWhenItIsCalledSeveralTimes() { when(cleverClient.create(any())) .thenReturn(ReflectUtil.createProxy( OpenAI.Audios.class, - new HttpProcessor(null, null, null))); + HttpProcessor.builder().build())); repeat(NUMBER_CALLINGS, () -> openAI.audios()); verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any()); } @@ -175,7 +175,7 @@ void shouldInstanceChatCompletionServiceOnlyOnceWhenItIsCalledSeveralTimes() { when(cleverClient.create(any())) .thenReturn(ReflectUtil.createProxy( OpenAI.ChatCompletions.class, - new HttpProcessor(null, null, null))); + HttpProcessor.builder().build())); repeat(NUMBER_CALLINGS, () -> openAI.chatCompletions()); verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any()); } @@ -185,7 +185,7 @@ void shouldInstanceCompletionServiceOnlyOnceWhenItIsCalledSeveralTimes() { when(cleverClient.create(any())) .thenReturn(ReflectUtil.createProxy( OpenAI.Completions.class, - new HttpProcessor(null, null, null))); + HttpProcessor.builder().build())); repeat(NUMBER_CALLINGS, () -> openAI.completions()); verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any()); } @@ -195,7 +195,7 @@ void shouldInstanceEmbeddingServiceOnlyOnceWhenItIsCalledSeveralTimes() { when(cleverClient.create(any())) .thenReturn(ReflectUtil.createProxy( OpenAI.Embeddings.class, - new HttpProcessor(null, null, null))); + HttpProcessor.builder().build())); repeat(NUMBER_CALLINGS, () -> openAI.embeddings()); verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any()); } @@ -205,7 +205,7 @@ void shouldInstanceFilesServiceOnlyOnceWhenItIsCalledSeveralTimes() { when(cleverClient.create(any())) .thenReturn(ReflectUtil.createProxy( OpenAI.Files.class, - new HttpProcessor(null, null, null))); + HttpProcessor.builder().build())); repeat(NUMBER_CALLINGS, () -> openAI.files()); verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any()); } @@ -215,7 +215,7 @@ void shouldInstanceFineTunningServiceOnlyOnceWhenItIsCalledSeveralTimes() { when(cleverClient.create(any())) .thenReturn(ReflectUtil.createProxy( OpenAI.FineTunings.class, - new HttpProcessor(null, null, null))); + HttpProcessor.builder().build())); repeat(NUMBER_CALLINGS, () -> openAI.fineTunings()); verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any()); } @@ -225,7 +225,7 @@ void shouldInstanceImageServiceOnlyOnceWhenItIsCalledSeveralTimes() { when(cleverClient.create(any())) .thenReturn(ReflectUtil.createProxy( OpenAI.Images.class, - new HttpProcessor(null, null, null))); + HttpProcessor.builder().build())); repeat(NUMBER_CALLINGS, () -> openAI.images()); verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any()); } @@ -235,7 +235,7 @@ void shouldInstanceModelsServiceOnlyOnceWhenItIsCalledSeveralTimes() { when(cleverClient.create(any())) .thenReturn(ReflectUtil.createProxy( OpenAI.Models.class, - new HttpProcessor(null, null, null))); + HttpProcessor.builder().build())); repeat(NUMBER_CALLINGS, () -> openAI.models()); verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any()); } @@ -245,7 +245,7 @@ void shouldInstanceModerationServiceOnlyOnceWhenItIsCalledSeveralTimes() { when(cleverClient.create(any())) .thenReturn(ReflectUtil.createProxy( OpenAI.Moderations.class, - new HttpProcessor(null, null, null))); + HttpProcessor.builder().build())); repeat(NUMBER_CALLINGS, () -> openAI.moderations()); verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any()); } diff --git a/src/test/java/io/github/sashirestela/openai/support/JsonSchemaUtilTest.java b/src/test/java/io/github/sashirestela/openai/support/JsonSchemaUtilTest.java index 529df928..543bb507 100644 --- a/src/test/java/io/github/sashirestela/openai/support/JsonSchemaUtilTest.java +++ b/src/test/java/io/github/sashirestela/openai/support/JsonSchemaUtilTest.java @@ -1,15 +1,13 @@ package io.github.sashirestela.openai.support; +import static io.github.sashirestela.openai.support.JsonSchemaUtil.JSON_EMPTY_CLASS; import static org.junit.jupiter.api.Assertions.assertEquals; -import org.junit.jupiter.api.Test; - import com.fasterxml.jackson.annotation.JsonProperty; - -import io.github.sashirestela.cleverclient.util.Constant; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; +import org.junit.jupiter.api.Test; class JsonSchemaUtilTest { @@ -24,7 +22,7 @@ void shouldGenerateFullJsonSchemaWhenClassHasSomeFields() { @Test void shouldGenerateEmptyJsonSchemaWhenClassHasNoFields() { var actualJsonSchema = JsonSchemaUtil.classToJsonSchema(EmptyClass.class).toString(); - var expectedJsonSchema = Constant.JSON_EMPTY_CLASS; + var expectedJsonSchema = JSON_EMPTY_CLASS; assertEquals(expectedJsonSchema, actualJsonSchema); }