Fully implemented batcher (without testing)

This commit is contained in:
Michail Kostochka 2024-10-23 11:54:32 +03:00
parent 504adfb263
commit 62cb52a8ae
22 changed files with 331 additions and 27 deletions

0
batcher/__init__.py Normal file
View File

27
batcher/app/main.py Normal file
View File

@ -0,0 +1,27 @@
from fastapi import Depends, FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException
from .src.routers.api import router as router_api
from .src.routers.handlers import http_error_handler
def get_application() -> FastAPI:
application = FastAPI()
application.include_router(router_api, prefix='/api')
application.add_exception_handler(HTTPException, http_error_handler)
application.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
return application
app = get_application()

29
batcher/app/src/config.py Normal file
View File

@ -0,0 +1,29 @@
from starlette.config import Config
from starlette.datastructures import Secret
from functools import lru_cache
config = Config('.env')
REDIS_USER = config('REDIS_USER')
REDIS_PASSWORD = config('REDIS_PASSWORD', cast=Secret)
REDIS_PORT = config('REDIS_PORT', cast=int)
REDIS_HOST = config('REDIS_HOST')
REDIS_DB = config('REDIS_DB')
HTTP_PORT = config('HTTP_PORT', cast=int)
PG_HOST = config('POSTGRES_HOST')
PG_PORT = config('POSTGRES_PORT', cast=int)
PG_USER = config('POSTGRES_USER')
PG_PASSWORD = config('POSTGRES_PASSWORD', cast=Secret)
PG_DB = config('POSTGRES_DB')
RMQ_HOST = config('RABBITMQ_HOST')
RMQ_PORT = config('RABBITMQ_PORT', cast=int)
RMQ_USER = config('RABBITMQ_DEFAULT_USER')
RMQ_PASSWORD = config('RABBITMQ_DEFAULT_PASSWORD', cast=Secret)
TG_TOKEN = config('TG_TOKEN', cast=Secret)
BACKEND_URL = config('BACKEND_URL', default='http://backend:8000')

View File

@ -0,0 +1,3 @@
from .pg import get_pg
from .redis import get_redis
from .rmq import get_rmq

19
batcher/app/src/db/pg.py Normal file
View File

@ -0,0 +1,19 @@
import asyncpg
import asyncio
from ..config import PG_HOST, PG_PORT, PG_USER, PG_PASSWORD, PG_DB
DB_URL = f'postgresql://{PG_USER}:{str(PG_PASSWORD)}@{PG_HOST}:{PG_PORT}/{PG_DB}'
async def connect_db() -> asyncpg.Pool:
return await asyncpg.create_pool(DB_URL)
pool = asyncio.run(connect_db())
async def get_pg() -> asyncpg.Connection:
async with pool.acquire() as conn:
yield conn

View File

@ -0,0 +1,11 @@
import asyncio
import redis.asyncio as redis
from ..config import REDIS_HOST, REDIS_PORT, REDIS_USER, REDIS_PASSWORD, REDIS_DB
r = asyncio.run(redis.Redis(host=REDIS_HOST, port=REDIS_PORT, username=REDIS_USER, password=REDIS_PASSWORD, db=REDIS_DB))
def get_redis() -> redis.Redis:
yield r

26
batcher/app/src/db/rmq.py Normal file
View File

@ -0,0 +1,26 @@
import aio_pika
from aio_pika.abc import AbstractRobustConnection
import asyncio
from ..config import RMQ_HOST, RMQ_PORT, RMQ_USER, RMQ_PASSWORD
async def get_connection() -> AbstractRobustConnection:
return await aio_pika.connect_robust(f'amqp://{RMQ_USER}:{RMQ_PASSWORD}@{RMQ_HOST}:{RMQ_PORT}/')
conn_pool = aio_pika.pool.Pool(get_connection, max_size=2)
async def get_channel() -> aio_pika.Channel:
async with conn_pool.acquire() as connection:
return await connection.channel()
chan_pool = aio_pika.pool.Pool(get_channel, max_size=10)
async def get_rmq() -> aio_pika.Channel:
async with chan_pool.acquire() as chan:
yield chan

View File

@ -0,0 +1,52 @@
import time
import hmac
import base64
import hashlib
import json
from fastapi import Header, HTTPException
from .config import TG_TOKEN
async def get_token_header(authorization: str = Header()) -> (int, str):
if not authorization:
raise HTTPException(status_code=403, detail='Unauthorized')
if not authorization.startswith('TelegramToken '):
raise HTTPException(status_code=403, detail='Unauthorized')
token = ' '.join(authorization.split()[1:])
split_res = base64.b64decode(token).decode('utf-8').split(':')
try:
data_check_string = ':'.join(split_res[:-1]).strip().replace('/', '\\/')
_hash = split_res[-1]
except IndexError:
raise HTTPException(status_code=403, detail='Unauthorized')
secret = hmac.new(
'WebAppData'.encode(),
TG_TOKEN.encode('utf-8'),
digestmod=hashlib.sha256
).digest()
actual_hash = hmac.new(
secret,
msg=data_check_string.encode('utf-8'),
digestmod=hashlib.sha256
).hexdigest()
if hash != actual_hash:
raise HTTPException(status_code=403, detail='Unauthorized')
data_dict = dict([x.split('=') for x in data_check_string.split('\n')])
try:
auth_date = int(data_dict['auth_date'])
except KeyError:
raise HTTPException(status_code=403, detail='Unauthorized')
except ValueError:
raise HTTPException(status_code=403, detail='Unauthorized')
if auth_date + 60 * 30 < int(time.time()):
raise HTTPException(status_code=403, detail='Unauthorized')
user_info = json.loads(data_dict['user'])
return user_info['id'], authorization

View File

@ -0,0 +1,2 @@
from .schemas import ClickResponse, BatchClickRequest, EnergyResponse, ClickValueResponse
from .usecase import add_click_batch_copy, check_registration, check_energy, get_energy, click_value, delete_user_info

View File

@ -1,29 +1,22 @@
import json import json
import kombu import aio_pika
import uuid import uuid
from ..models import Click from ..models import Click
CELERY_QUEUE_NAME = "celery" CELERY_QUEUE_NAME = "celery"
SETTING_QUEUE_NAME = "settings"
CLICK_TASK_NAME = "clicks.celery.click.handle_click" CLICK_TASK_NAME = "clicks.celery.click.handle_click"
SETTING_TASK_NAME = "misc.celery.deliver_setting.deliver_setting"
def send_click_batch_copy(conn: kombu.Connection, click: Click, count: int): def send_click_batch_copy(chan: aio_pika.Channel, click: Click, count: int):
producer = kombu.Producer(conn) await chan.default_exchange.publish(
producer.publish( message=aio_pika.Message(json.dumps({
json.dumps({
'id': str(uuid.uuid4()), 'id': str(uuid.uuid4()),
'task': CLICK_TASK_NAME, 'task': CLICK_TASK_NAME,
'args': [click.UserID, int(click.DateTime.timestamp() * 1e3), str(click.Value), count], 'args': [click.UserID, int(click.DateTime.timestamp() * 1e3), str(click.Value), count],
'kwargs': dict(), 'kwargs': dict(),
}), }).encode('utf-8')),
routing_key=CELERY_QUEUE_NAME, routing_key=CELERY_QUEUE_NAME,
delivery_mode='persistent',
mandatory=False, mandatory=False,
immediate=False,
content_type='application/json',
serializer='json',
) )

View File

@ -14,5 +14,9 @@ class ClickValueResponse(pydantic.BaseModel):
value: decimal.Decimal value: decimal.Decimal
class EnergyResponse(pydantic.BaseModel):
energy: int
class BatchClickRequest(pydantic.BaseModel): class BatchClickRequest(pydantic.BaseModel):
count: int count: int

View File

@ -3,9 +3,10 @@ import decimal
from typing import Tuple from typing import Tuple
import aiohttp import aiohttp
import redis.asyncio as redis import redis.asyncio as redis
import kombu import aio_pika
import asyncpg import asyncpg
from batcher.app.src.domain.setting import get_setting
from .repos.redis import ( from .repos.redis import (
get_period_sum, incr_period_sum, get_max_period_sum, get_user_total, get_global_average, get_period_sum, incr_period_sum, get_max_period_sum, get_user_total, get_global_average,
incr_user_count_if_no_clicks, update_global_average, incr_user_total, compare_max_period_sum, incr_user_count_if_no_clicks, update_global_average, incr_user_total, compare_max_period_sum,
@ -14,22 +15,13 @@ from .repos.redis import (
) )
from .repos.pg import update_click_expiry, bulk_store_copy, delete_by_user_id from .repos.pg import update_click_expiry, bulk_store_copy, delete_by_user_id
from .repos.rmq import send_click_batch_copy from .repos.rmq import send_click_batch_copy
from .models import Click from .models import Click
SETTING_DICT = {
'PRICE_PER_CLICK': decimal.Decimal(1),
'DAY_MULT': decimal.Decimal(1),
'WEEK_MULT': decimal.Decimal(1),
'PROGRESS_MULT': decimal.Decimal(1),
'SESSION_ENERGY': decimal.Decimal(500),
}
PRECISION = 2 PRECISION = 2
async def add_click_batch_copy(r: redis.Redis, pg: asyncpg.Connection, rmq: kombu.Connection, user_id: int, count: int) -> Click: async def add_click_batch_copy(r: redis.Redis, pg: asyncpg.Connection, rmq: aio_pika.Channel, user_id: int, count: int) -> Click:
_click_value = await click_value(r, pg, user_id) _click_value = await click_value(r, pg, user_id)
click_value_sum = _click_value * count click_value_sum = _click_value * count
@ -131,7 +123,3 @@ async def check_energy(r: redis.Redis, user_id: int, amount: int, _token: str) -
async def get_energy(r: redis.Redis, user_id: int, _token: str) -> int: async def get_energy(r: redis.Redis, user_id: int, _token: str) -> int:
return await _get_refresh_energy(r, user_id, _token) return await _get_refresh_energy(r, user_id, _token)
def get_setting(name: str) -> decimal.Decimal:
return SETTING_DICT[name]

View File

@ -0,0 +1 @@
from .usecase import get_setting, launch_consumer

View File

@ -0,0 +1,21 @@
import decimal
import threading
_settings = dict()
mx = threading.Lock()
def get_setting(name: str) -> decimal.Decimal:
try:
mx.acquire()
return _settings[name]
finally:
mx.release()
def set_setting(name: str, value: decimal.Decimal):
try:
mx.acquire()
_settings[name] = value
finally:
mx.release()

View File

@ -0,0 +1,20 @@
import decimal
import json
import aio_pika
from typing import Callable
SETTING_QUEUE_NAME = "settings"
SETTING_TASK_NAME = "misc.celery.deliver_setting.deliver_setting"
async def consume_setting_updates(update_setting_func: Callable[[str, decimal.Decimal], None], chan: aio_pika.Channel):
queue = await chan.get_queue(SETTING_QUEUE_NAME)
async with queue.iterator() as queue_iter:
async for msg in queue_iter:
async with msg.process():
settings = json.loads(msg.body.decode('utf-8'))
for name, value in settings.items():
update_setting_func(name, value)

View File

@ -0,0 +1,16 @@
import decimal
import threading
import aio_pika
from .repos.in_memory_storage import get_setting as ims_get_setting
from .repos.rmq import consume_setting_updates
def get_setting(name: str) -> decimal.Decimal:
return ims_get_setting(name)
def launch_consumer(rmq: aio_pika.Connection):
t = threading.Thread(target=consume_setting_updates, args=(ims_get_setting, rmq))
t.start()

View File

@ -0,0 +1,2 @@
from .api import router
from .handlers import http_error_handler

View File

@ -0,0 +1,11 @@
from fastapi import APIRouter
from . import click
router = APIRouter()
def include_api_routes():
router.include_router(click.router, prefix='/v1')
include_api_routes()

View File

@ -0,0 +1,71 @@
import aio_pika
import asyncpg
import redis
from fastapi import APIRouter, Depends, HTTPException
from typing import Annotated
from ..domain.click import (
ClickResponse, BatchClickRequest, EnergyResponse, ClickValueResponse,
add_click_batch_copy, check_registration, check_energy, get_energy, click_value, delete_user_info
)
from ..dependencies import get_token_header
from ..db import get_pg, get_redis, get_rmq
from ..config import BACKEND_URL
router = APIRouter(
prefix="",
tags=['click'],
dependencies=[],
responses={404: {'description': 'Not found'}},
)
@router.post("/batch-click/", response_model=ClickResponse, status_code=200)
async def batch_click(req: BatchClickRequest, auth_info: Annotated[(int, str), Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)], r: Annotated[redis.Redis, Depends(get_redis)], rmq: Annotated[aio_pika.Channel, Depends(get_rmq)]):
user_id, token = auth_info
if not check_registration(r, user_id, token, BACKEND_URL):
raise HTTPException(status_code=403, detail='Unauthorized')
_energy, spent = await check_energy(r, user_id, req.count, token)
if spent == 0:
raise HTTPException(status_code=400, detail='No energy')
click = await add_click_batch_copy(r, pg, rmq, user_id, spent)
return ClickResponse(
click=click,
energy=_energy
)
@router.get("/energy", response_model=EnergyResponse, status_code=200)
async def energy(auth_info: Annotated[(int, str), Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)]):
user_id, token = auth_info
if not check_registration(r, user_id, token, BACKEND_URL):
raise HTTPException(status_code=403, detail='Unauthorized')
_energy = await get_energy(r, user_id, token)
return EnergyResponse(
energy=_energy
)
@router.get('/coefficient', response_model=ClickValueResponse, status_code=200)
async def coefficient(auth_info: Annotated[(int, str), Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]):
user_id, token = auth_info
if not check_registration(r, user_id, token, BACKEND_URL):
raise HTTPException(status_code=403, detail='Unauthorized')
value = await click_value(r, pg, user_id)
return ClickValueResponse(
value=value
)
@router.delete('/internal/user', status_code=204)
async def delete_user(auth_info: Annotated[(int, str), Depends(get_token_header())], r: Annotated[redis.Redis, Depends(get_redis)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]):
user_id, token = auth_info
if not check_registration(r, user_id, token, BACKEND_URL):
raise HTTPException(status_code=403, detail='Unauthorized')
await delete_user_info(r, pg, user_id)

View File

@ -0,0 +1 @@
from .http_error_handler import http_error_handler

View File

@ -0,0 +1,7 @@
from fastapi import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse
async def http_error_handler(_: Request, exc: HTTPException) -> JSONResponse:
return JSONResponse({"error": exc.detail}, status_code=exc.status_code)