Remove redis, refactor click batching #1

Merged
s1lur merged 1 commits from refactor-redis-and-click-batching into dev 2025-05-11 10:48:10 +03:00
17 changed files with 319 additions and 350 deletions

View File

@ -19,4 +19,3 @@ def deliver_setting(setting_name):
routing_key=settings.SETTINGS_QUEUE_NAME,
declare=[queue],
)

View File

@ -0,0 +1,68 @@
[
{
"model": "misc.setting",
"pk": 1,
"fields": {
"name": "SESSION_ENERGY",
"description": "Энергия на сессию",
"value": {
"value": 300
}
}
},
{
"model": "misc.setting",
"pk": 2,
"fields": {
"name": "PRICE_PER_CLICK",
"description": "Награда за клик",
"value": {
"value": 1
}
}
},
{
"model": "misc.setting",
"pk": 3,
"fields": {
"name": "DAY_MULT",
"description": "Дневной мультипликатор",
"value": {
"value": "1.5"
}
}
},
{
"model": "misc.setting",
"pk": 4,
"fields": {
"name": "WEEK_MULT",
"description": "Недельный мультипликатор",
"value": {
"value": "1.5"
}
}
},
{
"model": "misc.setting",
"pk": 5,
"fields": {
"name": "PROGRESS_MULT",
"description": "Мультипликатор прогресса",
"value": {
"value": "1.5"
}
}
},
{
"model": "misc.setting",
"pk": 6,
"fields": {
"name": "SESSION_COOLDOWN",
"description": "Кулдаун сессии",
"value": {
"value": 30
}
}
}
]

View File

@ -7,7 +7,7 @@ from starlette.exceptions import HTTPException
from app.src.routers.api import router as router_api
from app.src.routers.handlers import http_error_handler
from app.src.domain.setting import launch_consumer
from app.src.db import connect_pg, connect_redis, get_connection, get_channel, get_rmq
from app.src.db import connect_pg, get_connection, get_channel, get_rmq
def get_application() -> FastAPI:
@ -35,8 +35,6 @@ async def startup():
app.state.pg_pool = await connect_pg()
app.state.redis_pool = connect_redis()
rmq_conn_pool = aio_pika.pool.Pool(get_connection, max_size=2)
rmq_chan_pool = aio_pika.pool.Pool(partial(get_channel, conn_pool=rmq_conn_pool), max_size=10)
app.state.rmq_chan_pool = rmq_chan_pool
@ -45,5 +43,4 @@ async def startup():
@app.on_event("shutdown")
async def shutdown():
await app.state.pg_pool.close()
await app.state.redis.close()

View File

@ -4,13 +4,6 @@ from functools import lru_cache
config = Config()
REDIS_USER = config('REDIS_USER')
REDIS_PASSWORD = config('REDIS_PASSWORD', cast=Secret)
REDIS_PORT = config('REDIS_PORT', cast=int)
REDIS_HOST = config('REDIS_HOST')
REDIS_DB = config('REDIS_DB')
PG_HOST = config('POSTGRES_HOST')
PG_PORT = config('POSTGRES_PORT', cast=int)
PG_USER = config('POSTGRES_USER')

View File

@ -1,3 +1,2 @@
from .pg import get_pg, connect_pg
from .redis import get_redis, connect_redis
from .rmq import get_rmq, get_channel, get_connection

View File

@ -1,2 +1,12 @@
DROP INDEX clicks_user_id_time_idx;
DROP VIEW coefficients;
DROP TABLE clicks;
DROP TABLE users;
DROP TABLE global_stat;
DROP FUNCTION raise_error;
DROP FUNCTION handle_new_click;
DROP FUNCTION handle_new_user;
DROP FUNCTION handle_user_deletion;

View File

@ -1,8 +1,120 @@
CREATE TABLE IF NOT EXISTS clicks(
id BIGSERIAL PRIMARY KEY,
user_id BIGINT,
time TIMESTAMP,
value DECIMAL(100, 2),
expiry_info JSONB
CREATE TABLE users(
id BIGINT PRIMARY KEY,
energy INTEGER NOT NULL CONSTRAINT non_negative_energy CHECK (energy >= 0),
session VARCHAR(255) NOT NULL
);
CREATE INDEX IF NOT EXISTS clicks_user_id_time_idx ON clicks(user_id, time);
CREATE TABLE clicks(
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users (id) ON DELETE CASCADE,
time TIMESTAMP NOT NULL,
value DECIMAL(100, 2) NOT NULL,
count BIGINT NOT NULL
);
CREATE INDEX clicks_user_id_time_idx ON clicks(user_id, time);
CREATE MATERIALIZED VIEW coefficients AS
SELECT
user_id,
(SELECT COUNT(*) FROM clicks AS c1 WHERE now() - time < interval '24 hours' AND c1.user_id = user_id) AS period_24,
(SELECT COUNT(*) FROM clicks AS c1 WHERE now() - time < interval '168 hours' AND c1.user_id = user_id) AS period_168,
(SELECT SUM(value * count) FROM clicks AS c1 WHERE c1.user_id = user_id) as total
FROM clicks
;
CREATE TABLE global_stat(
id BIGINT PRIMARY KEY DEFAULT 1,
user_count BIGINT NOT NULL,
global_average DECIMAL(100, 2) NOT NULL,
max_period_24 DECIMAL(100, 2) NOT NULL,
max_period_168 DECIMAL(100, 2) NOT NULL
);
INSERT INTO global_stat (user_count, global_average, max_period_24, max_period_168) VALUES (0, 0, 0, 0);
CREATE OR REPLACE FUNCTION raise_error()
RETURNS TRIGGER
AS $body$
BEGIN
RAISE EXCEPTION 'No changes allowed';
RETURN NULL;
END;
$body$
LANGUAGE PLPGSQL;
CREATE TRIGGER singleton_trg
BEFORE INSERT OR DELETE OR TRUNCATE ON global_stat
FOR EACH STATEMENT EXECUTE PROCEDURE raise_error();
CREATE OR REPLACE FUNCTION handle_new_click()
RETURNS TRIGGER
AS $body$
BEGIN
WITH user_stats AS (
SELECT period_24, period_168 FROM coefficients AS c WHERE c.user_id=new.user_id
)
UPDATE global_stat AS gs SET
global_average=(gs.global_average * gs.user_count + new.value * new.count) / gs.user_count,
max_period_24=GREATEST(us.period_24, gs.max_period_24),
max_period_168=GREATEST(us.period_168, gs.max_period_168)
FROM user_stats AS us
;
RETURN NULL;
END;
$body$
LANGUAGE PLPGSQL;
CREATE TRIGGER new_click_trg
AFTER INSERT ON clicks
FOR EACH ROW EXECUTE PROCEDURE handle_new_click();
CREATE OR REPLACE FUNCTION handle_new_user()
RETURNS TRIGGER
AS $body$
BEGIN
UPDATE global_stat SET
global_average=global_average * user_count / (user_count + 1),
user_count=user_count + 1
;
RETURN NULL;
END;
$body$
LANGUAGE PLPGSQL;
CREATE TRIGGER new_user_trg
AFTER INSERT ON users
FOR EACH ROW EXECUTE PROCEDURE handle_new_user();
CREATE OR REPLACE FUNCTION handle_user_deletion()
RETURNS TRIGGER
AS $body$
BEGIN
UPDATE global_stat SET
global_average=global_average * user_count / (user_count - 1),
user_count=user_count - 1
;
END;
$body$
LANGUAGE PLPGSQL;
CREATE TRIGGER delete_user_trg
AFTER DELETE ON users
FOR EACH ROW EXECUTE PROCEDURE handle_user_deletion();
CREATE OR REPLACE FUNCTION handle_user_truncate()
RETURNS TRIGGER
AS $body$
BEGIN
UPDATE global_stat SET
global_average=0,
user_count=0
;
RETURN NULL;
END;
$body$
LANGUAGE PLPGSQL;
CREATE TRIGGER truncate_user_trg
AFTER TRUNCATE ON users
FOR STATEMENT EXECUTE PROCEDURE handle_user_truncate();

View File

@ -1,14 +0,0 @@
from starlette.requests import Request
import redis.asyncio as redis
from ..config import REDIS_HOST, REDIS_PORT, REDIS_USER, REDIS_PASSWORD, REDIS_DB
def connect_redis() -> redis.ConnectionPool:
return redis.ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, username=REDIS_USER, password=str(REDIS_PASSWORD), db=REDIS_DB)
async def get_redis(request: Request) -> redis.Redis:
r = redis.Redis(connection_pool=request.app.state.redis_pool)
yield r
await r.aclose()

View File

@ -7,3 +7,4 @@ class Click(pydantic.BaseModel):
userId: int
dateTime: datetime.datetime
value: decimal.Decimal
count: int

View File

@ -1,50 +1,78 @@
from datetime import datetime, timedelta
from typing import Tuple
import decimal
from asyncpg import Connection
from ..models import Click
async def update_click_expiry(conn: Connection, user_id: int, period: int) -> decimal.Decimal:
cur_time = datetime.now()
cutoff_time = cur_time - timedelta(hours=period)
query = '''
WITH expired_values AS(
UPDATE clicks
SET expiry_info=jsonb_set(expiry_info, $1, 'true')
WHERE 1=1
AND time < $2
AND user_id =$3
AND not (expiry_info->>$4)::bool
RETURNING value
)
SELECT COALESCE(SUM(value), 0)
FROM expired_values
;
'''
period_key = f'period_{period}'
return await conn.fetchval(query, [period_key], cutoff_time, user_id, period_key)
async def store(conn: Connection, click: Click) -> int:
query = '''
INSERT INTO clicks(user_id, time, value, expiry_info)
VALUES($1, $2, $3, '{"period_24": false, "period_168": false}')
INSERT INTO clicks(user_id, time, value, count)
VALUES($1, $2, $3, $4)
RETURNING id
;
'''
return await conn.fetchval(query, click.userId, click.dateTime, click.value)
return await conn.fetchval(query, click.userId, click.dateTime, click.value, click.count)
async def bulk_store_copy(conn: Connection, click: Click, count: int) -> None:
args = [(click.userId, click.dateTime, click.value) for _ in range(count)]
query = '''
INSERT INTO clicks(user_id, time, value, expiry_info)
VALUES($1, $2, $3, '{"period_24": false, "period_168": false}')
;
'''
await conn.executemany(query, args)
async def delete_by_user_id(conn: Connection, user_id: int):
async def delete_user_info(conn: Connection, user_id: int):
async with conn.transaction():
await conn.execute('DELETE FROM clicks WHERE user_id=$1', user_id)
await conn.execute('DELTE FROM users WHERE id=$1', user_id)
async def get_period_sum(conn: Connection, user_id: int, period: int) -> decimal.Decimal:
if not isinstance(period, int):
raise ValueError('period must be an integer')
return (await conn.fetchval(f'SELECT period_{period} FROM coefficients WHERE user_id=$1', user_id)) or decimal.Decimal(0)
async def get_max_period_sum(conn: Connection, period: int) -> decimal.Decimal:
if not isinstance(period, int):
raise ValueError('period must be an integer')
return (await conn.fetchval(f'SELECT max_period_{period} FROM global_stat')) or decimal.Decimal(0)
async def get_global_average(conn: Connection) -> decimal.Decimal:
return (await conn.fetchval('SELECT global_average FROM global_stat')) or decimal.Decimal(0)
async def get_user_total(conn: Connection, user_id: int) -> decimal.Decimal:
return (await conn.fetchval('SELECT total FROM coefficients WHERE user_id=$1', user_id)) or decimal.Decimal(0)
async def user_exists(conn: Connection, user_id: int) -> bool:
return await conn.fetchval('SELECT EXISTS(SELECT 1 FROM users WHERE id=$1)', user_id)\
async def add_user(conn: Connection, user_id: int, req_token: str, session_energy: int):
await conn.execute(
'INSERT INTO users (id, session, energy) VALUES ($1, $2, $3)',
user_id, req_token, session_energy
)
async def get_user_session(conn: Connection, user_id: int) -> str:
return await conn.fetchval('SELECT session FROM users WHERE id=$1', user_id)
async def set_new_session(conn: Connection, user_id: int, req_token: str, session_energy: int):
await conn.execute('UPDATE users SET session=$1, energy=$2 WHERE id=$3', req_token, session_energy, user_id)
async def get_energy(conn: Connection, user_id: int) -> int:
return await conn.fetchval('SELECT energy FROM users WHERE id=$1', user_id)
async def decr_energy(conn: Connection, user_id: int, amount: int) -> Tuple[int, int]:
new_energy, spent = await conn.fetchrow('''
WITH energy_cte AS (
SELECT energy, (CASE WHEN energy < $2 THEN energy ELSE $2 END) AS delta FROM users WHERE id=$1
)
UPDATE users AS u SET
energy=u.energy - e.delta
FROM energy_cte AS e
WHERE id=$1
RETURNING u.energy as new_energy, e.delta
''', user_id, amount)
return new_energy, spent

View File

@ -1,172 +0,0 @@
import decimal
from typing import Optional, List
import redis.asyncio as redis
async def get_period_sum(r: redis.Redis, user_id: int, period: int) -> decimal.Decimal:
sum_bytes = await r.get(f'period_{period}_user_{user_id}')
if sum_bytes is None:
return decimal.Decimal(0)
return decimal.Decimal(sum_bytes.decode())
async def incr_period_sum(r: redis.Redis, user_id: int, _period: int, value: decimal.Decimal) -> decimal.Decimal:
return await r.incrbyfloat(f'period_{_period}_user_{user_id}', float(value))
async def get_max_period_sum(r: redis.Redis, _period: int) -> decimal.Decimal:
max_sum_bytes = await r.get(f'max_period_{_period}')
if max_sum_bytes is None:
return decimal.Decimal(0)
return decimal.Decimal(max_sum_bytes.decode())
async def compare_max_period_sum(r: redis.Redis, _period: int, _sum: decimal.Decimal) -> None:
_script = r.register_script('''
local currentValue = tonumber(redis.call('GET', KEYS[1]))
local cmpValue = tonumber(ARGV[1])
if not currentValue or cmpValue > currentValue then
redis.call('SET', KEYS[1], ARGV[1])
return cmpValue
else
return currentValue
end
''')
await _script(keys=[f'max_period_{_period}'], args=[str(_sum)])
async def get_energy(r: redis.Redis, user_id: int) -> int:
energy_str = await r.get(f'energy_{user_id}')
if energy_str is None:
return 0
return int(energy_str)
async def set_energy(r: redis.Redis, user_id: int, energy: int) -> int:
await r.set(f'energy_{user_id}', energy)
async def decr_energy(r: redis.Redis, user_id: int, amount: int) -> (int, int):
_script = r.register_script('''
local energy = tonumber(redis.call('GET', KEYS[1]))
local delta = tonumber(ARGV[1])
if energy < delta then
redis.call('SET', KEYS[1], 0)
return {0, energy}
else
local newEnergy = tonumber(redis.call('DECRBY', KEYS[1], ARGV[1]))
return {newEnergy, delta}
end
''')
new_energy, spent= map(int, await _script(keys=[f'energy_{user_id}'], args=[amount]))
return new_energy, spent
async def get_global_average(r: redis.Redis) -> decimal.Decimal:
avg_bytes = await r.get('global_average')
if avg_bytes is None:
return decimal.Decimal(0)
return decimal.Decimal(avg_bytes.decode())
async def update_global_average(r: redis.Redis, value_to_add: decimal.Decimal) -> decimal.Decimal:
_script = r.register_script('''
local delta = tonumber(ARGV[1]) / tonumber(redis.call('GET', KEYS[1]))
return redis.call('INCRBYFLOAT', KEYS[2], delta)
''')
return decimal.Decimal((await _script(keys=["user_count", "global_average"], args=[float(value_to_add)])).decode())
async def get_user_total(r: redis.Redis, user_id: int) -> decimal.Decimal:
total_bytes = await r.get(f'total_{user_id}')
if total_bytes is None:
return decimal.Decimal(0)
return decimal.Decimal(total_bytes.decode())
async def incr_user_count_if_no_clicks(r: redis.Redis, user_id: int) -> int:
_script = r.register_script('''
local clickCount = tonumber(redis.call('GET', KEYS[1]))
local userCount = tonumber(redis.call('GET', KEYS[2]))
if (not clickCount) then
local oldUserCount = redis.call('GET', KEYS[2])
if (not oldUserCount) then
redis.call('SET', KEYS[2], 1)
redis.call('SET', KEYS[3], 0)
return 1
end
userCount = tonumber(redis.call('INCR', KEYS[2]))
oldUserCount = tonumber(oldUserCount)
local globalAverage = tonumber(redis.call('GET', KEYS[3]))
redis.call('SET', KEYS[3], globalAverage / userCount * oldUserCount)
end
return userCount
''')
return int(await _script(keys=[f'total_{user_id}', 'user_count', 'global_average'], args=[]))
async def incr_user_total(r: redis.Redis, user_id: int, value: decimal.Decimal) -> decimal.Decimal:
return await r.incrbyfloat(f'total_{user_id}', float(value))
async def get_user_session(r: redis.Redis, user_id: int) -> Optional[str]:
session_bytes = await r.get(f'session_{user_id}')
if session_bytes is None:
return None
return session_bytes.decode()
async def set_user_session(r: redis.Redis, user_id: int, token: str) -> None:
await r.set(f'session_{user_id}', token, ex=30 * 60)
async def get_user_count(r: redis.Redis) -> int:
user_count_str = await r.get('user_count')
if user_count_str is None:
return 0
return int(user_count_str)
async def incr_user_count(r: redis.Redis) -> int:
_script = r.register_script('''
local oldCount = redis.call('GET', KEYS[1])
if (not oldCount) then
redis.call('SET', KEYS[1], 1)
redis.call('SET', KEYS[2], 0)
return 1
end
local newCount = tonumber(redis.call('INCR', KEYS[1]))
local globalAverage = tonumber(redis.call('GET', KEYS[2]))
redis.call('SET', KEYS[2], globalAverage / newCount * oldCount)
return newCount
''')
return int(await _script(keys=['user_count', 'global_average'], args=[]))
async def delete_user_info(r: redis.Redis, user_id: int, periods: List[int]):
_script = r.register_script('''
local userTotal = redis.call('GET', KEYS[3])
if (not userTotal) then
return
end
local oldUserCount = tonumber(redis.call('GET', KEYS[1]))
local newUserCount = tonumber(redis.call('DECR', KEYS[1]))
local globalAverage = tonumber(redis.call('GET', KEYS[2]))
redis.call('SET', KEYS[2], (globalAverage * oldUserCount - userTotal) / newUserCount)
for i, v in ipairs(KEYS) do
if (i > 2) then
redis.call('DEL', v)
end
end
''')
keys = [
'user_count',
'global_average',
f'total_{user_id}'
f'energy_{user_id}',
f'session_{user_id}',
]
for period in periods:
keys.append(f'period_{period}_user_{user_id}')
await _script(keys=keys, args=[])

View File

@ -10,8 +10,8 @@ CELERY_QUEUE_NAME = "celery"
CLICK_TASK_NAME = "clicks.celery.click.handle_click"
async def send_click_batch_copy(chan: aio_pika.Channel, click: Click, count: int):
args = (click.userId, int(click.dateTime.timestamp() * 1e3), str(click.value), count)
async def send_click(chan: aio_pika.Channel, click: Click):
args = (click.userId, int(click.dateTime.timestamp() * 1e3), str(click.value), click.count)
await chan.default_exchange.publish(
message=aio_pika.Message(
body=json.dumps([

View File

@ -2,87 +2,70 @@ from datetime import datetime
import decimal
from typing import Tuple
import aiohttp
import redis.asyncio as redis
import aio_pika
import asyncpg
import base64
from fastapi.exceptions import HTTPException
from app.src.domain.setting import get_setting
from .repos.redis import (
get_period_sum, incr_period_sum, get_max_period_sum, get_user_total, get_global_average,
incr_user_count_if_no_clicks, update_global_average, incr_user_total, compare_max_period_sum,
delete_user_info as r_delete_user_info, get_user_session, set_user_session, set_energy, get_energy as r_get_energy,
decr_energy,
from .repos.pg import (
store, delete_user_info as pg_delete_user_info, get_period_sum, get_max_period_sum, get_global_average, get_user_total, user_exists, get_user_session,
set_new_session, get_energy as pg_get_energy, decr_energy, add_user
)
from .repos.pg import update_click_expiry, bulk_store_copy, delete_by_user_id
from .repos.rmq import send_click_batch_copy
from .repos.rmq import send_click
from .models import Click
PRECISION = 2
async def add_click_batch_copy(r: redis.Redis, pg: asyncpg.Connection, rmq: aio_pika.Channel, user_id: int, count: int) -> Click:
_click_value = await click_value(r, pg, user_id)
click_value_sum = _click_value * count
# update variables
await incr_user_count_if_no_clicks(r, user_id)
await update_global_average(r, click_value_sum)
await incr_user_total(r, user_id, click_value_sum)
for period in (24, 24*7):
new_period_sum = await incr_period_sum(r, user_id, period, click_value_sum)
await compare_max_period_sum(r, period, new_period_sum)
async def add_click_batch_copy(pg: asyncpg.Connection, rmq: aio_pika.Channel, user_id: int, count: int) -> Click:
_click_value = await click_value(pg, user_id)
click = Click(
userId=user_id,
dateTime=datetime.now(),
value=_click_value,
count=count,
)
# insert click
await bulk_store_copy(pg, click, count)
await store(pg, click)
# send click to backend
await send_click_batch_copy(rmq, click, count)
await send_click(rmq, click)
return click
async def delete_user_info(r: redis.Redis, pg: asyncpg.Connection, user_id: int) -> None:
await r_delete_user_info(r, user_id, [24, 168])
await delete_by_user_id(pg, user_id)
async def delete_user_info(pg: asyncpg.Connection, user_id: int) -> None:
await pg_delete_user_info(pg, user_id)
async def click_value(r: redis.Redis, pg: asyncpg.Connection, user_id: int) -> decimal.Decimal:
async def click_value(pg: asyncpg.Connection, user_id: int) -> decimal.Decimal:
price_per_click = get_setting('PRICE_PER_CLICK')
day_multiplier = get_setting('DAY_MULT')
week_multiplier = get_setting('WEEK_MULT')
progress_multiplier = get_setting('PROGRESS_MULT')
# period coefficients
day_coef = await period_coefficient(r, pg, user_id, 24, day_multiplier)
week_coef = await period_coefficient(r, pg, user_id, 24*7, week_multiplier)
day_coef = await period_coefficient(pg, user_id, 24, day_multiplier)
week_coef = await period_coefficient(pg, user_id, 24*7, week_multiplier)
# progress coefficient
user_total = await get_user_total(r, user_id)
global_avg = await get_global_average(r)
user_total = await get_user_total(pg, user_id)
global_avg = await get_global_average(pg)
progress_coef = progress_coefficient(user_total, global_avg, progress_multiplier)
return round(price_per_click * day_coef * week_coef * progress_coef, PRECISION)
async def period_coefficient(r: redis.Redis, pg: asyncpg.Connection, user_id: int, period: int, multiplier: decimal.Decimal) -> decimal.Decimal:
current_sum = await get_period_sum(r, user_id, period)
expired_sum = await update_click_expiry(pg, user_id, period)
new_sum = current_sum - expired_sum
await incr_period_sum(r, user_id, period, -expired_sum)
max_period_sum = await get_max_period_sum(r, period)
async def period_coefficient(pg: asyncpg.Connection, user_id: int, period: int, multiplier: decimal.Decimal) -> decimal.Decimal:
current_period_sum = await get_period_sum(pg, user_id, period)
max_period_sum = await get_max_period_sum(pg, period)
if max_period_sum == decimal.Decimal(0):
return decimal.Decimal(1)
return new_sum * multiplier / max_period_sum + 1
return current_period_sum * multiplier / max_period_sum + 1
def progress_coefficient(user_total: decimal.Decimal, global_avg: decimal.Decimal, multiplier: decimal.Decimal) -> decimal.Decimal:
@ -91,34 +74,31 @@ def progress_coefficient(user_total: decimal.Decimal, global_avg: decimal.Decima
return min(global_avg * multiplier / user_total + 1, decimal.Decimal(2))
async def check_registration(r: redis.Redis, user_id: int, _token: str, backend_url: str) -> bool:
if await _has_any_clicks(r, user_id):
async def check_registration(pg: asyncpg.Connection, user_id: int, _token: str, backend_url: str) -> bool:
if await user_exists(pg, user_id):
return True
async with aiohttp.ClientSession() as session:
async with session.get(f'{backend_url}/api/v1/users/{user_id}', headers={'Authorization': f'TelegramToken {_token}'}) as resp:
return resp.status == 200
async def _has_any_clicks(r: redis.Redis, user_id: int) -> bool:
total_value = await get_user_total(r, user_id)
return total_value > decimal.Decimal(0)
async def _get_refresh_energy(r: redis.Redis, user_id: int, req_token: str) -> int:
async def _get_refresh_energy(pg: asyncpg.Connection, user_id: int, req_token: str) -> int:
new_auth_date = _auth_date_from_token(req_token)
current_token = await get_user_session(r, user_id)
current_token = await get_user_session(pg, user_id)
if current_token is None:
session_energy = int(get_setting('SESSION_ENERGY'))
await add_user(pg, user_id, req_token, session_energy)
return session_energy
if current_token != req_token:
if current_token is not None:
last_auth_date = _auth_date_from_token(current_token)
session_cooldown = get_setting('SESSION_COOLDOWN')
if new_auth_date - last_auth_date < session_cooldown:
raise HTTPException(status_code=403, detail='Unauthorized')
session_energy = int(get_setting('SESSION_ENERGY'))
await set_user_session(r, user_id, req_token)
await set_energy(r, user_id, session_energy)
await set_new_session(pg, user_id, req_token, session_energy)
return session_energy
else:
return await r_get_energy(r, user_id)
return await pg_get_energy(pg, user_id)
def _auth_date_from_token(token):
split_res = base64.b64decode(token).decode('utf-8').split(':')
@ -127,12 +107,12 @@ def _auth_date_from_token(token):
return int(data_dict['auth_date'])
async def check_energy(r: redis.Redis, user_id: int, amount: int, _token: str) -> Tuple[int, int]:
_energy = await _get_refresh_energy(r, user_id, _token)
async def check_energy(pg: asyncpg.Connection, user_id: int, amount: int, _token: str) -> Tuple[int, int]:
_energy = await _get_refresh_energy(pg, user_id, _token)
if _energy == 0:
return 0, 0
return await decr_energy(r, user_id, amount)
return await decr_energy(pg, user_id, amount)
async def get_energy(r: redis.Redis, user_id: int, _token: str) -> int:
return await _get_refresh_energy(r, user_id, _token)
async def get_energy(pg: asyncpg.Connection, user_id: int, _token: str) -> int:
return await _get_refresh_energy(pg, user_id, _token)

View File

@ -1,6 +1,5 @@
import aio_pika
import asyncpg
import redis
from fastapi import APIRouter, Depends, HTTPException
from typing import Annotated, Tuple
from ..domain.click import (
@ -9,7 +8,7 @@ from ..domain.click import (
)
from ..dependencies import get_token_header
from ..db import get_pg, get_redis, get_rmq
from ..db import get_pg, get_rmq
from ..config import BACKEND_URL
@ -22,16 +21,16 @@ router = APIRouter(
@router.post("/batch-click/", response_model=ClickResponse, status_code=200)
async def batch_click(req: BatchClickRequest, auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)], r: Annotated[redis.Redis, Depends(get_redis)], rmq: Annotated[aio_pika.Channel, Depends(get_rmq)]):
async def batch_click(req: BatchClickRequest, auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)], rmq: Annotated[aio_pika.Channel, Depends(get_rmq)]):
user_id, token = auth_info
if not await check_registration(r, user_id, token, BACKEND_URL):
if not await check_registration(pg, user_id, token, BACKEND_URL):
raise HTTPException(status_code=403, detail='Unauthorized')
_energy, spent = await check_energy(r, user_id, req.count, token)
_energy, spent = await check_energy(pg, user_id, req.count, token)
if spent == 0:
raise HTTPException(status_code=400, detail='No energy')
click = await add_click_batch_copy(r, pg, rmq, user_id, spent)
click = await add_click_batch_copy(pg, rmq, user_id, spent)
return ClickResponse(
click=click,
energy=_energy
@ -39,33 +38,33 @@ async def batch_click(req: BatchClickRequest, auth_info: Annotated[Tuple[int, st
@router.get("/energy", response_model=EnergyResponse, status_code=200)
async def energy(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)]):
async def energy(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]):
user_id, token = auth_info
if not await check_registration(r, user_id, token, BACKEND_URL):
if not await check_registration(pg, user_id, token, BACKEND_URL):
raise HTTPException(status_code=403, detail='Unauthorized')
_energy = await get_energy(r, user_id, token)
_energy = await get_energy(pg, user_id, token)
return EnergyResponse(
energy=_energy
)
@router.get('/coefficient', response_model=ClickValueResponse, status_code=200)
async def coefficient(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]):
async def coefficient(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]):
user_id, token = auth_info
if not await check_registration(r, user_id, token, BACKEND_URL):
if not await check_registration(pg, user_id, token, BACKEND_URL):
raise HTTPException(status_code=403, detail='Unauthorized')
value = await click_value(r, pg, user_id)
value = await click_value(pg, pg, user_id)
return ClickValueResponse(
value=value
)
@router.delete('/internal/user', status_code=204)
async def delete_user(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], r: Annotated[redis.Redis, Depends(get_redis)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]):
async def delete_user(auth_info: Annotated[Tuple[int, str], Depends(get_token_header)], pg: Annotated[asyncpg.Connection, Depends(get_pg)]):
user_id, token = auth_info
if not await check_registration(r, user_id, token, BACKEND_URL):
if not await check_registration(pg, user_id, token, BACKEND_URL):
raise HTTPException(status_code=403, detail='Unauthorized')
await delete_user_info(r, pg, user_id)
await delete_user_info(pg, user_id)

View File

@ -20,7 +20,6 @@ pamqp==3.3.0
propcache==0.2.0
pydantic==2.9.2
pydantic_core==2.23.4
redis==5.1.1
requests==2.32.3
sniffio==1.3.1
starlette==0.40.0

View File

@ -1,7 +1,6 @@
volumes:
db_data: {}
batcher_db_data: {}
redis_data: {}
backend_media: {}
backend_static: {}
@ -84,26 +83,13 @@ services:
interval: 10s
timeout: 2s
redis:
env_file:
- .env/prod/redis
image: redis
command: bash -c "redis-server --appendonly yes --requirepass $${REDIS_PASSWORD}"
volumes:
- redis_data:/data
healthcheck:
<<: *pg-healthcheck
test: "[ $$(redis-cli -a $${REDIS_PASSWORD} ping) = 'PONG' ]"
batcher:
build: ./batcher
depends_on:
redis: *healthy-dependency
batcher-postgres: *healthy-dependency
rabbitmq: *healthy-dependency
env_file:
- .env/prod/rmq
- .env/prod/redis
- .env/prod/batcher-pg
- .env/prod/batcher
- .env/prod/bot

View File

@ -1,7 +1,6 @@
volumes:
db_data: {}
batcher_db_data: {}
redis_data: {}
backend_media: {}
backend_static: {}
@ -66,28 +65,13 @@ services:
interval: 10s
timeout: 2s
redis:
env_file:
- .env/dev/redis
image: redis
command: bash -c "redis-server --appendonly yes --requirepass $${REDIS_PASSWORD}"
ports:
- '6379:6379'
volumes:
- redis_data:/data
healthcheck:
<<: *pg-healthcheck
test: "[ $$(redis-cli -a $$REDIS_PASSWORD ping) = 'PONG' ]"
batcher:
build: ./batcher
depends_on:
redis: *healthy-dependency
batcher-postgres: *healthy-dependency
rabbitmq: *healthy-dependency
env_file:
- .env/dev/rmq
- .env/dev/redis
- .env/dev/batcher-pg
- .env/dev/batcher
- .env/dev/bot