mirror of
				https://github.com/superseriousbusiness/gotosocial
				synced 2025-06-05 21:59:39 +02:00 
			
		
		
		
	[bugfix] return early in websocket upgrade handler (#1315)
* launch websocket streaming in goroutine to allow upgrade handler to return * don't send any message on ping, improved close check on failed read * use context to signal wsconn close, ensure canceled in read goroutine Signed-off-by: kim <grufwub@gmail.com>
This commit is contained in:
		| @@ -19,6 +19,8 @@ | ||||
| package api | ||||
|  | ||||
| import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/api/client/admin" | ||||
| @@ -122,7 +124,7 @@ func NewClient(db db.DB, p processing.Processor) *Client { | ||||
| 		notifications:  notifications.New(p), | ||||
| 		search:         search.New(p), | ||||
| 		statuses:       statuses.New(p), | ||||
| 		streaming:      streaming.New(p), | ||||
| 		streaming:      streaming.New(p, time.Second*30, 4096), | ||||
| 		timelines:      timelines.New(p), | ||||
| 		user:           user.New(p), | ||||
| 	} | ||||
|   | ||||
| @@ -19,8 +19,9 @@ | ||||
| package streaming | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
|  | ||||
| 	"codeberg.org/gruf/go-kv" | ||||
| @@ -32,16 +33,6 @@ import ( | ||||
| 	"github.com/gorilla/websocket" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	wsUpgrader = websocket.Upgrader{ | ||||
| 		ReadBufferSize:  1024, | ||||
| 		WriteBufferSize: 1024, | ||||
| 		// we expect cors requests (via eg., pinafore.social) so be lenient | ||||
| 		CheckOrigin: func(r *http.Request) bool { return true }, | ||||
| 	} | ||||
| 	errNoToken = fmt.Errorf("no access token provided under query key %s or under header %s", AccessTokenQueryKey, AccessTokenHeader) | ||||
| ) | ||||
|  | ||||
| // StreamGETHandler swagger:operation GET /api/v1/streaming streamGet | ||||
| // | ||||
| // Initiate a websocket connection for live streaming of statuses and notifications. | ||||
| @@ -150,21 +141,20 @@ func (m *Module) StreamGETHandler(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var accessToken string | ||||
| 	if t := c.Query(AccessTokenQueryKey); t != "" { | ||||
| 		// try query param first | ||||
| 		accessToken = t | ||||
| 	} else if t := c.GetHeader(AccessTokenHeader); t != "" { | ||||
| 		// fall back to Sec-Websocket-Protocol | ||||
| 		accessToken = t | ||||
| 	} else { | ||||
| 		// no token | ||||
| 		err := errNoToken | ||||
| 		apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) | ||||
| 	var token string | ||||
|  | ||||
| 	// First we check for a query param provided access token | ||||
| 	if token = c.Query(AccessTokenQueryKey); token == "" { | ||||
| 		// Else we check the HTTP header provided token | ||||
| 		if token = c.GetHeader(AccessTokenHeader); token == "" { | ||||
| 			const errStr = "no access token provided" | ||||
| 			err := gtserror.NewErrorUnauthorized(errors.New(errStr), errStr) | ||||
| 			apiutil.ErrorHandler(c, err, m.processor.InstanceGet) | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), accessToken) | ||||
| 	account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), token) | ||||
| 	if errWithCode != nil { | ||||
| 		apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet) | ||||
| 		return | ||||
| @@ -178,51 +168,97 @@ func (m *Module) StreamGETHandler(c *gin.Context) { | ||||
|  | ||||
| 	l := log.WithFields(kv.Fields{ | ||||
| 		{"account", account.Username}, | ||||
| 		{"path", BasePath}, | ||||
| 		{"streamID", stream.ID}, | ||||
| 		{"streamType", streamType}, | ||||
| 	}...) | ||||
|  | ||||
| 	wsConn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil) | ||||
| 	// Upgrade the incoming HTTP request, which hijacks the underlying | ||||
| 	// connection and reuses it for the websocket (non-http) protocol. | ||||
| 	wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil) | ||||
| 	if err != nil { | ||||
| 		// If the upgrade fails, then Upgrade replies to the client with an HTTP error response. | ||||
| 		// Because websocket issues are a pretty common source of headaches, we should also log | ||||
| 		// this at Error to make this plenty visible and help admins out a bit. | ||||
| 		l.Errorf("error upgrading websocket connection: %s", err) | ||||
| 		l.Errorf("error upgrading websocket connection: %v", err) | ||||
| 		close(stream.Hangup) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		// We perform the main websocket send loop in a separate | ||||
| 		// goroutine in order to let the upgrade handler return. | ||||
| 		// This prevents the upgrade handler from holding open any | ||||
| 		// throttle / rate-limit request tokens which could become | ||||
| 		// problematic on instances with multiple users. | ||||
| 		l.Info("opened websocket connection") | ||||
| 		defer l.Info("closed websocket connection") | ||||
|  | ||||
| 		// Create new context for lifetime of the connection | ||||
| 		ctx, cncl := context.WithCancel(context.Background()) | ||||
|  | ||||
| 		// Create ticker to send alive pings | ||||
| 		pinger := time.NewTicker(m.dTicker) | ||||
|  | ||||
| 		defer func() { | ||||
| 		// cleanup | ||||
| 		wsConn.Close() | ||||
| 			// Signal done | ||||
| 			cncl() | ||||
|  | ||||
| 			// Close websocket conn | ||||
| 			_ = wsConn.Close() | ||||
|  | ||||
| 			// Close processor stream | ||||
| 			close(stream.Hangup) | ||||
|  | ||||
| 			// Stop ping ticker | ||||
| 			pinger.Stop() | ||||
| 		}() | ||||
|  | ||||
| 	streamTicker := time.NewTicker(m.tickDuration) | ||||
| 	defer streamTicker.Stop() | ||||
| 		go func() { | ||||
| 			// Signal done | ||||
| 			defer cncl() | ||||
|  | ||||
| 	// We want to stay in the loop as long as possible while the client is connected. | ||||
| 	// The only thing that should break the loop is if the client leaves or the connection becomes unhealthy. | ||||
| 			for { | ||||
| 				// We have to listen for received websocket messages in | ||||
| 				// order to trigger the underlying wsConn.PingHandler(). | ||||
| 				// | ||||
| 	// If the loop does break, we expect the client to reattempt connection, so it's cheap to leave + try again | ||||
| wsLoop: | ||||
| 				// So we wait on received messages but only act on errors. | ||||
| 				_, _, err := wsConn.ReadMessage() | ||||
| 				if err != nil { | ||||
| 					if ctx.Err() == nil { | ||||
| 						// Only log error if the connection was not closed | ||||
| 						// by us. Uncanceled context indicates this is the case. | ||||
| 						l.Errorf("error reading from websocket: %v", err) | ||||
| 					} | ||||
| 					return | ||||
| 				} | ||||
| 			} | ||||
| 		}() | ||||
|  | ||||
| 		for { | ||||
| 			select { | ||||
| 		case m := <-stream.Messages: | ||||
| 			l.Trace("received message from stream") | ||||
| 			if err := wsConn.WriteJSON(m); err != nil { | ||||
| 				l.Debugf("error writing json to websocket connection; breaking off: %s", err) | ||||
| 				break wsLoop | ||||
| 			// Connection closed | ||||
| 			case <-ctx.Done(): | ||||
| 				return | ||||
|  | ||||
| 			// Received next stream message | ||||
| 			case msg := <-stream.Messages: | ||||
| 				l.Tracef("sending message to websocket: %+v", msg) | ||||
| 				if err := wsConn.WriteJSON(msg); err != nil { | ||||
| 					l.Errorf("error writing json to websocket: %v", err) | ||||
| 					return | ||||
| 				} | ||||
| 			l.Trace("wrote message into websocket connection") | ||||
| 		case <-streamTicker.C: | ||||
| 			l.Trace("received TICK from ticker") | ||||
| 			if err := wsConn.WriteMessage(websocket.PingMessage, []byte(": ping")); err != nil { | ||||
| 				l.Debugf("error writing ping to websocket connection; breaking off: %s", err) | ||||
| 				break wsLoop | ||||
| 			} | ||||
| 			l.Trace("wrote ping message into websocket connection") | ||||
|  | ||||
| 				// Reset on each successful send. | ||||
| 				pinger.Reset(m.dTicker) | ||||
|  | ||||
| 			// Send keep-alive "ping" | ||||
| 			case <-pinger.C: | ||||
| 				l.Trace("pinging websocket ...") | ||||
| 				if err := wsConn.WriteMessage( | ||||
| 					websocket.PingMessage, | ||||
| 					[]byte{}, | ||||
| 				); err != nil { | ||||
| 					l.Errorf("error writing ping to websocket: %v", err) | ||||
| 					return | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| } | ||||
|   | ||||
| @@ -23,6 +23,7 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"github.com/superseriousbusiness/gotosocial/internal/processing" | ||||
| ) | ||||
|  | ||||
| @@ -42,20 +43,21 @@ const ( | ||||
|  | ||||
| type Module struct { | ||||
| 	processor processing.Processor | ||||
| 	tickDuration time.Duration | ||||
| 	dTicker   time.Duration | ||||
| 	wsUpgrade websocket.Upgrader | ||||
| } | ||||
|  | ||||
| func New(processor processing.Processor) *Module { | ||||
| func New(processor processing.Processor, dTicker time.Duration, wsBuf int) *Module { | ||||
| 	return &Module{ | ||||
| 		processor: processor, | ||||
| 		tickDuration: 30 * time.Second, | ||||
| 	} | ||||
| } | ||||
| 		dTicker:   dTicker, | ||||
| 		wsUpgrade: websocket.Upgrader{ | ||||
| 			ReadBufferSize:  wsBuf, // we don't expect reads | ||||
| 			WriteBufferSize: wsBuf, | ||||
|  | ||||
| func NewWithTickDuration(processor processing.Processor, tickDuration time.Duration) *Module { | ||||
| 	return &Module{ | ||||
| 		processor:    processor, | ||||
| 		tickDuration: tickDuration, | ||||
| 			// we expect cors requests (via eg., pinafore.social) so be lenient | ||||
| 			CheckOrigin: func(r *http.Request) bool { return true }, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -99,7 +99,7 @@ func (suite *StreamingTestSuite) SetupTest() { | ||||
| 	suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) | ||||
| 	suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) | ||||
| 	suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) | ||||
| 	suite.streamingModule = streaming.NewWithTickDuration(suite.processor, 1) | ||||
| 	suite.streamingModule = streaming.New(suite.processor, 1, 4096) | ||||
| 	suite.NoError(suite.processor.Start()) | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user