fix some issues in authentication

This commit is contained in:
odysseusmax 2021-06-14 21:23:33 +05:30
parent bf9348699a
commit e566d5048c
8 changed files with 106 additions and 28 deletions

View File

@ -33,18 +33,20 @@ class Indexer:
TEMPLATES_ROOT = pathlib.Path(__file__).parent / "templates" TEMPLATES_ROOT = pathlib.Path(__file__).parent / "templates"
def __init__(self): def __init__(self):
self.server = web.Application( middlewares = []
middlewares=[ if authenticated:
middlewares.append(
session_middleware( session_middleware(
EncryptedCookieStorage( EncryptedCookieStorage(
secret_key=SECRET_KEY.encode(), secret_key=SECRET_KEY.encode(),
max_age=60 * SESSION_COOKIE_LIFETIME, max_age=60 * SESSION_COOKIE_LIFETIME,
cookie_name="TG_INDEX_SESSION" cookie_name="TG_INDEX_SESSION",
) )
), )
middleware_factory(), )
]
) middlewares.append(middleware_factory())
self.server = web.Application(middlewares=middlewares)
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.tg_client = Client(session_string, api_id, api_hash) self.tg_client = Client(session_string, api_id, api_hash)

View File

@ -22,16 +22,30 @@ async def setup_routes(app, handler):
web.get("/login", h.login_get, name="login_page"), web.get("/login", h.login_get, name="login_page"),
web.post("/login", h.login_post, name="login_handle"), web.post("/login", h.login_post, name="login_handle"),
web.get("/logout", h.logout_get, name="logout"), web.get("/logout", h.logout_get, name="logout"),
web.get("/favicon.ico", h.faviconicon, name="favicon"),
] ]
def get_common_routes(p): def get_common_routes(alias_id):
p = "/{chat:" + alias_id + "}"
return [ return [
web.get(p, h.index), web.get(p, h.index, name=f"index_{alias_id}"),
web.get(p + r"/logo", h.logo), web.get(p + r"/logo", h.logo, name=f"logo_{alias_id}"),
web.get(p + r"/{id:\d+}/view", h.info), web.get(p + r"/{id:\d+}/view", h.info, name=f"info_{alias_id}"),
web.get(p + r"/{id:\d+}/thumbnail", h.thumbnail_get), web.get(
web.get(p + r"/{id:\d+}/{filename}", h.download_get), p + r"/{id:\d+}/thumbnail",
web.head(p + r"/{id:\d+}/{filename}", h.download_head), h.thumbnail_get,
name=f"thumbnail_get_{alias_id}",
),
web.get(
p + r"/{id:\d+}/{filename}",
h.download_get,
name=f"download_get_{alias_id}",
),
web.head(
p + r"/{id:\d+}/{filename}",
h.download_head,
name=f"download_head_{alias_id}",
),
] ]
if index_all: if index_all:
@ -56,16 +70,14 @@ async def setup_routes(app, handler):
continue continue
alias_id = h.generate_alias_id(chat) alias_id = h.generate_alias_id(chat)
p = "/{chat:" + alias_id + "}" routes.extend(get_common_routes(alias_id))
routes.extend(get_common_routes(p))
log.debug(f"Index added for {chat.id} at /{alias_id}") log.debug(f"Index added for {chat.id} at /{alias_id}")
else: else:
for chat_id in include_chats: for chat_id in include_chats:
chat = await client.get_entity(chat_id) chat = await client.get_entity(chat_id)
alias_id = h.generate_alias_id(chat) alias_id = h.generate_alias_id(chat)
p = "/{chat:" + alias_id + "}" routes.extend(get_common_routes(alias_id)) # returns list() of common routes
routes.extend(get_common_routes(p)) # returns list() of common routes
log.debug(f"Index added for {chat.id} at /{alias_id}") log.debug(f"Index added for {chat.id} at /{alias_id}")
routes.append(web.view(r"/{wildcard:.*}", h.wildcard)) routes.append(web.view(r"/{wildcard:.*}", h.wildcard))
app.add_routes(routes) app.add_routes(routes)

View File

@ -12,27 +12,37 @@ class Client(TelegramClient):
self.log = logging.getLogger(__name__) self.log = logging.getLogger(__name__)
async def download(self, file, file_size, offset, limit): async def download(self, file, file_size, offset, limit):
part_size_kb = utils.get_appropriated_part_size(file_size) part_size = utils.get_appropriated_part_size(file_size) * 1024
part_size = int(part_size_kb * 1024)
first_part_cut = offset % part_size first_part_cut = offset % part_size
first_part = math.floor(offset / part_size) first_part = math.floor(offset / part_size)
last_part_cut = part_size - (limit % part_size) last_part_cut = part_size - (limit % part_size)
last_part = math.ceil(limit / part_size) last_part = math.ceil(limit / part_size)
part_count = math.ceil(file_size / part_size) part_count = math.ceil(file_size / part_size)
part = first_part part = first_part
self.log.debug(
f"""Request Details
part_size(bytes) = {part_size},
first_part = {first_part}, cut = {first_part_cut}(length={part_size-first_part_cut}),
last_part = {last_part}, cut = {last_part_cut}(length={last_part_cut}),
parts_count = {part_count}
"""
)
try: try:
async for chunk in self.iter_download( async for chunk in self.iter_download(
file, offset=first_part * part_size, request_size=part_size file, offset=first_part * part_size, request_size=part_size
): ):
self.log.debug(f"Part {part}/{last_part} (total {part_count}) served!")
if part == first_part: if part == first_part:
yield chunk[first_part_cut:] yield chunk[first_part_cut:]
elif part == last_part - 1: elif part == last_part:
yield chunk[:last_part_cut] yield chunk[:last_part_cut]
break
else: else:
yield chunk yield chunk
self.log.debug(f"Part {part}/{last_part} (total {part_count}) served!")
part += 1 part += 1
self.log.debug("serving finished")
self.log.debug(f"serving finished")
except (GeneratorExit, StopAsyncIteration, asyncio.CancelledError): except (GeneratorExit, StopAsyncIteration, asyncio.CancelledError):
self.log.debug("file serve interrupted") self.log.debug("file serve interrupted")
raise raise

View File

@ -12,6 +12,7 @@ from .logo_view import LogoView
from .thumbnail_view import ThumbnailView from .thumbnail_view import ThumbnailView
from .login_view import LoginView from .login_view import LoginView
from .logout_view import LogoutView from .logout_view import LogoutView
from .faviconicon_view import FaviconIconView
from .middlewhere import middleware_factory from .middlewhere import middleware_factory
@ -25,6 +26,7 @@ class Views(
WildcardView, WildcardView,
LoginView, LoginView,
LogoutView, LogoutView,
FaviconIconView,
): ):
def __init__(self, client): def __init__(self, client):
self.client = client self.client = client
@ -38,7 +40,7 @@ class Views(
while True: while True:
orig_id = f"{chat_id}" # the original id orig_id = f"{chat_id}" # the original id
unique_hash = hashlib.md5(orig_id.encode()).digest() unique_hash = hashlib.md5(orig_id.encode()).digest()
alias_id = base64.urlsafe_b64encode(unique_hash).decode()[: self.url_len] alias_id = base64.b64encode(unique_hash, b"__").decode()[: self.url_len]
if alias_id in self.chat_ids: if alias_id in self.chat_ids:
self.url_len += ( self.url_len += (

View File

@ -0,0 +1,33 @@
import random
from PIL import Image, ImageDraw, ImageFont
from aiohttp import web
from app.config import logo_folder
class FaviconIconView:
async def faviconicon(self, req):
favicon_path = logo_folder.joinpath("favicon.ico")
text = "T"
if not favicon_path.exists():
W, H = (360, 360)
color = tuple((random.randint(0, 255) for _ in range(3)))
im = Image.new("RGB", (W, H), color)
draw = ImageDraw.Draw(im)
font = ImageFont.truetype("arial.ttf", 50)
w, h = draw.textsize(text, font=font)
draw.text(((W - w) / 2, (H - h) / 2), text, fill="white", font=font)
im.save(favicon_path)
with open(favicon_path, "rb") as fp:
body = fp.read()
return web.Response(
status=200,
body=body,
headers={
"Content-Type": "image/jpeg",
"Content-Disposition": 'inline; filename="favicon.ico"',
},
)

View File

@ -105,5 +105,5 @@ class IndexView:
"block_downloads": block_downloads, "block_downloads": block_downloads,
"m3u_option": "" "m3u_option": ""
if not req.app["is_authenticated"] if not req.app["is_authenticated"]
else f"{req.app['is_authenticated']}:{req.app['is_authenticated']}@", else f"{req.app['username']}:{req.app['password']}@",
} }

View File

@ -1,5 +1,4 @@
import logging import logging
import math
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import random import random

View File

@ -1,7 +1,7 @@
import time import time
import logging import logging
from aiohttp.web import middleware, HTTPFound from aiohttp.web import middleware, HTTPFound, Response
from aiohttp import BasicAuth, hdrs from aiohttp import BasicAuth, hdrs
from aiohttp_session import get_session from aiohttp_session import get_session
@ -12,6 +12,17 @@ log = logging.getLogger(__name__)
def _do_basic_auth_check(request): def _do_basic_auth_check(request):
auth_header = request.headers.get(hdrs.AUTHORIZATION) auth_header = request.headers.get(hdrs.AUTHORIZATION)
if not auth_header: if not auth_header:
if "download_" in request.match_info.route.name:
return Response(
body=b"",
status=401,
reason="UNAUTHORIZED",
headers={
hdrs.WWW_AUTHENTICATE: 'Basic realm=""',
hdrs.CONTENT_TYPE: "text/html; charset=utf-8",
hdrs.CONNECTION: "keep-alive",
},
)
return return
try: try:
@ -25,6 +36,12 @@ def _do_basic_auth_check(request):
if auth.login is None or auth.password is None: if auth.login is None or auth.password is None:
return return
if (
auth.login != request.app["username"]
or auth.password != request.app["password"]
):
return
return True return True
@ -50,7 +67,7 @@ def middleware_factory():
basic_auth_check_resp = _do_basic_auth_check(request) basic_auth_check_resp = _do_basic_auth_check(request)
if basic_auth_check_resp is not None: if basic_auth_check_resp is True:
return await handler(request) return await handler(request)
cookies_auth_check_resp = await _do_cookies_auth_check(request) cookies_auth_check_resp = await _do_cookies_auth_check(request)
@ -58,6 +75,9 @@ def middleware_factory():
if cookies_auth_check_resp is not None: if cookies_auth_check_resp is not None:
return await handler(request) return await handler(request)
if isinstance(basic_auth_check_resp, Response):
return basic_auth_check_resp
return HTTPFound(url) return HTTPFound(url)
return await handler(request) return await handler(request)