diff --git a/toot/entities.py b/toot/entities.py index 3706bac..30d7323 100644 --- a/toot/entities.py +++ b/toot/entities.py @@ -14,12 +14,18 @@ import typing as t from dataclasses import dataclass, is_dataclass from datetime import date, datetime from functools import lru_cache -from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, NamedTuple, Optional, Type, TypeVar, Union from typing import get_args, get_origin, get_type_hints from toot.utils import get_text from toot.utils.datetime import parse_datetime +# Generic data class instance +T = TypeVar("T") + +# A dict decoded from JSON +Data = Dict[str, Any] + @dataclass class AccountField: @@ -76,7 +82,7 @@ class Account: source: Optional[dict] @staticmethod - def __toot_prepare__(obj: Dict) -> Dict: + def __toot_prepare__(obj: Data) -> Data: # Pleroma has not yet converted last_status_at from datetime to date # so trim it here so it doesn't break when converting to date. # See: https://git.pleroma.social/pleroma/pleroma/-/issues/1470 @@ -266,7 +272,7 @@ class Status: return self.reblog or self @staticmethod - def __toot_prepare__(obj: Dict) -> Dict: + def __toot_prepare__(obj: Data) -> Data: # Pleroma has a bug where created_at is set to an empty string. # To avoid marking created_at as optional, which would require work # because we count on it always existing, set it to current datetime. @@ -457,27 +463,25 @@ class List: # see: https://git.pleroma.social/pleroma/pleroma/-/issues/2918 replies_policy: Optional[str] +# ------------------------------------------------------------------------------ -# Generic data class instance -T = TypeVar("T") + +class Field(NamedTuple): + name: str + type: Any + default: Any class ConversionError(Exception): """Raised when conversion fails from JSON value to data class field.""" - def __init__( - self, - data_class: Type, - field_name: str, - field_type: Type, - field_value: Optional[str] - ): + def __init__(self, data_class: type, field: Field, field_value: Optional[str]): super().__init__( - f"Failed converting field `{data_class.__name__}.{field_name}` " - + f"of type `{field_type.__name__}` from value {field_value!r}" + f"Failed converting field `{data_class.__name__}.{field.name}` " + + f"of type `{field.type.__name__}` from value {field_value!r}" ) -def from_dict(cls: Type[T], data: Dict) -> T: +def from_dict(cls: Type[T], data: Data) -> T: """Convert a nested dict into an instance of `cls`.""" # Apply __toot_prepare__ if it exists prepare = getattr(cls, '__toot_prepare__', None) @@ -485,19 +489,19 @@ def from_dict(cls: Type[T], data: Dict) -> T: data = prepare(data) def _fields(): - for name, type, default in get_fields(cls): - value = data.get(name, default) - converted = _convert_with_error_handling(cls, name, type, value) - yield name, converted + for field in _get_fields(cls): + value = data.get(field.name, field.default) + converted = _convert_with_error_handling(cls, field, value) + yield field.name, converted return cls(**dict(_fields())) -@lru_cache(maxsize=100) -def get_fields(cls: Type) -> t.List[Tuple[str, Type, Any]]: +@lru_cache +def _get_fields(cls: type) -> t.List[Field]: hints = get_type_hints(cls) return [ - ( + Field( field.name, _prune_optional(hints[field.name]), _get_default_value(field) @@ -506,11 +510,11 @@ def get_fields(cls: Type) -> t.List[Tuple[str, Type, Any]]: ] -def from_dict_list(cls: Type[T], data: t.List[Dict]) -> t.List[T]: +def from_dict_list(cls: Type[T], data: t.List[Data]) -> t.List[T]: return [from_dict(cls, x) for x in data] -def _get_default_value(field): +def _get_default_value(field: dataclasses.Field[Any]): if field.default is not dataclasses.MISSING: return field.default @@ -520,21 +524,16 @@ def _get_default_value(field): return None -def _convert_with_error_handling( - data_class: Type, - field_name: str, - field_type: Type, - field_value: Optional[str] -): +def _convert_with_error_handling(data_class: type, field: Field, field_value: Any) -> Any: try: - return _convert(field_type, field_value) + return _convert(field.type, field_value) except ConversionError: raise except Exception: - raise ConversionError(data_class, field_name, field_type, field_value) + raise ConversionError(data_class, field, field_value) -def _convert(field_type, value): +def _convert(field_type: Any, value: Any) -> Any: if value is None: return None @@ -557,7 +556,7 @@ def _convert(field_type, value): raise ValueError(f"Not implemented for type '{field_type}'") -def _prune_optional(field_type: Type) -> Type: +def _prune_optional(field_type: type) -> type: """For `Optional[]` returns the encapsulated ``.""" if get_origin(field_type) == Union: args = get_args(field_type)