adding support for signing in to homeservers without wellknown setup

This commit is contained in:
Adam Brown 2022-04-13 21:04:23 +01:00
parent 32494b961b
commit e1fd79de02
14 changed files with 218 additions and 68 deletions

View File

@ -22,7 +22,7 @@
### Feature list
- Login with username/password (homeservers must serve `https://${domain}/.well-known/matrix/client`)
- Login with Matrix ID/Password
- Combined Room and DM interface
- End to end encryption
- Message bubbles, supporting text, replies and edits

View File

@ -1,5 +1,6 @@
package app.dapk.st.login
import android.widget.Toast
import androidx.compose.foundation.layout.*
import androidx.compose.foundation.text.KeyboardActions
import androidx.compose.foundation.text.KeyboardOptions
@ -7,6 +8,7 @@ import androidx.compose.material.*
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.Visibility
import androidx.compose.material.icons.filled.VisibilityOff
import androidx.compose.material.icons.filled.Web
import androidx.compose.material.icons.outlined.Lock
import androidx.compose.runtime.*
import androidx.compose.runtime.saveable.rememberSaveable
@ -14,6 +16,7 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.ExperimentalComposeUiApi
import androidx.compose.ui.Modifier
import androidx.compose.ui.focus.FocusDirection
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalFocusManager
import androidx.compose.ui.platform.LocalSoftwareKeyboardController
import androidx.compose.ui.text.font.FontWeight
@ -37,9 +40,10 @@ fun LoginScreen(loginViewModel: LoginViewModel, onLoggedIn: () -> Unit) {
var userName by rememberSaveable { mutableStateOf("") }
var password by rememberSaveable { mutableStateOf("") }
var serverUrl by rememberSaveable { mutableStateOf("") }
val keyboardController = LocalSoftwareKeyboardController.current
when (loginViewModel.state) {
when (val state = loginViewModel.state) {
is Error -> {
Box(contentAlignment = Alignment.Center, modifier = Modifier.fillMaxSize()) {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
@ -58,7 +62,7 @@ fun LoginScreen(loginViewModel: LoginViewModel, onLoggedIn: () -> Unit) {
CircularProgressIndicator()
}
}
Idle ->
is Content ->
Row {
Spacer(modifier = Modifier.weight(0.1f))
Column(
@ -91,7 +95,11 @@ fun LoginScreen(loginViewModel: LoginViewModel, onLoggedIn: () -> Unit) {
keyboardOptions = KeyboardOptions(autoCorrect = false, keyboardType = KeyboardType.Email, imeAction = ImeAction.Next)
)
val canDoLoginAttempt = userName.isNotEmpty() && password.isNotEmpty()
val canDoLoginAttempt = if (state.showServerUrl) {
userName.isNotEmpty() && password.isNotEmpty() && serverUrl.isNotEmpty()
} else {
userName.isNotEmpty() && password.isNotEmpty()
}
TextField(
modifier = Modifier.fillMaxWidth(),
@ -102,10 +110,13 @@ fun LoginScreen(loginViewModel: LoginViewModel, onLoggedIn: () -> Unit) {
leadingIcon = {
Icon(imageVector = Icons.Outlined.Lock, contentDescription = null)
},
keyboardActions = KeyboardActions(onDone = { loginViewModel.login(userName, password) }),
keyboardActions = KeyboardActions(
onDone = { loginViewModel.login(userName, password, serverUrl) },
onNext = { focusManager.moveFocus(FocusDirection.Down) },
),
keyboardOptions = KeyboardOptions(
autoCorrect = false,
imeAction = ImeAction.Done.takeIf { canDoLoginAttempt } ?: ImeAction.None,
imeAction = ImeAction.Done.takeIf { canDoLoginAttempt } ?: ImeAction.Next.takeIf { state.showServerUrl } ?: ImeAction.None,
keyboardType = KeyboardType.Password
),
visualTransformation = if (passwordVisibility) VisualTransformation.None else PasswordVisualTransformation(),
@ -117,13 +128,32 @@ fun LoginScreen(loginViewModel: LoginViewModel, onLoggedIn: () -> Unit) {
}
)
if (state.showServerUrl) {
TextField(
modifier = Modifier.fillMaxWidth(),
value = serverUrl,
onValueChange = { serverUrl = it },
label = { Text("Server URL") },
singleLine = true,
leadingIcon = {
Icon(imageVector = Icons.Default.Web, contentDescription = null)
},
keyboardActions = KeyboardActions(onDone = { loginViewModel.login(userName, password, serverUrl) }),
keyboardOptions = KeyboardOptions(
autoCorrect = false,
imeAction = ImeAction.Done.takeIf { canDoLoginAttempt } ?: ImeAction.None,
keyboardType = KeyboardType.Uri
),
)
}
Spacer(Modifier.height(4.dp))
Button(
modifier = Modifier.fillMaxWidth(),
onClick = {
keyboardController?.hide()
loginViewModel.login(userName, password)
loginViewModel.login(userName, password, serverUrl)
},
enabled = canDoLoginAttempt
) {
@ -137,10 +167,14 @@ fun LoginScreen(loginViewModel: LoginViewModel, onLoggedIn: () -> Unit) {
@Composable
private fun LoginViewModel.ObserveEvents(onLoggedIn: () -> Unit) {
val context = LocalContext.current
StartObserving {
this@ObserveEvents.events.launch {
when (it) {
LoginComplete -> onLoggedIn()
LoginEvent.WellKnownMissing -> {
Toast.makeText(context, "Couldn't find the homeserver, please enter the server URL", Toast.LENGTH_LONG).show()
}
}
}
}

View File

@ -2,12 +2,13 @@ package app.dapk.st.login
sealed interface LoginScreenState {
object Idle : LoginScreenState
data class Content(val showServerUrl: Boolean) : LoginScreenState
object Loading : LoginScreenState
data class Error(val cause: Throwable) : LoginScreenState
}
sealed interface LoginEvent {
object LoginComplete : LoginEvent
object WellKnownMissing : LoginEvent
}

View File

@ -19,33 +19,45 @@ class LoginViewModel(
private val profileService: ProfileService,
private val errorTracker: ErrorTracker,
) : DapkViewModel<LoginScreenState, LoginEvent>(
initialState = Idle
initialState = Content(showServerUrl = false)
) {
fun login(userName: String, password: String) {
private var previousState: LoginScreenState? = null
fun login(userName: String, password: String, serverUrl: String?) {
state = Loading
viewModelScope.launch {
kotlin.runCatching {
logP("login") {
authService.login(userName, password).also {
when (val result = authService.login(AuthService.LoginRequest(userName, password, serverUrl.takeIfNotEmpty()))) {
is AuthService.LoginResult.Success -> {
runCatching {
listOf(
async { firebasePushTokenUseCase.registerCurrentToken() },
async { preloadMe() },
).awaitAll()
}
}
}.onFailure {
errorTracker.track(it)
state = Error(it)
}.onSuccess {
_events.tryEmit(LoginComplete)
}
is AuthService.LoginResult.Error -> {
errorTracker.track(result.cause)
state = Error(result.cause)
}
AuthService.LoginResult.MissingWellKnown -> {
_events.tryEmit(LoginEvent.WellKnownMissing)
state = Content(showServerUrl = true)
}
}
}
}
}
private suspend fun preloadMe() = profileService.me(forceRefresh = false)
fun start() {
state = Idle
val showServerUrl = previousState?.let { it is Content && it.showServerUrl } ?: false
state = Content(showServerUrl = showServerUrl)
}
}
private fun String?.takeIfNotEmpty() = this?.takeIf { it.isNotEmpty() }

View File

@ -4,6 +4,6 @@ fun String.ensureTrailingSlash(): String {
return if (this.endsWith("/")) this else "$this/"
}
fun String.ensureHttps(): String {
return if (this.startsWith("https")) this else "https://$this"
fun String.ensureHttpsIfMissing(): String {
return if (this.startsWith("http")) this else "https://$this"
}

View File

@ -10,16 +10,24 @@ import app.dapk.st.matrix.common.UserCredentials
private val SERVICE_KEY = AuthService::class
interface AuthService : MatrixService {
suspend fun login(userName: String, password: String): UserCredentials
suspend fun login(request: LoginRequest): LoginResult
suspend fun register(userName: String, password: String, homeServer: String): UserCredentials
sealed interface LoginResult {
data class Success(val userCredentials: UserCredentials) : LoginResult
object MissingWellKnown : LoginResult
data class Error(val cause: Throwable) : LoginResult
}
data class LoginRequest(val userName: String, val password: String, val serverUrl: String?)
}
fun MatrixServiceInstaller.installAuthService(
credentialsStore: CredentialsStore,
authConfig: AuthConfig = AuthConfig(),
) {
this.install { (httpClient, json) ->
SERVICE_KEY to DefaultAuthService(httpClient, credentialsStore, json, authConfig)
SERVICE_KEY to DefaultAuthService(httpClient, credentialsStore, json)
}
}

View File

@ -49,6 +49,10 @@ internal fun wellKnownRequest(baseUrl: String) = httpRequest<String>(
authenticated = false,
)
@JvmInline
@Serializable
internal value class RawResponse(val value: String)
internal data class Auth(
val session: String,
val type: String,

View File

@ -1,25 +1,33 @@
package app.dapk.st.matrix.auth.internal
import app.dapk.st.matrix.auth.AuthConfig
import app.dapk.st.matrix.auth.AuthService
import app.dapk.st.matrix.common.CredentialsStore
import app.dapk.st.matrix.common.HomeServerUrl
import app.dapk.st.matrix.common.UserCredentials
import app.dapk.st.matrix.http.MatrixHttpClient
import app.dapk.st.matrix.http.ensureHttpsIfMissing
import app.dapk.st.matrix.http.ensureTrailingSlash
import kotlinx.serialization.json.Json
internal class DefaultAuthService(
httpClient: MatrixHttpClient,
credentialsStore: CredentialsStore,
json: Json,
authConfig: AuthConfig,
) : AuthService {
private val fetchWellKnownUseCase = FetchWellKnownUseCaseImpl(httpClient, json)
private val loginUseCase = LoginUseCase(httpClient, credentialsStore, fetchWellKnownUseCase, authConfig)
private val registerCase = RegisterUseCase(httpClient, credentialsStore, json, fetchWellKnownUseCase, authConfig)
private val loginUseCase = LoginWithUserPasswordUseCase(httpClient, credentialsStore, fetchWellKnownUseCase)
private val loginServerUseCase = LoginWithUserPasswordServerUseCase(httpClient, credentialsStore)
private val registerCase = RegisterUseCase(httpClient, credentialsStore, json, fetchWellKnownUseCase)
override suspend fun login(userName: String, password: String): UserCredentials {
return loginUseCase.login(userName, password)
override suspend fun login(request: AuthService.LoginRequest): AuthService.LoginResult {
return when {
request.serverUrl == null -> loginUseCase.login(request.userName, request.password)
else -> {
val serverUrl = HomeServerUrl(request.serverUrl.ensureHttpsIfMissing().ensureTrailingSlash())
loginServerUseCase.login(request.userName, request.password, serverUrl)
}
}
}
override suspend fun register(userName: String, password: String, homeServer: String): UserCredentials {

View File

@ -1,19 +1,50 @@
package app.dapk.st.matrix.auth.internal
import app.dapk.st.matrix.http.MatrixHttpClient
import io.ktor.client.plugins.*
import io.ktor.http.*
import kotlinx.serialization.SerializationException
import kotlinx.serialization.json.Json
import java.net.UnknownHostException
internal typealias FetchWellKnownUseCase = suspend (String) -> ApiWellKnown
internal typealias FetchWellKnownUseCase = suspend (String) -> WellKnownResult
internal class FetchWellKnownUseCaseImpl(
private val httpClient: MatrixHttpClient,
private val json: Json,
) : FetchWellKnownUseCase {
override suspend fun invoke(domainUrl: String): ApiWellKnown {
// workaround for matrix.org not returning a content-type
val raw = httpClient.execute(wellKnownRequest(domainUrl))
return json.decodeFromString(ApiWellKnown.serializer(), raw)
override suspend fun invoke(domainUrl: String): WellKnownResult {
return runCatching {
val rawResponse = httpClient.execute(rawWellKnownRequestForServersWithoutContentTypes(domainUrl))
json.decodeFromString(ApiWellKnown.serializer(), rawResponse)
}
.fold(
onSuccess = { WellKnownResult.Success(it) },
onFailure = {
when (it) {
is UnknownHostException -> WellKnownResult.MissingWellKnown
is ClientRequestException -> when {
it.response.status.is404() -> WellKnownResult.MissingWellKnown
else -> WellKnownResult.Error(it)
}
is SerializationException -> WellKnownResult.InvalidWellKnown
else -> WellKnownResult.Error(it)
}
},
)
}
private fun rawWellKnownRequestForServersWithoutContentTypes(domainUrl: String) = wellKnownRequest(domainUrl)
}
sealed interface WellKnownResult {
data class Success(val wellKnown: ApiWellKnown) : WellKnownResult
object MissingWellKnown : WellKnownResult
object InvalidWellKnown : WellKnownResult
data class Error(val cause: Throwable) : WellKnownResult
}
fun HttpStatusCode.is404() = this.value == 404

View File

@ -0,0 +1,33 @@
package app.dapk.st.matrix.auth.internal
import app.dapk.st.matrix.auth.AuthService
import app.dapk.st.matrix.common.CredentialsStore
import app.dapk.st.matrix.common.HomeServerUrl
import app.dapk.st.matrix.common.UserCredentials
import app.dapk.st.matrix.common.UserId
import app.dapk.st.matrix.http.MatrixHttpClient
class LoginWithUserPasswordServerUseCase(
private val httpClient: MatrixHttpClient,
private val credentialsProvider: CredentialsStore,
) {
suspend fun login(userName: String, password: String, serverUrl: HomeServerUrl): AuthService.LoginResult {
return runCatching {
authenticate(serverUrl, UserId(userName.substringBefore(":")), password)
}.fold(
onSuccess = { AuthService.LoginResult.Success(it) },
onFailure = { AuthService.LoginResult.Error(it) }
)
}
private suspend fun authenticate(baseUrl: HomeServerUrl, fullUserId: UserId, password: String): UserCredentials {
val authResponse = httpClient.execute(loginRequest(fullUserId, password, baseUrl.value))
return UserCredentials(
authResponse.accessToken,
baseUrl,
authResponse.userId,
authResponse.deviceId,
).also { credentialsProvider.update(it) }
}
}

View File

@ -1,6 +1,6 @@
package app.dapk.st.matrix.auth.internal
import app.dapk.st.matrix.auth.AuthConfig
import app.dapk.st.matrix.auth.AuthService
import app.dapk.st.matrix.common.CredentialsStore
import app.dapk.st.matrix.common.HomeServerUrl
import app.dapk.st.matrix.common.UserCredentials
@ -10,23 +10,27 @@ import app.dapk.st.matrix.http.ensureTrailingSlash
private const val MATRIX_DOT_ORG_DOMAIN = "matrix.org"
class LoginUseCase(
class LoginWithUserPasswordUseCase(
private val httpClient: MatrixHttpClient,
private val credentialsProvider: CredentialsStore,
private val fetchWellKnownUseCase: FetchWellKnownUseCase,
private val authConfig: AuthConfig
) {
suspend fun login(userName: String, password: String): UserCredentials {
suspend fun login(userName: String, password: String): AuthService.LoginResult {
val (domainUrl, fullUserId) = generateUserAccessInfo(userName)
val baseUrl = fetchWellKnownUseCase(domainUrl).homeServer.baseUrl.ensureTrailingSlash()
val authResponse = httpClient.execute(loginRequest(fullUserId, password, baseUrl.value))
return UserCredentials(
authResponse.accessToken,
baseUrl,
authResponse.userId,
authResponse.deviceId,
).also { credentialsProvider.update(it) }
return when (val wellKnownResult = fetchWellKnownUseCase(domainUrl)) {
is WellKnownResult.Success -> {
runCatching {
authenticate(wellKnownResult.wellKnown.homeServer.baseUrl.ensureTrailingSlash(), fullUserId, password)
}.fold(
onSuccess = { AuthService.LoginResult.Success(it) },
onFailure = { AuthService.LoginResult.Error(it) }
)
}
WellKnownResult.InvalidWellKnown -> AuthService.LoginResult.MissingWellKnown
WellKnownResult.MissingWellKnown -> AuthService.LoginResult.MissingWellKnown
is WellKnownResult.Error -> AuthService.LoginResult.Error(wellKnownResult.cause)
}
}
private fun generateUserAccessInfo(userName: String): Pair<String, UserId> {
@ -37,14 +41,20 @@ class LoginUseCase(
return Pair(domainUrl, UserId(fullUserId))
}
private suspend fun authenticate(baseUrl: HomeServerUrl, fullUserId: UserId, password: String): UserCredentials {
val authResponse = httpClient.execute(loginRequest(fullUserId, password, baseUrl.value))
return UserCredentials(
authResponse.accessToken,
baseUrl,
authResponse.userId,
authResponse.deviceId,
).also { credentialsProvider.update(it) }
}
private fun String.findDomain(fallback: String) = this.substringAfter(":", missingDelimiterValue = fallback)
private fun String.asHttpsUrl(): String {
val schema = when (authConfig.forceHttp) {
true -> "http://"
false -> "https://"
}
return "$schema$this".ensureTrailingSlash()
return "https://$this".ensureTrailingSlash()
}
}

View File

@ -1,6 +1,5 @@
package app.dapk.st.matrix.auth.internal
import app.dapk.st.matrix.auth.AuthConfig
import app.dapk.st.matrix.common.CredentialsStore
import app.dapk.st.matrix.common.UserCredentials
import app.dapk.st.matrix.http.MatrixHttpClient
@ -16,7 +15,6 @@ class RegisterUseCase(
private val credentialsProvider: CredentialsStore,
private val json: Json,
private val fetchWellKnownUseCase: FetchWellKnownUseCase,
private val authConfig: AuthConfig,
) {
suspend fun register(userName: String, password: String, homeServer: String): UserCredentials {
@ -46,7 +44,12 @@ class RegisterUseCase(
registerRequest(userName, password, baseUrl, Auth(session, "m.login.dummy"))
)
val homeServerUrl = when (authResponse.wellKnown == null) {
true -> fetchWellKnownUseCase(baseUrl).homeServer.baseUrl
true -> when (val wellKnownResult = fetchWellKnownUseCase(baseUrl)) {
is WellKnownResult.Error, -> TODO()
WellKnownResult.InvalidWellKnown -> TODO()
WellKnownResult.MissingWellKnown -> TODO()
is WellKnownResult.Success -> wellKnownResult.wellKnown.homeServer.baseUrl
}
false -> authResponse.wellKnown.homeServer.baseUrl
}
return UserCredentials(

View File

@ -1,3 +1,4 @@
import app.dapk.st.matrix.auth.AuthService
import app.dapk.st.matrix.auth.authService
import app.dapk.st.matrix.common.HomeServerUrl
import app.dapk.st.matrix.common.RoomId
@ -11,6 +12,7 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.test.runTest
import org.amshove.kluent.shouldBeEqualTo
import org.amshove.kluent.shouldBeInstanceOf
import org.amshove.kluent.shouldNotBeEqualTo
import org.junit.jupiter.api.MethodOrderer
import org.junit.jupiter.api.Order
@ -131,13 +133,17 @@ private suspend fun login(user: TestUser) {
val result = testMatrix
.client
.authService()
.login(userName = user.roomMember.id.value, password = user.password)
.login(AuthService.LoginRequest(userName = user.roomMember.id.value, password = user.password, serverUrl = null))
result.accessToken shouldNotBeEqualTo null
result.homeServer shouldBeEqualTo HomeServerUrl(HTTPS_TEST_SERVER_URL)
result.userId shouldBeEqualTo user.roomMember.id
testMatrix.saveLogin(result)
result shouldBeInstanceOf AuthService.LoginResult.Success::class.java
(result as AuthService.LoginResult.Success).userCredentials.let { credentials ->
credentials.accessToken shouldNotBeEqualTo null
credentials.homeServer shouldBeEqualTo HomeServerUrl(HTTPS_TEST_SERVER_URL)
credentials.userId shouldBeEqualTo user.roomMember.id
testMatrix.saveLogin(credentials)
}
}
object SharedState {

View File

@ -6,7 +6,7 @@ import app.dapk.st.core.CoroutineDispatchers
import app.dapk.st.core.SingletonFlows
import app.dapk.st.domain.StoreModule
import app.dapk.st.matrix.MatrixClient
import app.dapk.st.matrix.auth.AuthConfig
import app.dapk.st.matrix.auth.AuthService
import app.dapk.st.matrix.auth.authService
import app.dapk.st.matrix.auth.installAuthService
import app.dapk.st.matrix.common.*
@ -79,7 +79,7 @@ class TestMatrix(
logger
).also {
it.install {
installAuthService(storeModule.credentialsStore(), AuthConfig(forceHttp = false))
installAuthService(storeModule.credentialsStore())
installEncryptionService(storeModule.knownDevicesStore())
val base64 = JavaBase64()
@ -255,7 +255,7 @@ class TestMatrix(
suspend fun newlogin() {
client.authService()
.login(user.roomMember.id.value, user.password)
.login(AuthService.LoginRequest(user.roomMember.id.value, user.password, null))
}
suspend fun restoreLogin() {