add type hinting
This commit is contained in:
parent
34811800f7
commit
b5e876f393
|
@ -1,16 +1,42 @@
|
|||
import logging
|
||||
from typing import List
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.web_routedef import RouteDef
|
||||
from telethon.tl.types import Channel, Chat, User
|
||||
|
||||
from .config import index_settings
|
||||
from .views import Views
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def setup_routes(app, handler):
|
||||
h = handler
|
||||
client = h.client
|
||||
def get_common_routes(handler: Views, alias_id: str) -> List[RouteDef]:
|
||||
p = "/{chat:" + alias_id + "}"
|
||||
return [
|
||||
web.get(p, handler.index, name=f"index_{alias_id}"),
|
||||
web.get(p + r"/logo", handler.logo, name=f"logo_{alias_id}"),
|
||||
web.get(p + r"/{id:\d+}/view", handler.info, name=f"info_{alias_id}"),
|
||||
web.get(
|
||||
p + r"/{id:\d+}/thumbnail",
|
||||
handler.thumbnail_get,
|
||||
name=f"thumbnail_get_{alias_id}",
|
||||
),
|
||||
web.get(
|
||||
p + r"/{id:\d+}/{filename}",
|
||||
handler.download_get,
|
||||
name=f"download_get_{alias_id}",
|
||||
),
|
||||
web.head(
|
||||
p + r"/{id:\d+}/{filename}",
|
||||
handler.download_head,
|
||||
name=f"download_head_{alias_id}",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def setup_routes(app: web.Application, handler: Views):
|
||||
client = handler.client
|
||||
index_all = index_settings["index_all"]
|
||||
index_private = index_settings["index_private"]
|
||||
index_group = index_settings["index_group"]
|
||||
|
@ -18,36 +44,13 @@ async def setup_routes(app, handler):
|
|||
exclude_chats = index_settings["exclude_chats"]
|
||||
include_chats = index_settings["include_chats"]
|
||||
routes = [
|
||||
web.get("/", h.home, name="home"),
|
||||
web.get("/login", h.login_get, name="login_page"),
|
||||
web.post("/login", h.login_post, name="login_handle"),
|
||||
web.get("/logout", h.logout_get, name="logout"),
|
||||
web.get("/favicon.ico", h.faviconicon, name="favicon"),
|
||||
web.get("/", handler.home, name="home"),
|
||||
web.get("/login", handler.login_get, name="login_page"),
|
||||
web.post("/login", handler.login_post, name="login_handle"),
|
||||
web.get("/logout", handler.logout_get, name="logout"),
|
||||
web.get("/favicon.ico", handler.faviconicon, name="favicon"),
|
||||
]
|
||||
|
||||
def get_common_routes(alias_id):
|
||||
p = "/{chat:" + alias_id + "}"
|
||||
return [
|
||||
web.get(p, h.index, name=f"index_{alias_id}"),
|
||||
web.get(p + r"/logo", h.logo, name=f"logo_{alias_id}"),
|
||||
web.get(p + r"/{id:\d+}/view", h.info, name=f"info_{alias_id}"),
|
||||
web.get(
|
||||
p + r"/{id:\d+}/thumbnail",
|
||||
h.thumbnail_get,
|
||||
name=f"thumbnail_get_{alias_id}",
|
||||
),
|
||||
web.get(
|
||||
p + r"/{id:\d+}/{filename}",
|
||||
h.download_get,
|
||||
name=f"download_get_{alias_id}",
|
||||
),
|
||||
web.head(
|
||||
p + r"/{id:\d+}/{filename}",
|
||||
h.download_head,
|
||||
name=f"download_head_{alias_id}",
|
||||
),
|
||||
]
|
||||
|
||||
if index_all:
|
||||
# print(await client.get_dialogs())
|
||||
# dialogs = await client.get_dialogs()
|
||||
|
@ -69,17 +72,17 @@ async def setup_routes(app, handler):
|
|||
log.debug(f"{chat.title}, group: {index_group}")
|
||||
continue
|
||||
|
||||
alias_id = h.generate_alias_id(chat)
|
||||
routes.extend(get_common_routes(alias_id))
|
||||
alias_id = handler.generate_alias_id(chat)
|
||||
routes.extend(get_common_routes(handler, alias_id))
|
||||
log.debug(f"Index added for {chat.id} at /{alias_id}")
|
||||
|
||||
else:
|
||||
for chat_id in include_chats:
|
||||
chat = await client.get_entity(chat_id)
|
||||
alias_id = h.generate_alias_id(chat)
|
||||
alias_id = handler.generate_alias_id(chat)
|
||||
routes.extend(
|
||||
get_common_routes(alias_id)
|
||||
get_common_routes(handler, alias_id)
|
||||
) # returns list() of common routes
|
||||
log.debug(f"Index added for {chat.id} at /{alias_id}")
|
||||
routes.append(web.view(r"/{wildcard:.*}", h.wildcard, name="wildcard"))
|
||||
routes.append(web.view(r"/{wildcard:.*}", handler.wildcard, name="wildcard"))
|
||||
app.add_routes(routes)
|
||||
|
|
|
@ -7,7 +7,7 @@ from telethon.sessions import StringSession
|
|||
|
||||
|
||||
class Client(TelegramClient):
|
||||
def __init__(self, session_string, *args, **kwargs):
|
||||
def __init__(self, session_string: str, *args, **kwargs):
|
||||
super().__init__(StringSession(session_string), *args, **kwargs)
|
||||
self.log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -25,7 +25,7 @@ class Client(TelegramClient):
|
|||
first_part = {first_part}, cut = {first_part_cut}(length={part_size-first_part_cut}),
|
||||
last_part = {last_part}, cut = {last_part_cut}(length={last_part_cut}),
|
||||
parts_count = {part_count}
|
||||
"""
|
||||
"""
|
||||
)
|
||||
try:
|
||||
async for chunk in self.iter_download(
|
||||
|
@ -42,7 +42,7 @@ class Client(TelegramClient):
|
|||
|
||||
part += 1
|
||||
|
||||
self.log.debug(f"serving finished")
|
||||
self.log.debug("serving finished")
|
||||
except (GeneratorExit, StopAsyncIteration, asyncio.CancelledError):
|
||||
self.log.debug("file serve interrupted")
|
||||
raise
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from typing import Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from telethon.tl.custom import Message
|
||||
|
||||
def get_file_name(message, quote_name=True):
|
||||
|
||||
def get_file_name(message: Message, quote_name: bool = True) -> str:
|
||||
if message.file.name:
|
||||
name = message.file.name
|
||||
else:
|
||||
|
@ -10,7 +13,7 @@ def get_file_name(message, quote_name=True):
|
|||
return quote(name) if quote_name else name
|
||||
|
||||
|
||||
def get_human_size(num):
|
||||
def get_human_size(num: Union[int, float]) -> str:
|
||||
base = 1024.0
|
||||
sufix_list = ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"]
|
||||
for unit in sufix_list:
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import base64
|
||||
import hashlib
|
||||
from typing import Dict, Union
|
||||
|
||||
from telethon.tl.types import Chat, User, Channel
|
||||
|
||||
from ..config import SHORT_URL_LEN
|
||||
|
||||
from ..telegram import Client
|
||||
from .home_view import HomeView
|
||||
from .wildcard_view import WildcardView
|
||||
from .download import Download
|
||||
|
@ -16,6 +19,9 @@ from .faviconicon_view import FaviconIconView
|
|||
from .middlewhere import middleware_factory
|
||||
|
||||
|
||||
TELEGRAM_CHAT = Union[Chat, User, Channel]
|
||||
|
||||
|
||||
class Views(
|
||||
HomeView,
|
||||
Download,
|
||||
|
@ -28,12 +34,12 @@ class Views(
|
|||
LogoutView,
|
||||
FaviconIconView,
|
||||
):
|
||||
def __init__(self, client):
|
||||
def __init__(self, client: Client):
|
||||
self.client = client
|
||||
self.url_len = SHORT_URL_LEN
|
||||
self.chat_ids = {}
|
||||
self.chat_ids: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
def generate_alias_id(self, chat):
|
||||
def generate_alias_id(self, chat: TELEGRAM_CHAT) -> str:
|
||||
chat_id = chat.id
|
||||
title = chat.title
|
||||
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
from typing import Dict, Union
|
||||
|
||||
from telethon.tl.types import Chat, User, Channel
|
||||
|
||||
from ..telegram import Client
|
||||
|
||||
|
||||
TELEGRAM_CHAT = Union[Chat, User, Channel]
|
||||
|
||||
|
||||
class BaseView:
|
||||
client: Client
|
||||
url_len: int
|
||||
chat_ids: Dict[str, Dict[str, str]]
|
|
@ -1,22 +1,26 @@
|
|||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
from telethon.tl.custom import Message
|
||||
|
||||
from app.util import get_file_name
|
||||
from app.config import block_downloads
|
||||
from .base import BaseView
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Download:
|
||||
async def download_get(self, req):
|
||||
class Download(BaseView):
|
||||
async def download_get(self, req: web.Request) -> web.Response:
|
||||
return await self.handle_request(req)
|
||||
|
||||
async def download_head(self, req):
|
||||
async def download_head(self, req: web.Request) -> web.Response:
|
||||
return await self.handle_request(req, head=True)
|
||||
|
||||
async def handle_request(self, req, head=False):
|
||||
async def handle_request(
|
||||
self, req: web.Request, head: bool = False
|
||||
) -> web.Response:
|
||||
if block_downloads:
|
||||
return web.Response(status=403, text="403: Forbiden" if not head else None)
|
||||
|
||||
|
@ -26,7 +30,9 @@ class Download:
|
|||
chat_id = chat["chat_id"]
|
||||
|
||||
try:
|
||||
message = await self.client.get_messages(entity=chat_id, ids=file_id)
|
||||
message: Message = await self.client.get_messages(
|
||||
entity=chat_id, ids=file_id
|
||||
)
|
||||
except Exception:
|
||||
log.debug(f"Error in getting message {file_id} in {chat_id}", exc_info=True)
|
||||
message = None
|
||||
|
|
|
@ -4,10 +4,11 @@ from PIL import Image, ImageDraw, ImageFont
|
|||
from aiohttp import web
|
||||
|
||||
from app.config import logo_folder
|
||||
from .base import BaseView
|
||||
|
||||
|
||||
class FaviconIconView:
|
||||
async def faviconicon(self, req):
|
||||
class FaviconIconView(BaseView):
|
||||
async def faviconicon(self, req: web.Request) -> web.Response:
|
||||
favicon_path = logo_folder.joinpath("favicon.ico")
|
||||
text = "T"
|
||||
if not favicon_path.exists():
|
||||
|
@ -15,10 +16,10 @@ class FaviconIconView:
|
|||
color = tuple((random.randint(0, 255) for _ in range(3)))
|
||||
im = Image.new("RGB", (W, H), color)
|
||||
draw = ImageDraw.Draw(im)
|
||||
font = ImageFont.truetype("arial.ttf", 50)
|
||||
font = ImageFont.truetype("arial.ttf", 100)
|
||||
w, h = draw.textsize(text, font=font)
|
||||
draw.text(((W - w) / 2, (H - h) / 2), text, fill="white", font=font)
|
||||
im.save(favicon_path)
|
||||
im.save(favicon_path, "JPEG")
|
||||
|
||||
with open(favicon_path, "rb") as fp:
|
||||
body = fp.read()
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
from aiohttp import web
|
||||
import aiohttp_jinja2
|
||||
|
||||
from .base import BaseView
|
||||
|
||||
class HomeView:
|
||||
|
||||
class HomeView(BaseView):
|
||||
@aiohttp_jinja2.template("home.html")
|
||||
async def home(self, req):
|
||||
print(self.chat_ids)
|
||||
async def home(self, req: web.Request) -> web.Response:
|
||||
if len(self.chat_ids) == 1:
|
||||
(chat,) = self.chat_ids.values()
|
||||
return web.HTTPFound(f"{chat['alias_id']}")
|
||||
|
|
|
@ -1,19 +1,22 @@
|
|||
import logging
|
||||
from typing import List
|
||||
from urllib.parse import quote
|
||||
|
||||
import aiohttp_jinja2
|
||||
from telethon.tl import types
|
||||
from aiohttp import web
|
||||
from telethon.tl import types, custom
|
||||
|
||||
from app.config import results_per_page, block_downloads
|
||||
from app.util import get_file_name, get_human_size
|
||||
from .base import BaseView
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IndexView:
|
||||
class IndexView(BaseView):
|
||||
@aiohttp_jinja2.template("index.html")
|
||||
async def index(self, req):
|
||||
async def index(self, req: web.Request) -> web.Response:
|
||||
alias_id = req.match_info["chat"]
|
||||
chat = self.chat_ids[alias_id]
|
||||
log_msg = ""
|
||||
|
@ -39,7 +42,9 @@ class IndexView:
|
|||
if search_query:
|
||||
kwargs.update({"search": search_query})
|
||||
|
||||
messages = (await self.client.get_messages(**kwargs)) or []
|
||||
messages: List[custom.Message] = (
|
||||
await self.client.get_messages(**kwargs)
|
||||
) or []
|
||||
|
||||
except Exception:
|
||||
log.debug("failed to get messages", exc_info=True)
|
||||
|
|
|
@ -2,20 +2,22 @@ import logging
|
|||
from urllib.parse import unquote
|
||||
|
||||
import aiohttp_jinja2
|
||||
from aiohttp import web
|
||||
from telethon.tl import types
|
||||
from telethon.tl.custom import Message
|
||||
from jinja2 import Markup
|
||||
|
||||
from app.util import get_file_name, get_human_size
|
||||
from app.config import block_downloads
|
||||
from .base import BaseView
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InfoView:
|
||||
class InfoView(BaseView):
|
||||
@aiohttp_jinja2.template("info.html")
|
||||
async def info(self, req):
|
||||
async def info(self, req: web.Request) -> web.Response:
|
||||
file_id = int(req.match_info["id"])
|
||||
alias_id = req.match_info["chat"]
|
||||
chat = self.chat_ids[alias_id]
|
||||
|
|
|
@ -3,14 +3,15 @@ import time
|
|||
from aiohttp import web
|
||||
import aiohttp_jinja2
|
||||
from aiohttp_session import new_session
|
||||
from .base import BaseView
|
||||
|
||||
|
||||
class LoginView:
|
||||
class LoginView(BaseView):
|
||||
@aiohttp_jinja2.template("login.html")
|
||||
async def login_get(self, req):
|
||||
async def login_get(self, req: web.Request) -> web.Response:
|
||||
return dict(authenticated=False, **req.query)
|
||||
|
||||
async def login_post(self, req):
|
||||
async def login_post(self, req: web.Request) -> web.Response:
|
||||
post_data = await req.post()
|
||||
redirect_to = post_data.get("redirect_to") or "/"
|
||||
location = req.app.router["login_page"].url_for()
|
||||
|
|
|
@ -5,22 +5,29 @@ import random
|
|||
from aiohttp import web
|
||||
from telethon.tl import types
|
||||
|
||||
|
||||
from app.config import logo_folder
|
||||
from .base import BaseView
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogoView:
|
||||
async def logo(self, req):
|
||||
class LogoView(BaseView):
|
||||
async def logo(self, req: web.Request) -> web.Response:
|
||||
alias_id = req.match_info["chat"]
|
||||
chat = self.chat_ids[alias_id]
|
||||
chat_id = chat["chat_id"]
|
||||
chat_name = " ".join(map(lambda x: x[0].upper(), chat["title"].split(" ")))
|
||||
chat_name = " ".join(
|
||||
map(lambda x: x[0].upper(), (chat["title"] or "_").split(" "))
|
||||
)
|
||||
logo_path = logo_folder.joinpath(f"{alias_id}.jpg")
|
||||
if not logo_path.exists():
|
||||
try:
|
||||
(photo,) = await self.client.get_profile_photos(chat_id, limit=1)
|
||||
|
||||
photo: types.Photo = (
|
||||
await self.client.get_profile_photos(chat_id, limit=1)
|
||||
)[0]
|
||||
except Exception:
|
||||
log.debug(
|
||||
f"Error in getting profile picture in {chat_id}", exc_info=True
|
||||
|
@ -40,7 +47,7 @@ class LogoView:
|
|||
im.save(logo_path)
|
||||
else:
|
||||
pos = -1 if req.query.get("big", None) else int(len(photo.sizes) / 2)
|
||||
size = self.client._get_thumb(photo.sizes, pos)
|
||||
size: types.PhotoSize = self.client._get_thumb(photo.sizes, pos)
|
||||
if isinstance(size, (types.PhotoCachedSize, types.PhotoStrippedSize)):
|
||||
await self.client._download_cached_photo_size(size, logo_path)
|
||||
else:
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
from aiohttp_session import get_session
|
||||
from aiohttp import web
|
||||
|
||||
from .base import BaseView
|
||||
|
||||
class LogoutView:
|
||||
async def logout_get(self, req):
|
||||
|
||||
class LogoutView(BaseView):
|
||||
async def logout_get(self, req: web.Request) -> web.Response:
|
||||
session = await get_session(req)
|
||||
session["logged_in"] = False
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import time
|
||||
import logging
|
||||
from typing import Coroutine, Union
|
||||
|
||||
from aiohttp.web import middleware, HTTPFound, Response
|
||||
from aiohttp.web import middleware, HTTPFound, Response, Request
|
||||
from aiohttp import BasicAuth, hdrs
|
||||
from aiohttp_session import get_session
|
||||
|
||||
|
@ -9,7 +10,7 @@ from aiohttp_session import get_session
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _do_basic_auth_check(request):
|
||||
def _do_basic_auth_check(request: Request) -> Union[None, bool]:
|
||||
if "download_" not in request.match_info.route.name:
|
||||
return
|
||||
|
||||
|
@ -47,7 +48,7 @@ def _do_basic_auth_check(request):
|
|||
return True
|
||||
|
||||
|
||||
async def _do_cookies_auth_check(request):
|
||||
async def _do_cookies_auth_check(request: Request) -> Union[None, bool]:
|
||||
session = await get_session(request)
|
||||
if not session.get("logged_in", False):
|
||||
return
|
||||
|
@ -56,9 +57,9 @@ async def _do_cookies_auth_check(request):
|
|||
return True
|
||||
|
||||
|
||||
def middleware_factory():
|
||||
def middleware_factory() -> Coroutine:
|
||||
@middleware
|
||||
async def factory(request, handler):
|
||||
async def factory(request: Request, handler: Coroutine) -> Response:
|
||||
if request.app["is_authenticated"] and str(request.rel_url.path) not in [
|
||||
"/login",
|
||||
"/logout",
|
||||
|
|
|
@ -4,20 +4,24 @@ import random
|
|||
import io
|
||||
|
||||
from aiohttp import web
|
||||
from telethon.tl import types
|
||||
from telethon.tl import types, custom
|
||||
|
||||
from .base import BaseView
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThumbnailView:
|
||||
async def thumbnail_get(self, req):
|
||||
class ThumbnailView(BaseView):
|
||||
async def thumbnail_get(self, req: web.Request) -> web.Response:
|
||||
file_id = int(req.match_info["id"])
|
||||
alias_id = req.match_info["chat"]
|
||||
chat = self.chat_ids[alias_id]
|
||||
chat_id = chat["chat_id"]
|
||||
try:
|
||||
message = await self.client.get_messages(entity=chat_id, ids=file_id)
|
||||
message: custom.Message = await self.client.get_messages(
|
||||
entity=chat_id, ids=file_id
|
||||
)
|
||||
except Exception:
|
||||
log.debug(f"Error in getting message {file_id} in {chat_id}", exc_info=True)
|
||||
message = None
|
||||
|
@ -47,7 +51,9 @@ class ThumbnailView:
|
|||
else:
|
||||
thumb_pos = int(len(thumbnails) / 2)
|
||||
try:
|
||||
thumbnail = self.client._get_thumb(thumbnails, thumb_pos)
|
||||
thumbnail: types.PhotoSize = self.client._get_thumb(
|
||||
thumbnails, thumb_pos
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug(e)
|
||||
thumbnail = None
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from aiohttp import web
|
||||
|
||||
from .base import BaseView
|
||||
|
||||
class WildcardView:
|
||||
async def wildcard(self, req):
|
||||
|
||||
class WildcardView(BaseView):
|
||||
async def wildcard(self, req: web.Request) -> web.Response:
|
||||
return web.HTTPFound("/")
|
||||
|
|
Loading…
Reference in New Issue