2019-12-23 20:30:32 +01:00
package writefreely
import (
"context"
"database/sql"
"github.com/stretchr/testify/assert"
"testing"
)
func TestOAuthDatastore ( t * testing . T ) {
if ! runMySQLTests ( ) {
t . Skip ( "skipping mysql tests" )
}
withTestDB ( t , func ( db * sql . DB ) {
ctx := context . Background ( )
ds := & datastore {
DB : db ,
driverName : "" ,
}
2020-04-21 00:18:23 +02:00
state , err := ds . GenerateOAuthState ( ctx , "test" , "development" , 0 , "" )
2019-12-23 20:30:32 +01:00
assert . NoError ( t , err )
assert . Len ( t , state , 24 )
2019-12-31 17:28:05 +01:00
countRows ( t , ctx , db , 1 , "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false" , state )
2019-12-23 20:30:32 +01:00
2020-04-21 00:18:23 +02:00
_ , _ , _ , _ , err = ds . ValidateOAuthState ( ctx , state )
2019-12-23 20:30:32 +01:00
assert . NoError ( t , err )
2019-12-31 17:28:05 +01:00
countRows ( t , ctx , db , 1 , "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true" , state )
2019-12-23 20:30:32 +01:00
var localUserID int64 = 99
2019-12-28 21:15:47 +01:00
var remoteUserID = "100"
2019-12-30 19:32:06 +01:00
err = ds . RecordRemoteUserID ( ctx , localUserID , remoteUserID , "test" , "test" , "access_token_a" )
2019-12-23 20:30:32 +01:00
assert . NoError ( t , err )
2019-12-31 17:48:08 +01:00
countRows ( t , ctx , db , 1 , "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_a'" , localUserID , remoteUserID )
2019-12-23 20:30:32 +01:00
2019-12-30 19:32:06 +01:00
err = ds . RecordRemoteUserID ( ctx , localUserID , remoteUserID , "test" , "test" , "access_token_b" )
assert . NoError ( t , err )
2019-12-31 17:48:08 +01:00
countRows ( t , ctx , db , 1 , "SELECT COUNT(*) FROM `oauth_users` WHERE `user_id` = ? AND `remote_user_id` = ? AND access_token = 'access_token_b'" , localUserID , remoteUserID )
2019-12-30 19:32:06 +01:00
2019-12-31 17:48:08 +01:00
countRows ( t , ctx , db , 1 , "SELECT COUNT(*) FROM `oauth_users`" )
2019-12-30 19:32:06 +01:00
foundUserID , err := ds . GetIDForRemoteUser ( ctx , remoteUserID , "test" , "test" )
2019-12-23 20:30:32 +01:00
assert . NoError ( t , err )
assert . Equal ( t , localUserID , foundUserID )
} )
}