Skip to content

Commit 3c8fc31

Browse files
mariofuscojmartisk
authored andcommitted
Bump to langchain4j 1.0.0-beta3
1 parent 1aa3ef1 commit 3c8fc31

File tree

88 files changed

+426
-696
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+426
-696
lines changed

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.MODERATION_MODEL;
88
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.SCORING_MODEL;
99
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.STREAMING_CHAT_MODEL;
10-
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.TOKEN_COUNT_ESTIMATOR;
1110

1211
import java.util.HashSet;
1312
import java.util.List;
@@ -144,8 +143,6 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
144143
requestedModerationModels.add(modelName);
145144
} else if (IMAGE_MODEL.equals(requiredName)) {
146145
requestedImageModels.add(modelName);
147-
} else if (TOKEN_COUNT_ESTIMATOR.equals(requiredName)) {
148-
tokenCountEstimators.add(modelName);
149146
}
150147
}
151148
for (

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import dev.langchain4j.memory.chat.ChatMemoryProvider;
99
import dev.langchain4j.model.chat.ChatLanguageModel;
1010
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
11-
import dev.langchain4j.model.chat.TokenCountEstimator;
1211
import dev.langchain4j.model.embedding.EmbeddingModel;
1312
import dev.langchain4j.model.image.ImageModel;
1413
import dev.langchain4j.model.input.structured.StructuredPrompt;
@@ -17,7 +16,7 @@
1716
import dev.langchain4j.model.output.structured.Description;
1817
import dev.langchain4j.model.scoring.ScoringModel;
1918
import dev.langchain4j.rag.RetrievalAugmentor;
20-
import dev.langchain4j.retriever.Retriever;
19+
import dev.langchain4j.rag.content.retriever.ContentRetriever;
2120
import dev.langchain4j.service.AiServices;
2221
import dev.langchain4j.service.MemoryId;
2322
import dev.langchain4j.service.Moderate;
@@ -46,7 +45,6 @@ public class LangChain4jDotNames {
4645
public static final DotName EMBEDDING_MODEL = DotName.createSimple(EmbeddingModel.class);
4746
public static final DotName MODERATION_MODEL = DotName.createSimple(ModerationModel.class);
4847
public static final DotName IMAGE_MODEL = DotName.createSimple(ImageModel.class);
49-
public static final DotName TOKEN_COUNT_ESTIMATOR = DotName.createSimple(TokenCountEstimator.class);
5048
public static final DotName CHAT_MESSAGE = DotName.createSimple(ChatMessage.class);
5149
public static final DotName TOKEN_STREAM = DotName.createSimple(TokenStream.class);
5250
public static final DotName OUTPUT_GUARDRAILS = DotName.createSimple(OutputGuardrails.class);
@@ -82,7 +80,7 @@ public class LangChain4jDotNames {
8280
static final DotName NO_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
8381
RegisterAiService.NoChatMemoryProviderSupplier.class);
8482

85-
static final DotName RETRIEVER = DotName.createSimple(Retriever.class);
83+
static final DotName RETRIEVER = DotName.createSimple(ContentRetriever.class);
8684
static final DotName NO_RETRIEVER = DotName.createSimple(
8785
RegisterAiService.NoRetriever.class);
8886

core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ public void handleTools(
244244

245245
builder.parameters(
246246
JsonObjectSchema.builder()
247-
.properties(properties)
247+
.addProperties(properties)
248248
.required(required)
249249
.build());
250250

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/AiServiceMethodParametersAnnotationsTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.quarkiverse.langchain4j.test;
22

3+
import static io.quarkiverse.langchain4j.runtime.LangChain4jUtil.chatMessageToText;
34
import static io.restassured.RestAssured.get;
45
import static org.hamcrest.Matchers.containsString;
56
import static org.hamcrest.Matchers.equalTo;
@@ -84,7 +85,8 @@ public ChatLanguageModel get() {
8485
return new ChatLanguageModel() {
8586
@Override
8687
public ChatResponse doChat(ChatRequest chatRequest) {
87-
return ChatResponse.builder().aiMessage(new AiMessage(chatRequest.messages().get(0).text())).build();
88+
return ChatResponse.builder().aiMessage(new AiMessage(chatMessageToText(chatRequest.messages().get(0))))
89+
.build();
8890
}
8991
};
9092
}

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/InMemoryEmbeddingStoreTest.java

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
2626
import dev.langchain4j.store.embedding.CosineSimilarity;
2727
import dev.langchain4j.store.embedding.EmbeddingMatch;
28+
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
2829
import dev.langchain4j.store.embedding.EmbeddingStore;
2930
import dev.langchain4j.store.embedding.RelevanceScore;
3031
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
@@ -59,7 +60,8 @@ void should_add_embedding() {
5960
String id = embeddingStore.add(embedding);
6061
assertThat(id).isNotNull();
6162

62-
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
63+
var request = EmbeddingSearchRequest.builder().queryEmbedding(embedding).maxResults(10).build();
64+
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.search(request).matches();
6365
assertThat(relevant).hasSize(1);
6466

6567
EmbeddingMatch<TextSegment> match = relevant.get(0);
@@ -79,7 +81,8 @@ void should_add_embedding_with_id() {
7981

8082
embeddingStore.add(id, embedding);
8183

82-
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
84+
var request = EmbeddingSearchRequest.builder().queryEmbedding(embedding).maxResults(10).build();
85+
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.search(request).matches();
8386
assertThat(relevant).hasSize(1);
8487

8588
EmbeddingMatch<TextSegment> match = relevant.get(0);
@@ -100,7 +103,8 @@ void should_add_embedding_with_segment() {
100103
String id = embeddingStore.add(embedding, segment);
101104
assertThat(id).isNotNull();
102105

103-
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
106+
var request = EmbeddingSearchRequest.builder().queryEmbedding(embedding).maxResults(10).build();
107+
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.search(request).matches();
104108
assertThat(relevant).hasSize(1);
105109

106110
EmbeddingMatch<TextSegment> match = relevant.get(0);
@@ -121,7 +125,8 @@ void should_add_embedding_with_segment_with_metadata() {
121125
String id = embeddingStore.add(embedding, segment);
122126
assertThat(id).isNotNull();
123127

124-
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
128+
var request = EmbeddingSearchRequest.builder().queryEmbedding(embedding).maxResults(10).build();
129+
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.search(request).matches();
125130
assertThat(relevant).hasSize(1);
126131

127132
EmbeddingMatch<TextSegment> match = relevant.get(0);
@@ -142,7 +147,8 @@ void should_add_multiple_embeddings() {
142147
List<String> ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
143148
assertThat(ids).hasSize(2);
144149

145-
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
150+
var request = EmbeddingSearchRequest.builder().queryEmbedding(firstEmbedding).maxResults(10).build();
151+
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.search(request).matches();
146152
assertThat(relevant).hasSize(2);
147153

148154
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
@@ -173,7 +179,8 @@ void should_add_multiple_embeddings_with_segments() {
173179
asList(firstSegment, secondSegment));
174180
assertThat(ids).hasSize(2);
175181

176-
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
182+
var request = EmbeddingSearchRequest.builder().queryEmbedding(firstEmbedding).maxResults(10).build();
183+
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.search(request).matches();
177184
assertThat(relevant).hasSize(2);
178185

179186
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
@@ -202,7 +209,8 @@ void should_find_with_min_score() {
202209
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
203210
embeddingStore.add(secondId, secondEmbedding);
204211

205-
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
212+
var request = EmbeddingSearchRequest.builder().queryEmbedding(firstEmbedding).maxResults(10).build();
213+
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.search(request).matches();
206214
assertThat(relevant).hasSize(2);
207215
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
208216
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
@@ -211,26 +219,23 @@ void should_find_with_min_score() {
211219
assertThat(secondMatch.score()).isBetween(0d, 1d);
212220
assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
213221

214-
List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore.findRelevant(
215-
firstEmbedding,
216-
10,
217-
secondMatch.score() - 0.01);
222+
var request2 = EmbeddingSearchRequest.builder().queryEmbedding(firstEmbedding).maxResults(10)
223+
.minScore(secondMatch.score() - 0.01).build();
224+
List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore.search(request2).matches();
218225
assertThat(relevant2).hasSize(2);
219226
assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
220227
assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
221228

222-
List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore.findRelevant(
223-
firstEmbedding,
224-
10,
225-
secondMatch.score());
229+
var request3 = EmbeddingSearchRequest.builder().queryEmbedding(firstEmbedding).maxResults(10)
230+
.minScore(secondMatch.score()).build();
231+
List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore.search(request3).matches();
226232
assertThat(relevant3).hasSize(2);
227233
assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
228234
assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
229235

230-
List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore.findRelevant(
231-
firstEmbedding,
232-
10,
233-
secondMatch.score() + 0.01);
236+
var request4 = EmbeddingSearchRequest.builder().queryEmbedding(firstEmbedding).maxResults(10)
237+
.minScore(secondMatch.score() + 0.01).build();
238+
List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore.search(request4).matches();
234239
assertThat(relevant4).hasSize(1);
235240
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
236241
}
@@ -247,7 +252,8 @@ void should_return_correct_score() {
247252

248253
Embedding referenceEmbedding = embeddingModel.embed("hi").content();
249254

250-
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(referenceEmbedding, 1);
255+
var request = EmbeddingSearchRequest.builder().queryEmbedding(referenceEmbedding).maxResults(1).build();
256+
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.search(request).matches();
251257
assertThat(relevant).hasSize(1);
252258

253259
EmbeddingMatch<TextSegment> match = relevant.get(0);

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/GuardrailWithAugmentationTest.java

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.quarkiverse.langchain4j.test.guardrails;
22

3+
import static io.quarkiverse.langchain4j.runtime.LangChain4jUtil.chatMessageToText;
34
import static org.assertj.core.api.Assertions.assertThat;
45

56
import java.util.List;
@@ -28,7 +29,6 @@
2829
import dev.langchain4j.rag.AugmentationResult;
2930
import dev.langchain4j.rag.RetrievalAugmentor;
3031
import dev.langchain4j.rag.content.Content;
31-
import dev.langchain4j.rag.query.Metadata;
3232
import dev.langchain4j.service.MemoryId;
3333
import dev.langchain4j.service.UserMessage;
3434
import io.quarkiverse.langchain4j.RegisterAiService;
@@ -173,7 +173,7 @@ public static class MyChatModel implements ChatLanguageModel {
173173

174174
@Override
175175
public ChatResponse doChat(ChatRequest chatRequest) {
176-
assertThat(chatRequest.messages().get(chatRequest.messages().size() - 1).text()).isEqualTo("augmented");
176+
assertThat(chatMessageToText(chatRequest.messages().get(chatRequest.messages().size() - 1))).isEqualTo("augmented");
177177
return ChatResponse.builder().aiMessage(new AiMessage("Hi!")).build();
178178
}
179179
}
@@ -182,7 +182,7 @@ public static class MyStreamingChatModel implements StreamingChatLanguageModel {
182182

183183
@Override
184184
public void doChat(ChatRequest chatRequest, StreamingChatResponseHandler handler) {
185-
assertThat(chatRequest.messages().get(chatRequest.messages().size() - 1).text()).isEqualTo("augmented");
185+
assertThat(chatMessageToText(chatRequest.messages().get(chatRequest.messages().size() - 1))).isEqualTo("augmented");
186186
handler.onPartialResponse("Streaming hi");
187187
handler.onPartialResponse("!");
188188
handler.onCompleteResponse(ChatResponse.builder().aiMessage(new AiMessage("")).build());
@@ -205,13 +205,6 @@ public static class MyRetrievalAugmentor implements Supplier<RetrievalAugmentor>
205205
@Override
206206
public RetrievalAugmentor get() {
207207
return new RetrievalAugmentor() {
208-
@Override
209-
public dev.langchain4j.data.message.UserMessage augment(dev.langchain4j.data.message.UserMessage userMessage,
210-
Metadata metadata) {
211-
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
212-
return (dev.langchain4j.data.message.UserMessage) augment(augmentationRequest).chatMessage();
213-
}
214-
215208
@Override
216209
public AugmentationResult augment(AugmentationRequest augmentationRequest) {
217210
List<Content> content = List.of(Content.from("content1"), Content.from("content2"));

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailValidationTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.quarkiverse.langchain4j.test.guardrails;
22

3+
import static io.quarkiverse.langchain4j.runtime.LangChain4jUtil.chatMessageToText;
34
import static org.assertj.core.api.Assertions.assertThat;
45
import static org.assertj.core.api.Assertions.assertThatThrownBy;
56

@@ -186,8 +187,8 @@ public InputGuardrailResult validate(InputGuardrailParams params) {
186187
assertThat(params.userMessage().singleText()).isEqualTo("foo");
187188
}
188189
if (params.memory().messages().size() == 2) {
189-
assertThat(params.memory().messages().get(0).text()).isEqualTo("foo");
190-
assertThat(params.memory().messages().get(1).text()).isEqualTo("Hi!");
190+
assertThat(chatMessageToText(params.memory().messages().get(0))).isEqualTo("foo");
191+
assertThat(chatMessageToText(params.memory().messages().get(1))).isEqualTo("Hi!");
191192
assertThat(params.userMessage().singleText()).isEqualTo("bar");
192193
}
193194
return success();

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailRepromptingTest.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.quarkiverse.langchain4j.test.guardrails;
22

3+
import static io.quarkiverse.langchain4j.runtime.LangChain4jUtil.chatMessageToText;
34
import static org.assertj.core.api.Assertions.assertThat;
45
import static org.assertj.core.api.Assertions.assertThatThrownBy;
56

@@ -143,7 +144,7 @@ public OutputGuardrailResult validate(OutputGuardrailParams params) {
143144
assertThat(((AiMessage) last).text()).isEqualTo("Hello");
144145
assertThat(params.responseFromLLM().text()).isEqualTo("Hello");
145146
assertThat(beforeLast).isInstanceOf(UserMessage.class);
146-
assertThat(beforeLast.text()).isEqualTo("Retry");
147+
assertThat(chatMessageToText(beforeLast)).isEqualTo("Retry");
147148

148149
return reprompt("Retry", "Retry");
149150
}
@@ -181,7 +182,7 @@ public OutputGuardrailResult validate(OutputGuardrailParams params) {
181182
assertThat(last).isInstanceOf(AiMessage.class);
182183
assertThat(((AiMessage) last).text()).isEqualTo("Hello");
183184
assertThat(beforeLast).isInstanceOf(UserMessage.class);
184-
assertThat(beforeLast.text()).isEqualTo("Retry Once");
185+
assertThat(chatMessageToText(beforeLast)).isEqualTo("Retry Once");
185186
return reprompt("Retry", "Retry Twice");
186187
}
187188
return reprompt("Retry", "Retry Again");
@@ -197,10 +198,10 @@ public static class MyChatModel implements ChatLanguageModel {
197198
@Override
198199
public ChatResponse doChat(ChatRequest request) {
199200
ChatMessage last = request.messages().get(request.messages().size() - 1);
200-
if (last instanceof UserMessage && last.text().equals("foo")) {
201+
if (last instanceof UserMessage && chatMessageToText(last).equals("foo")) {
201202
return ChatResponse.builder().aiMessage(new AiMessage("Nope")).build();
202203
}
203-
if (last instanceof UserMessage && last.text().contains("Retry")) {
204+
if (last instanceof UserMessage && chatMessageToText(last).contains("Retry")) {
204205
return ChatResponse.builder().aiMessage(new AiMessage("Hello")).build();
205206
}
206207
throw new IllegalArgumentException("Unexpected message: " + request.messages());

core/deployment/src/test/java/io/quarkiverse/langchain4j/test/response/ResponseAugmenterWithAugmentationResultTest.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import dev.langchain4j.rag.AugmentationResult;
1919
import dev.langchain4j.rag.RetrievalAugmentor;
2020
import dev.langchain4j.rag.content.Content;
21-
import dev.langchain4j.rag.query.Metadata;
2221
import dev.langchain4j.service.UserMessage;
2322
import io.quarkiverse.langchain4j.RegisterAiService;
2423
import io.quarkiverse.langchain4j.response.AiResponseAugmenter;
@@ -86,13 +85,6 @@ public static class MyRetrievalAugmentor implements Supplier<RetrievalAugmentor>
8685
@Override
8786
public RetrievalAugmentor get() {
8887
return new RetrievalAugmentor() {
89-
@Override
90-
public dev.langchain4j.data.message.UserMessage augment(dev.langchain4j.data.message.UserMessage userMessage,
91-
Metadata metadata) {
92-
AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
93-
return (dev.langchain4j.data.message.UserMessage) augment(augmentationRequest).chatMessage();
94-
}
95-
9688
@Override
9789
public AugmentationResult augment(AugmentationRequest augmentationRequest) {
9890
List<Content> content = List.of(Content.from("content1"), Content.from("content2"));

core/opentelemetry-tests/src/test/java/io/quarkiverse/langchain4j/opentelemetry/test/ListenersProcessorAbstractSpanChatModelListenerTest.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import dev.langchain4j.data.message.AiMessage;
1919
import dev.langchain4j.data.message.UserMessage;
20+
import dev.langchain4j.model.ModelProvider;
2021
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
2122
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
2223
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
@@ -108,16 +109,16 @@ record MockedContexts(
108109
ChatModelResponseContext responseContext,
109110
ChatModelErrorContext errorContext) {
110111
static MockedContexts create() {
111-
var attributes = new HashMap();
112+
var attributes = new HashMap<>();
112113
var request = ChatRequest.builder().messages(List.of(UserMessage.from("--test-message--")))
113114
.parameters(DefaultChatRequestParameters.builder().modelName("--mock-model-name--").temperature(0.0)
114115
.topP(0.0).build())
115116
.build();
116117
var response = ChatResponse.builder().aiMessage(AiMessage.from("--test-response--")).build();
117-
var requestCtx = new ChatModelRequestContext(request, attributes);
118-
var responseContext = new ChatModelResponseContext(response, request, attributes);
118+
var requestCtx = new ChatModelRequestContext(request, ModelProvider.OTHER, attributes);
119+
var responseContext = new ChatModelResponseContext(response, request, ModelProvider.OTHER, attributes);
119120
var errorCtx = new ChatModelErrorContext(
120-
new RuntimeException("--failed--"), request, attributes);
121+
new RuntimeException("--failed--"), request, ModelProvider.OTHER, attributes);
121122
return new MockedContexts(requestCtx, responseContext, errorCtx);
122123
}
123124
}

0 commit comments

Comments
 (0)