mirror of
https://github.com/ihabunek/toot
synced 2024-12-22 15:06:05 +01:00
Improve types
This commit is contained in:
parent
2ba90fc2d2
commit
927fdc3026
@ -14,12 +14,18 @@ import typing as t
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from datetime import date, datetime
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Dict, NamedTuple, Optional, Type, TypeVar, Union
|
||||
from typing import get_args, get_origin, get_type_hints
|
||||
|
||||
from toot.utils import get_text
|
||||
from toot.utils.datetime import parse_datetime
|
||||
|
||||
# Generic data class instance
|
||||
T = TypeVar("T")
|
||||
|
||||
# A dict decoded from JSON
|
||||
Data = Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AccountField:
|
||||
@ -76,7 +82,7 @@ class Account:
|
||||
source: Optional[dict]
|
||||
|
||||
@staticmethod
|
||||
def __toot_prepare__(obj: Dict) -> Dict:
|
||||
def __toot_prepare__(obj: Data) -> Data:
|
||||
# Pleroma has not yet converted last_status_at from datetime to date
|
||||
# so trim it here so it doesn't break when converting to date.
|
||||
# See: https://git.pleroma.social/pleroma/pleroma/-/issues/1470
|
||||
@ -266,7 +272,7 @@ class Status:
|
||||
return self.reblog or self
|
||||
|
||||
@staticmethod
|
||||
def __toot_prepare__(obj: Dict) -> Dict:
|
||||
def __toot_prepare__(obj: Data) -> Data:
|
||||
# Pleroma has a bug where created_at is set to an empty string.
|
||||
# To avoid marking created_at as optional, which would require work
|
||||
# because we count on it always existing, set it to current datetime.
|
||||
@ -457,27 +463,25 @@ class List:
|
||||
# see: https://git.pleroma.social/pleroma/pleroma/-/issues/2918
|
||||
replies_policy: Optional[str]
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Generic data class instance
|
||||
T = TypeVar("T")
|
||||
|
||||
class Field(NamedTuple):
|
||||
name: str
|
||||
type: Any
|
||||
default: Any
|
||||
|
||||
|
||||
class ConversionError(Exception):
|
||||
"""Raised when conversion fails from JSON value to data class field."""
|
||||
def __init__(
|
||||
self,
|
||||
data_class: Type,
|
||||
field_name: str,
|
||||
field_type: Type,
|
||||
field_value: Optional[str]
|
||||
):
|
||||
def __init__(self, data_class: type, field: Field, field_value: Optional[str]):
|
||||
super().__init__(
|
||||
f"Failed converting field `{data_class.__name__}.{field_name}` "
|
||||
+ f"of type `{field_type.__name__}` from value {field_value!r}"
|
||||
f"Failed converting field `{data_class.__name__}.{field.name}` "
|
||||
+ f"of type `{field.type.__name__}` from value {field_value!r}"
|
||||
)
|
||||
|
||||
|
||||
def from_dict(cls: Type[T], data: Dict) -> T:
|
||||
def from_dict(cls: Type[T], data: Data) -> T:
|
||||
"""Convert a nested dict into an instance of `cls`."""
|
||||
# Apply __toot_prepare__ if it exists
|
||||
prepare = getattr(cls, '__toot_prepare__', None)
|
||||
@ -485,19 +489,19 @@ def from_dict(cls: Type[T], data: Dict) -> T:
|
||||
data = prepare(data)
|
||||
|
||||
def _fields():
|
||||
for name, type, default in get_fields(cls):
|
||||
value = data.get(name, default)
|
||||
converted = _convert_with_error_handling(cls, name, type, value)
|
||||
yield name, converted
|
||||
for field in _get_fields(cls):
|
||||
value = data.get(field.name, field.default)
|
||||
converted = _convert_with_error_handling(cls, field, value)
|
||||
yield field.name, converted
|
||||
|
||||
return cls(**dict(_fields()))
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def get_fields(cls: Type) -> t.List[Tuple[str, Type, Any]]:
|
||||
@lru_cache
|
||||
def _get_fields(cls: type) -> t.List[Field]:
|
||||
hints = get_type_hints(cls)
|
||||
return [
|
||||
(
|
||||
Field(
|
||||
field.name,
|
||||
_prune_optional(hints[field.name]),
|
||||
_get_default_value(field)
|
||||
@ -506,11 +510,11 @@ def get_fields(cls: Type) -> t.List[Tuple[str, Type, Any]]:
|
||||
]
|
||||
|
||||
|
||||
def from_dict_list(cls: Type[T], data: t.List[Dict]) -> t.List[T]:
|
||||
def from_dict_list(cls: Type[T], data: t.List[Data]) -> t.List[T]:
|
||||
return [from_dict(cls, x) for x in data]
|
||||
|
||||
|
||||
def _get_default_value(field):
|
||||
def _get_default_value(field: dataclasses.Field[Any]):
|
||||
if field.default is not dataclasses.MISSING:
|
||||
return field.default
|
||||
|
||||
@ -520,21 +524,16 @@ def _get_default_value(field):
|
||||
return None
|
||||
|
||||
|
||||
def _convert_with_error_handling(
|
||||
data_class: Type,
|
||||
field_name: str,
|
||||
field_type: Type,
|
||||
field_value: Optional[str]
|
||||
):
|
||||
def _convert_with_error_handling(data_class: type, field: Field, field_value: Any) -> Any:
|
||||
try:
|
||||
return _convert(field_type, field_value)
|
||||
return _convert(field.type, field_value)
|
||||
except ConversionError:
|
||||
raise
|
||||
except Exception:
|
||||
raise ConversionError(data_class, field_name, field_type, field_value)
|
||||
raise ConversionError(data_class, field, field_value)
|
||||
|
||||
|
||||
def _convert(field_type, value):
|
||||
def _convert(field_type: Any, value: Any) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
@ -557,7 +556,7 @@ def _convert(field_type, value):
|
||||
raise ValueError(f"Not implemented for type '{field_type}'")
|
||||
|
||||
|
||||
def _prune_optional(field_type: Type) -> Type:
|
||||
def _prune_optional(field_type: type) -> type:
|
||||
"""For `Optional[<type>]` returns the encapsulated `<type>`."""
|
||||
if get_origin(field_type) == Union:
|
||||
args = get_args(field_type)
|
||||
|
Loading…
Reference in New Issue
Block a user