add type hinting

This commit is contained in:
odysseusmax 2021-11-01 09:32:46 +05:30
parent 34811800f7
commit b5e876f393
16 changed files with 145 additions and 85 deletions

View File

@ -1,16 +1,42 @@
import logging
from typing import List
from aiohttp import web
from aiohttp.web_routedef import RouteDef
from telethon.tl.types import Channel, Chat, User
from .config import index_settings
from .views import Views
log = logging.getLogger(__name__)
async def setup_routes(app, handler):
h = handler
client = h.client
def get_common_routes(handler: Views, alias_id: str) -> List[RouteDef]:
p = "/{chat:" + alias_id + "}"
return [
web.get(p, handler.index, name=f"index_{alias_id}"),
web.get(p + r"/logo", handler.logo, name=f"logo_{alias_id}"),
web.get(p + r"/{id:\d+}/view", handler.info, name=f"info_{alias_id}"),
web.get(
p + r"/{id:\d+}/thumbnail",
handler.thumbnail_get,
name=f"thumbnail_get_{alias_id}",
),
web.get(
p + r"/{id:\d+}/{filename}",
handler.download_get,
name=f"download_get_{alias_id}",
),
web.head(
p + r"/{id:\d+}/{filename}",
handler.download_head,
name=f"download_head_{alias_id}",
),
]
async def setup_routes(app: web.Application, handler: Views):
client = handler.client
index_all = index_settings["index_all"]
index_private = index_settings["index_private"]
index_group = index_settings["index_group"]
@ -18,36 +44,13 @@ async def setup_routes(app, handler):
exclude_chats = index_settings["exclude_chats"]
include_chats = index_settings["include_chats"]
routes = [
web.get("/", h.home, name="home"),
web.get("/login", h.login_get, name="login_page"),
web.post("/login", h.login_post, name="login_handle"),
web.get("/logout", h.logout_get, name="logout"),
web.get("/favicon.ico", h.faviconicon, name="favicon"),
web.get("/", handler.home, name="home"),
web.get("/login", handler.login_get, name="login_page"),
web.post("/login", handler.login_post, name="login_handle"),
web.get("/logout", handler.logout_get, name="logout"),
web.get("/favicon.ico", handler.faviconicon, name="favicon"),
]
def get_common_routes(alias_id):
p = "/{chat:" + alias_id + "}"
return [
web.get(p, h.index, name=f"index_{alias_id}"),
web.get(p + r"/logo", h.logo, name=f"logo_{alias_id}"),
web.get(p + r"/{id:\d+}/view", h.info, name=f"info_{alias_id}"),
web.get(
p + r"/{id:\d+}/thumbnail",
h.thumbnail_get,
name=f"thumbnail_get_{alias_id}",
),
web.get(
p + r"/{id:\d+}/{filename}",
h.download_get,
name=f"download_get_{alias_id}",
),
web.head(
p + r"/{id:\d+}/{filename}",
h.download_head,
name=f"download_head_{alias_id}",
),
]
if index_all:
# print(await client.get_dialogs())
# dialogs = await client.get_dialogs()
@ -69,17 +72,17 @@ async def setup_routes(app, handler):
log.debug(f"{chat.title}, group: {index_group}")
continue
alias_id = h.generate_alias_id(chat)
routes.extend(get_common_routes(alias_id))
alias_id = handler.generate_alias_id(chat)
routes.extend(get_common_routes(handler, alias_id))
log.debug(f"Index added for {chat.id} at /{alias_id}")
else:
for chat_id in include_chats:
chat = await client.get_entity(chat_id)
alias_id = h.generate_alias_id(chat)
alias_id = handler.generate_alias_id(chat)
routes.extend(
get_common_routes(alias_id)
get_common_routes(handler, alias_id)
) # returns list() of common routes
log.debug(f"Index added for {chat.id} at /{alias_id}")
routes.append(web.view(r"/{wildcard:.*}", h.wildcard, name="wildcard"))
routes.append(web.view(r"/{wildcard:.*}", handler.wildcard, name="wildcard"))
app.add_routes(routes)

View File

@ -7,7 +7,7 @@ from telethon.sessions import StringSession
class Client(TelegramClient):
def __init__(self, session_string, *args, **kwargs):
def __init__(self, session_string: str, *args, **kwargs):
super().__init__(StringSession(session_string), *args, **kwargs)
self.log = logging.getLogger(__name__)
@ -25,7 +25,7 @@ class Client(TelegramClient):
first_part = {first_part}, cut = {first_part_cut}(length={part_size-first_part_cut}),
last_part = {last_part}, cut = {last_part_cut}(length={last_part_cut}),
parts_count = {part_count}
"""
"""
)
try:
async for chunk in self.iter_download(
@ -42,7 +42,7 @@ class Client(TelegramClient):
part += 1
self.log.debug(f"serving finished")
self.log.debug("serving finished")
except (GeneratorExit, StopAsyncIteration, asyncio.CancelledError):
self.log.debug("file serve interrupted")
raise

View File

@ -1,7 +1,10 @@
from typing import Union
from urllib.parse import quote
from telethon.tl.custom import Message
def get_file_name(message, quote_name=True):
def get_file_name(message: Message, quote_name: bool = True) -> str:
if message.file.name:
name = message.file.name
else:
@ -10,7 +13,7 @@ def get_file_name(message, quote_name=True):
return quote(name) if quote_name else name
def get_human_size(num):
def get_human_size(num: Union[int, float]) -> str:
base = 1024.0
sufix_list = ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"]
for unit in sufix_list:

View File

@ -1,8 +1,11 @@
import base64
import hashlib
from typing import Dict, Union
from telethon.tl.types import Chat, User, Channel
from ..config import SHORT_URL_LEN
from ..telegram import Client
from .home_view import HomeView
from .wildcard_view import WildcardView
from .download import Download
@ -16,6 +19,9 @@ from .faviconicon_view import FaviconIconView
from .middlewhere import middleware_factory
TELEGRAM_CHAT = Union[Chat, User, Channel]
class Views(
HomeView,
Download,
@ -28,12 +34,12 @@ class Views(
LogoutView,
FaviconIconView,
):
def __init__(self, client):
def __init__(self, client: Client):
self.client = client
self.url_len = SHORT_URL_LEN
self.chat_ids = {}
self.chat_ids: Dict[str, Dict[str, str]] = {}
def generate_alias_id(self, chat):
def generate_alias_id(self, chat: TELEGRAM_CHAT) -> str:
chat_id = chat.id
title = chat.title

14
app/views/base.py Normal file
View File

@ -0,0 +1,14 @@
from typing import Dict, Union
from telethon.tl.types import Chat, User, Channel
from ..telegram import Client
TELEGRAM_CHAT = Union[Chat, User, Channel]
class BaseView:
client: Client
url_len: int
chat_ids: Dict[str, Dict[str, str]]

View File

@ -1,22 +1,26 @@
import logging
from aiohttp import web
from telethon.tl.custom import Message
from app.util import get_file_name
from app.config import block_downloads
from .base import BaseView
log = logging.getLogger(__name__)
class Download:
async def download_get(self, req):
class Download(BaseView):
async def download_get(self, req: web.Request) -> web.Response:
return await self.handle_request(req)
async def download_head(self, req):
async def download_head(self, req: web.Request) -> web.Response:
return await self.handle_request(req, head=True)
async def handle_request(self, req, head=False):
async def handle_request(
self, req: web.Request, head: bool = False
) -> web.Response:
if block_downloads:
return web.Response(status=403, text="403: Forbiden" if not head else None)
@ -26,7 +30,9 @@ class Download:
chat_id = chat["chat_id"]
try:
message = await self.client.get_messages(entity=chat_id, ids=file_id)
message: Message = await self.client.get_messages(
entity=chat_id, ids=file_id
)
except Exception:
log.debug(f"Error in getting message {file_id} in {chat_id}", exc_info=True)
message = None

View File

@ -4,10 +4,11 @@ from PIL import Image, ImageDraw, ImageFont
from aiohttp import web
from app.config import logo_folder
from .base import BaseView
class FaviconIconView:
async def faviconicon(self, req):
class FaviconIconView(BaseView):
async def faviconicon(self, req: web.Request) -> web.Response:
favicon_path = logo_folder.joinpath("favicon.ico")
text = "T"
if not favicon_path.exists():
@ -15,10 +16,10 @@ class FaviconIconView:
color = tuple((random.randint(0, 255) for _ in range(3)))
im = Image.new("RGB", (W, H), color)
draw = ImageDraw.Draw(im)
font = ImageFont.truetype("arial.ttf", 50)
font = ImageFont.truetype("arial.ttf", 100)
w, h = draw.textsize(text, font=font)
draw.text(((W - w) / 2, (H - h) / 2), text, fill="white", font=font)
im.save(favicon_path)
im.save(favicon_path, "JPEG")
with open(favicon_path, "rb") as fp:
body = fp.read()

View File

@ -1,11 +1,12 @@
from aiohttp import web
import aiohttp_jinja2
from .base import BaseView
class HomeView:
class HomeView(BaseView):
@aiohttp_jinja2.template("home.html")
async def home(self, req):
print(self.chat_ids)
async def home(self, req: web.Request) -> web.Response:
if len(self.chat_ids) == 1:
(chat,) = self.chat_ids.values()
return web.HTTPFound(f"{chat['alias_id']}")

View File

@ -1,19 +1,22 @@
import logging
from typing import List
from urllib.parse import quote
import aiohttp_jinja2
from telethon.tl import types
from aiohttp import web
from telethon.tl import types, custom
from app.config import results_per_page, block_downloads
from app.util import get_file_name, get_human_size
from .base import BaseView
log = logging.getLogger(__name__)
class IndexView:
class IndexView(BaseView):
@aiohttp_jinja2.template("index.html")
async def index(self, req):
async def index(self, req: web.Request) -> web.Response:
alias_id = req.match_info["chat"]
chat = self.chat_ids[alias_id]
log_msg = ""
@ -39,7 +42,9 @@ class IndexView:
if search_query:
kwargs.update({"search": search_query})
messages = (await self.client.get_messages(**kwargs)) or []
messages: List[custom.Message] = (
await self.client.get_messages(**kwargs)
) or []
except Exception:
log.debug("failed to get messages", exc_info=True)

View File

@ -2,20 +2,22 @@ import logging
from urllib.parse import unquote
import aiohttp_jinja2
from aiohttp import web
from telethon.tl import types
from telethon.tl.custom import Message
from jinja2 import Markup
from app.util import get_file_name, get_human_size
from app.config import block_downloads
from .base import BaseView
log = logging.getLogger(__name__)
class InfoView:
class InfoView(BaseView):
@aiohttp_jinja2.template("info.html")
async def info(self, req):
async def info(self, req: web.Request) -> web.Response:
file_id = int(req.match_info["id"])
alias_id = req.match_info["chat"]
chat = self.chat_ids[alias_id]

View File

@ -3,14 +3,15 @@ import time
from aiohttp import web
import aiohttp_jinja2
from aiohttp_session import new_session
from .base import BaseView
class LoginView:
class LoginView(BaseView):
@aiohttp_jinja2.template("login.html")
async def login_get(self, req):
async def login_get(self, req: web.Request) -> web.Response:
return dict(authenticated=False, **req.query)
async def login_post(self, req):
async def login_post(self, req: web.Request) -> web.Response:
post_data = await req.post()
redirect_to = post_data.get("redirect_to") or "/"
location = req.app.router["login_page"].url_for()

View File

@ -5,22 +5,29 @@ import random
from aiohttp import web
from telethon.tl import types
from app.config import logo_folder
from .base import BaseView
log = logging.getLogger(__name__)
class LogoView:
async def logo(self, req):
class LogoView(BaseView):
async def logo(self, req: web.Request) -> web.Response:
alias_id = req.match_info["chat"]
chat = self.chat_ids[alias_id]
chat_id = chat["chat_id"]
chat_name = " ".join(map(lambda x: x[0].upper(), chat["title"].split(" ")))
chat_name = " ".join(
map(lambda x: x[0].upper(), (chat["title"] or "_").split(" "))
)
logo_path = logo_folder.joinpath(f"{alias_id}.jpg")
if not logo_path.exists():
try:
(photo,) = await self.client.get_profile_photos(chat_id, limit=1)
photo: types.Photo = (
await self.client.get_profile_photos(chat_id, limit=1)
)[0]
except Exception:
log.debug(
f"Error in getting profile picture in {chat_id}", exc_info=True
@ -40,7 +47,7 @@ class LogoView:
im.save(logo_path)
else:
pos = -1 if req.query.get("big", None) else int(len(photo.sizes) / 2)
size = self.client._get_thumb(photo.sizes, pos)
size: types.PhotoSize = self.client._get_thumb(photo.sizes, pos)
if isinstance(size, (types.PhotoCachedSize, types.PhotoStrippedSize)):
await self.client._download_cached_photo_size(size, logo_path)
else:

View File

@ -1,9 +1,11 @@
from aiohttp_session import get_session
from aiohttp import web
from .base import BaseView
class LogoutView:
async def logout_get(self, req):
class LogoutView(BaseView):
async def logout_get(self, req: web.Request) -> web.Response:
session = await get_session(req)
session["logged_in"] = False

View File

@ -1,7 +1,8 @@
import time
import logging
from typing import Coroutine, Union
from aiohttp.web import middleware, HTTPFound, Response
from aiohttp.web import middleware, HTTPFound, Response, Request
from aiohttp import BasicAuth, hdrs
from aiohttp_session import get_session
@ -9,7 +10,7 @@ from aiohttp_session import get_session
log = logging.getLogger(__name__)
def _do_basic_auth_check(request):
def _do_basic_auth_check(request: Request) -> Union[None, bool]:
if "download_" not in request.match_info.route.name:
return
@ -47,7 +48,7 @@ def _do_basic_auth_check(request):
return True
async def _do_cookies_auth_check(request):
async def _do_cookies_auth_check(request: Request) -> Union[None, bool]:
session = await get_session(request)
if not session.get("logged_in", False):
return
@ -56,9 +57,9 @@ async def _do_cookies_auth_check(request):
return True
def middleware_factory():
def middleware_factory() -> Coroutine:
@middleware
async def factory(request, handler):
async def factory(request: Request, handler: Coroutine) -> Response:
if request.app["is_authenticated"] and str(request.rel_url.path) not in [
"/login",
"/logout",

View File

@ -4,20 +4,24 @@ import random
import io
from aiohttp import web
from telethon.tl import types
from telethon.tl import types, custom
from .base import BaseView
log = logging.getLogger(__name__)
class ThumbnailView:
async def thumbnail_get(self, req):
class ThumbnailView(BaseView):
async def thumbnail_get(self, req: web.Request) -> web.Response:
file_id = int(req.match_info["id"])
alias_id = req.match_info["chat"]
chat = self.chat_ids[alias_id]
chat_id = chat["chat_id"]
try:
message = await self.client.get_messages(entity=chat_id, ids=file_id)
message: custom.Message = await self.client.get_messages(
entity=chat_id, ids=file_id
)
except Exception:
log.debug(f"Error in getting message {file_id} in {chat_id}", exc_info=True)
message = None
@ -47,7 +51,9 @@ class ThumbnailView:
else:
thumb_pos = int(len(thumbnails) / 2)
try:
thumbnail = self.client._get_thumb(thumbnails, thumb_pos)
thumbnail: types.PhotoSize = self.client._get_thumb(
thumbnails, thumb_pos
)
except Exception as e:
logging.debug(e)
thumbnail = None

View File

@ -1,6 +1,8 @@
from aiohttp import web
from .base import BaseView
class WildcardView:
async def wildcard(self, req):
class WildcardView(BaseView):
async def wildcard(self, req: web.Request) -> web.Response:
return web.HTTPFound("/")