/* GoToSocial Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see . */ package bundb import ( "context" "strings" "github.com/superseriousbusiness/gotosocial/internal/cache" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/uptrace/bun" ) type emojiDB struct { conn *DBConn cache *cache.EmojiCache } func (e *emojiDB) newEmojiQ(emoji *gtsmodel.Emoji) *bun.SelectQuery { return e.conn. NewSelect(). Model(emoji) } func (e *emojiDB) PutEmoji(ctx context.Context, emoji *gtsmodel.Emoji) db.Error { if _, err := e.conn.NewInsert().Model(emoji).Exec(ctx); err != nil { return e.conn.ProcessError(err) } e.cache.Put(emoji) return nil } func (e *emojiDB) GetCustomEmojis(ctx context.Context) ([]*gtsmodel.Emoji, db.Error) { emojiIDs := []string{} q := e.conn. NewSelect(). Table("emojis"). Column("id"). Where("visible_in_picker = true"). Where("disabled = false"). Where("domain IS NULL"). Order("shortcode ASC") if err := q.Scan(ctx, &emojiIDs); err != nil { return nil, e.conn.ProcessError(err) } return e.emojisFromIDs(ctx, emojiIDs) } func (e *emojiDB) GetEmojiByID(ctx context.Context, id string) (*gtsmodel.Emoji, db.Error) { return e.getEmoji( ctx, func() (*gtsmodel.Emoji, bool) { return e.cache.GetByID(id) }, func(emoji *gtsmodel.Emoji) error { return e.newEmojiQ(emoji).Where("emoji.id = ?", id).Scan(ctx) }, ) } func (e *emojiDB) GetEmojiByURI(ctx context.Context, uri string) (*gtsmodel.Emoji, db.Error) { return e.getEmoji( ctx, func() (*gtsmodel.Emoji, bool) { return e.cache.GetByURI(uri) }, func(emoji *gtsmodel.Emoji) error { return e.newEmojiQ(emoji).Where("emoji.uri = ?", uri).Scan(ctx) }, ) } func (e *emojiDB) GetEmojiByShortcodeDomain(ctx context.Context, shortcode string, domain string) (*gtsmodel.Emoji, db.Error) { return e.getEmoji( ctx, func() (*gtsmodel.Emoji, bool) { return e.cache.GetByShortcodeDomain(shortcode, domain) }, func(emoji *gtsmodel.Emoji) error { q := e.newEmojiQ(emoji) if domain != "" { q = q.Where("emoji.shortcode = ?", shortcode) q = q.Where("emoji.domain = ?", domain) } else { q = q.Where("emoji.shortcode = ?", strings.ToLower(shortcode)) q = q.Where("emoji.domain IS NULL") } return q.Scan(ctx) }, ) } func (e *emojiDB) getEmoji(ctx context.Context, cacheGet func() (*gtsmodel.Emoji, bool), dbQuery func(*gtsmodel.Emoji) error) (*gtsmodel.Emoji, db.Error) { // Attempt to fetch cached emoji emoji, cached := cacheGet() if !cached { emoji = >smodel.Emoji{} // Not cached! Perform database query err := dbQuery(emoji) if err != nil { return nil, e.conn.ProcessError(err) } // Place in the cache e.cache.Put(emoji) } return emoji, nil } func (e *emojiDB) emojisFromIDs(ctx context.Context, emojiIDs []string) ([]*gtsmodel.Emoji, db.Error) { // Catch case of no emojis early if len(emojiIDs) == 0 { return nil, db.ErrNoEntries } emojis := make([]*gtsmodel.Emoji, 0, len(emojiIDs)) for _, id := range emojiIDs { emoji, err := e.GetEmojiByID(ctx, id) if err != nil { log.Errorf("emojisFromIDs: error getting emoji %q: %v", id, err) } emojis = append(emojis, emoji) } return emojis, nil }