microblog.pub/app/utils/workers.py

68 lines
2.2 KiB
Python
Raw Normal View History

2022-08-10 20:39:19 +02:00
import asyncio
import signal
from typing import Generic
from typing import TypeVar
from loguru import logger
from app.database import AsyncSession
from app.database import async_session
T = TypeVar("T")
class Worker(Generic[T]):
2022-08-11 12:24:17 +02:00
def __init__(self) -> None:
2022-08-10 20:39:19 +02:00
self._loop = asyncio.get_event_loop()
self._stop_event = asyncio.Event()
async def process_message(self, db_session: AsyncSession, message: T) -> None:
raise NotImplementedError
async def get_next_message(self, db_session: AsyncSession) -> T | None:
raise NotImplementedError
async def startup(self, db_session: AsyncSession) -> None:
return None
2022-08-11 12:24:17 +02:00
async def _main_loop(self, db_session: AsyncSession) -> None:
while not self._stop_event.is_set():
next_message = await self.get_next_message(db_session)
if next_message:
await self.process_message(db_session, next_message)
else:
await asyncio.sleep(1)
async def _until_stopped(self) -> None:
await self._stop_event.wait()
2022-08-10 20:39:19 +02:00
async def run_forever(self) -> None:
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for s in signals:
self._loop.add_signal_handler(
s,
lambda s=s: asyncio.create_task(self._shutdown(s)),
)
async with async_session() as db_session:
await self.startup(db_session)
2022-08-11 12:24:17 +02:00
task = self._loop.create_task(self._main_loop(db_session))
stop_task = self._loop.create_task(self._until_stopped())
2022-08-10 20:39:19 +02:00
2022-08-11 12:24:17 +02:00
done, pending = await asyncio.wait(
{task, stop_task}, return_when=asyncio.FIRST_COMPLETED
)
logger.info(f"Waiting for tasks to finish {done=}/{pending=}")
await asyncio.sleep(5)
2022-08-10 20:39:19 +02:00
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
logger.info(f"Cancelling {len(tasks)} tasks")
[task.cancel() for task in tasks]
await asyncio.gather(*tasks, return_exceptions=True)
logger.info("stopping loop")
async def _shutdown(self, sig: signal.Signals) -> None:
logger.info(f"Caught {signal=}")
self._stop_event.set()