Skip to content

Commit

Permalink
fix: Handle 404 and non 200 Status Code from MDS Identity Token calls (
Browse files Browse the repository at this point in the history
…#1636)

* fix: Handle 404 Status Code from MDS Identity Token calls

* chore: Fix tests
  • Loading branch information
lqiu96 authored Feb 3, 2025
1 parent 26785bf commit 152c851
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
public class ComputeEngineCredentials extends GoogleCredentials
implements ServiceAccountSigner, IdTokenProvider {

static final String METADATA_RESPONSE_EMPTY_CONTENT_ERROR_MESSAGE =
"Empty content from metadata token server request.";
// Decrease timing margins on GCE.
// This is needed because GCE VMs maintain their own OAuth cache that expires T-4 mins, attempting
// to refresh a token before then, will yield the same stale token. To enable pre-emptive
Expand Down Expand Up @@ -366,7 +368,7 @@ public AccessToken refreshAccessToken() throws IOException {
if (content == null) {
// Throw explicitly here on empty content to avoid NullPointerException from parseAs call.
// Mock transports will have success code with empty content by default.
throw new IOException("Empty content from metadata token server request.");
throw new IOException(METADATA_RESPONSE_EMPTY_CONTENT_ERROR_MESSAGE);
}
GenericData responseData = response.parseAs(GenericData.class);
String accessToken =
Expand Down Expand Up @@ -408,9 +410,24 @@ public IdToken idTokenWithAudience(String targetAudience, List<IdTokenProvider.O
documentUrl.set("audience", targetAudience);
HttpResponse response =
getMetadataResponse(documentUrl.toString(), RequestType.ID_TOKEN_REQUEST, true);
int statusCode = response.getStatusCode();
if (statusCode == HttpStatusCodes.STATUS_CODE_NOT_FOUND) {
throw new IOException(
String.format(
"Error code %s trying to get identity token from"
+ " Compute Engine metadata. This may be because the virtual machine instance"
+ " does not have permission scopes specified.",
statusCode));
}
if (statusCode != HttpStatusCodes.STATUS_CODE_OK) {
throw new IOException(
String.format(
"Unexpected Error code %s trying to get identity token from Compute Engine metadata: %s",
statusCode, response.parseAsString()));
}
InputStream content = response.getContent();
if (content == null) {
throw new IOException("Empty content from metadata token server request.");
throw new IOException(METADATA_RESPONSE_EMPTY_CONTENT_ERROR_MESSAGE);
}
String rawToken = response.parseAsString();
return IdToken.create(rawToken);
Expand Down Expand Up @@ -710,7 +727,7 @@ private String getDefaultServiceAccount() throws IOException {
if (content == null) {
// Throw explicitly here on empty content to avoid NullPointerException from parseAs call.
// Mock transports will have success code with empty content by default.
throw new IOException("Empty content from metadata token server request.");
throw new IOException(METADATA_RESPONSE_EMPTY_CONTENT_ERROR_MESSAGE);
}
GenericData responseData = response.parseAs(GenericData.class);
Map<String, Object> defaultAccount =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@

package com.google.auth.oauth2;

import static com.google.auth.oauth2.ComputeEngineCredentials.METADATA_RESPONSE_EMPTY_CONTENT_ERROR_MESSAGE;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

Expand Down Expand Up @@ -420,7 +422,7 @@ public void getRequestMetadata_shouldInvalidateAccessTokenWhenScoped_newAccessTo
@Test
public void getRequestMetadata_missingServiceAccount_throws() {
MockMetadataServerTransportFactory transportFactory = new MockMetadataServerTransportFactory();
transportFactory.transport.setRequestStatusCode(HttpStatusCodes.STATUS_CODE_NOT_FOUND);
transportFactory.transport.setStatusCode(HttpStatusCodes.STATUS_CODE_NOT_FOUND);
ComputeEngineCredentials credentials =
ComputeEngineCredentials.newBuilder().setHttpTransportFactory(transportFactory).build();
try {
Expand All @@ -437,7 +439,7 @@ public void getRequestMetadata_missingServiceAccount_throws() {
@Test
public void getRequestMetadata_serverError_throws() {
MockMetadataServerTransportFactory transportFactory = new MockMetadataServerTransportFactory();
transportFactory.transport.setRequestStatusCode(HttpStatusCodes.STATUS_CODE_SERVER_ERROR);
transportFactory.transport.setStatusCode(HttpStatusCodes.STATUS_CODE_SERVER_ERROR);
ComputeEngineCredentials credentials =
ComputeEngineCredentials.newBuilder().setHttpTransportFactory(transportFactory).build();
try {
Expand Down Expand Up @@ -668,7 +670,7 @@ public void sign_getUniverseException() {
ComputeEngineCredentials credentials =
ComputeEngineCredentials.newBuilder().setHttpTransportFactory(transportFactory).build();

transportFactory.transport.setRequestStatusCode(501);
transportFactory.transport.setStatusCode(501);
Assert.assertThrows(IOException.class, credentials::getUniverseDomain);

byte[] expectedSignature = {0xD, 0xE, 0xA, 0xD};
Expand Down Expand Up @@ -962,7 +964,7 @@ public void getUniverseDomain_fromMetadata_non404error_throws() throws IOExcepti
continue;
}
try {
transportFactory.transport.setRequestStatusCode(status);
transportFactory.transport.setStatusCode(status);
credentials.getUniverseDomain();
fail("Should not be able to use credential without exception.");
} catch (GoogleAuthException ex) {
Expand Down Expand Up @@ -1095,6 +1097,45 @@ public void idTokenWithAudience_license() throws IOException {
assertTrue(googleClaim.containsKey("license"));
}

@Test
public void idTokenWithAudience_404StatusCode() {
int statusCode = HttpStatusCodes.STATUS_CODE_NOT_FOUND;
MockMetadataServerTransportFactory transportFactory = new MockMetadataServerTransportFactory();
transportFactory.transport.setStatusCode(HttpStatusCodes.STATUS_CODE_NOT_FOUND);
ComputeEngineCredentials credentials =
ComputeEngineCredentials.newBuilder().setHttpTransportFactory(transportFactory).build();
IOException exception =
assertThrows(IOException.class, () -> credentials.idTokenWithAudience("Audience", null));
assertEquals(
String.format(
"Error code %s trying to get identity token from"
+ " Compute Engine metadata. This may be because the virtual machine instance"
+ " does not have permission scopes specified.",
statusCode),
exception.getMessage());
}

@Test
public void idTokenWithAudience_emptyContent() {
MockMetadataServerTransportFactory transportFactory = new MockMetadataServerTransportFactory();
transportFactory.transport.setEmptyContent(true);
ComputeEngineCredentials credentials =
ComputeEngineCredentials.newBuilder().setHttpTransportFactory(transportFactory).build();
IOException exception =
assertThrows(IOException.class, () -> credentials.idTokenWithAudience("Audience", null));
assertEquals(METADATA_RESPONSE_EMPTY_CONTENT_ERROR_MESSAGE, exception.getMessage());
}

@Test
public void idTokenWithAudience_503StatusCode() {
MockMetadataServerTransportFactory transportFactory = new MockMetadataServerTransportFactory();
transportFactory.transport.setStatusCode(HttpStatusCodes.STATUS_CODE_SERVICE_UNAVAILABLE);
ComputeEngineCredentials credentials =
ComputeEngineCredentials.newBuilder().setHttpTransportFactory(transportFactory).build();
assertThrows(
GoogleAuthException.class, () -> credentials.idTokenWithAudience("Audience", null));
}

static class MockMetadataServerTransportFactory implements HttpTransportFactory {

MockMetadataServerTransport transport =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

package com.google.auth.oauth2;

import com.google.api.client.http.HttpStatusCodes;
import com.google.api.client.http.LowLevelHttpRequest;
import com.google.api.client.http.LowLevelHttpResponse;
import com.google.api.client.json.GenericJson;
Expand All @@ -55,7 +56,7 @@ public class MockMetadataServerTransport extends MockHttpTransport {

// key are scopes as in request url string following "?scopes="
private Map<String, String> scopesToAccessToken;
private Integer requestStatusCode;
private Integer statusCode;

private String serviceAccountEmail;

Expand Down Expand Up @@ -91,8 +92,8 @@ public void setAccessToken(String scopes, String accessToken) {
scopesToAccessToken.put(scopes, accessToken);
}

public void setRequestStatusCode(Integer requestStatusCode) {
this.requestStatusCode = requestStatusCode;
public void setStatusCode(Integer statusCode) {
this.statusCode = statusCode;
}

public void setServiceAccountEmail(String serviceAccountEmail) {
Expand Down Expand Up @@ -140,14 +141,15 @@ public LowLevelHttpRequest buildRequest(String method, String url) throws IOExce
new MockLowLevelHttpRequest(url) {
@Override
public LowLevelHttpResponse execute() {
if (requestStatusCode != null) {
if (statusCode != null && (statusCode >= 400 && statusCode < 600)) {
return new MockLowLevelHttpResponse()
.setStatusCode(requestStatusCode)
.setStatusCode(statusCode)
.setContent("Metadata Error");
}

MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();
response.addHeader("Metadata-Flavor", "Google");
response.setStatusCode(HttpStatusCodes.STATUS_CODE_OK);
return response;
}
};
Expand Down Expand Up @@ -195,9 +197,9 @@ private MockLowLevelHttpRequest getMockRequestForTokenEndpoint(String url) {
@Override
public LowLevelHttpResponse execute() throws IOException {

if (requestStatusCode != null) {
if (statusCode != null && (statusCode >= 400 && statusCode < 600)) {
return new MockLowLevelHttpResponse()
.setStatusCode(requestStatusCode)
.setStatusCode(statusCode)
.setContent("Token Fetch Error");
}

Expand All @@ -224,20 +226,35 @@ public LowLevelHttpResponse execute() throws IOException {

return new MockLowLevelHttpResponse()
.setContentType(Json.MEDIA_TYPE)
.setStatusCode(HttpStatusCodes.STATUS_CODE_OK)
.setContent(refreshText);
}
};
}

private MockLowLevelHttpRequest getMockRequestForIdentityDocument(String url)
throws MalformedURLException, UnsupportedEncodingException {
if (idToken != null) {
if (statusCode != null && statusCode != HttpStatusCodes.STATUS_CODE_OK) {
return new MockLowLevelHttpRequest(url) {
@Override
public LowLevelHttpResponse execute() throws IOException {
public LowLevelHttpResponse execute() {
return new MockLowLevelHttpResponse().setStatusCode(statusCode);
}
};
} else if (idToken != null) {
return new MockLowLevelHttpRequest(url) {
@Override
public LowLevelHttpResponse execute() {
return new MockLowLevelHttpResponse().setContent(idToken);
}
};
} else if (emptyContent) {
return new MockLowLevelHttpRequest(url) {
@Override
public LowLevelHttpResponse execute() {
return new MockLowLevelHttpResponse();
}
};
}

// https://cloud.google.com/compute/docs/instances/verifying-instance-identity#token_format
Expand Down Expand Up @@ -299,15 +316,15 @@ public LowLevelHttpResponse execute() throws IOException {
// Create the JSON response
GenericJson content = new GenericJson();
content.setFactory(OAuth2Utils.JSON_FACTORY);
if (requestStatusCode == 200) {
if (statusCode == HttpStatusCodes.STATUS_CODE_OK) {
content.put(SecureSessionAgent.S2A_JSON_KEY, s2aContentMap);
}
String contentText = content.toPrettyString();

MockLowLevelHttpResponse response = new MockLowLevelHttpResponse();

if (requestStatusCode != null) {
response.setStatusCode(requestStatusCode);
if (statusCode != null) {
response.setStatusCode(statusCode);
}
if (emptyContent == true) {
return response.setZeroContent();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public void getS2AAddress_validAddress() {
S2A_PLAINTEXT_ADDRESS,
SecureSessionAgent.S2A_MTLS_ADDRESS_JSON_KEY,
S2A_MTLS_ADDRESS));
transportFactory.transport.setRequestStatusCode(HttpStatusCodes.STATUS_CODE_OK);
transportFactory.transport.setStatusCode(HttpStatusCodes.STATUS_CODE_OK);

SecureSessionAgent s2aUtils =
SecureSessionAgent.newBuilder().setHttpTransportFactory(transportFactory).build();
Expand All @@ -77,8 +77,7 @@ public void getS2AAddress_queryEndpointResponseErrorCode_emptyAddress() {
S2A_PLAINTEXT_ADDRESS,
SecureSessionAgent.S2A_MTLS_ADDRESS_JSON_KEY,
S2A_MTLS_ADDRESS));
transportFactory.transport.setRequestStatusCode(
HttpStatusCodes.STATUS_CODE_SERVICE_UNAVAILABLE);
transportFactory.transport.setStatusCode(HttpStatusCodes.STATUS_CODE_SERVICE_UNAVAILABLE);

SecureSessionAgent s2aUtils =
SecureSessionAgent.newBuilder().setHttpTransportFactory(transportFactory).build();
Expand All @@ -98,7 +97,7 @@ public void getS2AAddress_queryEndpointResponseEmpty_emptyAddress() {
S2A_PLAINTEXT_ADDRESS,
SecureSessionAgent.S2A_MTLS_ADDRESS_JSON_KEY,
S2A_MTLS_ADDRESS));
transportFactory.transport.setRequestStatusCode(HttpStatusCodes.STATUS_CODE_OK);
transportFactory.transport.setStatusCode(HttpStatusCodes.STATUS_CODE_OK);
transportFactory.transport.setEmptyContent(true);

SecureSessionAgent s2aUtils =
Expand All @@ -119,7 +118,7 @@ public void getS2AAddress_queryEndpointResponseInvalidPlaintextJsonKey_plaintext
S2A_PLAINTEXT_ADDRESS,
SecureSessionAgent.S2A_MTLS_ADDRESS_JSON_KEY,
S2A_MTLS_ADDRESS));
transportFactory.transport.setRequestStatusCode(HttpStatusCodes.STATUS_CODE_OK);
transportFactory.transport.setStatusCode(HttpStatusCodes.STATUS_CODE_OK);

SecureSessionAgent s2aUtils =
SecureSessionAgent.newBuilder().setHttpTransportFactory(transportFactory).build();
Expand All @@ -139,7 +138,7 @@ public void getS2AAddress_queryEndpointResponseInvalidMtlsJsonKey_mtlsEmptyAddre
S2A_PLAINTEXT_ADDRESS,
INVALID_JSON_KEY,
S2A_MTLS_ADDRESS));
transportFactory.transport.setRequestStatusCode(HttpStatusCodes.STATUS_CODE_OK);
transportFactory.transport.setStatusCode(HttpStatusCodes.STATUS_CODE_OK);

SecureSessionAgent s2aUtils =
SecureSessionAgent.newBuilder().setHttpTransportFactory(transportFactory).build();
Expand Down

0 comments on commit 152c851

Please sign in to comment.