diff --git a/internal/cache/timeline/preload.go b/internal/cache/timeline/preload.go index c2daa2fc6..41f826eaf 100644 --- a/internal/cache/timeline/preload.go +++ b/internal/cache/timeline/preload.go @@ -60,23 +60,23 @@ func (p *preloader) Check() bool { // CheckPreload will safely check the preload state, // and if needed call the provided function. if a // preload is in progress, it will wait until complete. -func (p *preloader) CheckPreload(preload func(*any)) { +func (p *preloader) CheckPreload(preload func() error) error { for { // Get state ptr. ptr := p.p.Load() if ptr == nil || *ptr == false { // Needs preloading, start it. - ok := p.start(ptr, preload) - + ok, err := p.start(ptr, preload) if !ok { + // Failed to acquire start, // other thread beat us to it. continue } - // Success! - return + // We ran! + return err } // Check for a preload currently in progress. @@ -85,52 +85,65 @@ func (p *preloader) CheckPreload(preload func(*any)) { continue } - // Anything else - // means success. - return + // Anything else means + // already preloaded. + return nil } } -// start attempts to start the given preload function, by performing -// a compare and swap operation with 'old'. return is success. -func (p *preloader) start(old *any, preload func(*any)) bool { +// start will attempt to acquire start state of the preloader, on success calling 'preload'. +// this returns whether start was acquired, and if called returns 'preload' error. in the +// case that 'preload' is called, returned error determines the next state that preloader +// will update itself to. (err == nil) => "preloaded", (err != nil) => "needs preload". +// NOTE: this is the only function that may unset an in-progress sync.WaitGroup value. +func (p *preloader) start(old *any, preload func() error) (started bool, err error) { // Optimistically setup a // new waitgroup to set as // the preload waiter. var wg sync.WaitGroup wg.Add(1) - defer wg.Done() // Wrap waitgroup in // 'any' for pointer. - new := any(&wg) - ptr := &new + a := any(&wg) + ptr := &a // Attempt CAS operation to claim start. - started := p.p.CompareAndSwap(old, ptr) + started = p.p.CompareAndSwap(old, ptr) if !started { - return false + return false, nil } - // Start. - preload(ptr) - return true -} + defer func() { + // Release. + wg.Done() -// done marks state as preloaded, -// i.e. no more preload required. -func (p *preloader) Done(ptr *any) { - if !p.p.CompareAndSwap(ptr, new(any)) { - log.Errorf(nil, "BUG: invalid preloader state: %#v", (*p.p.Load())) - } + var ok bool + if err != nil { + // Preload failed, + // drop waiter ptr. + a := any(false) + ok = p.p.CompareAndSwap(ptr, &a) + } else { + // Preload success, set success value. + ok = p.p.CompareAndSwap(ptr, new(any)) + } + + if !ok { + log.Errorf(nil, "BUG: invalid preloader state: %#v", (*p.p.Load())) + } + }() + + // Perform preload. + err = preload() + return } // clear will clear the state, marking a "preload" as required. // i.e. next call to Check() will call provided preload func. func (p *preloader) Clear() { - b := false - a := any(b) + a := any(false) for { // Load current ptr. ptr := p.p.Load() diff --git a/internal/cache/timeline/status.go b/internal/cache/timeline/status.go index c0c394042..59c339c91 100644 --- a/internal/cache/timeline/status.go +++ b/internal/cache/timeline/status.go @@ -196,14 +196,9 @@ func (t *StatusTimeline) Preload( n int, err error, ) { - t.preloader.CheckPreload(func(ptr *any) { + err = t.preloader.CheckPreload(func() error { n, err = t.preload(loadPage, filter) - if err != nil { - return - } - - // Mark as preloaded. - t.preloader.Done(ptr) + return err }) return } diff --git a/internal/cache/timeline/status_test.go b/internal/cache/timeline/status_test.go index fc7e43da8..6d513032a 100644 --- a/internal/cache/timeline/status_test.go +++ b/internal/cache/timeline/status_test.go @@ -19,9 +19,12 @@ package timeline import ( "context" + "errors" "fmt" "slices" + "sync/atomic" "testing" + "time" apimodel "code.superseriousbusiness.org/gotosocial/internal/api/model" "code.superseriousbusiness.org/gotosocial/internal/gtsmodel" @@ -65,6 +68,87 @@ var testStatusMeta = []*StatusMeta{ }, } +func TestStatusTimelinePreloader(t *testing.T) { + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + var tt StatusTimeline + tt.Init(1000) + + // Start goroutine to add some + // concurrent usage to preloader. + var started atomic.Int32 + go func() { + for { + select { + case <-ctx.Done(): + return + default: + } + tt.preloader.Check() + started.Add(1) + } + }() + + // Wait until goroutine running. + for started.Load() == 0 { + time.Sleep(time.Millisecond) + } + + // Variable to check whether + // our hook funcs are called. + var called bool + reset := func() { called = false } + + // "no error" preloader hook. + preloadNoErr := func() error { + called = true + return nil + } + + // "error return" preloader hook. + preloadErr := func() error { + called = true + return errors.New("oh no") + } + + // Check that on fail does not mark as preloaded. + err := tt.preloader.CheckPreload(preloadErr) + assert.Error(t, err) + assert.False(t, tt.preloader.Check()) + assert.True(t, called) + reset() + + // Check that on success marks itself as preloaded. + err = tt.preloader.CheckPreload(preloadNoErr) + assert.NoError(t, err) + assert.True(t, tt.preloader.Check()) + assert.True(t, called) + reset() + + // Check that preload func not called again + // if it's already in the 'preloaded' state. + err = tt.preloader.CheckPreload(preloadErr) + assert.NoError(t, err) + assert.True(t, tt.preloader.Check()) + assert.False(t, called) + reset() + + // Ensure that a clear operation + // successfully unsets preloader. + tt.preloader.Clear() + assert.False(t, tt.preloader.Check()) + assert.False(t, called) + reset() + + // Ensure that it can be marked as preloaded again. + err = tt.preloader.CheckPreload(preloadNoErr) + assert.NoError(t, err) + assert.True(t, tt.preloader.Check()) + assert.True(t, called) + reset() +} + func TestStatusTimelineLoadLimit(t *testing.T) { var tt StatusTimeline tt.Init(1000) @@ -80,7 +164,7 @@ func TestStatusTimelineLoadLimit(t *testing.T) { _ = tt.cache.Insert(data...) // Manually mark timeline as 'preloaded'. - tt.preloader.CheckPreload(tt.preloader.Done) + tt.preloader.CheckPreload(func() error { return nil }) // Craft a new page for selection, // setting placeholder min / max values @@ -251,7 +335,7 @@ func TestStatusTimelineInserts(t *testing.T) { assert.Equal(t, maxID, maxStatus(&tt).ID) // Manually mark timeline as 'preloaded'. - tt.preloader.CheckPreload(tt.preloader.Done) + tt.preloader.CheckPreload(func() error { return nil }) // Specifically craft a boost of latest (i.e. max) status in timeline. boost := >smodel.Status{ID: "06B1A00PQWDZZH9WK9P5VND35C", BoostOfID: maxID}