diff --git a/pom.xml b/pom.xml
index de8699e6..f401c0c9 100644
--- a/pom.xml
+++ b/pom.xml
@@ -52,7 +52,7 @@
11
[2.0.9,3.0.0)
- 0.13.0
+ 1.1.0
[1.18.30,2.0.0)
[2.15.2,3.0.0)
[4.31.1,5.0.0)
diff --git a/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java
index 97d3825a..55a01d35 100644
--- a/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java
+++ b/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java
@@ -1,9 +1,11 @@
package io.github.sashirestela.openai.demo;
+import io.github.sashirestela.cleverclient.http.HttpRequestData;
+import io.github.sashirestela.openai.SimpleOpenAI;
import java.util.ArrayList;
import java.util.List;
-
-import io.github.sashirestela.openai.SimpleOpenAI;
+import java.util.function.UnaryOperator;
+import lombok.NonNull;
public abstract class AbstractDemo {
@@ -23,6 +25,16 @@ protected AbstractDemo() {
.build();
}
+ protected AbstractDemo(@NonNull String baseUrl,
+ @NonNull String apiKey,
+ @NonNull UnaryOperator requestInterceptor) {
+ openAI = SimpleOpenAI.builder()
+ .apiKey(apiKey)
+ .baseUrl(baseUrl)
+ .requestInterceptor(requestInterceptor)
+ .build();
+ }
+
public void addTitleAction(String title, Action action) {
titleActions.add(new TitleAction(title, action));
}
diff --git a/src/demo/java/io/github/sashirestela/openai/demo/AzureOpenAIChatServiceDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/AzureOpenAIChatServiceDemo.java
new file mode 100644
index 00000000..23c384c7
--- /dev/null
+++ b/src/demo/java/io/github/sashirestela/openai/demo/AzureOpenAIChatServiceDemo.java
@@ -0,0 +1,81 @@
+package io.github.sashirestela.openai.demo;
+
+
+import io.github.sashirestela.cleverclient.support.ContentType;
+import io.github.sashirestela.openai.domain.chat.ChatRequest;
+import io.github.sashirestela.openai.domain.chat.message.ChatMsgSystem;
+import io.github.sashirestela.openai.domain.chat.message.ChatMsgUser;
+import java.util.Map;
+import java.util.Optional;
+
+public class AzureOpenAIChatServiceDemo extends AbstractDemo {
+ private static final String AZURE_OPENAI_API_KEY_HEADER = "api-key";
+ private final ChatRequest chatRequest;
+
+ @SuppressWarnings("unchecked")
+ public AzureOpenAIChatServiceDemo(String baseUrl, String apiKey, String model) {
+ super(baseUrl, apiKey, request -> {
+ var url = request.getUrl();
+ var contentType = request.getContentType();
+ var body = request.getBody();
+
+ // add a header to the request
+ var headers = request.getHeaders();
+ headers.put(AZURE_OPENAI_API_KEY_HEADER, apiKey);
+ request.setHeaders(headers);
+
+ // add a query parameter to url
+ url += (url.contains("?") ? "&" : "?") + "api-version=2023-05-15";
+ // remove '/vN' or '/vN.M' from url
+ url = url.replaceFirst("(\\/v\\d+\\.*\\d*)", "");
+ request.setUrl(url);
+
+ if (contentType != null) {
+ if (contentType.equals(ContentType.APPLICATION_JSON)) {
+ var bodyJson = (String) request.getBody();
+ // remove a field from body (as Json)
+ bodyJson = bodyJson.replaceFirst(",?\"model\":\"[^\"]*\",?", "");
+ bodyJson = bodyJson.replaceFirst("\"\"", "\",\"");
+ body = bodyJson;
+ }
+ if (contentType.equals(ContentType.MULTIPART_FORMDATA)) {
+ Map bodyMap = (Map) request.getBody();
+ // remove a field from body (as Map)
+ bodyMap.remove("model");
+ body = bodyMap;
+ }
+ request.setBody(body);
+ }
+
+ return request;
+ });
+
+ chatRequest = ChatRequest.builder()
+ .model(model)
+ .message(new ChatMsgSystem("You are an expert in AI."))
+ .message(
+ new ChatMsgUser("Write a technical article about ChatGPT, no more than 100 words."))
+ .temperature(0.0)
+ .maxTokens(300)
+ .build();
+ }
+
+ public void demoCallChatBlocking() {
+ var futureChat = openAI.chatCompletions().create(chatRequest);
+ var chatResponse = futureChat.join();
+ System.out.println(chatResponse.firstContent());
+ }
+
+ public static void main(String[] args) {
+ var baseUrl = System.getenv("CUSTOM_OPENAI_BASE_URL");
+ var apiKey = System.getenv("CUSTOM_OPENAI_API_KEY");
+ // Services like Azure OpenAI don't require a model (endpoints have built-in model)
+ var model = Optional.ofNullable(System.getenv("CUSTOM_OPENAI_MODEL"))
+ .orElse("N/A");
+ var demo = new AzureOpenAIChatServiceDemo(baseUrl, apiKey, model);
+
+ demo.addTitleAction("Call Completion (Blocking Approach)", demo::demoCallChatBlocking);
+
+ demo.run();
+ }
+}
diff --git a/src/main/java/io/github/sashirestela/openai/OpenAI.java b/src/main/java/io/github/sashirestela/openai/OpenAI.java
index 3408c75c..7c27b3d6 100644
--- a/src/main/java/io/github/sashirestela/openai/OpenAI.java
+++ b/src/main/java/io/github/sashirestela/openai/OpenAI.java
@@ -218,7 +218,7 @@ default CompletableFuture> createStream(@Body ChatRequest c
/**
* Given a prompt, the model will return one or more predicted completions. It
- * is recommend most users to use the Chat Completion.
+ * is recommended for most users to use the Chat Completion.
*
* @see OpenAI
diff --git a/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java b/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java
index 928f1f74..e7f74c77 100644
--- a/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java
+++ b/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java
@@ -1,10 +1,11 @@
package io.github.sashirestela.openai;
+import io.github.sashirestela.cleverclient.CleverClient;
+import io.github.sashirestela.cleverclient.http.HttpRequestData;
import java.net.http.HttpClient;
-import java.util.ArrayList;
+import java.util.HashMap;
import java.util.Optional;
-
-import io.github.sashirestela.cleverclient.CleverClient;
+import java.util.function.UnaryOperator;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
@@ -30,6 +31,7 @@ public class SimpleOpenAI {
private final String baseUrl;
@Deprecated
private final String urlBase = null;
+
private HttpClient httpClient;
private CleverClient cleverClient;
@@ -86,7 +88,8 @@ public SimpleOpenAI(
String organizationId,
String baseUrl,
String urlBase,
- HttpClient httpClient) {
+ HttpClient httpClient,
+ UnaryOperator requestInterceptor) {
this.apiKey = apiKey;
this.organizationId = organizationId;
this.baseUrl = Optional.ofNullable(baseUrl)
@@ -94,18 +97,17 @@ public SimpleOpenAI(
this.httpClient = Optional.ofNullable(httpClient).orElse(HttpClient.newHttpClient());
- var headers = new ArrayList();
- headers.add(AUTHORIZATION_HEADER);
- headers.add(BEARER_AUTHORIZATION + apiKey);
+ var headers = new HashMap();
+ headers.put(AUTHORIZATION_HEADER, BEARER_AUTHORIZATION + apiKey);
if (organizationId != null) {
- headers.add(ORGANIZATION_HEADER);
- headers.add(organizationId);
+ headers.put(ORGANIZATION_HEADER, organizationId);
}
this.cleverClient = CleverClient.builder()
.httpClient(this.httpClient)
.baseUrl(this.baseUrl)
.headers(headers)
.endOfStream(END_OF_STREAM)
+ .requestInterceptor(requestInterceptor)
.build();
}
diff --git a/src/main/java/io/github/sashirestela/openai/support/JsonSchemaUtil.java b/src/main/java/io/github/sashirestela/openai/support/JsonSchemaUtil.java
index da53db69..e47999f1 100644
--- a/src/main/java/io/github/sashirestela/openai/support/JsonSchemaUtil.java
+++ b/src/main/java/io/github/sashirestela/openai/support/JsonSchemaUtil.java
@@ -10,11 +10,11 @@
import com.github.victools.jsonschema.generator.SchemaVersion;
import com.github.victools.jsonschema.module.jackson.JacksonModule;
import com.github.victools.jsonschema.module.jackson.JacksonOption;
-
-import io.github.sashirestela.cleverclient.util.Constant;
import io.github.sashirestela.openai.SimpleUncheckedException;
public class JsonSchemaUtil {
+
+ public static final String JSON_EMPTY_CLASS = "{\"type\":\"object\",\"properties\":{}}";
private static ObjectMapper objectMapper = new ObjectMapper();
private JsonSchemaUtil() {
@@ -39,7 +39,7 @@ public static JsonNode classToJsonSchema(Class> clazz) {
}
} else {
try {
- jsonSchema = objectMapper.readTree(Constant.JSON_EMPTY_CLASS);
+ jsonSchema = objectMapper.readTree(JSON_EMPTY_CLASS);
} catch (JsonProcessingException e) {
throw new SimpleUncheckedException("Cannot generate the Json Schema for the class {0}.",
clazz.getName(), e);
diff --git a/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java b/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java
index 12ef68db..65b4c9f5 100644
--- a/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java
+++ b/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java
@@ -99,7 +99,7 @@ void shouldNotAddOrganizationToHeadersWhenBuilderIsCalledWithoutOrganizationId()
var openAI = SimpleOpenAI.builder()
.apiKey("apiKey")
.build();
- assertFalse(openAI.getCleverClient().getHeaders().contains(openAI.getOrganizationId()));
+ assertFalse(openAI.getCleverClient().getHeaders().containsValue(openAI.getOrganizationId()));
}
@Test
@@ -108,7 +108,7 @@ void shouldAddOrganizationToHeadersWhenBuilderIsCalledWithOrganizationId() {
.apiKey("apiKey")
.organizationId("orgId")
.build();
- assertTrue(openAI.getCleverClient().getHeaders().contains(openAI.getOrganizationId()));
+ assertTrue(openAI.getCleverClient().getHeaders().containsValue(openAI.getOrganizationId()));
}
@Test
@@ -165,7 +165,7 @@ void shouldInstanceAudioServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Audios.class,
- new HttpProcessor(null, null, null)));
+ HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.audios());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
@@ -175,7 +175,7 @@ void shouldInstanceChatCompletionServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.ChatCompletions.class,
- new HttpProcessor(null, null, null)));
+ HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.chatCompletions());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
@@ -185,7 +185,7 @@ void shouldInstanceCompletionServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Completions.class,
- new HttpProcessor(null, null, null)));
+ HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.completions());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
@@ -195,7 +195,7 @@ void shouldInstanceEmbeddingServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Embeddings.class,
- new HttpProcessor(null, null, null)));
+ HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.embeddings());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
@@ -205,7 +205,7 @@ void shouldInstanceFilesServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Files.class,
- new HttpProcessor(null, null, null)));
+ HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.files());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
@@ -215,7 +215,7 @@ void shouldInstanceFineTunningServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.FineTunings.class,
- new HttpProcessor(null, null, null)));
+ HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.fineTunings());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
@@ -225,7 +225,7 @@ void shouldInstanceImageServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Images.class,
- new HttpProcessor(null, null, null)));
+ HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.images());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
@@ -235,7 +235,7 @@ void shouldInstanceModelsServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Models.class,
- new HttpProcessor(null, null, null)));
+ HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.models());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
@@ -245,7 +245,7 @@ void shouldInstanceModerationServiceOnlyOnceWhenItIsCalledSeveralTimes() {
when(cleverClient.create(any()))
.thenReturn(ReflectUtil.createProxy(
OpenAI.Moderations.class,
- new HttpProcessor(null, null, null)));
+ HttpProcessor.builder().build()));
repeat(NUMBER_CALLINGS, () -> openAI.moderations());
verify(cleverClient, times(NUMBER_INVOCATIONS)).create(any());
}
diff --git a/src/test/java/io/github/sashirestela/openai/support/JsonSchemaUtilTest.java b/src/test/java/io/github/sashirestela/openai/support/JsonSchemaUtilTest.java
index 529df928..543bb507 100644
--- a/src/test/java/io/github/sashirestela/openai/support/JsonSchemaUtilTest.java
+++ b/src/test/java/io/github/sashirestela/openai/support/JsonSchemaUtilTest.java
@@ -1,15 +1,13 @@
package io.github.sashirestela.openai.support;
+import static io.github.sashirestela.openai.support.JsonSchemaUtil.JSON_EMPTY_CLASS;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import org.junit.jupiter.api.Test;
-
import com.fasterxml.jackson.annotation.JsonProperty;
-
-import io.github.sashirestela.cleverclient.util.Constant;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
+import org.junit.jupiter.api.Test;
class JsonSchemaUtilTest {
@@ -24,7 +22,7 @@ void shouldGenerateFullJsonSchemaWhenClassHasSomeFields() {
@Test
void shouldGenerateEmptyJsonSchemaWhenClassHasNoFields() {
var actualJsonSchema = JsonSchemaUtil.classToJsonSchema(EmptyClass.class).toString();
- var expectedJsonSchema = Constant.JSON_EMPTY_CLASS;
+ var expectedJsonSchema = JSON_EMPTY_CLASS;
assertEquals(expectedJsonSchema, actualJsonSchema);
}