[chore/bugfix] Break Websockets logic into smaller read/write functions, don't log expected errors (#1932)

* [chore/bugfix] Break Websockets logic into smaller read/write functions, don't log expected errors

* tweak

* tidy up, use control message
This commit is contained in:
tobi 2023-07-04 12:55:10 +02:00 committed by GitHub
parent ba0bc06b8c
commit 3d16962173
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 237 additions and 133 deletions

View File

@ -149,60 +149,78 @@ import (
// '400':
// description: bad request
func (m *Module) StreamGETHandler(c *gin.Context) {
var (
account *gtsmodel.Account
errWithCode gtserror.WithCode
)
// First we check for a query param provided access token
// Try query param access token.
token := c.Query(AccessTokenQueryKey)
if token == "" {
// Else we check the HTTP header provided token
// Try fallback HTTP header provided token.
token = c.GetHeader(AccessTokenHeader)
}
var account *gtsmodel.Account
if token != "" {
// Check the explicit token
var errWithCode gtserror.WithCode
// Token was provided, use it to authorize stream.
account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
} else {
// If no explicit token was provided, try regular oauth
auth, errStr := oauth.Authed(c, true, true, true, true)
if errStr != nil {
err := gtserror.NewErrorUnauthorized(errStr, errStr.Error())
apiutil.ErrorHandler(c, err, m.processor.InstanceGetV1)
return
}
account = auth.Account
// No explicit token was provided:
// try regular oauth as a last resort.
account, errWithCode = func() (*gtsmodel.Account, gtserror.WithCode) {
authed, err := oauth.Authed(c, true, true, true, true)
if err != nil {
return nil, gtserror.NewErrorUnauthorized(err, err.Error())
}
return authed.Account, nil
}()
}
// Get the initial stream type, if there is one.
// By appending other query params to the streamType,
// we can allow for streaming for specific list IDs
// or hashtags.
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
// Get the initial requested stream type, if there is one.
streamType := c.Query(StreamQueryKey)
// By appending other query params to the streamType, we
// can allow streaming for specific list IDs or hashtags.
// The streamType in this case will end up looking like
// `hashtag:example` or `list:01H3YF48G8B7KTPQFS8D2QBVG8`.
if list := c.Query(StreamListKey); list != "" {
streamType += ":" + list
} else if tag := c.Query(StreamTagKey); tag != "" {
streamType += ":" + tag
}
stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType)
// Open a stream with the processor; this lets processor
// functions pass messages into a channel, which we can
// then read from and put into a websockets connection.
stream, errWithCode := m.processor.Stream().Open(
c.Request.Context(),
account,
streamType,
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
l := log.WithContext(c.Request.Context()).
l := log.
WithContext(c.Request.Context()).
WithFields(kv.Fields{
{"account", account.Username},
{"username", account.Username},
{"streamID", stream.ID},
{"streamType", streamType},
}...)
// Upgrade the incoming HTTP request, which hijacks the underlying
// connection and reuses it for the websocket (non-http) protocol.
// Upgrade the incoming HTTP request. This hijacks the
// underlying connection and reuses it for the websocket
// (non-http) protocol.
//
// If the upgrade fails, then Upgrade replies to the client
// with an HTTP error response.
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
if err != nil {
l.Errorf("error upgrading websocket connection: %v", err)
@ -210,125 +228,208 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
return
}
l.Info("opened websocket connection")
// We perform the main websocket rw loops in a separate
// goroutine in order to let the upgrade handler return.
// This prevents the upgrade handler from holding open any
// throttle / rate-limit request tokens which could become
// problematic on instances with multiple users.
go m.handleWSConn(account.Username, wsConn, stream)
}
// handleWSConn handles a two-way websocket streaming connection.
// It will both read messages from the connection, and push messages
// into the connection. If any errors are encountered while reading
// or writing (including expected errors like clients leaving), the
// connection will be closed.
func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) {
// Create new context for the lifetime of this connection.
ctx, cancel := context.WithCancel(context.Background())
l := log.
WithContext(ctx).
WithFields(kv.Fields{
{"username", username},
{"streamID", stream.ID},
}...)
// Create ticker to send keepalive pings
pinger := time.NewTicker(m.dTicker)
// Read messages coming from the Websocket client connection into the server.
go func() {
// We perform the main websocket send loop in a separate
// goroutine in order to let the upgrade handler return.
// This prevents the upgrade handler from holding open any
// throttle / rate-limit request tokens which could become
// problematic on instances with multiple users.
l.Info("opened websocket connection")
defer l.Info("closed websocket connection")
defer cancel()
m.readFromWSConn(ctx, username, wsConn, stream)
}()
// Create new context for lifetime of the connection
ctx, cncl := context.WithCancel(context.Background())
// Write messages coming from the processor into the Websocket client connection.
go func() {
defer cancel()
m.writeToWSConn(ctx, username, wsConn, stream, pinger)
}()
// Create ticker to send alive pings
pinger := time.NewTicker(m.dTicker)
// Wait for either the read or write functions to close, to indicate
// that the client has left, or something else has gone wrong.
<-ctx.Done()
defer func() {
// Signal done
cncl()
// Tidy up underlying websocket connection.
if err := wsConn.Close(); err != nil {
l.Errorf("error closing websocket connection: %v", err)
}
// Close websocket conn
_ = wsConn.Close()
// Close processor channel so the processor knows
// not to send any more messages to this stream.
close(stream.Hangup)
// Close processor stream
close(stream.Hangup)
// Stop ping ticker (tiny resource saving).
pinger.Stop()
// Stop ping ticker
pinger.Stop()
}()
l.Info("closed websocket connection")
}
go func() {
// Signal done
defer cncl()
// readFromWSConn reads control messages coming in from the given
// websockets connection, and modifies the subscription StreamTypes
// of the given stream accordingly after acquiring a lock on it.
//
// This is a blocking function; will return only on read error or
// if the given context is canceled.
func (m *Module) readFromWSConn(
ctx context.Context,
username string,
wsConn *websocket.Conn,
stream *streampkg.Stream,
) {
l := log.
WithContext(ctx).
WithFields(kv.Fields{
{"username", username},
{"streamID", stream.ID},
}...)
for {
// We have to listen for received websocket messages in
// order to trigger the underlying wsConn.PingHandler().
//
// Read JSON objects from the client and act on them
var msg map[string]string
err := wsConn.ReadJSON(&msg)
if err != nil {
if ctx.Err() == nil {
// Only log error if the connection was not closed
// by us. Uncanceled context indicates this is the case.
l.Errorf("error reading from websocket: %v", err)
}
return
}
l.Tracef("received message from websocket: %v", msg)
readLoop:
for {
select {
case <-ctx.Done():
// Connection closed.
break readLoop
// If the message contains 'stream' and 'type' fields, we can
// update the set of timelines that are subscribed for events.
updateType, ok := msg["type"]
if !ok {
l.Warn("'type' field not provided")
continue
default:
// Read JSON objects from the client and act on them.
var msg map[string]string
if err := wsConn.ReadJSON(&msg); err != nil {
// Only log an error if something weird happened.
// See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7
if websocket.IsUnexpectedCloseError(err, []int{
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived,
}...) {
l.Errorf("error reading from websocket: %v", err)
}
updateStream, ok := msg["stream"]
if !ok {
l.Warn("'stream' field not provided")
continue
}
// Ignore if the updateStreamType is unknown (or missing),
// so a bad client can't cause extra memory allocations
if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
l.Warnf("unknown 'stream' field: %v", msg)
continue
}
updateList, ok := msg["list"]
if ok {
updateStream += ":" + updateList
}
switch updateType {
case "subscribe":
stream.Lock()
stream.StreamTypes[updateStream] = true
stream.Unlock()
case "unsubscribe":
stream.Lock()
delete(stream.StreamTypes, updateStream)
stream.Unlock()
default:
l.Warnf("invalid 'type' field: %v", msg)
}
// The connection is gone; no
// further streaming possible.
break readLoop
}
}()
for {
select {
// Connection closed
case <-ctx.Done():
return
// Messages *from* the WS connection are infrequent
// and usually interesting, so log this at info.
l.Infof("received message from websocket: %v", msg)
// Received next stream message
case msg := <-stream.Messages:
l.Tracef("sending message to websocket: %+v", msg)
if err := wsConn.WriteJSON(msg); err != nil {
l.Debugf("error writing json to websocket: %v", err)
return
}
// If the message contains 'stream' and 'type' fields, we can
// update the set of timelines that are subscribed for events.
updateType, ok := msg["type"]
if !ok {
l.Warn("'type' field not provided")
continue
}
// Reset on each successful send.
pinger.Reset(m.dTicker)
updateStream, ok := msg["stream"]
if !ok {
l.Warn("'stream' field not provided")
continue
}
// Send keep-alive "ping"
case <-pinger.C:
l.Trace("pinging websocket ...")
if err := wsConn.WriteMessage(
websocket.PingMessage,
[]byte{},
); err != nil {
l.Debugf("error writing ping to websocket: %v", err)
return
}
// Ignore if the updateStreamType is unknown (or missing),
// so a bad client can't cause extra memory allocations
if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
l.Warnf("unknown 'stream' field: %v", msg)
continue
}
updateList, ok := msg["list"]
if ok {
updateStream += ":" + updateList
}
switch updateType {
case "subscribe":
stream.Lock()
stream.StreamTypes[updateStream] = true
stream.Unlock()
case "unsubscribe":
stream.Lock()
delete(stream.StreamTypes, updateStream)
stream.Unlock()
default:
l.Warnf("invalid 'type' field: %v", msg)
}
}
}()
}
l.Debug("finished reading from websocket connection")
}
// writeToWSConn receives messages coming from the processor via the
// given stream, and writes them into the given websockets connection.
// This function also handles sending ping messages into the websockets
// connection to keep it alive when no other activity occurs.
//
// This is a blocking function; will return only on write error or
// if the given context is canceled.
func (m *Module) writeToWSConn(
ctx context.Context,
username string,
wsConn *websocket.Conn,
stream *streampkg.Stream,
pinger *time.Ticker,
) {
l := log.
WithContext(ctx).
WithFields(kv.Fields{
{"username", username},
{"streamID", stream.ID},
}...)
writeLoop:
for {
select {
case <-ctx.Done():
// Connection closed.
break writeLoop
case msg := <-stream.Messages:
// Received a new message from the processor.
l.Tracef("writing message to websocket: %+v", msg)
if err := wsConn.WriteJSON(msg); err != nil {
l.Debugf("error writing json to websocket: %v", err)
break writeLoop
}
// Reset pinger on successful send, since
// we know the connection is still there.
pinger.Reset(m.dTicker)
case <-pinger.C:
// Time to send a keep-alive "ping".
l.Trace("writing ping control message to websocket")
if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil {
l.Debugf("error writing ping to websocket: %v", err)
break writeLoop
}
}
}
l.Debug("finished writing to websocket connection")
}

View File

@ -42,15 +42,18 @@ type Module struct {
}
func New(processor *processing.Processor, dTicker time.Duration, wsBuf int) *Module {
// We expect CORS requests for websockets,
// (via eg., semaphore.social) so be lenient.
// TODO: make this customizable?
checkOrigin := func(r *http.Request) bool { return true }
return &Module{
processor: processor,
dTicker: dTicker,
wsUpgrade: websocket.Upgrader{
ReadBufferSize: wsBuf, // we don't expect reads
ReadBufferSize: wsBuf,
WriteBufferSize: wsBuf,
// we expect cors requests (via eg., semaphore.social) so be lenient
CheckOrigin: func(r *http.Request) bool { return true },
CheckOrigin: checkOrigin,
},
}
}