diff --git a/src/main/kotlin/com/vauthenticator/server/mfa/adapter/dynamodb/DynamoMfaAccountMethodsRepository.kt b/src/main/kotlin/com/vauthenticator/server/mfa/adapter/dynamodb/DynamoMfaAccountMethodsRepository.kt index 09cfa26c..7203b9d1 100644 --- a/src/main/kotlin/com/vauthenticator/server/mfa/adapter/dynamodb/DynamoMfaAccountMethodsRepository.kt +++ b/src/main/kotlin/com/vauthenticator/server/mfa/adapter/dynamodb/DynamoMfaAccountMethodsRepository.kt @@ -83,11 +83,16 @@ class DynamoMfaAccountMethodsRepository( mfaChannel: String, associated: Boolean ): MfaAccountMethod { - //todo kid and device id should be the same across save request - - val kid = keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) - val mfaDeviceId = mfaDeviceIdGenerator.invoke() - storeOnDynamo(userName, mfaMfaMethod, mfaChannel, mfaDeviceId, kid, associated) + val (kid, mfaDeviceId) = findBy(userName, mfaMfaMethod, mfaChannel) + .map { + listOf(it.key, it.mfaDeviceId) + }.orElseGet { + val kid = keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) + val mfaDeviceId = mfaDeviceIdGenerator.invoke() + listOf(kid, mfaDeviceId) + } + + storeOnDynamo(userName, mfaMfaMethod, mfaChannel, mfaDeviceId as MfaDeviceId, kid as Kid, associated) return MfaAccountMethod(userName, mfaDeviceId, kid, mfaMfaMethod, mfaChannel, associated) } diff --git a/src/main/kotlin/com/vauthenticator/server/mfa/adapter/jdbc/JdbcMfaAccountMethodsRepository.kt b/src/main/kotlin/com/vauthenticator/server/mfa/adapter/jdbc/JdbcMfaAccountMethodsRepository.kt index c1b44b34..ab27890f 100644 --- a/src/main/kotlin/com/vauthenticator/server/mfa/adapter/jdbc/JdbcMfaAccountMethodsRepository.kt +++ b/src/main/kotlin/com/vauthenticator/server/mfa/adapter/jdbc/JdbcMfaAccountMethodsRepository.kt @@ -47,13 +47,34 @@ class JdbcMfaAccountMethodsRepository( mfaChannel: String, associated: Boolean ): MfaAccountMethod { - val kid = keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) - val mfaDeviceId = mfaDeviceIdGenerator.invoke() + val (kid, mfaDeviceId) = findBy(userName, mfaMfaMethod, mfaChannel) + .map { + listOf(it.key, it.mfaDeviceId) + }.orElseGet { + val kid = keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) + val mfaDeviceId = mfaDeviceIdGenerator.invoke() + listOf(kid, mfaDeviceId) + } jdbcTemplate.update( - "INSERT INTO MFA_ACCOUNT_METHODS (user_name, mfa_device_id, mfa_method, mfa_channel, key_id, associated) VALUES (?,?,?,?,?,?)", - userName, mfaDeviceId.content, mfaMfaMethod.name, mfaChannel, kid.content(), associated - ) + """INSERT INTO MFA_ACCOUNT_METHODS (user_name, mfa_device_id, mfa_method, mfa_channel, key_id, associated) VALUES (?,?,?,?,?,?) + ON CONFLICT(user_name, mfa_channel) + DO UPDATE SET mfa_device_id=?, + mfa_method=?, + key_id=?, + associated=? + """, + userName, + (mfaDeviceId as MfaDeviceId).content, + mfaMfaMethod.name, + mfaChannel, + (kid as Kid).content(), + associated, + mfaDeviceId.content, + mfaMfaMethod.name, + kid.content(), + associated + ) return MfaAccountMethod(userName, mfaDeviceId, kid, mfaMfaMethod, mfaChannel, associated) } diff --git a/src/test/kotlin/com/vauthenticator/server/mfa/adapter/AbstractMfaAccountMethodsRepositoryTest.kt b/src/test/kotlin/com/vauthenticator/server/mfa/adapter/AbstractMfaAccountMethodsRepositoryTest.kt index fb2687b6..0a999065 100644 --- a/src/test/kotlin/com/vauthenticator/server/mfa/adapter/AbstractMfaAccountMethodsRepositoryTest.kt +++ b/src/test/kotlin/com/vauthenticator/server/mfa/adapter/AbstractMfaAccountMethodsRepositoryTest.kt @@ -35,7 +35,7 @@ abstract class AbstractMfaAccountMethodsRepositoryTest { @Test fun `when a mfa account method is stored`() { - every { keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) } returns key + whenAKeyIsStored() uut.save(email, MfaMethod.EMAIL_MFA_METHOD, email, true) val mfaAccountMethods = uut.findAll(email) @@ -86,7 +86,7 @@ abstract class AbstractMfaAccountMethodsRepositoryTest { @Test fun `when decide what mfa use as default`() { - every { keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) } returns key + whenAKeyIsStored() uut.save(email, MfaMethod.EMAIL_MFA_METHOD, email, true) val expected = Optional.of(mfaDeviceId) @@ -95,4 +95,28 @@ abstract class AbstractMfaAccountMethodsRepositoryTest { assertEquals(expected, defaultDevice) } + + @Test + fun `when a mfa account method is stored and then enabled`() { + whenAKeyIsStored() + + uut.save(email, MfaMethod.EMAIL_MFA_METHOD, email, false) + val beforeToBeAssociated = uut.findBy(email, MfaMethod.EMAIL_MFA_METHOD, email).get() + + uut.save(email, MfaMethod.EMAIL_MFA_METHOD, email, true) + val afterAssociated = uut.findBy(email, MfaMethod.EMAIL_MFA_METHOD, email).get() + + + assertEquals(afterAssociated.mfaDeviceId, beforeToBeAssociated.mfaDeviceId) + assertEquals(afterAssociated.mfaChannel, beforeToBeAssociated.mfaChannel) + assertEquals(afterAssociated.mfaMethod, beforeToBeAssociated.mfaMethod) + assertEquals(beforeToBeAssociated.associated,false) + assertEquals(afterAssociated.associated, true) + assertEquals(afterAssociated.key, beforeToBeAssociated.key) + assertEquals(afterAssociated.userName, beforeToBeAssociated.userName) + } + + private fun whenAKeyIsStored() { + every { keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) } returns key + } } \ No newline at end of file