Skip to content

xds: float LRU cache across interceptors (v1.72.x backport) #12096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,17 @@

static final String TYPE_URL =
"type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig";

private final LruCache<String, CallCredentials> callCredentialsCache;
final String filterInstanceName;

GcpAuthenticationFilter(String name) {
GcpAuthenticationFilter(String name, int cacheSize) {
filterInstanceName = checkNotNull(name, "name");
this.callCredentialsCache = new LruCache<>(cacheSize);
}


static final class Provider implements Filter.Provider {
private final int cacheSize = 10;

@Override
public String[] typeUrls() {
return new String[]{TYPE_URL};
Expand All @@ -80,7 +82,7 @@

@Override
public GcpAuthenticationFilter newInstance(String name) {
return new GcpAuthenticationFilter(name);
return new GcpAuthenticationFilter(name, cacheSize);

Check warning on line 85 in xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java#L85

Added line #L85 was not covered by tests
}

@Override
Expand All @@ -101,11 +103,14 @@
// Validate cache_config
if (gcpAuthnProto.hasCacheConfig()) {
TokenCacheConfig cacheConfig = gcpAuthnProto.getCacheConfig();
cacheSize = cacheConfig.getCacheSize().getValue();
if (cacheSize == 0) {
return ConfigOrError.fromError(
"cache_config.cache_size must be greater than zero");
if (cacheConfig.hasCacheSize()) {
cacheSize = cacheConfig.getCacheSize().getValue();
if (cacheSize == 0) {
return ConfigOrError.fromError(
"cache_config.cache_size must be greater than zero");
}
}

// LruCache's size is an int and briefly exceeds its maximum size before evicting entries
cacheSize = UnsignedLongs.min(cacheSize, Integer.MAX_VALUE - 1);
}
Expand All @@ -127,8 +132,9 @@
@Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler) {

ComputeEngineCredentials credentials = ComputeEngineCredentials.create();
LruCache<String, CallCredentials> callCredentialsCache =
new LruCache<>(((GcpAuthenticationConfig) config).getCacheSize());
synchronized (callCredentialsCache) {
callCredentialsCache.resizeCache(((GcpAuthenticationConfig) config).getCacheSize());
}
return new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
Expand Down Expand Up @@ -254,23 +260,37 @@

private static final class LruCache<K, V> {

private final Map<K, V> cache;
private Map<K, V> cache;
private int maxSize;

LruCache(int maxSize) {
this.cache = new LinkedHashMap<K, V>(
maxSize,
0.75f,
true) {
@Override
protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
return size() > maxSize;
}
};
this.maxSize = maxSize;
this.cache = createEvictingMap(maxSize);
}

V getOrInsert(K key, Function<K, V> create) {
return cache.computeIfAbsent(key, create);
}

private void resizeCache(int newSize) {
if (newSize >= maxSize) {
maxSize = newSize;
return;
}
Map<K, V> newCache = createEvictingMap(newSize);
maxSize = newSize;
newCache.putAll(cache);
cache = newCache;
}

private Map<K, V> createEvictingMap(int size) {
return new LinkedHashMap<K, V>(size, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
return size() > LruCache.this.maxSize;
}
};
}
}

static class AudienceMetadataParser implements MetadataValueParser {
Expand Down
146 changes: 135 additions & 11 deletions xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
import static io.grpc.xds.XdsTestUtils.getWrrLbConfigAsMap;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -89,8 +91,8 @@ public class GcpAuthenticationFilterTest {

@Test
public void testNewFilterInstancesPerFilterName() {
assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1"))
.isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1"));
assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10))
.isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10));
}

@Test
Expand Down Expand Up @@ -152,7 +154,7 @@ public void testClientInterceptor_success() throws IOException, ResourceInvalidE
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = Mockito.mock(Channel.class);
Expand Down Expand Up @@ -181,7 +183,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials()
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = Mockito.mock(Channel.class);
Expand All @@ -190,7 +192,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials()
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);

verify(mockChannel, Mockito.times(2))
verify(mockChannel, times(2))
.newCall(eq(methodDescriptor), callOptionsCaptor.capture());
CallOptions firstCapturedOptions = callOptionsCaptor.getAllValues().get(0);
CallOptions secondCapturedOptions = callOptionsCaptor.getAllValues().get(1);
Expand All @@ -202,7 +204,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials()
@Test
public void testClientInterceptor_withoutClusterSelectionKey() throws Exception {
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = mock(Channel.class);
Expand Down Expand Up @@ -233,7 +235,7 @@ public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exce
Channel mockChannel = mock(Channel.class);

GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
Expand All @@ -244,7 +246,7 @@ public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exce
@Test
public void testClientInterceptor_xdsConfigDoesNotExist() throws Exception {
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = mock(Channel.class);
Expand Down Expand Up @@ -274,7 +276,7 @@ public void testClientInterceptor_incorrectClusterName() throws Exception {
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = mock(Channel.class);
Expand All @@ -300,7 +302,7 @@ public void testClientInterceptor_statusOrError() throws Exception {
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = mock(Channel.class);
Expand Down Expand Up @@ -329,7 +331,7 @@ public void testClientInterceptor_notAudienceWrapper()
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = Mockito.mock(Channel.class);
Expand All @@ -342,6 +344,115 @@ public void testClientInterceptor_notAudienceWrapper()
assertThat(clientCall.error.getDescription()).contains("GCP Authn found wrong type");
}

@Test
public void testLruCacheAcrossInterceptors() throws IOException, ResourceInvalidException {
XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig(
CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
.setListener(ldsUpdate)
.setRoute(rdsUpdate)
.setVirtualHost(rdsUpdate.virtualHosts.get(0))
.addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
CallOptions callOptionsWithXds = CallOptions.DEFAULT
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2);
ClientInterceptor interceptor1
= filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = Mockito.mock(Channel.class);
ArgumentCaptor<CallOptions> callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class);

interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
verify(mockChannel).newCall(eq(methodDescriptor), callOptionsCaptor.capture());
CallOptions capturedOptions1 = callOptionsCaptor.getAllValues().get(0);
assertNotNull(capturedOptions1.getCredentials());
ClientInterceptor interceptor2
= filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
verify(mockChannel, times(2))
.newCall(eq(methodDescriptor), callOptionsCaptor.capture());
CallOptions capturedOptions2 = callOptionsCaptor.getAllValues().get(1);
assertNotNull(capturedOptions2.getCredentials());

assertSame(capturedOptions1.getCredentials(), capturedOptions2.getCredentials());
}

@Test
public void testLruCacheEvictionOnResize() throws IOException, ResourceInvalidException {
XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig(
CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
.setListener(ldsUpdate)
.setRoute(rdsUpdate)
.setVirtualHost(rdsUpdate.virtualHosts.get(0))
.addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
CallOptions callOptionsWithXds = CallOptions.DEFAULT
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();

ClientInterceptor interceptor1 =
filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null);
Channel mockChannel1 = Mockito.mock(Channel.class);
ArgumentCaptor<CallOptions> captor = ArgumentCaptor.forClass(CallOptions.class);
interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel1);
verify(mockChannel1).newCall(eq(methodDescriptor), captor.capture());
CallOptions options1 = captor.getValue();
// This will recreate the cache with max size of 1 and copy the credential for audience1.
ClientInterceptor interceptor2 =
filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
Channel mockChannel2 = Mockito.mock(Channel.class);
interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel2);
verify(mockChannel2).newCall(eq(methodDescriptor), captor.capture());
CallOptions options2 = captor.getValue();

assertSame(options1.getCredentials(), options2.getCredentials());

clusterConfig = new XdsConfig.XdsClusterConfig(
CLUSTER_NAME, getCdsUpdate2(), new EndpointConfig(StatusOr.fromValue(edsUpdate)));
defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
.setListener(ldsUpdate)
.setRoute(rdsUpdate)
.setVirtualHost(rdsUpdate.virtualHosts.get(0))
.addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
callOptionsWithXds = CallOptions.DEFAULT
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);

// This will evict the credential for audience1 and add new credential for audience2
ClientInterceptor interceptor3 =
filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
Channel mockChannel3 = Mockito.mock(Channel.class);
interceptor3.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel3);
verify(mockChannel3).newCall(eq(methodDescriptor), captor.capture());
CallOptions options3 = captor.getValue();

assertNotSame(options1.getCredentials(), options3.getCredentials());

clusterConfig = new XdsConfig.XdsClusterConfig(
CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
.setListener(ldsUpdate)
.setRoute(rdsUpdate)
.setVirtualHost(rdsUpdate.virtualHosts.get(0))
.addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
callOptionsWithXds = CallOptions.DEFAULT
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);

// This will create new credential for audience1 because it has been evicted
ClientInterceptor interceptor4 =
filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
Channel mockChannel4 = Mockito.mock(Channel.class);
interceptor4.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel4);
verify(mockChannel4).newCall(eq(methodDescriptor), captor.capture());
CallOptions options4 = captor.getValue();

assertNotSame(options1.getCredentials(), options4.getCredentials());
}

private static LdsUpdate getLdsUpdate() {
Filter.NamedFilterConfig routerFilterConfig = new Filter.NamedFilterConfig(
serverName, RouterFilter.ROUTER_CONFIG);
Expand Down Expand Up @@ -384,6 +495,19 @@ private static CdsUpdate getCdsUpdate() {
}
}

private static CdsUpdate getCdsUpdate2() {
ImmutableMap.Builder<String, Object> parsedMetadata = ImmutableMap.builder();
parsedMetadata.put("FILTER_INSTANCE_NAME", new AudienceWrapper("NEW_TEST_AUDIENCE"));
try {
CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds(
CLUSTER_NAME, EDS_NAME, null, null, null, null, false)
.lbPolicyConfig(getWrrLbConfigAsMap());
return cdsUpdate.parsedMetadata(parsedMetadata.build()).build();
} catch (IOException ex) {
return null;
}
}

private static CdsUpdate getCdsUpdateWithIncorrectAudienceWrapper() throws IOException {
ImmutableMap.Builder<String, Object> parsedMetadata = ImmutableMap.builder();
parsedMetadata.put("FILTER_INSTANCE_NAME", "TEST_AUDIENCE");
Expand Down