diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index 7bb65f7a1..5d7f17f94 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -170,7 +170,6 @@ func (m *Module) StreamGETHandler(c *gin.Context) { // Prefer query token else use header token. token := cmp.Or(queryToken, headerToken) - if token != "" { // Token was provided, use it to authorize stream. diff --git a/internal/cache/timeline/status.go b/internal/cache/timeline/status.go index 56d90e422..c0c394042 100644 --- a/internal/cache/timeline/status.go +++ b/internal/cache/timeline/status.go @@ -336,6 +336,14 @@ func (t *StatusTimeline) Load( limit := page.Limit order := page.Order() dir := toDirection(order) + if limit <= 0 { + + // a page limit MUST be set! + // this shouldn't be possible + // but we check anyway to stop + // chance of limitless db calls! + panic("invalid page limit") + } // Use a copy of current page so // we can repeatedly update it. @@ -344,11 +352,11 @@ func (t *StatusTimeline) Load( nextPg.Min.Value = lo nextPg.Max.Value = hi - // Interstitial meta objects. - var metas []*StatusMeta + // Preallocate slice of interstitial models. + metas := make([]*StatusMeta, 0, limit) - // Returned frontend API statuses. - var apiStatuses []*apimodel.Status + // Preallocate slice of required status API models. + apiStatuses := make([]*apimodel.Status, 0, limit) // TODO: we can remove this nil // check when we've updated all @@ -362,13 +370,17 @@ func (t *StatusTimeline) Load( return nil, "", "", err } + // Load a little more than limit to + // reduce chance of db calls below. + limitPtr := util.Ptr(limit + 10) + // First we attempt to load status // metadata entries from the timeline // cache, up to given limit. metas = t.cache.Select( util.PtrIf(lo), util.PtrIf(hi), - util.PtrIf(limit), + limitPtr, dir, ) @@ -384,9 +396,6 @@ func (t *StatusTimeline) Load( lo = metas[len(metas)-1].ID hi = metas[0].ID - // Allocate slice of expected required API models. - apiStatuses = make([]*apimodel.Status, 0, len(metas)) - // Prepare frontend API models for // the cached statuses. For now this // also does its own extra filtering. @@ -399,10 +408,10 @@ func (t *StatusTimeline) Load( } } - // If no cached timeline statuses - // were found for page, we need to - // call through to the database. - if len(apiStatuses) == 0 { + // If not enough cached timeline + // statuses were found for page, + // we need to call to database. + if len(apiStatuses) < limit { // Pass through to main timeline db load function. apiStatuses, lo, hi, err = loadStatusTimeline(ctx, @@ -460,25 +469,31 @@ func loadStatusTimeline( // vals of loaded statuses. var lo, hi string - // Extract paging params. + // Extract paging params, in particular + // limit is used separate to nextPg to + // determine the *expected* return limit, + // not just what we use in db queries. + returnLimit := nextPg.Limit order := nextPg.Order() - limit := nextPg.Limit - - // Load a little more than - // limit to reduce db calls. - nextPg.Limit += 10 - - // Ensure we have a slice of meta objects to - // use in later preparation of the API models. - metas = xslices.GrowJust(metas[:0], nextPg.Limit) - - // Ensure we have a slice of required frontend API models. - apiStatuses = xslices.GrowJust(apiStatuses[:0], nextPg.Limit) // Perform maximum of 5 load // attempts fetching statuses. for i := 0; i < 5; i++ { + // Update page limit to the *remaining* + // limit of total we're expected to return. + nextPg.Limit = returnLimit - len(apiStatuses) + if nextPg.Limit <= 0 { + + // We reached the end! Set lo paging value. + lo = apiStatuses[len(apiStatuses)-1].ID + break + } + + // But load a bit more than + // limit to reduce db calls. + nextPg.Limit += 10 + // Load next timeline statuses. statuses, err := loadPage(nextPg) if err != nil { @@ -519,17 +534,8 @@ func loadStatusTimeline( metas, prepareAPI, apiStatuses, - limit, + returnLimit, ) - - // If we have anything, return - // here. Even if below limit. - if len(apiStatuses) > 0 { - - // Set returned lo status paging value. - lo = apiStatuses[len(apiStatuses)-1].ID - break - } } return apiStatuses, lo, hi, nil diff --git a/internal/cache/timeline/status_test.go b/internal/cache/timeline/status_test.go index 6a288d2ea..fc7e43da8 100644 --- a/internal/cache/timeline/status_test.go +++ b/internal/cache/timeline/status_test.go @@ -18,11 +18,16 @@ package timeline import ( + "context" + "fmt" "slices" "testing" apimodel "code.superseriousbusiness.org/gotosocial/internal/api/model" "code.superseriousbusiness.org/gotosocial/internal/gtsmodel" + "code.superseriousbusiness.org/gotosocial/internal/id" + "code.superseriousbusiness.org/gotosocial/internal/log" + "code.superseriousbusiness.org/gotosocial/internal/paging" "codeberg.org/gruf/go-structr" "github.com/stretchr/testify/assert" ) @@ -60,6 +65,46 @@ var testStatusMeta = []*StatusMeta{ }, } +func TestStatusTimelineLoadLimit(t *testing.T) { + var tt StatusTimeline + tt.Init(1000) + + // Prepare new context for the duration of this test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Clone the input test status data. + data := slices.Clone(testStatusMeta) + + // Insert test data into timeline. + _ = tt.cache.Insert(data...) + + // Manually mark timeline as 'preloaded'. + tt.preloader.CheckPreload(tt.preloader.Done) + + // Craft a new page for selection, + // setting placeholder min / max values + // but in particular setting a limit + // HIGHER than currently cached values. + page := new(paging.Page) + page.Min = paging.MinID(id.Lowest) + page.Max = paging.MaxID(id.Highest) + page.Limit = len(data) + 10 + + // Load crafted page from the cache. This + // SHOULD load all cached entries, then + // generate an extra 10 statuses up to limit. + apiStatuses, _, _, err := tt.Load(ctx, + page, + loadGeneratedStatusPage, + loadStatusIDsFrom(data), + nil, // no filtering + func(status *gtsmodel.Status) (*apimodel.Status, error) { return new(apimodel.Status), nil }, + ) + assert.NoError(t, err) + assert.Len(t, apiStatuses, page.Limit) +} + func TestStatusTimelineUnprepare(t *testing.T) { var tt StatusTimeline tt.Init(1000) @@ -301,6 +346,44 @@ func TestStatusTimelineTrim(t *testing.T) { assert.Equal(t, before, tt.cache.Len()) } +// loadStatusIDsFrom imitates loading of statuses of given IDs from the database, instead selecting +// statuses with appropriate IDs from the given slice of status meta, converting them to statuses. +func loadStatusIDsFrom(data []*StatusMeta) func(ids []string) ([]*gtsmodel.Status, error) { + return func(ids []string) ([]*gtsmodel.Status, error) { + var statuses []*gtsmodel.Status + for _, id := range ids { + i := slices.IndexFunc(data, func(s *StatusMeta) bool { + return s.ID == id + }) + if i < 0 || i >= len(data) { + panic(fmt.Sprintf("could not find %s in %v", id, log.VarDump(data))) + } + statuses = append(statuses, >smodel.Status{ + ID: data[i].ID, + AccountID: data[i].AccountID, + BoostOfID: data[i].BoostOfID, + BoostOfAccountID: data[i].BoostOfAccountID, + }) + } + return statuses, nil + } +} + +// loadGeneratedStatusPage imitates loading of a given page of statuses, +// simply generating new statuses until the given page's limit is reached. +func loadGeneratedStatusPage(page *paging.Page) ([]*gtsmodel.Status, error) { + var statuses []*gtsmodel.Status + for range page.Limit { + statuses = append(statuses, >smodel.Status{ + ID: id.NewULID(), + AccountID: id.NewULID(), + BoostOfID: id.NewULID(), + BoostOfAccountID: id.NewULID(), + }) + } + return statuses, nil +} + // containsStatusID returns whether timeline contains a status with ID. func containsStatusID(t *StatusTimeline, id string) bool { return getStatusByID(t, id) != nil diff --git a/internal/middleware/logger.go b/internal/middleware/logger.go index 9fa245666..00e940992 100644 --- a/internal/middleware/logger.go +++ b/internal/middleware/logger.go @@ -21,6 +21,7 @@ import ( "fmt" "net/http" "runtime" + "strings" "time" "code.superseriousbusiness.org/gotosocial/internal/gtscontext" @@ -35,19 +36,21 @@ import ( // Logger returns a gin middleware which provides request logging and panic recovery. func Logger(logClientIP bool) gin.HandlerFunc { return func(c *gin.Context) { - // Initialize the logging fields - fields := make(kv.Fields, 5, 7) - // Determine pre-handler time before := time.Now() - // defer so that we log *after the request has completed* + // defer so that we log *after + // the request has completed* defer func() { + + // Get response status code. code := c.Writer.Status() - path := c.Request.URL.Path + + // Get request context. + ctx := c.Request.Context() if r := recover(); r != nil { - if c.Writer.Status() == 0 { + if code == 0 { // No response was written, send a generic Internal Error c.Writer.WriteHeader(http.StatusInternalServerError) } @@ -65,37 +68,51 @@ func Logger(logClientIP bool) gin.HandlerFunc { WithField("stacktrace", callers).Error(err) } - // NOTE: - // It is very important here that we are ONLY logging - // the request path, and none of the query parameters. - // Query parameters can contain sensitive information - // and could lead to storing plaintext API keys in logs + // Initialize the logging fields + fields := make(kv.Fields, 5, 8) // Set request logging fields fields[0] = kv.Field{"latency", time.Since(before)} fields[1] = kv.Field{"userAgent", c.Request.UserAgent()} fields[2] = kv.Field{"method", c.Request.Method} fields[3] = kv.Field{"statusCode", code} - fields[4] = kv.Field{"path", path} - // Set optional request logging fields. + // If the request contains sensitive query + // data only log path, else log entire URI. + if sensitiveQuery(c.Request.URL.RawQuery) { + path := c.Request.URL.Path + fields[4] = kv.Field{"uri", path} + } else { + uri := c.Request.RequestURI + fields[4] = kv.Field{"uri", uri} + } + if logClientIP { + // Append IP only if configured to. fields = append(fields, kv.Field{ "clientIP", c.ClientIP(), }) } - ctx := c.Request.Context() if pubKeyID := gtscontext.HTTPSignaturePubKeyID(ctx); pubKeyID != nil { + // Append public key ID if attached. fields = append(fields, kv.Field{ "pubKeyID", pubKeyID.String(), }) } - // Create log entry with fields - l := log.New() - l = l.WithContext(ctx) - l = l.WithFields(fields...) + if len(c.Errors) > 0 { + // Always attach any found errors. + fields = append(fields, kv.Field{ + "errors", c.Errors, + }) + } + + // Create entry + // with fields. + l := log.New(). + WithContext(ctx). + WithFields(fields...) // Default is info lvl := log.INFO @@ -105,11 +122,6 @@ func Logger(logClientIP bool) gin.HandlerFunc { lvl = log.ERROR } - if len(c.Errors) > 0 { - // Always attach any found errors. - l = l.WithField("errors", c.Errors) - } - // Get appropriate text for this code. statusText := http.StatusText(code) if statusText == "" { @@ -125,15 +137,22 @@ func Logger(logClientIP bool) gin.HandlerFunc { // Generate a nicer looking bytecount size := bytesize.Size(c.Writer.Size()) // #nosec G115 -- Just logging - // Finally, write log entry with status text + body size. + // Write log entry with status text + body size. l.Logf(lvl, "%s: wrote %s", statusText, size) }() - // Process request + // Process + // request. c.Next() } } +// sensitiveQuery checks whether given query string +// contains sensitive data that shouldn't be logged. +func sensitiveQuery(query string) bool { + return strings.Contains(query, "token") +} + // gatherFrames gathers runtime frames from a frame iterator. func gatherFrames(iter *runtime.Frames, n int) []runtime.Frame { if iter == nil { diff --git a/internal/paging/page.go b/internal/paging/page.go index 6c91da6b2..708ab1bd7 100644 --- a/internal/paging/page.go +++ b/internal/paging/page.go @@ -278,10 +278,10 @@ func (p *Page) ToLinkURL(proto, host, path string, queryParams url.Values) *url. if queryParams == nil { // Allocate new query parameters. - queryParams = make(url.Values) + queryParams = make(url.Values, 2) } else { // Before edit clone existing params. - queryParams = cloneQuery(queryParams) + queryParams = cloneQuery(queryParams, 2) } if p.Min.Value != "" { @@ -309,8 +309,8 @@ func (p *Page) ToLinkURL(proto, host, path string, queryParams url.Values) *url. } // cloneQuery clones input map of url values. -func cloneQuery(src url.Values) url.Values { - dst := make(url.Values, len(src)) +func cloneQuery(src url.Values, extra int) url.Values { + dst := make(url.Values, len(src)+extra) for k, vs := range src { dst[k] = slices.Clone(vs) }