Remove redis, refactor click batching (tested on api level)

This commit is contained in:
Michail Kostochka 2025-03-16 10:56:57 +03:00
parent dc58c8db45
commit 78f2ebe4d5
17 changed files with 319 additions and 350 deletions

View File

@ -19,4 +19,3 @@ def deliver_setting(setting_name):
routing_key=settings.SETTINGS_QUEUE_NAME, routing_key=settings.SETTINGS_QUEUE_NAME,
declare=[queue], declare=[queue],
) )

View File

@ -0,0 +1,68 @@
[
{
"model": "misc.setting",
"pk": 1,
"fields": {
"name": "SESSION_ENERGY",
"description": "Энергия на сессию",
"value": {
"value": 300
}
}
},
{
"model": "misc.setting",
"pk": 2,
"fields": {
"name": "PRICE_PER_CLICK",
"description": "Награда за клик",
"value": {
"value": 1
}
}
},
{
"model": "misc.setting",
"pk": 3,
"fields": {
"name": "DAY_MULT",
"description": "Дневной мультипликатор",
"value": {
"value": "1.5"
}
}
},
{
"model": "misc.setting",
"pk": 4,
"fields": {
"name": "WEEK_MULT",
"description": "Недельный мультипликатор",
"value": {
"value": "1.5"
}
}
},
{
"model": "misc.setting",
"pk": 5,
"fields": {
"name": "PROGRESS_MULT",
"description": "Мультипликатор прогресса",
"value": {
"value": "1.5"
}
}
},
{
"model": "misc.setting",
"pk": 6,
"fields": {
"name": "SESSION_COOLDOWN",
"description": "Кулдаун сессии",
"value": {
"value": 30
}
}
}
]

View File

@ -7,7 +7,7 @@ from starlette.exceptions import HTTPException
from app.src.routers.api import router as router_api from app.src.routers.api import router as router_api
from app.src.routers.handlers import http_error_handler from app.src.routers.handlers import http_error_handler
from app.src.domain.setting import launch_consumer from app.src.domain.setting import launch_consumer
from app.src.db import connect_pg, connect_redis, get_connection, get_channel, get_rmq from app.src.db import connect_pg, get_connection, get_channel, get_rmq
def get_application() -> FastAPI: def get_application() -> FastAPI:
@ -35,8 +35,6 @@ async def startup():
app.state.pg_pool = await connect_pg() app.state.pg_pool = await connect_pg()
app.state.redis_pool = connect_redis()
rmq_conn_pool = aio_pika.pool.Pool(get_connection, max_size=2) rmq_conn_pool = aio_pika.pool.Pool(get_connection, max_size=2)
rmq_chan_pool = aio_pika.pool.Pool(partial(get_channel, conn_pool=rmq_conn_pool), max_size=10) rmq_chan_pool = aio_pika.pool.Pool(partial(get_channel, conn_pool=rmq_conn_pool), max_size=10)
app.state.rmq_chan_pool = rmq_chan_pool app.state.rmq_chan_pool = rmq_chan_pool
@ -45,5 +43,4 @@ async def startup():
@app.on_event("shutdown") @app.on_event("shutdown")
async def shutdown(): async def shutdown():
await app.state.pg_pool.close() await app.state.pg_pool.close()
await app.state.redis.close()

View File

@ -4,13 +4,6 @@ from functools import lru_cache
config = Config() config = Config()
REDIS_USER = config('REDIS_USER')
REDIS_PASSWORD = config('REDIS_PASSWORD', cast=Secret)
REDIS_PORT = config('REDIS_PORT', cast=int)
REDIS_HOST = config('REDIS_HOST')
REDIS_DB = config('REDIS_DB')
PG_HOST = config('POSTGRES_HOST') PG_HOST = config('POSTGRES_HOST')
PG_PORT = config('POSTGRES_PORT', cast=int) PG_PORT = config('POSTGRES_PORT', cast=int)
PG_USER = config('POSTGRES_USER') PG_USER = config('POSTGRES_USER')

View File

@ -1,3 +1,2 @@
from .pg import get_pg, connect_pg from .pg import get_pg, connect_pg
from .redis import get_redis, connect_redis
from .rmq import get_rmq, get_channel, get_connection from .rmq import get_rmq, get_channel, get_connection

View File

@ -1,2 +1,12 @@
DROP INDEX clicks_user_id_time_idx; DROP VIEW coefficients;
DROP TABLE clicks; DROP TABLE clicks;
DROP TABLE users;
DROP TABLE global_stat;
DROP FUNCTION raise_error;
DROP FUNCTION handle_new_click;
DROP FUNCTION handle_new_user;
DROP FUNCTION handle_user_deletion;

View File

@ -1,8 +1,120 @@
CREATE TABLE IF NOT EXISTS clicks( CREATE TABLE users(
id BIGSERIAL PRIMARY KEY, id BIGINT PRIMARY KEY,
user_id BIGINT, energy INTEGER NOT NULL CONSTRAINT non_negative_energy CHECK (energy >= 0),
time TIMESTAMP, session VARCHAR(255) NOT NULL
value DECIMAL(100, 2),
expiry_info JSONB
); );
CREATE INDEX IF NOT EXISTS clicks_user_id_time_idx ON clicks(user_id, time);
CREATE TABLE clicks(
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users (id) ON DELETE CASCADE,
time TIMESTAMP NOT NULL,
value DECIMAL(100, 2) NOT NULL,
count BIGINT NOT NULL
);
CREATE INDEX clicks_user_id_time_idx ON clicks(user_id, time);
CREATE MATERIALIZED VIEW coefficients AS
SELECT
user_id,
(SELECT COUNT(*) FROM clicks AS c1 WHERE now() - time < interval '24 hours' AND c1.user_id = user_id) AS period_24,
(SELECT COUNT(*) FROM clicks AS c1 WHERE now() - time < interval '168 hours' AND c1.user_id = user_id) AS period_168,
(SELECT SUM(value * count) FROM clicks AS c1 WHERE c1.user_id = user_id) as total
FROM clicks
;
CREATE TABLE global_stat(
id BIGINT PRIMARY KEY DEFAULT 1,
user_count BIGINT NOT NULL,
global_average DECIMAL(100, 2) NOT NULL,
max_period_24 DECIMAL(100, 2) NOT NULL,
max_period_168 DECIMAL(100, 2) NOT NULL
);
INSERT INTO global_stat (user_count, global_average, max_period_24, max_period_168) VALUES (0, 0, 0, 0);
CREATE OR REPLACE FUNCTION raise_error()
RETURNS TRIGGER
AS $body$
BEGIN
RAISE EXCEPTION 'No changes allowed';
RETURN NULL;
END;
$body$
LANGUAGE PLPGSQL;
CREATE TRIGGER singleton_trg
BEFORE INSERT OR DELETE OR TRUNCATE ON global_stat
FOR EACH STATEMENT EXECUTE PROCEDURE raise_error();
CREATE OR REPLACE FUNCTION handle_new_click()
RETURNS TRIGGER
AS $body$
BEGIN
WITH user_stats AS (
SELECT period_24, period_168 FROM coefficients AS c WHERE c.user_id=new.user_id
)
UPDATE global_stat AS gs SET
global_average=(gs.global_average * gs.user_count + new.value * new.count) / gs.user_count,
max_period_24=GREATEST(us.period_24, gs.max_period_24),
max_period_168=GREATEST(us.period_168, gs.max_period_168)
FROM user_stats AS us
;
RETURN NULL;
END;
$body$
LANGUAGE PLPGSQL;
CREATE TRIGGER new_click_trg
AFTER INSERT ON clicks
FOR EACH ROW EXECUTE PROCEDURE handle_new_click();
CREATE OR REPLACE FUNCTION handle_new_user()
RETURNS TRIGGER
AS $body$
BEGIN
UPDATE global_stat SET
global_average=global_average * user_count / (user_count + 1),
user_count=user_count + 1
;
RETURN NULL;
END;
$body$
LANGUAGE PLPGSQL;
CREATE TRIGGER new_user_trg
AFTER INSERT ON users
FOR EACH ROW EXECUTE PROCEDURE handle_new_user();
CREATE OR REPLACE FUNCTION handle_user_deletion()
RETURNS TRIGGER
AS $body$
BEGIN
UPDATE global_stat SET
global_average=global_average * user_count / (user_count - 1),
user_count=user_count - 1
;
END;
$body$
LANGUAGE PLPGSQL;
CREATE TRIGGER delete_user_trg
AFTER DELETE ON users
FOR EACH ROW EXECUTE PROCEDURE handle_user_deletion();
CREATE OR REPLACE FUNCTION handle_user_truncate()
RETURNS TRIGGER
AS $body$
BEGIN
UPDATE global_stat SET
global_average=0,
user_count=0
;
RETURN NULL;
END;
$body$
LANGUAGE PLPGSQL;
CREATE TRIGGER truncate_user_trg
AFTER TRUNCATE ON users
FOR STATEMENT EXECUTE PROCEDURE handle_user_truncate();

View File

@ -1,14 +0,0 @@
from starlette.requests import Request
import redis.asyncio as redis
from ..config import REDIS_HOST, REDIS_PORT, REDIS_USER, REDIS_PASSWORD, REDIS_DB
def connect_redis() -> redis.ConnectionPool:
return redis.ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, username=REDIS_USER, password=str(REDIS_PASSWORD), db=REDIS_DB)
async def get_redis(request: Request) -> redis.Redis:
r = redis.Redis(connection_pool=request.app.state.redis_pool)
yield r
await r.aclose()

View File

@ -7,3 +7,4 @@ class Click(pydantic.BaseModel):
userId: int userId: int
dateTime: datetime.datetime dateTime: datetime.datetime
value: decimal.Decimal value: decimal.Decimal
count: int

View File

@ -1,50 +1,78 @@
from datetime import datetime, timedelta from typing import Tuple
import decimal import decimal
from asyncpg import Connection from asyncpg import Connection
from ..models import Click from ..models import Click
async def update_click_expiry(conn: Connection, user_id: int, period: int) -> decimal.Decimal:
cur_time = datetime.now()
cutoff_time = cur_time - timedelta(hours=period)
query = '''
WITH expired_values AS(
UPDATE clicks
SET expiry_info=jsonb_set(expiry_info, $1, 'true')
WHERE 1=1
AND time < $2
AND user_id =$3
AND not (expiry_info->>$4)::bool
RETURNING value
)
SELECT COALESCE(SUM(value), 0)
FROM expired_values
;
'''
period_key = f'period_{period}'
return await conn.fetchval(query, [period_key], cutoff_time, user_id, period_key)
async def store(conn: Connection, click: Click) -> int: async def store(conn: Connection, click: Click) -> int:
query = ''' query = '''
INSERT INTO clicks(user_id, time, value, expiry_info) INSERT INTO clicks(user_id, time, value, count)
VALUES($1, $2, $3, '{"period_24": false, "period_168": false}') VALUES($1, $2, $3, $4)
RETURNING id RETURNING id
; ;
''' '''
return await conn.fetchval(query, click.userId, click.dateTime, click.value) return await conn.fetchval(query, click.userId, click.dateTime, click.value, click.count)
async def bulk_store_copy(conn: Connection, click: Click, count: int) -> None: async def delete_user_info(conn: Connection, user_id: int):
args = [(click.userId, click.dateTime, click.value) for _ in range(count)] async with conn.transaction():
query = ''' await conn.execute('DELETE FROM clicks WHERE user_id=$1', user_id)
INSERT INTO clicks(user_id, time, value, expiry_info) await conn.execute('DELTE FROM users WHERE id=$1', user_id)
VALUES($1, $2, $3, '{"period_24": false, "period_168": false}')
;
'''
await conn.executemany(query, args)
async def delete_by_user_id(conn: Connection, user_id: int): async def get_period_sum(conn: Connection, user_id: int, period: int) -> decimal.Decimal:
await conn.execute('DELETE FROM clicks WHERE user_id=$1', user_id) if not isinstance(period, int):
raise ValueError('period must be an integer')
return (await conn.fetchval(f'SELECT period_{period} FROM coefficients WHERE user_id=$1', user_id)) or decimal.Decimal(0)
async def get_max_period_sum(conn: Connection, period: int) -> decimal.Decimal:
if not isinstance(period, int):
raise ValueError('period must be an integer')
return (await conn.fetchval(f'SELECT max_period_{period} FROM global_stat')) or decimal.Decimal(0)
async def get_global_average(conn: Connection) -> decimal.Decimal:
return (await conn.fetchval('SELECT global_average FROM global_stat')) or decimal.Decimal(0)
async def get_user_total(conn: Connection, user_id: int) -> decimal.Decimal:
return (await conn.fetchval('SELECT total FROM coefficients WHERE user_id=$1', user_id)) or decimal.Decimal(0)
async def user_exists(conn: Connection, user_id: int) -> bool:
return await conn.fetchval('SELECT EXISTS(SELECT 1 FROM users WHERE id=$1)', user_id)\
async def add_user(conn: Connection, user_id: int, req_token: str, session_energy: int):
await conn.execute(
'INSERT INTO users (id, session, energy) VALUES ($1, $2, $3)',
user_id, req_token, session_energy
)
async def get_user_session(conn: Connection, user_id: int) -> str:
return await conn.fetchval('SELECT session FROM users WHERE id=$1', user_id)
async def set_new_session(conn: Connection, user_id: int, req_token: str, session_energy: int):
await conn.execute('UPDATE users SET session=$1, energy=$2 WHERE id=$3', req_token, session_energy, user_id)
async def get_energy(conn: Connection, user_id: int) -> int:
return await conn.fetchval('SELECT energy FROM users WHERE id=$1', user_id)
async def decr_energy(conn: Connection, user_id: int, amount: int) -> Tuple[int, int]:
new_energy, spent = await conn.fetchrow('''
WITH energy_cte AS (
SELECT energy, (CASE WHEN energy < $2 THEN energy ELSE $2 END) AS delta FROM users WHERE id=$1
)
UPDATE users AS u SET
energy=u.energy - e.delta
FROM energy_cte AS e
WHERE id=$1
RETURNING u.energy as new_energy, e.delta
''', user_id, amount)
return new_energy, spent

View File

@ -1,172 +0,0 @@
import decimal
from typing import Optional, List
import redis.asyncio as redis
async def get_period_sum(r: redis.Redis, user_id: int, period: int) -> decimal.Decimal:
sum_bytes = await r.get(f'period_{period}_user_{user_id}')
if sum_bytes is None:
return decimal.Decimal(0)
return decimal.Decimal(sum_bytes.decode())
async def incr_period_sum(r: redis.Redis, user_id: int, _period: int, value: decimal.Decimal) -> decimal.Decimal:
return await r.incrbyfloat(f'period_{_period}_user_{user_id}', float(value))
async def get_max_period_sum(r: redis.Redis, _period: int) -> decimal.Decimal:
max_sum_bytes = await r.get(f'max_period_{_period}')
if max_sum_bytes is None:
return decimal.Decimal(0)
return decimal.Decimal(max_sum_bytes.decode())
async def compare_max_period_sum(r: redis.Redis, _period: int, _sum: decimal.Decimal) -> None:
_script = r.register_script('''
local currentValue = tonumber(redis.call('GET', KEYS[1]))
local cmpValue = tonumber(ARGV[1])
if not currentValue or cmpValue > currentValue then
redis.call('SET', KEYS[1], ARGV[1])
return cmpValue
else
return currentValue
end
''')
await _script(keys=[f'max_period_{_period}'], args=[str(_sum)])
async def get_energy(r: redis.Redis, user_id: int) -> int:
energy_str = await r.get(f'energy_{user_id}')
if energy_str is None:
return 0
return int(energy_str)
async def set_energy(r: redis.Redis, user_id: int, energy: int) -> int:
await r.set(f'energy_{user_id}', energy)
async def decr_energy(r: redis.Redis, user_id: int, amount: int) -> (int, int):
_script = r.register_script('''
local energy = tonumber(redis.call('GET', KEYS[1]))
local delta = tonumber(ARGV[1])
if energy < delta then
redis.call('SET', KEYS[1], 0)
return {0, energy}
else
local newEnergy = tonumber(redis.call('DECRBY', KEYS[1], ARGV[1]))
return {newEnergy, delta}
end
''')
new_energy, spent= map(int, await _script(keys=[f'energy_{user_id}'], args=[amount]))
return new_energy, spent
async def get_global_average(r: redis.Redis) -> decimal.Decimal:
avg_bytes = await r.get('global_average')
if avg_bytes is None:
return decimal.Decimal(0)
return decimal.Decimal(avg_bytes.decode())
async def update_global_average(r: redis.Redis, value_to_add: decimal.Decimal) -> decimal.Decimal:
_script = r.register_script('''
local delta = tonumber(ARGV[1]) / tonumber(redis.call('GET', KEYS[1]))
return redis.call('INCRBYFLOAT', KEYS[2], delta)
''')
return decimal.Decimal((await _script(keys=["user_count", "global_average"], args=[float(value_to_add)])).decode())
async def get_user_total(r: redis.Redis, user_id: int) -> decimal.Decimal:
total_bytes = await r.get(f'total_{user_id}')
if total_bytes is None:
return decimal.Decimal(0)
return decimal.Decimal(total_bytes.decode())
async def incr_user_count_if_no_clicks(r: redis.Redis, user_id: int) -> int:
_script = r.register_script('''
local clickCount = tonumber(redis.call('GET', KEYS[1]))
local userCount = tonumber(redis.call('GET', KEYS[2]))
if (not clickCount) then
local oldUserCount = redis.call('GET', KEYS[2])
if (not oldUserCount) then
redis.call('SET', KEYS[2], 1)
redis.call('SET', KEYS[3], 0)
return 1
end
userCount = tonumber(redis.call('INCR', KEYS[2]))
oldUserCount = tonumber(oldUserCount)
local globalAverage = tonumber(redis.call('GET', KEYS[3]))
redis.call('SET', KEYS[3], globalAverage / userCount * oldUserCount)
end
return userCount
''')
return int(await _script(keys=[f'total_{user_id}', 'user_count', 'global_average'], args=[]))
async def incr_user_total(r: redis.Redis, user_id: int, value: decimal.Decimal) -> decimal.Decimal:
return await r.incrbyfloat(f'total_{user_id}', float(value))
async def get_user_session(r: redis.Redis, user_id: int) -> Optional[str]:
session_bytes = await r.get(f'session_{user_id}')
if session_bytes is None:
return None
return session_bytes.decode()
async def set_user_session(r: redis.Redis, user_id: int, token: str) -> None:
await r.set(f'session_{user_id}', token, ex=30 * 60)
async def get_user_count(r: redis.Redis) -> int:
user_count_str = await r.get('user_count')
if user_count_str is None:
return 0
return int(user_count_str)
async def incr_user_count(r: redis.Redis) -> int:
_script = r.register_script('''
local oldCount = redis.call('GET', KEYS[1])
if (not oldCount) then
redis.call('SET', KEYS[1], 1)
redis.call('SET', KEYS[2], 0)
return 1
end
local newCount = tonumber(redis.call('INCR', KEYS[1]))
local globalAverage = tonumber(redis.call('GET', KEYS[2]))
redis.call('SET', KEYS[2], globalAverage / newCount * oldCount)
return newCount
''')
return int(await _script(keys=['user_count', 'global_average'], args=[]))
async def delete_user_info(r: redis.Redis, user_id: int, periods: List[int]):
_script = r.register_script('''
local userTotal = redis.call('GET', KEYS[3])
if (not userTotal) then
return
end
local oldUserCount = tonumber(redis.call('GET', KEYS[1]))
local newUserCount = tonumber(redis.call('DECR', KEYS[1]))
local globalAverage = tonumber(redis.call('GET', KEYS[2]))
redis.call('SET', KEYS[2], (globalAverage * oldUserCount - userTotal) / newUserCount)
for i, v in ipairs(KEYS) do
if (i > 2) then
redis.call('DEL', v)
end
end
''')
keys = [
'user_count',
'global_average',
f'total_{user_id}'
f'energy_{user_id}',
f'session_{user_id}',
]
for period in periods:
keys.append(f'period_{period}_user_{user_id}')
await _script(keys=keys, args=[])

View File

@ -10,8 +10,8 @@ CELERY_QUEUE_NAME = "celery"
CLICK_TASK_NAME = "clicks.celery.click.handle_click" CLICK_TASK_NAME = "clicks.celery.click.handle_click"
async def send_click_batch_copy(chan: aio_pika.Channel, click: Click, count: int): async def send_click(chan: aio_pika.Channel, click: Click):
args = (click.userId, int(click.dateTime.timestamp() * 1e3), str(click.value), count) args = (click.userId, int(click.dateTime.timestamp() * 1e3), str(click.value), click.count)
await chan.default_exchange.publish( await chan.default_exchange.publish(
message=aio_pika.Message( message=aio_pika.Message(
body=json.dumps([ body=json.dumps([

View File

@ -2,87 +2,70 @@ from datetime import datetime
import decimal import decimal
from typing import Tuple from typing import Tuple
import aiohttp import aiohttp
import redis.asyncio as redis
import aio_pika import aio_pika
import asyncpg import asyncpg
import base64 import base64
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from app.src.domain.setting import get_setting from app.src.domain.setting import get_setting
from .repos.redis import ( from .repos.pg import (
get_period_sum, incr_period_sum, get_max_period_sum, get_user_total, get_global_average, store, delete_user_info as pg_delete_user_info, get_period_sum, get_max_period_sum, get_global_average, get_user_total, user_exists, get_user_session,
incr_user_count_if_no_clicks, update_global_average, incr_user_total, compare_max_period_sum, set_new_session, get_energy as pg_get_energy, decr_energy, add_user
delete_user_info as r_delete_user_info, get_user_session, set_user_session, set_energy, get_energy as r_get_energy,
decr_energy,
) )
from .repos.pg import update_click_expiry, bulk_store_copy, delete_by_user_id from .repos.rmq import send_click
from .repos.rmq import send_click_batch_copy
from .models import Click from .models import Click
PRECISION = 2 PRECISION = 2
async def add_click_batch_copy(r: redis.Redis, pg: asyncpg.Connection, rmq: aio_pika.Channel, user_id: int, count: int) -> Click: async def add_click_batch_copy(pg: asyncpg.Connection, rmq: aio_pika.Channel, user_id: int, count: int) -> Click:
_click_value = await click_value(r, pg, user_id) _click_value = await click_value(pg, user_id)
click_value_sum = _click_value * count
# update variables
await incr_user_count_if_no_clicks(r, user_id)
await update_global_average(r, click_value_sum)
await incr_user_total(r, user_id, click_value_sum)
for period in (24, 24*7):
new_period_sum = await incr_period_sum(r, user_id, period, click_value_sum)
await compare_max_period_sum(r, period, new_period_sum)
click = Click( click = Click(
userId=user_id, userId=user_id,
dateTime=datetime.now(), dateTime=datetime.now(),
value=_click_value, value=_click_value,
count=count,
) )
# insert click # insert click
await bulk_store_copy(pg, click, count) await store(pg, click)
# send click to backend # send click to backend
await send_click_batch_copy(rmq, click, count) await send_click(rmq, click)
return click return click
async def delete_user_info(r: redis.Redis, pg: asyncpg.Connection, user_id: int) -> None: async def delete_user_info(pg: asyncpg.Connection, user_id: int) -> None:
await r_delete_user_info(r, user_id, [24, 168]) await pg_delete_user_info(pg, user_id)
await delete_by_user_id(pg, user_id)
async def click_value(r: redis.Redis, pg: asyncpg.Connection, user_id: int) -> decimal.Decimal: async def click_value(pg: asyncpg.Connection, user_id: int) -> decimal.Decimal:
price_per_click = get_setting('PRICE_PER_CLICK') price_per_click = get_setting('PRICE_PER_CLICK')
day_multiplier = get_setting('DAY_MULT') day_multiplier = get_setting('DAY_MULT')
week_multiplier = get_setting('WEEK_MULT') week_multiplier = get_setting('WEEK_MULT')
progress_multiplier = get_setting('PROGRESS_MULT') progress_multiplier = get_setting('PROGRESS_MULT')
# period coefficients # period coefficients
day_coef = await period_coefficient(r, pg, user_id, 24, day_multiplier) day_coef = await period_coefficient(pg, user_id, 24, day_multiplier)
week_coef = await period_coefficient(r, pg, user_id, 24*7, week_multiplier) week_coef = await period_coefficient(pg, user_id, 24*7, week_multiplier)
# progress coefficient # progress coefficient
user_total = await get_user_total(r, user_id) user_total = await get_user_total(pg, user_id)
global_avg = await get_global_average(r) global_avg = await get_global_average(pg)
progress_coef = progress_coefficient(user_total, global_avg, progress_multiplier) progress_coef = progress_coefficient(user_total, global_avg, progress_multiplier)
return round(price_per_click * day_coef * week_coef * progress_coef, PRECISION) return round(price_per_click * day_coef * week_coef * progress_coef, PRECISION)
async def period_coefficient(r: redis.Redis, pg: asyncpg.Connection, user_id: int, period: int, multiplier: decimal.Decimal) -> decimal.Decimal: async def period_coefficient(pg: asyncpg.Connection, user_id: int, period: int, multiplier: decimal.Decimal) -> decimal.Decimal:
current_sum = await get_period_sum(r, user_id, period) current_period_sum = await get_period_sum(pg, user_id, period)
expired_sum = await update_click_expiry(pg, user_id, period) max_period_sum = await get_max_period_sum(pg, period)
new_sum = current_sum - expired_sum
await incr_period_sum(r, user_id, period, -expired_sum)
max_period_sum = await get_max_period_sum(r, period)
if max_period_sum == decimal.Decimal(0): if max_period_sum == decimal.Decimal(0):
return decimal.Decimal(1) return decimal.Decimal(1)
return new_sum * multiplier / max_period_sum + 1 return current_period_sum * multiplier / max_period_sum + 1
def progress_coefficient(user_total: decimal.Decimal, global_avg: decimal.Decimal, multiplier: decimal.Decimal) -> decimal.Decimal: def progress_coefficient(user_total: decimal.Decimal, global_avg: decimal.Decimal, multiplier: decimal.Decimal) -> decimal.Decimal:
@ -91,34 +74,31 @@ def progress_coefficient(user_total: decimal.Decimal, global_avg: decimal.Decima
return min(global_avg * multiplier / user_total + 1, decimal.Decimal(2)) return min(global_avg * multiplier / user_total + 1, decimal.Decimal(2))
async def check_registration(r: redis.Redis, user_id: int, _token: str, backend_url: str) -> bool: async def check_registration(pg: asyncpg.Connection, user_id: int, _token: str, backend_url: str) -> bool:
if await _has_any_clicks(r, user_id): if await user_exists(pg, user_id):
return True return True
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(f'{backend_url}/api/v1/users/{user_id}', headers={'Authorization': f'TelegramToken {_token}'}) as resp: async with session.get(f'{backend_url}/api/v1/users/{user_id}', headers={'Authorization': f'TelegramToken {_token}'}) as resp:
return resp.status == 200 return resp.status == 200
async def _has_any_clicks(r: redis.Redis, user_id: int) -> bool: async def _get_refresh_energy(pg: asyncpg.Connection, user_id: int, req_token: str) -> int:
total_value = await get_user_total(r, user_id)
return total_value > decimal.Decimal(0)
async def _get_refresh_energy(r: redis.Redis, user_id: int, req_token: str) -> int:
new_auth_date = _auth_date_from_token(req_token) new_auth_date = _auth_date_from_token(req_token)
current_token = await get_user_session(r, user_id) current_token = await get_user_session(pg, user_id)
if current_token != req_token: if current_token is None:
if current_token is not None:
last_auth_date = _auth_date_from_token(current_token)
session_cooldown = get_setting('SESSION_COOLDOWN')
if new_auth_date - last_auth_date < session_cooldown:
raise HTTPException(status_code=403, detail='Unauthorized')
session_energy = int(get_setting('SESSION_ENERGY')) session_energy = int(get_setting('SESSION_ENERGY'))
await set_user_session(r, user_id, req_token) await add_user(pg, user_id, req_token, session_energy)
await set_energy(r, user_id, session_energy) return session_energy
if current_token != req_token:
last_auth_date = _auth_date_from_token(current_token)
session_cooldown = get_setting('SESSION_COOLDOWN')
if new_auth_date - last_auth_date < session_cooldown:
raise HTTPException(status_code=403, detail='Unauthorized')
session_energy = int(get_setting('SESSION_ENERGY'))
await set_new_session(pg, user_id, req_token, session_energy)
return session_energy return session_energy
else: else:
return await r_get_energy(r, user_id) return await pg_get_energy(pg, user_id)
def _auth_date_from_token(token): def _auth_date_from_token(token):
split_res = base64.b64decode(token).decode('utf-8').split(':') split_res = base64.b64decode(token).decode('utf-8').split(':')
@ -127,12 +107,12 @@ def _auth_date_from_token(token):
return int(data_dict['auth_date']) return int(data_dict['auth_date'])
async def check_energy(r: redis.Redis, user_id: int, amount: int, _token: str) -> Tuple[int, int]: async def check_energy(pg: asyncpg.Connection, user_id: int, amount: int, _token: str) -> Tuple[int, int]:
_energy = await _get_refresh_energy(r, user_id, _token) _energy = await _get_refresh_energy(pg, user_id, _token)
if _energy == 0: if _energy == 0:
return 0, 0 return 0, 0
return await decr_energy(r, user_id, amount) return await decr_energy(pg, user_id, amount)
async def get_energy(r: redis.Redis, user_id: int, _token: str) -> int: async def get_energy(pg: asyncpg.Connection, user_id: int, _token: str) -> int:
return await _get_refresh_energy(r, user_id, _token) return await _get_refresh_energy(pg, user_id, _token)

View File

@ -1,6 +1,5 @@
import aio_pika import aio_pika
import asyncpg import asyncpg
import redis
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from typing import Annotated, Tuple from typing import Annotated, Tuple
from ..domain.click import ( from ..domain.click import (
@ -9,7 +8,7 @@ from ..domain.click import (
) )
from ..dependencies import get_token_header from ..dependencies import get_token_header
from ..db import get_pg, get_redis, get_rmq from ..db import get_pg, get_rmq
from ..config import BACKEND_URL from ..config import BACKEND_URL
@ -22,16 +21,16 @@ router = APIRouter(
@router.post("/batch-click/", response_model=ClickResponse, status_code=200) @router.post("/batch-click/", response_model=ClickResponse, status_code=200)
async def batch_click(req: BatchClickRequest, auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)], r: Annotated[redis.Redis, Depends(get_redis)], rmq: Annotated[aio_pika.Channel, Depends(get_rmq)]): async def batch_click(req: BatchClickRequest, auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)], rmq: Annotated[aio_pika.Channel, Depends(get_rmq)]):
user_id, token = auth_info user_id, token = auth_info
if not await check_registration(r, user_id, token, BACKEND_URL): if not await check_registration(pg, user_id, token, BACKEND_URL):
raise HTTPException(status_code=403, detail='Unauthorized') raise HTTPException(status_code=403, detail='Unauthorized')
_energy, spent = await check_energy(r, user_id, req.count, token) _energy, spent = await check_energy(pg, user_id, req.count, token)
if spent == 0: if spent == 0:
raise HTTPException(status_code=400, detail='No energy') raise HTTPException(status_code=400, detail='No energy')
click = await add_click_batch_copy(r, pg, rmq, user_id, spent) click = await add_click_batch_copy(pg, rmq, user_id, spent)
return ClickResponse( return ClickResponse(
click=click, click=click,
energy=_energy energy=_energy
@ -39,33 +38,33 @@ async def batch_click(req: BatchClickRequest, auth_info: Annotated[Tuple[int, st
@router.get("/energy", response_model=EnergyResponse, status_code=200) @router.get("/energy", response_model=EnergyResponse, status_code=200)
async def energy(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)]): async def energy(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]):
user_id, token = auth_info user_id, token = auth_info
if not await check_registration(r, user_id, token, BACKEND_URL): if not await check_registration(pg, user_id, token, BACKEND_URL):
raise HTTPException(status_code=403, detail='Unauthorized') raise HTTPException(status_code=403, detail='Unauthorized')
_energy = await get_energy(r, user_id, token) _energy = await get_energy(pg, user_id, token)
return EnergyResponse( return EnergyResponse(
energy=_energy energy=_energy
) )
@router.get('/coefficient', response_model=ClickValueResponse, status_code=200) @router.get('/coefficient', response_model=ClickValueResponse, status_code=200)
async def coefficient(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]): async def coefficient(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]):
user_id, token = auth_info user_id, token = auth_info
if not await check_registration(r, user_id, token, BACKEND_URL): if not await check_registration(pg, user_id, token, BACKEND_URL):
raise HTTPException(status_code=403, detail='Unauthorized') raise HTTPException(status_code=403, detail='Unauthorized')
value = await click_value(r, pg, user_id) value = await click_value(pg, pg, user_id)
return ClickValueResponse( return ClickValueResponse(
value=value value=value
) )
@router.delete('/internal/user', status_code=204) @router.delete('/internal/user', status_code=204)
async def delete_user(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]): async def delete_user(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]):
user_id, token = auth_info user_id, token = auth_info
if not await check_registration(r, user_id, token, BACKEND_URL): if not await check_registration(pg, user_id, token, BACKEND_URL):
raise HTTPException(status_code=403, detail='Unauthorized') raise HTTPException(status_code=403, detail='Unauthorized')
await delete_user_info(r, pg, user_id) await delete_user_info(pg, user_id)

View File

@ -20,7 +20,6 @@ pamqp==3.3.0
propcache==0.2.0 propcache==0.2.0
pydantic==2.9.2 pydantic==2.9.2
pydantic_core==2.23.4 pydantic_core==2.23.4
redis==5.1.1
requests==2.32.3 requests==2.32.3
sniffio==1.3.1 sniffio==1.3.1
starlette==0.40.0 starlette==0.40.0

View File

@ -1,7 +1,6 @@
volumes: volumes:
db_data: {} db_data: {}
batcher_db_data: {} batcher_db_data: {}
redis_data: {}
backend_media: {} backend_media: {}
backend_static: {} backend_static: {}
@ -84,26 +83,13 @@ services:
interval: 10s interval: 10s
timeout: 2s timeout: 2s
redis:
env_file:
- .env/prod/redis
image: redis
command: bash -c "redis-server --appendonly yes --requirepass $${REDIS_PASSWORD}"
volumes:
- redis_data:/data
healthcheck:
<<: *pg-healthcheck
test: "[ $$(redis-cli -a $${REDIS_PASSWORD} ping) = 'PONG' ]"
batcher: batcher:
build: ./batcher build: ./batcher
depends_on: depends_on:
redis: *healthy-dependency
batcher-postgres: *healthy-dependency batcher-postgres: *healthy-dependency
rabbitmq: *healthy-dependency rabbitmq: *healthy-dependency
env_file: env_file:
- .env/prod/rmq - .env/prod/rmq
- .env/prod/redis
- .env/prod/batcher-pg - .env/prod/batcher-pg
- .env/prod/batcher - .env/prod/batcher
- .env/prod/bot - .env/prod/bot

View File

@ -1,7 +1,6 @@
volumes: volumes:
db_data: {} db_data: {}
batcher_db_data: {} batcher_db_data: {}
redis_data: {}
backend_media: {} backend_media: {}
backend_static: {} backend_static: {}
@ -66,28 +65,13 @@ services:
interval: 10s interval: 10s
timeout: 2s timeout: 2s
redis:
env_file:
- .env/dev/redis
image: redis
command: bash -c "redis-server --appendonly yes --requirepass $${REDIS_PASSWORD}"
ports:
- '6379:6379'
volumes:
- redis_data:/data
healthcheck:
<<: *pg-healthcheck
test: "[ $$(redis-cli -a $$REDIS_PASSWORD ping) = 'PONG' ]"
batcher: batcher:
build: ./batcher build: ./batcher
depends_on: depends_on:
redis: *healthy-dependency
batcher-postgres: *healthy-dependency batcher-postgres: *healthy-dependency
rabbitmq: *healthy-dependency rabbitmq: *healthy-dependency
env_file: env_file:
- .env/dev/rmq - .env/dev/rmq
- .env/dev/redis
- .env/dev/batcher-pg - .env/dev/batcher-pg
- .env/dev/batcher - .env/dev/batcher
- .env/dev/bot - .env/dev/bot