From 9a052700dd41e2c1194f1496ed79de77b6136665 Mon Sep 17 00:00:00 2001 From: Michail Kostochka Date: Sun, 11 May 2025 10:49:39 +0300 Subject: [PATCH] Refactor settings in batcher --- backend/scripts/start_celery.sh | 2 +- batcher/app/main.py | 4 +-- .../20241023_first_down_initial.sql | 2 ++ .../migrations/20241023_initial_up_first.sql | 5 ++++ batcher/app/src/db/rmq.py | 1 - batcher/app/src/domain/click/usecase.py | 14 +++++----- .../domain/setting/repos/in_memory_storage.py | 21 --------------- batcher/app/src/domain/setting/repos/pg.py | 16 ++++++++++++ batcher/app/src/domain/setting/repos/rmq.py | 10 ++++--- batcher/app/src/domain/setting/usecase.py | 26 ++++++++++++------- 10 files changed, 55 insertions(+), 46 deletions(-) delete mode 100644 batcher/app/src/domain/setting/repos/in_memory_storage.py create mode 100644 batcher/app/src/domain/setting/repos/pg.py diff --git a/backend/scripts/start_celery.sh b/backend/scripts/start_celery.sh index d94e00c..707965b 100755 --- a/backend/scripts/start_celery.sh +++ b/backend/scripts/start_celery.sh @@ -1,7 +1,7 @@ #!/bin/sh for i in $(seq 1 "${CELERY_WORKER_COUNT}"); do - celery -A clicker worker -l info --concurrency=10 -n "worker${i}@$(%h)" + celery -A clicker worker -l info --concurrency=10 -n "worker${i}" done celery -A clicker beat -l info diff --git a/batcher/app/main.py b/batcher/app/main.py index 1742d7e..606302d 100644 --- a/batcher/app/main.py +++ b/batcher/app/main.py @@ -7,7 +7,7 @@ from starlette.exceptions import HTTPException from app.src.routers.api import router as router_api from app.src.routers.handlers import http_error_handler from app.src.domain.setting import launch_consumer -from app.src.db import connect_pg, get_connection, get_channel, get_rmq +from app.src.db import connect_pg, get_connection, get_channel, get_rmq, get_pg def get_application() -> FastAPI: @@ -31,7 +31,7 @@ app = get_application() @app.on_event("startup") async def startup(): - launch_consumer(get_connection) + launch_consumer(connect_pg, get_connection) app.state.pg_pool = await connect_pg() diff --git a/batcher/app/src/db/pg/migrations/20241023_first_down_initial.sql b/batcher/app/src/db/pg/migrations/20241023_first_down_initial.sql index 1d7df0c..09fafa9 100644 --- a/batcher/app/src/db/pg/migrations/20241023_first_down_initial.sql +++ b/batcher/app/src/db/pg/migrations/20241023_first_down_initial.sql @@ -2,6 +2,8 @@ DROP VIEW coefficients; DROP TABLE clicks; +DROP TABLE settings; + DROP TABLE users; DROP TABLE global_stat; diff --git a/batcher/app/src/db/pg/migrations/20241023_initial_up_first.sql b/batcher/app/src/db/pg/migrations/20241023_initial_up_first.sql index 92c40fd..08e64c9 100644 --- a/batcher/app/src/db/pg/migrations/20241023_initial_up_first.sql +++ b/batcher/app/src/db/pg/migrations/20241023_initial_up_first.sql @@ -13,6 +13,11 @@ CREATE TABLE clicks( ); CREATE INDEX clicks_user_id_time_idx ON clicks(user_id, time); +CREATE TABLE settings( + name VARCHAR(255) PRIMARY KEY, + value DECIMAL(100, 2) NOT NULL +); + CREATE MATERIALIZED VIEW coefficients AS SELECT user_id, diff --git a/batcher/app/src/db/rmq.py b/batcher/app/src/db/rmq.py index 3287a67..55aeb04 100644 --- a/batcher/app/src/db/rmq.py +++ b/batcher/app/src/db/rmq.py @@ -8,7 +8,6 @@ from ..config import RMQ_HOST, RMQ_PORT, RMQ_USER, RMQ_PASSWORD fqdn = f'amqp://{RMQ_USER}:{str(RMQ_PASSWORD)}@{RMQ_HOST}:{RMQ_PORT}/' -logger = logging.getLogger("uvicorn") async def get_connection() -> AbstractRobustConnection: return await aio_pika.connect_robust(fqdn) diff --git a/batcher/app/src/domain/click/usecase.py b/batcher/app/src/domain/click/usecase.py index 1bec6e6..9976f7c 100644 --- a/batcher/app/src/domain/click/usecase.py +++ b/batcher/app/src/domain/click/usecase.py @@ -43,10 +43,10 @@ async def delete_user_info(pg: asyncpg.Connection, user_id: int) -> None: async def click_value(pg: asyncpg.Connection, user_id: int) -> decimal.Decimal: - price_per_click = get_setting('PRICE_PER_CLICK') - day_multiplier = get_setting('DAY_MULT') - week_multiplier = get_setting('WEEK_MULT') - progress_multiplier = get_setting('PROGRESS_MULT') + price_per_click = await get_setting(pg, 'PRICE_PER_CLICK') + day_multiplier = await get_setting(pg, 'DAY_MULT') + week_multiplier = await get_setting(pg, 'WEEK_MULT') + progress_multiplier = await get_setting(pg, 'PROGRESS_MULT') # period coefficients day_coef = await period_coefficient(pg, user_id, 24, day_multiplier) @@ -86,15 +86,15 @@ async def _get_refresh_energy(pg: asyncpg.Connection, user_id: int, req_token: s new_auth_date = _auth_date_from_token(req_token) current_token = await get_user_session(pg, user_id) if current_token is None: - session_energy = int(get_setting('SESSION_ENERGY')) + session_energy = int(await get_setting(pg, 'SESSION_ENERGY')) await add_user(pg, user_id, req_token, session_energy) return session_energy if current_token != req_token: last_auth_date = _auth_date_from_token(current_token) - session_cooldown = get_setting('SESSION_COOLDOWN') + session_cooldown = await get_setting(pg, 'SESSION_COOLDOWN') if new_auth_date - last_auth_date < session_cooldown: raise HTTPException(status_code=403, detail='Unauthorized') - session_energy = int(get_setting('SESSION_ENERGY')) + session_energy = int(await get_setting(pg, 'SESSION_ENERGY')) await set_new_session(pg, user_id, req_token, session_energy) return session_energy else: diff --git a/batcher/app/src/domain/setting/repos/in_memory_storage.py b/batcher/app/src/domain/setting/repos/in_memory_storage.py deleted file mode 100644 index ac6e46c..0000000 --- a/batcher/app/src/domain/setting/repos/in_memory_storage.py +++ /dev/null @@ -1,21 +0,0 @@ -import decimal -import threading - -_settings = dict() -mx = threading.Lock() - - -def get_setting(name: str) -> decimal.Decimal: - try: - mx.acquire() - return _settings[name] - finally: - mx.release() - - -def set_setting(name: str, value: decimal.Decimal): - try: - mx.acquire() - _settings[name] = value - finally: - mx.release() diff --git a/batcher/app/src/domain/setting/repos/pg.py b/batcher/app/src/domain/setting/repos/pg.py new file mode 100644 index 0000000..164eea5 --- /dev/null +++ b/batcher/app/src/domain/setting/repos/pg.py @@ -0,0 +1,16 @@ +from decimal import Decimal +from asyncpg import Connection + + +async def get_setting(conn: Connection, name: str) -> Decimal: + return await conn.fetchval('SELECT value FROM settings WHERE name=$1', name) + + +async def set_setting(conn: Connection, name: str, value: Decimal): + query = ''' + INSERT INTO settings (name, value) + VALUES ($1, $2) + ON CONFLICT(name) DO UPDATE + SET value=$2 + ''' + await conn.execute(query, name, value) \ No newline at end of file diff --git a/batcher/app/src/domain/setting/repos/rmq.py b/batcher/app/src/domain/setting/repos/rmq.py index f2396f6..1bf1c92 100644 --- a/batcher/app/src/domain/setting/repos/rmq.py +++ b/batcher/app/src/domain/setting/repos/rmq.py @@ -2,16 +2,18 @@ import decimal import json import aio_pika from typing import Callable - +import asyncpg SETTING_QUEUE_NAME = "settings" -async def consume_setting_updates(set_setting_func: Callable[[str, decimal.Decimal], None], chan: aio_pika.abc.AbstractChannel): + +async def consume_setting_updates(pg_pool: asyncpg.Pool, set_setting_func: Callable[[str, decimal.Decimal], None], chan: aio_pika.abc.AbstractChannel): queue = await chan.declare_queue(SETTING_QUEUE_NAME, durable=True) async with queue.iterator() as queue_iter: async for msg in queue_iter: async with msg.process(): settings = json.loads(msg.body.decode('utf-8')) - for name, value in settings.items(): - set_setting_func(name, decimal.Decimal(value)) + async with pg_pool.acquire() as pg_conn: + for name, value in settings.items(): + await set_setting_func(pg_conn, name, decimal.Decimal(value)) diff --git a/batcher/app/src/domain/setting/usecase.py b/batcher/app/src/domain/setting/usecase.py index babb8a1..2c6ee07 100644 --- a/batcher/app/src/domain/setting/usecase.py +++ b/batcher/app/src/domain/setting/usecase.py @@ -3,22 +3,28 @@ import threading import asyncio from collections.abc import Callable, Awaitable import aio_pika +import asyncpg -from .repos.in_memory_storage import set_setting, get_setting as ims_get_setting +from .repos.pg import set_setting, get_setting as pg_get_setting from .repos.rmq import consume_setting_updates -def get_setting(name: str) -> decimal.Decimal: - return ims_get_setting(name) +def get_setting(pg: asyncpg.Connection, name: str) -> decimal.Decimal: + return pg_get_setting(pg, name) -async def start_thread(rmq_connect_func: Callable[[], Awaitable[aio_pika.abc.AbstractRobustConnection]], *args): + +async def start_thread(connect_pg: Callable[[], Awaitable[asyncpg.Pool]], rmq_connect_func: Callable[[], Awaitable[aio_pika.abc.AbstractRobustConnection]], *args): + pg_pool = await connect_pg() conn = await rmq_connect_func() - async with conn: - chan = await conn.channel() - await consume_setting_updates(set_setting, chan) + try: + async with conn: + chan = await conn.channel() + await consume_setting_updates(pg_pool, set_setting, chan) + finally: + await pg_pool.close() + - -def launch_consumer(rmq_connect_func: Callable[[], Awaitable[aio_pika.abc.AbstractRobustConnection]]): - t = threading.Thread(target=asyncio.run, args=(start_thread(rmq_connect_func),)) +def launch_consumer(connect_pg: Callable[[], Awaitable[asyncpg.Pool]], rmq_connect_func: Callable[[], Awaitable[aio_pika.abc.AbstractRobustConnection]]): + t = threading.Thread(target=asyncio.run, args=(start_thread(connect_pg, rmq_connect_func),)) t.start() -- 2.34.1