[bugfix] ensure timeline limit query is respected (#4141)

# Description

Fixes a bug in the new timeline code in which the limit query parameter wasn't always being fulfilled, in which case some clients like Tusky would then assume it didn't need to add a "load more" placeholder view even when there were more statuses to be loaded. This also fiddles around a bit in the logging middleware handler to add some more code comments, and add logging of full request URIs when it is safe to do so.

## Checklist

- [x] I/we have read the [GoToSocial contribution guidelines](https://codeberg.org/superseriousbusiness/gotosocial/src/branch/main/CONTRIBUTING.md).
- [x] I/we have discussed the proposed changes already, either in an issue on the repository, or in the Matrix chat.
- [x] I/we have not leveraged AI to create the proposed changes.
- [x] I/we have performed a self-review of added code.
- [x] I/we have written code that is legible and maintainable by others.
- [x] I/we have commented the added code, particularly in hard-to-understand areas.
- [ ] I/we have made any necessary changes to documentation.
- [x] I/we have added tests that cover new code.
- [x] I/we have run tests and they pass locally with the changes.
- [x] I/we have run `go fmt ./...` and `golangci-lint run`.

Reviewed-on: https://codeberg.org/superseriousbusiness/gotosocial/pulls/4141
Co-authored-by: kim <grufwub@gmail.com>
Co-committed-by: kim <grufwub@gmail.com>
This commit is contained in:
kim
2025-05-06 13:30:23 +00:00
committed by tobi
parent e464de1322
commit 8264b63337
5 changed files with 172 additions and 65 deletions

View File

@ -170,7 +170,6 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
// Prefer query token else use header token. // Prefer query token else use header token.
token := cmp.Or(queryToken, headerToken) token := cmp.Or(queryToken, headerToken)
if token != "" { if token != "" {
// Token was provided, use it to authorize stream. // Token was provided, use it to authorize stream.

View File

@ -336,6 +336,14 @@ func (t *StatusTimeline) Load(
limit := page.Limit limit := page.Limit
order := page.Order() order := page.Order()
dir := toDirection(order) dir := toDirection(order)
if limit <= 0 {
// a page limit MUST be set!
// this shouldn't be possible
// but we check anyway to stop
// chance of limitless db calls!
panic("invalid page limit")
}
// Use a copy of current page so // Use a copy of current page so
// we can repeatedly update it. // we can repeatedly update it.
@ -344,11 +352,11 @@ func (t *StatusTimeline) Load(
nextPg.Min.Value = lo nextPg.Min.Value = lo
nextPg.Max.Value = hi nextPg.Max.Value = hi
// Interstitial meta objects. // Preallocate slice of interstitial models.
var metas []*StatusMeta metas := make([]*StatusMeta, 0, limit)
// Returned frontend API statuses. // Preallocate slice of required status API models.
var apiStatuses []*apimodel.Status apiStatuses := make([]*apimodel.Status, 0, limit)
// TODO: we can remove this nil // TODO: we can remove this nil
// check when we've updated all // check when we've updated all
@ -362,13 +370,17 @@ func (t *StatusTimeline) Load(
return nil, "", "", err return nil, "", "", err
} }
// Load a little more than limit to
// reduce chance of db calls below.
limitPtr := util.Ptr(limit + 10)
// First we attempt to load status // First we attempt to load status
// metadata entries from the timeline // metadata entries from the timeline
// cache, up to given limit. // cache, up to given limit.
metas = t.cache.Select( metas = t.cache.Select(
util.PtrIf(lo), util.PtrIf(lo),
util.PtrIf(hi), util.PtrIf(hi),
util.PtrIf(limit), limitPtr,
dir, dir,
) )
@ -384,9 +396,6 @@ func (t *StatusTimeline) Load(
lo = metas[len(metas)-1].ID lo = metas[len(metas)-1].ID
hi = metas[0].ID hi = metas[0].ID
// Allocate slice of expected required API models.
apiStatuses = make([]*apimodel.Status, 0, len(metas))
// Prepare frontend API models for // Prepare frontend API models for
// the cached statuses. For now this // the cached statuses. For now this
// also does its own extra filtering. // also does its own extra filtering.
@ -399,10 +408,10 @@ func (t *StatusTimeline) Load(
} }
} }
// If no cached timeline statuses // If not enough cached timeline
// were found for page, we need to // statuses were found for page,
// call through to the database. // we need to call to database.
if len(apiStatuses) == 0 { if len(apiStatuses) < limit {
// Pass through to main timeline db load function. // Pass through to main timeline db load function.
apiStatuses, lo, hi, err = loadStatusTimeline(ctx, apiStatuses, lo, hi, err = loadStatusTimeline(ctx,
@ -460,25 +469,31 @@ func loadStatusTimeline(
// vals of loaded statuses. // vals of loaded statuses.
var lo, hi string var lo, hi string
// Extract paging params. // Extract paging params, in particular
// limit is used separate to nextPg to
// determine the *expected* return limit,
// not just what we use in db queries.
returnLimit := nextPg.Limit
order := nextPg.Order() order := nextPg.Order()
limit := nextPg.Limit
// Load a little more than
// limit to reduce db calls.
nextPg.Limit += 10
// Ensure we have a slice of meta objects to
// use in later preparation of the API models.
metas = xslices.GrowJust(metas[:0], nextPg.Limit)
// Ensure we have a slice of required frontend API models.
apiStatuses = xslices.GrowJust(apiStatuses[:0], nextPg.Limit)
// Perform maximum of 5 load // Perform maximum of 5 load
// attempts fetching statuses. // attempts fetching statuses.
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
// Update page limit to the *remaining*
// limit of total we're expected to return.
nextPg.Limit = returnLimit - len(apiStatuses)
if nextPg.Limit <= 0 {
// We reached the end! Set lo paging value.
lo = apiStatuses[len(apiStatuses)-1].ID
break
}
// But load a bit more than
// limit to reduce db calls.
nextPg.Limit += 10
// Load next timeline statuses. // Load next timeline statuses.
statuses, err := loadPage(nextPg) statuses, err := loadPage(nextPg)
if err != nil { if err != nil {
@ -519,17 +534,8 @@ func loadStatusTimeline(
metas, metas,
prepareAPI, prepareAPI,
apiStatuses, apiStatuses,
limit, returnLimit,
) )
// If we have anything, return
// here. Even if below limit.
if len(apiStatuses) > 0 {
// Set returned lo status paging value.
lo = apiStatuses[len(apiStatuses)-1].ID
break
}
} }
return apiStatuses, lo, hi, nil return apiStatuses, lo, hi, nil

View File

@ -18,11 +18,16 @@
package timeline package timeline
import ( import (
"context"
"fmt"
"slices" "slices"
"testing" "testing"
apimodel "code.superseriousbusiness.org/gotosocial/internal/api/model" apimodel "code.superseriousbusiness.org/gotosocial/internal/api/model"
"code.superseriousbusiness.org/gotosocial/internal/gtsmodel" "code.superseriousbusiness.org/gotosocial/internal/gtsmodel"
"code.superseriousbusiness.org/gotosocial/internal/id"
"code.superseriousbusiness.org/gotosocial/internal/log"
"code.superseriousbusiness.org/gotosocial/internal/paging"
"codeberg.org/gruf/go-structr" "codeberg.org/gruf/go-structr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -60,6 +65,46 @@ var testStatusMeta = []*StatusMeta{
}, },
} }
func TestStatusTimelineLoadLimit(t *testing.T) {
var tt StatusTimeline
tt.Init(1000)
// Prepare new context for the duration of this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Clone the input test status data.
data := slices.Clone(testStatusMeta)
// Insert test data into timeline.
_ = tt.cache.Insert(data...)
// Manually mark timeline as 'preloaded'.
tt.preloader.CheckPreload(tt.preloader.Done)
// Craft a new page for selection,
// setting placeholder min / max values
// but in particular setting a limit
// HIGHER than currently cached values.
page := new(paging.Page)
page.Min = paging.MinID(id.Lowest)
page.Max = paging.MaxID(id.Highest)
page.Limit = len(data) + 10
// Load crafted page from the cache. This
// SHOULD load all cached entries, then
// generate an extra 10 statuses up to limit.
apiStatuses, _, _, err := tt.Load(ctx,
page,
loadGeneratedStatusPage,
loadStatusIDsFrom(data),
nil, // no filtering
func(status *gtsmodel.Status) (*apimodel.Status, error) { return new(apimodel.Status), nil },
)
assert.NoError(t, err)
assert.Len(t, apiStatuses, page.Limit)
}
func TestStatusTimelineUnprepare(t *testing.T) { func TestStatusTimelineUnprepare(t *testing.T) {
var tt StatusTimeline var tt StatusTimeline
tt.Init(1000) tt.Init(1000)
@ -301,6 +346,44 @@ func TestStatusTimelineTrim(t *testing.T) {
assert.Equal(t, before, tt.cache.Len()) assert.Equal(t, before, tt.cache.Len())
} }
// loadStatusIDsFrom imitates loading of statuses of given IDs from the database, instead selecting
// statuses with appropriate IDs from the given slice of status meta, converting them to statuses.
func loadStatusIDsFrom(data []*StatusMeta) func(ids []string) ([]*gtsmodel.Status, error) {
return func(ids []string) ([]*gtsmodel.Status, error) {
var statuses []*gtsmodel.Status
for _, id := range ids {
i := slices.IndexFunc(data, func(s *StatusMeta) bool {
return s.ID == id
})
if i < 0 || i >= len(data) {
panic(fmt.Sprintf("could not find %s in %v", id, log.VarDump(data)))
}
statuses = append(statuses, &gtsmodel.Status{
ID: data[i].ID,
AccountID: data[i].AccountID,
BoostOfID: data[i].BoostOfID,
BoostOfAccountID: data[i].BoostOfAccountID,
})
}
return statuses, nil
}
}
// loadGeneratedStatusPage imitates loading of a given page of statuses,
// simply generating new statuses until the given page's limit is reached.
func loadGeneratedStatusPage(page *paging.Page) ([]*gtsmodel.Status, error) {
var statuses []*gtsmodel.Status
for range page.Limit {
statuses = append(statuses, &gtsmodel.Status{
ID: id.NewULID(),
AccountID: id.NewULID(),
BoostOfID: id.NewULID(),
BoostOfAccountID: id.NewULID(),
})
}
return statuses, nil
}
// containsStatusID returns whether timeline contains a status with ID. // containsStatusID returns whether timeline contains a status with ID.
func containsStatusID(t *StatusTimeline, id string) bool { func containsStatusID(t *StatusTimeline, id string) bool {
return getStatusByID(t, id) != nil return getStatusByID(t, id) != nil

View File

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"runtime" "runtime"
"strings"
"time" "time"
"code.superseriousbusiness.org/gotosocial/internal/gtscontext" "code.superseriousbusiness.org/gotosocial/internal/gtscontext"
@ -35,19 +36,21 @@ import (
// Logger returns a gin middleware which provides request logging and panic recovery. // Logger returns a gin middleware which provides request logging and panic recovery.
func Logger(logClientIP bool) gin.HandlerFunc { func Logger(logClientIP bool) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// Initialize the logging fields
fields := make(kv.Fields, 5, 7)
// Determine pre-handler time // Determine pre-handler time
before := time.Now() before := time.Now()
// defer so that we log *after the request has completed* // defer so that we log *after
// the request has completed*
defer func() { defer func() {
// Get response status code.
code := c.Writer.Status() code := c.Writer.Status()
path := c.Request.URL.Path
// Get request context.
ctx := c.Request.Context()
if r := recover(); r != nil { if r := recover(); r != nil {
if c.Writer.Status() == 0 { if code == 0 {
// No response was written, send a generic Internal Error // No response was written, send a generic Internal Error
c.Writer.WriteHeader(http.StatusInternalServerError) c.Writer.WriteHeader(http.StatusInternalServerError)
} }
@ -65,37 +68,51 @@ func Logger(logClientIP bool) gin.HandlerFunc {
WithField("stacktrace", callers).Error(err) WithField("stacktrace", callers).Error(err)
} }
// NOTE: // Initialize the logging fields
// It is very important here that we are ONLY logging fields := make(kv.Fields, 5, 8)
// the request path, and none of the query parameters.
// Query parameters can contain sensitive information
// and could lead to storing plaintext API keys in logs
// Set request logging fields // Set request logging fields
fields[0] = kv.Field{"latency", time.Since(before)} fields[0] = kv.Field{"latency", time.Since(before)}
fields[1] = kv.Field{"userAgent", c.Request.UserAgent()} fields[1] = kv.Field{"userAgent", c.Request.UserAgent()}
fields[2] = kv.Field{"method", c.Request.Method} fields[2] = kv.Field{"method", c.Request.Method}
fields[3] = kv.Field{"statusCode", code} fields[3] = kv.Field{"statusCode", code}
fields[4] = kv.Field{"path", path}
// Set optional request logging fields. // If the request contains sensitive query
// data only log path, else log entire URI.
if sensitiveQuery(c.Request.URL.RawQuery) {
path := c.Request.URL.Path
fields[4] = kv.Field{"uri", path}
} else {
uri := c.Request.RequestURI
fields[4] = kv.Field{"uri", uri}
}
if logClientIP { if logClientIP {
// Append IP only if configured to.
fields = append(fields, kv.Field{ fields = append(fields, kv.Field{
"clientIP", c.ClientIP(), "clientIP", c.ClientIP(),
}) })
} }
ctx := c.Request.Context()
if pubKeyID := gtscontext.HTTPSignaturePubKeyID(ctx); pubKeyID != nil { if pubKeyID := gtscontext.HTTPSignaturePubKeyID(ctx); pubKeyID != nil {
// Append public key ID if attached.
fields = append(fields, kv.Field{ fields = append(fields, kv.Field{
"pubKeyID", pubKeyID.String(), "pubKeyID", pubKeyID.String(),
}) })
} }
// Create log entry with fields if len(c.Errors) > 0 {
l := log.New() // Always attach any found errors.
l = l.WithContext(ctx) fields = append(fields, kv.Field{
l = l.WithFields(fields...) "errors", c.Errors,
})
}
// Create entry
// with fields.
l := log.New().
WithContext(ctx).
WithFields(fields...)
// Default is info // Default is info
lvl := log.INFO lvl := log.INFO
@ -105,11 +122,6 @@ func Logger(logClientIP bool) gin.HandlerFunc {
lvl = log.ERROR lvl = log.ERROR
} }
if len(c.Errors) > 0 {
// Always attach any found errors.
l = l.WithField("errors", c.Errors)
}
// Get appropriate text for this code. // Get appropriate text for this code.
statusText := http.StatusText(code) statusText := http.StatusText(code)
if statusText == "" { if statusText == "" {
@ -125,15 +137,22 @@ func Logger(logClientIP bool) gin.HandlerFunc {
// Generate a nicer looking bytecount // Generate a nicer looking bytecount
size := bytesize.Size(c.Writer.Size()) // #nosec G115 -- Just logging size := bytesize.Size(c.Writer.Size()) // #nosec G115 -- Just logging
// Finally, write log entry with status text + body size. // Write log entry with status text + body size.
l.Logf(lvl, "%s: wrote %s", statusText, size) l.Logf(lvl, "%s: wrote %s", statusText, size)
}() }()
// Process request // Process
// request.
c.Next() c.Next()
} }
} }
// sensitiveQuery checks whether given query string
// contains sensitive data that shouldn't be logged.
func sensitiveQuery(query string) bool {
return strings.Contains(query, "token")
}
// gatherFrames gathers runtime frames from a frame iterator. // gatherFrames gathers runtime frames from a frame iterator.
func gatherFrames(iter *runtime.Frames, n int) []runtime.Frame { func gatherFrames(iter *runtime.Frames, n int) []runtime.Frame {
if iter == nil { if iter == nil {

View File

@ -278,10 +278,10 @@ func (p *Page) ToLinkURL(proto, host, path string, queryParams url.Values) *url.
if queryParams == nil { if queryParams == nil {
// Allocate new query parameters. // Allocate new query parameters.
queryParams = make(url.Values) queryParams = make(url.Values, 2)
} else { } else {
// Before edit clone existing params. // Before edit clone existing params.
queryParams = cloneQuery(queryParams) queryParams = cloneQuery(queryParams, 2)
} }
if p.Min.Value != "" { if p.Min.Value != "" {
@ -309,8 +309,8 @@ func (p *Page) ToLinkURL(proto, host, path string, queryParams url.Values) *url.
} }
// cloneQuery clones input map of url values. // cloneQuery clones input map of url values.
func cloneQuery(src url.Values) url.Values { func cloneQuery(src url.Values, extra int) url.Values {
dst := make(url.Values, len(src)) dst := make(url.Values, len(src)+extra)
for k, vs := range src { for k, vs := range src {
dst[k] = slices.Clone(vs) dst[k] = slices.Clone(vs)
} }