diff --git a/chatmodel-springai/pom.xml b/chatmodel-springai/pom.xml index aaabc93..be995b2 100644 --- a/chatmodel-springai/pom.xml +++ b/chatmodel-springai/pom.xml @@ -28,6 +28,10 @@ org.springframework.boot spring-boot-starter-web + + org.springframework.boot + spring-boot-starter-validation + org.springframework.boot spring-boot-starter-webflux diff --git a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java index 369675a..c664a08 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java +++ b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java @@ -5,6 +5,8 @@ import com.example.ai.model.response.AIStreamChatResponse; import com.example.ai.model.response.ActorsFilms; import com.example.ai.service.ChatService; +import jakarta.validation.Valid; +import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; @@ -15,6 +17,7 @@ @RestController @RequestMapping("/api/ai") +@Validated public class ChatController { private final ChatService chatService; @@ -24,27 +27,27 @@ public class ChatController { } @PostMapping("/chat") - AIChatResponse chat(@RequestBody AIChatRequest aiChatRequest) { + AIChatResponse chat(@RequestBody @Valid AIChatRequest aiChatRequest) { return chatService.chat(aiChatRequest.query()); } @PostMapping("/chat-with-prompt") - AIChatResponse chatWithPrompt(@RequestBody AIChatRequest aiChatRequest) { + AIChatResponse chatWithPrompt(@RequestBody @Valid AIChatRequest aiChatRequest) { return chatService.chatWithPrompt(aiChatRequest.query()); } @PostMapping("/chat-with-system-prompt") - AIChatResponse chatWithSystemPrompt(@RequestBody AIChatRequest aiChatRequest) { + AIChatResponse chatWithSystemPrompt(@RequestBody @Valid AIChatRequest aiChatRequest) { return chatService.chatWithSystemPrompt(aiChatRequest.query()); } @PostMapping("/sentiment/analyze") - AIChatResponse sentimentAnalyzer(@RequestBody AIChatRequest aiChatRequest) { + AIChatResponse sentimentAnalyzer(@RequestBody @Valid AIChatRequest aiChatRequest) { return chatService.analyzeSentiment(aiChatRequest.query()); } - @PostMapping("/emebedding-client-conversion") - AIChatResponse chatWithEmbeddingClient(@RequestBody AIChatRequest aiChatRequest) { + @PostMapping("/embedding-client-conversion") + AIChatResponse chatWithEmbeddingClient(@RequestBody @Valid AIChatRequest aiChatRequest) { return chatService.getEmbeddings(aiChatRequest.query()); } @@ -54,12 +57,12 @@ public ActorsFilms generate(@RequestParam(value = "actor", defaultValue = "Jr NT } @PostMapping("/rag") - AIChatResponse chatUsingRag(@RequestBody AIChatRequest aiChatRequest) { + AIChatResponse chatUsingRag(@RequestBody @Valid AIChatRequest aiChatRequest) { return chatService.ragGenerate(aiChatRequest.query()); } @PostMapping("/chat/stream") - AIStreamChatResponse streamChat(@RequestBody AIChatRequest aiChatRequest) { + AIStreamChatResponse streamChat(@RequestBody @Valid AIChatRequest aiChatRequest) { Flux streamChat = chatService.streamChat(aiChatRequest.query()); return new AIStreamChatResponse(streamChat); } diff --git a/chatmodel-springai/src/main/java/com/example/ai/model/request/AIChatRequest.java b/chatmodel-springai/src/main/java/com/example/ai/model/request/AIChatRequest.java index b0d7c3c..d4b4b01 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/model/request/AIChatRequest.java +++ b/chatmodel-springai/src/main/java/com/example/ai/model/request/AIChatRequest.java @@ -1,3 +1,5 @@ package com.example.ai.model.request; -public record AIChatRequest(String query) {} +import jakarta.validation.constraints.NotBlank; + +public record AIChatRequest(@NotBlank(message = "Query cant be Blank") String query) {} diff --git a/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java b/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java index 9d7b46a..b82ddac 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java +++ b/chatmodel-springai/src/main/java/com/example/ai/service/ChatService.java @@ -85,8 +85,9 @@ public ActorsFilms generateAsBean(String actor) { BeanOutputConverter outputParser = new BeanOutputConverter<>(ActorsFilms.class); String format = outputParser.getFormat(); - String template = """ - Generate the filmography for the actor {actor}. + String template = + """ + Generate the filmography for the Indian actor {actor} as of today. {format} """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("actor", actor, "format", format)); diff --git a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java index 581d991..e3963e2 100644 --- a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java +++ b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java @@ -1,7 +1,9 @@ package com.example.ai.controller; import static io.restassured.RestAssured.given; +import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.containsStringIgnoringCase; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -9,10 +11,15 @@ import com.example.ai.model.request.AIChatRequest; import io.restassured.RestAssured; import io.restassured.http.ContentType; +import java.util.Arrays; +import java.util.stream.Stream; import org.apache.http.HttpStatus; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.web.server.LocalServerPort; @@ -20,6 +27,8 @@ @TestInstance(TestInstance.Lifecycle.PER_CLASS) class ChatControllerTest { + private static final int OPENAI_EMBEDDING_DIMENSION = 1536; + @LocalServerPort private int localServerPort; @@ -31,7 +40,7 @@ void setUp() { @Test void testChat() { given().contentType(ContentType.JSON) - .body(new AIChatRequest("Hello?")) + .body(defaultChatRequest("Hello?")) .when() .post("/api/ai/chat") .then() @@ -41,21 +50,32 @@ void testChat() { } @Test - void chatWithPrompt() { + void shouldReturnBadRequestForMalformedChatRequest() { given().contentType(ContentType.JSON) - .body(new AIChatRequest("java")) + .body("{}") // Empty or malformed request body + .when() + .post("/api/ai/chat") + .then() + .statusCode(HttpStatus.SC_BAD_REQUEST); + } + + @ParameterizedTest + @MethodSource("chatPrompts") + void shouldChatWithMultiplePrompts(String prompt) { + given().contentType(ContentType.JSON) + .body(defaultChatRequest(prompt)) .when() .post("/api/ai/chat-with-prompt") .then() .statusCode(HttpStatus.SC_OK) .contentType(ContentType.JSON) - .body("answer", containsString("Java")); + .body("answer", containsStringIgnoringCase(prompt)); } @Test void chatWithSystemPrompt() { given().contentType(ContentType.JSON) - .body(new AIChatRequest("cricket")) + .body(defaultChatRequest("cricket")) .when() .post("/api/ai/chat-with-system-prompt") .then() @@ -65,9 +85,19 @@ void chatWithSystemPrompt() { } @Test - void sentimentAnalyzer() { + void shouldHandleErrorGracefullyForSystemPrompt() { given().contentType(ContentType.JSON) - .body(new AIChatRequest("Why did the Python programmer go broke? Because he couldn't C#")) + .body(defaultChatRequest("")) + .when() + .post("/api/ai/chat-with-system-prompt") + .then() + .statusCode(HttpStatus.SC_BAD_REQUEST); + } + + @Test + void shouldAnalyzeSentimentAsSarcastic() { + given().contentType(ContentType.JSON) + .body(defaultChatRequest("Why did the Python programmer go broke? Because he couldn't C#")) .when() .post("/api/ai/sentiment/analyze") .then() @@ -76,10 +106,60 @@ void sentimentAnalyzer() { .body("answer", is("SARCASTIC")); } + @ParameterizedTest + @ValueSource(strings = {"This is a test sentence.", "Another different sentence.", "A third unique test case."}) + void shouldGenerateValidEmbeddingsWithinExpectedRange(String input) { + String response = given().contentType(ContentType.JSON) + .body(defaultChatRequest(input)) + .when() + .post("/api/ai/embedding-client-conversion") + .then() + .statusCode(HttpStatus.SC_OK) + .contentType(ContentType.JSON) + .extract() + .jsonPath() + .get("answer"); + + assertThat(response).isNotNull().startsWith("[").endsWith("]"); + + double[] doubles = Arrays.stream(response.replaceAll("[\\[\\]]", "").split(",")) + .mapToDouble(Double::parseDouble) + .toArray(); + + assertThat(doubles.length) + .isEqualTo(OPENAI_EMBEDDING_DIMENSION) + .as("Dimensions for openai model is %d", OPENAI_EMBEDDING_DIMENSION); + + assertThat(Arrays.stream(doubles).allMatch(value -> value >= -1.0 && value <= 1.0)) + .isTrue() + .as("All embedding values should be between -1.0 and 1.0"); + } + + @Test + void shouldHandleErrorCasesGracefully() { + given().contentType(ContentType.JSON) + .body(defaultChatRequest("")) + .when() + .post("/api/ai/embedding-client-conversion") + .then() + .statusCode(HttpStatus.SC_BAD_REQUEST); + } + @Test - void outputParser() { - given().param("actor", "Jr NTR") + void outputParserWithParam() { + given().param("actor", "BalaKrishna") .when() + .get("/api/ai/output") + .then() + .statusCode(HttpStatus.SC_OK) + .contentType(ContentType.JSON) + .body("actor", is("BalaKrishna")) + .body("movies", hasSize(greaterThanOrEqualTo(10))); + } + + @Test + void outputParserDefaultParam() { + given().when() .get("/api/ai/output") .then() .statusCode(HttpStatus.SC_OK) @@ -89,9 +169,9 @@ void outputParser() { } @Test - void ragWithSimpleStore() { + void testRagWithSimpleStoreProvidesValidResponse() { given().contentType(ContentType.JSON) - .body(new AIChatRequest( + .body(defaultChatRequest( "Which is the restaurant with the highest grade that has a cuisine as American ?")) .when() .post("/api/ai/rag") @@ -100,4 +180,12 @@ void ragWithSimpleStore() { .contentType(ContentType.JSON) .body("answer", containsString("Regina Caterers")); } + + static Stream chatPrompts() { + return Stream.of("java", "spring boot", "ai"); + } + + private AIChatRequest defaultChatRequest(String message) { + return new AIChatRequest(message); + } }