diff --git a/src/main/kotlin/com/vauthenticator/server/mfa/adapter/dynamodb/DynamoMfaAccountMethodsRepository.kt b/src/main/kotlin/com/vauthenticator/server/mfa/adapter/dynamodb/DynamoMfaAccountMethodsRepository.kt index 09cfa26c..7203b9d1 100644 --- a/src/main/kotlin/com/vauthenticator/server/mfa/adapter/dynamodb/DynamoMfaAccountMethodsRepository.kt +++ b/src/main/kotlin/com/vauthenticator/server/mfa/adapter/dynamodb/DynamoMfaAccountMethodsRepository.kt @@ -83,11 +83,16 @@ class DynamoMfaAccountMethodsRepository( mfaChannel: String, associated: Boolean ): MfaAccountMethod { - //todo kid and device id should be the same across save request - - val kid = keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) - val mfaDeviceId = mfaDeviceIdGenerator.invoke() - storeOnDynamo(userName, mfaMfaMethod, mfaChannel, mfaDeviceId, kid, associated) + val (kid, mfaDeviceId) = findBy(userName, mfaMfaMethod, mfaChannel) + .map { + listOf(it.key, it.mfaDeviceId) + }.orElseGet { + val kid = keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) + val mfaDeviceId = mfaDeviceIdGenerator.invoke() + listOf(kid, mfaDeviceId) + } + + storeOnDynamo(userName, mfaMfaMethod, mfaChannel, mfaDeviceId as MfaDeviceId, kid as Kid, associated) return MfaAccountMethod(userName, mfaDeviceId, kid, mfaMfaMethod, mfaChannel, associated) } diff --git a/src/main/kotlin/com/vauthenticator/server/mfa/adapter/jdbc/JdbcMfaAccountMethodsRepository.kt b/src/main/kotlin/com/vauthenticator/server/mfa/adapter/jdbc/JdbcMfaAccountMethodsRepository.kt index c1b44b34..ab27890f 100644 --- a/src/main/kotlin/com/vauthenticator/server/mfa/adapter/jdbc/JdbcMfaAccountMethodsRepository.kt +++ b/src/main/kotlin/com/vauthenticator/server/mfa/adapter/jdbc/JdbcMfaAccountMethodsRepository.kt @@ -47,13 +47,34 @@ class JdbcMfaAccountMethodsRepository( mfaChannel: String, associated: Boolean ): MfaAccountMethod { - val kid = keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) - val mfaDeviceId = mfaDeviceIdGenerator.invoke() + val (kid, mfaDeviceId) = findBy(userName, mfaMfaMethod, mfaChannel) + .map { + listOf(it.key, it.mfaDeviceId) + }.orElseGet { + val kid = keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) + val mfaDeviceId = mfaDeviceIdGenerator.invoke() + listOf(kid, mfaDeviceId) + } jdbcTemplate.update( - "INSERT INTO MFA_ACCOUNT_METHODS (user_name, mfa_device_id, mfa_method, mfa_channel, key_id, associated) VALUES (?,?,?,?,?,?)", - userName, mfaDeviceId.content, mfaMfaMethod.name, mfaChannel, kid.content(), associated - ) + """INSERT INTO MFA_ACCOUNT_METHODS (user_name, mfa_device_id, mfa_method, mfa_channel, key_id, associated) VALUES (?,?,?,?,?,?) + ON CONFLICT(user_name, mfa_channel) + DO UPDATE SET mfa_device_id=?, + mfa_method=?, + key_id=?, + associated=? + """, + userName, + (mfaDeviceId as MfaDeviceId).content, + mfaMfaMethod.name, + mfaChannel, + (kid as Kid).content(), + associated, + mfaDeviceId.content, + mfaMfaMethod.name, + kid.content(), + associated + ) return MfaAccountMethod(userName, mfaDeviceId, kid, mfaMfaMethod, mfaChannel, associated) } diff --git a/src/test/kotlin/com/vauthenticator/server/mfa/adapter/AbstractMfaAccountMethodsRepositoryTest.kt b/src/test/kotlin/com/vauthenticator/server/mfa/adapter/AbstractMfaAccountMethodsRepositoryTest.kt index fb2687b6..1eddcce8 100644 --- a/src/test/kotlin/com/vauthenticator/server/mfa/adapter/AbstractMfaAccountMethodsRepositoryTest.kt +++ b/src/test/kotlin/com/vauthenticator/server/mfa/adapter/AbstractMfaAccountMethodsRepositoryTest.kt @@ -26,8 +26,10 @@ abstract class AbstractMfaAccountMethodsRepositoryTest { lateinit var uut: MfaAccountMethodsRepository abstract fun initMfaAccountMethodsRepository(): MfaAccountMethodsRepository abstract fun resetDatabase() + @BeforeEach fun setUp() { + every { keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) } returns key resetDatabase() uut = initMfaAccountMethodsRepository() @@ -35,8 +37,6 @@ abstract class AbstractMfaAccountMethodsRepositoryTest { @Test fun `when a mfa account method is stored`() { - every { keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) } returns key - uut.save(email, MfaMethod.EMAIL_MFA_METHOD, email, true) val mfaAccountMethods = uut.findAll(email) assertEquals( @@ -86,7 +86,6 @@ abstract class AbstractMfaAccountMethodsRepositoryTest { @Test fun `when decide what mfa use as default`() { - every { keyRepository.createKeyFrom(masterKid, KeyType.SYMMETRIC, KeyPurpose.MFA) } returns key uut.save(email, MfaMethod.EMAIL_MFA_METHOD, email, true) val expected = Optional.of(mfaDeviceId) @@ -95,4 +94,23 @@ abstract class AbstractMfaAccountMethodsRepositoryTest { assertEquals(expected, defaultDevice) } + + @Test + fun `when a mfa account method is stored and then enabled`() { + uut.save(email, MfaMethod.EMAIL_MFA_METHOD, email, false) + val beforeToBeAssociated = uut.findBy(email, MfaMethod.EMAIL_MFA_METHOD, email).get() + + uut.save(email, MfaMethod.EMAIL_MFA_METHOD, email, true) + val afterAssociated = uut.findBy(email, MfaMethod.EMAIL_MFA_METHOD, email).get() + + + assertEquals(afterAssociated.mfaDeviceId, beforeToBeAssociated.mfaDeviceId) + assertEquals(afterAssociated.mfaChannel, beforeToBeAssociated.mfaChannel) + assertEquals(afterAssociated.mfaMethod, beforeToBeAssociated.mfaMethod) + assertEquals(beforeToBeAssociated.associated, false) + assertEquals(afterAssociated.associated, true) + assertEquals(afterAssociated.key, beforeToBeAssociated.key) + assertEquals(afterAssociated.userName, beforeToBeAssociated.userName) + } + } \ No newline at end of file