extracting the direct login logic to its own use case along with viewmodel test case

- will ensure we emit account sign in when going via direct login flow
This commit is contained in:
Adam Brown 2022-03-24 12:44:57 +00:00
parent 10974366fb
commit 88197991e1
4 changed files with 159 additions and 72 deletions

View File

@ -0,0 +1,91 @@
/*
* Copyright (c) 2022 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package im.vector.app.features.onboarding
import android.net.Uri
import im.vector.app.R
import im.vector.app.core.resources.StringProvider
import im.vector.app.features.onboarding.OnboardingAction.LoginOrRegister
import org.matrix.android.sdk.api.MatrixPatterns.getDomain
import org.matrix.android.sdk.api.auth.AuthenticationService
import org.matrix.android.sdk.api.auth.data.HomeServerConnectionConfig
import org.matrix.android.sdk.api.auth.wellknown.WellknownResult
import org.matrix.android.sdk.api.session.Session
import javax.inject.Inject
class DirectLoginUseCase @Inject constructor(
private val authenticationService: AuthenticationService,
private val stringProvider: StringProvider,
) {
suspend fun execute(action: LoginOrRegister, homeServerConnectionConfig: HomeServerConnectionConfig?): Result<Session> {
return fetchWellKnown(action.username, homeServerConnectionConfig)
.andThen { wellKnown -> createSessionFor(wellKnown, action, homeServerConnectionConfig) }
}
private suspend fun fetchWellKnown(matrixId: String, config: HomeServerConnectionConfig?) = runCatching {
authenticationService.getWellKnownData(matrixId, config)
}
private suspend fun createSessionFor(data: WellknownResult, action: LoginOrRegister, config: HomeServerConnectionConfig?) = when (data) {
is WellknownResult.Prompt -> loginDirect(action, data, config)
is WellknownResult.FailPrompt -> handleFailPrompt(data, action, config)
else -> onWellKnownError()
}
private suspend fun handleFailPrompt(data: WellknownResult.FailPrompt, action: LoginOrRegister, config: HomeServerConnectionConfig?): Result<Session> {
// Relax on IS discovery if homeserver is valid
val isMissingInformationToLogin = data.homeServerUrl == null || data.wellKnown == null
return when {
isMissingInformationToLogin -> onWellKnownError()
else -> loginDirect(action, WellknownResult.Prompt(data.homeServerUrl!!, null, data.wellKnown!!), config)
}
}
private suspend fun loginDirect(action: LoginOrRegister, wellKnownPrompt: WellknownResult.Prompt, config: HomeServerConnectionConfig?): Result<Session> {
val alteredHomeServerConnectionConfig = config?.updateWith(wellKnownPrompt) ?: fallbackConfig(action, wellKnownPrompt)
return runCatching {
authenticationService.directAuthentication(
alteredHomeServerConnectionConfig,
action.username,
action.password,
action.initialDeviceName
)
}
}
private fun HomeServerConnectionConfig.updateWith(wellKnownPrompt: WellknownResult.Prompt) = copy(
homeServerUriBase = Uri.parse(wellKnownPrompt.homeServerUrl),
identityServerUri = wellKnownPrompt.identityServerUrl?.let { Uri.parse(it) }
)
private fun fallbackConfig(action: LoginOrRegister, wellKnownPrompt: WellknownResult.Prompt) = HomeServerConnectionConfig(
homeServerUri = Uri.parse("https://${action.username.getDomain()}"),
homeServerUriBase = Uri.parse(wellKnownPrompt.homeServerUrl),
identityServerUri = wellKnownPrompt.identityServerUrl?.let { Uri.parse(it) }
)
private fun onWellKnownError() = Result.failure<Session>(Exception(stringProvider.getString(R.string.autodiscover_well_known_error)))
}
@Suppress("UNCHECKED_CAST") // We're casting null failure results to R
private inline fun <T, R> Result<T>.andThen(block: (T) -> Result<R>): Result<R> {
return when (val result = getOrNull()) {
null -> this as Result<R>
else -> block(result)
}
}

View File

@ -17,7 +17,6 @@
package im.vector.app.features.onboarding package im.vector.app.features.onboarding
import android.content.Context import android.content.Context
import android.net.Uri
import com.airbnb.mvrx.MavericksViewModelFactory import com.airbnb.mvrx.MavericksViewModelFactory
import dagger.assisted.Assisted import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory import dagger.assisted.AssistedFactory
@ -45,7 +44,6 @@ import im.vector.app.features.login.SignMode
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.flow.firstOrNull
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import org.matrix.android.sdk.api.MatrixPatterns.getDomain
import org.matrix.android.sdk.api.auth.AuthenticationService import org.matrix.android.sdk.api.auth.AuthenticationService
import org.matrix.android.sdk.api.auth.HomeServerHistoryService import org.matrix.android.sdk.api.auth.HomeServerHistoryService
import org.matrix.android.sdk.api.auth.data.HomeServerConnectionConfig import org.matrix.android.sdk.api.auth.data.HomeServerConnectionConfig
@ -55,9 +53,6 @@ import org.matrix.android.sdk.api.auth.registration.FlowResult
import org.matrix.android.sdk.api.auth.registration.RegistrationResult import org.matrix.android.sdk.api.auth.registration.RegistrationResult
import org.matrix.android.sdk.api.auth.registration.RegistrationWizard import org.matrix.android.sdk.api.auth.registration.RegistrationWizard
import org.matrix.android.sdk.api.auth.registration.Stage import org.matrix.android.sdk.api.auth.registration.Stage
import org.matrix.android.sdk.api.auth.wellknown.WellknownResult
import org.matrix.android.sdk.api.failure.Failure
import org.matrix.android.sdk.api.failure.MatrixIdFailure
import org.matrix.android.sdk.api.session.Session import org.matrix.android.sdk.api.session.Session
import timber.log.Timber import timber.log.Timber
import java.util.UUID import java.util.UUID
@ -79,6 +74,7 @@ class OnboardingViewModel @AssistedInject constructor(
private val analyticsTracker: AnalyticsTracker, private val analyticsTracker: AnalyticsTracker,
private val uriFilenameResolver: UriFilenameResolver, private val uriFilenameResolver: UriFilenameResolver,
private val registrationActionHandler: RegistrationActionHandler, private val registrationActionHandler: RegistrationActionHandler,
private val directLoginUseCase: DirectLoginUseCase,
private val vectorOverrides: VectorOverrides private val vectorOverrides: VectorOverrides
) : VectorViewModel<OnboardingViewState, OnboardingAction, OnboardingViewEvents>(initialState) { ) : VectorViewModel<OnboardingViewState, OnboardingAction, OnboardingViewEvents>(initialState) {
@ -470,74 +466,14 @@ class OnboardingViewModel @AssistedInject constructor(
private fun handleDirectLogin(action: OnboardingAction.LoginOrRegister, homeServerConnectionConfig: HomeServerConnectionConfig?) { private fun handleDirectLogin(action: OnboardingAction.LoginOrRegister, homeServerConnectionConfig: HomeServerConnectionConfig?) {
setState { copy(isLoading = true) } setState { copy(isLoading = true) }
currentJob = viewModelScope.launch { currentJob = viewModelScope.launch {
val data = try { directLoginUseCase.execute(action, homeServerConnectionConfig).fold(
authenticationService.getWellKnownData(action.username, homeServerConnectionConfig) onSuccess = { onSessionCreated(it, isAccountCreated = false) },
} catch (failure: Throwable) { onFailure = {
onDirectLoginError(failure) setState { copy(isLoading = false) }
return@launch _viewEvents.post(OnboardingViewEvents.Failure(it))
}
when (data) {
is WellknownResult.Prompt ->
directLoginOnWellknownSuccess(action, data, homeServerConnectionConfig)
is WellknownResult.FailPrompt ->
// Relax on IS discovery if homeserver is valid
if (data.homeServerUrl != null && data.wellKnown != null) {
directLoginOnWellknownSuccess(action, WellknownResult.Prompt(data.homeServerUrl!!, null, data.wellKnown!!), homeServerConnectionConfig)
} else {
onWellKnownError()
} }
else -> { )
onWellKnownError()
}
}
}
}
private fun onWellKnownError() {
setState { copy(isLoading = false) }
_viewEvents.post(OnboardingViewEvents.Failure(Exception(stringProvider.getString(R.string.autodiscover_well_known_error))))
}
private suspend fun directLoginOnWellknownSuccess(action: OnboardingAction.LoginOrRegister,
wellKnownPrompt: WellknownResult.Prompt,
homeServerConnectionConfig: HomeServerConnectionConfig?) {
val alteredHomeServerConnectionConfig = homeServerConnectionConfig
?.copy(
homeServerUriBase = Uri.parse(wellKnownPrompt.homeServerUrl),
identityServerUri = wellKnownPrompt.identityServerUrl?.let { Uri.parse(it) }
)
?: HomeServerConnectionConfig(
homeServerUri = Uri.parse("https://${action.username.getDomain()}"),
homeServerUriBase = Uri.parse(wellKnownPrompt.homeServerUrl),
identityServerUri = wellKnownPrompt.identityServerUrl?.let { Uri.parse(it) }
)
val data = try {
authenticationService.directAuthentication(
alteredHomeServerConnectionConfig,
action.username,
action.password,
action.initialDeviceName)
} catch (failure: Throwable) {
onDirectLoginError(failure)
return
}
onSessionCreated(data, isAccountCreated = false)
}
private fun onDirectLoginError(failure: Throwable) {
when (failure) {
is MatrixIdFailure.InvalidMatrixId,
is Failure.UnrecognizedCertificateFailure -> {
setState { copy(isLoading = false) }
// Display this error in a dialog
_viewEvents.post(OnboardingViewEvents.Failure(failure))
}
else -> {
setState { copy(isLoading = false) }
}
} }
} }

View File

@ -24,6 +24,7 @@ import im.vector.app.test.fakes.FakeActiveSessionHolder
import im.vector.app.test.fakes.FakeAnalyticsTracker import im.vector.app.test.fakes.FakeAnalyticsTracker
import im.vector.app.test.fakes.FakeAuthenticationService import im.vector.app.test.fakes.FakeAuthenticationService
import im.vector.app.test.fakes.FakeContext import im.vector.app.test.fakes.FakeContext
import im.vector.app.test.fakes.FakeDirectLoginUseCase
import im.vector.app.test.fakes.FakeHomeServerConnectionConfigFactory import im.vector.app.test.fakes.FakeHomeServerConnectionConfigFactory
import im.vector.app.test.fakes.FakeHomeServerHistoryService import im.vector.app.test.fakes.FakeHomeServerHistoryService
import im.vector.app.test.fakes.FakeRegisterActionHandler import im.vector.app.test.fakes.FakeRegisterActionHandler
@ -44,6 +45,7 @@ import org.matrix.android.sdk.api.auth.registration.FlowResult
import org.matrix.android.sdk.api.auth.registration.RegisterThreePid import org.matrix.android.sdk.api.auth.registration.RegisterThreePid
import org.matrix.android.sdk.api.auth.registration.RegistrationResult import org.matrix.android.sdk.api.auth.registration.RegistrationResult
import org.matrix.android.sdk.api.auth.registration.Stage import org.matrix.android.sdk.api.auth.registration.Stage
import org.matrix.android.sdk.api.session.Session
import org.matrix.android.sdk.api.session.homeserver.HomeServerCapabilities import org.matrix.android.sdk.api.session.homeserver.HomeServerCapabilities
private const val A_DISPLAY_NAME = "a display name" private const val A_DISPLAY_NAME = "a display name"
@ -55,6 +57,7 @@ private val A_RESULT_IGNORED_REGISTER_ACTION = RegisterAction.AddThreePid(Regist
private val A_HOMESERVER_CAPABILITIES = aHomeServerCapabilities(canChangeDisplayName = true, canChangeAvatar = true) private val A_HOMESERVER_CAPABILITIES = aHomeServerCapabilities(canChangeDisplayName = true, canChangeAvatar = true)
private val AN_IGNORED_FLOW_RESULT = FlowResult(missingStages = emptyList(), completedStages = emptyList()) private val AN_IGNORED_FLOW_RESULT = FlowResult(missingStages = emptyList(), completedStages = emptyList())
private val ANY_CONTINUING_REGISTRATION_RESULT = RegistrationResult.FlowResponse(AN_IGNORED_FLOW_RESULT) private val ANY_CONTINUING_REGISTRATION_RESULT = RegistrationResult.FlowResponse(AN_IGNORED_FLOW_RESULT)
private val A_LOGIN_OR_REGISTER_ACTION = OnboardingAction.LoginOrRegister("@a-user:id.org", "a-password", "a-device-name")
class OnboardingViewModelTest { class OnboardingViewModelTest {
@ -69,6 +72,7 @@ class OnboardingViewModelTest {
private val fakeActiveSessionHolder = FakeActiveSessionHolder(fakeSession) private val fakeActiveSessionHolder = FakeActiveSessionHolder(fakeSession)
private val fakeAuthenticationService = FakeAuthenticationService() private val fakeAuthenticationService = FakeAuthenticationService()
private val fakeRegisterActionHandler = FakeRegisterActionHandler() private val fakeRegisterActionHandler = FakeRegisterActionHandler()
private val fakeDirectLoginUseCase = FakeDirectLoginUseCase()
lateinit var viewModel: OnboardingViewModel lateinit var viewModel: OnboardingViewModel
@ -114,6 +118,26 @@ class OnboardingViewModelTest {
.finish() .finish()
} }
@Test
fun `given has sign in with matrix id sign mode, when handling login or register action, then logs in directly`() = runTest {
val initialState = initialState.copy(signMode = SignMode.SignInWithMatrixId)
viewModel = createViewModel(initialState)
fakeDirectLoginUseCase.givenSuccessResult(A_LOGIN_OR_REGISTER_ACTION, config = null, result = fakeSession)
givenInitialisesSession(fakeSession)
val test = viewModel.test()
viewModel.handle(A_LOGIN_OR_REGISTER_ACTION)
test
.assertStatesChanges(
initialState,
{ copy(isLoading = true) },
{ copy(isLoading = false) }
)
.assertEvents(OnboardingViewEvents.OnAccountSignedIn)
.finish()
}
@Test @Test
fun `when handling SignUp then sets sign mode to sign up and starts registration`() = runTest { fun `when handling SignUp then sets sign mode to sign up and starts registration`() = runTest {
givenRegistrationResultFor(RegisterAction.StartRegistration, ANY_CONTINUING_REGISTRATION_RESULT) givenRegistrationResultFor(RegisterAction.StartRegistration, ANY_CONTINUING_REGISTRATION_RESULT)
@ -344,6 +368,7 @@ class OnboardingViewModelTest {
FakeAnalyticsTracker(), FakeAnalyticsTracker(),
fakeUriFilenameResolver.instance, fakeUriFilenameResolver.instance,
fakeRegisterActionHandler.instance, fakeRegisterActionHandler.instance,
fakeDirectLoginUseCase.instance,
FakeVectorOverrides() FakeVectorOverrides()
) )
} }
@ -384,7 +409,11 @@ class OnboardingViewModelTest {
private fun givenSuccessfullyCreatesAccount(homeServerCapabilities: HomeServerCapabilities) { private fun givenSuccessfullyCreatesAccount(homeServerCapabilities: HomeServerCapabilities) {
fakeSession.fakeHomeServerCapabilitiesService.givenCapabilities(homeServerCapabilities) fakeSession.fakeHomeServerCapabilitiesService.givenCapabilities(homeServerCapabilities)
fakeActiveSessionHolder.expectSetsActiveSession(fakeSession) givenInitialisesSession(fakeSession)
}
private fun givenInitialisesSession(session: Session) {
fakeActiveSessionHolder.expectSetsActiveSession(session)
fakeAuthenticationService.expectReset() fakeAuthenticationService.expectReset()
fakeSession.expectStartsSyncing() fakeSession.expectStartsSyncing()
} }

View File

@ -0,0 +1,31 @@
/*
* Copyright (c) 2022 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package im.vector.app.test.fakes
import im.vector.app.features.onboarding.DirectLoginUseCase
import im.vector.app.features.onboarding.OnboardingAction
import io.mockk.coEvery
import io.mockk.mockk
import org.matrix.android.sdk.api.auth.data.HomeServerConnectionConfig
class FakeDirectLoginUseCase {
val instance = mockk<DirectLoginUseCase>()
fun givenSuccessResult(action: OnboardingAction.LoginOrRegister, config: HomeServerConnectionConfig?, result: FakeSession) {
coEvery { instance.execute(action, config) } returns Result.success(result)
}
}