Skip to content

Ensure ID Token is updated after refresh token (Reactive) #17246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -362,6 +362,8 @@ public final class RefreshTokenGrantBuilder implements Builder {

private ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient;

private ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler;

private Duration clockSkew;

private Clock clock;
@@ -382,6 +384,21 @@ public RefreshTokenGrantBuilder accessTokenResponseClient(
return this;
}

/**
* Sets a {@link ReactiveOAuth2AuthorizationSuccessHandler} that is called after
* the client is re-authorized, defaults to
* {@link RefreshTokenReactiveOAuth2AuthorizationSuccessHandler}.
* @param refreshTokenSuccessHandler the
* {@link ReactiveOAuth2AuthorizationSuccessHandler} to use
* @return the {@link RefreshTokenGrantBuilder}
* @since 7.0
*/
public RefreshTokenGrantBuilder refreshTokenSuccessHandler(
ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler) {
this.refreshTokenSuccessHandler = refreshTokenSuccessHandler;
return this;
}

/**
* Sets the maximum acceptable clock skew, which is used when checking the access
* token expiry. An access token is considered expired if
@@ -418,6 +435,9 @@ public ReactiveOAuth2AuthorizedClientProvider build() {
if (this.accessTokenResponseClient != null) {
authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
}
if (this.refreshTokenSuccessHandler != null) {
authorizedClientProvider.setRefreshTokenSuccessHandler(this.refreshTokenSuccessHandler);
}
if (this.clockSkew != null) {
authorizedClientProvider.setClockSkew(this.clockSkew);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client;

import java.time.Duration;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import reactor.core.publisher.Mono;

import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.oidc.authentication.ReactiveOidcIdTokenDecoderFactory;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcReactiveOAuth2UserService;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;

/**
* A {@link ReactiveOAuth2AuthorizationSuccessHandler} that refreshes an {@link OidcUser}
* in the {@link SecurityContext} if the refreshed {@link OidcIdToken} is valid according
* to <a href=
* "https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse">OpenID
* Connect Core 1.0 - Section 12.2 Successful Refresh Response</a>
*
* @author Evgeniy Cheban
* @since 7.0
*/
public final class RefreshTokenReactiveOAuth2AuthorizationSuccessHandler
implements ReactiveOAuth2AuthorizationSuccessHandler {

private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token";

private static final String INVALID_NONCE_ERROR_CODE = "invalid_nonce";

private static final String REFRESH_TOKEN_RESPONSE_ERROR_URI = "https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse";

// @formatter:off
private static final Mono<ServerWebExchange> currentServerWebExchangeMono = Mono.deferContextual(Mono::just)
.filter((c) -> c.hasKey(ServerWebExchange.class))
.map((c) -> c.get(ServerWebExchange.class));
// @formatter:on

private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();

private ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory = new ReactiveOidcIdTokenDecoderFactory();

private ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = new OidcReactiveOAuth2UserService();

private GrantedAuthoritiesMapper authoritiesMapper = (authorities) -> authorities;

private Duration clockSkew = Duration.ofSeconds(60);

@Override
public Mono<Void> onAuthorizationSuccess(OAuth2AuthorizedClient authorizedClient, Authentication principal,
Map<String, Object> attributes) {
if (!(principal instanceof OAuth2AuthenticationToken authenticationToken)
|| authenticationToken.getClass() != OAuth2AuthenticationToken.class) {
// If the application customizes the authentication result, then a custom
// handler should be provided.
return Mono.empty();
}
// The current principal must be an OidcUser.
if (!(authenticationToken.getPrincipal() instanceof OidcUser existingOidcUser)) {
return Mono.empty();
}
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
// The registrationId must match the one used to log in.
if (!authenticationToken.getAuthorizedClientRegistrationId().equals(clientRegistration.getRegistrationId())) {
return Mono.empty();
}
// Create, validate OidcIdToken and refresh OidcUser in the SecurityContext.
return Mono.zip(serverWebExchange(attributes), accessTokenResponse(attributes)).flatMap((t2) -> {
ReactiveJwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
Map<String, Object> additionalParameters = t2.getT2().getAdditionalParameters();
return jwtDecoder.decode((String) additionalParameters.get(OidcParameterNames.ID_TOKEN))
.onErrorMap(JwtException.class, (ex) -> {
OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(),
null);
return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex);
})
.map((jwt) -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(),
jwt.getClaims()))
.doOnNext((idToken) -> validateIdToken(existingOidcUser, idToken))
.flatMap((idToken) -> {
OidcUserRequest userRequest = new OidcUserRequest(clientRegistration,
authorizedClient.getAccessToken(), idToken);
return this.userService.loadUser(userRequest);
})
.flatMap((oidcUser) -> refreshSecurityContext(t2.getT1(), clientRegistration, authenticationToken,
oidcUser));
});
}

private Mono<ServerWebExchange> serverWebExchange(Map<String, Object> attributes) {
if (attributes.get(ServerWebExchange.class.getName()) instanceof ServerWebExchange exchange) {
return Mono.just(exchange);
}
return currentServerWebExchangeMono;
}

private Mono<OAuth2AccessTokenResponse> accessTokenResponse(Map<String, Object> attributes) {
if (attributes.get(OAuth2AccessTokenResponse.class.getName()) instanceof OAuth2AccessTokenResponse response) {
return Mono.just(response);
}
return Mono.empty();
}

private void validateIdToken(OidcUser existingOidcUser, OidcIdToken idToken) {
// OpenID Connect Core 1.0 - Section 12.2 Successful Refresh Response
// If an ID Token is returned as a result of a token refresh request, the
// following requirements apply:
// its iss Claim Value MUST be the same as in the ID Token issued when the
// original authentication occurred,
validateIssuer(existingOidcUser, idToken);
// its sub Claim Value MUST be the same as in the ID Token issued when the
// original authentication occurred,
validateSubject(existingOidcUser, idToken);
// its iat Claim MUST represent the time that the new ID Token is issued,
validateIssuedAt(existingOidcUser, idToken);
// its aud Claim Value MUST be the same as in the ID Token issued when the
// original authentication occurred,
validateAudience(existingOidcUser, idToken);
// if the ID Token contains an auth_time Claim, its value MUST represent the time
// of the original authentication - not the time that the new ID token is issued,
validateAuthenticatedAt(existingOidcUser, idToken);
// it SHOULD NOT have a nonce Claim, even when the ID Token issued at the time of
// the original authentication contained nonce; however, if it is present, its
// value MUST be the same as in the ID Token issued at the time of the original
// authentication,
validateNonce(existingOidcUser, idToken);
}

private void validateIssuer(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!idToken.getIssuer().toString().equals(existingOidcUser.getIdToken().getIssuer().toString())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issuer",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}

private void validateSubject(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!idToken.getSubject().equals(existingOidcUser.getIdToken().getSubject())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid subject",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}

private void validateIssuedAt(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!idToken.getIssuedAt().isAfter(existingOidcUser.getIdToken().getIssuedAt().minus(this.clockSkew))) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid issued at time",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}

private void validateAudience(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!isValidAudience(existingOidcUser, idToken)) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid audience",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}

private boolean isValidAudience(OidcUser existingOidcUser, OidcIdToken idToken) {
List<String> idTokenAudiences = idToken.getAudience();
Set<String> oidcUserAudiences = new HashSet<>(existingOidcUser.getIdToken().getAudience());
if (idTokenAudiences.size() != oidcUserAudiences.size()) {
return false;
}
for (String audience : idTokenAudiences) {
if (!oidcUserAudiences.contains(audience)) {
return false;
}
}
return true;
}

private void validateAuthenticatedAt(OidcUser existingOidcUser, OidcIdToken idToken) {
if (idToken.getAuthenticatedAt() == null) {
return;
}
if (!idToken.getAuthenticatedAt().equals(existingOidcUser.getIdToken().getAuthenticatedAt())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, "Invalid authenticated at time",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}

private void validateNonce(OidcUser existingOidcUser, OidcIdToken idToken) {
if (!StringUtils.hasText(idToken.getNonce())) {
return;
}
if (!idToken.getNonce().equals(existingOidcUser.getIdToken().getNonce())) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_NONCE_ERROR_CODE, "Invalid nonce",
REFRESH_TOKEN_RESPONSE_ERROR_URI);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
}

private Mono<Void> refreshSecurityContext(ServerWebExchange exchange, ClientRegistration clientRegistration,
OAuth2AuthenticationToken authenticationToken, OidcUser oidcUser) {
Collection<? extends GrantedAuthority> mappedAuthorities = this.authoritiesMapper
.mapAuthorities(oidcUser.getAuthorities());
OAuth2AuthenticationToken authenticationResult = new OAuth2AuthenticationToken(oidcUser, mappedAuthorities,
clientRegistration.getRegistrationId());
authenticationResult.setDetails(authenticationToken.getDetails());
SecurityContextImpl securityContext = new SecurityContextImpl(authenticationResult);
return this.serverSecurityContextRepository.save(exchange, securityContext);
}

/**
* Sets a {@link ServerSecurityContextRepository} to use for refreshing a
* {@link SecurityContext}, defaults to
* {@link WebSessionServerSecurityContextRepository}.
* @param serverSecurityContextRepository the {@link ServerSecurityContextRepository}
* to use
*/
public void setServerSecurityContextRepository(ServerSecurityContextRepository serverSecurityContextRepository) {
Assert.notNull(serverSecurityContextRepository, "serverSecurityContextRepository cannot be null");
this.serverSecurityContextRepository = serverSecurityContextRepository;
}

/**
* Sets a {@link ReactiveJwtDecoderFactory} to use for decoding refreshed oidc
* id-token, defaults to {@link ReactiveOidcIdTokenDecoderFactory}.
* @param jwtDecoderFactory the {@link ReactiveJwtDecoderFactory} to use
*/
public void setJwtDecoderFactory(ReactiveJwtDecoderFactory<ClientRegistration> jwtDecoderFactory) {
Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
this.jwtDecoderFactory = jwtDecoderFactory;
}

/**
* Sets a {@link GrantedAuthoritiesMapper} to use for mapping
* {@link GrantedAuthority}s, defaults to no-op implementation.
* @param authoritiesMapper the {@link GrantedAuthoritiesMapper} to use
*/
public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
Assert.notNull(authoritiesMapper, "authoritiesMapper cannot be null");
this.authoritiesMapper = authoritiesMapper;
}

/**
* Sets a {@link ReactiveOAuth2UserService} to use for loading an {@link OidcUser}
* from refreshed oidc id-token, defaults to {@link OidcReactiveOAuth2UserService}.
* @param userService the {@link ReactiveOAuth2UserService} to use
*/
public void setUserService(ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService) {
Assert.notNull(userService, "userService cannot be null");
this.userService = userService;
}

/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OidcIdToken#getIssuedAt()} to match the existing
* {@link OidcUser#getIdToken()}'s issuedAt time, defaults to 60 seconds.
* @param clockSkew the maximum acceptable clock skew to use
*/
public void setClockSkew(Duration clockSkew) {
Assert.notNull(clockSkew, "clockSkew cannot be null");
Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0");
this.clockSkew = clockSkew;
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -21,7 +21,9 @@
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import reactor.core.publisher.Mono;
@@ -33,13 +35,15 @@
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.util.Assert;

/**
* An implementation of a {@link ReactiveOAuth2AuthorizedClientProvider} for the
* {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant.
*
* @author Joe Grandja
* @author Evgeniy Cheban
* @since 5.2
* @see ReactiveOAuth2AuthorizedClientProvider
* @see WebClientReactiveRefreshTokenTokenResponseClient
@@ -49,6 +53,8 @@ public final class RefreshTokenReactiveOAuth2AuthorizedClientProvider

private ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient = new WebClientReactiveRefreshTokenTokenResponseClient();

private ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler = new RefreshTokenReactiveOAuth2AuthorizationSuccessHandler();

private Duration clockSkew = Duration.ofSeconds(60);

private Clock clock = Clock.systemUTC();
@@ -96,8 +102,16 @@ public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext context
.flatMap(this.accessTokenResponseClient::getTokenResponse)
.onErrorMap(OAuth2AuthorizationException.class,
(e) -> new ClientAuthorizationException(e.getError(), clientRegistration.getRegistrationId(), e))
.map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(),
tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()));
.flatMap((tokenResponse) -> {
OAuth2AuthorizedClient refreshedAuthorizedClient = new OAuth2AuthorizedClient(clientRegistration,
context.getPrincipal().getName(), tokenResponse.getAccessToken(),
tokenResponse.getRefreshToken());
Map<String, Object> attributes = new HashMap<>(context.getAttributes());
attributes.put(OAuth2AccessTokenResponse.class.getName(), tokenResponse);
return this.refreshTokenSuccessHandler
.onAuthorizationSuccess(refreshedAuthorizedClient, context.getPrincipal(), attributes)
.then(Mono.just(refreshedAuthorizedClient));
});
}

private boolean hasTokenExpired(OAuth2Token token) {
@@ -116,6 +130,19 @@ public void setAccessTokenResponseClient(
this.accessTokenResponseClient = accessTokenResponseClient;
}

/**
* Sets a {@link ReactiveOAuth2AuthorizationSuccessHandler} that is called after the
* client is re-authorized, defaults to
* {@link RefreshTokenReactiveOAuth2AuthorizationSuccessHandler}.
* @param refreshTokenSuccessHandler the
* {@link ReactiveOAuth2AuthorizationSuccessHandler} to use
* @since 7.0
*/
public void setRefreshTokenSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler) {
Assert.notNull(refreshTokenSuccessHandler, "refreshTokenSuccessHandler cannot be null");
this.refreshTokenSuccessHandler = refreshTokenSuccessHandler;
}

/**
* Sets the maximum acceptable clock skew, which is used when checking the
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -85,6 +85,7 @@
*
* @author Joe Grandja
* @author Phil Clay
* @author Evgeniy Cheban
* @since 5.2
* @see ReactiveOAuth2AuthorizedClientManager
* @see ReactiveOAuth2AuthorizedClientProvider
@@ -319,10 +320,10 @@ public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest)
return Mono.justOrEmpty(serverWebExchange)
.switchIfEmpty(currentServerWebExchangeMono)
.flatMap((exchange) -> {
Map<String, Object> contextAttributes = Collections.emptyMap();
Map<String, Object> contextAttributes = new HashMap<>();
contextAttributes.put(ServerWebExchange.class.getName(), serverWebExchange);
String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
if (StringUtils.hasText(scope)) {
contextAttributes = new HashMap<>();
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
StringUtils.delimitedListToStringArray(scope, " "));
}
Original file line number Diff line number Diff line change
@@ -51,6 +51,8 @@
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.client.ClientRequest;
@@ -96,6 +98,7 @@
* @author Rob Winch
* @author Joe Grandja
* @author Phil Clay
* @author Evgeniy Cheban
* @since 5.1
*/
public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
@@ -139,6 +142,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements

private ClientResponseHandler clientResponseHandler;

private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();

/**
* Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the
* provided parameters.
@@ -330,8 +335,11 @@ public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next)
}

private Mono<ClientResponse> exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) {
return next.exchange(request)
.transform((responseMono) -> this.clientResponseHandler.handleResponse(request, responseMono));
// Re-request an Authentication from serverSecurityContextRepository since it
// might have been changed during provider invocation.
return effectiveAuthentication(request).flatMap((authentication) -> next.exchange(request)
.transform((responseMono) -> this.clientResponseHandler.handleResponse(request, responseMono))
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication)));
}

private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
@@ -362,6 +370,17 @@ private Mono<OAuth2AuthorizeRequest> authorizeRequest(ClientRequest request) {
// @formatter:on
}

private Mono<Authentication> effectiveAuthentication(ClientRequest request) {
// @formatter:off
return effectiveServerWebExchange(request)
.filter(Optional::isPresent)
.map(Optional::get)
.flatMap(this.serverSecurityContextRepository::load)
.map(SecurityContext::getAuthentication)
.switchIfEmpty(this.currentAuthenticationMono);
// @formatter:on
}

/**
* Returns a {@link Mono} the emits the {@code clientRegistrationId} that is active
* for the given request.
@@ -445,6 +464,19 @@ public void setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHan
this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler);
}

/**
* Sets a {@link ServerSecurityContextRepository} to use for re-obtaining a
* {@link SecurityContext} if it has been refreshed during provider invocation,
* defaults to {@link WebSessionServerSecurityContextRepository}.
* @param serverSecurityContextRepository the {@link ServerSecurityContextRepository}
* to use
* @since 7.0
*/
public void setServerSecurityContextRepository(ServerSecurityContextRepository serverSecurityContextRepository) {
Assert.notNull(serverSecurityContextRepository, "serverSecurityContextRepository cannot be null");
this.serverSecurityContextRepository = serverSecurityContextRepository;
}

@FunctionalInterface
private interface ClientResponseHandler {

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -49,6 +49,7 @@
* Tests for {@link RefreshTokenReactiveOAuth2AuthorizedClientProvider}.
*
* @author Joe Grandja
* @author Evgeniy Cheban
*/
public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests {

@@ -84,6 +85,15 @@ public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgument
.withMessage("accessTokenResponseClient cannot be null");
}

@Test
public void setRefreshTokenSuccessHandlerWhenHandlerIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setRefreshTokenSuccessHandler(null))
.withMessage("refreshTokenSuccessHandler cannot be null");
// @formatter:on
}

@Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,9 +18,13 @@

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;

import reactor.core.publisher.Mono;

import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFunction;
@@ -29,14 +33,21 @@

/**
* @author Rob Winch
* @author Evgeniy Cheban
* @since 5.1
*/
public class MockExchangeFunction implements ExchangeFunction {

private final AtomicReference<Authentication> authenticationCaptor = new AtomicReference<>();

private List<ClientRequest> requests = new ArrayList<>();

private ClientResponse response = mock(ClientResponse.class);

public Authentication getCapturedAuthentication() {
return this.authenticationCaptor.get();
}

public ClientRequest getRequest() {
return this.requests.get(this.requests.size() - 1);
}
@@ -53,8 +64,14 @@ public ClientResponse getResponse() {
public Mono<ClientResponse> exchange(ClientRequest request) {
return Mono.defer(() -> {
this.requests.add(request);
return Mono.just(this.response);
return captureAuthentication().then(Mono.just(this.response));
});
}

private Mono<Authentication> captureAuthentication() {
return ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.doOnNext(this.authenticationCaptor::set);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -60,6 +60,7 @@
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider;
@@ -68,6 +69,7 @@
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationSuccessHandler;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
@@ -93,8 +95,10 @@
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.core.user.TestOAuth2Users;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.TestJwts;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.client.ClientRequest;
@@ -118,6 +122,7 @@

/**
* @author Rob Winch
* @author Evgeniy Cheban
* @since 5.1
*/
@ExtendWith(MockitoExtension.class)
@@ -144,6 +149,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler;

@Mock
private ReactiveOAuth2AuthorizationSuccessHandler refreshTokenSuccessHandler;

@Captor
private ArgumentCaptor<OAuth2AuthorizationException> authorizationExceptionCaptor;

@@ -178,7 +186,8 @@ public void setup() {
.builder()
.authorizationCode()
.refreshToken(
(configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient))
(configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)
.refreshTokenSuccessHandler(this.refreshTokenSuccessHandler))
.clientCredentials(
(configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient))
.password((configurer) -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient))
@@ -210,6 +219,13 @@ public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgument
.isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(null));
}

@Test
public void setServerSecurityContextRepositoryWhenHandlerIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager)
.setServerSecurityContextRepository(null));
}

@Test
public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();
@@ -335,14 +351,23 @@ public void filterWhenRefreshRequiredThenRefresh() {
// @formatter:on
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
// @formatter:off
WebSessionServerSecurityContextRepository securityContextRepository = new WebSessionServerSecurityContextRepository();
DefaultOAuth2User refreshedUser = TestOAuth2Users.create();
OAuth2AuthenticationToken refreshedAuthentication = new OAuth2AuthenticationToken(refreshedUser, refreshedUser.getAuthorities(), this.registration.getRegistrationId());
SecurityContextImpl securityContext = new SecurityContextImpl(refreshedAuthentication);
given(this.refreshTokenSuccessHandler.onAuthorizationSuccess(any(), eq(authentication), any()))
.willReturn(securityContextRepository.save(this.serverWebExchange, securityContext));
this.function.filter(request, this.exchange)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication))
.contextWrite(serverWebExchange())
.block();
Authentication currentAuthentication = this.exchange.getCapturedAuthentication();
assertThat(currentAuthentication).isSameAs(refreshedAuthentication);
// @formatter:on
verify(this.refreshTokenTokenResponseClient).getTokenResponse(any());
verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(),
eq(authentication), any());
verify(this.refreshTokenSuccessHandler).onAuthorizationSuccess(any(), eq(authentication), any());
OAuth2AuthorizedClient newAuthorizedClient = this.authorizedClientCaptor.getValue();
assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken());
assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(response.getRefreshToken());
@@ -364,6 +389,7 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
.refreshToken("refresh-1")
.build();
given(this.refreshTokenTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(response));
given(this.refreshTokenSuccessHandler.onAuthorizationSuccess(any(), any(), any())).willReturn(Mono.empty());
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(),