Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] Optimize parallel context loading in the TestContext framework #34323

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version=7.0.0-SNAPSHOT
version=7.0.0b-SNAPSHOT

org.gradle.caching=true
org.gradle.jvmargs=-Xmx2048m
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
package org.springframework.test.context.cache;

import org.jspecify.annotations.Nullable;

import org.springframework.context.ApplicationContext;
import org.springframework.test.annotation.DirtiesContext.HierarchyMode;
import org.springframework.test.context.MergedContextConfiguration;

import java.util.concurrent.Future;
import java.util.function.Function;

/**
* {@code ContextCache} defines the SPI for caching Spring
* {@link ApplicationContext ApplicationContexts} within the
Expand Down Expand Up @@ -96,7 +98,10 @@ public interface ContextCache {
* if not found in the cache
* @see #remove
*/
@Nullable ApplicationContext get(MergedContextConfiguration key);
@Nullable
ApplicationContext get(MergedContextConfiguration key);

Future<ApplicationContext> computeIfAbsent(MergedContextConfiguration key, Function<MergedContextConfiguration, ApplicationContext> mappingFunction);

/**
* Explicitly add an {@code ApplicationContext} instance to the cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.test.context.cache;

import java.util.List;
import java.util.concurrent.ExecutionException;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -125,43 +126,58 @@ private DefaultCacheAwareContextLoaderDelegate(ContextCache contextCache, int fa
@Override
public boolean isContextLoaded(MergedContextConfiguration mergedConfig) {
mergedConfig = replaceIfNecessary(mergedConfig);
synchronized (this.contextCache) {
return this.contextCache.contains(mergedConfig);
}
}

@Override
public ApplicationContext loadContext(MergedContextConfiguration mergedConfig) {
mergedConfig = replaceIfNecessary(mergedConfig);
synchronized (this.contextCache) {
ApplicationContext context = this.contextCache.get(mergedConfig);

try {
if (context == null) {
int failureCount = this.contextCache.getFailureCount(mergedConfig);
var contextLoader = this.contextCache.computeIfAbsent(mergedConfig, this::loadContextForReal);

var context = contextLoader.get();

if (context != null && logger.isTraceEnabled()) {
logger.trace("Retrieved ApplicationContext [%s] from cache with key %s".formatted(
System.identityHashCode(context), mergedConfig));
}

return context;
} catch (InterruptedException e) {
throw new RuntimeException(e); //FIXME: Better message
} catch (ExecutionException e) {
throw new RuntimeException(e); //FIXME: Better message
} finally {
this.contextCache.logStatistics();
}
}

private ApplicationContext loadContextForReal(MergedContextConfiguration k) {
int failureCount = this.contextCache.getFailureCount(k);
if (failureCount >= this.failureThreshold) {
throw new IllegalStateException("""
ApplicationContext failure threshold (%d) exceeded: \
skipping repeated attempt to load context for %s"""
.formatted(this.failureThreshold, mergedConfig));
.formatted(this.failureThreshold, k));
}
try {
if (mergedConfig instanceof AotMergedContextConfiguration aotMergedConfig) {
context = loadContextInAotMode(aotMergedConfig);
}
else {
context = loadContextInternal(mergedConfig);
ApplicationContext contextToReturn;
if (k instanceof AotMergedContextConfiguration aotMergedConfig) {
contextToReturn = loadContextInAotMode(aotMergedConfig);
} else {
contextToReturn = loadContextInternal(k);
}
if (logger.isTraceEnabled()) {
logger.trace("Storing ApplicationContext [%s] in cache under key %s".formatted(
System.identityHashCode(context), mergedConfig));
System.identityHashCode(contextToReturn), k));
}
this.contextCache.put(mergedConfig, context);
}
catch (Exception ex) {
return contextToReturn;
} catch (Exception ex) {
if (logger.isTraceEnabled()) {
logger.trace("Incrementing ApplicationContext failure count for " + mergedConfig);
logger.trace("Incrementing ApplicationContext failure count for " + k);
}
this.contextCache.incrementFailureCount(mergedConfig);
this.contextCache.incrementFailureCount(k);
Throwable cause = ex;
if (ex instanceof ContextLoadException cle) {
cause = cle.getCause();
Expand All @@ -178,38 +194,15 @@ ApplicationContext failure threshold (%d) exceeded: \
}
}
throw new IllegalStateException(
"Failed to load ApplicationContext for " + mergedConfig, cause);
}
}
else {
if (logger.isTraceEnabled()) {
logger.trace("Retrieved ApplicationContext [%s] from cache with key %s".formatted(
System.identityHashCode(context), mergedConfig));
}
}
}
finally {
this.contextCache.logStatistics();
}

return context;
"Failed to load ApplicationContext for " + k, cause);
}
}

@Override
public void closeContext(MergedContextConfiguration mergedConfig, @Nullable HierarchyMode hierarchyMode) {
mergedConfig = replaceIfNecessary(mergedConfig);
synchronized (this.contextCache) {
this.contextCache.remove(mergedConfig, hierarchyMode);
}
}

/**
* Get the {@link ContextCache} used by this context loader delegate.
*/
protected ContextCache getContextCache() {
return this.contextCache;
}

/**
* Load the {@code ApplicationContext} for the supplied merged context configuration.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;

import org.jspecify.annotations.Nullable;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.core.style.ToStringCreator;
Expand Down Expand Up @@ -57,11 +58,12 @@ public class DefaultContextCache implements ContextCache {

private static final Log statsLogger = LogFactory.getLog(CONTEXT_CACHE_LOGGING_CATEGORY);

ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); //TODO: Make this parametric

/**
* Map of context keys to Spring {@code ApplicationContext} instances.
*/
private final Map<MergedContextConfiguration, ApplicationContext> contextMap =
private final Map<MergedContextConfiguration, Future<ApplicationContext>> contextMap =
Collections.synchronizedMap(new LruCache(32, 0.75f));

/**
Expand Down Expand Up @@ -120,25 +122,69 @@ public boolean contains(MergedContextConfiguration key) {
return this.contextMap.containsKey(key);
}

@Override
@Override//TODO: This is not used anymore in spring but make sense to keep it for retro compatibility, right?
public @Nullable ApplicationContext get(MergedContextConfiguration key) {
Assert.notNull(key, "Key must not be null");
ApplicationContext context = this.contextMap.get(key);

try {
Future<ApplicationContext> context = this.contextMap.get(key);

if (context == null) {
this.missCount.incrementAndGet();

return null;
}
else {
this.hitCount.incrementAndGet();

return context.get();
}
} catch (InterruptedException e) {
throw new RuntimeException(e);//FIXME: fix the message
} catch (ExecutionException e) {
throw new RuntimeException(e);//FIXME: fix the message
}
}


@Override
public Future<ApplicationContext> computeIfAbsent(MergedContextConfiguration key, Function<MergedContextConfiguration, ApplicationContext> mappingFunction) {
Assert.notNull(key, "Key must not be null");

if(contextMap.containsKey(key)) {
this.hitCount.incrementAndGet();
}
return context;

return contextMap.computeIfAbsent(key, (k) ->
{
this.missCount.incrementAndGet();
return CompletableFuture.supplyAsync(() -> mappingFunction.apply(k), executorService)
.thenApply(
(contextLoaded) -> {
MergedContextConfiguration child = key;
MergedContextConfiguration parent = child.getParent();
while (parent != null) {
Set<MergedContextConfiguration> list = this.hierarchyMap.computeIfAbsent(parent, k2 -> new HashSet<>());
list.add(child);
child = parent;
parent = child.getParent();
}

return contextLoaded;
}
);

}
);
}

//TODO: This is not used anymore in spring but make sense to keep it for retro compatibility, right?
@Override
public void put(MergedContextConfiguration key, ApplicationContext context) {
Assert.notNull(key, "Key must not be null");
Assert.notNull(context, "ApplicationContext must not be null");

this.contextMap.put(key, context);
this.contextMap.put(key, CompletableFuture.completedFuture(context));
MergedContextConfiguration child = key;
MergedContextConfiguration parent = child.getParent();
while (parent != null) {
Expand Down Expand Up @@ -198,10 +244,19 @@ private void remove(List<MergedContextConfiguration> removedContexts, MergedCont

// Physically remove and close leaf nodes first (i.e., on the way back up the
// stack as opposed to prior to the recursive call).
ApplicationContext context = this.contextMap.remove(key);
Future<ApplicationContext> contextLoader = this.contextMap.remove(key);

try {
ApplicationContext context = contextLoader.get();
if (context instanceof ConfigurableApplicationContext cac) {
cac.close();
}
} catch (InterruptedException e) {
throw new RuntimeException(e); //FIXME: fix the message
} catch (ExecutionException e) {
throw new RuntimeException(e); //FIXME: fix the message
}

removedContexts.add(key);
}

Expand Down Expand Up @@ -303,7 +358,7 @@ public String toString() {
* @since 4.3
*/
@SuppressWarnings("serial")
private class LruCache extends LinkedHashMap<MergedContextConfiguration, ApplicationContext> {
private class LruCache extends LinkedHashMap<MergedContextConfiguration, Future<ApplicationContext>> {

/**
* Create a new {@code LruCache} with the supplied initial capacity
Expand All @@ -316,7 +371,7 @@ private class LruCache extends LinkedHashMap<MergedContextConfiguration, Applica
}

@Override
protected boolean removeEldestEntry(Map.Entry<MergedContextConfiguration, ApplicationContext> eldest) {
protected boolean removeEldestEntry(Map.Entry<MergedContextConfiguration, Future<ApplicationContext>> eldest) {
if (this.size() > DefaultContextCache.this.getMaxSize()) {
// Do NOT delete "DefaultContextCache.this."; otherwise, we accidentally
// invoke java.util.Map.remove(Object, Object).
Expand Down