diff --git a/app/routes.py b/app/routes.py index 971647b..942723a 100644 --- a/app/routes.py +++ b/app/routes.py @@ -1,16 +1,42 @@ import logging +from typing import List from aiohttp import web +from aiohttp.web_routedef import RouteDef from telethon.tl.types import Channel, Chat, User from .config import index_settings +from .views import Views log = logging.getLogger(__name__) -async def setup_routes(app, handler): - h = handler - client = h.client +def get_common_routes(handler: Views, alias_id: str) -> List[RouteDef]: + p = "/{chat:" + alias_id + "}" + return [ + web.get(p, handler.index, name=f"index_{alias_id}"), + web.get(p + r"/logo", handler.logo, name=f"logo_{alias_id}"), + web.get(p + r"/{id:\d+}/view", handler.info, name=f"info_{alias_id}"), + web.get( + p + r"/{id:\d+}/thumbnail", + handler.thumbnail_get, + name=f"thumbnail_get_{alias_id}", + ), + web.get( + p + r"/{id:\d+}/{filename}", + handler.download_get, + name=f"download_get_{alias_id}", + ), + web.head( + p + r"/{id:\d+}/{filename}", + handler.download_head, + name=f"download_head_{alias_id}", + ), + ] + + +async def setup_routes(app: web.Application, handler: Views): + client = handler.client index_all = index_settings["index_all"] index_private = index_settings["index_private"] index_group = index_settings["index_group"] @@ -18,36 +44,13 @@ async def setup_routes(app, handler): exclude_chats = index_settings["exclude_chats"] include_chats = index_settings["include_chats"] routes = [ - web.get("/", h.home, name="home"), - web.get("/login", h.login_get, name="login_page"), - web.post("/login", h.login_post, name="login_handle"), - web.get("/logout", h.logout_get, name="logout"), - web.get("/favicon.ico", h.faviconicon, name="favicon"), + web.get("/", handler.home, name="home"), + web.get("/login", handler.login_get, name="login_page"), + web.post("/login", handler.login_post, name="login_handle"), + web.get("/logout", handler.logout_get, name="logout"), + web.get("/favicon.ico", handler.faviconicon, name="favicon"), ] - def get_common_routes(alias_id): - p = "/{chat:" + alias_id + "}" - return [ - web.get(p, h.index, name=f"index_{alias_id}"), - web.get(p + r"/logo", h.logo, name=f"logo_{alias_id}"), - web.get(p + r"/{id:\d+}/view", h.info, name=f"info_{alias_id}"), - web.get( - p + r"/{id:\d+}/thumbnail", - h.thumbnail_get, - name=f"thumbnail_get_{alias_id}", - ), - web.get( - p + r"/{id:\d+}/{filename}", - h.download_get, - name=f"download_get_{alias_id}", - ), - web.head( - p + r"/{id:\d+}/{filename}", - h.download_head, - name=f"download_head_{alias_id}", - ), - ] - if index_all: # print(await client.get_dialogs()) # dialogs = await client.get_dialogs() @@ -69,17 +72,17 @@ async def setup_routes(app, handler): log.debug(f"{chat.title}, group: {index_group}") continue - alias_id = h.generate_alias_id(chat) - routes.extend(get_common_routes(alias_id)) + alias_id = handler.generate_alias_id(chat) + routes.extend(get_common_routes(handler, alias_id)) log.debug(f"Index added for {chat.id} at /{alias_id}") else: for chat_id in include_chats: chat = await client.get_entity(chat_id) - alias_id = h.generate_alias_id(chat) + alias_id = handler.generate_alias_id(chat) routes.extend( - get_common_routes(alias_id) + get_common_routes(handler, alias_id) ) # returns list() of common routes log.debug(f"Index added for {chat.id} at /{alias_id}") - routes.append(web.view(r"/{wildcard:.*}", h.wildcard, name="wildcard")) + routes.append(web.view(r"/{wildcard:.*}", handler.wildcard, name="wildcard")) app.add_routes(routes) diff --git a/app/telegram.py b/app/telegram.py index 083ede9..5dcd145 100644 --- a/app/telegram.py +++ b/app/telegram.py @@ -7,7 +7,7 @@ from telethon.sessions import StringSession class Client(TelegramClient): - def __init__(self, session_string, *args, **kwargs): + def __init__(self, session_string: str, *args, **kwargs): super().__init__(StringSession(session_string), *args, **kwargs) self.log = logging.getLogger(__name__) @@ -25,7 +25,7 @@ class Client(TelegramClient): first_part = {first_part}, cut = {first_part_cut}(length={part_size-first_part_cut}), last_part = {last_part}, cut = {last_part_cut}(length={last_part_cut}), parts_count = {part_count} - """ + """ ) try: async for chunk in self.iter_download( @@ -42,7 +42,7 @@ class Client(TelegramClient): part += 1 - self.log.debug(f"serving finished") + self.log.debug("serving finished") except (GeneratorExit, StopAsyncIteration, asyncio.CancelledError): self.log.debug("file serve interrupted") raise diff --git a/app/util.py b/app/util.py index e903f8f..ebd11a8 100644 --- a/app/util.py +++ b/app/util.py @@ -1,7 +1,10 @@ +from typing import Union from urllib.parse import quote +from telethon.tl.custom import Message -def get_file_name(message, quote_name=True): + +def get_file_name(message: Message, quote_name: bool = True) -> str: if message.file.name: name = message.file.name else: @@ -10,7 +13,7 @@ def get_file_name(message, quote_name=True): return quote(name) if quote_name else name -def get_human_size(num): +def get_human_size(num: Union[int, float]) -> str: base = 1024.0 sufix_list = ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"] for unit in sufix_list: diff --git a/app/views/__init__.py b/app/views/__init__.py index 947f08a..f8033ca 100644 --- a/app/views/__init__.py +++ b/app/views/__init__.py @@ -1,8 +1,11 @@ import base64 import hashlib +from typing import Dict, Union + +from telethon.tl.types import Chat, User, Channel from ..config import SHORT_URL_LEN - +from ..telegram import Client from .home_view import HomeView from .wildcard_view import WildcardView from .download import Download @@ -16,6 +19,9 @@ from .faviconicon_view import FaviconIconView from .middlewhere import middleware_factory +TELEGRAM_CHAT = Union[Chat, User, Channel] + + class Views( HomeView, Download, @@ -28,12 +34,12 @@ class Views( LogoutView, FaviconIconView, ): - def __init__(self, client): + def __init__(self, client: Client): self.client = client self.url_len = SHORT_URL_LEN - self.chat_ids = {} + self.chat_ids: Dict[str, Dict[str, str]] = {} - def generate_alias_id(self, chat): + def generate_alias_id(self, chat: TELEGRAM_CHAT) -> str: chat_id = chat.id title = chat.title diff --git a/app/views/base.py b/app/views/base.py new file mode 100644 index 0000000..2b6a83d --- /dev/null +++ b/app/views/base.py @@ -0,0 +1,14 @@ +from typing import Dict, Union + +from telethon.tl.types import Chat, User, Channel + +from ..telegram import Client + + +TELEGRAM_CHAT = Union[Chat, User, Channel] + + +class BaseView: + client: Client + url_len: int + chat_ids: Dict[str, Dict[str, str]] diff --git a/app/views/download.py b/app/views/download.py index 33180d0..33dd919 100644 --- a/app/views/download.py +++ b/app/views/download.py @@ -1,22 +1,26 @@ import logging from aiohttp import web +from telethon.tl.custom import Message from app.util import get_file_name from app.config import block_downloads +from .base import BaseView log = logging.getLogger(__name__) -class Download: - async def download_get(self, req): +class Download(BaseView): + async def download_get(self, req: web.Request) -> web.Response: return await self.handle_request(req) - async def download_head(self, req): + async def download_head(self, req: web.Request) -> web.Response: return await self.handle_request(req, head=True) - async def handle_request(self, req, head=False): + async def handle_request( + self, req: web.Request, head: bool = False + ) -> web.Response: if block_downloads: return web.Response(status=403, text="403: Forbiden" if not head else None) @@ -26,7 +30,9 @@ class Download: chat_id = chat["chat_id"] try: - message = await self.client.get_messages(entity=chat_id, ids=file_id) + message: Message = await self.client.get_messages( + entity=chat_id, ids=file_id + ) except Exception: log.debug(f"Error in getting message {file_id} in {chat_id}", exc_info=True) message = None diff --git a/app/views/faviconicon_view.py b/app/views/faviconicon_view.py index 2663293..6502a3f 100644 --- a/app/views/faviconicon_view.py +++ b/app/views/faviconicon_view.py @@ -4,10 +4,11 @@ from PIL import Image, ImageDraw, ImageFont from aiohttp import web from app.config import logo_folder +from .base import BaseView -class FaviconIconView: - async def faviconicon(self, req): +class FaviconIconView(BaseView): + async def faviconicon(self, req: web.Request) -> web.Response: favicon_path = logo_folder.joinpath("favicon.ico") text = "T" if not favicon_path.exists(): @@ -15,10 +16,10 @@ class FaviconIconView: color = tuple((random.randint(0, 255) for _ in range(3))) im = Image.new("RGB", (W, H), color) draw = ImageDraw.Draw(im) - font = ImageFont.truetype("arial.ttf", 50) + font = ImageFont.truetype("arial.ttf", 100) w, h = draw.textsize(text, font=font) draw.text(((W - w) / 2, (H - h) / 2), text, fill="white", font=font) - im.save(favicon_path) + im.save(favicon_path, "JPEG") with open(favicon_path, "rb") as fp: body = fp.read() diff --git a/app/views/home_view.py b/app/views/home_view.py index e903797..747e496 100644 --- a/app/views/home_view.py +++ b/app/views/home_view.py @@ -1,11 +1,12 @@ from aiohttp import web import aiohttp_jinja2 +from .base import BaseView -class HomeView: + +class HomeView(BaseView): @aiohttp_jinja2.template("home.html") - async def home(self, req): - print(self.chat_ids) + async def home(self, req: web.Request) -> web.Response: if len(self.chat_ids) == 1: (chat,) = self.chat_ids.values() return web.HTTPFound(f"{chat['alias_id']}") diff --git a/app/views/index_view.py b/app/views/index_view.py index 5ca86e7..67b7069 100644 --- a/app/views/index_view.py +++ b/app/views/index_view.py @@ -1,19 +1,22 @@ import logging +from typing import List from urllib.parse import quote import aiohttp_jinja2 -from telethon.tl import types +from aiohttp import web +from telethon.tl import types, custom from app.config import results_per_page, block_downloads from app.util import get_file_name, get_human_size +from .base import BaseView log = logging.getLogger(__name__) -class IndexView: +class IndexView(BaseView): @aiohttp_jinja2.template("index.html") - async def index(self, req): + async def index(self, req: web.Request) -> web.Response: alias_id = req.match_info["chat"] chat = self.chat_ids[alias_id] log_msg = "" @@ -39,7 +42,9 @@ class IndexView: if search_query: kwargs.update({"search": search_query}) - messages = (await self.client.get_messages(**kwargs)) or [] + messages: List[custom.Message] = ( + await self.client.get_messages(**kwargs) + ) or [] except Exception: log.debug("failed to get messages", exc_info=True) diff --git a/app/views/info_view.py b/app/views/info_view.py index f5850dc..26b8c0a 100644 --- a/app/views/info_view.py +++ b/app/views/info_view.py @@ -2,20 +2,22 @@ import logging from urllib.parse import unquote import aiohttp_jinja2 +from aiohttp import web from telethon.tl import types from telethon.tl.custom import Message from jinja2 import Markup from app.util import get_file_name, get_human_size from app.config import block_downloads +from .base import BaseView log = logging.getLogger(__name__) -class InfoView: +class InfoView(BaseView): @aiohttp_jinja2.template("info.html") - async def info(self, req): + async def info(self, req: web.Request) -> web.Response: file_id = int(req.match_info["id"]) alias_id = req.match_info["chat"] chat = self.chat_ids[alias_id] diff --git a/app/views/login_view.py b/app/views/login_view.py index 74ce97e..ced1fbc 100644 --- a/app/views/login_view.py +++ b/app/views/login_view.py @@ -3,14 +3,15 @@ import time from aiohttp import web import aiohttp_jinja2 from aiohttp_session import new_session +from .base import BaseView -class LoginView: +class LoginView(BaseView): @aiohttp_jinja2.template("login.html") - async def login_get(self, req): + async def login_get(self, req: web.Request) -> web.Response: return dict(authenticated=False, **req.query) - async def login_post(self, req): + async def login_post(self, req: web.Request) -> web.Response: post_data = await req.post() redirect_to = post_data.get("redirect_to") or "/" location = req.app.router["login_page"].url_for() diff --git a/app/views/logo_view.py b/app/views/logo_view.py index bc1ea25..2d8e9c9 100644 --- a/app/views/logo_view.py +++ b/app/views/logo_view.py @@ -5,22 +5,29 @@ import random from aiohttp import web from telethon.tl import types + from app.config import logo_folder +from .base import BaseView log = logging.getLogger(__name__) -class LogoView: - async def logo(self, req): +class LogoView(BaseView): + async def logo(self, req: web.Request) -> web.Response: alias_id = req.match_info["chat"] chat = self.chat_ids[alias_id] chat_id = chat["chat_id"] - chat_name = " ".join(map(lambda x: x[0].upper(), chat["title"].split(" "))) + chat_name = " ".join( + map(lambda x: x[0].upper(), (chat["title"] or "_").split(" ")) + ) logo_path = logo_folder.joinpath(f"{alias_id}.jpg") if not logo_path.exists(): try: - (photo,) = await self.client.get_profile_photos(chat_id, limit=1) + + photo: types.Photo = ( + await self.client.get_profile_photos(chat_id, limit=1) + )[0] except Exception: log.debug( f"Error in getting profile picture in {chat_id}", exc_info=True @@ -40,7 +47,7 @@ class LogoView: im.save(logo_path) else: pos = -1 if req.query.get("big", None) else int(len(photo.sizes) / 2) - size = self.client._get_thumb(photo.sizes, pos) + size: types.PhotoSize = self.client._get_thumb(photo.sizes, pos) if isinstance(size, (types.PhotoCachedSize, types.PhotoStrippedSize)): await self.client._download_cached_photo_size(size, logo_path) else: diff --git a/app/views/logout_view.py b/app/views/logout_view.py index db851c9..ab2058a 100644 --- a/app/views/logout_view.py +++ b/app/views/logout_view.py @@ -1,9 +1,11 @@ from aiohttp_session import get_session from aiohttp import web +from .base import BaseView -class LogoutView: - async def logout_get(self, req): + +class LogoutView(BaseView): + async def logout_get(self, req: web.Request) -> web.Response: session = await get_session(req) session["logged_in"] = False diff --git a/app/views/middlewhere.py b/app/views/middlewhere.py index 937040c..9e63d39 100644 --- a/app/views/middlewhere.py +++ b/app/views/middlewhere.py @@ -1,7 +1,8 @@ import time import logging +from typing import Coroutine, Union -from aiohttp.web import middleware, HTTPFound, Response +from aiohttp.web import middleware, HTTPFound, Response, Request from aiohttp import BasicAuth, hdrs from aiohttp_session import get_session @@ -9,7 +10,7 @@ from aiohttp_session import get_session log = logging.getLogger(__name__) -def _do_basic_auth_check(request): +def _do_basic_auth_check(request: Request) -> Union[None, bool]: if "download_" not in request.match_info.route.name: return @@ -47,7 +48,7 @@ def _do_basic_auth_check(request): return True -async def _do_cookies_auth_check(request): +async def _do_cookies_auth_check(request: Request) -> Union[None, bool]: session = await get_session(request) if not session.get("logged_in", False): return @@ -56,9 +57,9 @@ async def _do_cookies_auth_check(request): return True -def middleware_factory(): +def middleware_factory() -> Coroutine: @middleware - async def factory(request, handler): + async def factory(request: Request, handler: Coroutine) -> Response: if request.app["is_authenticated"] and str(request.rel_url.path) not in [ "/login", "/logout", diff --git a/app/views/thumbnail_view.py b/app/views/thumbnail_view.py index 7bbaff7..f50d2d2 100644 --- a/app/views/thumbnail_view.py +++ b/app/views/thumbnail_view.py @@ -4,20 +4,24 @@ import random import io from aiohttp import web -from telethon.tl import types +from telethon.tl import types, custom + +from .base import BaseView log = logging.getLogger(__name__) -class ThumbnailView: - async def thumbnail_get(self, req): +class ThumbnailView(BaseView): + async def thumbnail_get(self, req: web.Request) -> web.Response: file_id = int(req.match_info["id"]) alias_id = req.match_info["chat"] chat = self.chat_ids[alias_id] chat_id = chat["chat_id"] try: - message = await self.client.get_messages(entity=chat_id, ids=file_id) + message: custom.Message = await self.client.get_messages( + entity=chat_id, ids=file_id + ) except Exception: log.debug(f"Error in getting message {file_id} in {chat_id}", exc_info=True) message = None @@ -47,7 +51,9 @@ class ThumbnailView: else: thumb_pos = int(len(thumbnails) / 2) try: - thumbnail = self.client._get_thumb(thumbnails, thumb_pos) + thumbnail: types.PhotoSize = self.client._get_thumb( + thumbnails, thumb_pos + ) except Exception as e: logging.debug(e) thumbnail = None diff --git a/app/views/wildcard_view.py b/app/views/wildcard_view.py index f0d12d4..3ca6764 100644 --- a/app/views/wildcard_view.py +++ b/app/views/wildcard_view.py @@ -1,6 +1,8 @@ from aiohttp import web +from .base import BaseView -class WildcardView: - async def wildcard(self, req): + +class WildcardView(BaseView): + async def wildcard(self, req: web.Request) -> web.Response: return web.HTTPFound("/")