Skip to content

Commit 1562840

Browse files
committed
first cut of aot improvements
Signed-off-by: Josh Long <[email protected]>
1 parent 4be1002 commit 1562840

File tree

4 files changed

+39
-29
lines changed

4 files changed

+39
-29
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java

+10-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList;
4343
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
4444
import org.springframework.ai.retry.RetryUtils;
45+
import org.springframework.aot.hint.RuntimeHints;
46+
import org.springframework.aot.hint.RuntimeHintsRegistrar;
4547
import org.springframework.retry.support.RetryTemplate;
4648
import org.springframework.util.Assert;
4749

@@ -50,8 +52,10 @@
5052
*
5153
* @author Christian Tzolov
5254
* @author Thomas Vitale
55+
* @author Josh Long
56+
*
5357
*/
54-
public class OpenAiEmbeddingModel extends AbstractEmbeddingModel {
58+
public class OpenAiEmbeddingModel extends AbstractEmbeddingModel implements RuntimeHintsRegistrar {
5559

5660
private static final Logger logger = LoggerFactory.getLogger(OpenAiEmbeddingModel.class);
5761

@@ -229,4 +233,9 @@ public void setObservationConvention(EmbeddingModelObservationConvention observa
229233
this.observationConvention = observationConvention;
230234
}
231235

236+
@Override
237+
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
238+
239+
}
240+
232241
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java

+3-18
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,10 @@
1616

1717
package org.springframework.ai.openai.aot;
1818

19-
import java.util.Set;
20-
21-
import org.springframework.ai.openai.api.OpenAiApi;
22-
import org.springframework.ai.openai.api.OpenAiAudioApi;
23-
import org.springframework.ai.openai.api.OpenAiImageApi;
19+
import org.springframework.ai.openai.OpenAiChatOptions;
2420
import org.springframework.aot.hint.MemberCategory;
2521
import org.springframework.aot.hint.RuntimeHints;
2622
import org.springframework.aot.hint.RuntimeHintsRegistrar;
27-
import org.springframework.aot.hint.TypeReference;
2823
import org.springframework.lang.NonNull;
2924
import org.springframework.lang.Nullable;
3025

@@ -40,23 +35,13 @@
4035
*/
4136
public class OpenAiRuntimeHints implements RuntimeHintsRegistrar {
4237

43-
private static Set<TypeReference> eval(Set<TypeReference> referenceSet) {
44-
referenceSet.forEach(tr -> System.out.println(tr.toString()));
45-
return referenceSet;
46-
}
47-
4838
@Override
4939
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
5040
var mcs = MemberCategory.values();
51-
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiApi.class))) {
52-
hints.reflection().registerType(tr, mcs);
53-
}
54-
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiAudioApi.class))) {
55-
hints.reflection().registerType(tr, mcs);
56-
}
57-
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiImageApi.class))) {
41+
for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiChatOptions.class))) {
5842
hints.reflection().registerType(tr, mcs);
5943
}
44+
6045
}
6146

6247
}

spring-ai-model/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java

+26-10
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,41 @@
1616

1717
package org.springframework.ai.embedding;
1818

19+
import org.springframework.aot.hint.RuntimeHints;
20+
import org.springframework.aot.hint.RuntimeHintsRegistrar;
21+
import org.springframework.context.annotation.ImportRuntimeHints;
22+
import org.springframework.core.io.ClassPathResource;
23+
import org.springframework.core.io.Resource;
24+
import org.springframework.util.Assert;
25+
1926
import java.io.IOException;
2027
import java.util.Map;
2128
import java.util.Properties;
2229
import java.util.concurrent.atomic.AtomicInteger;
2330
import java.util.stream.Collectors;
2431

25-
import org.springframework.core.io.DefaultResourceLoader;
26-
2732
/**
2833
* Abstract implementation of the {@link EmbeddingModel} interface that provides
2934
* dimensions calculation caching.
3035
*
3136
* @author Christian Tzolov
37+
* @author Josh Long
3238
*/
39+
@ImportRuntimeHints(AbstractEmbeddingModel.Hints.class)
3340
public abstract class AbstractEmbeddingModel implements EmbeddingModel {
3441

42+
private static final Resource EMBEDDING_MODEL_DIMENSIONS_PROPERTIES = new ClassPathResource(
43+
"/embedding/embedding-model-dimensions.properties");
44+
3545
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions();
3646

37-
/**
38-
* Default constructor.
39-
*/
40-
public AbstractEmbeddingModel() {
47+
static class Hints implements RuntimeHintsRegistrar {
48+
49+
@Override
50+
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
51+
hints.resources().registerResource(EMBEDDING_MODEL_DIMENSIONS_PROPERTIES);
52+
}
53+
4154
}
4255

4356
/**
@@ -69,10 +82,13 @@ public static int dimensions(EmbeddingModel embeddingModel, String modelName, St
6982

7083
private static Map<String, Integer> loadKnownModelDimensions() {
7184
try {
72-
Properties properties = new Properties();
73-
properties.load(new DefaultResourceLoader()
74-
.getResource("classpath:/embedding/embedding-model-dimensions.properties")
75-
.getInputStream());
85+
var resource = EMBEDDING_MODEL_DIMENSIONS_PROPERTIES;
86+
Assert.notNull(resource, "the embedding dimensions must be non-null");
87+
Assert.state(resource.exists(), "the embedding dimensions properties file must exist");
88+
var properties = new Properties();
89+
try (var in = resource.getInputStream()) {
90+
properties.load(in);
91+
}
7692
return properties.entrySet()
7793
.stream()
7894
.collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())));

0 commit comments

Comments
 (0)