Fully implemented batcher (without testing)
This commit is contained in:
parent
504adfb263
commit
62cb52a8ae
0
batcher/__init__.py
Normal file
0
batcher/__init__.py
Normal file
27
batcher/app/main.py
Normal file
27
batcher/app/main.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
from fastapi import Depends, FastAPI, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from .src.routers.api import router as router_api
|
||||
from .src.routers.handlers import http_error_handler
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI()
|
||||
|
||||
application.include_router(router_api, prefix='/api')
|
||||
|
||||
application.add_exception_handler(HTTPException, http_error_handler)
|
||||
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
app = get_application()
|
29
batcher/app/src/config.py
Normal file
29
batcher/app/src/config.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
from starlette.config import Config
|
||||
from starlette.datastructures import Secret
|
||||
from functools import lru_cache
|
||||
|
||||
config = Config('.env')
|
||||
|
||||
|
||||
REDIS_USER = config('REDIS_USER')
|
||||
REDIS_PASSWORD = config('REDIS_PASSWORD', cast=Secret)
|
||||
REDIS_PORT = config('REDIS_PORT', cast=int)
|
||||
REDIS_HOST = config('REDIS_HOST')
|
||||
REDIS_DB = config('REDIS_DB')
|
||||
|
||||
HTTP_PORT = config('HTTP_PORT', cast=int)
|
||||
|
||||
PG_HOST = config('POSTGRES_HOST')
|
||||
PG_PORT = config('POSTGRES_PORT', cast=int)
|
||||
PG_USER = config('POSTGRES_USER')
|
||||
PG_PASSWORD = config('POSTGRES_PASSWORD', cast=Secret)
|
||||
PG_DB = config('POSTGRES_DB')
|
||||
|
||||
RMQ_HOST = config('RABBITMQ_HOST')
|
||||
RMQ_PORT = config('RABBITMQ_PORT', cast=int)
|
||||
RMQ_USER = config('RABBITMQ_DEFAULT_USER')
|
||||
RMQ_PASSWORD = config('RABBITMQ_DEFAULT_PASSWORD', cast=Secret)
|
||||
|
||||
TG_TOKEN = config('TG_TOKEN', cast=Secret)
|
||||
|
||||
BACKEND_URL = config('BACKEND_URL', default='http://backend:8000')
|
3
batcher/app/src/db/__init__.py
Normal file
3
batcher/app/src/db/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .pg import get_pg
|
||||
from .redis import get_redis
|
||||
from .rmq import get_rmq
|
19
batcher/app/src/db/pg.py
Normal file
19
batcher/app/src/db/pg.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
import asyncpg
|
||||
import asyncio
|
||||
|
||||
from ..config import PG_HOST, PG_PORT, PG_USER, PG_PASSWORD, PG_DB
|
||||
|
||||
|
||||
DB_URL = f'postgresql://{PG_USER}:{str(PG_PASSWORD)}@{PG_HOST}:{PG_PORT}/{PG_DB}'
|
||||
|
||||
|
||||
async def connect_db() -> asyncpg.Pool:
|
||||
return await asyncpg.create_pool(DB_URL)
|
||||
|
||||
|
||||
pool = asyncio.run(connect_db())
|
||||
|
||||
|
||||
async def get_pg() -> asyncpg.Connection:
|
||||
async with pool.acquire() as conn:
|
||||
yield conn
|
11
batcher/app/src/db/redis.py
Normal file
11
batcher/app/src/db/redis.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
import asyncio
|
||||
import redis.asyncio as redis
|
||||
|
||||
from ..config import REDIS_HOST, REDIS_PORT, REDIS_USER, REDIS_PASSWORD, REDIS_DB
|
||||
|
||||
|
||||
r = asyncio.run(redis.Redis(host=REDIS_HOST, port=REDIS_PORT, username=REDIS_USER, password=REDIS_PASSWORD, db=REDIS_DB))
|
||||
|
||||
|
||||
def get_redis() -> redis.Redis:
|
||||
yield r
|
26
batcher/app/src/db/rmq.py
Normal file
26
batcher/app/src/db/rmq.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
import aio_pika
|
||||
from aio_pika.abc import AbstractRobustConnection
|
||||
import asyncio
|
||||
|
||||
from ..config import RMQ_HOST, RMQ_PORT, RMQ_USER, RMQ_PASSWORD
|
||||
|
||||
|
||||
async def get_connection() -> AbstractRobustConnection:
|
||||
return await aio_pika.connect_robust(f'amqp://{RMQ_USER}:{RMQ_PASSWORD}@{RMQ_HOST}:{RMQ_PORT}/')
|
||||
|
||||
|
||||
conn_pool = aio_pika.pool.Pool(get_connection, max_size=2)
|
||||
|
||||
|
||||
async def get_channel() -> aio_pika.Channel:
|
||||
async with conn_pool.acquire() as connection:
|
||||
return await connection.channel()
|
||||
|
||||
|
||||
chan_pool = aio_pika.pool.Pool(get_channel, max_size=10)
|
||||
|
||||
|
||||
async def get_rmq() -> aio_pika.Channel:
|
||||
async with chan_pool.acquire() as chan:
|
||||
yield chan
|
||||
|
52
batcher/app/src/dependencies.py
Normal file
52
batcher/app/src/dependencies.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import time
|
||||
import hmac
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
from fastapi import Header, HTTPException
|
||||
|
||||
from .config import TG_TOKEN
|
||||
|
||||
|
||||
async def get_token_header(authorization: str = Header()) -> (int, str):
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=403, detail='Unauthorized')
|
||||
|
||||
if not authorization.startswith('TelegramToken '):
|
||||
raise HTTPException(status_code=403, detail='Unauthorized')
|
||||
|
||||
token = ' '.join(authorization.split()[1:])
|
||||
|
||||
split_res = base64.b64decode(token).decode('utf-8').split(':')
|
||||
try:
|
||||
data_check_string = ':'.join(split_res[:-1]).strip().replace('/', '\\/')
|
||||
_hash = split_res[-1]
|
||||
except IndexError:
|
||||
raise HTTPException(status_code=403, detail='Unauthorized')
|
||||
secret = hmac.new(
|
||||
'WebAppData'.encode(),
|
||||
TG_TOKEN.encode('utf-8'),
|
||||
digestmod=hashlib.sha256
|
||||
).digest()
|
||||
actual_hash = hmac.new(
|
||||
secret,
|
||||
msg=data_check_string.encode('utf-8'),
|
||||
digestmod=hashlib.sha256
|
||||
).hexdigest()
|
||||
if hash != actual_hash:
|
||||
raise HTTPException(status_code=403, detail='Unauthorized')
|
||||
|
||||
data_dict = dict([x.split('=') for x in data_check_string.split('\n')])
|
||||
try:
|
||||
auth_date = int(data_dict['auth_date'])
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=403, detail='Unauthorized')
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail='Unauthorized')
|
||||
|
||||
if auth_date + 60 * 30 < int(time.time()):
|
||||
raise HTTPException(status_code=403, detail='Unauthorized')
|
||||
|
||||
user_info = json.loads(data_dict['user'])
|
||||
return user_info['id'], authorization
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
from .schemas import ClickResponse, BatchClickRequest, EnergyResponse, ClickValueResponse
|
||||
from .usecase import add_click_batch_copy, check_registration, check_energy, get_energy, click_value, delete_user_info
|
|
@ -1,29 +1,22 @@
|
|||
import json
|
||||
import kombu
|
||||
import aio_pika
|
||||
import uuid
|
||||
|
||||
from ..models import Click
|
||||
|
||||
|
||||
CELERY_QUEUE_NAME = "celery"
|
||||
SETTING_QUEUE_NAME = "settings"
|
||||
CLICK_TASK_NAME = "clicks.celery.click.handle_click"
|
||||
SETTING_TASK_NAME = "misc.celery.deliver_setting.deliver_setting"
|
||||
|
||||
|
||||
def send_click_batch_copy(conn: kombu.Connection, click: Click, count: int):
|
||||
producer = kombu.Producer(conn)
|
||||
producer.publish(
|
||||
json.dumps({
|
||||
def send_click_batch_copy(chan: aio_pika.Channel, click: Click, count: int):
|
||||
await chan.default_exchange.publish(
|
||||
message=aio_pika.Message(json.dumps({
|
||||
'id': str(uuid.uuid4()),
|
||||
'task': CLICK_TASK_NAME,
|
||||
'args': [click.UserID, int(click.DateTime.timestamp() * 1e3), str(click.Value), count],
|
||||
'kwargs': dict(),
|
||||
}),
|
||||
}).encode('utf-8')),
|
||||
routing_key=CELERY_QUEUE_NAME,
|
||||
delivery_mode='persistent',
|
||||
mandatory=False,
|
||||
immediate=False,
|
||||
content_type='application/json',
|
||||
serializer='json',
|
||||
)
|
||||
|
|
|
@ -14,5 +14,9 @@ class ClickValueResponse(pydantic.BaseModel):
|
|||
value: decimal.Decimal
|
||||
|
||||
|
||||
class EnergyResponse(pydantic.BaseModel):
|
||||
energy: int
|
||||
|
||||
|
||||
class BatchClickRequest(pydantic.BaseModel):
|
||||
count: int
|
||||
|
|
|
@ -3,9 +3,10 @@ import decimal
|
|||
from typing import Tuple
|
||||
import aiohttp
|
||||
import redis.asyncio as redis
|
||||
import kombu
|
||||
import aio_pika
|
||||
import asyncpg
|
||||
|
||||
from batcher.app.src.domain.setting import get_setting
|
||||
from .repos.redis import (
|
||||
get_period_sum, incr_period_sum, get_max_period_sum, get_user_total, get_global_average,
|
||||
incr_user_count_if_no_clicks, update_global_average, incr_user_total, compare_max_period_sum,
|
||||
|
@ -14,22 +15,13 @@ from .repos.redis import (
|
|||
)
|
||||
from .repos.pg import update_click_expiry, bulk_store_copy, delete_by_user_id
|
||||
from .repos.rmq import send_click_batch_copy
|
||||
|
||||
from .models import Click
|
||||
|
||||
|
||||
SETTING_DICT = {
|
||||
'PRICE_PER_CLICK': decimal.Decimal(1),
|
||||
'DAY_MULT': decimal.Decimal(1),
|
||||
'WEEK_MULT': decimal.Decimal(1),
|
||||
'PROGRESS_MULT': decimal.Decimal(1),
|
||||
'SESSION_ENERGY': decimal.Decimal(500),
|
||||
}
|
||||
|
||||
PRECISION = 2
|
||||
|
||||
|
||||
async def add_click_batch_copy(r: redis.Redis, pg: asyncpg.Connection, rmq: kombu.Connection, user_id: int, count: int) -> Click:
|
||||
async def add_click_batch_copy(r: redis.Redis, pg: asyncpg.Connection, rmq: aio_pika.Channel, user_id: int, count: int) -> Click:
|
||||
_click_value = await click_value(r, pg, user_id)
|
||||
click_value_sum = _click_value * count
|
||||
|
||||
|
@ -131,7 +123,3 @@ async def check_energy(r: redis.Redis, user_id: int, amount: int, _token: str) -
|
|||
|
||||
async def get_energy(r: redis.Redis, user_id: int, _token: str) -> int:
|
||||
return await _get_refresh_energy(r, user_id, _token)
|
||||
|
||||
|
||||
def get_setting(name: str) -> decimal.Decimal:
|
||||
return SETTING_DICT[name]
|
||||
|
|
1
batcher/app/src/domain/setting/__init__.py
Normal file
1
batcher/app/src/domain/setting/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .usecase import get_setting, launch_consumer
|
0
batcher/app/src/domain/setting/repos/__init__.py
Normal file
0
batcher/app/src/domain/setting/repos/__init__.py
Normal file
21
batcher/app/src/domain/setting/repos/in_memory_storage.py
Normal file
21
batcher/app/src/domain/setting/repos/in_memory_storage.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
import decimal
|
||||
import threading
|
||||
|
||||
_settings = dict()
|
||||
mx = threading.Lock()
|
||||
|
||||
|
||||
def get_setting(name: str) -> decimal.Decimal:
|
||||
try:
|
||||
mx.acquire()
|
||||
return _settings[name]
|
||||
finally:
|
||||
mx.release()
|
||||
|
||||
|
||||
def set_setting(name: str, value: decimal.Decimal):
|
||||
try:
|
||||
mx.acquire()
|
||||
_settings[name] = value
|
||||
finally:
|
||||
mx.release()
|
20
batcher/app/src/domain/setting/repos/rmq.py
Normal file
20
batcher/app/src/domain/setting/repos/rmq.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import decimal
|
||||
import json
|
||||
|
||||
import aio_pika
|
||||
from typing import Callable
|
||||
|
||||
|
||||
SETTING_QUEUE_NAME = "settings"
|
||||
SETTING_TASK_NAME = "misc.celery.deliver_setting.deliver_setting"
|
||||
|
||||
|
||||
async def consume_setting_updates(update_setting_func: Callable[[str, decimal.Decimal], None], chan: aio_pika.Channel):
|
||||
queue = await chan.get_queue(SETTING_QUEUE_NAME)
|
||||
|
||||
async with queue.iterator() as queue_iter:
|
||||
async for msg in queue_iter:
|
||||
async with msg.process():
|
||||
settings = json.loads(msg.body.decode('utf-8'))
|
||||
for name, value in settings.items():
|
||||
update_setting_func(name, value)
|
16
batcher/app/src/domain/setting/usecase.py
Normal file
16
batcher/app/src/domain/setting/usecase.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
import decimal
|
||||
import threading
|
||||
|
||||
import aio_pika
|
||||
|
||||
from .repos.in_memory_storage import get_setting as ims_get_setting
|
||||
from .repos.rmq import consume_setting_updates
|
||||
|
||||
|
||||
def get_setting(name: str) -> decimal.Decimal:
|
||||
return ims_get_setting(name)
|
||||
|
||||
|
||||
def launch_consumer(rmq: aio_pika.Connection):
|
||||
t = threading.Thread(target=consume_setting_updates, args=(ims_get_setting, rmq))
|
||||
t.start()
|
2
batcher/app/src/routers/__init__.py
Normal file
2
batcher/app/src/routers/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .api import router
|
||||
from .handlers import http_error_handler
|
11
batcher/app/src/routers/api.py
Normal file
11
batcher/app/src/routers/api.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from fastapi import APIRouter
|
||||
from . import click
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def include_api_routes():
|
||||
router.include_router(click.router, prefix='/v1')
|
||||
|
||||
|
||||
include_api_routes()
|
71
batcher/app/src/routers/click.py
Normal file
71
batcher/app/src/routers/click.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
import aio_pika
|
||||
import asyncpg
|
||||
import redis
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Annotated
|
||||
from ..domain.click import (
|
||||
ClickResponse, BatchClickRequest, EnergyResponse, ClickValueResponse,
|
||||
add_click_batch_copy, check_registration, check_energy, get_energy, click_value, delete_user_info
|
||||
)
|
||||
|
||||
from ..dependencies import get_token_header
|
||||
from ..db import get_pg, get_redis, get_rmq
|
||||
from ..config import BACKEND_URL
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="",
|
||||
tags=['click'],
|
||||
dependencies=[],
|
||||
responses={404: {'description': 'Not found'}},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/batch-click/", response_model=ClickResponse, status_code=200)
|
||||
async def batch_click(req: BatchClickRequest, auth_info: Annotated[(int, str), Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)], r: Annotated[redis.Redis, Depends(get_redis)], rmq: Annotated[aio_pika.Channel, Depends(get_rmq)]):
|
||||
user_id, token = auth_info
|
||||
if not check_registration(r, user_id, token, BACKEND_URL):
|
||||
raise HTTPException(status_code=403, detail='Unauthorized')
|
||||
|
||||
_energy, spent = await check_energy(r, user_id, req.count, token)
|
||||
if spent == 0:
|
||||
raise HTTPException(status_code=400, detail='No energy')
|
||||
|
||||
click = await add_click_batch_copy(r, pg, rmq, user_id, spent)
|
||||
return ClickResponse(
|
||||
click=click,
|
||||
energy=_energy
|
||||
)
|
||||
|
||||
|
||||
@router.get("/energy", response_model=EnergyResponse, status_code=200)
|
||||
async def energy(auth_info: Annotated[(int, str), Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)]):
|
||||
user_id, token = auth_info
|
||||
if not check_registration(r, user_id, token, BACKEND_URL):
|
||||
raise HTTPException(status_code=403, detail='Unauthorized')
|
||||
|
||||
_energy = await get_energy(r, user_id, token)
|
||||
return EnergyResponse(
|
||||
energy=_energy
|
||||
)
|
||||
|
||||
|
||||
@router.get('/coefficient', response_model=ClickValueResponse, status_code=200)
|
||||
async def coefficient(auth_info: Annotated[(int, str), Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]):
|
||||
user_id, token = auth_info
|
||||
if not check_registration(r, user_id, token, BACKEND_URL):
|
||||
raise HTTPException(status_code=403, detail='Unauthorized')
|
||||
|
||||
value = await click_value(r, pg, user_id)
|
||||
return ClickValueResponse(
|
||||
value=value
|
||||
)
|
||||
|
||||
|
||||
@router.delete('/internal/user', status_code=204)
|
||||
async def delete_user(auth_info: Annotated[(int, str), Depends(get_token_header())], r: Annotated[redis.Redis, Depends(get_redis)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]):
|
||||
user_id, token = auth_info
|
||||
if not check_registration(r, user_id, token, BACKEND_URL):
|
||||
raise HTTPException(status_code=403, detail='Unauthorized')
|
||||
|
||||
await delete_user_info(r, pg, user_id)
|
1
batcher/app/src/routers/handlers/__init__.py
Normal file
1
batcher/app/src/routers/handlers/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .http_error_handler import http_error_handler
|
7
batcher/app/src/routers/handlers/http_error_handler.py
Normal file
7
batcher/app/src/routers/handlers/http_error_handler.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
from fastapi import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
async def http_error_handler(_: Request, exc: HTTPException) -> JSONResponse:
|
||||
return JSONResponse({"error": exc.detail}, status_code=exc.status_code)
|
Loading…
Reference in New Issue
Block a user