Skip to content

AOT refinements for M7 #2664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.aot.hint.RuntimeHints;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will clean this up when merging.

import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

Expand All @@ -50,6 +52,8 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Josh Long
*
*/
public class OpenAiEmbeddingModel extends AbstractEmbeddingModel {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,10 @@

package org.springframework.ai.openai.aot;

import java.util.Set;

import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.aot.hint.TypeReference;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;

Expand All @@ -40,23 +35,13 @@
*/
public class OpenAiRuntimeHints implements RuntimeHintsRegistrar {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joshlong what is the reasoning behind removing the OpenAI API types and just adding the OpenAIChatOptions here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can not hurt to add them. please do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I will add them back and merge.


private static Set<TypeReference> eval(Set<TypeReference> referenceSet) {
referenceSet.forEach(tr -> System.out.println(tr.toString()));
return referenceSet;
}

@Override
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiApi.class))) {
hints.reflection().registerType(tr, mcs);
}
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiAudioApi.class))) {
hints.reflection().registerType(tr, mcs);
}
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiImageApi.class))) {
for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiChatOptions.class))) {
hints.reflection().registerType(tr, mcs);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,41 @@

package org.springframework.ai.embedding;

import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.context.annotation.ImportRuntimeHints;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;

import java.io.IOException;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import org.springframework.core.io.DefaultResourceLoader;

/**
* Abstract implementation of the {@link EmbeddingModel} interface that provides
* dimensions calculation caching.
*
* @author Christian Tzolov
* @author Josh Long
*/
@ImportRuntimeHints(AbstractEmbeddingModel.Hints.class)
public abstract class AbstractEmbeddingModel implements EmbeddingModel {

private static final Resource EMBEDDING_MODEL_DIMENSIONS_PROPERTIES = new ClassPathResource(
"/embedding/embedding-model-dimensions.properties");

private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions();

/**
* Default constructor.
*/
public AbstractEmbeddingModel() {
static class Hints implements RuntimeHintsRegistrar {

@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
hints.resources().registerResource(EMBEDDING_MODEL_DIMENSIONS_PROPERTIES);
}

}

/**
Expand Down Expand Up @@ -69,10 +82,13 @@ public static int dimensions(EmbeddingModel embeddingModel, String modelName, St

private static Map<String, Integer> loadKnownModelDimensions() {
try {
Properties properties = new Properties();
properties.load(new DefaultResourceLoader()
.getResource("classpath:/embedding/embedding-model-dimensions.properties")
.getInputStream());
var resource = EMBEDDING_MODEL_DIMENSIONS_PROPERTIES;
Assert.notNull(resource, "the embedding dimensions must be non-null");
Assert.state(resource.exists(), "the embedding dimensions properties file must exist");
var properties = new Properties();
try (var in = resource.getInputStream()) {
properties.load(in);
}
return properties.entrySet()
.stream()
.collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())));
Expand Down