diff --git a/chatmodel-springai/pom.xml b/chatmodel-springai/pom.xml
index aaabc93..be995b2 100644
--- a/chatmodel-springai/pom.xml
+++ b/chatmodel-springai/pom.xml
@@ -28,6 +28,10 @@
org.springframework.boot
spring-boot-starter-web
+
+ org.springframework.boot
+ spring-boot-starter-validation
+
org.springframework.boot
spring-boot-starter-webflux
diff --git a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java
index 369675a..c664a08 100644
--- a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java
+++ b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java
@@ -5,6 +5,8 @@
import com.example.ai.model.response.AIStreamChatResponse;
import com.example.ai.model.response.ActorsFilms;
import com.example.ai.service.ChatService;
+import jakarta.validation.Valid;
+import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
@@ -15,6 +17,7 @@
@RestController
@RequestMapping("/api/ai")
+@Validated
public class ChatController {
private final ChatService chatService;
@@ -24,27 +27,27 @@ public class ChatController {
}
@PostMapping("/chat")
- AIChatResponse chat(@RequestBody AIChatRequest aiChatRequest) {
+ AIChatResponse chat(@RequestBody @Valid AIChatRequest aiChatRequest) {
return chatService.chat(aiChatRequest.query());
}
@PostMapping("/chat-with-prompt")
- AIChatResponse chatWithPrompt(@RequestBody AIChatRequest aiChatRequest) {
+ AIChatResponse chatWithPrompt(@RequestBody @Valid AIChatRequest aiChatRequest) {
return chatService.chatWithPrompt(aiChatRequest.query());
}
@PostMapping("/chat-with-system-prompt")
- AIChatResponse chatWithSystemPrompt(@RequestBody AIChatRequest aiChatRequest) {
+ AIChatResponse chatWithSystemPrompt(@RequestBody @Valid AIChatRequest aiChatRequest) {
return chatService.chatWithSystemPrompt(aiChatRequest.query());
}
@PostMapping("/sentiment/analyze")
- AIChatResponse sentimentAnalyzer(@RequestBody AIChatRequest aiChatRequest) {
+ AIChatResponse sentimentAnalyzer(@RequestBody @Valid AIChatRequest aiChatRequest) {
return chatService.analyzeSentiment(aiChatRequest.query());
}
- @PostMapping("/emebedding-client-conversion")
- AIChatResponse chatWithEmbeddingClient(@RequestBody AIChatRequest aiChatRequest) {
+ @PostMapping("/embedding-client-conversion")
+ AIChatResponse chatWithEmbeddingClient(@RequestBody @Valid AIChatRequest aiChatRequest) {
return chatService.getEmbeddings(aiChatRequest.query());
}
@@ -54,12 +57,12 @@ public ActorsFilms generate(@RequestParam(value = "actor", defaultValue = "Jr NT
}
@PostMapping("/rag")
- AIChatResponse chatUsingRag(@RequestBody AIChatRequest aiChatRequest) {
+ AIChatResponse chatUsingRag(@RequestBody @Valid AIChatRequest aiChatRequest) {
return chatService.ragGenerate(aiChatRequest.query());
}
@PostMapping("/chat/stream")
- AIStreamChatResponse streamChat(@RequestBody AIChatRequest aiChatRequest) {
+ AIStreamChatResponse streamChat(@RequestBody @Valid AIChatRequest aiChatRequest) {
Flux streamChat = chatService.streamChat(aiChatRequest.query());
return new AIStreamChatResponse(streamChat);
}
diff --git a/chatmodel-springai/src/main/java/com/example/ai/model/request/AIChatRequest.java b/chatmodel-springai/src/main/java/com/example/ai/model/request/AIChatRequest.java
index b0d7c3c..d4b4b01 100644
--- a/chatmodel-springai/src/main/java/com/example/ai/model/request/AIChatRequest.java
+++ b/chatmodel-springai/src/main/java/com/example/ai/model/request/AIChatRequest.java
@@ -1,3 +1,5 @@
package com.example.ai.model.request;
-public record AIChatRequest(String query) {}
+import jakarta.validation.constraints.NotBlank;
+
+public record AIChatRequest(@NotBlank(message = "Query cant be Blank") String query) {}
diff --git a/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java b/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java
index 9d7b46a..b82ddac 100644
--- a/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java
+++ b/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java
@@ -85,8 +85,9 @@ public ActorsFilms generateAsBean(String actor) {
BeanOutputConverter outputParser = new BeanOutputConverter<>(ActorsFilms.class);
String format = outputParser.getFormat();
- String template = """
- Generate the filmography for the actor {actor}.
+ String template =
+ """
+ Generate the filmography for the Indian actor {actor} as of today.
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("actor", actor, "format", format));
diff --git a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java
index 581d991..e3963e2 100644
--- a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java
+++ b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java
@@ -1,7 +1,9 @@
package com.example.ai.controller;
import static io.restassured.RestAssured.given;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.containsStringIgnoringCase;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
@@ -9,10 +11,15 @@
import com.example.ai.model.request.AIChatRequest;
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
+import java.util.Arrays;
+import java.util.stream.Stream;
import org.apache.http.HttpStatus;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.server.LocalServerPort;
@@ -20,6 +27,8 @@
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class ChatControllerTest {
+ private static final int OPENAI_EMBEDDING_DIMENSION = 1536;
+
@LocalServerPort
private int localServerPort;
@@ -31,7 +40,7 @@ void setUp() {
@Test
void testChat() {
given().contentType(ContentType.JSON)
- .body(new AIChatRequest("Hello?"))
+ .body(defaultChatRequest("Hello?"))
.when()
.post("/api/ai/chat")
.then()
@@ -41,21 +50,32 @@ void testChat() {
}
@Test
- void chatWithPrompt() {
+ void shouldReturnBadRequestForMalformedChatRequest() {
given().contentType(ContentType.JSON)
- .body(new AIChatRequest("java"))
+ .body("{}") // Empty or malformed request body
+ .when()
+ .post("/api/ai/chat")
+ .then()
+ .statusCode(HttpStatus.SC_BAD_REQUEST);
+ }
+
+ @ParameterizedTest
+ @MethodSource("chatPrompts")
+ void shouldChatWithMultiplePrompts(String prompt) {
+ given().contentType(ContentType.JSON)
+ .body(defaultChatRequest(prompt))
.when()
.post("/api/ai/chat-with-prompt")
.then()
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
- .body("answer", containsString("Java"));
+ .body("answer", containsStringIgnoringCase(prompt));
}
@Test
void chatWithSystemPrompt() {
given().contentType(ContentType.JSON)
- .body(new AIChatRequest("cricket"))
+ .body(defaultChatRequest("cricket"))
.when()
.post("/api/ai/chat-with-system-prompt")
.then()
@@ -65,9 +85,19 @@ void chatWithSystemPrompt() {
}
@Test
- void sentimentAnalyzer() {
+ void shouldHandleErrorGracefullyForSystemPrompt() {
given().contentType(ContentType.JSON)
- .body(new AIChatRequest("Why did the Python programmer go broke? Because he couldn't C#"))
+ .body(defaultChatRequest(""))
+ .when()
+ .post("/api/ai/chat-with-system-prompt")
+ .then()
+ .statusCode(HttpStatus.SC_BAD_REQUEST);
+ }
+
+ @Test
+ void shouldAnalyzeSentimentAsSarcastic() {
+ given().contentType(ContentType.JSON)
+ .body(defaultChatRequest("Why did the Python programmer go broke? Because he couldn't C#"))
.when()
.post("/api/ai/sentiment/analyze")
.then()
@@ -76,10 +106,60 @@ void sentimentAnalyzer() {
.body("answer", is("SARCASTIC"));
}
+ @ParameterizedTest
+ @ValueSource(strings = {"This is a test sentence.", "Another different sentence.", "A third unique test case."})
+ void shouldGenerateValidEmbeddingsWithinExpectedRange(String input) {
+ String response = given().contentType(ContentType.JSON)
+ .body(defaultChatRequest(input))
+ .when()
+ .post("/api/ai/embedding-client-conversion")
+ .then()
+ .statusCode(HttpStatus.SC_OK)
+ .contentType(ContentType.JSON)
+ .extract()
+ .jsonPath()
+ .get("answer");
+
+ assertThat(response).isNotNull().startsWith("[").endsWith("]");
+
+ double[] doubles = Arrays.stream(response.replaceAll("[\\[\\]]", "").split(","))
+ .mapToDouble(Double::parseDouble)
+ .toArray();
+
+ assertThat(doubles.length)
+ .isEqualTo(OPENAI_EMBEDDING_DIMENSION)
+ .as("Dimensions for openai model is %d", OPENAI_EMBEDDING_DIMENSION);
+
+ assertThat(Arrays.stream(doubles).allMatch(value -> value >= -1.0 && value <= 1.0))
+ .isTrue()
+ .as("All embedding values should be between -1.0 and 1.0");
+ }
+
+ @Test
+ void shouldHandleErrorCasesGracefully() {
+ given().contentType(ContentType.JSON)
+ .body(defaultChatRequest(""))
+ .when()
+ .post("/api/ai/embedding-client-conversion")
+ .then()
+ .statusCode(HttpStatus.SC_BAD_REQUEST);
+ }
+
@Test
- void outputParser() {
- given().param("actor", "Jr NTR")
+ void outputParserWithParam() {
+ given().param("actor", "BalaKrishna")
.when()
+ .get("/api/ai/output")
+ .then()
+ .statusCode(HttpStatus.SC_OK)
+ .contentType(ContentType.JSON)
+ .body("actor", is("BalaKrishna"))
+ .body("movies", hasSize(greaterThanOrEqualTo(10)));
+ }
+
+ @Test
+ void outputParserDefaultParam() {
+ given().when()
.get("/api/ai/output")
.then()
.statusCode(HttpStatus.SC_OK)
@@ -89,9 +169,9 @@ void outputParser() {
}
@Test
- void ragWithSimpleStore() {
+ void testRagWithSimpleStoreProvidesValidResponse() {
given().contentType(ContentType.JSON)
- .body(new AIChatRequest(
+ .body(defaultChatRequest(
"Which is the restaurant with the highest grade that has a cuisine as American ?"))
.when()
.post("/api/ai/rag")
@@ -100,4 +180,12 @@ void ragWithSimpleStore() {
.contentType(ContentType.JSON)
.body("answer", containsString("Regina Caterers"));
}
+
+ static Stream chatPrompts() {
+ return Stream.of("java", "spring boot", "ai");
+ }
+
+ private AIChatRequest defaultChatRequest(String message) {
+ return new AIChatRequest(message);
+ }
}