diff --git a/internal/api/client.go b/internal/api/client.go index e948d74c5..b365441ef 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -19,6 +19,8 @@ package api import ( + "time" + "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts" "github.com/superseriousbusiness/gotosocial/internal/api/client/admin" @@ -122,7 +124,7 @@ func NewClient(db db.DB, p processing.Processor) *Client { notifications: notifications.New(p), search: search.New(p), statuses: statuses.New(p), - streaming: streaming.New(p), + streaming: streaming.New(p, time.Second*30, 4096), timelines: timelines.New(p), user: user.New(p), } diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index fc14e87e3..c175c8461 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -19,8 +19,9 @@ package streaming import ( + "context" + "errors" "fmt" - "net/http" "time" "codeberg.org/gruf/go-kv" @@ -32,16 +33,6 @@ import ( "github.com/gorilla/websocket" ) -var ( - wsUpgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - // we expect cors requests (via eg., pinafore.social) so be lenient - CheckOrigin: func(r *http.Request) bool { return true }, - } - errNoToken = fmt.Errorf("no access token provided under query key %s or under header %s", AccessTokenQueryKey, AccessTokenHeader) -) - // StreamGETHandler swagger:operation GET /api/v1/streaming streamGet // // Initiate a websocket connection for live streaming of statuses and notifications. @@ -150,21 +141,20 @@ func (m *Module) StreamGETHandler(c *gin.Context) { return } - var accessToken string - if t := c.Query(AccessTokenQueryKey); t != "" { - // try query param first - accessToken = t - } else if t := c.GetHeader(AccessTokenHeader); t != "" { - // fall back to Sec-Websocket-Protocol - accessToken = t - } else { - // no token - err := errNoToken - apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) - return + var token string + + // First we check for a query param provided access token + if token = c.Query(AccessTokenQueryKey); token == "" { + // Else we check the HTTP header provided token + if token = c.GetHeader(AccessTokenHeader); token == "" { + const errStr = "no access token provided" + err := gtserror.NewErrorUnauthorized(errors.New(errStr), errStr) + apiutil.ErrorHandler(c, err, m.processor.InstanceGet) + return + } } - account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), accessToken) + account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), token) if errWithCode != nil { apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) return @@ -178,51 +168,97 @@ func (m *Module) StreamGETHandler(c *gin.Context) { l := log.WithFields(kv.Fields{ {"account", account.Username}, - {"path", BasePath}, {"streamID", stream.ID}, {"streamType", streamType}, }...) - wsConn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil) + // Upgrade the incoming HTTP request, which hijacks the underlying + // connection and reuses it for the websocket (non-http) protocol. + wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil) if err != nil { - // If the upgrade fails, then Upgrade replies to the client with an HTTP error response. - // Because websocket issues are a pretty common source of headaches, we should also log - // this at Error to make this plenty visible and help admins out a bit. - l.Errorf("error upgrading websocket connection: %s", err) + l.Errorf("error upgrading websocket connection: %v", err) close(stream.Hangup) return } - defer func() { - // cleanup - wsConn.Close() - close(stream.Hangup) - }() + go func() { + // We perform the main websocket send loop in a separate + // goroutine in order to let the upgrade handler return. + // This prevents the upgrade handler from holding open any + // throttle / rate-limit request tokens which could become + // problematic on instances with multiple users. + l.Info("opened websocket connection") + defer l.Info("closed websocket connection") - streamTicker := time.NewTicker(m.tickDuration) - defer streamTicker.Stop() + // Create new context for lifetime of the connection + ctx, cncl := context.WithCancel(context.Background()) - // We want to stay in the loop as long as possible while the client is connected. - // The only thing that should break the loop is if the client leaves or the connection becomes unhealthy. - // - // If the loop does break, we expect the client to reattempt connection, so it's cheap to leave + try again -wsLoop: - for { - select { - case m := <-stream.Messages: - l.Trace("received message from stream") - if err := wsConn.WriteJSON(m); err != nil { - l.Debugf("error writing json to websocket connection; breaking off: %s", err) - break wsLoop + // Create ticker to send alive pings + pinger := time.NewTicker(m.dTicker) + + defer func() { + // Signal done + cncl() + + // Close websocket conn + _ = wsConn.Close() + + // Close processor stream + close(stream.Hangup) + + // Stop ping ticker + pinger.Stop() + }() + + go func() { + // Signal done + defer cncl() + + for { + // We have to listen for received websocket messages in + // order to trigger the underlying wsConn.PingHandler(). + // + // So we wait on received messages but only act on errors. + _, _, err := wsConn.ReadMessage() + if err != nil { + if ctx.Err() == nil { + // Only log error if the connection was not closed + // by us. Uncanceled context indicates this is the case. + l.Errorf("error reading from websocket: %v", err) + } + return + } } - l.Trace("wrote message into websocket connection") - case <-streamTicker.C: - l.Trace("received TICK from ticker") - if err := wsConn.WriteMessage(websocket.PingMessage, []byte(": ping")); err != nil { - l.Debugf("error writing ping to websocket connection; breaking off: %s", err) - break wsLoop + }() + + for { + select { + // Connection closed + case <-ctx.Done(): + return + + // Received next stream message + case msg := <-stream.Messages: + l.Tracef("sending message to websocket: %+v", msg) + if err := wsConn.WriteJSON(msg); err != nil { + l.Errorf("error writing json to websocket: %v", err) + return + } + + // Reset on each successful send. + pinger.Reset(m.dTicker) + + // Send keep-alive "ping" + case <-pinger.C: + l.Trace("pinging websocket ...") + if err := wsConn.WriteMessage( + websocket.PingMessage, + []byte{}, + ); err != nil { + l.Errorf("error writing ping to websocket: %v", err) + return + } } - l.Trace("wrote ping message into websocket connection") } - } + }() } diff --git a/internal/api/client/streaming/streaming.go b/internal/api/client/streaming/streaming.go index c23f03c81..d4c61f7a0 100644 --- a/internal/api/client/streaming/streaming.go +++ b/internal/api/client/streaming/streaming.go @@ -23,6 +23,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "github.com/superseriousbusiness/gotosocial/internal/processing" ) @@ -41,21 +42,22 @@ const ( ) type Module struct { - processor processing.Processor - tickDuration time.Duration + processor processing.Processor + dTicker time.Duration + wsUpgrade websocket.Upgrader } -func New(processor processing.Processor) *Module { +func New(processor processing.Processor, dTicker time.Duration, wsBuf int) *Module { return &Module{ - processor: processor, - tickDuration: 30 * time.Second, - } -} + processor: processor, + dTicker: dTicker, + wsUpgrade: websocket.Upgrader{ + ReadBufferSize: wsBuf, // we don't expect reads + WriteBufferSize: wsBuf, -func NewWithTickDuration(processor processing.Processor, tickDuration time.Duration) *Module { - return &Module{ - processor: processor, - tickDuration: tickDuration, + // we expect cors requests (via eg., pinafore.social) so be lenient + CheckOrigin: func(r *http.Request) bool { return true }, + }, } } diff --git a/internal/api/client/streaming/streaming_test.go b/internal/api/client/streaming/streaming_test.go index 22b077464..f713607bb 100644 --- a/internal/api/client/streaming/streaming_test.go +++ b/internal/api/client/streaming/streaming_test.go @@ -99,7 +99,7 @@ func (suite *StreamingTestSuite) SetupTest() { suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) - suite.streamingModule = streaming.NewWithTickDuration(suite.processor, 1) + suite.streamingModule = streaming.New(suite.processor, 1, 4096) suite.NoError(suite.processor.Start()) }