Сохранение контекста ИИ в БД.

Имя ИИ-модели и температура задаются в конфиге.
This commit is contained in:
Kirill Kirilenko 2026-01-24 01:40:30 +03:00
parent 73cd047c82
commit 01de81cd3f
8 changed files with 176 additions and 102 deletions

View file

@ -1,9 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Dict, Optional from typing import List, Tuple
from openrouter import OpenRouter, RetryConfig from openrouter import OpenRouter, RetryConfig
from openrouter.utils import BackoffStrategy from openrouter.utils import BackoffStrategy
from database import BasicDatabase
GROUP_CHAT_SYSTEM_PROMPT = """ GROUP_CHAT_SYSTEM_PROMPT = """
Ты - ИИ-помощник в групповом чате.\n Ты - ИИ-помощник в групповом чате.\n
@ -12,63 +13,80 @@ GROUP_CHAT_SYSTEM_PROMPT = """
Сообщения пользователей будут приходить в следующем формате: '[Имя]: текст сообщения'\n Сообщения пользователей будут приходить в следующем формате: '[Имя]: текст сообщения'\n
При ответе НЕ нужно указывать пользователя, которому он предназначен. При ответе НЕ нужно указывать пользователя, которому он предназначен.
""" """
GROUP_CHAT_MAX_MESSAGES = 20
PRIVATE_CHAT_SYSTEM_PROMPT = """ PRIVATE_CHAT_SYSTEM_PROMPT = """
Ты - ИИ-помощник в чате c пользователем.\n Ты - ИИ-помощник в чате c пользователем.\n
Отвечай на вопросы и поддерживай контекст беседы. Отвечай на вопросы и поддерживай контекст беседы.
""" """
PRIVATE_CHAT_MAX_MESSAGES = 40
class ChatContext:
def __init__(self, max_messages: int):
self.max_messages: int = max_messages
self.messages: List[Dict[str, str]] = []
def add_message(self, role: str, content: str):
if len(self.messages) == self.max_messages:
# Всегда сохраняем в контексте системное сообщение
self.messages.pop(1)
self.messages.append({"role": role, "content": content})
def get_messages_for_api(self) -> List[Dict[str, str]]:
return self.messages
def remove_last_message(self):
self.messages.pop()
@dataclass() @dataclass()
class Message: class Message:
user_name: str = None user_name: str = None
text: str = None text: str = None
message_id: int = None
class AiAgent: class AiAgent:
def __init__(self, api_token: str): def __init__(self, api_token: str, model: str, model_temp: float, db: BasicDatabase):
retry_config = RetryConfig(strategy="backoff", retry_config = RetryConfig(strategy="backoff",
backoff=BackoffStrategy( backoff=BackoffStrategy(
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000), initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
retry_connection_errors=True) retry_connection_errors=True)
self.model = "meta-llama/llama-3.3-70b-instruct:free" self.db = db
self.model = model
self.model_temp = model_temp
self.client = OpenRouter(api_key=api_token, retry_config=retry_config) self.client = OpenRouter(api_key=api_token, retry_config=retry_config)
self.chat_contexts: Dict[int, ChatContext] = {}
async def get_group_chat_reply(self, chat_id: int, chat_prompt: str, async def get_group_chat_reply(self, chat_id: int,
message: Message, forwarded_messages: List[Message]) -> str: message: Message, forwarded_messages: List[Message]) -> Tuple[str, bool]:
message_text = message.text message_text = f"[{message.user_name}]: {message.text}"
for fwd_message in forwarded_messages: for fwd_message in forwarded_messages:
message_text += '\n<Цитируемое сообщение от {}>\n'.format(fwd_message.user_name) message_text += '\n<Цитируемое сообщение от {}>\n'.format(fwd_message.user_name)
message_text += fwd_message.text + '\n' message_text += fwd_message.text + '\n'
message_text += '<Конец цитаты>' message_text += '<Конец цитаты>'
context = self._get_chat_context(is_group_chat=True, chat_id=chat_id, chat_prompt=chat_prompt) context = self._get_chat_context(is_group_chat=True, chat_id=chat_id)
context.add_message(role="user", content=f"[{message.user_name}]: {message_text}") context.append({"role": "user", "content": message_text})
try: try:
# Get response from OpenRouter # Get response from OpenRouter
response = await self.client.chat.send_async( response = await self.client.chat.send_async(
model=self.model, model=self.model,
messages=context.get_messages_for_api(), messages=context,
max_tokens=500,
temperature=self.model_temp
)
# Extract AI response
ai_response = response.choices[0].message.content
# Add message and AI response to context
self.db.context_add_message(chat_id=chat_id, role="user", content=message_text,
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
self.db.context_add_message(chat_id=chat_id, role="assistant", content=ai_response,
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
return ai_response, True
except Exception as e:
if str(e).find("Rate limit exceeded") != -1:
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК).", False
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return f"Извините, при обработке запроса произошла ошибка.", False
async def get_private_chat_reply(self, chat_id: int, message: str, message_id: int) -> Tuple[str, bool]:
context = self._get_chat_context(is_group_chat=False, chat_id=chat_id)
context.append({"role": "user", "content": message})
try:
# Get response from OpenRouter
response = await self.client.chat.send_async(
model=self.model,
messages=context,
max_tokens=500, max_tokens=500,
temperature=0.5 temperature=0.5
) )
@ -76,71 +94,41 @@ class AiAgent:
# Extract AI response # Extract AI response
ai_response = response.choices[0].message.content ai_response = response.choices[0].message.content
# Add AI response to context # Add message and AI response to context
context.add_message(role="assistant", content=ai_response) self.db.context_add_message(chat_id=chat_id, role="user", content=message,
message_id=message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
self.db.context_add_message(chat_id=chat_id, role="assistant", content=ai_response,
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
return ai_response return ai_response, True
except Exception as e: except Exception as e:
context.remove_last_message()
if str(e).find("Rate limit exceeded") != -1: if str(e).find("Rate limit exceeded") != -1:
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)." return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК).", False
else: else:
print(f"Ошибка выполнения запроса к ИИ: {e}") print(f"Ошибка выполнения запроса к ИИ: {e}")
return f"Извините, при обработке запроса произошла ошибка." return f"Извините, при обработке запроса произошла ошибка.", False
async def get_private_chat_reply(self, chat_id: int, chat_prompt: str, message: str) -> str: def set_last_response_id(self, chat_id: int, message_id: int):
context = self._get_chat_context(is_group_chat=False, chat_id=chat_id, chat_prompt=chat_prompt) self.db.context_set_last_message_id(chat_id, message_id)
context.add_message(role="user", content=message)
try:
# Get response from OpenRouter
response = await self.client.chat.send_async(
model=self.model,
messages=context.get_messages_for_api(),
max_tokens=500,
temperature=0.5
)
# Extract AI response
ai_response = response.choices[0].message.content
# Add AI response to context
context.add_message(role="assistant", content=ai_response)
return ai_response
except Exception as e:
context.remove_last_message()
if str(e).find("Rate limit exceeded") != -1:
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return f"Извините, при обработке запроса произошла ошибка."
def clear_chat_context(self, chat_id: int): def clear_chat_context(self, chat_id: int):
self.chat_contexts.pop(chat_id, None) self.db.context_clear(chat_id)
def _get_chat_context(self, is_group_chat: bool, chat_id: int, chat_prompt: Optional[str]) -> ChatContext: def _get_chat_context(self, is_group_chat: bool, chat_id: int) -> list[dict]:
"""Get or create chat context for a specific chat""" prompt = GROUP_CHAT_SYSTEM_PROMPT if is_group_chat else PRIVATE_CHAT_SYSTEM_PROMPT
if chat_id not in self.chat_contexts:
if is_group_chat:
self.chat_contexts[chat_id] = ChatContext(max_messages=20)
prompt = GROUP_CHAT_SYSTEM_PROMPT
else:
self.chat_contexts[chat_id] = ChatContext(max_messages=40)
prompt = PRIVATE_CHAT_SYSTEM_PROMPT
if chat_prompt is not None: chat = self.db.create_chat_if_not_exists(chat_id)
prompt += '\n\n' + chat_prompt if chat['ai_prompt'] is not None:
prompt += '\n\n' + chat['ai_prompt']
self.chat_contexts[chat_id].add_message(role="system", content=prompt) messages = self.db.context_get_messages(chat_id)
return self.chat_contexts[chat_id] return [{"role": "system", "content": prompt}] + messages
agent: AiAgent agent: AiAgent
def create_ai_agent(api_token: str): def create_ai_agent(api_token: str, model: str, model_temp: float, db: BasicDatabase):
global agent global agent
agent = AiAgent(api_token) agent = AiAgent(api_token, model, model_temp, db)

View file

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import List, Union from typing import List, Optional, Union
from pyodbc import connect, SQL_CHAR, SQL_WCHAR, Row from pyodbc import connect, SQL_CHAR, SQL_WCHAR, Row
@ -106,6 +106,58 @@ class BasicDatabase:
def reset_messages_month(self): def reset_messages_month(self):
self.cursor.execute("UPDATE users SET messages_month = 0") self.cursor.execute("UPDATE users SET messages_month = 0")
def context_get_messages(self, chat_id: int) -> list[dict]:
self.cursor.execute("""
SELECT role, content FROM contexts
WHERE chat_id = ? AND message_id IS NOT NULL
ORDER BY message_id
""", chat_id)
return self._to_dict(self.cursor.fetchall())
def context_get_count(self, chat_id: int) -> int:
self.cursor.execute("SELECT COUNT(*) FROM contexts WHERE chat_id = ?", chat_id)
return self.cursor.fetchval()
def context_add_message(self, chat_id: int, role: str, content: str, message_id: Optional[int], max_messages: int):
self._context_trim(chat_id, max_messages)
if message_id is not None:
self.cursor.execute("""
INSERT INTO contexts (chat_id, message_id, role, content)
VALUES (?, ?, ?, ?)
""", chat_id, message_id, role, content)
else:
self.cursor.execute("""
INSERT INTO contexts (chat_id, role, content)
VALUES (?, ?, ?)
""", chat_id, role, content)
def context_set_last_message_id(self, chat_id: int, message_id: int):
self.cursor.execute("""
UPDATE contexts SET message_id = ?
WHERE chat_id = ? AND message_id IS NULL
""", message_id, chat_id)
def _context_trim(self, chat_id: int, max_messages: int):
current_count = self.context_get_count(chat_id)
while current_count >= max_messages:
oldest_message_id = self.cursor.execute("""
SELECT message_id FROM contexts
WHERE chat_id = ? AND message_id IS NOT NULL
ORDER BY message_id ASC
LIMIT 1
""", chat_id).fetchval()
if oldest_message_id:
self.cursor.execute("DELETE FROM contexts WHERE chat_id = ? AND message_id = ?",
chat_id, oldest_message_id)
current_count -= 1
else:
break
def context_clear(self, chat_id: int):
self.cursor.execute("DELETE FROM contexts WHERE chat_id = ?", chat_id)
def create_chat_if_not_exists(self, chat_id: int): def create_chat_if_not_exists(self, chat_id: int):
chat = self.get_chat(chat_id) chat = self.get_chat(chat_id)
if chat is None: if chat is None:
@ -130,7 +182,7 @@ class BasicDatabase:
result[column] = args[i] result[column] = args[i]
return result return result
elif isinstance(args, list) and all(isinstance(item, Row) for item in args): elif isinstance(args, list) and all(isinstance(item, Row) for item in args):
results = [] results: list[dict] = []
for row in args: for row in args:
row_dict = {} row_dict = {}
for i, column in enumerate(columns): for i, column in enumerate(columns):

View file

@ -4,7 +4,8 @@ import json
from aiogram import Bot, Dispatcher from aiogram import Bot, Dispatcher
from ai_agent import create_ai_agent from ai_agent import create_ai_agent
from tg.tg_database import create_database
import tg.tg_database as database
from . import handlers from . import handlers
from . import tasks from . import tasks
@ -16,8 +17,11 @@ async def main() -> None:
print('Конфигурация загружена.') print('Конфигурация загружена.')
bot = Bot(token=config['api_token']) bot = Bot(token=config['api_token'])
create_database(config['db_connection_string']) database.create_database(config['db_connection_string'])
create_ai_agent(config['openrouter_token']) create_ai_agent(config['openrouter_token'],
config['openrouter_model'],
config['openrouter_model_temp'],
database.DB)
dp = Dispatcher() dp = Dispatcher()
dp.include_router(handlers.router) dp.include_router(handlers.router)

View file

@ -85,10 +85,13 @@ async def any_message_handler(message: Message):
return return
ai_message.user_name = await get_user_name_for_ai(message.from_user) ai_message.user_name = await get_user_name_for_ai(message.from_user)
chat_prompt = chat['ai_prompt'] ai_message.message_id = message.message_id
await message.reply( answer, success = await utils.run_with_progress(
await utils.run_with_progress( partial(ai_agent.agent.get_group_chat_reply, chat_id, ai_message, ai_fwd_messages),
partial(ai_agent.agent.get_group_chat_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages),
partial(message.bot.send_chat_action, chat_id, 'typing'), partial(message.bot.send_chat_action, chat_id, 'typing'),
interval=4)) interval=4)
answer_id = (await message.reply(answer)).message_id
if success:
ai_agent.agent.set_last_response_id(chat_id, answer_id)

View file

@ -7,7 +7,7 @@ class TgDatabase(database.BasicDatabase):
self.cursor.execute(""" self.cursor.execute("""
CREATE TABLE IF NOT EXISTS chats ( CREATE TABLE IF NOT EXISTS chats (
id BIGINT, id BIGINT NOT NULL,
active TINYINT NOT NULL DEFAULT 0, active TINYINT NOT NULL DEFAULT 0,
rules VARCHAR(4000), rules VARCHAR(4000),
greeting_join VARCHAR(2000), greeting_join VARCHAR(2000),
@ -17,8 +17,8 @@ class TgDatabase(database.BasicDatabase):
self.cursor.execute(""" self.cursor.execute("""
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
chat_id BIGINT, chat_id BIGINT NOT NULL,
user_id BIGINT, user_id BIGINT NOT NULL,
last_message BIGINT NOT NULL DEFAULT 0, last_message BIGINT NOT NULL DEFAULT 0,
messages_today SMALLINT NOT NULL DEFAULT 0, messages_today SMALLINT NOT NULL DEFAULT 0,
messages_month SMALLINT NOT NULL DEFAULT 0, messages_month SMALLINT NOT NULL DEFAULT 0,
@ -28,6 +28,16 @@ class TgDatabase(database.BasicDatabase):
CONSTRAINT fk_users_chats FOREIGN KEY (chat_id) REFERENCES chats (id) ON UPDATE CASCADE ON DELETE CASCADE) CONSTRAINT fk_users_chats FOREIGN KEY (chat_id) REFERENCES chats (id) ON UPDATE CASCADE ON DELETE CASCADE)
""") """)
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS contexts (
chat_id BIGINT NOT NULL,
message_id BIGINT,
role VARCHAR(16) NOT NULL,
content VARCHAR(2000) NOT NULL,
UNIQUE KEY contexts_unique (chat_id, message_id),
CONSTRAINT fk_contexts_chats FOREIGN KEY (chat_id) REFERENCES chats (id) ON UPDATE CASCADE ON DELETE CASCADE)
""")
self.conn.commit() self.conn.commit()

View file

@ -3,7 +3,8 @@ import json
from vkbottle.bot import Bot as VkBot from vkbottle.bot import Bot as VkBot
from ai_agent import create_ai_agent from ai_agent import create_ai_agent
from vk.vk_database import create_database
import vk.vk_database as database
from . import handlers from . import handlers
from . import tasks from . import tasks
@ -15,8 +16,11 @@ if __name__ == '__main__':
print('Конфигурация загружена.') print('Конфигурация загружена.')
bot = VkBot(config['api_token'], labeler=handlers.labeler) bot = VkBot(config['api_token'], labeler=handlers.labeler)
create_database(config['db_connection_string']) database.create_database(config['db_connection_string'])
create_ai_agent(config["openrouter_token"]) create_ai_agent(config['openrouter_token'],
config['openrouter_model'],
config['openrouter_model_temp'],
database.DB)
bot.loop_wrapper.on_startup.append(tasks.startup_task(bot.api)) bot.loop_wrapper.on_startup.append(tasks.startup_task(bot.api))
bot.loop_wrapper.add_task(tasks.daily_maintenance_task(bot.api)) bot.loop_wrapper.add_task(tasks.daily_maintenance_task(bot.api))

View file

@ -87,10 +87,13 @@ async def any_message_handler(message: Message):
return return
ai_message.user_name = await get_user_name_for_ai(message.ctx_api, message.from_id) ai_message.user_name = await get_user_name_for_ai(message.ctx_api, message.from_id)
chat_prompt = chat['ai_prompt'] ai_message.message_id = message.conversation_message_id
await message.reply( answer, success = await utils.run_with_progress(
await utils.run_with_progress( partial(ai_agent.agent.get_group_chat_reply, chat_id, ai_message, ai_fwd_messages),
partial(ai_agent.agent.get_group_chat_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages),
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'), partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
interval=4)) interval=4)
answer_id = (await message.reply(answer)).conversation_message_id
if success:
ai_agent.agent.set_last_response_id(chat_id, answer_id)

View file

@ -7,7 +7,7 @@ class VkDatabase(database.BasicDatabase):
self.cursor.execute(""" self.cursor.execute("""
CREATE TABLE IF NOT EXISTS chats ( CREATE TABLE IF NOT EXISTS chats (
id BIGINT, id BIGINT NOT NULL,
active TINYINT NOT NULL DEFAULT 0, active TINYINT NOT NULL DEFAULT 0,
rules VARCHAR(4000), rules VARCHAR(4000),
greeting_join VARCHAR(2000), greeting_join VARCHAR(2000),
@ -19,8 +19,8 @@ class VkDatabase(database.BasicDatabase):
self.cursor.execute(""" self.cursor.execute("""
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
chat_id BIGINT, chat_id BIGINT NOT NULL,
user_id BIGINT, user_id BIGINT NOT NULL,
last_message BIGINT NOT NULL DEFAULT 0, last_message BIGINT NOT NULL DEFAULT 0,
messages_today SMALLINT NOT NULL DEFAULT 0, messages_today SMALLINT NOT NULL DEFAULT 0,
messages_month SMALLINT NOT NULL DEFAULT 0, messages_month SMALLINT NOT NULL DEFAULT 0,
@ -31,6 +31,16 @@ class VkDatabase(database.BasicDatabase):
CONSTRAINT fk_users_chats FOREIGN KEY (chat_id) REFERENCES chats (id) ON UPDATE CASCADE ON DELETE CASCADE) CONSTRAINT fk_users_chats FOREIGN KEY (chat_id) REFERENCES chats (id) ON UPDATE CASCADE ON DELETE CASCADE)
""") """)
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS contexts (
chat_id BIGINT NOT NULL,
message_id BIGINT,
role VARCHAR(16) NOT NULL,
content VARCHAR(2000) NOT NULL,
UNIQUE KEY contexts_unique (chat_id, message_id),
CONSTRAINT fk_contexts_chats FOREIGN KEY (chat_id) REFERENCES chats (id) ON UPDATE CASCADE ON DELETE CASCADE)
""")
self.conn.commit() self.conn.commit()
def user_toggle_happy_birthday(self, chat_id: int, user_id: int, happy_birthday: int): def user_toggle_happy_birthday(self, chat_id: int, user_id: int, happy_birthday: int):