diff --git a/batcher/__init__.py b/batcher/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/batcher/app/main.py b/batcher/app/main.py new file mode 100644 index 0000000..9e6f5bd --- /dev/null +++ b/batcher/app/main.py @@ -0,0 +1,27 @@ +from fastapi import Depends, FastAPI, Request, Response +from fastapi.middleware.cors import CORSMiddleware +from starlette.exceptions import HTTPException + +from .src.routers.api import router as router_api +from .src.routers.handlers import http_error_handler + + +def get_application() -> FastAPI: + application = FastAPI() + + application.include_router(router_api, prefix='/api') + + application.add_exception_handler(HTTPException, http_error_handler) + + application.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + return application + + +app = get_application() \ No newline at end of file diff --git a/batcher/app/src/config.py b/batcher/app/src/config.py new file mode 100644 index 0000000..dce7c4f --- /dev/null +++ b/batcher/app/src/config.py @@ -0,0 +1,29 @@ +from starlette.config import Config +from starlette.datastructures import Secret +from functools import lru_cache + +config = Config('.env') + + +REDIS_USER = config('REDIS_USER') +REDIS_PASSWORD = config('REDIS_PASSWORD', cast=Secret) +REDIS_PORT = config('REDIS_PORT', cast=int) +REDIS_HOST = config('REDIS_HOST') +REDIS_DB = config('REDIS_DB') + +HTTP_PORT = config('HTTP_PORT', cast=int) + +PG_HOST = config('POSTGRES_HOST') +PG_PORT = config('POSTGRES_PORT', cast=int) +PG_USER = config('POSTGRES_USER') +PG_PASSWORD = config('POSTGRES_PASSWORD', cast=Secret) +PG_DB = config('POSTGRES_DB') + +RMQ_HOST = config('RABBITMQ_HOST') +RMQ_PORT = config('RABBITMQ_PORT', cast=int) +RMQ_USER = config('RABBITMQ_DEFAULT_USER') +RMQ_PASSWORD = config('RABBITMQ_DEFAULT_PASSWORD', cast=Secret) + +TG_TOKEN = config('TG_TOKEN', cast=Secret) + +BACKEND_URL = config('BACKEND_URL', default='http://backend:8000') diff --git a/batcher/app/src/db/__init__.py b/batcher/app/src/db/__init__.py new file mode 100644 index 0000000..ac05ff3 --- /dev/null +++ b/batcher/app/src/db/__init__.py @@ -0,0 +1,3 @@ +from .pg import get_pg +from .redis import get_redis +from .rmq import get_rmq \ No newline at end of file diff --git a/batcher/app/src/db/pg.py b/batcher/app/src/db/pg.py new file mode 100644 index 0000000..23f9ba5 --- /dev/null +++ b/batcher/app/src/db/pg.py @@ -0,0 +1,19 @@ +import asyncpg +import asyncio + +from ..config import PG_HOST, PG_PORT, PG_USER, PG_PASSWORD, PG_DB + + +DB_URL = f'postgresql://{PG_USER}:{str(PG_PASSWORD)}@{PG_HOST}:{PG_PORT}/{PG_DB}' + + +async def connect_db() -> asyncpg.Pool: + return await asyncpg.create_pool(DB_URL) + + +pool = asyncio.run(connect_db()) + + +async def get_pg() -> asyncpg.Connection: + async with pool.acquire() as conn: + yield conn \ No newline at end of file diff --git a/batcher/app/src/db/redis.py b/batcher/app/src/db/redis.py new file mode 100644 index 0000000..4430e3f --- /dev/null +++ b/batcher/app/src/db/redis.py @@ -0,0 +1,11 @@ +import asyncio +import redis.asyncio as redis + +from ..config import REDIS_HOST, REDIS_PORT, REDIS_USER, REDIS_PASSWORD, REDIS_DB + + +r = asyncio.run(redis.Redis(host=REDIS_HOST, port=REDIS_PORT, username=REDIS_USER, password=REDIS_PASSWORD, db=REDIS_DB)) + + +def get_redis() -> redis.Redis: + yield r diff --git a/batcher/app/src/db/rmq.py b/batcher/app/src/db/rmq.py new file mode 100644 index 0000000..eac9504 --- /dev/null +++ b/batcher/app/src/db/rmq.py @@ -0,0 +1,26 @@ +import aio_pika +from aio_pika.abc import AbstractRobustConnection +import asyncio + +from ..config import RMQ_HOST, RMQ_PORT, RMQ_USER, RMQ_PASSWORD + + +async def get_connection() -> AbstractRobustConnection: + return await aio_pika.connect_robust(f'amqp://{RMQ_USER}:{RMQ_PASSWORD}@{RMQ_HOST}:{RMQ_PORT}/') + + +conn_pool = aio_pika.pool.Pool(get_connection, max_size=2) + + +async def get_channel() -> aio_pika.Channel: + async with conn_pool.acquire() as connection: + return await connection.channel() + + +chan_pool = aio_pika.pool.Pool(get_channel, max_size=10) + + +async def get_rmq() -> aio_pika.Channel: + async with chan_pool.acquire() as chan: + yield chan + diff --git a/batcher/app/src/dependencies.py b/batcher/app/src/dependencies.py new file mode 100644 index 0000000..faee168 --- /dev/null +++ b/batcher/app/src/dependencies.py @@ -0,0 +1,52 @@ +import time +import hmac +import base64 +import hashlib +import json +from fastapi import Header, HTTPException + +from .config import TG_TOKEN + + +async def get_token_header(authorization: str = Header()) -> (int, str): + if not authorization: + raise HTTPException(status_code=403, detail='Unauthorized') + + if not authorization.startswith('TelegramToken '): + raise HTTPException(status_code=403, detail='Unauthorized') + + token = ' '.join(authorization.split()[1:]) + + split_res = base64.b64decode(token).decode('utf-8').split(':') + try: + data_check_string = ':'.join(split_res[:-1]).strip().replace('/', '\\/') + _hash = split_res[-1] + except IndexError: + raise HTTPException(status_code=403, detail='Unauthorized') + secret = hmac.new( + 'WebAppData'.encode(), + TG_TOKEN.encode('utf-8'), + digestmod=hashlib.sha256 + ).digest() + actual_hash = hmac.new( + secret, + msg=data_check_string.encode('utf-8'), + digestmod=hashlib.sha256 + ).hexdigest() + if hash != actual_hash: + raise HTTPException(status_code=403, detail='Unauthorized') + + data_dict = dict([x.split('=') for x in data_check_string.split('\n')]) + try: + auth_date = int(data_dict['auth_date']) + except KeyError: + raise HTTPException(status_code=403, detail='Unauthorized') + except ValueError: + raise HTTPException(status_code=403, detail='Unauthorized') + + if auth_date + 60 * 30 < int(time.time()): + raise HTTPException(status_code=403, detail='Unauthorized') + + user_info = json.loads(data_dict['user']) + return user_info['id'], authorization + diff --git a/batcher/app/src/domain/click/__init__.py b/batcher/app/src/domain/click/__init__.py index e69de29..5d30c07 100644 --- a/batcher/app/src/domain/click/__init__.py +++ b/batcher/app/src/domain/click/__init__.py @@ -0,0 +1,2 @@ +from .schemas import ClickResponse, BatchClickRequest, EnergyResponse, ClickValueResponse +from .usecase import add_click_batch_copy, check_registration, check_energy, get_energy, click_value, delete_user_info diff --git a/batcher/app/src/domain/click/repos/rmq.py b/batcher/app/src/domain/click/repos/rmq.py index 01cdece..a016a0d 100644 --- a/batcher/app/src/domain/click/repos/rmq.py +++ b/batcher/app/src/domain/click/repos/rmq.py @@ -1,29 +1,22 @@ import json -import kombu +import aio_pika import uuid from ..models import Click CELERY_QUEUE_NAME = "celery" -SETTING_QUEUE_NAME = "settings" CLICK_TASK_NAME = "clicks.celery.click.handle_click" -SETTING_TASK_NAME = "misc.celery.deliver_setting.deliver_setting" -def send_click_batch_copy(conn: kombu.Connection, click: Click, count: int): - producer = kombu.Producer(conn) - producer.publish( - json.dumps({ +def send_click_batch_copy(chan: aio_pika.Channel, click: Click, count: int): + await chan.default_exchange.publish( + message=aio_pika.Message(json.dumps({ 'id': str(uuid.uuid4()), 'task': CLICK_TASK_NAME, 'args': [click.UserID, int(click.DateTime.timestamp() * 1e3), str(click.Value), count], 'kwargs': dict(), - }), + }).encode('utf-8')), routing_key=CELERY_QUEUE_NAME, - delivery_mode='persistent', mandatory=False, - immediate=False, - content_type='application/json', - serializer='json', ) diff --git a/batcher/app/src/domain/click/schemas.py b/batcher/app/src/domain/click/schemas.py index 717c377..3eec6df 100644 --- a/batcher/app/src/domain/click/schemas.py +++ b/batcher/app/src/domain/click/schemas.py @@ -14,5 +14,9 @@ class ClickValueResponse(pydantic.BaseModel): value: decimal.Decimal +class EnergyResponse(pydantic.BaseModel): + energy: int + + class BatchClickRequest(pydantic.BaseModel): count: int diff --git a/batcher/app/src/domain/click/usecase.py b/batcher/app/src/domain/click/usecase.py index 4c87148..2158dc8 100644 --- a/batcher/app/src/domain/click/usecase.py +++ b/batcher/app/src/domain/click/usecase.py @@ -3,9 +3,10 @@ import decimal from typing import Tuple import aiohttp import redis.asyncio as redis -import kombu +import aio_pika import asyncpg +from batcher.app.src.domain.setting import get_setting from .repos.redis import ( get_period_sum, incr_period_sum, get_max_period_sum, get_user_total, get_global_average, incr_user_count_if_no_clicks, update_global_average, incr_user_total, compare_max_period_sum, @@ -14,22 +15,13 @@ from .repos.redis import ( ) from .repos.pg import update_click_expiry, bulk_store_copy, delete_by_user_id from .repos.rmq import send_click_batch_copy - from .models import Click -SETTING_DICT = { - 'PRICE_PER_CLICK': decimal.Decimal(1), - 'DAY_MULT': decimal.Decimal(1), - 'WEEK_MULT': decimal.Decimal(1), - 'PROGRESS_MULT': decimal.Decimal(1), - 'SESSION_ENERGY': decimal.Decimal(500), -} - PRECISION = 2 -async def add_click_batch_copy(r: redis.Redis, pg: asyncpg.Connection, rmq: kombu.Connection, user_id: int, count: int) -> Click: +async def add_click_batch_copy(r: redis.Redis, pg: asyncpg.Connection, rmq: aio_pika.Channel, user_id: int, count: int) -> Click: _click_value = await click_value(r, pg, user_id) click_value_sum = _click_value * count @@ -131,7 +123,3 @@ async def check_energy(r: redis.Redis, user_id: int, amount: int, _token: str) - async def get_energy(r: redis.Redis, user_id: int, _token: str) -> int: return await _get_refresh_energy(r, user_id, _token) - - -def get_setting(name: str) -> decimal.Decimal: - return SETTING_DICT[name] diff --git a/batcher/app/src/domain/setting/__init__.py b/batcher/app/src/domain/setting/__init__.py new file mode 100644 index 0000000..5d5610c --- /dev/null +++ b/batcher/app/src/domain/setting/__init__.py @@ -0,0 +1 @@ +from .usecase import get_setting, launch_consumer \ No newline at end of file diff --git a/batcher/app/src/domain/setting/repos/__init__.py b/batcher/app/src/domain/setting/repos/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/batcher/app/src/domain/setting/repos/in_memory_storage.py b/batcher/app/src/domain/setting/repos/in_memory_storage.py new file mode 100644 index 0000000..ac6e46c --- /dev/null +++ b/batcher/app/src/domain/setting/repos/in_memory_storage.py @@ -0,0 +1,21 @@ +import decimal +import threading + +_settings = dict() +mx = threading.Lock() + + +def get_setting(name: str) -> decimal.Decimal: + try: + mx.acquire() + return _settings[name] + finally: + mx.release() + + +def set_setting(name: str, value: decimal.Decimal): + try: + mx.acquire() + _settings[name] = value + finally: + mx.release() diff --git a/batcher/app/src/domain/setting/repos/rmq.py b/batcher/app/src/domain/setting/repos/rmq.py new file mode 100644 index 0000000..52d9de8 --- /dev/null +++ b/batcher/app/src/domain/setting/repos/rmq.py @@ -0,0 +1,20 @@ +import decimal +import json + +import aio_pika +from typing import Callable + + +SETTING_QUEUE_NAME = "settings" +SETTING_TASK_NAME = "misc.celery.deliver_setting.deliver_setting" + + +async def consume_setting_updates(update_setting_func: Callable[[str, decimal.Decimal], None], chan: aio_pika.Channel): + queue = await chan.get_queue(SETTING_QUEUE_NAME) + + async with queue.iterator() as queue_iter: + async for msg in queue_iter: + async with msg.process(): + settings = json.loads(msg.body.decode('utf-8')) + for name, value in settings.items(): + update_setting_func(name, value) diff --git a/batcher/app/src/domain/setting/usecase.py b/batcher/app/src/domain/setting/usecase.py new file mode 100644 index 0000000..11865a4 --- /dev/null +++ b/batcher/app/src/domain/setting/usecase.py @@ -0,0 +1,16 @@ +import decimal +import threading + +import aio_pika + +from .repos.in_memory_storage import get_setting as ims_get_setting +from .repos.rmq import consume_setting_updates + + +def get_setting(name: str) -> decimal.Decimal: + return ims_get_setting(name) + + +def launch_consumer(rmq: aio_pika.Connection): + t = threading.Thread(target=consume_setting_updates, args=(ims_get_setting, rmq)) + t.start() diff --git a/batcher/app/src/routers/__init__.py b/batcher/app/src/routers/__init__.py new file mode 100644 index 0000000..e94737c --- /dev/null +++ b/batcher/app/src/routers/__init__.py @@ -0,0 +1,2 @@ +from .api import router +from .handlers import http_error_handler \ No newline at end of file diff --git a/batcher/app/src/routers/api.py b/batcher/app/src/routers/api.py new file mode 100644 index 0000000..4db588a --- /dev/null +++ b/batcher/app/src/routers/api.py @@ -0,0 +1,11 @@ +from fastapi import APIRouter +from . import click + +router = APIRouter() + + +def include_api_routes(): + router.include_router(click.router, prefix='/v1') + + +include_api_routes() \ No newline at end of file diff --git a/batcher/app/src/routers/click.py b/batcher/app/src/routers/click.py new file mode 100644 index 0000000..65a2412 --- /dev/null +++ b/batcher/app/src/routers/click.py @@ -0,0 +1,71 @@ +import aio_pika +import asyncpg +import redis +from fastapi import APIRouter, Depends, HTTPException +from typing import Annotated +from ..domain.click import ( + ClickResponse, BatchClickRequest, EnergyResponse, ClickValueResponse, + add_click_batch_copy, check_registration, check_energy, get_energy, click_value, delete_user_info +) + +from ..dependencies import get_token_header +from ..db import get_pg, get_redis, get_rmq +from ..config import BACKEND_URL + + +router = APIRouter( + prefix="", + tags=['click'], + dependencies=[], + responses={404: {'description': 'Not found'}}, +) + + +@router.post("/batch-click/", response_model=ClickResponse, status_code=200) +async def batch_click(req: BatchClickRequest, auth_info: Annotated[(int, str), Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)], r: Annotated[redis.Redis, Depends(get_redis)], rmq: Annotated[aio_pika.Channel, Depends(get_rmq)]): + user_id, token = auth_info + if not check_registration(r, user_id, token, BACKEND_URL): + raise HTTPException(status_code=403, detail='Unauthorized') + + _energy, spent = await check_energy(r, user_id, req.count, token) + if spent == 0: + raise HTTPException(status_code=400, detail='No energy') + + click = await add_click_batch_copy(r, pg, rmq, user_id, spent) + return ClickResponse( + click=click, + energy=_energy + ) + + +@router.get("/energy", response_model=EnergyResponse, status_code=200) +async def energy(auth_info: Annotated[(int, str), Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)]): + user_id, token = auth_info + if not check_registration(r, user_id, token, BACKEND_URL): + raise HTTPException(status_code=403, detail='Unauthorized') + + _energy = await get_energy(r, user_id, token) + return EnergyResponse( + energy=_energy + ) + + +@router.get('/coefficient', response_model=ClickValueResponse, status_code=200) +async def coefficient(auth_info: Annotated[(int, str), Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]): + user_id, token = auth_info + if not check_registration(r, user_id, token, BACKEND_URL): + raise HTTPException(status_code=403, detail='Unauthorized') + + value = await click_value(r, pg, user_id) + return ClickValueResponse( + value=value + ) + + +@router.delete('/internal/user', status_code=204) +async def delete_user(auth_info: Annotated[(int, str), Depends(get_token_header())], r: Annotated[redis.Redis, Depends(get_redis)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]): + user_id, token = auth_info + if not check_registration(r, user_id, token, BACKEND_URL): + raise HTTPException(status_code=403, detail='Unauthorized') + + await delete_user_info(r, pg, user_id) \ No newline at end of file diff --git a/batcher/app/src/routers/handlers/__init__.py b/batcher/app/src/routers/handlers/__init__.py new file mode 100644 index 0000000..ca85473 --- /dev/null +++ b/batcher/app/src/routers/handlers/__init__.py @@ -0,0 +1 @@ +from .http_error_handler import http_error_handler \ No newline at end of file diff --git a/batcher/app/src/routers/handlers/http_error_handler.py b/batcher/app/src/routers/handlers/http_error_handler.py new file mode 100644 index 0000000..a37fdd3 --- /dev/null +++ b/batcher/app/src/routers/handlers/http_error_handler.py @@ -0,0 +1,7 @@ +from fastapi import HTTPException +from starlette.requests import Request +from starlette.responses import JSONResponse + + +async def http_error_handler(_: Request, exc: HTTPException) -> JSONResponse: + return JSONResponse({"error": exc.detail}, status_code=exc.status_code)