diff --git a/server/profiler/profiler.go b/server/profiler/profiler.go new file mode 100644 index 00000000..bc3a80ef --- /dev/null +++ b/server/profiler/profiler.go @@ -0,0 +1,121 @@ +package profiler + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "net/http/pprof" + "runtime" + "time" + + "github.com/labstack/echo/v4" +) + +// Profiler provides HTTP endpoints for memory profiling +type Profiler struct { + memStatsLogInterval time.Duration +} + +// NewProfiler creates a new profiler +func NewProfiler() *Profiler { + return &Profiler{ + memStatsLogInterval: 1 * time.Minute, + } +} + +// RegisterRoutes adds profiling endpoints to the Echo server +func (p *Profiler) RegisterRoutes(e *echo.Echo) { + // Register pprof handlers + g := e.Group("/debug/pprof") + g.GET("", echo.WrapHandler(http.HandlerFunc(pprof.Index))) + g.GET("/cmdline", echo.WrapHandler(http.HandlerFunc(pprof.Cmdline))) + g.GET("/profile", echo.WrapHandler(http.HandlerFunc(pprof.Profile))) + g.POST("/symbol", echo.WrapHandler(http.HandlerFunc(pprof.Symbol))) + g.GET("/symbol", echo.WrapHandler(http.HandlerFunc(pprof.Symbol))) + g.GET("/trace", echo.WrapHandler(http.HandlerFunc(pprof.Trace))) + g.GET("/allocs", echo.WrapHandler(http.HandlerFunc(pprof.Handler("allocs").ServeHTTP))) + g.GET("/block", echo.WrapHandler(http.HandlerFunc(pprof.Handler("block").ServeHTTP))) + g.GET("/goroutine", echo.WrapHandler(http.HandlerFunc(pprof.Handler("goroutine").ServeHTTP))) + g.GET("/heap", echo.WrapHandler(http.HandlerFunc(pprof.Handler("heap").ServeHTTP))) + g.GET("/mutex", echo.WrapHandler(http.HandlerFunc(pprof.Handler("mutex").ServeHTTP))) + g.GET("/threadcreate", echo.WrapHandler(http.HandlerFunc(pprof.Handler("threadcreate").ServeHTTP))) + + // Add a custom memory stats endpoint + g.GET("/memstats", func(c echo.Context) error { + var m runtime.MemStats + runtime.ReadMemStats(&m) + return c.JSON(http.StatusOK, map[string]interface{}{ + "alloc": m.Alloc, + "totalAlloc": m.TotalAlloc, + "sys": m.Sys, + "numGC": m.NumGC, + "heapAlloc": m.HeapAlloc, + "heapSys": m.HeapSys, + "heapInuse": m.HeapInuse, + "heapObjects": m.HeapObjects, + }) + }) +} + +// StartMemoryMonitor starts a goroutine that periodically logs memory stats +func (p *Profiler) StartMemoryMonitor(ctx context.Context) { + go func() { + ticker := time.NewTicker(p.memStatsLogInterval) + defer ticker.Stop() + + // Store previous heap allocation to track growth + var lastHeapAlloc uint64 + var lastNumGC uint32 + + for { + select { + case <-ticker.C: + var m runtime.MemStats + runtime.ReadMemStats(&m) + + // Calculate heap growth since last check + heapGrowth := int64(m.HeapAlloc) - int64(lastHeapAlloc) + gcCount := m.NumGC - lastNumGC + + slog.Info("memory stats", + "heapAlloc", byteCountIEC(m.HeapAlloc), + "heapSys", byteCountIEC(m.HeapSys), + "heapObjects", m.HeapObjects, + "heapGrowth", byteCountIEC(uint64(heapGrowth)), + "numGoroutine", runtime.NumGoroutine(), + "numGC", m.NumGC, + "gcSince", gcCount, + "nextGC", byteCountIEC(m.NextGC), + "gcPause", time.Duration(m.PauseNs[(m.NumGC+255)%256]).String(), + ) + + // Track values for next iteration + lastHeapAlloc = m.HeapAlloc + lastNumGC = m.NumGC + + // Force GC if memory usage is high to see if objects can be reclaimed + if m.HeapAlloc > 500*1024*1024 { // 500 MB threshold + slog.Info("forcing garbage collection due to high memory usage") + runtime.GC() + } + case <-ctx.Done(): + return + } + } + }() +} + +// byteCountIEC converts bytes to a human-readable string (MiB, GiB) +func byteCountIEC(b uint64) string { + const unit = 1024 + if b < unit { + return fmt.Sprintf("%d B", b) + } + div, exp := uint64(unit), 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(b)/float64(div), "KMGTPE"[exp]) +} diff --git a/server/router/api/v1/resource_service.go b/server/router/api/v1/resource_service.go index 43db4d18..71514b94 100644 --- a/server/router/api/v1/resource_service.go +++ b/server/router/api/v1/resource_service.go @@ -26,6 +26,7 @@ import ( "github.com/usememos/memos/plugin/storage/s3" v1pb "github.com/usememos/memos/proto/gen/api/v1" storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" ) @@ -71,7 +72,7 @@ func (s *APIV1Service) CreateResource(ctx context.Context, request *v1pb.CreateR } create.Size = int64(size) create.Blob = request.Resource.Content - if err := SaveResourceBlob(ctx, s.Store, create); err != nil { + if err := SaveResourceBlob(ctx, s.Profile, s.Store, create); err != nil { return nil, status.Errorf(codes.Internal, "failed to save resource blob: %v", err) } @@ -286,8 +287,8 @@ func (s *APIV1Service) convertResourceFromStore(ctx context.Context, resource *s } // SaveResourceBlob save the blob of resource based on the storage config. -func SaveResourceBlob(ctx context.Context, s *store.Store, create *store.Resource) error { - workspaceStorageSetting, err := s.GetWorkspaceStorageSetting(ctx) +func SaveResourceBlob(ctx context.Context, profile *profile.Profile, stores *store.Store, create *store.Resource) error { + workspaceStorageSetting, err := stores.GetWorkspaceStorageSetting(ctx) if err != nil { return errors.Wrap(err, "Failed to find workspace storage setting") } @@ -308,7 +309,7 @@ func SaveResourceBlob(ctx context.Context, s *store.Store, create *store.Resourc // Ensure the directory exists. osPath := filepath.FromSlash(internalPath) if !filepath.IsAbs(osPath) { - osPath = filepath.Join(s.Profile.Data, osPath) + osPath = filepath.Join(profile.Data, osPath) } dir := filepath.Dir(osPath) if err = os.MkdirAll(dir, os.ModePerm); err != nil { diff --git a/server/runner/memopayload/runner.go b/server/runner/memopayload/runner.go index 9526dfb9..6dc34637 100644 --- a/server/runner/memopayload/runner.go +++ b/server/runner/memopayload/runner.go @@ -3,6 +3,7 @@ package memopayload import ( "context" "log/slog" + "runtime" "slices" "github.com/pkg/errors" @@ -26,23 +27,52 @@ func NewRunner(store *store.Store) *Runner { // RunOnce rebuilds the payload of all memos. func (r *Runner) RunOnce(ctx context.Context) { - memos, err := r.Store.ListMemos(ctx, &store.FindMemo{}) - if err != nil { - slog.Error("failed to list memos", "err", err) - return - } + // Process memos in batches to avoid loading all memos into memory at once + const batchSize = 100 + offset := 0 + processed := 0 - for _, memo := range memos { - if err := RebuildMemoPayload(memo); err != nil { - slog.Error("failed to rebuild memo payload", "err", err) - continue + for { + limit := batchSize + memos, err := r.Store.ListMemos(ctx, &store.FindMemo{ + Limit: &limit, + Offset: &offset, + }) + if err != nil { + slog.Error("failed to list memos", "err", err) + return } - if err := r.Store.UpdateMemo(ctx, &store.UpdateMemo{ - ID: memo.ID, - Payload: memo.Payload, - }); err != nil { - slog.Error("failed to update memo", "err", err) + + // Break if no more memos + if len(memos) == 0 { + break } + + // Process batch + batchSuccessCount := 0 + for _, memo := range memos { + if err := RebuildMemoPayload(memo); err != nil { + slog.Error("failed to rebuild memo payload", "err", err, "memoID", memo.ID) + continue + } + if err := r.Store.UpdateMemo(ctx, &store.UpdateMemo{ + ID: memo.ID, + Payload: memo.Payload, + }); err != nil { + slog.Error("failed to update memo", "err", err, "memoID", memo.ID) + continue + } + batchSuccessCount++ + } + + processed += len(memos) + slog.Info("Processed memo batch", "batchSize", len(memos), "successCount", batchSuccessCount, "totalProcessed", processed) + + // Move to next batch + offset += len(memos) + + // Force garbage collection between batches to prevent memory accumulation + runtime.GC() } } diff --git a/server/runner/s3presign/runner.go b/server/runner/s3presign/runner.go index 342b26c5..d1d4a326 100644 --- a/server/runner/s3presign/runner.go +++ b/server/runner/s3presign/runner.go @@ -3,6 +3,7 @@ package s3presign import ( "context" "log/slog" + "runtime" "time" "google.golang.org/protobuf/types/known/timestamppb" @@ -50,59 +51,88 @@ func (r *Runner) CheckAndPresign(ctx context.Context) { } s3StorageType := storepb.ResourceStorageType_S3 - resources, err := r.Store.ListResources(ctx, &store.FindResource{ - GetBlob: false, - StorageType: &s3StorageType, - }) - if err != nil { - return - } + // Limit resources to a reasonable batch size + const batchSize = 100 + offset := 0 - for _, resource := range resources { - s3ObjectPayload := resource.Payload.GetS3Object() - if s3ObjectPayload == nil { - continue + for { + limit := batchSize + resources, err := r.Store.ListResources(ctx, &store.FindResource{ + GetBlob: false, + StorageType: &s3StorageType, + Limit: &limit, + Offset: &offset, + }) + if err != nil { + slog.Error("Failed to list resources for presigning", "error", err) + return } - if s3ObjectPayload.LastPresignedTime != nil { - // Skip if the presigned URL is still valid for the next 4 days. - // The expiration time is set to 5 days. - if time.Now().Before(s3ObjectPayload.LastPresignedTime.AsTime().Add(4 * 24 * time.Hour)) { + // Break if no more resources + if len(resources) == 0 { + break + } + + // Process batch of resources + presignCount := 0 + for _, resource := range resources { + s3ObjectPayload := resource.Payload.GetS3Object() + if s3ObjectPayload == nil { continue } - } - s3Config := workspaceStorageSetting.GetS3Config() - if s3ObjectPayload.S3Config != nil { - s3Config = s3ObjectPayload.S3Config - } - if s3Config == nil { - slog.Error("S3 config is not found") - continue - } + if s3ObjectPayload.LastPresignedTime != nil { + // Skip if the presigned URL is still valid for the next 4 days. + // The expiration time is set to 5 days. + if time.Now().Before(s3ObjectPayload.LastPresignedTime.AsTime().Add(4 * 24 * time.Hour)) { + continue + } + } - s3Client, err := s3.NewClient(ctx, s3Config) - if err != nil { - slog.Error("Failed to create S3 client", "error", err) - continue - } + s3Config := workspaceStorageSetting.GetS3Config() + if s3ObjectPayload.S3Config != nil { + s3Config = s3ObjectPayload.S3Config + } + if s3Config == nil { + slog.Error("S3 config is not found") + continue + } - presignURL, err := s3Client.PresignGetObject(ctx, s3ObjectPayload.Key) - if err != nil { - return - } - s3ObjectPayload.S3Config = s3Config - s3ObjectPayload.LastPresignedTime = timestamppb.New(time.Now()) - if err := r.Store.UpdateResource(ctx, &store.UpdateResource{ - ID: resource.ID, - Reference: &presignURL, - Payload: &storepb.ResourcePayload{ - Payload: &storepb.ResourcePayload_S3Object_{ - S3Object: s3ObjectPayload, + s3Client, err := s3.NewClient(ctx, s3Config) + if err != nil { + slog.Error("Failed to create S3 client", "error", err) + continue + } + + presignURL, err := s3Client.PresignGetObject(ctx, s3ObjectPayload.Key) + if err != nil { + slog.Error("Failed to presign URL", "error", err, "resourceID", resource.ID) + continue + } + + s3ObjectPayload.S3Config = s3Config + s3ObjectPayload.LastPresignedTime = timestamppb.New(time.Now()) + if err := r.Store.UpdateResource(ctx, &store.UpdateResource{ + ID: resource.ID, + Reference: &presignURL, + Payload: &storepb.ResourcePayload{ + Payload: &storepb.ResourcePayload_S3Object_{ + S3Object: s3ObjectPayload, + }, }, - }, - }); err != nil { - return + }); err != nil { + slog.Error("Failed to update resource", "error", err, "resourceID", resource.ID) + continue + } + presignCount++ } + + slog.Info("Presigned batch of S3 resources", "batchSize", len(resources), "presigned", presignCount) + + // Move to next batch + offset += len(resources) + + // Prevent memory accumulation between batches + runtime.GC() } } diff --git a/server/server.go b/server/server.go index fd5834ca..5bb33ce8 100644 --- a/server/server.go +++ b/server/server.go @@ -7,6 +7,7 @@ import ( "math" "net" "net/http" + "runtime" "time" "github.com/google/uuid" @@ -19,6 +20,7 @@ import ( storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/server/profile" + "github.com/usememos/memos/server/profiler" apiv1 "github.com/usememos/memos/server/router/api/v1" "github.com/usememos/memos/server/router/frontend" "github.com/usememos/memos/server/router/rss" @@ -32,8 +34,10 @@ type Server struct { Profile *profile.Profile Store *store.Store - echoServer *echo.Echo - grpcServer *grpc.Server + echoServer *echo.Echo + grpcServer *grpc.Server + profiler *profiler.Profiler + runnerCancelFuncs []context.CancelFunc } func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store) (*Server, error) { @@ -49,6 +53,11 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store echoServer.Use(middleware.Recover()) s.echoServer = echoServer + // Initialize profiler + s.profiler = profiler.NewProfiler() + s.profiler.RegisterRoutes(echoServer) + s.profiler.StartMemoryMonitor(ctx) + workspaceBasicSetting, err := s.getOrUpsertWorkspaceBasicSetting(ctx) if err != nil { return nil, errors.Wrap(err, "failed to get workspace basic setting") @@ -134,6 +143,15 @@ func (s *Server) Shutdown(ctx context.Context) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() + slog.Info("server shutting down") + + // Cancel all background runners + for _, cancelFunc := range s.runnerCancelFuncs { + if cancelFunc != nil { + cancelFunc() + } + } + // Shutdown echo server. if err := s.echoServer.Shutdown(ctx); err != nil { slog.Error("failed to shutdown server", slog.String("error", err.Error())) @@ -142,6 +160,23 @@ func (s *Server) Shutdown(ctx context.Context) { // Shutdown gRPC server. s.grpcServer.GracefulStop() + // Stop the profiler + if s.profiler != nil { + slog.Info("stopping profiler") + // Force one last garbage collection to clean up remaining objects + runtime.GC() + + // Log final memory stats + var m runtime.MemStats + runtime.ReadMemStats(&m) + slog.Info("final memory stats before exit", + "heapAlloc", m.Alloc, + "heapSys", m.Sys, + "heapObjects", m.HeapObjects, + "numGoroutine", runtime.NumGoroutine(), + ) + } + // Close database connection. if err := s.Store.Close(); err != nil { slog.Error("failed to close database", slog.String("error", err.Error())) @@ -151,13 +186,30 @@ func (s *Server) Shutdown(ctx context.Context) { } func (s *Server) StartBackgroundRunners(ctx context.Context) { + // Create a separate context for each background runner + // This allows us to control cancellation for each runner independently + s3Context, s3Cancel := context.WithCancel(ctx) + + // Store the cancel function so we can properly shut down runners + s.runnerCancelFuncs = append(s.runnerCancelFuncs, s3Cancel) + + // Create and start S3 presign runner s3presignRunner := s3presign.NewRunner(s.Store) s3presignRunner.RunOnce(ctx) + + // Create and start memo payload runner just once memopayloadRunner := memopayload.NewRunner(s.Store) // Rebuild all memos' payload after server starts. memopayloadRunner.RunOnce(ctx) - go s3presignRunner.Run(ctx) + // Start continuous S3 presign runner + go func() { + s3presignRunner.Run(s3Context) + slog.Info("s3presign runner stopped") + }() + + // Log the number of goroutines running + slog.Info("background runners started", "goroutines", runtime.NumGoroutine()) } func (s *Server) getOrUpsertWorkspaceBasicSetting(ctx context.Context) (*storepb.WorkspaceBasicSetting, error) { diff --git a/store/cache/cache.go b/store/cache/cache.go new file mode 100644 index 00000000..48dea566 --- /dev/null +++ b/store/cache/cache.go @@ -0,0 +1,311 @@ +package cache + +import ( + "context" + "sync" + "sync/atomic" + "time" +) + +// Interface defines the operations a cache must support +type Interface interface { + // Set adds a value to the cache with the default TTL + Set(ctx context.Context, key string, value interface{}) + + // SetWithTTL adds a value to the cache with a custom TTL + SetWithTTL(ctx context.Context, key string, value interface{}, ttl time.Duration) + + // Get retrieves a value from the cache + Get(ctx context.Context, key string) (interface{}, bool) + + // Delete removes a value from the cache + Delete(ctx context.Context, key string) + + // Clear removes all values from the cache + Clear(ctx context.Context) + + // Size returns the number of items in the cache + Size() int64 + + // Close stops all background tasks and releases resources + Close() error +} + +// item represents a cached value with metadata +type item struct { + value interface{} + expiration time.Time + size int // Approximate size in bytes +} + +// Config contains options for configuring a cache +type Config struct { + // DefaultTTL is the default time-to-live for cache entries + DefaultTTL time.Duration + + // CleanupInterval is how often the cache runs cleanup + CleanupInterval time.Duration + + // MaxItems is the maximum number of items allowed in the cache + MaxItems int + + // OnEviction is called when an item is evicted from the cache + OnEviction func(key string, value interface{}) +} + +// DefaultConfig returns a default configuration for the cache +func DefaultConfig() Config { + return Config{ + DefaultTTL: 10 * time.Minute, + CleanupInterval: 5 * time.Minute, + MaxItems: 1000, + OnEviction: nil, + } +} + +// Cache is a thread-safe in-memory cache with TTL and memory management +type Cache struct { + data sync.Map + config Config + itemCount int64 // Use atomic operations to track item count + stopChan chan struct{} + closedChan chan struct{} +} + +// New creates a new memory cache with the given configuration +func New(config Config) *Cache { + c := &Cache{ + config: config, + stopChan: make(chan struct{}), + closedChan: make(chan struct{}), + } + + go c.cleanupLoop() + return c +} + +// NewDefault creates a new memory cache with default configuration +func NewDefault() *Cache { + return New(DefaultConfig()) +} + +// Set adds a value to the cache with the default TTL +func (c *Cache) Set(ctx context.Context, key string, value interface{}) { + c.SetWithTTL(ctx, key, value, c.config.DefaultTTL) +} + +// SetWithTTL adds a value to the cache with a custom TTL +func (c *Cache) SetWithTTL(ctx context.Context, key string, value interface{}, ttl time.Duration) { + // Estimate size of the item (very rough approximation) + size := estimateSize(value) + + // Check if item already exists to avoid double counting + if _, exists := c.data.Load(key); exists { + c.data.Delete(key) + // Don't decrement count - we'll replace it + } else { + // Only increment if this is a new key + atomic.AddInt64(&c.itemCount, 1) + } + + c.data.Store(key, item{ + value: value, + expiration: time.Now().Add(ttl), + size: size, + }) + + // If we're over the max items, clean up old items + if c.config.MaxItems > 0 && atomic.LoadInt64(&c.itemCount) > int64(c.config.MaxItems) { + c.cleanupOldest() + } +} + +// Get retrieves a value from the cache +func (c *Cache) Get(ctx context.Context, key string) (interface{}, bool) { + value, ok := c.data.Load(key) + if !ok { + return nil, false + } + + itm := value.(item) + if time.Now().After(itm.expiration) { + c.data.Delete(key) + atomic.AddInt64(&c.itemCount, -1) + + if c.config.OnEviction != nil { + c.config.OnEviction(key, itm.value) + } + + return nil, false + } + + return itm.value, true +} + +// Delete removes a value from the cache +func (c *Cache) Delete(ctx context.Context, key string) { + if value, loaded := c.data.LoadAndDelete(key); loaded { + atomic.AddInt64(&c.itemCount, -1) + + if c.config.OnEviction != nil { + itm := value.(item) + c.config.OnEviction(key, itm.value) + } + } +} + +// Clear removes all values from the cache +func (c *Cache) Clear(ctx context.Context) { + if c.config.OnEviction != nil { + c.data.Range(func(key, value interface{}) bool { + itm := value.(item) + c.config.OnEviction(key.(string), itm.value) + return true + }) + } + + c.data = sync.Map{} + atomic.StoreInt64(&c.itemCount, 0) +} + +// Size returns the number of items in the cache +func (c *Cache) Size() int64 { + return atomic.LoadInt64(&c.itemCount) +} + +// Close stops the cache cleanup goroutine +func (c *Cache) Close() error { + select { + case <-c.stopChan: + // Already closed + return nil + default: + close(c.stopChan) + <-c.closedChan // Wait for cleanup goroutine to exit + return nil + } +} + +// cleanupLoop periodically cleans up expired items +func (c *Cache) cleanupLoop() { + ticker := time.NewTicker(c.config.CleanupInterval) + defer func() { + ticker.Stop() + close(c.closedChan) + }() + + for { + select { + case <-ticker.C: + c.cleanup() + case <-c.stopChan: + return + } + } +} + +// cleanup removes expired items +func (c *Cache) cleanup() { + evicted := make(map[string]interface{}) + count := 0 + + c.data.Range(func(key, value interface{}) bool { + itm := value.(item) + if time.Now().After(itm.expiration) { + c.data.Delete(key) + count++ + + if c.config.OnEviction != nil { + evicted[key.(string)] = itm.value + } + } + return true + }) + + if count > 0 { + atomic.AddInt64(&c.itemCount, -int64(count)) + + // Call eviction callbacks outside the loop to avoid blocking the range + if c.config.OnEviction != nil { + for k, v := range evicted { + c.config.OnEviction(k, v) + } + } + } +} + +// cleanupOldest removes the oldest items if we're over the max items +func (c *Cache) cleanupOldest() { + threshold := c.config.MaxItems / 5 // Remove 20% of max items at once + if threshold < 1 { + threshold = 1 + } + + currentCount := atomic.LoadInt64(&c.itemCount) + + // If we're not over the threshold, don't do anything + if currentCount <= int64(c.config.MaxItems) { + return + } + + // Find the oldest items + type keyExpPair struct { + key string + value interface{} + expiration time.Time + } + candidates := make([]keyExpPair, 0, threshold) + + c.data.Range(func(key, value interface{}) bool { + itm := value.(item) + if len(candidates) < threshold { + candidates = append(candidates, keyExpPair{key.(string), itm.value, itm.expiration}) + return true + } + + // Find the newest item in candidates + newestIdx := 0 + for i := 1; i < len(candidates); i++ { + if candidates[i].expiration.After(candidates[newestIdx].expiration) { + newestIdx = i + } + } + + // Replace it if this item is older + if itm.expiration.Before(candidates[newestIdx].expiration) { + candidates[newestIdx] = keyExpPair{key.(string), itm.value, itm.expiration} + } + + return true + }) + + // Delete the oldest items + deletedCount := 0 + for _, candidate := range candidates { + c.data.Delete(candidate.key) + deletedCount++ + + if c.config.OnEviction != nil { + c.config.OnEviction(candidate.key, candidate.value) + } + } + + // Update count + if deletedCount > 0 { + atomic.AddInt64(&c.itemCount, -int64(deletedCount)) + } +} + +// estimateSize attempts to estimate the memory footprint of a value +func estimateSize(value interface{}) int { + switch v := value.(type) { + case string: + return len(v) + 24 // base size + string overhead + case []byte: + return len(v) + 24 // base size + slice overhead + case map[string]interface{}: + return len(v) * 64 // rough estimate + default: + return 64 // default conservative estimate + } +} diff --git a/store/cache/cache_test.go b/store/cache/cache_test.go new file mode 100644 index 00000000..3ba23306 --- /dev/null +++ b/store/cache/cache_test.go @@ -0,0 +1,209 @@ +package cache + +import ( + "context" + "fmt" + "sync" + "testing" + "time" +) + +func TestCacheBasicOperations(t *testing.T) { + ctx := context.Background() + config := DefaultConfig() + config.DefaultTTL = 100 * time.Millisecond + config.CleanupInterval = 50 * time.Millisecond + cache := New(config) + defer cache.Close() + + // Test Set and Get + cache.Set(ctx, "key1", "value1") + if val, ok := cache.Get(ctx, "key1"); !ok || val != "value1" { + t.Errorf("Expected 'value1', got %v, exists: %v", val, ok) + } + + // Test SetWithTTL + cache.SetWithTTL(ctx, "key2", "value2", 200*time.Millisecond) + if val, ok := cache.Get(ctx, "key2"); !ok || val != "value2" { + t.Errorf("Expected 'value2', got %v, exists: %v", val, ok) + } + + // Test Delete + cache.Delete(ctx, "key1") + if _, ok := cache.Get(ctx, "key1"); ok { + t.Errorf("Key 'key1' should have been deleted") + } + + // Test automatic expiration + time.Sleep(150 * time.Millisecond) + if _, ok := cache.Get(ctx, "key1"); ok { + t.Errorf("Key 'key1' should have expired") + } + // key2 should still be valid (200ms TTL) + if _, ok := cache.Get(ctx, "key2"); !ok { + t.Errorf("Key 'key2' should still be valid") + } + + // Wait for key2 to expire + time.Sleep(100 * time.Millisecond) + if _, ok := cache.Get(ctx, "key2"); ok { + t.Errorf("Key 'key2' should have expired") + } + + // Test Clear + cache.Set(ctx, "key3", "value3") + cache.Clear(ctx) + if _, ok := cache.Get(ctx, "key3"); ok { + t.Errorf("Cache should be empty after Clear()") + } +} + +func TestCacheEviction(t *testing.T) { + ctx := context.Background() + config := DefaultConfig() + config.MaxItems = 5 + cache := New(config) + defer cache.Close() + + // Add 5 items (max capacity) + for i := 0; i < 5; i++ { + key := fmt.Sprintf("key%d", i) + cache.Set(ctx, key, i) + } + + // Verify all 5 items are in the cache + for i := 0; i < 5; i++ { + key := fmt.Sprintf("key%d", i) + if _, ok := cache.Get(ctx, key); !ok { + t.Errorf("Key '%s' should be in the cache", key) + } + } + + // Add 2 more items to trigger eviction + cache.Set(ctx, "keyA", "valueA") + cache.Set(ctx, "keyB", "valueB") + + // Verify size is still within limits + if cache.Size() > int64(config.MaxItems) { + t.Errorf("Cache size %d exceeds limit %d", cache.Size(), config.MaxItems) + } + + // Some of the original keys should have been evicted + evictedCount := 0 + for i := 0; i < 5; i++ { + key := fmt.Sprintf("key%d", i) + if _, ok := cache.Get(ctx, key); !ok { + evictedCount++ + } + } + + if evictedCount == 0 { + t.Errorf("No keys were evicted despite exceeding max items") + } + + // The newer keys should still be present + if _, ok := cache.Get(ctx, "keyA"); !ok { + t.Errorf("Key 'keyA' should be in the cache") + } + if _, ok := cache.Get(ctx, "keyB"); !ok { + t.Errorf("Key 'keyB' should be in the cache") + } +} + +func TestCacheConcurrency(t *testing.T) { + ctx := context.Background() + cache := NewDefault() + defer cache.Close() + + const goroutines = 10 + const operationsPerGoroutine = 100 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(id int) { + defer wg.Done() + + baseKey := fmt.Sprintf("worker%d-", id) + + // Set operations + for j := 0; j < operationsPerGoroutine; j++ { + key := fmt.Sprintf("%skey%d", baseKey, j) + value := fmt.Sprintf("value%d-%d", id, j) + cache.Set(ctx, key, value) + } + + // Get operations + for j := 0; j < operationsPerGoroutine; j++ { + key := fmt.Sprintf("%skey%d", baseKey, j) + val, ok := cache.Get(ctx, key) + if !ok { + t.Errorf("Key '%s' should exist in cache", key) + continue + } + expected := fmt.Sprintf("value%d-%d", id, j) + if val != expected { + t.Errorf("For key '%s', expected '%s', got '%s'", key, expected, val) + } + } + + // Delete half the keys + for j := 0; j < operationsPerGoroutine/2; j++ { + key := fmt.Sprintf("%skey%d", baseKey, j) + cache.Delete(ctx, key) + } + }(i) + } + + wg.Wait() + + // Verify size and deletion + var totalKeysExpected int64 = goroutines * operationsPerGoroutine / 2 + if cache.Size() != totalKeysExpected { + t.Errorf("Expected cache size to be %d, got %d", totalKeysExpected, cache.Size()) + } +} + +func TestEvictionCallback(t *testing.T) { + ctx := context.Background() + evicted := make(map[string]interface{}) + evictedMu := sync.Mutex{} + + config := DefaultConfig() + config.DefaultTTL = 50 * time.Millisecond + config.CleanupInterval = 25 * time.Millisecond + config.OnEviction = func(key string, value interface{}) { + evictedMu.Lock() + evicted[key] = value + evictedMu.Unlock() + } + + cache := New(config) + defer cache.Close() + + // Add items + cache.Set(ctx, "key1", "value1") + cache.Set(ctx, "key2", "value2") + + // Manually delete + cache.Delete(ctx, "key1") + + // Verify manual deletion triggered callback + time.Sleep(10 * time.Millisecond) // Small delay to ensure callback processed + evictedMu.Lock() + if evicted["key1"] != "value1" { + t.Errorf("Eviction callback not triggered for manual deletion") + } + evictedMu.Unlock() + + // Wait for automatic expiration + time.Sleep(60 * time.Millisecond) + + // Verify TTL expiration triggered callback + evictedMu.Lock() + if evicted["key2"] != "value2" { + t.Errorf("Eviction callback not triggered for TTL expiration") + } + evictedMu.Unlock() +} diff --git a/store/idp.go b/store/idp.go index 6aed4bbf..88ab6f0e 100644 --- a/store/idp.go +++ b/store/idp.go @@ -46,7 +46,6 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *storepb.Iden if err != nil { return nil, err } - s.idpCache.Store(identityProvider.Id, identityProvider) return identityProvider, nil } @@ -63,21 +62,11 @@ func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityPro return nil, err } identityProviders = append(identityProviders, identityProvider) - s.idpCache.Store(identityProvider.Id, identityProvider) } return identityProviders, nil } func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*storepb.IdentityProvider, error) { - if find.ID != nil { - if cache, ok := s.idpCache.Load(*find.ID); ok { - identityProvider, ok := cache.(*storepb.IdentityProvider) - if ok { - return identityProvider, nil - } - } - } - list, err := s.ListIdentityProviders(ctx, find) if err != nil { return nil, err @@ -127,7 +116,6 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti if err != nil { return nil, err } - s.idpCache.Store(identityProvider.Id, identityProvider) return identityProvider, nil } @@ -136,8 +124,6 @@ func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdenti if err != nil { return err } - - s.idpCache.Delete(delete.ID) return nil } diff --git a/store/migrator.go b/store/migrator.go index 2469e55c..2b8c94ea 100644 --- a/store/migrator.go +++ b/store/migrator.go @@ -39,7 +39,7 @@ func (s *Store) Migrate(ctx context.Context) error { return errors.Wrap(err, "failed to pre-migrate") } - if s.Profile.Mode == "prod" { + if s.profile.Mode == "prod" { migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{}) if err != nil { return errors.Wrap(err, "failed to find migration history") @@ -107,7 +107,7 @@ func (s *Store) Migrate(ctx context.Context) error { return errors.Wrap(err, "failed to update current schema version") } } - } else if s.Profile.Mode == "demo" { + } else if s.profile.Mode == "demo" { // In demo mode, we should seed the database. if err := s.seed(ctx); err != nil { return errors.Wrap(err, "failed to seed") @@ -157,7 +157,7 @@ func (s *Store) preMigrate(ctx context.Context) error { return errors.Wrap(err, "failed to update current schema version") } } - if s.Profile.Mode == "prod" { + if s.profile.Mode == "prod" { if err := s.normalizedMigrationHistoryList(ctx); err != nil { return errors.Wrap(err, "failed to normalize migration history list") } @@ -166,16 +166,16 @@ func (s *Store) preMigrate(ctx context.Context) error { } func (s *Store) getMigrationBasePath() string { - return fmt.Sprintf("migration/%s/", s.Profile.Driver) + return fmt.Sprintf("migration/%s/", s.profile.Driver) } func (s *Store) getSeedBasePath() string { - return fmt.Sprintf("seed/%s/", s.Profile.Driver) + return fmt.Sprintf("seed/%s/", s.profile.Driver) } func (s *Store) seed(ctx context.Context) error { // Only seed for SQLite. - if s.Profile.Driver != "sqlite" { + if s.profile.Driver != "sqlite" { slog.Warn("seed is only supported for SQLite") return nil } @@ -207,7 +207,7 @@ func (s *Store) seed(ctx context.Context) error { } func (s *Store) GetCurrentSchemaVersion() (string, error) { - currentVersion := version.GetCurrentVersion(s.Profile.Mode) + currentVersion := version.GetCurrentVersion(s.profile.Mode) minorVersion := version.GetMinorVersion(currentVersion) filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s%s/*.sql", s.getMigrationBasePath(), minorVersion)) if err != nil { diff --git a/store/resource.go b/store/resource.go index fdb68d47..ad03ee52 100644 --- a/store/resource.go +++ b/store/resource.go @@ -74,6 +74,17 @@ func (s *Store) CreateResource(ctx context.Context, create *Resource) (*Resource } func (s *Store) ListResources(ctx context.Context, find *FindResource) ([]*Resource, error) { + // Set default limits to prevent loading too many resources at once + if find.Limit == nil && find.GetBlob { + // When fetching blobs, we should be especially careful with limits + defaultLimit := 10 + find.Limit = &defaultLimit + } else if find.Limit == nil { + // Even without blobs, let's default to a reasonable limit + defaultLimit := 100 + find.Limit = &defaultLimit + } + return s.driver.ListResources(ctx, find) } @@ -110,7 +121,7 @@ func (s *Store) DeleteResource(ctx context.Context, delete *DeleteResource) erro if err := func() error { p := filepath.FromSlash(resource.Reference) if !filepath.IsAbs(p) { - p = filepath.Join(s.Profile.Data, p) + p = filepath.Join(s.profile.Data, p) } err := os.Remove(p) if err != nil { diff --git a/store/store.go b/store/store.go index 030a0a22..07716278 100644 --- a/store/store.go +++ b/store/store.go @@ -1,27 +1,46 @@ package store import ( - "sync" + "time" "github.com/usememos/memos/server/profile" + "github.com/usememos/memos/store/cache" ) // Store provides database access to all raw objects. type Store struct { - Profile *profile.Profile - driver Driver - workspaceSettingCache sync.Map // map[string]*storepb.WorkspaceSetting - userCache sync.Map // map[int]*User - userSettingCache sync.Map // map[string]*storepb.UserSetting - idpCache sync.Map // map[int]*storepb.IdentityProvider + profile *profile.Profile + driver Driver + + // Cache settings + cacheConfig cache.Config + + // Caches + workspaceSettingCache *cache.Cache // cache for workspace settings + userCache *cache.Cache // cache for users + userSettingCache *cache.Cache // cache for user settings } // New creates a new instance of Store. func New(driver Driver, profile *profile.Profile) *Store { - return &Store{ - driver: driver, - Profile: profile, + // Default cache settings + cacheConfig := cache.Config{ + DefaultTTL: 10 * time.Minute, + CleanupInterval: 5 * time.Minute, + MaxItems: 1000, + OnEviction: nil, } + + store := &Store{ + driver: driver, + profile: profile, + cacheConfig: cacheConfig, + workspaceSettingCache: cache.New(cacheConfig), + userCache: cache.New(cacheConfig), + userSettingCache: cache.New(cacheConfig), + } + + return store } func (s *Store) GetDriver() Driver { @@ -29,5 +48,10 @@ func (s *Store) GetDriver() Driver { } func (s *Store) Close() error { + // Stop all cache cleanup goroutines + s.workspaceSettingCache.Close() + s.userCache.Close() + s.userSettingCache.Close() + return s.driver.Close() } diff --git a/store/user.go b/store/user.go index 8b5c0fd3..8a2a9408 100644 --- a/store/user.go +++ b/store/user.go @@ -97,7 +97,7 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { return nil, err } - s.userCache.Store(user.ID, user) + s.userCache.Set(ctx, string(user.ID), user) return user, nil } @@ -107,7 +107,7 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro return nil, err } - s.userCache.Store(user.ID, user) + s.userCache.Set(ctx, string(user.ID), user) return user, nil } @@ -118,7 +118,7 @@ func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) } for _, user := range list { - s.userCache.Store(user.ID, user) + s.userCache.Set(ctx, string(user.ID), user) } return list, nil } @@ -128,7 +128,7 @@ func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) { if *find.ID == SystemBotID { return SystemBot, nil } - if cache, ok := s.userCache.Load(*find.ID); ok { + if cache, ok := s.userCache.Get(ctx, string(*find.ID)); ok { user, ok := cache.(*User) if ok { return user, nil @@ -145,7 +145,7 @@ func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) { } user := list[0] - s.userCache.Store(user.ID, user) + s.userCache.Set(ctx, string(user.ID), user) return user, nil } @@ -154,7 +154,6 @@ func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error { if err != nil { return err } - - s.userCache.Delete(delete.ID) + s.userCache.Delete(ctx, string(delete.ID)) return nil } diff --git a/store/user_setting.go b/store/user_setting.go index 18d0761e..a2053285 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -37,7 +37,7 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetti if userSetting == nil { return nil, errors.New("unexpected nil user setting") } - s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) + s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) return userSetting, nil } @@ -56,7 +56,7 @@ func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([] if userSetting == nil { continue } - s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) + s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) userSettings = append(userSettings, userSetting) } return userSettings, nil @@ -64,7 +64,7 @@ func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([] func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*storepb.UserSetting, error) { if find.UserID != nil { - if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key.String())); ok { + if cache, ok := s.userSettingCache.Get(ctx, getUserSettingCacheKey(*find.UserID, find.Key.String())); ok { userSetting, ok := cache.(*storepb.UserSetting) if ok { return userSetting, nil @@ -84,7 +84,7 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*sto } userSetting := list[0] - s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) + s.userSettingCache.Set(ctx, getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) return userSetting, nil } diff --git a/store/workspace_setting.go b/store/workspace_setting.go index c8c21448..9f9054e7 100644 --- a/store/workspace_setting.go +++ b/store/workspace_setting.go @@ -53,7 +53,7 @@ func (s *Store) UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.Work if err != nil { return nil, errors.Wrap(err, "Failed to convert workspace setting") } - s.workspaceSettingCache.Store(workspaceSetting.Key.String(), workspaceSetting) + s.workspaceSettingCache.Set(ctx, workspaceSetting.Key.String(), workspaceSetting) return workspaceSetting, nil } @@ -72,14 +72,14 @@ func (s *Store) ListWorkspaceSettings(ctx context.Context, find *FindWorkspaceSe if workspaceSetting == nil { continue } - s.workspaceSettingCache.Store(workspaceSetting.Key.String(), workspaceSetting) + s.workspaceSettingCache.Set(ctx, workspaceSetting.Key.String(), workspaceSetting) workspaceSettings = append(workspaceSettings, workspaceSetting) } return workspaceSettings, nil } func (s *Store) GetWorkspaceSetting(ctx context.Context, find *FindWorkspaceSetting) (*storepb.WorkspaceSetting, error) { - if cache, ok := s.workspaceSettingCache.Load(find.Name); ok { + if cache, ok := s.workspaceSettingCache.Get(ctx, find.Name); ok { workspaceSetting, ok := cache.(*storepb.WorkspaceSetting) if ok { return workspaceSetting, nil @@ -111,7 +111,7 @@ func (s *Store) GetWorkspaceBasicSetting(ctx context.Context) (*storepb.Workspac if workspaceSetting != nil { workspaceBasicSetting = workspaceSetting.GetBasicSetting() } - s.workspaceSettingCache.Store(storepb.WorkspaceSettingKey_BASIC.String(), &storepb.WorkspaceSetting{ + s.workspaceSettingCache.Set(ctx, storepb.WorkspaceSettingKey_BASIC.String(), &storepb.WorkspaceSetting{ Key: storepb.WorkspaceSettingKey_BASIC, Value: &storepb.WorkspaceSetting_BasicSetting{BasicSetting: workspaceBasicSetting}, }) @@ -130,7 +130,7 @@ func (s *Store) GetWorkspaceGeneralSetting(ctx context.Context) (*storepb.Worksp if workspaceSetting != nil { workspaceGeneralSetting = workspaceSetting.GetGeneralSetting() } - s.workspaceSettingCache.Store(storepb.WorkspaceSettingKey_GENERAL.String(), &storepb.WorkspaceSetting{ + s.workspaceSettingCache.Set(ctx, storepb.WorkspaceSettingKey_GENERAL.String(), &storepb.WorkspaceSetting{ Key: storepb.WorkspaceSettingKey_GENERAL, Value: &storepb.WorkspaceSetting_GeneralSetting{GeneralSetting: workspaceGeneralSetting}, }) @@ -167,7 +167,7 @@ func (s *Store) GetWorkspaceMemoRelatedSetting(ctx context.Context) (*storepb.Wo if len(workspaceMemoRelatedSetting.NsfwTags) == 0 { workspaceMemoRelatedSetting.NsfwTags = append(workspaceMemoRelatedSetting.NsfwTags, DefaultNsfwTags...) } - s.workspaceSettingCache.Store(storepb.WorkspaceSettingKey_MEMO_RELATED.String(), &storepb.WorkspaceSetting{ + s.workspaceSettingCache.Set(ctx, storepb.WorkspaceSettingKey_MEMO_RELATED.String(), &storepb.WorkspaceSetting{ Key: storepb.WorkspaceSettingKey_MEMO_RELATED, Value: &storepb.WorkspaceSetting_MemoRelatedSetting{MemoRelatedSetting: workspaceMemoRelatedSetting}, }) @@ -201,7 +201,7 @@ func (s *Store) GetWorkspaceStorageSetting(ctx context.Context) (*storepb.Worksp if workspaceStorageSetting.FilepathTemplate == "" { workspaceStorageSetting.FilepathTemplate = defaultWorkspaceFilepathTemplate } - s.workspaceSettingCache.Store(storepb.WorkspaceSettingKey_STORAGE.String(), &storepb.WorkspaceSetting{ + s.workspaceSettingCache.Set(ctx, storepb.WorkspaceSettingKey_STORAGE.String(), &storepb.WorkspaceSetting{ Key: storepb.WorkspaceSettingKey_STORAGE, Value: &storepb.WorkspaceSetting_StorageSetting{StorageSetting: workspaceStorageSetting}, })