diff --git a/ollama-examples/pom.xml b/ollama-examples/pom.xml index 136fc75..1ee5adf 100644 --- a/ollama-examples/pom.xml +++ b/ollama-examples/pom.xml @@ -16,6 +16,12 @@ + + dev.langchain4j + langchain4j + 1.0.0-alpha1 + + dev.langchain4j langchain4j-ollama @@ -24,8 +30,14 @@ org.testcontainers - testcontainers - 1.19.6 + ollama + 1.20.4 + + + + com.fasterxml.jackson.core + jackson-databind + 2.18.2 @@ -40,6 +52,12 @@ 5.10.0 + + org.assertj + assertj-core + 3.27.0 + + org.tinylog tinylog-impl diff --git a/ollama-examples/src/main/java/OllamaChatModelTest.java b/ollama-examples/src/main/java/OllamaChatModelTest.java index e36a077..f993abf 100644 --- a/ollama-examples/src/main/java/OllamaChatModelTest.java +++ b/ollama-examples/src/main/java/OllamaChatModelTest.java @@ -1,5 +1,5 @@ -import dev.langchain4j.agent.tool.ToolSpecification; -import dev.langchain4j.data.message.AiMessage; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -9,144 +9,143 @@ import dev.langchain4j.model.chat.request.json.JsonSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.ollama.OllamaChatModel; -import dev.langchain4j.model.output.Response; +import dev.langchain4j.service.AiServices; import org.junit.jupiter.api.Test; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; +import utils.AbstractOllamaInfrastructure; -import java.util.List; +import java.util.Map; -import static dev.langchain4j.model.chat.request.ResponseFormat.JSON; +import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA; +import static org.assertj.core.api.Assertions.assertThat; -@Testcontainers -class OllamaChatModelTest { +class OllamaChatModelTest extends AbstractOllamaInfrastructure { /** - * The first time you run this test, it will download a Docker image with Ollama and a model. + * If you have Ollama running locally, + * please set the OLLAMA_BASE_URL environment variable (e.g., http://localhost:11434). + * If you do not set the OLLAMA_BASE_URL environment variable, + * Testcontainers will download and start Ollama Docker container. * It might take a few minutes. - *

- * This test uses modified Ollama Docker images, which already contain models inside them. - * All images with pre-packaged models are available here: https://hub.docker.com/repositories/langchain4j - *

- * However, you are not restricted to these images. - * You can run any model from https://ollama.ai/library by following these steps: - * 1. Run "docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama" - * 2. Run "docker exec -it ollama ollama run mistral" <- specify the desired model here */ - static String MODEL_NAME = "orca-mini"; // try "mistral", "llama2", "codellama", "phi" or "tinyllama" - - @Container - static GenericContainer ollama = new GenericContainer<>("langchain4j/ollama-" + MODEL_NAME + ":latest") - .withExposedPorts(11434); - @Test void simple_example() { - ChatLanguageModel model = OllamaChatModel.builder() - .baseUrl(baseUrl()) + ChatLanguageModel chatModel = OllamaChatModel.builder() + .baseUrl(ollamaBaseUrl(ollama)) .modelName(MODEL_NAME) + .logRequests(true) .build(); - String answer = model.generate("Provide 3 short bullet points explaining why Java is awesome"); - + String answer = chatModel.chat("Provide 3 short bullet points explaining why Java is awesome"); System.out.println(answer); + + assertThat(answer).isNotBlank(); } @Test - void json_output_example() { + void json_schema_with_AI_Service_example() { + + record Person(String name, int age) { + } + + interface PersonExtractor { - ChatLanguageModel model = OllamaChatModel.builder() - .baseUrl(baseUrl()) + Person extractPersonFrom(String text); + } + + ChatLanguageModel chatModel = OllamaChatModel.builder() + .baseUrl(ollamaBaseUrl(ollama)) .modelName(MODEL_NAME) - .responseFormat(JSON) + .temperature(0.0) + .supportedCapabilities(RESPONSE_FORMAT_JSON_SCHEMA) + .logRequests(true) .build(); - String json = model.generate("Give me a JSON with 2 fields: name and age of a John Doe, 42"); + PersonExtractor personExtractor = AiServices.create(PersonExtractor.class, chatModel); - System.out.println(json); + Person person = personExtractor.extractPersonFrom("John Doe is 42 years old"); + System.out.println(person); + + assertThat(person).isEqualTo(new Person("John Doe", 42)); } @Test - void json_schema_builder_example() { + void json_schema_with_low_level_chat_api_example() { - ChatLanguageModel model = OllamaChatModel.builder() - .baseUrl(baseUrl()) + ChatLanguageModel chatModel = OllamaChatModel.builder() + .baseUrl(ollamaBaseUrl(ollama)) .modelName(MODEL_NAME) + .temperature(0.0) + .logRequests(true) + .build(); + + ChatRequest chatRequest = ChatRequest.builder() + .messages(UserMessage.from("John Doe is 42 years old")) .responseFormat(ResponseFormat.builder() .type(ResponseFormatType.JSON) .jsonSchema(JsonSchema.builder() - .name("Person") .rootElement(JsonObjectSchema.builder() - .addStringProperty("fullName") + .addStringProperty("name") .addIntegerProperty("age") .build()) .build()) .build()) .build(); - String json = model.generate("Extract data: John Doe, 42"); + ChatResponse chatResponse = chatModel.chat(chatRequest); + System.out.println(chatResponse); - System.out.println(json); + assertThat(toMap(chatResponse.aiMessage().text())).isEqualTo(Map.of("name", "John Doe", "age", 42)); } @Test - void json_schema_chat_api_example() { + void json_schema_with_low_level_model_builder_example() { - ChatLanguageModel model = OllamaChatModel.builder() - .baseUrl(baseUrl()) + ChatLanguageModel chatModel = OllamaChatModel.builder() + .baseUrl(ollamaBaseUrl(ollama)) .modelName(MODEL_NAME) - .build(); - - - ChatResponse chatResponse = model.chat(ChatRequest.builder() - .messages(UserMessage.from("Extract data: John Doe, 42")) + .temperature(0.0) .responseFormat(ResponseFormat.builder() .type(ResponseFormatType.JSON) .jsonSchema(JsonSchema.builder() - .name("Person") .rootElement(JsonObjectSchema.builder() - .addStringProperty("fullName") + .addStringProperty("name") .addIntegerProperty("age") .build()) .build()) .build()) - .build()); + .logRequests(true) + .build(); - System.out.println(chatResponse.aiMessage().text()); - } + String json = chatModel.chat("Extract: John Doe is 42 years old"); + System.out.println(json); + assertThat(toMap(json)).isEqualTo(Map.of("name", "John Doe", "age", 42)); + } @Test - void ollama_tools_specification_example() { + void json_mode_with_low_level_model_builder_example() { - ChatLanguageModel model = OllamaChatModel.builder() - .baseUrl(baseUrl()) + ChatLanguageModel chatModel = OllamaChatModel.builder() + .baseUrl(ollamaBaseUrl(ollama)) .modelName(MODEL_NAME) + .temperature(0.0) + .responseFormat(ResponseFormat.JSON) + .logRequests(true) .build(); + String json = chatModel.chat("Give me a JSON object with 2 fields: name and age of a John Doe, 42"); + System.out.println(json); - List toolSpecificationList = List.of( - ToolSpecification.builder() - .name("get_fav_color") - .description("Gets favorite color of user by ID") - .parameters(JsonObjectSchema.builder() - .addIntegerProperty("user_id") - .required("user_id") - .build()) - .build() - ); - - Response aiMessageResponse = model.generate( - List.of(UserMessage.from("Find the favorite color of user Jim with ID 21")), - toolSpecificationList - ); - - System.out.println(aiMessageResponse.content().toolExecutionRequests()); + assertThat(toMap(json)).isEqualTo(Map.of("name", "John Doe", "age", 42)); } - static String baseUrl() { - return String.format("http://%s:%d", ollama.getHost(), ollama.getFirstMappedPort()); + private static Map toMap(String json) { + try { + return new ObjectMapper().readValue(json, Map.class); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } } } diff --git a/ollama-examples/src/main/java/OllamaStreamingChatModelTest.java b/ollama-examples/src/main/java/OllamaStreamingChatModelTest.java index 56a7f67..56ca7a4 100644 --- a/ollama-examples/src/main/java/OllamaStreamingChatModelTest.java +++ b/ollama-examples/src/main/java/OllamaStreamingChatModelTest.java @@ -4,49 +4,34 @@ import dev.langchain4j.model.ollama.OllamaStreamingChatModel; import dev.langchain4j.model.output.Response; import org.junit.jupiter.api.Test; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; +import utils.AbstractOllamaInfrastructure; import java.util.concurrent.CompletableFuture; @Testcontainers -class OllamaStreamingChatModelTest { +class OllamaStreamingChatModelTest extends AbstractOllamaInfrastructure { /** - * The first time you run this test, it will download a Docker image with Ollama and a model. + * If you have Ollama running locally, + * please set the OLLAMA_BASE_URL environment variable (e.g., http://localhost:11434). + * If you do not set the OLLAMA_BASE_URL environment variable, + * Testcontainers will download and start Ollama Docker container. * It might take a few minutes. - *

- * This test uses modified Ollama Docker images, which already contain models inside them. - * All images with pre-packaged models are available here: https://hub.docker.com/repositories/langchain4j - *

- * However, you are not restricted to these images. - * You can run any model from https://ollama.ai/library by following these steps: - * 1. Run "docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama" - * 2. Run "docker exec -it ollama ollama run llama2" <- specify the desired model here */ - static String MODEL_NAME = "orca-mini"; // try "mistral", "llama2", "codellama" or "phi" - static String DOCKER_IMAGE_NAME = "langchain4j/ollama-" + MODEL_NAME + ":latest"; - static Integer PORT = 11434; - - @Container - static GenericContainer ollama = new GenericContainer<>(DOCKER_IMAGE_NAME) - .withExposedPorts(PORT); - @Test void streaming_example() { StreamingChatLanguageModel model = OllamaStreamingChatModel.builder() - .baseUrl(String.format("http://%s:%d", ollama.getHost(), ollama.getMappedPort(PORT))) + .baseUrl(ollamaBaseUrl(ollama)) .modelName(MODEL_NAME) - .temperature(0.0) .build(); String userMessage = "Write a 100-word poem about Java and AI"; CompletableFuture> futureResponse = new CompletableFuture<>(); - model.generate(userMessage, new StreamingResponseHandler() { + model.generate(userMessage, new StreamingResponseHandler<>() { @Override public void onNext(String token) { diff --git a/ollama-examples/src/main/java/utils/AbstractOllamaInfrastructure.java b/ollama-examples/src/main/java/utils/AbstractOllamaInfrastructure.java new file mode 100644 index 0000000..a5bfc51 --- /dev/null +++ b/ollama-examples/src/main/java/utils/AbstractOllamaInfrastructure.java @@ -0,0 +1,31 @@ +package utils; + +import static dev.langchain4j.internal.Utils.isNullOrEmpty; +import static utils.OllamaImage.LLAMA_3_1; +import static utils.OllamaImage.localOllamaImage; + +public class AbstractOllamaInfrastructure { + + public static final String OLLAMA_BASE_URL = System.getenv("OLLAMA_BASE_URL"); + public static final String MODEL_NAME = LLAMA_3_1; + + public static LangChain4jOllamaContainer ollama; + + static { + if (isNullOrEmpty(OLLAMA_BASE_URL)) { + String localOllamaImage = localOllamaImage(MODEL_NAME); + ollama = new LangChain4jOllamaContainer(OllamaImage.resolve(OllamaImage.OLLAMA_IMAGE, localOllamaImage)) + .withModel(MODEL_NAME); + ollama.start(); + ollama.commitToImage(localOllamaImage); + } + } + + public static String ollamaBaseUrl(LangChain4jOllamaContainer ollama) { + if (isNullOrEmpty(OLLAMA_BASE_URL)) { + return ollama.getEndpoint(); + } else { + return OLLAMA_BASE_URL; + } + } +} diff --git a/ollama-examples/src/main/java/utils/LangChain4jOllamaContainer.java b/ollama-examples/src/main/java/utils/LangChain4jOllamaContainer.java new file mode 100644 index 0000000..093892c --- /dev/null +++ b/ollama-examples/src/main/java/utils/LangChain4jOllamaContainer.java @@ -0,0 +1,38 @@ +package utils; + +import com.github.dockerjava.api.command.InspectContainerResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.ollama.OllamaContainer; +import org.testcontainers.utility.DockerImageName; + +import java.io.IOException; + +public class LangChain4jOllamaContainer extends OllamaContainer { + + private static final Logger log = LoggerFactory.getLogger(LangChain4jOllamaContainer.class); + + private String model; + + public LangChain4jOllamaContainer(DockerImageName dockerImageName) { + super(dockerImageName); + } + + public LangChain4jOllamaContainer withModel(String model) { + this.model = model; + return this; + } + + @Override + protected void containerIsStarted(InspectContainerResponse containerInfo) { + if (this.model != null) { + try { + log.info("Start pulling the '{}' model ... would take several minutes ...", this.model); + ExecResult r = execInContainer("ollama", "pull", this.model); + log.info("Model pulling competed! {}", r); + } catch (IOException | InterruptedException e) { + throw new RuntimeException("Error pulling model", e); + } + } + } +} diff --git a/ollama-examples/src/main/java/utils/OllamaImage.java b/ollama-examples/src/main/java/utils/OllamaImage.java new file mode 100644 index 0000000..ba7155c --- /dev/null +++ b/ollama-examples/src/main/java/utils/OllamaImage.java @@ -0,0 +1,29 @@ +package utils; + +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.model.Image; +import org.testcontainers.DockerClientFactory; +import org.testcontainers.utility.DockerImageName; + +import java.util.List; + +public class OllamaImage { + + public static final String OLLAMA_IMAGE = "ollama/ollama:latest"; + + public static String localOllamaImage(String modelName) { + return String.format("tc-%s-%s", OllamaImage.OLLAMA_IMAGE, modelName); + } + + public static final String LLAMA_3_1 = "llama3.1"; + + public static DockerImageName resolve(String baseImage, String localImageName) { + DockerImageName dockerImageName = DockerImageName.parse(baseImage); + DockerClient dockerClient = DockerClientFactory.instance().client(); + List images = dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec(); + if (images.isEmpty()) { + return dockerImageName; + } + return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage); + } +}