From 78f2ebe4d5b56219541568fde9e37e1ddaa6c329 Mon Sep 17 00:00:00 2001 From: Michail Kostochka Date: Sun, 16 Mar 2025 10:56:57 +0300 Subject: [PATCH] Remove redis, refactor click batching (tested on api level) --- backend/misc/celery/deliver_setting.py | 1 - backend/misc/fixtures/setting.json | 68 +++++++ batcher/app/main.py | 5 +- batcher/app/src/config.py | 7 - batcher/app/src/db/__init__.py | 1 - .../20241023_first_down_initial.sql | 12 +- .../migrations/20241023_initial_up_first.sql | 126 ++++++++++++- batcher/app/src/db/redis.py | 14 -- batcher/app/src/domain/click/models.py | 1 + batcher/app/src/domain/click/repos/pg.py | 98 ++++++---- batcher/app/src/domain/click/repos/redis.py | 172 ------------------ batcher/app/src/domain/click/repos/rmq.py | 4 +- batcher/app/src/domain/click/usecase.py | 100 ++++------ batcher/app/src/routers/click.py | 29 ++- batcher/requirements.txt | 1 - docker-compose-prod.yml | 14 -- docker-compose.yml | 16 -- 17 files changed, 319 insertions(+), 350 deletions(-) create mode 100644 backend/misc/fixtures/setting.json delete mode 100644 batcher/app/src/db/redis.py delete mode 100644 batcher/app/src/domain/click/repos/redis.py diff --git a/backend/misc/celery/deliver_setting.py b/backend/misc/celery/deliver_setting.py index 63fa529..00a78d9 100644 --- a/backend/misc/celery/deliver_setting.py +++ b/backend/misc/celery/deliver_setting.py @@ -19,4 +19,3 @@ def deliver_setting(setting_name): routing_key=settings.SETTINGS_QUEUE_NAME, declare=[queue], ) - diff --git a/backend/misc/fixtures/setting.json b/backend/misc/fixtures/setting.json new file mode 100644 index 0000000..d86f81d --- /dev/null +++ b/backend/misc/fixtures/setting.json @@ -0,0 +1,68 @@ +[ +{ + "model": "misc.setting", + "pk": 1, + "fields": { + "name": "SESSION_ENERGY", + "description": "Энергия на сессию", + "value": { + "value": 300 + } + } +}, +{ + "model": "misc.setting", + "pk": 2, + "fields": { + "name": "PRICE_PER_CLICK", + "description": "Награда за клик", + "value": { + "value": 1 + } + } +}, +{ + "model": "misc.setting", + "pk": 3, + "fields": { + "name": "DAY_MULT", + "description": "Дневной мультипликатор", + "value": { + "value": "1.5" + } + } +}, +{ + "model": "misc.setting", + "pk": 4, + "fields": { + "name": "WEEK_MULT", + "description": "Недельный мультипликатор", + "value": { + "value": "1.5" + } + } +}, +{ + "model": "misc.setting", + "pk": 5, + "fields": { + "name": "PROGRESS_MULT", + "description": "Мультипликатор прогресса", + "value": { + "value": "1.5" + } + } +}, +{ + "model": "misc.setting", + "pk": 6, + "fields": { + "name": "SESSION_COOLDOWN", + "description": "Кулдаун сессии", + "value": { + "value": 30 + } + } +} +] diff --git a/batcher/app/main.py b/batcher/app/main.py index 76f5593..1742d7e 100644 --- a/batcher/app/main.py +++ b/batcher/app/main.py @@ -7,7 +7,7 @@ from starlette.exceptions import HTTPException from app.src.routers.api import router as router_api from app.src.routers.handlers import http_error_handler from app.src.domain.setting import launch_consumer -from app.src.db import connect_pg, connect_redis, get_connection, get_channel, get_rmq +from app.src.db import connect_pg, get_connection, get_channel, get_rmq def get_application() -> FastAPI: @@ -35,8 +35,6 @@ async def startup(): app.state.pg_pool = await connect_pg() - app.state.redis_pool = connect_redis() - rmq_conn_pool = aio_pika.pool.Pool(get_connection, max_size=2) rmq_chan_pool = aio_pika.pool.Pool(partial(get_channel, conn_pool=rmq_conn_pool), max_size=10) app.state.rmq_chan_pool = rmq_chan_pool @@ -45,5 +43,4 @@ async def startup(): @app.on_event("shutdown") async def shutdown(): await app.state.pg_pool.close() - await app.state.redis.close() diff --git a/batcher/app/src/config.py b/batcher/app/src/config.py index bb2ec0d..d5cde01 100644 --- a/batcher/app/src/config.py +++ b/batcher/app/src/config.py @@ -4,13 +4,6 @@ from functools import lru_cache config = Config() - -REDIS_USER = config('REDIS_USER') -REDIS_PASSWORD = config('REDIS_PASSWORD', cast=Secret) -REDIS_PORT = config('REDIS_PORT', cast=int) -REDIS_HOST = config('REDIS_HOST') -REDIS_DB = config('REDIS_DB') - PG_HOST = config('POSTGRES_HOST') PG_PORT = config('POSTGRES_PORT', cast=int) PG_USER = config('POSTGRES_USER') diff --git a/batcher/app/src/db/__init__.py b/batcher/app/src/db/__init__.py index ce5d7cb..92fb85b 100644 --- a/batcher/app/src/db/__init__.py +++ b/batcher/app/src/db/__init__.py @@ -1,3 +1,2 @@ from .pg import get_pg, connect_pg -from .redis import get_redis, connect_redis from .rmq import get_rmq, get_channel, get_connection \ No newline at end of file diff --git a/batcher/app/src/db/pg/migrations/20241023_first_down_initial.sql b/batcher/app/src/db/pg/migrations/20241023_first_down_initial.sql index 8929739..1d7df0c 100644 --- a/batcher/app/src/db/pg/migrations/20241023_first_down_initial.sql +++ b/batcher/app/src/db/pg/migrations/20241023_first_down_initial.sql @@ -1,2 +1,12 @@ -DROP INDEX clicks_user_id_time_idx; +DROP VIEW coefficients; + DROP TABLE clicks; + +DROP TABLE users; + +DROP TABLE global_stat; + +DROP FUNCTION raise_error; +DROP FUNCTION handle_new_click; +DROP FUNCTION handle_new_user; +DROP FUNCTION handle_user_deletion; \ No newline at end of file diff --git a/batcher/app/src/db/pg/migrations/20241023_initial_up_first.sql b/batcher/app/src/db/pg/migrations/20241023_initial_up_first.sql index bb20941..92c40fd 100644 --- a/batcher/app/src/db/pg/migrations/20241023_initial_up_first.sql +++ b/batcher/app/src/db/pg/migrations/20241023_initial_up_first.sql @@ -1,8 +1,120 @@ -CREATE TABLE IF NOT EXISTS clicks( - id BIGSERIAL PRIMARY KEY, - user_id BIGINT, - time TIMESTAMP, - value DECIMAL(100, 2), - expiry_info JSONB +CREATE TABLE users( + id BIGINT PRIMARY KEY, + energy INTEGER NOT NULL CONSTRAINT non_negative_energy CHECK (energy >= 0), + session VARCHAR(255) NOT NULL ); -CREATE INDEX IF NOT EXISTS clicks_user_id_time_idx ON clicks(user_id, time); + +CREATE TABLE clicks( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users (id) ON DELETE CASCADE, + time TIMESTAMP NOT NULL, + value DECIMAL(100, 2) NOT NULL, + count BIGINT NOT NULL +); +CREATE INDEX clicks_user_id_time_idx ON clicks(user_id, time); + +CREATE MATERIALIZED VIEW coefficients AS +SELECT + user_id, + (SELECT COUNT(*) FROM clicks AS c1 WHERE now() - time < interval '24 hours' AND c1.user_id = user_id) AS period_24, + (SELECT COUNT(*) FROM clicks AS c1 WHERE now() - time < interval '168 hours' AND c1.user_id = user_id) AS period_168, + (SELECT SUM(value * count) FROM clicks AS c1 WHERE c1.user_id = user_id) as total +FROM clicks +; + +CREATE TABLE global_stat( + id BIGINT PRIMARY KEY DEFAULT 1, + user_count BIGINT NOT NULL, + global_average DECIMAL(100, 2) NOT NULL, + max_period_24 DECIMAL(100, 2) NOT NULL, + max_period_168 DECIMAL(100, 2) NOT NULL +); + +INSERT INTO global_stat (user_count, global_average, max_period_24, max_period_168) VALUES (0, 0, 0, 0); + +CREATE OR REPLACE FUNCTION raise_error() + RETURNS TRIGGER +AS $body$ +BEGIN + RAISE EXCEPTION 'No changes allowed'; + RETURN NULL; +END; +$body$ +LANGUAGE PLPGSQL; + +CREATE TRIGGER singleton_trg + BEFORE INSERT OR DELETE OR TRUNCATE ON global_stat + FOR EACH STATEMENT EXECUTE PROCEDURE raise_error(); + +CREATE OR REPLACE FUNCTION handle_new_click() + RETURNS TRIGGER +AS $body$ +BEGIN + WITH user_stats AS ( + SELECT period_24, period_168 FROM coefficients AS c WHERE c.user_id=new.user_id + ) + UPDATE global_stat AS gs SET + global_average=(gs.global_average * gs.user_count + new.value * new.count) / gs.user_count, + max_period_24=GREATEST(us.period_24, gs.max_period_24), + max_period_168=GREATEST(us.period_168, gs.max_period_168) + FROM user_stats AS us + ; + RETURN NULL; +END; +$body$ +LANGUAGE PLPGSQL; + +CREATE TRIGGER new_click_trg + AFTER INSERT ON clicks + FOR EACH ROW EXECUTE PROCEDURE handle_new_click(); + +CREATE OR REPLACE FUNCTION handle_new_user() + RETURNS TRIGGER +AS $body$ +BEGIN + UPDATE global_stat SET + global_average=global_average * user_count / (user_count + 1), + user_count=user_count + 1 + ; + RETURN NULL; +END; +$body$ +LANGUAGE PLPGSQL; + +CREATE TRIGGER new_user_trg + AFTER INSERT ON users + FOR EACH ROW EXECUTE PROCEDURE handle_new_user(); + +CREATE OR REPLACE FUNCTION handle_user_deletion() + RETURNS TRIGGER +AS $body$ +BEGIN + UPDATE global_stat SET + global_average=global_average * user_count / (user_count - 1), + user_count=user_count - 1 + ; +END; +$body$ +LANGUAGE PLPGSQL; + +CREATE TRIGGER delete_user_trg + AFTER DELETE ON users + FOR EACH ROW EXECUTE PROCEDURE handle_user_deletion(); + +CREATE OR REPLACE FUNCTION handle_user_truncate() + RETURNS TRIGGER +AS $body$ +BEGIN + UPDATE global_stat SET + global_average=0, + user_count=0 + ; + RETURN NULL; +END; +$body$ +LANGUAGE PLPGSQL; + +CREATE TRIGGER truncate_user_trg + AFTER TRUNCATE ON users + FOR STATEMENT EXECUTE PROCEDURE handle_user_truncate(); + diff --git a/batcher/app/src/db/redis.py b/batcher/app/src/db/redis.py deleted file mode 100644 index 10a0577..0000000 --- a/batcher/app/src/db/redis.py +++ /dev/null @@ -1,14 +0,0 @@ -from starlette.requests import Request -import redis.asyncio as redis - -from ..config import REDIS_HOST, REDIS_PORT, REDIS_USER, REDIS_PASSWORD, REDIS_DB - - -def connect_redis() -> redis.ConnectionPool: - return redis.ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, username=REDIS_USER, password=str(REDIS_PASSWORD), db=REDIS_DB) - - -async def get_redis(request: Request) -> redis.Redis: - r = redis.Redis(connection_pool=request.app.state.redis_pool) - yield r - await r.aclose() diff --git a/batcher/app/src/domain/click/models.py b/batcher/app/src/domain/click/models.py index bc2c90c..feed776 100644 --- a/batcher/app/src/domain/click/models.py +++ b/batcher/app/src/domain/click/models.py @@ -7,3 +7,4 @@ class Click(pydantic.BaseModel): userId: int dateTime: datetime.datetime value: decimal.Decimal + count: int diff --git a/batcher/app/src/domain/click/repos/pg.py b/batcher/app/src/domain/click/repos/pg.py index 0c1ae28..7dedd79 100644 --- a/batcher/app/src/domain/click/repos/pg.py +++ b/batcher/app/src/domain/click/repos/pg.py @@ -1,50 +1,78 @@ -from datetime import datetime, timedelta +from typing import Tuple import decimal from asyncpg import Connection from ..models import Click -async def update_click_expiry(conn: Connection, user_id: int, period: int) -> decimal.Decimal: - cur_time = datetime.now() - cutoff_time = cur_time - timedelta(hours=period) - query = ''' - WITH expired_values AS( - UPDATE clicks - SET expiry_info=jsonb_set(expiry_info, $1, 'true') - WHERE 1=1 - AND time < $2 - AND user_id =$3 - AND not (expiry_info->>$4)::bool - RETURNING value - ) - SELECT COALESCE(SUM(value), 0) - FROM expired_values - ; - ''' - period_key = f'period_{period}' - return await conn.fetchval(query, [period_key], cutoff_time, user_id, period_key) - - async def store(conn: Connection, click: Click) -> int: query = ''' - INSERT INTO clicks(user_id, time, value, expiry_info) - VALUES($1, $2, $3, '{"period_24": false, "period_168": false}') + INSERT INTO clicks(user_id, time, value, count) + VALUES($1, $2, $3, $4) RETURNING id ; ''' - return await conn.fetchval(query, click.userId, click.dateTime, click.value) + return await conn.fetchval(query, click.userId, click.dateTime, click.value, click.count) -async def bulk_store_copy(conn: Connection, click: Click, count: int) -> None: - args = [(click.userId, click.dateTime, click.value) for _ in range(count)] - query = ''' - INSERT INTO clicks(user_id, time, value, expiry_info) - VALUES($1, $2, $3, '{"period_24": false, "period_168": false}') - ; - ''' - await conn.executemany(query, args) +async def delete_user_info(conn: Connection, user_id: int): + async with conn.transaction(): + await conn.execute('DELETE FROM clicks WHERE user_id=$1', user_id) + await conn.execute('DELTE FROM users WHERE id=$1', user_id) -async def delete_by_user_id(conn: Connection, user_id: int): - await conn.execute('DELETE FROM clicks WHERE user_id=$1', user_id) +async def get_period_sum(conn: Connection, user_id: int, period: int) -> decimal.Decimal: + if not isinstance(period, int): + raise ValueError('period must be an integer') + return (await conn.fetchval(f'SELECT period_{period} FROM coefficients WHERE user_id=$1', user_id)) or decimal.Decimal(0) + + +async def get_max_period_sum(conn: Connection, period: int) -> decimal.Decimal: + if not isinstance(period, int): + raise ValueError('period must be an integer') + return (await conn.fetchval(f'SELECT max_period_{period} FROM global_stat')) or decimal.Decimal(0) + + +async def get_global_average(conn: Connection) -> decimal.Decimal: + return (await conn.fetchval('SELECT global_average FROM global_stat')) or decimal.Decimal(0) + + +async def get_user_total(conn: Connection, user_id: int) -> decimal.Decimal: + return (await conn.fetchval('SELECT total FROM coefficients WHERE user_id=$1', user_id)) or decimal.Decimal(0) + + +async def user_exists(conn: Connection, user_id: int) -> bool: + return await conn.fetchval('SELECT EXISTS(SELECT 1 FROM users WHERE id=$1)', user_id)\ + + +async def add_user(conn: Connection, user_id: int, req_token: str, session_energy: int): + await conn.execute( + 'INSERT INTO users (id, session, energy) VALUES ($1, $2, $3)', + user_id, req_token, session_energy + ) + + +async def get_user_session(conn: Connection, user_id: int) -> str: + return await conn.fetchval('SELECT session FROM users WHERE id=$1', user_id) + + +async def set_new_session(conn: Connection, user_id: int, req_token: str, session_energy: int): + await conn.execute('UPDATE users SET session=$1, energy=$2 WHERE id=$3', req_token, session_energy, user_id) + + +async def get_energy(conn: Connection, user_id: int) -> int: + return await conn.fetchval('SELECT energy FROM users WHERE id=$1', user_id) + + +async def decr_energy(conn: Connection, user_id: int, amount: int) -> Tuple[int, int]: + new_energy, spent = await conn.fetchrow(''' + WITH energy_cte AS ( + SELECT energy, (CASE WHEN energy < $2 THEN energy ELSE $2 END) AS delta FROM users WHERE id=$1 + ) + UPDATE users AS u SET + energy=u.energy - e.delta + FROM energy_cte AS e + WHERE id=$1 + RETURNING u.energy as new_energy, e.delta + ''', user_id, amount) + return new_energy, spent \ No newline at end of file diff --git a/batcher/app/src/domain/click/repos/redis.py b/batcher/app/src/domain/click/repos/redis.py deleted file mode 100644 index 68b67c3..0000000 --- a/batcher/app/src/domain/click/repos/redis.py +++ /dev/null @@ -1,172 +0,0 @@ -import decimal -from typing import Optional, List - -import redis.asyncio as redis - - -async def get_period_sum(r: redis.Redis, user_id: int, period: int) -> decimal.Decimal: - sum_bytes = await r.get(f'period_{period}_user_{user_id}') - if sum_bytes is None: - return decimal.Decimal(0) - return decimal.Decimal(sum_bytes.decode()) - - -async def incr_period_sum(r: redis.Redis, user_id: int, _period: int, value: decimal.Decimal) -> decimal.Decimal: - return await r.incrbyfloat(f'period_{_period}_user_{user_id}', float(value)) - - -async def get_max_period_sum(r: redis.Redis, _period: int) -> decimal.Decimal: - max_sum_bytes = await r.get(f'max_period_{_period}') - if max_sum_bytes is None: - return decimal.Decimal(0) - return decimal.Decimal(max_sum_bytes.decode()) - - -async def compare_max_period_sum(r: redis.Redis, _period: int, _sum: decimal.Decimal) -> None: - _script = r.register_script(''' - local currentValue = tonumber(redis.call('GET', KEYS[1])) - local cmpValue = tonumber(ARGV[1]) - if not currentValue or cmpValue > currentValue then - redis.call('SET', KEYS[1], ARGV[1]) - return cmpValue - else - return currentValue - end - ''') - await _script(keys=[f'max_period_{_period}'], args=[str(_sum)]) - - -async def get_energy(r: redis.Redis, user_id: int) -> int: - energy_str = await r.get(f'energy_{user_id}') - if energy_str is None: - return 0 - return int(energy_str) - - -async def set_energy(r: redis.Redis, user_id: int, energy: int) -> int: - await r.set(f'energy_{user_id}', energy) - - -async def decr_energy(r: redis.Redis, user_id: int, amount: int) -> (int, int): - _script = r.register_script(''' - local energy = tonumber(redis.call('GET', KEYS[1])) - local delta = tonumber(ARGV[1]) - if energy < delta then - redis.call('SET', KEYS[1], 0) - return {0, energy} - else - local newEnergy = tonumber(redis.call('DECRBY', KEYS[1], ARGV[1])) - return {newEnergy, delta} - end - ''') - new_energy, spent= map(int, await _script(keys=[f'energy_{user_id}'], args=[amount])) - return new_energy, spent - - -async def get_global_average(r: redis.Redis) -> decimal.Decimal: - avg_bytes = await r.get('global_average') - if avg_bytes is None: - return decimal.Decimal(0) - return decimal.Decimal(avg_bytes.decode()) - - -async def update_global_average(r: redis.Redis, value_to_add: decimal.Decimal) -> decimal.Decimal: - _script = r.register_script(''' - local delta = tonumber(ARGV[1]) / tonumber(redis.call('GET', KEYS[1])) - return redis.call('INCRBYFLOAT', KEYS[2], delta) - ''') - return decimal.Decimal((await _script(keys=["user_count", "global_average"], args=[float(value_to_add)])).decode()) - - -async def get_user_total(r: redis.Redis, user_id: int) -> decimal.Decimal: - total_bytes = await r.get(f'total_{user_id}') - if total_bytes is None: - return decimal.Decimal(0) - return decimal.Decimal(total_bytes.decode()) - - -async def incr_user_count_if_no_clicks(r: redis.Redis, user_id: int) -> int: - _script = r.register_script(''' - local clickCount = tonumber(redis.call('GET', KEYS[1])) - local userCount = tonumber(redis.call('GET', KEYS[2])) - if (not clickCount) then - local oldUserCount = redis.call('GET', KEYS[2]) - if (not oldUserCount) then - redis.call('SET', KEYS[2], 1) - redis.call('SET', KEYS[3], 0) - return 1 - end - userCount = tonumber(redis.call('INCR', KEYS[2])) - oldUserCount = tonumber(oldUserCount) - local globalAverage = tonumber(redis.call('GET', KEYS[3])) - redis.call('SET', KEYS[3], globalAverage / userCount * oldUserCount) - end - return userCount - ''') - return int(await _script(keys=[f'total_{user_id}', 'user_count', 'global_average'], args=[])) - - -async def incr_user_total(r: redis.Redis, user_id: int, value: decimal.Decimal) -> decimal.Decimal: - return await r.incrbyfloat(f'total_{user_id}', float(value)) - - -async def get_user_session(r: redis.Redis, user_id: int) -> Optional[str]: - session_bytes = await r.get(f'session_{user_id}') - if session_bytes is None: - return None - return session_bytes.decode() - - -async def set_user_session(r: redis.Redis, user_id: int, token: str) -> None: - await r.set(f'session_{user_id}', token, ex=30 * 60) - - -async def get_user_count(r: redis.Redis) -> int: - user_count_str = await r.get('user_count') - if user_count_str is None: - return 0 - return int(user_count_str) - - -async def incr_user_count(r: redis.Redis) -> int: - _script = r.register_script(''' - local oldCount = redis.call('GET', KEYS[1]) - if (not oldCount) then - redis.call('SET', KEYS[1], 1) - redis.call('SET', KEYS[2], 0) - return 1 - end - local newCount = tonumber(redis.call('INCR', KEYS[1])) - local globalAverage = tonumber(redis.call('GET', KEYS[2])) - redis.call('SET', KEYS[2], globalAverage / newCount * oldCount) - return newCount - ''') - return int(await _script(keys=['user_count', 'global_average'], args=[])) - - -async def delete_user_info(r: redis.Redis, user_id: int, periods: List[int]): - _script = r.register_script(''' - local userTotal = redis.call('GET', KEYS[3]) - if (not userTotal) then - return - end - local oldUserCount = tonumber(redis.call('GET', KEYS[1])) - local newUserCount = tonumber(redis.call('DECR', KEYS[1])) - local globalAverage = tonumber(redis.call('GET', KEYS[2])) - redis.call('SET', KEYS[2], (globalAverage * oldUserCount - userTotal) / newUserCount) - for i, v in ipairs(KEYS) do - if (i > 2) then - redis.call('DEL', v) - end - end - ''') - keys = [ - 'user_count', - 'global_average', - f'total_{user_id}' - f'energy_{user_id}', - f'session_{user_id}', - ] - for period in periods: - keys.append(f'period_{period}_user_{user_id}') - await _script(keys=keys, args=[]) diff --git a/batcher/app/src/domain/click/repos/rmq.py b/batcher/app/src/domain/click/repos/rmq.py index 598bc16..64fab94 100644 --- a/batcher/app/src/domain/click/repos/rmq.py +++ b/batcher/app/src/domain/click/repos/rmq.py @@ -10,8 +10,8 @@ CELERY_QUEUE_NAME = "celery" CLICK_TASK_NAME = "clicks.celery.click.handle_click" -async def send_click_batch_copy(chan: aio_pika.Channel, click: Click, count: int): - args = (click.userId, int(click.dateTime.timestamp() * 1e3), str(click.value), count) +async def send_click(chan: aio_pika.Channel, click: Click): + args = (click.userId, int(click.dateTime.timestamp() * 1e3), str(click.value), click.count) await chan.default_exchange.publish( message=aio_pika.Message( body=json.dumps([ diff --git a/batcher/app/src/domain/click/usecase.py b/batcher/app/src/domain/click/usecase.py index 774218d..1bec6e6 100644 --- a/batcher/app/src/domain/click/usecase.py +++ b/batcher/app/src/domain/click/usecase.py @@ -2,87 +2,70 @@ from datetime import datetime import decimal from typing import Tuple import aiohttp -import redis.asyncio as redis import aio_pika import asyncpg import base64 from fastapi.exceptions import HTTPException from app.src.domain.setting import get_setting -from .repos.redis import ( - get_period_sum, incr_period_sum, get_max_period_sum, get_user_total, get_global_average, - incr_user_count_if_no_clicks, update_global_average, incr_user_total, compare_max_period_sum, - delete_user_info as r_delete_user_info, get_user_session, set_user_session, set_energy, get_energy as r_get_energy, - decr_energy, +from .repos.pg import ( + store, delete_user_info as pg_delete_user_info, get_period_sum, get_max_period_sum, get_global_average, get_user_total, user_exists, get_user_session, + set_new_session, get_energy as pg_get_energy, decr_energy, add_user ) -from .repos.pg import update_click_expiry, bulk_store_copy, delete_by_user_id -from .repos.rmq import send_click_batch_copy +from .repos.rmq import send_click from .models import Click PRECISION = 2 -async def add_click_batch_copy(r: redis.Redis, pg: asyncpg.Connection, rmq: aio_pika.Channel, user_id: int, count: int) -> Click: - _click_value = await click_value(r, pg, user_id) - click_value_sum = _click_value * count - - # update variables - await incr_user_count_if_no_clicks(r, user_id) - await update_global_average(r, click_value_sum) - await incr_user_total(r, user_id, click_value_sum) - - for period in (24, 24*7): - new_period_sum = await incr_period_sum(r, user_id, period, click_value_sum) - await compare_max_period_sum(r, period, new_period_sum) +async def add_click_batch_copy(pg: asyncpg.Connection, rmq: aio_pika.Channel, user_id: int, count: int) -> Click: + _click_value = await click_value(pg, user_id) click = Click( userId=user_id, dateTime=datetime.now(), value=_click_value, + count=count, ) # insert click - await bulk_store_copy(pg, click, count) + await store(pg, click) # send click to backend - await send_click_batch_copy(rmq, click, count) + await send_click(rmq, click) return click -async def delete_user_info(r: redis.Redis, pg: asyncpg.Connection, user_id: int) -> None: - await r_delete_user_info(r, user_id, [24, 168]) - await delete_by_user_id(pg, user_id) +async def delete_user_info(pg: asyncpg.Connection, user_id: int) -> None: + await pg_delete_user_info(pg, user_id) -async def click_value(r: redis.Redis, pg: asyncpg.Connection, user_id: int) -> decimal.Decimal: +async def click_value(pg: asyncpg.Connection, user_id: int) -> decimal.Decimal: price_per_click = get_setting('PRICE_PER_CLICK') day_multiplier = get_setting('DAY_MULT') week_multiplier = get_setting('WEEK_MULT') progress_multiplier = get_setting('PROGRESS_MULT') # period coefficients - day_coef = await period_coefficient(r, pg, user_id, 24, day_multiplier) - week_coef = await period_coefficient(r, pg, user_id, 24*7, week_multiplier) + day_coef = await period_coefficient(pg, user_id, 24, day_multiplier) + week_coef = await period_coefficient(pg, user_id, 24*7, week_multiplier) # progress coefficient - user_total = await get_user_total(r, user_id) - global_avg = await get_global_average(r) + user_total = await get_user_total(pg, user_id) + global_avg = await get_global_average(pg) progress_coef = progress_coefficient(user_total, global_avg, progress_multiplier) return round(price_per_click * day_coef * week_coef * progress_coef, PRECISION) -async def period_coefficient(r: redis.Redis, pg: asyncpg.Connection, user_id: int, period: int, multiplier: decimal.Decimal) -> decimal.Decimal: - current_sum = await get_period_sum(r, user_id, period) - expired_sum = await update_click_expiry(pg, user_id, period) - new_sum = current_sum - expired_sum - await incr_period_sum(r, user_id, period, -expired_sum) - max_period_sum = await get_max_period_sum(r, period) +async def period_coefficient(pg: asyncpg.Connection, user_id: int, period: int, multiplier: decimal.Decimal) -> decimal.Decimal: + current_period_sum = await get_period_sum(pg, user_id, period) + max_period_sum = await get_max_period_sum(pg, period) if max_period_sum == decimal.Decimal(0): return decimal.Decimal(1) - return new_sum * multiplier / max_period_sum + 1 + return current_period_sum * multiplier / max_period_sum + 1 def progress_coefficient(user_total: decimal.Decimal, global_avg: decimal.Decimal, multiplier: decimal.Decimal) -> decimal.Decimal: @@ -91,34 +74,31 @@ def progress_coefficient(user_total: decimal.Decimal, global_avg: decimal.Decima return min(global_avg * multiplier / user_total + 1, decimal.Decimal(2)) -async def check_registration(r: redis.Redis, user_id: int, _token: str, backend_url: str) -> bool: - if await _has_any_clicks(r, user_id): +async def check_registration(pg: asyncpg.Connection, user_id: int, _token: str, backend_url: str) -> bool: + if await user_exists(pg, user_id): return True async with aiohttp.ClientSession() as session: async with session.get(f'{backend_url}/api/v1/users/{user_id}', headers={'Authorization': f'TelegramToken {_token}'}) as resp: return resp.status == 200 -async def _has_any_clicks(r: redis.Redis, user_id: int) -> bool: - total_value = await get_user_total(r, user_id) - return total_value > decimal.Decimal(0) - - -async def _get_refresh_energy(r: redis.Redis, user_id: int, req_token: str) -> int: +async def _get_refresh_energy(pg: asyncpg.Connection, user_id: int, req_token: str) -> int: new_auth_date = _auth_date_from_token(req_token) - current_token = await get_user_session(r, user_id) - if current_token != req_token: - if current_token is not None: - last_auth_date = _auth_date_from_token(current_token) - session_cooldown = get_setting('SESSION_COOLDOWN') - if new_auth_date - last_auth_date < session_cooldown: - raise HTTPException(status_code=403, detail='Unauthorized') + current_token = await get_user_session(pg, user_id) + if current_token is None: session_energy = int(get_setting('SESSION_ENERGY')) - await set_user_session(r, user_id, req_token) - await set_energy(r, user_id, session_energy) + await add_user(pg, user_id, req_token, session_energy) + return session_energy + if current_token != req_token: + last_auth_date = _auth_date_from_token(current_token) + session_cooldown = get_setting('SESSION_COOLDOWN') + if new_auth_date - last_auth_date < session_cooldown: + raise HTTPException(status_code=403, detail='Unauthorized') + session_energy = int(get_setting('SESSION_ENERGY')) + await set_new_session(pg, user_id, req_token, session_energy) return session_energy else: - return await r_get_energy(r, user_id) + return await pg_get_energy(pg, user_id) def _auth_date_from_token(token): split_res = base64.b64decode(token).decode('utf-8').split(':') @@ -127,12 +107,12 @@ def _auth_date_from_token(token): return int(data_dict['auth_date']) -async def check_energy(r: redis.Redis, user_id: int, amount: int, _token: str) -> Tuple[int, int]: - _energy = await _get_refresh_energy(r, user_id, _token) +async def check_energy(pg: asyncpg.Connection, user_id: int, amount: int, _token: str) -> Tuple[int, int]: + _energy = await _get_refresh_energy(pg, user_id, _token) if _energy == 0: return 0, 0 - return await decr_energy(r, user_id, amount) + return await decr_energy(pg, user_id, amount) -async def get_energy(r: redis.Redis, user_id: int, _token: str) -> int: - return await _get_refresh_energy(r, user_id, _token) +async def get_energy(pg: asyncpg.Connection, user_id: int, _token: str) -> int: + return await _get_refresh_energy(pg, user_id, _token) diff --git a/batcher/app/src/routers/click.py b/batcher/app/src/routers/click.py index 9d66857..809aafd 100644 --- a/batcher/app/src/routers/click.py +++ b/batcher/app/src/routers/click.py @@ -1,6 +1,5 @@ import aio_pika import asyncpg -import redis from fastapi import APIRouter, Depends, HTTPException from typing import Annotated, Tuple from ..domain.click import ( @@ -9,7 +8,7 @@ from ..domain.click import ( ) from ..dependencies import get_token_header -from ..db import get_pg, get_redis, get_rmq +from ..db import get_pg, get_rmq from ..config import BACKEND_URL @@ -22,16 +21,16 @@ router = APIRouter( @router.post("/batch-click/", response_model=ClickResponse, status_code=200) -async def batch_click(req: BatchClickRequest, auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)], r: Annotated[redis.Redis, Depends(get_redis)], rmq: Annotated[aio_pika.Channel, Depends(get_rmq)]): +async def batch_click(req: BatchClickRequest, auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)], rmq: Annotated[aio_pika.Channel, Depends(get_rmq)]): user_id, token = auth_info - if not await check_registration(r, user_id, token, BACKEND_URL): + if not await check_registration(pg, user_id, token, BACKEND_URL): raise HTTPException(status_code=403, detail='Unauthorized') - _energy, spent = await check_energy(r, user_id, req.count, token) + _energy, spent = await check_energy(pg, user_id, req.count, token) if spent == 0: raise HTTPException(status_code=400, detail='No energy') - click = await add_click_batch_copy(r, pg, rmq, user_id, spent) + click = await add_click_batch_copy(pg, rmq, user_id, spent) return ClickResponse( click=click, energy=_energy @@ -39,33 +38,33 @@ async def batch_click(req: BatchClickRequest, auth_info: Annotated[Tuple[int, st @router.get("/energy", response_model=EnergyResponse, status_code=200) -async def energy(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)]): +async def energy(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]): user_id, token = auth_info - if not await check_registration(r, user_id, token, BACKEND_URL): + if not await check_registration(pg, user_id, token, BACKEND_URL): raise HTTPException(status_code=403, detail='Unauthorized') - _energy = await get_energy(r, user_id, token) + _energy = await get_energy(pg, user_id, token) return EnergyResponse( energy=_energy ) @router.get('/coefficient', response_model=ClickValueResponse, status_code=200) -async def coefficient(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]): +async def coefficient(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]): user_id, token = auth_info - if not await check_registration(r, user_id, token, BACKEND_URL): + if not await check_registration(pg, user_id, token, BACKEND_URL): raise HTTPException(status_code=403, detail='Unauthorized') - value = await click_value(r, pg, user_id) + value = await click_value(pg, pg, user_id) return ClickValueResponse( value=value ) @router.delete('/internal/user', status_code=204) -async def delete_user(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]): +async def delete_user(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]): user_id, token = auth_info - if not await check_registration(r, user_id, token, BACKEND_URL): + if not await check_registration(pg, user_id, token, BACKEND_URL): raise HTTPException(status_code=403, detail='Unauthorized') - await delete_user_info(r, pg, user_id) \ No newline at end of file + await delete_user_info(pg, user_id) \ No newline at end of file diff --git a/batcher/requirements.txt b/batcher/requirements.txt index 9849d19..b264ac9 100644 --- a/batcher/requirements.txt +++ b/batcher/requirements.txt @@ -20,7 +20,6 @@ pamqp==3.3.0 propcache==0.2.0 pydantic==2.9.2 pydantic_core==2.23.4 -redis==5.1.1 requests==2.32.3 sniffio==1.3.1 starlette==0.40.0 diff --git a/docker-compose-prod.yml b/docker-compose-prod.yml index 8702e85..52add9c 100644 --- a/docker-compose-prod.yml +++ b/docker-compose-prod.yml @@ -1,7 +1,6 @@ volumes: db_data: {} batcher_db_data: {} - redis_data: {} backend_media: {} backend_static: {} @@ -84,26 +83,13 @@ services: interval: 10s timeout: 2s - redis: - env_file: - - .env/prod/redis - image: redis - command: bash -c "redis-server --appendonly yes --requirepass $${REDIS_PASSWORD}" - volumes: - - redis_data:/data - healthcheck: - <<: *pg-healthcheck - test: "[ $$(redis-cli -a $${REDIS_PASSWORD} ping) = 'PONG' ]" - batcher: build: ./batcher depends_on: - redis: *healthy-dependency batcher-postgres: *healthy-dependency rabbitmq: *healthy-dependency env_file: - .env/prod/rmq - - .env/prod/redis - .env/prod/batcher-pg - .env/prod/batcher - .env/prod/bot diff --git a/docker-compose.yml b/docker-compose.yml index 4419e1c..6057bf9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,6 @@ volumes: db_data: {} batcher_db_data: {} - redis_data: {} backend_media: {} backend_static: {} @@ -66,28 +65,13 @@ services: interval: 10s timeout: 2s - redis: - env_file: - - .env/dev/redis - image: redis - command: bash -c "redis-server --appendonly yes --requirepass $${REDIS_PASSWORD}" - ports: - - '6379:6379' - volumes: - - redis_data:/data - healthcheck: - <<: *pg-healthcheck - test: "[ $$(redis-cli -a $$REDIS_PASSWORD ping) = 'PONG' ]" - batcher: build: ./batcher depends_on: - redis: *healthy-dependency batcher-postgres: *healthy-dependency rabbitmq: *healthy-dependency env_file: - .env/dev/rmq - - .env/dev/redis - .env/dev/batcher-pg - .env/dev/batcher - .env/dev/bot -- 2.34.1