Display images

This commit is contained in:
Daniel Schwarz 2024-04-13 08:14:36 +02:00 committed by Ivan Habunek
parent 07ad41960f
commit 0fc2ec12f5
No known key found for this signature in database
GPG Key ID: F5F0623FF5EBCB3D
10 changed files with 498 additions and 21 deletions

View File

@ -39,9 +39,14 @@ setup(
"beautifulsoup4>=4.5.0,<5.0",
"wcwidth>=0.1.7",
"urwid>=2.0.0,<3.0",
"tomlkit>=0.10.0,<1.0"
"tomlkit>=0.10.0,<1.0",
],
extras_require={
# Required to display images in the TUI
"images": [
"pillow>=9.5.0",
"term-image==0.7.0",
],
# Required to display rich text in the TUI
"richtext": [
"urwidgets>=0.1,<0.2"
@ -60,6 +65,7 @@ setup(
"setuptools",
"vermin",
"typing-extensions",
"pillow>=9.5.0",
],
},
entry_points={

View File

@ -1,8 +1,12 @@
import click
import pytest
import sys
from toot.cli.validators import validate_duration
from toot.wcstring import wc_wrap, trunc, pad, fit_text
from toot.tui.utils import LRUCache
from PIL import Image
from collections import namedtuple
from toot.utils import urlencode_url
@ -207,6 +211,111 @@ def test_duration():
duration("banana")
def test_cache_null():
"""Null dict is null."""
cache = LRUCache(cache_max_bytes=1024)
assert cache.__len__() == 0
Case = namedtuple("Case", ["cache_len", "len", "init"])
img = Image.new('RGB', (100, 100))
img_size = sys.getsizeof(img.tobytes())
@pytest.mark.parametrize(
"case",
[
Case(9, 0, []),
Case(9, 1, [("one", img)]),
Case(9, 2, [("one", img), ("two", img)]),
Case(2, 2, [("one", img), ("two", img)]),
Case(1, 1, [("one", img), ("two", img)]),
],
)
@pytest.mark.parametrize("method", ["assign", "init"])
def test_cache_init(case, method):
"""Check that the # of elements is right, given # given and cache_len."""
if method == "init":
cache = LRUCache(case.init, cache_max_bytes=img_size * case.cache_len)
elif method == "assign":
cache = LRUCache(cache_max_bytes=img_size * case.cache_len)
for (key, val) in case.init:
cache[key] = val
else:
assert False
# length is max(#entries, cache_len)
assert cache.__len__() == case.len
# make sure the first entry is the one ejected
if case.cache_len > 1 and case.init:
assert "one" in cache.keys()
else:
assert "one" not in cache.keys()
@pytest.mark.parametrize("method", ["init", "assign"])
def test_cache_overflow_default(method):
"""Test default overflow logic."""
if method == "init":
cache = LRUCache([("one", img), ("two", img), ("three", img)], cache_max_bytes=img_size * 2)
elif method == "assign":
cache = LRUCache(cache_max_bytes=img_size * 2)
cache["one"] = img
cache["two"] = img
cache["three"] = img
else:
assert False
assert "one" not in cache.keys()
assert "two" in cache.keys()
assert "three" in cache.keys()
@pytest.mark.parametrize("mode", ["get", "set"])
@pytest.mark.parametrize("add_third", [False, True])
def test_cache_lru_overflow(mode, add_third):
img = Image.new('RGB', (100, 100))
img_size = sys.getsizeof(img.tobytes())
"""Test that key access resets LRU logic."""
cache = LRUCache([("one", img), ("two", img)], cache_max_bytes=img_size * 2)
if mode == "get":
dummy = cache["one"]
elif mode == "set":
cache["one"] = img
else:
assert False
if add_third:
cache["three"] = img
assert "one" in cache.keys()
assert "two" not in cache.keys()
assert "three" in cache.keys()
else:
assert "one" in cache.keys()
assert "two" in cache.keys()
assert "three" not in cache.keys()
def test_cache_keyerror():
cache = LRUCache()
with pytest.raises(KeyError):
cache["foo"]
def test_cache_miss_doesnt_eject():
cache = LRUCache([("one", img), ("two", img)], cache_max_bytes=img_size * 3)
with pytest.raises(KeyError):
cache["foo"]
assert len(cache) == 2
assert "one" in cache.keys()
assert "two" in cache.keys()
def test_urlencode_url():
assert urlencode_url("https://www.example.com") == "https://www.example.com"
assert urlencode_url("https://www.example.com/url%20with%20spaces") == "https://www.example.com/url%20with%20spaces"

View File

@ -22,7 +22,7 @@ T = t.TypeVar("T")
PRIVACY_CHOICES = ["public", "unlisted", "private"]
VISIBILITY_CHOICES = ["public", "unlisted", "private", "direct"]
IMAGE_FORMAT_CHOICES = ["block", "iterm", "kitty"]
TUI_COLORS = {
"1": 1,
"16": 16,

View File

@ -1,8 +1,8 @@
import click
from typing import Optional
from toot.cli import TUI_COLORS, VISIBILITY_CHOICES, Context, cli, pass_context
from toot.cli.validators import validate_tui_colors
from toot.cli import TUI_COLORS, VISIBILITY_CHOICES, IMAGE_FORMAT_CHOICES, Context, cli, pass_context
from toot.cli.validators import validate_tui_colors, validate_cache_size
from toot.tui.app import TUI, TuiOptions
COLOR_OPTIONS = ", ".join(TUI_COLORS.keys())
@ -24,6 +24,12 @@ COLOR_OPTIONS = ", ".join(TUI_COLORS.keys())
help=f"""Number of colors to use, one of {COLOR_OPTIONS}, defaults to 16 if
using --color, and 1 if using --no-color."""
)
@click.option(
"-s", "--cache-size",
callback=validate_cache_size,
help="""Specify the image cache maximum size in megabytes. Default: 10MB.
Minimum: 1MB."""
)
@click.option(
"-v", "--default-visibility",
type=click.Choice(VISIBILITY_CHOICES),
@ -34,6 +40,11 @@ COLOR_OPTIONS = ", ".join(TUI_COLORS.keys())
is_flag=True,
help="Expand toots with content warnings automatically"
)
@click.option(
"-f", "--image-format",
type=click.Choice(IMAGE_FORMAT_CHOICES),
help="Image output format; support varies across terminals. Default: block"
)
@pass_context
def tui(
ctx: Context,
@ -41,7 +52,9 @@ def tui(
media_viewer: Optional[str],
always_show_sensitive: bool,
relative_datetimes: bool,
default_visibility: Optional[str]
cache_size: Optional[int],
default_visibility: Optional[str],
image_format: Optional[str]
):
"""Launches the toot terminal user interface"""
if colors is None:
@ -51,8 +64,10 @@ def tui(
colors=colors,
media_viewer=media_viewer,
relative_datetimes=relative_datetimes,
cache_size=cache_size,
default_visibility=default_visibility,
always_show_sensitive=always_show_sensitive,
image_format=image_format,
)
tui = TUI.create(ctx.app, ctx.user, options)
tui.run()

View File

@ -73,3 +73,21 @@ def validate_tui_colors(ctx, param, value) -> Optional[int]:
return TUI_COLORS[value]
raise click.BadParameter(f"Invalid value: {value}. Expected one of: {', '.join(TUI_COLORS)}")
def validate_cache_size(ctx: click.Context, param: str, value: Optional[str]) -> Optional[int]:
"""validates the cache size parameter"""
if value is None:
return 1024 * 1024 * 10 # default 10MB
else:
if value.isdigit():
size = int(value)
else:
raise click.BadParameter("Cache size must be numeric.")
if size > 1024:
raise click.BadParameter("Cache size too large: 1024MB maximum.")
elif size < 1:
raise click.BadParameter("Cache size too small: 1MB minimum.")
return size

View File

@ -2,6 +2,7 @@ import logging
import subprocess
import urwid
from concurrent.futures import ThreadPoolExecutor
from typing import NamedTuple, Optional
from datetime import datetime, timezone
@ -15,11 +16,12 @@ from toot.utils.datetime import parse_datetime
from .compose import StatusComposer
from .constants import PALETTE
from .entities import Status
from .images import TuiScreen, load_image
from .overlays import ExceptionStackTrace, GotoMenu, Help, StatusSource, StatusLinks, StatusZoom
from .overlays import StatusDeleteConfirmation, Account
from .poll import Poll
from .timeline import Timeline
from .utils import get_max_toot_chars, parse_content_links, copy_to_clipboard
from .utils import get_max_toot_chars, parse_content_links, copy_to_clipboard, LRUCache
from .widgets import ModalBox, RoundedLineBox
logger = logging.getLogger(__name__)
@ -35,7 +37,9 @@ class TuiOptions(NamedTuple):
media_viewer: Optional[str]
always_show_sensitive: bool
relative_datetimes: bool
cache_size: int
default_visibility: Optional[str]
image_format: Optional[str]
class Header(urwid.WidgetWrap):
@ -95,7 +99,7 @@ class TUI(urwid.Frame):
@staticmethod
def create(app: App, user: User, args: TuiOptions):
"""Factory method, sets up TUI and an event loop."""
screen = urwid.raw_display.Screen()
screen = TuiScreen()
screen.set_terminal_properties(args.colors)
tui = TUI(app, user, screen, args)
@ -144,6 +148,11 @@ class TUI(urwid.Frame):
self.followed_accounts = []
self.preferences = {}
if self.options.cache_size:
self.cache_max = 1024 * 1024 * self.options.cache_size
else:
self.cache_max = 1024 * 1024 * 10 # default 10MB
super().__init__(self.body, header=self.header, footer=self.footer)
def run(self):
@ -648,7 +657,7 @@ class TUI(urwid.Frame):
account = api.whois(self.app, self.user, account_id)
relationship = api.get_relationship(self.app, self.user, account_id)
self.open_overlay(
widget=Account(self.app, self.user, account, relationship),
widget=Account(self.app, self.user, account, relationship, self.options),
title="Account",
)
@ -757,6 +766,27 @@ class TUI(urwid.Frame):
return self.run_in_thread(_delete, done_callback=_done)
def async_load_image(self, timeline, status, path, placeholder_index):
def _load():
# don't bother loading images for statuses we are not viewing now
if timeline.get_focused_status().id != status.id:
return
if not hasattr(timeline, "images"):
timeline.images = LRUCache(cache_max_bytes=self.cache_max)
img = load_image(path)
if img:
timeline.images[str(hash(path))] = img
def _done(loop):
# don't bother loading images for statuses we are not viewing now
if timeline.get_focused_status().id != status.id:
return
timeline.update_status_image(status, path, placeholder_index)
return self.run_in_thread(_load, done_callback=_done)
def copy_status(self, status):
# TODO: copy a better version of status content
# including URLs

104
toot/tui/images.py Normal file
View File

@ -0,0 +1,104 @@
import urwid
import math
import requests
import warnings
# If term_image is loaded use their screen implementation which handles images
try:
from term_image.widget import UrwidImageScreen, UrwidImage
from term_image.image import BaseImage, KittyImage, ITerm2Image, BlockImage
from term_image import disable_queries # prevent phantom keystrokes
from PIL import Image, ImageDraw
TuiScreen = UrwidImageScreen
disable_queries()
def image_support_enabled():
return True
def can_render_pixels(image_format):
return image_format in ['kitty', 'iterm']
def get_base_image(image, image_format) -> BaseImage:
# we don't autodetect kitty, iterm; we choose based on option switches
BaseImage.forced_support = True
if image_format == 'kitty':
return KittyImage(image)
elif image_format == 'iterm':
return ITerm2Image(image)
else:
return BlockImage(image)
def resize_image(basewidth: int, baseheight: int, img: Image.Image) -> Image.Image:
if baseheight and not basewidth:
hpercent = baseheight / float(img.size[1])
width = math.ceil(img.size[0] * hpercent)
img = img.resize((width, baseheight), Image.Resampling.LANCZOS)
elif basewidth and not baseheight:
wpercent = (basewidth / float(img.size[0]))
hsize = int((float(img.size[1]) * float(wpercent)))
img = img.resize((basewidth, hsize), Image.Resampling.LANCZOS)
else:
img = img.resize((basewidth, baseheight), Image.Resampling.LANCZOS)
if img.mode != 'P':
img = img.convert('RGB')
return img
def add_corners(img, rad):
circle = Image.new('L', (rad * 2, rad * 2), 0)
draw = ImageDraw.Draw(circle)
draw.ellipse((0, 0, rad * 2, rad * 2), fill=255)
alpha = Image.new('L', img.size, "white")
w, h = img.size
alpha.paste(circle.crop((0, 0, rad, rad)), (0, 0))
alpha.paste(circle.crop((0, rad, rad, rad * 2)), (0, h - rad))
alpha.paste(circle.crop((rad, 0, rad * 2, rad)), (w - rad, 0))
alpha.paste(circle.crop((rad, rad, rad * 2, rad * 2)), (w - rad, h - rad))
img.putalpha(alpha)
return img
def load_image(url):
with warnings.catch_warnings():
warnings.simplefilter("ignore") # suppress "corrupt exif" output from PIL
try:
img = Image.open(requests.get(url, stream=True).raw)
if img.format == 'PNG' and img.mode != 'RGBA':
img = img.convert("RGBA")
return img
except Exception:
return None
def graphics_widget(img, image_format="block", corner_radius=0) -> urwid.Widget:
if not img:
return urwid.SolidFill(fill_char=" ")
if can_render_pixels(image_format) and corner_radius > 0:
render_img = add_corners(img, 10)
else:
render_img = img
return UrwidImage(get_base_image(render_img, image_format), '<', upscale=True)
# "<" means left-justify the image
except ImportError:
from urwid.raw_display import Screen
TuiScreen = Screen
def image_support_enabled():
return False
def can_render_pixels(image_format: str):
return False
def get_base_image(image, image_format: str):
return None
def add_corners(img, rad):
return None
def load_image(url):
return None
def graphics_widget(img, image_format="block", corner_radius=0) -> urwid.Widget:
return urwid.SolidFill(fill_char=" ")

View File

@ -5,7 +5,9 @@ import webbrowser
from toot import __version__
from toot import api
from toot.tui.utils import highlight_keys
from toot.tui.images import image_support_enabled, load_image, graphics_widget
from toot.tui.widgets import Button, EditBox, SelectableText
from toot.tui.richtext import html_to_widgets
@ -242,11 +244,12 @@ class Help(urwid.Padding):
class Account(urwid.ListBox):
"""Shows account data and provides various actions"""
def __init__(self, app, user, account, relationship):
def __init__(self, app, user, account, relationship, options):
self.app = app
self.user = user
self.account = account
self.relationship = relationship
self.options = options
self.last_action = None
self.setup_listbox()
@ -255,6 +258,30 @@ class Account(urwid.ListBox):
walker = urwid.SimpleListWalker(actions)
super().__init__(walker)
def account_header(self, account):
if image_support_enabled() and account['avatar'] and not account["avatar"].endswith("missing.png"):
img = load_image(account['avatar'])
aimg = urwid.BoxAdapter(
graphics_widget(img, image_format=self.options.image_format, corner_radius=10), 10)
else:
aimg = urwid.BoxAdapter(urwid.SolidFill(" "), 10)
if image_support_enabled() and account['header'] and not account["header"].endswith("missing.png"):
img = load_image(account['header'])
himg = (urwid.BoxAdapter(
graphics_widget(img, image_format=self.options.image_format, corner_radius=10), 10))
else:
himg = urwid.BoxAdapter(urwid.SolidFill(" "), 10)
atxt = urwid.Pile([urwid.Divider(),
(urwid.Text(("account", account["display_name"]))),
(urwid.Text(("highlight", "@" + self.account['acct'])))])
columns = urwid.Columns([aimg, ("weight", 9999, himg)], dividechars=2, min_width=20)
header = urwid.Pile([columns, urwid.Divider(), atxt])
return header
def generate_contents(self, account, relationship=None, last_action=None):
if self.last_action and not self.last_action.startswith("Confirm"):
yield Button(f"Confirm {self.last_action}", on_press=take_action, user_data=self)
@ -276,11 +303,11 @@ class Account(urwid.ListBox):
yield urwid.Divider("")
yield urwid.Divider()
yield urwid.Text([("account", f"@{account['acct']}"), f" {account['display_name']}"])
yield self.account_header(account)
if account["note"]:
yield urwid.Divider()
widgetlist = html_to_widgets(account["note"])
for line in widgetlist:
yield (line)

View File

@ -1,26 +1,33 @@
import logging
import math
import urwid
import webbrowser
from typing import List, Optional
from toot.tui import app
from toot.tui.richtext import html_to_widgets, url_to_widget
from toot.utils.datetime import parse_datetime, time_ago
from toot.utils.language import language_name
from toot.entities import Status
from toot.tui.scroll import Scrollable, ScrollBar
from toot.tui.utils import highlight_keys
from toot.tui.images import image_support_enabled, graphics_widget, can_render_pixels
from toot.tui.widgets import SelectableText, SelectableColumns, RoundedLineBox
logger = logging.getLogger("toot")
screen = urwid.raw_display.Screen()
class Timeline(urwid.Columns):
"""
Displays a list of statuses to the left, and status details on the right.
"""
signals = [
"close", # Close thread
"focus", # Focus changed
@ -41,6 +48,7 @@ class Timeline(urwid.Columns):
self.is_thread = is_thread
self.statuses = statuses
self.status_list = self.build_status_list(statuses, focus=focus)
self.can_render_pixels = can_render_pixels(self.tui.options.image_format)
try:
focused_status = statuses[focus]
@ -141,6 +149,16 @@ class Timeline(urwid.Columns):
def modified(self):
"""Called when the list focus switches to a new status"""
status, index, count = self.get_focused_status_with_counts()
if image_support_enabled:
clear_op = getattr(self.tui.screen, "clear_images", None)
# term-image's screen implementation has clear_images(),
# urwid's implementation does not.
# TODO: it would be nice not to check this each time thru
if callable(clear_op):
self.tui.screen.clear_images()
self.draw_status_details(status)
self._emit("focus")
@ -282,7 +300,7 @@ class Timeline(urwid.Columns):
def get_status_index(self, id):
# TODO: This is suboptimal, consider a better way
for n, status in enumerate(self.statuses):
for n, status in enumerate(self.statuses.copy()):
if status.id == id:
return n
raise ValueError("Status with ID {} not found".format(id))
@ -306,6 +324,27 @@ class Timeline(urwid.Columns):
if index == self.status_list.body.focus:
self.draw_status_details(status)
def update_status_image(self, status, path, placeholder_index):
"""Replace image placeholder with image widget and redraw"""
index = self.get_status_index(status.id)
assert self.statuses[index].id == status.id # Sanity check
# get the image and replace the placeholder with a graphics widget
img = None
if hasattr(self, "images"):
try:
img = self.images[(str(hash(path)))]
except KeyError:
pass
if img:
try:
status.placeholders[placeholder_index]._set_original_widget(
graphics_widget(img, image_format=self.tui.options.image_format, corner_radius=10))
except IndexError:
# ignore IndexErrors.
pass
def remove_status(self, status):
index = self.get_status_index(status.id)
assert self.statuses[index].id == status.id # Sanity check
@ -318,6 +357,9 @@ class Timeline(urwid.Columns):
class StatusDetails(urwid.Pile):
def __init__(self, timeline: Timeline, status: Optional[Status]):
self.status = status
self.timeline = timeline
if self.status:
self.status.placeholders = []
self.followed_accounts = timeline.tui.followed_accounts
self.options = timeline.tui.options
@ -326,17 +368,83 @@ class StatusDetails(urwid.Pile):
if status else ())
return super().__init__(widget_list)
def image_widget(self, path, rows=None, aspect=None) -> urwid.Widget:
"""Returns a widget capable of displaying the image
path is required; URL to image
rows, if specfied, sets a fixed number of rows. Or:
aspect, if specified, calculates rows based on pane width
and the aspect ratio provided"""
if not rows:
if not aspect:
aspect = 3 / 2 # reasonable default
screen_rows = screen.get_cols_rows()[1]
if self.timeline.can_render_pixels:
# for pixel-rendered images,
# image rows should be 33% of the available screen
# but in no case fewer than 10
rows = max(10, math.floor(screen_rows * .33))
else:
# for cell-rendered images,
# use the max available columns
# and calculate rows based on the image
# aspect ratio
cols = math.floor(0.55 * screen.get_cols_rows()[0])
rows = math.ceil((cols / 2) / aspect)
# if the calculated rows are more than will
# fit on one screen, reduce to one screen of rows
rows = min(screen_rows - 6, rows)
# but in no case fewer than 10 rows
rows = max(rows, 10)
img = None
if hasattr(self.timeline, "images"):
try:
img = self.timeline.images[(str(hash(path)))]
except KeyError:
pass
if img:
return (urwid.BoxAdapter(
graphics_widget(img, image_format=self.timeline.tui.options.image_format, corner_radius=10), rows))
else:
placeholder = urwid.BoxAdapter(urwid.SolidFill(fill_char=" "), rows)
self.status.placeholders.append(placeholder)
if image_support_enabled():
self.timeline.tui.async_load_image(self.timeline, self.status, path, len(self.status.placeholders) - 1)
return placeholder
def author_header(self, reblogged_by):
avatar_url = self.status.original.data["account"]["avatar"]
if avatar_url and image_support_enabled():
aimg = self.image_widget(avatar_url, 2)
account_color = ("highlight" if self.status.original.author.account in
self.timeline.tui.followed_accounts else "account")
atxt = urwid.Pile([("pack", urwid.Text(("bold", self.status.original.author.display_name))),
("pack", urwid.Text((account_color, self.status.original.author.account)))])
if image_support_enabled():
columns = urwid.Columns([aimg, ("weight", 9999, atxt)], dividechars=1, min_width=5)
else:
columns = urwid.Columns([("weight", 9999, atxt)], dividechars=1, min_width=5)
return columns
def content_generator(self, status, reblogged_by):
if reblogged_by:
text = "{} boosted".format(reblogged_by.display_name or reblogged_by.username)
yield ("pack", urwid.Text(("dim", text)))
reblogger_name = (reblogged_by.display_name
if reblogged_by.display_name
else reblogged_by.username)
text = f"{reblogger_name} boosted"
yield urwid.Text(("dim", text))
yield ("pack", urwid.AttrMap(urwid.Divider("-"), "dim"))
if status.author.display_name:
yield ("pack", urwid.Text(("bold", status.author.display_name)))
account_color = "highlight" if status.author.account in self.followed_accounts else "account"
yield ("pack", urwid.Text((account_color, status.author.account)))
yield self.author_header(reblogged_by)
yield ("pack", urwid.Divider())
if status.data["spoiler_text"]:
@ -363,7 +471,27 @@ class StatusDetails(urwid.Pile):
yield ("pack", urwid.Text([("bold", "Media attachment"), " (", m["type"], ")"]))
if m["description"]:
yield ("pack", urwid.Text(m["description"]))
yield ("pack", url_to_widget(m["url"]))
if m["url"]:
if m["url"].lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.svg', '.webp')):
yield urwid.Text("")
try:
aspect = float(m["meta"]["original"]["aspect"])
except Exception:
aspect = None
if image_support_enabled():
yield self.image_widget(m["url"], aspect=aspect)
yield urwid.Divider()
# video media may include a preview URL, show that as a fallback
elif m["preview_url"].lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.svg', '.webp')):
yield urwid.Text("")
try:
aspect = float(m["meta"]["small"]["aspect"])
except Exception:
aspect = None
if image_support_enabled():
yield self.image_widget(m["preview_url"], aspect=aspect)
yield urwid.Divider()
yield ("pack", url_to_widget(m["url"]))
poll = status.original.data.get("poll")
if poll:
@ -427,6 +555,15 @@ class StatusDetails(urwid.Pile):
yield urwid.Text("")
yield url_to_widget(card["url"])
if card["image"] and image_support_enabled():
if card["image"].lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.svg', '.webp')):
yield urwid.Text("")
try:
aspect = int(card["width"]) / int(card["height"])
except Exception:
aspect = None
yield self.image_widget(card["image"], aspect=aspect)
def poll_generator(self, poll):
for idx, option in enumerate(poll["options"]):
perc = (round(100 * option["votes_count"] / poll["votes_count"])

View File

@ -1,7 +1,8 @@
import base64
import re
import sys
import urwid
from collections import OrderedDict
from functools import reduce
from html.parser import HTMLParser
from typing import List
@ -109,3 +110,33 @@ def deep_get(adict: dict, path: List[str], default=None):
path,
adict
)
class LRUCache(OrderedDict):
"""Dict with a limited size, ejecting LRUs as needed.
Default max size = 10Mb"""
def __init__(self, *args, cache_max_bytes: int = 1024 * 1024 * 10, **kwargs):
assert cache_max_bytes > 0
self.total_value_size = 0
self.cache_max_bytes = cache_max_bytes
super().__init__(*args, **kwargs)
def __setitem__(self, key: str, value):
if key in self:
self.total_value_size -= sys.getsizeof(super().__getitem__(key).tobytes())
self.total_value_size += sys.getsizeof(value.tobytes())
super().__setitem__(key, value)
super().move_to_end(key)
while self.total_value_size > self.cache_max_bytes:
old_key, value = next(iter(self.items()))
sz = sys.getsizeof(value.tobytes())
super().__delitem__(old_key)
self.total_value_size -= sz
def __getitem__(self, key: str):
val = super().__getitem__(key)
super().move_to_end(key)
return val