Skip to content

Add support one-time token value customization #16946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
package org.springframework.security.authentication.ott;

import java.time.Duration;
import java.util.UUID;

import org.springframework.util.Assert;

/**
* Class to store information related to an One-Time Token authentication request
*
* @author Marcus da Coregio
* @author Max Batiscev
* @since 6.4
*/
public class GenerateOneTimeTokenRequest {
Expand All @@ -34,6 +36,8 @@ public class GenerateOneTimeTokenRequest {

private final Duration expiresIn;

private final String tokenValue;

public GenerateOneTimeTokenRequest(String username) {
this(username, DEFAULT_EXPIRES_IN);
}
Expand All @@ -43,6 +47,16 @@ public GenerateOneTimeTokenRequest(String username, Duration expiresIn) {
Assert.notNull(expiresIn, "expiresIn cannot be null");
this.username = username;
this.expiresIn = expiresIn;
this.tokenValue = UUID.randomUUID().toString();
}

public GenerateOneTimeTokenRequest(String username, Duration expiresIn, String tokenValue) {
Assert.hasText(username, "username cannot be empty");
Assert.hasText(tokenValue, "tokenValue cannot be empty");
Assert.notNull(expiresIn, "expiresIn cannot be null");
this.username = username;
this.expiresIn = expiresIn;
this.tokenValue = tokenValue;
}

public String getUsername() {
Expand All @@ -53,4 +67,8 @@ public Duration getExpiresIn() {
return this.expiresIn;
}

public String getTokenValue() {
return this.tokenValue;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
* there is more or equal than 100 tokens stored in the map.
*
* @author Marcus da Coregio
* @author Max Batischev
* @since 6.4
*/
public final class InMemoryOneTimeTokenService implements OneTimeTokenService {
Expand All @@ -43,10 +44,9 @@ public final class InMemoryOneTimeTokenService implements OneTimeTokenService {
@Override
@NonNull
public OneTimeToken generate(GenerateOneTimeTokenRequest request) {
String token = UUID.randomUUID().toString();
Instant expiresAt = this.clock.instant().plus(request.getExpiresIn());
OneTimeToken ott = new DefaultOneTimeToken(token, request.getUsername(), expiresAt);
this.oneTimeTokenByToken.put(token, ott);
OneTimeToken ott = new DefaultOneTimeToken(request.getTokenValue(), request.getUsername(), expiresAt);
this.oneTimeTokenByToken.put(request.getTokenValue(), ott);
cleanExpiredTokensIfNeeded();
return ott;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.function.Function;

import org.apache.commons.logging.Log;
Expand Down Expand Up @@ -130,9 +129,8 @@ public void setCleanupCron(String cleanupCron) {
@Override
public OneTimeToken generate(GenerateOneTimeTokenRequest request) {
Assert.notNull(request, "generateOneTimeTokenRequest cannot be null");
String token = UUID.randomUUID().toString();
Instant expiresAt = this.clock.instant().plus(request.getExpiresIn());
OneTimeToken oneTimeToken = new DefaultOneTimeToken(token, request.getUsername(), expiresAt);
OneTimeToken oneTimeToken = new DefaultOneTimeToken(request.getTokenValue(), request.getUsername(), expiresAt);
insertOneTimeToken(oneTimeToken);
return oneTimeToken;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.springframework.security.web.authentication.ott;

import java.time.Duration;
import java.util.UUID;
import java.util.function.Supplier;

import jakarta.servlet.http.HttpServletRequest;

Expand All @@ -37,13 +39,15 @@ public final class DefaultGenerateOneTimeTokenRequestResolver implements Generat

private Duration expiresIn = DEFAULT_EXPIRES_IN;

private Supplier<String> tokenValueFactory = () -> UUID.randomUUID().toString();

@Override
public GenerateOneTimeTokenRequest resolve(HttpServletRequest request) {
String username = request.getParameter("username");
if (!StringUtils.hasText(username)) {
return null;
}
return new GenerateOneTimeTokenRequest(username, this.expiresIn);
return new GenerateOneTimeTokenRequest(username, this.expiresIn, this.tokenValueFactory.get());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not generate token value here, since it breaks the semantic of GenerateOneTimeTokenRequestResolver which is:
"A strategy for resolving a GenerateOneTimeTokenRequest from the HttpServletRequest".

It might also be confusing for users who provide a custom tokenValueFactory and notice that their factory gets called before OneTimeTokenService#generate method call.

So I would consider passing Supplier<String> tokenValueFactory to GenerateOneTimeTokenRequest and call it in the OneTimeTokenService#generate method.

}

/**
Expand All @@ -55,4 +59,14 @@ public void setExpiresIn(Duration expiresIn) {
this.expiresIn = expiresIn;
}

/**
* Sets factory for token value generation
* @param tokenValueFactory factory for token value generation
* @since 6.5
*/
public void setTokenValueFactory(Supplier<String> tokenValueFactory) {
Assert.notNull(tokenValueFactory, "tokenValueFactory cannot be null");
this.tokenValueFactory = tokenValueFactory;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.springframework.security.web.server.authentication.ott;

import java.time.Duration;
import java.util.UUID;
import java.util.function.Supplier;

import reactor.core.publisher.Mono;

Expand All @@ -40,13 +42,15 @@ public final class DefaultServerGenerateOneTimeTokenRequestResolver

private Duration expiresIn = DEFAULT_EXPIRES_IN;

private Supplier<String> tokenValueFactory = () -> UUID.randomUUID().toString();

@Override
public Mono<GenerateOneTimeTokenRequest> resolve(ServerWebExchange exchange) {
// @formatter:off
return exchange.getFormData()
.mapNotNull((data) -> data.getFirst(USERNAME))
.switchIfEmpty(Mono.empty())
.map((username) -> new GenerateOneTimeTokenRequest(username, this.expiresIn));
.map((username) -> new GenerateOneTimeTokenRequest(username, this.expiresIn, this.tokenValueFactory.get()));
// @formatter:on
}

Expand All @@ -59,4 +63,14 @@ public void setExpiresIn(Duration expiresIn) {
this.expiresIn = expiresIn;
}

/**
* Sets factory for token value generation
* @param tokenValueFactory factory for token value generation
* @since 6.5
*/
public void setTokenValueFactory(Supplier<String> tokenValueFactory) {
Assert.notNull(tokenValueFactory, "tokenValueFactory cannot be null");
this.tokenValueFactory = tokenValueFactory;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,15 @@ void resolveWhenExpiresInSetThenResolvesGenerateRequest() {
assertThat(generateRequest.getExpiresIn()).isEqualTo(Duration.ofSeconds(600));
}

@Test
void resolveWhenTokenValueFactorySetThenResolvesGenerateRequest() {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setParameter("username", "test");
this.requestResolver.setTokenValueFactory(() -> "tokenValue");

GenerateOneTimeTokenRequest generateRequest = this.requestResolver.resolve(request);

assertThat(generateRequest.getTokenValue()).isEqualTo("tokenValue");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,16 @@ void resolveWhenExpiresInSetThenResolvesGenerateRequest() {
assertThat(generateRequest.getExpiresIn()).isEqualTo(Duration.ofSeconds(600));
}

@Test
void resolveWhenTokenValueFactorySetThenResolvesGenerateRequest() {
MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.post("/ott/generate")
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
.body("username=user"));
this.resolver.setTokenValueFactory(() -> "tokenValue");

GenerateOneTimeTokenRequest generateRequest = this.resolver.resolve(exchange).block();

assertThat(generateRequest.getTokenValue()).isEqualTo("tokenValue");
}

}