[bugfix] fix possible mutex lockup during streaming code (#2633)

* rewrite Stream{} to use much less mutex locking, update related code

* use new context for the stream context

* ensure stream gets closed on return of writeTo / readFrom WSConn()

* ensure stream write timeout gets cancelled

* remove embedded context type from Stream{}, reformat log messages for consistency

* use c.Request.Context() for context passed into Stream().Open()

* only return 1 boolean, fix tests to expect multiple stream types in messages

* changes to ping logic

* further improved ping logic

* don't export unused function types, update message sending to only include relevant stream type

* ensure stream gets closed 🤦

* update to error log on failed json marshal (instead of panic)

* inverse websocket read error checking to _ignore_ expected close errors
This commit is contained in:
kim
2024-02-20 18:07:49 +00:00
committed by GitHub
parent 8cafa6b74b
commit 291e180990
14 changed files with 535 additions and 451 deletions

View File

@@ -22,10 +22,10 @@ import (
"slices"
"time"
"codeberg.org/gruf/go-kv"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
streampkg "github.com/superseriousbusiness/gotosocial/internal/stream"
@@ -202,7 +202,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
// functions pass messages into a channel, which we can
// then read from and put into a websockets connection.
stream, errWithCode := m.processor.Stream().Open(
c.Request.Context(),
c.Request.Context(), // this ctx is only used for logging
account,
streamType,
)
@@ -213,10 +213,8 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
l := log.
WithContext(c.Request.Context()).
WithFields(kv.Fields{
{"username", account.Username},
{"streamID", stream.ID},
}...)
WithField("streamID", id.NewULID()).
WithField("username", account.Username)
// Upgrade the incoming HTTP request. This hijacks the
// underlying connection and reuses it for the websocket
@@ -227,18 +225,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
if err != nil {
l.Errorf("error upgrading websocket connection: %v", err)
close(stream.Hangup)
stream.Close()
return
}
l.Info("opened websocket connection")
// We perform the main websocket rw loops in a separate
// goroutine in order to let the upgrade handler return.
// This prevents the upgrade handler from holding open any
// throttle / rate-limit request tokens which could become
// problematic on instances with multiple users.
go m.handleWSConn(account.Username, wsConn, stream)
go m.handleWSConn(&l, wsConn, stream)
}
// handleWSConn handles a two-way websocket streaming connection.
@@ -246,48 +242,39 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
// into the connection. If any errors are encountered while reading
// or writing (including expected errors like clients leaving), the
// connection will be closed.
func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) {
// Create new context for the lifetime of this connection.
ctx, cancel := context.WithCancel(context.Background())
func (m *Module) handleWSConn(l *log.Entry, wsConn *websocket.Conn, stream *streampkg.Stream) {
l.Info("opened websocket connection")
l := log.
WithContext(ctx).
WithFields(kv.Fields{
{"username", username},
{"streamID", stream.ID},
}...)
// Create new async context with cancel.
ctx, cncl := context.WithCancel(context.Background())
// Create ticker to send keepalive pings
pinger := time.NewTicker(m.dTicker)
// Read messages coming from the Websocket client connection into the server.
go func() {
defer cancel()
m.readFromWSConn(ctx, username, wsConn, stream)
defer cncl()
// Read messages from websocket to server.
m.readFromWSConn(ctx, wsConn, stream, l)
}()
// Write messages coming from the processor into the Websocket client connection.
go func() {
defer cancel()
m.writeToWSConn(ctx, username, wsConn, stream, pinger)
defer cncl()
// Write messages from processor in websocket conn.
m.writeToWSConn(ctx, wsConn, stream, m.dTicker, l)
}()
// Wait for either the read or write functions to close, to indicate
// that the client has left, or something else has gone wrong.
// Wait for ctx
// to be closed.
<-ctx.Done()
// Close stream
// straightaway.
stream.Close()
// Tidy up underlying websocket connection.
if err := wsConn.Close(); err != nil {
l.Errorf("error closing websocket connection: %v", err)
}
// Close processor channel so the processor knows
// not to send any more messages to this stream.
close(stream.Hangup)
// Stop ping ticker (tiny resource saving).
pinger.Stop()
l.Info("closed websocket connection")
}
@@ -299,89 +286,64 @@ func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *s
// if the given context is canceled.
func (m *Module) readFromWSConn(
ctx context.Context,
username string,
wsConn *websocket.Conn,
stream *streampkg.Stream,
l *log.Entry,
) {
l := log.
WithContext(ctx).
WithFields(kv.Fields{
{"username", username},
{"streamID", stream.ID},
}...)
readLoop:
for {
select {
case <-ctx.Done():
// Connection closed.
break readLoop
var msg struct {
Type string `json:"type"`
Stream string `json:"stream"`
List string `json:"list,omitempty"`
}
// Read JSON objects from the client and act on them.
if err := wsConn.ReadJSON(&msg); err != nil {
// Only log an error if something weird happened.
// See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7
if !websocket.IsCloseError(err, []int{
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived,
}...) {
l.Errorf("error during websocket read: %v", err)
}
// The connection is gone; no
// further streaming possible.
break
}
// Messages *from* the WS connection are infrequent
// and usually interesting, so log this at info.
l.Infof("received websocket message: %+v", msg)
// Ignore if the updateStreamType is unknown (or missing),
// so a bad client can't cause extra memory allocations
if !slices.Contains(streampkg.AllStatusTimelines, msg.Stream) {
l.Warnf("unknown 'stream' field: %v", msg)
continue
}
if msg.List != "" {
// If a list is given, add this to
// the stream name as this is how we
// we track stream types internally.
msg.Stream += ":" + msg.List
}
switch msg.Type {
case "subscribe":
stream.Subscribe(msg.Stream)
case "unsubscribe":
stream.Unsubscribe(msg.Stream)
default:
// Read JSON objects from the client and act on them.
var msg map[string]string
if err := wsConn.ReadJSON(&msg); err != nil {
// Only log an error if something weird happened.
// See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7
if websocket.IsUnexpectedCloseError(err, []int{
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived,
}...) {
l.Errorf("error reading from websocket: %v", err)
}
// The connection is gone; no
// further streaming possible.
break readLoop
}
// Messages *from* the WS connection are infrequent
// and usually interesting, so log this at info.
l.Infof("received message from websocket: %v", msg)
// If the message contains 'stream' and 'type' fields, we can
// update the set of timelines that are subscribed for events.
updateType, ok := msg["type"]
if !ok {
l.Warn("'type' field not provided")
continue
}
updateStream, ok := msg["stream"]
if !ok {
l.Warn("'stream' field not provided")
continue
}
// Ignore if the updateStreamType is unknown (or missing),
// so a bad client can't cause extra memory allocations
if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
l.Warnf("unknown 'stream' field: %v", msg)
continue
}
updateList, ok := msg["list"]
if ok {
updateStream += ":" + updateList
}
switch updateType {
case "subscribe":
stream.Lock()
stream.StreamTypes[updateStream] = true
stream.Unlock()
case "unsubscribe":
stream.Lock()
delete(stream.StreamTypes, updateStream)
stream.Unlock()
default:
l.Warnf("invalid 'type' field: %v", msg)
}
l.Warnf("invalid 'type' field: %v", msg)
}
}
l.Debug("finished reading from websocket connection")
l.Debug("finished websocket read")
}
// writeToWSConn receives messages coming from the processor via the
@@ -393,46 +355,47 @@ readLoop:
// if the given context is canceled.
func (m *Module) writeToWSConn(
ctx context.Context,
username string,
wsConn *websocket.Conn,
stream *streampkg.Stream,
pinger *time.Ticker,
ping time.Duration,
l *log.Entry,
) {
l := log.
WithContext(ctx).
WithFields(kv.Fields{
{"username", username},
{"streamID", stream.ID},
}...)
writeLoop:
for {
select {
case <-ctx.Done():
// Connection closed.
break writeLoop
// Wrap context with timeout to send a ping.
pingctx, cncl := context.WithTimeout(ctx, ping)
case msg := <-stream.Messages:
// Received a new message from the processor.
l.Tracef("writing message to websocket: %+v", msg)
if err := wsConn.WriteJSON(msg); err != nil {
l.Debugf("error writing json to websocket: %v", err)
break writeLoop
}
// Block on receipt of msg.
msg, ok := stream.Recv(pingctx)
// Reset pinger on successful send, since
// we know the connection is still there.
pinger.Reset(m.dTicker)
// Check if cancel because ping.
pinged := (pingctx.Err() != nil)
cncl()
case <-pinger.C:
// Time to send a keep-alive "ping".
l.Trace("writing ping control message to websocket")
switch {
case !ok && pinged:
// The ping context timed out!
l.Trace("writing websocket ping")
// Wrapped context time-out, send a keep-alive "ping".
if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil {
l.Debugf("error writing ping to websocket: %v", err)
break writeLoop
l.Debugf("error writing websocket ping: %v", err)
break
}
case !ok:
// Stream was
// closed.
return
}
l.Trace("writing websocket message: %+v", msg)
// Received a new message from the processor.
if err := wsConn.WriteJSON(msg); err != nil {
l.Debugf("error writing websocket message: %v", err)
break
}
}
l.Debug("finished writing to websocket connection")
l.Debug("finished websocket write")
}