Improve types

This commit is contained in:
Ivan Habunek 2024-04-13 15:30:52 +02:00
parent 2ba90fc2d2
commit 927fdc3026
No known key found for this signature in database
GPG Key ID: F5F0623FF5EBCB3D
1 changed files with 33 additions and 34 deletions

View File

@ -14,12 +14,18 @@ import typing as t
from dataclasses import dataclass, is_dataclass from dataclasses import dataclass, is_dataclass
from datetime import date, datetime from datetime import date, datetime
from functools import lru_cache from functools import lru_cache
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union from typing import Any, Dict, NamedTuple, Optional, Type, TypeVar, Union
from typing import get_args, get_origin, get_type_hints from typing import get_args, get_origin, get_type_hints
from toot.utils import get_text from toot.utils import get_text
from toot.utils.datetime import parse_datetime from toot.utils.datetime import parse_datetime
# Generic data class instance
T = TypeVar("T")
# A dict decoded from JSON
Data = Dict[str, Any]
@dataclass @dataclass
class AccountField: class AccountField:
@ -76,7 +82,7 @@ class Account:
source: Optional[dict] source: Optional[dict]
@staticmethod @staticmethod
def __toot_prepare__(obj: Dict) -> Dict: def __toot_prepare__(obj: Data) -> Data:
# Pleroma has not yet converted last_status_at from datetime to date # Pleroma has not yet converted last_status_at from datetime to date
# so trim it here so it doesn't break when converting to date. # so trim it here so it doesn't break when converting to date.
# See: https://git.pleroma.social/pleroma/pleroma/-/issues/1470 # See: https://git.pleroma.social/pleroma/pleroma/-/issues/1470
@ -266,7 +272,7 @@ class Status:
return self.reblog or self return self.reblog or self
@staticmethod @staticmethod
def __toot_prepare__(obj: Dict) -> Dict: def __toot_prepare__(obj: Data) -> Data:
# Pleroma has a bug where created_at is set to an empty string. # Pleroma has a bug where created_at is set to an empty string.
# To avoid marking created_at as optional, which would require work # To avoid marking created_at as optional, which would require work
# because we count on it always existing, set it to current datetime. # because we count on it always existing, set it to current datetime.
@ -457,27 +463,25 @@ class List:
# see: https://git.pleroma.social/pleroma/pleroma/-/issues/2918 # see: https://git.pleroma.social/pleroma/pleroma/-/issues/2918
replies_policy: Optional[str] replies_policy: Optional[str]
# ------------------------------------------------------------------------------
# Generic data class instance
T = TypeVar("T") class Field(NamedTuple):
name: str
type: Any
default: Any
class ConversionError(Exception): class ConversionError(Exception):
"""Raised when conversion fails from JSON value to data class field.""" """Raised when conversion fails from JSON value to data class field."""
def __init__( def __init__(self, data_class: type, field: Field, field_value: Optional[str]):
self,
data_class: Type,
field_name: str,
field_type: Type,
field_value: Optional[str]
):
super().__init__( super().__init__(
f"Failed converting field `{data_class.__name__}.{field_name}` " f"Failed converting field `{data_class.__name__}.{field.name}` "
+ f"of type `{field_type.__name__}` from value {field_value!r}" + f"of type `{field.type.__name__}` from value {field_value!r}"
) )
def from_dict(cls: Type[T], data: Dict) -> T: def from_dict(cls: Type[T], data: Data) -> T:
"""Convert a nested dict into an instance of `cls`.""" """Convert a nested dict into an instance of `cls`."""
# Apply __toot_prepare__ if it exists # Apply __toot_prepare__ if it exists
prepare = getattr(cls, '__toot_prepare__', None) prepare = getattr(cls, '__toot_prepare__', None)
@ -485,19 +489,19 @@ def from_dict(cls: Type[T], data: Dict) -> T:
data = prepare(data) data = prepare(data)
def _fields(): def _fields():
for name, type, default in get_fields(cls): for field in _get_fields(cls):
value = data.get(name, default) value = data.get(field.name, field.default)
converted = _convert_with_error_handling(cls, name, type, value) converted = _convert_with_error_handling(cls, field, value)
yield name, converted yield field.name, converted
return cls(**dict(_fields())) return cls(**dict(_fields()))
@lru_cache(maxsize=100) @lru_cache
def get_fields(cls: Type) -> t.List[Tuple[str, Type, Any]]: def _get_fields(cls: type) -> t.List[Field]:
hints = get_type_hints(cls) hints = get_type_hints(cls)
return [ return [
( Field(
field.name, field.name,
_prune_optional(hints[field.name]), _prune_optional(hints[field.name]),
_get_default_value(field) _get_default_value(field)
@ -506,11 +510,11 @@ def get_fields(cls: Type) -> t.List[Tuple[str, Type, Any]]:
] ]
def from_dict_list(cls: Type[T], data: t.List[Dict]) -> t.List[T]: def from_dict_list(cls: Type[T], data: t.List[Data]) -> t.List[T]:
return [from_dict(cls, x) for x in data] return [from_dict(cls, x) for x in data]
def _get_default_value(field): def _get_default_value(field: dataclasses.Field[Any]):
if field.default is not dataclasses.MISSING: if field.default is not dataclasses.MISSING:
return field.default return field.default
@ -520,21 +524,16 @@ def _get_default_value(field):
return None return None
def _convert_with_error_handling( def _convert_with_error_handling(data_class: type, field: Field, field_value: Any) -> Any:
data_class: Type,
field_name: str,
field_type: Type,
field_value: Optional[str]
):
try: try:
return _convert(field_type, field_value) return _convert(field.type, field_value)
except ConversionError: except ConversionError:
raise raise
except Exception: except Exception:
raise ConversionError(data_class, field_name, field_type, field_value) raise ConversionError(data_class, field, field_value)
def _convert(field_type, value): def _convert(field_type: Any, value: Any) -> Any:
if value is None: if value is None:
return None return None
@ -557,7 +556,7 @@ def _convert(field_type, value):
raise ValueError(f"Not implemented for type '{field_type}'") raise ValueError(f"Not implemented for type '{field_type}'")
def _prune_optional(field_type: Type) -> Type: def _prune_optional(field_type: type) -> type:
"""For `Optional[<type>]` returns the encapsulated `<type>`.""" """For `Optional[<type>]` returns the encapsulated `<type>`."""
if get_origin(field_type) == Union: if get_origin(field_type) == Union:
args = get_args(field_type) args = get_args(field_type)