refactor: store cache

This commit is contained in:
Steven
2025-05-27 22:06:41 +08:00
parent c23aebd648
commit ad2c5f0d05
14 changed files with 889 additions and 115 deletions

121
server/profiler/profiler.go Normal file
View File

@@ -0,0 +1,121 @@
package profiler
import (
"context"
"fmt"
"log/slog"
"net/http"
"net/http/pprof"
"runtime"
"time"
"github.com/labstack/echo/v4"
)
// Profiler provides HTTP endpoints for memory profiling
type Profiler struct {
memStatsLogInterval time.Duration
}
// NewProfiler creates a new profiler
func NewProfiler() *Profiler {
return &Profiler{
memStatsLogInterval: 1 * time.Minute,
}
}
// RegisterRoutes adds profiling endpoints to the Echo server
func (p *Profiler) RegisterRoutes(e *echo.Echo) {
// Register pprof handlers
g := e.Group("/debug/pprof")
g.GET("", echo.WrapHandler(http.HandlerFunc(pprof.Index)))
g.GET("/cmdline", echo.WrapHandler(http.HandlerFunc(pprof.Cmdline)))
g.GET("/profile", echo.WrapHandler(http.HandlerFunc(pprof.Profile)))
g.POST("/symbol", echo.WrapHandler(http.HandlerFunc(pprof.Symbol)))
g.GET("/symbol", echo.WrapHandler(http.HandlerFunc(pprof.Symbol)))
g.GET("/trace", echo.WrapHandler(http.HandlerFunc(pprof.Trace)))
g.GET("/allocs", echo.WrapHandler(http.HandlerFunc(pprof.Handler("allocs").ServeHTTP)))
g.GET("/block", echo.WrapHandler(http.HandlerFunc(pprof.Handler("block").ServeHTTP)))
g.GET("/goroutine", echo.WrapHandler(http.HandlerFunc(pprof.Handler("goroutine").ServeHTTP)))
g.GET("/heap", echo.WrapHandler(http.HandlerFunc(pprof.Handler("heap").ServeHTTP)))
g.GET("/mutex", echo.WrapHandler(http.HandlerFunc(pprof.Handler("mutex").ServeHTTP)))
g.GET("/threadcreate", echo.WrapHandler(http.HandlerFunc(pprof.Handler("threadcreate").ServeHTTP)))
// Add a custom memory stats endpoint
g.GET("/memstats", func(c echo.Context) error {
var m runtime.MemStats
runtime.ReadMemStats(&m)
return c.JSON(http.StatusOK, map[string]interface{}{
"alloc": m.Alloc,
"totalAlloc": m.TotalAlloc,
"sys": m.Sys,
"numGC": m.NumGC,
"heapAlloc": m.HeapAlloc,
"heapSys": m.HeapSys,
"heapInuse": m.HeapInuse,
"heapObjects": m.HeapObjects,
})
})
}
// StartMemoryMonitor starts a goroutine that periodically logs memory stats
func (p *Profiler) StartMemoryMonitor(ctx context.Context) {
go func() {
ticker := time.NewTicker(p.memStatsLogInterval)
defer ticker.Stop()
// Store previous heap allocation to track growth
var lastHeapAlloc uint64
var lastNumGC uint32
for {
select {
case <-ticker.C:
var m runtime.MemStats
runtime.ReadMemStats(&m)
// Calculate heap growth since last check
heapGrowth := int64(m.HeapAlloc) - int64(lastHeapAlloc)
gcCount := m.NumGC - lastNumGC
slog.Info("memory stats",
"heapAlloc", byteCountIEC(m.HeapAlloc),
"heapSys", byteCountIEC(m.HeapSys),
"heapObjects", m.HeapObjects,
"heapGrowth", byteCountIEC(uint64(heapGrowth)),
"numGoroutine", runtime.NumGoroutine(),
"numGC", m.NumGC,
"gcSince", gcCount,
"nextGC", byteCountIEC(m.NextGC),
"gcPause", time.Duration(m.PauseNs[(m.NumGC+255)%256]).String(),
)
// Track values for next iteration
lastHeapAlloc = m.HeapAlloc
lastNumGC = m.NumGC
// Force GC if memory usage is high to see if objects can be reclaimed
if m.HeapAlloc > 500*1024*1024 { // 500 MB threshold
slog.Info("forcing garbage collection due to high memory usage")
runtime.GC()
}
case <-ctx.Done():
return
}
}
}()
}
// byteCountIEC converts bytes to a human-readable string (MiB, GiB)
func byteCountIEC(b uint64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := uint64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(b)/float64(div), "KMGTPE"[exp])
}

View File

@@ -26,6 +26,7 @@ import (
"github.com/usememos/memos/plugin/storage/s3"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"
)
@@ -71,7 +72,7 @@ func (s *APIV1Service) CreateResource(ctx context.Context, request *v1pb.CreateR
}
create.Size = int64(size)
create.Blob = request.Resource.Content
if err := SaveResourceBlob(ctx, s.Store, create); err != nil {
if err := SaveResourceBlob(ctx, s.Profile, s.Store, create); err != nil {
return nil, status.Errorf(codes.Internal, "failed to save resource blob: %v", err)
}
@@ -286,8 +287,8 @@ func (s *APIV1Service) convertResourceFromStore(ctx context.Context, resource *s
}
// SaveResourceBlob save the blob of resource based on the storage config.
func SaveResourceBlob(ctx context.Context, s *store.Store, create *store.Resource) error {
workspaceStorageSetting, err := s.GetWorkspaceStorageSetting(ctx)
func SaveResourceBlob(ctx context.Context, profile *profile.Profile, stores *store.Store, create *store.Resource) error {
workspaceStorageSetting, err := stores.GetWorkspaceStorageSetting(ctx)
if err != nil {
return errors.Wrap(err, "Failed to find workspace storage setting")
}
@@ -308,7 +309,7 @@ func SaveResourceBlob(ctx context.Context, s *store.Store, create *store.Resourc
// Ensure the directory exists.
osPath := filepath.FromSlash(internalPath)
if !filepath.IsAbs(osPath) {
osPath = filepath.Join(s.Profile.Data, osPath)
osPath = filepath.Join(profile.Data, osPath)
}
dir := filepath.Dir(osPath)
if err = os.MkdirAll(dir, os.ModePerm); err != nil {

View File

@@ -3,6 +3,7 @@ package memopayload
import (
"context"
"log/slog"
"runtime"
"slices"
"github.com/pkg/errors"
@@ -26,23 +27,52 @@ func NewRunner(store *store.Store) *Runner {
// RunOnce rebuilds the payload of all memos.
func (r *Runner) RunOnce(ctx context.Context) {
memos, err := r.Store.ListMemos(ctx, &store.FindMemo{})
if err != nil {
slog.Error("failed to list memos", "err", err)
return
}
// Process memos in batches to avoid loading all memos into memory at once
const batchSize = 100
offset := 0
processed := 0
for _, memo := range memos {
if err := RebuildMemoPayload(memo); err != nil {
slog.Error("failed to rebuild memo payload", "err", err)
continue
for {
limit := batchSize
memos, err := r.Store.ListMemos(ctx, &store.FindMemo{
Limit: &limit,
Offset: &offset,
})
if err != nil {
slog.Error("failed to list memos", "err", err)
return
}
if err := r.Store.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
Payload: memo.Payload,
}); err != nil {
slog.Error("failed to update memo", "err", err)
// Break if no more memos
if len(memos) == 0 {
break
}
// Process batch
batchSuccessCount := 0
for _, memo := range memos {
if err := RebuildMemoPayload(memo); err != nil {
slog.Error("failed to rebuild memo payload", "err", err, "memoID", memo.ID)
continue
}
if err := r.Store.UpdateMemo(ctx, &store.UpdateMemo{
ID: memo.ID,
Payload: memo.Payload,
}); err != nil {
slog.Error("failed to update memo", "err", err, "memoID", memo.ID)
continue
}
batchSuccessCount++
}
processed += len(memos)
slog.Info("Processed memo batch", "batchSize", len(memos), "successCount", batchSuccessCount, "totalProcessed", processed)
// Move to next batch
offset += len(memos)
// Force garbage collection between batches to prevent memory accumulation
runtime.GC()
}
}

View File

@@ -3,6 +3,7 @@ package s3presign
import (
"context"
"log/slog"
"runtime"
"time"
"google.golang.org/protobuf/types/known/timestamppb"
@@ -50,59 +51,88 @@ func (r *Runner) CheckAndPresign(ctx context.Context) {
}
s3StorageType := storepb.ResourceStorageType_S3
resources, err := r.Store.ListResources(ctx, &store.FindResource{
GetBlob: false,
StorageType: &s3StorageType,
})
if err != nil {
return
}
// Limit resources to a reasonable batch size
const batchSize = 100
offset := 0
for _, resource := range resources {
s3ObjectPayload := resource.Payload.GetS3Object()
if s3ObjectPayload == nil {
continue
for {
limit := batchSize
resources, err := r.Store.ListResources(ctx, &store.FindResource{
GetBlob: false,
StorageType: &s3StorageType,
Limit: &limit,
Offset: &offset,
})
if err != nil {
slog.Error("Failed to list resources for presigning", "error", err)
return
}
if s3ObjectPayload.LastPresignedTime != nil {
// Skip if the presigned URL is still valid for the next 4 days.
// The expiration time is set to 5 days.
if time.Now().Before(s3ObjectPayload.LastPresignedTime.AsTime().Add(4 * 24 * time.Hour)) {
// Break if no more resources
if len(resources) == 0 {
break
}
// Process batch of resources
presignCount := 0
for _, resource := range resources {
s3ObjectPayload := resource.Payload.GetS3Object()
if s3ObjectPayload == nil {
continue
}
}
s3Config := workspaceStorageSetting.GetS3Config()
if s3ObjectPayload.S3Config != nil {
s3Config = s3ObjectPayload.S3Config
}
if s3Config == nil {
slog.Error("S3 config is not found")
continue
}
if s3ObjectPayload.LastPresignedTime != nil {
// Skip if the presigned URL is still valid for the next 4 days.
// The expiration time is set to 5 days.
if time.Now().Before(s3ObjectPayload.LastPresignedTime.AsTime().Add(4 * 24 * time.Hour)) {
continue
}
}
s3Client, err := s3.NewClient(ctx, s3Config)
if err != nil {
slog.Error("Failed to create S3 client", "error", err)
continue
}
s3Config := workspaceStorageSetting.GetS3Config()
if s3ObjectPayload.S3Config != nil {
s3Config = s3ObjectPayload.S3Config
}
if s3Config == nil {
slog.Error("S3 config is not found")
continue
}
presignURL, err := s3Client.PresignGetObject(ctx, s3ObjectPayload.Key)
if err != nil {
return
}
s3ObjectPayload.S3Config = s3Config
s3ObjectPayload.LastPresignedTime = timestamppb.New(time.Now())
if err := r.Store.UpdateResource(ctx, &store.UpdateResource{
ID: resource.ID,
Reference: &presignURL,
Payload: &storepb.ResourcePayload{
Payload: &storepb.ResourcePayload_S3Object_{
S3Object: s3ObjectPayload,
s3Client, err := s3.NewClient(ctx, s3Config)
if err != nil {
slog.Error("Failed to create S3 client", "error", err)
continue
}
presignURL, err := s3Client.PresignGetObject(ctx, s3ObjectPayload.Key)
if err != nil {
slog.Error("Failed to presign URL", "error", err, "resourceID", resource.ID)
continue
}
s3ObjectPayload.S3Config = s3Config
s3ObjectPayload.LastPresignedTime = timestamppb.New(time.Now())
if err := r.Store.UpdateResource(ctx, &store.UpdateResource{
ID: resource.ID,
Reference: &presignURL,
Payload: &storepb.ResourcePayload{
Payload: &storepb.ResourcePayload_S3Object_{
S3Object: s3ObjectPayload,
},
},
},
}); err != nil {
return
}); err != nil {
slog.Error("Failed to update resource", "error", err, "resourceID", resource.ID)
continue
}
presignCount++
}
slog.Info("Presigned batch of S3 resources", "batchSize", len(resources), "presigned", presignCount)
// Move to next batch
offset += len(resources)
// Prevent memory accumulation between batches
runtime.GC()
}
}

View File

@@ -7,6 +7,7 @@ import (
"math"
"net"
"net/http"
"runtime"
"time"
"github.com/google/uuid"
@@ -19,6 +20,7 @@ import (
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/server/profiler"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/server/router/frontend"
"github.com/usememos/memos/server/router/rss"
@@ -32,8 +34,10 @@ type Server struct {
Profile *profile.Profile
Store *store.Store
echoServer *echo.Echo
grpcServer *grpc.Server
echoServer *echo.Echo
grpcServer *grpc.Server
profiler *profiler.Profiler
runnerCancelFuncs []context.CancelFunc
}
func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store) (*Server, error) {
@@ -49,6 +53,11 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
echoServer.Use(middleware.Recover())
s.echoServer = echoServer
// Initialize profiler
s.profiler = profiler.NewProfiler()
s.profiler.RegisterRoutes(echoServer)
s.profiler.StartMemoryMonitor(ctx)
workspaceBasicSetting, err := s.getOrUpsertWorkspaceBasicSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get workspace basic setting")
@@ -134,6 +143,15 @@ func (s *Server) Shutdown(ctx context.Context) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
slog.Info("server shutting down")
// Cancel all background runners
for _, cancelFunc := range s.runnerCancelFuncs {
if cancelFunc != nil {
cancelFunc()
}
}
// Shutdown echo server.
if err := s.echoServer.Shutdown(ctx); err != nil {
slog.Error("failed to shutdown server", slog.String("error", err.Error()))
@@ -142,6 +160,23 @@ func (s *Server) Shutdown(ctx context.Context) {
// Shutdown gRPC server.
s.grpcServer.GracefulStop()
// Stop the profiler
if s.profiler != nil {
slog.Info("stopping profiler")
// Force one last garbage collection to clean up remaining objects
runtime.GC()
// Log final memory stats
var m runtime.MemStats
runtime.ReadMemStats(&m)
slog.Info("final memory stats before exit",
"heapAlloc", m.Alloc,
"heapSys", m.Sys,
"heapObjects", m.HeapObjects,
"numGoroutine", runtime.NumGoroutine(),
)
}
// Close database connection.
if err := s.Store.Close(); err != nil {
slog.Error("failed to close database", slog.String("error", err.Error()))
@@ -151,13 +186,30 @@ func (s *Server) Shutdown(ctx context.Context) {
}
func (s *Server) StartBackgroundRunners(ctx context.Context) {
// Create a separate context for each background runner
// This allows us to control cancellation for each runner independently
s3Context, s3Cancel := context.WithCancel(ctx)
// Store the cancel function so we can properly shut down runners
s.runnerCancelFuncs = append(s.runnerCancelFuncs, s3Cancel)
// Create and start S3 presign runner
s3presignRunner := s3presign.NewRunner(s.Store)
s3presignRunner.RunOnce(ctx)
// Create and start memo payload runner just once
memopayloadRunner := memopayload.NewRunner(s.Store)
// Rebuild all memos' payload after server starts.
memopayloadRunner.RunOnce(ctx)
go s3presignRunner.Run(ctx)
// Start continuous S3 presign runner
go func() {
s3presignRunner.Run(s3Context)
slog.Info("s3presign runner stopped")
}()
// Log the number of goroutines running
slog.Info("background runners started", "goroutines", runtime.NumGoroutine())
}
func (s *Server) getOrUpsertWorkspaceBasicSetting(ctx context.Context) (*storepb.WorkspaceBasicSetting, error) {

311
store/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,311 @@
package cache
import (
"context"
"sync"
"sync/atomic"
"time"
)
// Interface defines the operations a cache must support
type Interface interface {
// Set adds a value to the cache with the default TTL
Set(ctx context.Context, key string, value interface{})
// SetWithTTL adds a value to the cache with a custom TTL
SetWithTTL(ctx context.Context, key string, value interface{}, ttl time.Duration)
// Get retrieves a value from the cache
Get(ctx context.Context, key string) (interface{}, bool)
// Delete removes a value from the cache
Delete(ctx context.Context, key string)
// Clear removes all values from the cache
Clear(ctx context.Context)
// Size returns the number of items in the cache
Size() int64
// Close stops all background tasks and releases resources
Close() error
}
// item represents a cached value with metadata
type item struct {
value interface{}
expiration time.Time
size int // Approximate size in bytes
}
// Config contains options for configuring a cache
type Config struct {
// DefaultTTL is the default time-to-live for cache entries
DefaultTTL time.Duration
// CleanupInterval is how often the cache runs cleanup
CleanupInterval time.Duration
// MaxItems is the maximum number of items allowed in the cache
MaxItems int
// OnEviction is called when an item is evicted from the cache
OnEviction func(key string, value interface{})
}
// DefaultConfig returns a default configuration for the cache
func DefaultConfig() Config {
return Config{
DefaultTTL: 10 * time.Minute,
CleanupInterval: 5 * time.Minute,
MaxItems: 1000,
OnEviction: nil,
}
}
// Cache is a thread-safe in-memory cache with TTL and memory management
type Cache struct {
data sync.Map
config Config
itemCount int64 // Use atomic operations to track item count
stopChan chan struct{}
closedChan chan struct{}
}
// New creates a new memory cache with the given configuration
func New(config Config) *Cache {
c := &Cache{
config: config,
stopChan: make(chan struct{}),
closedChan: make(chan struct{}),
}
go c.cleanupLoop()
return c
}
// NewDefault creates a new memory cache with default configuration
func NewDefault() *Cache {
return New(DefaultConfig())
}
// Set adds a value to the cache with the default TTL
func (c *Cache) Set(ctx context.Context, key string, value interface{}) {
c.SetWithTTL(ctx, key, value, c.config.DefaultTTL)
}
// SetWithTTL adds a value to the cache with a custom TTL
func (c *Cache) SetWithTTL(ctx context.Context, key string, value interface{}, ttl time.Duration) {
// Estimate size of the item (very rough approximation)
size := estimateSize(value)
// Check if item already exists to avoid double counting
if _, exists := c.data.Load(key); exists {
c.data.Delete(key)
// Don't decrement count - we'll replace it
} else {
// Only increment if this is a new key
atomic.AddInt64(&c.itemCount, 1)
}
c.data.Store(key, item{
value: value,
expiration: time.Now().Add(ttl),
size: size,
})
// If we're over the max items, clean up old items
if c.config.MaxItems > 0 && atomic.LoadInt64(&c.itemCount) > int64(c.config.MaxItems) {
c.cleanupOldest()
}
}
// Get retrieves a value from the cache
func (c *Cache) Get(ctx context.Context, key string) (interface{}, bool) {
value, ok := c.data.Load(key)
if !ok {
return nil, false
}
itm := value.(item)
if time.Now().After(itm.expiration) {
c.data.Delete(key)
atomic.AddInt64(&c.itemCount, -1)
if c.config.OnEviction != nil {
c.config.OnEviction(key, itm.value)
}
return nil, false
}
return itm.value, true
}
// Delete removes a value from the cache
func (c *Cache) Delete(ctx context.Context, key string) {
if value, loaded := c.data.LoadAndDelete(key); loaded {
atomic.AddInt64(&c.itemCount, -1)
if c.config.OnEviction != nil {
itm := value.(item)
c.config.OnEviction(key, itm.value)
}
}
}
// Clear removes all values from the cache
func (c *Cache) Clear(ctx context.Context) {
if c.config.OnEviction != nil {
c.data.Range(func(key, value interface{}) bool {
itm := value.(item)
c.config.OnEviction(key.(string), itm.value)
return true
})
}
c.data = sync.Map{}
atomic.StoreInt64(&c.itemCount, 0)
}
// Size returns the number of items in the cache
func (c *Cache) Size() int64 {
return atomic.LoadInt64(&c.itemCount)
}
// Close stops the cache cleanup goroutine
func (c *Cache) Close() error {
select {
case <-c.stopChan:
// Already closed
return nil
default:
close(c.stopChan)
<-c.closedChan // Wait for cleanup goroutine to exit
return nil
}
}
// cleanupLoop periodically cleans up expired items
func (c *Cache) cleanupLoop() {
ticker := time.NewTicker(c.config.CleanupInterval)
defer func() {
ticker.Stop()
close(c.closedChan)
}()
for {
select {
case <-ticker.C:
c.cleanup()
case <-c.stopChan:
return
}
}
}
// cleanup removes expired items
func (c *Cache) cleanup() {
evicted := make(map[string]interface{})
count := 0
c.data.Range(func(key, value interface{}) bool {
itm := value.(item)
if time.Now().After(itm.expiration) {
c.data.Delete(key)
count++
if c.config.OnEviction != nil {
evicted[key.(string)] = itm.value
}
}
return true
})
if count > 0 {
atomic.AddInt64(&c.itemCount, -int64(count))
// Call eviction callbacks outside the loop to avoid blocking the range
if c.config.OnEviction != nil {
for k, v := range evicted {
c.config.OnEviction(k, v)
}
}
}
}
// cleanupOldest removes the oldest items if we're over the max items
func (c *Cache) cleanupOldest() {
threshold := c.config.MaxItems / 5 // Remove 20% of max items at once
if threshold < 1 {
threshold = 1
}
currentCount := atomic.LoadInt64(&c.itemCount)
// If we're not over the threshold, don't do anything
if currentCount <= int64(c.config.MaxItems) {
return
}
// Find the oldest items
type keyExpPair struct {
key string
value interface{}
expiration time.Time
}
candidates := make([]keyExpPair, 0, threshold)
c.data.Range(func(key, value interface{}) bool {
itm := value.(item)
if len(candidates) < threshold {
candidates = append(candidates, keyExpPair{key.(string), itm.value, itm.expiration})
return true
}
// Find the newest item in candidates
newestIdx := 0
for i := 1; i < len(candidates); i++ {
if candidates[i].expiration.After(candidates[newestIdx].expiration) {
newestIdx = i
}
}
// Replace it if this item is older
if itm.expiration.Before(candidates[newestIdx].expiration) {
candidates[newestIdx] = keyExpPair{key.(string), itm.value, itm.expiration}
}
return true
})
// Delete the oldest items
deletedCount := 0
for _, candidate := range candidates {
c.data.Delete(candidate.key)
deletedCount++
if c.config.OnEviction != nil {
c.config.OnEviction(candidate.key, candidate.value)
}
}
// Update count
if deletedCount > 0 {
atomic.AddInt64(&c.itemCount, -int64(deletedCount))
}
}
// estimateSize attempts to estimate the memory footprint of a value
func estimateSize(value interface{}) int {
switch v := value.(type) {
case string:
return len(v) + 24 // base size + string overhead
case []byte:
return len(v) + 24 // base size + slice overhead
case map[string]interface{}:
return len(v) * 64 // rough estimate
default:
return 64 // default conservative estimate
}
}

209
store/cache/cache_test.go vendored Normal file
View File

@@ -0,0 +1,209 @@
package cache
import (
"context"
"fmt"
"sync"
"testing"
"time"
)
func TestCacheBasicOperations(t *testing.T) {
ctx := context.Background()
config := DefaultConfig()
config.DefaultTTL = 100 * time.Millisecond
config.CleanupInterval = 50 * time.Millisecond
cache := New(config)
defer cache.Close()
// Test Set and Get
cache.Set(ctx, "key1", "value1")
if val, ok := cache.Get(ctx, "key1"); !ok || val != "value1" {
t.Errorf("Expected 'value1', got %v, exists: %v", val, ok)
}
// Test SetWithTTL
cache.SetWithTTL(ctx, "key2", "value2", 200*time.Millisecond)
if val, ok := cache.Get(ctx, "key2"); !ok || val != "value2" {
t.Errorf("Expected 'value2', got %v, exists: %v", val, ok)
}
// Test Delete
cache.Delete(ctx, "key1")
if _, ok := cache.Get(ctx, "key1"); ok {
t.Errorf("Key 'key1' should have been deleted")
}
// Test automatic expiration
time.Sleep(150 * time.Millisecond)
if _, ok := cache.Get(ctx, "key1"); ok {
t.Errorf("Key 'key1' should have expired")
}
// key2 should still be valid (200ms TTL)
if _, ok := cache.Get(ctx, "key2"); !ok {
t.Errorf("Key 'key2' should still be valid")
}
// Wait for key2 to expire
time.Sleep(100 * time.Millisecond)
if _, ok := cache.Get(ctx, "key2"); ok {
t.Errorf("Key 'key2' should have expired")
}
// Test Clear
cache.Set(ctx, "key3", "value3")
cache.Clear(ctx)
if _, ok := cache.Get(ctx, "key3"); ok {
t.Errorf("Cache should be empty after Clear()")
}
}
func TestCacheEviction(t *testing.T) {
ctx := context.Background()
config := DefaultConfig()
config.MaxItems = 5
cache := New(config)
defer cache.Close()
// Add 5 items (max capacity)
for i := 0; i < 5; i++ {
key := fmt.Sprintf("key%d", i)
cache.Set(ctx, key, i)
}
// Verify all 5 items are in the cache
for i := 0; i < 5; i++ {
key := fmt.Sprintf("key%d", i)
if _, ok := cache.Get(ctx, key); !ok {
t.Errorf("Key '%s' should be in the cache", key)
}
}
// Add 2 more items to trigger eviction
cache.Set(ctx, "keyA", "valueA")
cache.Set(ctx, "keyB", "valueB")
// Verify size is still within limits
if cache.Size() > int64(config.MaxItems) {
t.Errorf("Cache size %d exceeds limit %d", cache.Size(), config.MaxItems)
}
// Some of the original keys should have been evicted
evictedCount := 0
for i := 0; i < 5; i++ {
key := fmt.Sprintf("key%d", i)
if _, ok := cache.Get(ctx, key); !ok {
evictedCount++
}
}
if evictedCount == 0 {
t.Errorf("No keys were evicted despite exceeding max items")
}
// The newer keys should still be present
if _, ok := cache.Get(ctx, "keyA"); !ok {
t.Errorf("Key 'keyA' should be in the cache")
}
if _, ok := cache.Get(ctx, "keyB"); !ok {
t.Errorf("Key 'keyB' should be in the cache")
}
}
func TestCacheConcurrency(t *testing.T) {
ctx := context.Background()
cache := NewDefault()
defer cache.Close()
const goroutines = 10
const operationsPerGoroutine = 100
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(id int) {
defer wg.Done()
baseKey := fmt.Sprintf("worker%d-", id)
// Set operations
for j := 0; j < operationsPerGoroutine; j++ {
key := fmt.Sprintf("%skey%d", baseKey, j)
value := fmt.Sprintf("value%d-%d", id, j)
cache.Set(ctx, key, value)
}
// Get operations
for j := 0; j < operationsPerGoroutine; j++ {
key := fmt.Sprintf("%skey%d", baseKey, j)
val, ok := cache.Get(ctx, key)
if !ok {
t.Errorf("Key '%s' should exist in cache", key)
continue
}
expected := fmt.Sprintf("value%d-%d", id, j)
if val != expected {
t.Errorf("For key '%s', expected '%s', got '%s'", key, expected, val)
}
}
// Delete half the keys
for j := 0; j < operationsPerGoroutine/2; j++ {
key := fmt.Sprintf("%skey%d", baseKey, j)
cache.Delete(ctx, key)
}
}(i)
}
wg.Wait()
// Verify size and deletion
var totalKeysExpected int64 = goroutines * operationsPerGoroutine / 2
if cache.Size() != totalKeysExpected {
t.Errorf("Expected cache size to be %d, got %d", totalKeysExpected, cache.Size())
}
}
func TestEvictionCallback(t *testing.T) {
ctx := context.Background()
evicted := make(map[string]interface{})
evictedMu := sync.Mutex{}
config := DefaultConfig()
config.DefaultTTL = 50 * time.Millisecond
config.CleanupInterval = 25 * time.Millisecond
config.OnEviction = func(key string, value interface{}) {
evictedMu.Lock()
evicted[key] = value
evictedMu.Unlock()
}
cache := New(config)
defer cache.Close()
// Add items
cache.Set(ctx, "key1", "value1")
cache.Set(ctx, "key2", "value2")
// Manually delete
cache.Delete(ctx, "key1")
// Verify manual deletion triggered callback
time.Sleep(10 * time.Millisecond) // Small delay to ensure callback processed
evictedMu.Lock()
if evicted["key1"] != "value1" {
t.Errorf("Eviction callback not triggered for manual deletion")
}
evictedMu.Unlock()
// Wait for automatic expiration
time.Sleep(60 * time.Millisecond)
// Verify TTL expiration triggered callback
evictedMu.Lock()
if evicted["key2"] != "value2" {
t.Errorf("Eviction callback not triggered for TTL expiration")
}
evictedMu.Unlock()
}

View File

@@ -46,7 +46,6 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *storepb.Iden
if err != nil {
return nil, err
}
s.idpCache.Store(identityProvider.Id, identityProvider)
return identityProvider, nil
}
@@ -63,21 +62,11 @@ func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityPro
return nil, err
}
identityProviders = append(identityProviders, identityProvider)
s.idpCache.Store(identityProvider.Id, identityProvider)
}
return identityProviders, nil
}
func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*storepb.IdentityProvider, error) {
if find.ID != nil {
if cache, ok := s.idpCache.Load(*find.ID); ok {
identityProvider, ok := cache.(*storepb.IdentityProvider)
if ok {
return identityProvider, nil
}
}
}
list, err := s.ListIdentityProviders(ctx, find)
if err != nil {
return nil, err
@@ -127,7 +116,6 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti
if err != nil {
return nil, err
}
s.idpCache.Store(identityProvider.Id, identityProvider)
return identityProvider, nil
}
@@ -136,8 +124,6 @@ func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdenti
if err != nil {
return err
}
s.idpCache.Delete(delete.ID)
return nil
}

View File

@@ -39,7 +39,7 @@ func (s *Store) Migrate(ctx context.Context) error {
return errors.Wrap(err, "failed to pre-migrate")
}
if s.Profile.Mode == "prod" {
if s.profile.Mode == "prod" {
migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
if err != nil {
return errors.Wrap(err, "failed to find migration history")
@@ -107,7 +107,7 @@ func (s *Store) Migrate(ctx context.Context) error {
return errors.Wrap(err, "failed to update current schema version")
}
}
} else if s.Profile.Mode == "demo" {
} else if s.profile.Mode == "demo" {
// In demo mode, we should seed the database.
if err := s.seed(ctx); err != nil {
return errors.Wrap(err, "failed to seed")
@@ -157,7 +157,7 @@ func (s *Store) preMigrate(ctx context.Context) error {
return errors.Wrap(err, "failed to update current schema version")
}
}
if s.Profile.Mode == "prod" {
if s.profile.Mode == "prod" {
if err := s.normalizedMigrationHistoryList(ctx); err != nil {
return errors.Wrap(err, "failed to normalize migration history list")
}
@@ -166,16 +166,16 @@ func (s *Store) preMigrate(ctx context.Context) error {
}
func (s *Store) getMigrationBasePath() string {
return fmt.Sprintf("migration/%s/", s.Profile.Driver)
return fmt.Sprintf("migration/%s/", s.profile.Driver)
}
func (s *Store) getSeedBasePath() string {
return fmt.Sprintf("seed/%s/", s.Profile.Driver)
return fmt.Sprintf("seed/%s/", s.profile.Driver)
}
func (s *Store) seed(ctx context.Context) error {
// Only seed for SQLite.
if s.Profile.Driver != "sqlite" {
if s.profile.Driver != "sqlite" {
slog.Warn("seed is only supported for SQLite")
return nil
}
@@ -207,7 +207,7 @@ func (s *Store) seed(ctx context.Context) error {
}
func (s *Store) GetCurrentSchemaVersion() (string, error) {
currentVersion := version.GetCurrentVersion(s.Profile.Mode)
currentVersion := version.GetCurrentVersion(s.profile.Mode)
minorVersion := version.GetMinorVersion(currentVersion)
filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s%s/*.sql", s.getMigrationBasePath(), minorVersion))
if err != nil {

View File

@@ -74,6 +74,17 @@ func (s *Store) CreateResource(ctx context.Context, create *Resource) (*Resource
}
func (s *Store) ListResources(ctx context.Context, find *FindResource) ([]*Resource, error) {
// Set default limits to prevent loading too many resources at once
if find.Limit == nil && find.GetBlob {
// When fetching blobs, we should be especially careful with limits
defaultLimit := 10
find.Limit = &defaultLimit
} else if find.Limit == nil {
// Even without blobs, let's default to a reasonable limit
defaultLimit := 100
find.Limit = &defaultLimit
}
return s.driver.ListResources(ctx, find)
}
@@ -110,7 +121,7 @@ func (s *Store) DeleteResource(ctx context.Context, delete *DeleteResource) erro
if err := func() error {
p := filepath.FromSlash(resource.Reference)
if !filepath.IsAbs(p) {
p = filepath.Join(s.Profile.Data, p)
p = filepath.Join(s.profile.Data, p)
}
err := os.Remove(p)
if err != nil {

View File

@@ -1,27 +1,46 @@
package store
import (
"sync"
"time"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store/cache"
)
// Store provides database access to all raw objects.
type Store struct {
Profile *profile.Profile
driver Driver
workspaceSettingCache sync.Map // map[string]*storepb.WorkspaceSetting
userCache sync.Map // map[int]*User
userSettingCache sync.Map // map[string]*storepb.UserSetting
idpCache sync.Map // map[int]*storepb.IdentityProvider
profile *profile.Profile
driver Driver
// Cache settings
cacheConfig cache.Config
// Caches
workspaceSettingCache *cache.Cache // cache for workspace settings
userCache *cache.Cache // cache for users
userSettingCache *cache.Cache // cache for user settings
}
// New creates a new instance of Store.
func New(driver Driver, profile *profile.Profile) *Store {
return &Store{
driver: driver,
Profile: profile,
// Default cache settings
cacheConfig := cache.Config{
DefaultTTL: 10 * time.Minute,
CleanupInterval: 5 * time.Minute,
MaxItems: 1000,
OnEviction: nil,
}
store := &Store{
driver: driver,
profile: profile,
cacheConfig: cacheConfig,
workspaceSettingCache: cache.New(cacheConfig),
userCache: cache.New(cacheConfig),
userSettingCache: cache.New(cacheConfig),
}
return store
}
func (s *Store) GetDriver() Driver {
@@ -29,5 +48,10 @@ func (s *Store) GetDriver() Driver {
}
func (s *Store) Close() error {
// Stop all cache cleanup goroutines
s.workspaceSettingCache.Close()
s.userCache.Close()
s.userSettingCache.Close()
return s.driver.Close()
}

View File

@@ -97,7 +97,7 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) {
return nil, err
}
s.userCache.Store(user.ID, user)
s.userCache.Set(ctx, string(user.ID), user)
return user, nil
}
@@ -107,7 +107,7 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro
return nil, err
}
s.userCache.Store(user.ID, user)
s.userCache.Set(ctx, string(user.ID), user)
return user, nil
}
@@ -118,7 +118,7 @@ func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error)
}
for _, user := range list {
s.userCache.Store(user.ID, user)
s.userCache.Set(ctx, string(user.ID), user)
}
return list, nil
}
@@ -128,7 +128,7 @@ func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) {
if *find.ID == SystemBotID {
return SystemBot, nil
}
if cache, ok := s.userCache.Load(*find.ID); ok {
if cache, ok := s.userCache.Get(ctx, string(*find.ID)); ok {
user, ok := cache.(*User)
if ok {
return user, nil
@@ -145,7 +145,7 @@ func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) {
}
user := list[0]
s.userCache.Store(user.ID, user)
s.userCache.Set(ctx, string(user.ID), user)
return user, nil
}
@@ -154,7 +154,6 @@ func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error {
if err != nil {
return err
}
s.userCache.Delete(delete.ID)
s.userCache.Delete(ctx, string(delete.ID))
return nil
}

View File

@@ -37,7 +37,7 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetti
if userSetting == nil {
return nil, errors.New("unexpected nil user setting")
}
s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
return userSetting, nil
}
@@ -56,7 +56,7 @@ func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]
if userSetting == nil {
continue
}
s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
userSettings = append(userSettings, userSetting)
}
return userSettings, nil
@@ -64,7 +64,7 @@ func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]
func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*storepb.UserSetting, error) {
if find.UserID != nil {
if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key.String())); ok {
if cache, ok := s.userSettingCache.Get(ctx, getUserSettingCacheKey(*find.UserID, find.Key.String())); ok {
userSetting, ok := cache.(*storepb.UserSetting)
if ok {
return userSetting, nil
@@ -84,7 +84,7 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*sto
}
userSetting := list[0]
s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
return userSetting, nil
}

View File

@@ -53,7 +53,7 @@ func (s *Store) UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.Work
if err != nil {
return nil, errors.Wrap(err, "Failed to convert workspace setting")
}
s.workspaceSettingCache.Store(workspaceSetting.Key.String(), workspaceSetting)
s.workspaceSettingCache.Set(ctx, workspaceSetting.Key.String(), workspaceSetting)
return workspaceSetting, nil
}
@@ -72,14 +72,14 @@ func (s *Store) ListWorkspaceSettings(ctx context.Context, find *FindWorkspaceSe
if workspaceSetting == nil {
continue
}
s.workspaceSettingCache.Store(workspaceSetting.Key.String(), workspaceSetting)
s.workspaceSettingCache.Set(ctx, workspaceSetting.Key.String(), workspaceSetting)
workspaceSettings = append(workspaceSettings, workspaceSetting)
}
return workspaceSettings, nil
}
func (s *Store) GetWorkspaceSetting(ctx context.Context, find *FindWorkspaceSetting) (*storepb.WorkspaceSetting, error) {
if cache, ok := s.workspaceSettingCache.Load(find.Name); ok {
if cache, ok := s.workspaceSettingCache.Get(ctx, find.Name); ok {
workspaceSetting, ok := cache.(*storepb.WorkspaceSetting)
if ok {
return workspaceSetting, nil
@@ -111,7 +111,7 @@ func (s *Store) GetWorkspaceBasicSetting(ctx context.Context) (*storepb.Workspac
if workspaceSetting != nil {
workspaceBasicSetting = workspaceSetting.GetBasicSetting()
}
s.workspaceSettingCache.Store(storepb.WorkspaceSettingKey_BASIC.String(), &storepb.WorkspaceSetting{
s.workspaceSettingCache.Set(ctx, storepb.WorkspaceSettingKey_BASIC.String(), &storepb.WorkspaceSetting{
Key: storepb.WorkspaceSettingKey_BASIC,
Value: &storepb.WorkspaceSetting_BasicSetting{BasicSetting: workspaceBasicSetting},
})
@@ -130,7 +130,7 @@ func (s *Store) GetWorkspaceGeneralSetting(ctx context.Context) (*storepb.Worksp
if workspaceSetting != nil {
workspaceGeneralSetting = workspaceSetting.GetGeneralSetting()
}
s.workspaceSettingCache.Store(storepb.WorkspaceSettingKey_GENERAL.String(), &storepb.WorkspaceSetting{
s.workspaceSettingCache.Set(ctx, storepb.WorkspaceSettingKey_GENERAL.String(), &storepb.WorkspaceSetting{
Key: storepb.WorkspaceSettingKey_GENERAL,
Value: &storepb.WorkspaceSetting_GeneralSetting{GeneralSetting: workspaceGeneralSetting},
})
@@ -167,7 +167,7 @@ func (s *Store) GetWorkspaceMemoRelatedSetting(ctx context.Context) (*storepb.Wo
if len(workspaceMemoRelatedSetting.NsfwTags) == 0 {
workspaceMemoRelatedSetting.NsfwTags = append(workspaceMemoRelatedSetting.NsfwTags, DefaultNsfwTags...)
}
s.workspaceSettingCache.Store(storepb.WorkspaceSettingKey_MEMO_RELATED.String(), &storepb.WorkspaceSetting{
s.workspaceSettingCache.Set(ctx, storepb.WorkspaceSettingKey_MEMO_RELATED.String(), &storepb.WorkspaceSetting{
Key: storepb.WorkspaceSettingKey_MEMO_RELATED,
Value: &storepb.WorkspaceSetting_MemoRelatedSetting{MemoRelatedSetting: workspaceMemoRelatedSetting},
})
@@ -201,7 +201,7 @@ func (s *Store) GetWorkspaceStorageSetting(ctx context.Context) (*storepb.Worksp
if workspaceStorageSetting.FilepathTemplate == "" {
workspaceStorageSetting.FilepathTemplate = defaultWorkspaceFilepathTemplate
}
s.workspaceSettingCache.Store(storepb.WorkspaceSettingKey_STORAGE.String(), &storepb.WorkspaceSetting{
s.workspaceSettingCache.Set(ctx, storepb.WorkspaceSettingKey_STORAGE.String(), &storepb.WorkspaceSetting{
Key: storepb.WorkspaceSettingKey_STORAGE,
Value: &storepb.WorkspaceSetting_StorageSetting{StorageSetting: workspaceStorageSetting},
})