Skip to content

Commit

Permalink
let to accept mfa channel and mfa method to mfa sender and verifier a…
Browse files Browse the repository at this point in the history
…bstraction
  • Loading branch information
mrFlick72 committed Jul 26, 2024
1 parent 9aaa715 commit 735fba0
Show file tree
Hide file tree
Showing 16 changed files with 129 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.vauthenticator.server.mfa.domain.MfaMethod
import com.vauthenticator.server.mfa.domain.MfaMethodsEnrollment
import com.vauthenticator.server.oauth2.clientapp.ClientAppId
import com.vauthenticator.server.ticket.Ticket
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_AUTO_ASSOCIATION_CONTEXT_VALUE
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_SELF_ASSOCIATION_CONTEXT_VALUE
import com.vauthenticator.server.ticket.TicketId
import org.slf4j.LoggerFactory

Expand All @@ -31,7 +31,7 @@ class SendVerifyEMailChallenge(
account.email,
ClientAppId.empty(),
false,
mapOf(Ticket.MFA_AUTO_ASSOCIATION_CONTEXT_KEY to MFA_AUTO_ASSOCIATION_CONTEXT_VALUE)
mapOf(Ticket.MFA_SELF_ASSOCIATION_CONTEXT_KEY to MFA_SELF_ASSOCIATION_CONTEXT_VALUE)
)
val mailContext = mailContextFrom(verificationTicket)
mailVerificationMailSender.sendFor(account, mailContext)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.vauthenticator.server.mfa.api

import com.vauthenticator.server.mfa.domain.MfaMethod
import com.vauthenticator.server.mfa.domain.OtpMfaSender
import org.springframework.security.core.Authentication
import org.springframework.web.bind.annotation.PutMapping
Expand All @@ -10,7 +11,7 @@ class MfaChallengeEndPoint(private val otpMfaSender: OtpMfaSender) {

@PutMapping("/api/mfa/challenge")
fun sendMfaChallenge(authentication: Authentication) {
otpMfaSender.sendMfaChallenge(authentication.name, authentication.name)
otpMfaSender.sendMfaChallenge(authentication.name, MfaMethod.EMAIL_MFA_METHOD, authentication.name)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ import com.vauthenticator.server.account.Account
import com.vauthenticator.server.mfa.repository.MfaAccountMethodsRepository
import com.vauthenticator.server.oauth2.clientapp.ClientAppId
import com.vauthenticator.server.ticket.*
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_AUTO_ASSOCIATION_CONTEXT_KEY
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_AUTO_ASSOCIATION_CONTEXT_VALUE
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_CHANNEL_CONTEXT_KEY
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_METHOD_CONTEXT_KEY
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_SELF_ASSOCIATION_CONTEXT_KEY

typealias MfaAssociationVerifier = (ticket: Ticket) -> Unit

Expand All @@ -21,8 +18,8 @@ class MfaMethodsEnrollmentAssociation(
associate(
ticketId,
) {
if(it.context.content[MFA_AUTO_ASSOCIATION_CONTEXT_KEY] != MFA_AUTO_ASSOCIATION_CONTEXT_VALUE){
throw InvalidTicketException("Mfa association without code is allowed only if in the ticket context there is $MFA_AUTO_ASSOCIATION_CONTEXT_KEY feature enabled")
if (it.context.isMfaNotSelfAssociable()) {
throw InvalidTicketException("Mfa association without code is allowed only if in the ticket context there is $MFA_SELF_ASSOCIATION_CONTEXT_KEY feature enabled")
}
}

Expand All @@ -31,7 +28,14 @@ class MfaMethodsEnrollmentAssociation(
fun associate(ticketId: String, code: String) {
associate(
ticketId,
) { otpMfaVerifier.verifyMfaChallengeFor(it.userName, MfaChallenge(code)) }
) {
otpMfaVerifier.verifyMfaChallengeFor(
it.userName,
it.context.mfaMethod(),
it.context.mfaChannel(),
MfaChallenge(code)
)
}
}

private fun associate(ticket: String, verifier: MfaAssociationVerifier) {
Expand Down Expand Up @@ -70,17 +74,16 @@ class MfaMethodsEnrollment(
)

if (sendChallengeCode) {
mfaSender.sendMfaChallenge(email, mfaChannel)
mfaSender.sendMfaChallenge(email, mfaMethod, mfaChannel)
}

return ticketCreator.createTicketFor(
account,
clientAppId,
TicketContext(
mapOf(
MFA_CHANNEL_CONTEXT_KEY to mfaChannel,
MFA_METHOD_CONTEXT_KEY to mfaMethod.name
) + ticketContextAdditionalProperties
TicketContext.mfaContextFor(
mfaMethod = mfaMethod,
mfaChannel = mfaChannel,
ticketContextAdditionalProperties = ticketContextAdditionalProperties
)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import com.vauthenticator.server.email.EMailSenderService


interface OtpMfaSender {
fun sendMfaChallenge(userName: String, challengeChannel: String)
fun sendMfaChallenge(userName: String, mfaMethod: MfaMethod, mfaChannel: String)
}

interface OtpMfaVerifier {
fun verifyMfaChallengeFor(userName: String, challenge: MfaChallenge)
fun verifyMfaChallengeFor(userName: String, mfaMethod: MfaMethod, mfaChannel: String, challenge: MfaChallenge)
}

class OtpMfaEmailSender(
Expand All @@ -18,21 +18,26 @@ class OtpMfaEmailSender(
private val mfaMailSender: EMailSenderService
) : OtpMfaSender {

override fun sendMfaChallenge(userName: String, challengeChannel: String) {
override fun sendMfaChallenge(userName: String, mfaMethod: MfaMethod, mfaChannel: String) {
val account = accountRepository.accountFor(userName).get()
val mfaSecret = otpMfa.generateSecretKeyFor(account)
val mfaSecret = otpMfa.generateSecretKeyFor(account, mfaMethod, mfaChannel)
val mfaCode = otpMfa.getTOTPCode(mfaSecret).content()
mfaMailSender.sendFor(account, mapOf("email" to challengeChannel, "mfaCode" to mfaCode))
mfaMailSender.sendFor(account, mapOf("email" to mfaChannel, "mfaCode" to mfaCode))
}
}

class AccountAwareOtpMfaVerifier(
private val accountRepository: AccountRepository,
private val otpMfa: OtpMfa
) : OtpMfaVerifier {
override fun verifyMfaChallengeFor(userName: String, challenge: MfaChallenge) {
override fun verifyMfaChallengeFor(
userName: String,
mfaMethod: MfaMethod,
mfaChannel: String,
challenge: MfaChallenge
) {
val account = accountRepository.accountFor(userName).get()
otpMfa.verify(account, challenge)
otpMfa.verify(account, mfaMethod, mfaChannel, challenge)
}

}
19 changes: 9 additions & 10 deletions src/main/kotlin/com/vauthenticator/server/mfa/domain/OtpMfa.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import org.apache.commons.codec.binary.Hex

//todo the interface has to take in account the enrolled method
interface OtpMfa {
fun generateSecretKeyFor(account: Account): MfaSecret
fun generateSecretKeyFor(account: Account, mfaMethod: MfaMethod, mfaChannel: String): MfaSecret
fun getTOTPCode(secretKey: MfaSecret): MfaChallenge
fun verify(account: Account, optCode: MfaChallenge)
fun verify(account: Account, mfaMethod: MfaMethod, mfaChannel: String, optCode: MfaChallenge)
}

class TaimosOtpMfa(
Expand All @@ -26,12 +26,11 @@ class TaimosOtpMfa(
private val tokenTimeWindow: Int = properties.timeToLiveInSeconds
private val tokenTimeWindowMillis: Long = (tokenTimeWindow * 1000).toLong()

// todo to be improved
override fun generateSecretKeyFor(account: Account): MfaSecret {
//todo
val mfatMethod =
mfaAccountMethodsRepository.findOne(account.email, MfaMethod.EMAIL_MFA_METHOD, account.email).orElseGet { null }
val encryptedSecret = keyRepository.keyFor(mfatMethod.key, KeyPurpose.MFA)
override fun generateSecretKeyFor(account: Account, mfaMethod: MfaMethod, mfaChannel: String): MfaSecret {
val mfaAccountMethod =
mfaAccountMethodsRepository.findOne(account.email, mfaMethod, mfaChannel)
.orElseGet { null }
val encryptedSecret = keyRepository.keyFor(mfaAccountMethod.key, KeyPurpose.MFA)
val decryptKeyAsByteArray = keyDecrypter.decryptKey(encryptedSecret.dataKey.encryptedPrivateKeyAsString())
val decryptedKey = Hex.encodeHexString(decoder.decode(decryptKeyAsByteArray))
return MfaSecret(decryptedKey)
Expand All @@ -48,8 +47,8 @@ class TaimosOtpMfa(
)
}

override fun verify(account: Account, optCode: MfaChallenge) {
val mfaSecret = generateSecretKeyFor(account)
override fun verify(account: Account, mfaMethod: MfaMethod, mfaChannel: String, optCode: MfaChallenge) {
val mfaSecret = generateSecretKeyFor(account, mfaMethod, mfaChannel)
try {
val validated =
TimeBasedOneTimePasswordUtil.validateCurrentNumberHex(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class MfaController(

@GetMapping("/mfa-challenge/send")
fun view(authentication: Authentication): String {
otpMfaSender.sendMfaChallenge(authentication.name, authentication.name)
otpMfaSender.sendMfaChallenge(authentication.name, MfaMethod.EMAIL_MFA_METHOD, authentication.name)
return "redirect:/mfa-challenge"
}

Expand All @@ -48,13 +48,15 @@ class MfaController(
@PostMapping("/mfa-challenge")
fun processSecondFactor(
@RequestParam("mfa-code") mfaCode: String,
@RequestParam("mfa-method") mfaMethod: MfaMethod,
@RequestParam("mfa-channel") mfaChannel: String,
authentication: Authentication,
request: HttpServletRequest,
response: HttpServletResponse
) {
try {
otpMfaVerifier.verifyMfaChallengeFor(authentication.name, MfaChallenge(mfaCode))
publisher.publishEvent(MfaSuccessEvent( authentication))
otpMfaVerifier.verifyMfaChallengeFor(authentication.name, mfaMethod, mfaChannel, MfaChallenge(mfaCode))
publisher.publishEvent(MfaSuccessEvent(authentication))
nextHopeLoginWorkflowSuccessHandler.onAuthenticationSuccess(request, response, authentication)
} catch (e: Exception) {
logger.error(e.message, e)
Expand Down
31 changes: 29 additions & 2 deletions src/main/kotlin/com/vauthenticator/server/ticket/Ticket.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package com.vauthenticator.server.ticket

import com.vauthenticator.server.mfa.domain.MfaMethod
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_CHANNEL_CONTEXT_KEY
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_METHOD_CONTEXT_KEY
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_NOT_SELF_ASSOCIATION_CONTEXT_VALUE
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_SELF_ASSOCIATION_CONTEXT_KEY
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_SELF_ASSOCIATION_CONTEXT_VALUE
import java.time.Duration

data class Ticket(
Expand All @@ -12,16 +18,37 @@ data class Ticket(
companion object {
const val MFA_CHANNEL_CONTEXT_KEY = "mfaChannel"
const val MFA_METHOD_CONTEXT_KEY = "mfaMethod"
const val MFA_AUTO_ASSOCIATION_CONTEXT_KEY = "auto-association"
const val MFA_AUTO_ASSOCIATION_CONTEXT_VALUE = "true"
const val MFA_SELF_ASSOCIATION_CONTEXT_KEY = "selfAssociation"
const val MFA_SELF_ASSOCIATION_CONTEXT_VALUE = "true"
const val MFA_NOT_SELF_ASSOCIATION_CONTEXT_VALUE = "false"
}
}

data class TicketContext(val content: Map<String, String>) {

companion object {
fun empty() = TicketContext(emptyMap())
fun mfaContextFor(
mfaMethod: MfaMethod,
mfaChannel: String,
autoAssociation: Boolean = false,
ticketContextAdditionalProperties: Map<String, String>
) = TicketContext(
mapOf(
MFA_CHANNEL_CONTEXT_KEY to mfaChannel,
MFA_METHOD_CONTEXT_KEY to mfaMethod.name,
MFA_SELF_ASSOCIATION_CONTEXT_KEY to if (autoAssociation) {
MFA_SELF_ASSOCIATION_CONTEXT_VALUE
} else {
MFA_NOT_SELF_ASSOCIATION_CONTEXT_VALUE
}
) + ticketContextAdditionalProperties
)
}

fun isMfaNotSelfAssociable() = content[MFA_SELF_ASSOCIATION_CONTEXT_KEY] != MFA_SELF_ASSOCIATION_CONTEXT_VALUE
fun mfaMethod() = MfaMethod.valueOf(content[MFA_METHOD_CONTEXT_KEY]!!)
fun mfaChannel() = content[MFA_CHANNEL_CONTEXT_KEY]!!
}

data class TicketId(val content: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import com.vauthenticator.server.oauth2.clientapp.Scope
import com.vauthenticator.server.oauth2.clientapp.Scopes
import com.vauthenticator.server.support.AccountTestFixture.anAccount
import com.vauthenticator.server.ticket.Ticket
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_AUTO_ASSOCIATION_CONTEXT_VALUE
import com.vauthenticator.server.ticket.Ticket.Companion.MFA_SELF_ASSOCIATION_CONTEXT_VALUE
import com.vauthenticator.server.ticket.TicketId
import io.mockk.every
import io.mockk.impl.annotations.MockK
Expand Down Expand Up @@ -69,7 +69,7 @@ internal class SendVerifyEMailChallengeTest {
account.email,
ClientAppId.empty(),
false,
mapOf(Ticket.MFA_AUTO_ASSOCIATION_CONTEXT_KEY to MFA_AUTO_ASSOCIATION_CONTEXT_VALUE)
mapOf(Ticket.MFA_SELF_ASSOCIATION_CONTEXT_KEY to MFA_SELF_ASSOCIATION_CONTEXT_VALUE)
)
} returns ticketId
every { mailVerificationMailSender.sendFor(account, requestContext) } just runs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.vauthenticator.server.mfa.api

import com.vauthenticator.server.mfa.domain.MfaMethod
import com.vauthenticator.server.mfa.domain.OtpMfaSender
import com.vauthenticator.server.support.AccountTestFixture
import com.vauthenticator.server.support.SecurityFixture
Expand Down Expand Up @@ -35,7 +36,7 @@ internal class MfaChallengeEndPointTest {

@Test
internal fun `when an mfa challenge is sent`() {
every { otpMfaSender.sendMfaChallenge(account.email,account.email) } just runs
every { otpMfaSender.sendMfaChallenge(account.email, MfaMethod.EMAIL_MFA_METHOD, account.email) } just runs

mokMvc.perform(
put("/api/mfa/challenge")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ internal class AccountAwareOtpMfaVerifierTest {
val underTest = AccountAwareOtpMfaVerifier(accountRepository, otpMfa)

every { accountRepository.accountFor(account.email) } returns Optional.of(account)
every { otpMfa.verify(account, challenge) } just runs
every { otpMfa.verify(account, MfaMethod.EMAIL_MFA_METHOD, account.email, challenge) } just runs

underTest.verifyMfaChallengeFor(account.email, challenge)
underTest.verifyMfaChallengeFor(account.email, MfaMethod.EMAIL_MFA_METHOD, account.email, challenge)
}

@Test
Expand All @@ -39,10 +39,17 @@ internal class AccountAwareOtpMfaVerifierTest {
val challenge = MfaChallenge("AN_MFA_CHALLENGE")

every { accountRepository.accountFor(account.email) } returns Optional.of(account)
every { otpMfa.verify(account, challenge) } throws MfaException("")
every { otpMfa.verify(account, MfaMethod.EMAIL_MFA_METHOD, account.email, challenge) } throws MfaException("")

val underTest = AccountAwareOtpMfaVerifier(accountRepository, otpMfa)

assertThrows(MfaException::class.java) { underTest.verifyMfaChallengeFor(account.email, challenge) }
assertThrows(MfaException::class.java) {
underTest.verifyMfaChallengeFor(
account.email,
MfaMethod.EMAIL_MFA_METHOD,
account.email,
challenge
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class MfaMethodsEnrollmentTest {
)
} returns Optional.of(emailMfaAccountMethod)
every { ticketCreator.createTicketFor(account, clientAppId, ticketContext(emailMfaChannel)) } returns ticketId
every { mfaSender.sendMfaChallenge(account.email, emailMfaChannel) } just runs
every { mfaSender.sendMfaChallenge(account.email, EMAIL_MFA_METHOD,emailMfaChannel) } just runs

val actual = uut.enroll(account, EMAIL_MFA_METHOD, emailMfaChannel, clientAppId, true)

Expand All @@ -102,7 +102,7 @@ class MfaMethodsEnrollmentTest {
)
}
verify { ticketCreator.createTicketFor(account, clientAppId, ticketContext(emailMfaChannel)) }
verify { mfaSender.sendMfaChallenge(account.email, emailMfaChannel) }
verify { mfaSender.sendMfaChallenge(account.email, EMAIL_MFA_METHOD,emailMfaChannel) }

assertEquals(ticketId, actual)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ internal class OtpMfaEmailSenderTest {
val underTest = OtpMfaEmailSender(accountRepository, otp, mfaMailSender)

every { accountRepository.accountFor(account.email) } returns Optional.of(account)
every { otp.generateSecretKeyFor(account) } returns mfaSecret
every { otp.generateSecretKeyFor(account, MfaMethod.EMAIL_MFA_METHOD, account.email) } returns mfaSecret
every { otp.getTOTPCode(mfaSecret) } returns mfaChallenge
every {
mfaMailSender.sendFor(
Expand All @@ -41,6 +41,6 @@ internal class OtpMfaEmailSenderTest {
)
} just runs

underTest.sendMfaChallenge(account.email, account.email)
underTest.sendMfaChallenge(account.email, MfaMethod.EMAIL_MFA_METHOD, account.email)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class TaimosOtpMfaTest {

every { keyRepository.keyFor(Kid("A_KID"), KeyPurpose.MFA) } returns key
every { keyDecrypter.decryptKey("QV9FTkNSWVBURURfS0VZ") } returns "QV9ERUNSWVBURURfU1lNTUVUUklDX0tFWQ=="
val actual = underTest.generateSecretKeyFor(account)
val actual = underTest.generateSecretKeyFor(account, MfaMethod.EMAIL_MFA_METHOD, email)
val expectedSecret = Hex.encodeHexString(decoder.decode("QV9ERUNSWVBURURfU1lNTUVUUklDX0tFWQ=="))
assertEquals(MfaSecret(expectedSecret), actual)
}
Expand Down
Loading

0 comments on commit 735fba0

Please sign in to comment.