add type hinting

This commit is contained in:
odysseusmax 2021-11-01 09:32:46 +05:30
parent 34811800f7
commit b5e876f393
16 changed files with 145 additions and 85 deletions

View File

@ -1,16 +1,42 @@
import logging import logging
from typing import List
from aiohttp import web from aiohttp import web
from aiohttp.web_routedef import RouteDef
from telethon.tl.types import Channel, Chat, User from telethon.tl.types import Channel, Chat, User
from .config import index_settings from .config import index_settings
from .views import Views
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
async def setup_routes(app, handler): def get_common_routes(handler: Views, alias_id: str) -> List[RouteDef]:
h = handler p = "/{chat:" + alias_id + "}"
client = h.client return [
web.get(p, handler.index, name=f"index_{alias_id}"),
web.get(p + r"/logo", handler.logo, name=f"logo_{alias_id}"),
web.get(p + r"/{id:\d+}/view", handler.info, name=f"info_{alias_id}"),
web.get(
p + r"/{id:\d+}/thumbnail",
handler.thumbnail_get,
name=f"thumbnail_get_{alias_id}",
),
web.get(
p + r"/{id:\d+}/{filename}",
handler.download_get,
name=f"download_get_{alias_id}",
),
web.head(
p + r"/{id:\d+}/{filename}",
handler.download_head,
name=f"download_head_{alias_id}",
),
]
async def setup_routes(app: web.Application, handler: Views):
client = handler.client
index_all = index_settings["index_all"] index_all = index_settings["index_all"]
index_private = index_settings["index_private"] index_private = index_settings["index_private"]
index_group = index_settings["index_group"] index_group = index_settings["index_group"]
@ -18,36 +44,13 @@ async def setup_routes(app, handler):
exclude_chats = index_settings["exclude_chats"] exclude_chats = index_settings["exclude_chats"]
include_chats = index_settings["include_chats"] include_chats = index_settings["include_chats"]
routes = [ routes = [
web.get("/", h.home, name="home"), web.get("/", handler.home, name="home"),
web.get("/login", h.login_get, name="login_page"), web.get("/login", handler.login_get, name="login_page"),
web.post("/login", h.login_post, name="login_handle"), web.post("/login", handler.login_post, name="login_handle"),
web.get("/logout", h.logout_get, name="logout"), web.get("/logout", handler.logout_get, name="logout"),
web.get("/favicon.ico", h.faviconicon, name="favicon"), web.get("/favicon.ico", handler.faviconicon, name="favicon"),
] ]
def get_common_routes(alias_id):
p = "/{chat:" + alias_id + "}"
return [
web.get(p, h.index, name=f"index_{alias_id}"),
web.get(p + r"/logo", h.logo, name=f"logo_{alias_id}"),
web.get(p + r"/{id:\d+}/view", h.info, name=f"info_{alias_id}"),
web.get(
p + r"/{id:\d+}/thumbnail",
h.thumbnail_get,
name=f"thumbnail_get_{alias_id}",
),
web.get(
p + r"/{id:\d+}/{filename}",
h.download_get,
name=f"download_get_{alias_id}",
),
web.head(
p + r"/{id:\d+}/{filename}",
h.download_head,
name=f"download_head_{alias_id}",
),
]
if index_all: if index_all:
# print(await client.get_dialogs()) # print(await client.get_dialogs())
# dialogs = await client.get_dialogs() # dialogs = await client.get_dialogs()
@ -69,17 +72,17 @@ async def setup_routes(app, handler):
log.debug(f"{chat.title}, group: {index_group}") log.debug(f"{chat.title}, group: {index_group}")
continue continue
alias_id = h.generate_alias_id(chat) alias_id = handler.generate_alias_id(chat)
routes.extend(get_common_routes(alias_id)) routes.extend(get_common_routes(handler, alias_id))
log.debug(f"Index added for {chat.id} at /{alias_id}") log.debug(f"Index added for {chat.id} at /{alias_id}")
else: else:
for chat_id in include_chats: for chat_id in include_chats:
chat = await client.get_entity(chat_id) chat = await client.get_entity(chat_id)
alias_id = h.generate_alias_id(chat) alias_id = handler.generate_alias_id(chat)
routes.extend( routes.extend(
get_common_routes(alias_id) get_common_routes(handler, alias_id)
) # returns list() of common routes ) # returns list() of common routes
log.debug(f"Index added for {chat.id} at /{alias_id}") log.debug(f"Index added for {chat.id} at /{alias_id}")
routes.append(web.view(r"/{wildcard:.*}", h.wildcard, name="wildcard")) routes.append(web.view(r"/{wildcard:.*}", handler.wildcard, name="wildcard"))
app.add_routes(routes) app.add_routes(routes)

View File

@ -7,7 +7,7 @@ from telethon.sessions import StringSession
class Client(TelegramClient): class Client(TelegramClient):
def __init__(self, session_string, *args, **kwargs): def __init__(self, session_string: str, *args, **kwargs):
super().__init__(StringSession(session_string), *args, **kwargs) super().__init__(StringSession(session_string), *args, **kwargs)
self.log = logging.getLogger(__name__) self.log = logging.getLogger(__name__)
@ -25,7 +25,7 @@ class Client(TelegramClient):
first_part = {first_part}, cut = {first_part_cut}(length={part_size-first_part_cut}), first_part = {first_part}, cut = {first_part_cut}(length={part_size-first_part_cut}),
last_part = {last_part}, cut = {last_part_cut}(length={last_part_cut}), last_part = {last_part}, cut = {last_part_cut}(length={last_part_cut}),
parts_count = {part_count} parts_count = {part_count}
""" """
) )
try: try:
async for chunk in self.iter_download( async for chunk in self.iter_download(
@ -42,7 +42,7 @@ class Client(TelegramClient):
part += 1 part += 1
self.log.debug(f"serving finished") self.log.debug("serving finished")
except (GeneratorExit, StopAsyncIteration, asyncio.CancelledError): except (GeneratorExit, StopAsyncIteration, asyncio.CancelledError):
self.log.debug("file serve interrupted") self.log.debug("file serve interrupted")
raise raise

View File

@ -1,7 +1,10 @@
from typing import Union
from urllib.parse import quote from urllib.parse import quote
from telethon.tl.custom import Message
def get_file_name(message, quote_name=True):
def get_file_name(message: Message, quote_name: bool = True) -> str:
if message.file.name: if message.file.name:
name = message.file.name name = message.file.name
else: else:
@ -10,7 +13,7 @@ def get_file_name(message, quote_name=True):
return quote(name) if quote_name else name return quote(name) if quote_name else name
def get_human_size(num): def get_human_size(num: Union[int, float]) -> str:
base = 1024.0 base = 1024.0
sufix_list = ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"] sufix_list = ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"]
for unit in sufix_list: for unit in sufix_list:

View File

@ -1,8 +1,11 @@
import base64 import base64
import hashlib import hashlib
from typing import Dict, Union
from telethon.tl.types import Chat, User, Channel
from ..config import SHORT_URL_LEN from ..config import SHORT_URL_LEN
from ..telegram import Client
from .home_view import HomeView from .home_view import HomeView
from .wildcard_view import WildcardView from .wildcard_view import WildcardView
from .download import Download from .download import Download
@ -16,6 +19,9 @@ from .faviconicon_view import FaviconIconView
from .middlewhere import middleware_factory from .middlewhere import middleware_factory
TELEGRAM_CHAT = Union[Chat, User, Channel]
class Views( class Views(
HomeView, HomeView,
Download, Download,
@ -28,12 +34,12 @@ class Views(
LogoutView, LogoutView,
FaviconIconView, FaviconIconView,
): ):
def __init__(self, client): def __init__(self, client: Client):
self.client = client self.client = client
self.url_len = SHORT_URL_LEN self.url_len = SHORT_URL_LEN
self.chat_ids = {} self.chat_ids: Dict[str, Dict[str, str]] = {}
def generate_alias_id(self, chat): def generate_alias_id(self, chat: TELEGRAM_CHAT) -> str:
chat_id = chat.id chat_id = chat.id
title = chat.title title = chat.title

14
app/views/base.py Normal file
View File

@ -0,0 +1,14 @@
from typing import Dict, Union
from telethon.tl.types import Chat, User, Channel
from ..telegram import Client
TELEGRAM_CHAT = Union[Chat, User, Channel]
class BaseView:
client: Client
url_len: int
chat_ids: Dict[str, Dict[str, str]]

View File

@ -1,22 +1,26 @@
import logging import logging
from aiohttp import web from aiohttp import web
from telethon.tl.custom import Message
from app.util import get_file_name from app.util import get_file_name
from app.config import block_downloads from app.config import block_downloads
from .base import BaseView
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class Download: class Download(BaseView):
async def download_get(self, req): async def download_get(self, req: web.Request) -> web.Response:
return await self.handle_request(req) return await self.handle_request(req)
async def download_head(self, req): async def download_head(self, req: web.Request) -> web.Response:
return await self.handle_request(req, head=True) return await self.handle_request(req, head=True)
async def handle_request(self, req, head=False): async def handle_request(
self, req: web.Request, head: bool = False
) -> web.Response:
if block_downloads: if block_downloads:
return web.Response(status=403, text="403: Forbiden" if not head else None) return web.Response(status=403, text="403: Forbiden" if not head else None)
@ -26,7 +30,9 @@ class Download:
chat_id = chat["chat_id"] chat_id = chat["chat_id"]
try: try:
message = await self.client.get_messages(entity=chat_id, ids=file_id) message: Message = await self.client.get_messages(
entity=chat_id, ids=file_id
)
except Exception: except Exception:
log.debug(f"Error in getting message {file_id} in {chat_id}", exc_info=True) log.debug(f"Error in getting message {file_id} in {chat_id}", exc_info=True)
message = None message = None

View File

@ -4,10 +4,11 @@ from PIL import Image, ImageDraw, ImageFont
from aiohttp import web from aiohttp import web
from app.config import logo_folder from app.config import logo_folder
from .base import BaseView
class FaviconIconView: class FaviconIconView(BaseView):
async def faviconicon(self, req): async def faviconicon(self, req: web.Request) -> web.Response:
favicon_path = logo_folder.joinpath("favicon.ico") favicon_path = logo_folder.joinpath("favicon.ico")
text = "T" text = "T"
if not favicon_path.exists(): if not favicon_path.exists():
@ -15,10 +16,10 @@ class FaviconIconView:
color = tuple((random.randint(0, 255) for _ in range(3))) color = tuple((random.randint(0, 255) for _ in range(3)))
im = Image.new("RGB", (W, H), color) im = Image.new("RGB", (W, H), color)
draw = ImageDraw.Draw(im) draw = ImageDraw.Draw(im)
font = ImageFont.truetype("arial.ttf", 50) font = ImageFont.truetype("arial.ttf", 100)
w, h = draw.textsize(text, font=font) w, h = draw.textsize(text, font=font)
draw.text(((W - w) / 2, (H - h) / 2), text, fill="white", font=font) draw.text(((W - w) / 2, (H - h) / 2), text, fill="white", font=font)
im.save(favicon_path) im.save(favicon_path, "JPEG")
with open(favicon_path, "rb") as fp: with open(favicon_path, "rb") as fp:
body = fp.read() body = fp.read()

View File

@ -1,11 +1,12 @@
from aiohttp import web from aiohttp import web
import aiohttp_jinja2 import aiohttp_jinja2
from .base import BaseView
class HomeView:
class HomeView(BaseView):
@aiohttp_jinja2.template("home.html") @aiohttp_jinja2.template("home.html")
async def home(self, req): async def home(self, req: web.Request) -> web.Response:
print(self.chat_ids)
if len(self.chat_ids) == 1: if len(self.chat_ids) == 1:
(chat,) = self.chat_ids.values() (chat,) = self.chat_ids.values()
return web.HTTPFound(f"{chat['alias_id']}") return web.HTTPFound(f"{chat['alias_id']}")

View File

@ -1,19 +1,22 @@
import logging import logging
from typing import List
from urllib.parse import quote from urllib.parse import quote
import aiohttp_jinja2 import aiohttp_jinja2
from telethon.tl import types from aiohttp import web
from telethon.tl import types, custom
from app.config import results_per_page, block_downloads from app.config import results_per_page, block_downloads
from app.util import get_file_name, get_human_size from app.util import get_file_name, get_human_size
from .base import BaseView
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class IndexView: class IndexView(BaseView):
@aiohttp_jinja2.template("index.html") @aiohttp_jinja2.template("index.html")
async def index(self, req): async def index(self, req: web.Request) -> web.Response:
alias_id = req.match_info["chat"] alias_id = req.match_info["chat"]
chat = self.chat_ids[alias_id] chat = self.chat_ids[alias_id]
log_msg = "" log_msg = ""
@ -39,7 +42,9 @@ class IndexView:
if search_query: if search_query:
kwargs.update({"search": search_query}) kwargs.update({"search": search_query})
messages = (await self.client.get_messages(**kwargs)) or [] messages: List[custom.Message] = (
await self.client.get_messages(**kwargs)
) or []
except Exception: except Exception:
log.debug("failed to get messages", exc_info=True) log.debug("failed to get messages", exc_info=True)

View File

@ -2,20 +2,22 @@ import logging
from urllib.parse import unquote from urllib.parse import unquote
import aiohttp_jinja2 import aiohttp_jinja2
from aiohttp import web
from telethon.tl import types from telethon.tl import types
from telethon.tl.custom import Message from telethon.tl.custom import Message
from jinja2 import Markup from jinja2 import Markup
from app.util import get_file_name, get_human_size from app.util import get_file_name, get_human_size
from app.config import block_downloads from app.config import block_downloads
from .base import BaseView
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class InfoView: class InfoView(BaseView):
@aiohttp_jinja2.template("info.html") @aiohttp_jinja2.template("info.html")
async def info(self, req): async def info(self, req: web.Request) -> web.Response:
file_id = int(req.match_info["id"]) file_id = int(req.match_info["id"])
alias_id = req.match_info["chat"] alias_id = req.match_info["chat"]
chat = self.chat_ids[alias_id] chat = self.chat_ids[alias_id]

View File

@ -3,14 +3,15 @@ import time
from aiohttp import web from aiohttp import web
import aiohttp_jinja2 import aiohttp_jinja2
from aiohttp_session import new_session from aiohttp_session import new_session
from .base import BaseView
class LoginView: class LoginView(BaseView):
@aiohttp_jinja2.template("login.html") @aiohttp_jinja2.template("login.html")
async def login_get(self, req): async def login_get(self, req: web.Request) -> web.Response:
return dict(authenticated=False, **req.query) return dict(authenticated=False, **req.query)
async def login_post(self, req): async def login_post(self, req: web.Request) -> web.Response:
post_data = await req.post() post_data = await req.post()
redirect_to = post_data.get("redirect_to") or "/" redirect_to = post_data.get("redirect_to") or "/"
location = req.app.router["login_page"].url_for() location = req.app.router["login_page"].url_for()

View File

@ -5,22 +5,29 @@ import random
from aiohttp import web from aiohttp import web
from telethon.tl import types from telethon.tl import types
from app.config import logo_folder from app.config import logo_folder
from .base import BaseView
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class LogoView: class LogoView(BaseView):
async def logo(self, req): async def logo(self, req: web.Request) -> web.Response:
alias_id = req.match_info["chat"] alias_id = req.match_info["chat"]
chat = self.chat_ids[alias_id] chat = self.chat_ids[alias_id]
chat_id = chat["chat_id"] chat_id = chat["chat_id"]
chat_name = " ".join(map(lambda x: x[0].upper(), chat["title"].split(" "))) chat_name = " ".join(
map(lambda x: x[0].upper(), (chat["title"] or "_").split(" "))
)
logo_path = logo_folder.joinpath(f"{alias_id}.jpg") logo_path = logo_folder.joinpath(f"{alias_id}.jpg")
if not logo_path.exists(): if not logo_path.exists():
try: try:
(photo,) = await self.client.get_profile_photos(chat_id, limit=1)
photo: types.Photo = (
await self.client.get_profile_photos(chat_id, limit=1)
)[0]
except Exception: except Exception:
log.debug( log.debug(
f"Error in getting profile picture in {chat_id}", exc_info=True f"Error in getting profile picture in {chat_id}", exc_info=True
@ -40,7 +47,7 @@ class LogoView:
im.save(logo_path) im.save(logo_path)
else: else:
pos = -1 if req.query.get("big", None) else int(len(photo.sizes) / 2) pos = -1 if req.query.get("big", None) else int(len(photo.sizes) / 2)
size = self.client._get_thumb(photo.sizes, pos) size: types.PhotoSize = self.client._get_thumb(photo.sizes, pos)
if isinstance(size, (types.PhotoCachedSize, types.PhotoStrippedSize)): if isinstance(size, (types.PhotoCachedSize, types.PhotoStrippedSize)):
await self.client._download_cached_photo_size(size, logo_path) await self.client._download_cached_photo_size(size, logo_path)
else: else:

View File

@ -1,9 +1,11 @@
from aiohttp_session import get_session from aiohttp_session import get_session
from aiohttp import web from aiohttp import web
from .base import BaseView
class LogoutView:
async def logout_get(self, req): class LogoutView(BaseView):
async def logout_get(self, req: web.Request) -> web.Response:
session = await get_session(req) session = await get_session(req)
session["logged_in"] = False session["logged_in"] = False

View File

@ -1,7 +1,8 @@
import time import time
import logging import logging
from typing import Coroutine, Union
from aiohttp.web import middleware, HTTPFound, Response from aiohttp.web import middleware, HTTPFound, Response, Request
from aiohttp import BasicAuth, hdrs from aiohttp import BasicAuth, hdrs
from aiohttp_session import get_session from aiohttp_session import get_session
@ -9,7 +10,7 @@ from aiohttp_session import get_session
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def _do_basic_auth_check(request): def _do_basic_auth_check(request: Request) -> Union[None, bool]:
if "download_" not in request.match_info.route.name: if "download_" not in request.match_info.route.name:
return return
@ -47,7 +48,7 @@ def _do_basic_auth_check(request):
return True return True
async def _do_cookies_auth_check(request): async def _do_cookies_auth_check(request: Request) -> Union[None, bool]:
session = await get_session(request) session = await get_session(request)
if not session.get("logged_in", False): if not session.get("logged_in", False):
return return
@ -56,9 +57,9 @@ async def _do_cookies_auth_check(request):
return True return True
def middleware_factory(): def middleware_factory() -> Coroutine:
@middleware @middleware
async def factory(request, handler): async def factory(request: Request, handler: Coroutine) -> Response:
if request.app["is_authenticated"] and str(request.rel_url.path) not in [ if request.app["is_authenticated"] and str(request.rel_url.path) not in [
"/login", "/login",
"/logout", "/logout",

View File

@ -4,20 +4,24 @@ import random
import io import io
from aiohttp import web from aiohttp import web
from telethon.tl import types from telethon.tl import types, custom
from .base import BaseView
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class ThumbnailView: class ThumbnailView(BaseView):
async def thumbnail_get(self, req): async def thumbnail_get(self, req: web.Request) -> web.Response:
file_id = int(req.match_info["id"]) file_id = int(req.match_info["id"])
alias_id = req.match_info["chat"] alias_id = req.match_info["chat"]
chat = self.chat_ids[alias_id] chat = self.chat_ids[alias_id]
chat_id = chat["chat_id"] chat_id = chat["chat_id"]
try: try:
message = await self.client.get_messages(entity=chat_id, ids=file_id) message: custom.Message = await self.client.get_messages(
entity=chat_id, ids=file_id
)
except Exception: except Exception:
log.debug(f"Error in getting message {file_id} in {chat_id}", exc_info=True) log.debug(f"Error in getting message {file_id} in {chat_id}", exc_info=True)
message = None message = None
@ -47,7 +51,9 @@ class ThumbnailView:
else: else:
thumb_pos = int(len(thumbnails) / 2) thumb_pos = int(len(thumbnails) / 2)
try: try:
thumbnail = self.client._get_thumb(thumbnails, thumb_pos) thumbnail: types.PhotoSize = self.client._get_thumb(
thumbnails, thumb_pos
)
except Exception as e: except Exception as e:
logging.debug(e) logging.debug(e)
thumbnail = None thumbnail = None

View File

@ -1,6 +1,8 @@
from aiohttp import web from aiohttp import web
from .base import BaseView
class WildcardView:
async def wildcard(self, req): class WildcardView(BaseView):
async def wildcard(self, req: web.Request) -> web.Response:
return web.HTTPFound("/") return web.HTTPFound("/")