vk_chat_bot/ai_agent.py

137 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass
from typing import List, Tuple
from openrouter import OpenRouter, RetryConfig
from openrouter.utils import BackoffStrategy
from database import BasicDatabase
GROUP_CHAT_SYSTEM_PROMPT = """
Ты - ИИ-помощник в групповом чате.\n
Отвечай на вопросы и поддерживай контекст беседы.\n
Ты не можешь обсуждать политику и религию.\n
Сообщения пользователей будут приходить в следующем формате: '[Имя]: текст сообщения'\n
При ответе НЕ нужно указывать ни пользователя, которому он предназначен, ни свое имя.
"""
GROUP_CHAT_MAX_MESSAGES = 20
PRIVATE_CHAT_SYSTEM_PROMPT = """
Ты - ИИ-помощник в чате c пользователем.\n
Отвечай на вопросы и поддерживай контекст беседы.
"""
PRIVATE_CHAT_MAX_MESSAGES = 40
@dataclass()
class Message:
user_name: str = None
text: str = None
message_id: int = None
class AiAgent:
def __init__(self, api_token: str, model: str, model_temp: float, db: BasicDatabase):
retry_config = RetryConfig(strategy="backoff",
backoff=BackoffStrategy(
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
retry_connection_errors=True)
self.db = db
self.model = model
self.model_temp = model_temp
self.client = OpenRouter(api_key=api_token, retry_config=retry_config)
async def get_group_chat_reply(self, chat_id: int,
message: Message, forwarded_messages: List[Message]) -> Tuple[str, bool]:
message_text = f"[{message.user_name}]: {message.text}"
for fwd_message in forwarded_messages:
message_text += '\n<Цитируемое сообщение от {}>\n'.format(fwd_message.user_name)
message_text += fwd_message.text + '\n'
message_text += '<Конец цитаты>'
context = self._get_chat_context(is_group_chat=True, chat_id=chat_id)
context.append({"role": "user", "content": message_text})
try:
# Get response from OpenRouter
response = await self.client.chat.send_async(
model=self.model,
messages=context,
max_tokens=500,
temperature=self.model_temp
)
# Extract AI response
ai_response = response.choices[0].message.content
# Add message and AI response to context
self.db.context_add_message(chat_id=chat_id, role="user", content=message_text,
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
self.db.context_add_message(chat_id=chat_id, role="assistant", content=ai_response,
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
return ai_response, True
except Exception as e:
if str(e).find("Rate limit exceeded") != -1:
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК).", False
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return f"Извините, при обработке запроса произошла ошибка.", False
async def get_private_chat_reply(self, chat_id: int, message: str, message_id: int) -> Tuple[str, bool]:
context = self._get_chat_context(is_group_chat=False, chat_id=chat_id)
context.append({"role": "user", "content": message})
try:
# Get response from OpenRouter
response = await self.client.chat.send_async(
model=self.model,
messages=context,
max_tokens=500,
temperature=0.5
)
# Extract AI response
ai_response = response.choices[0].message.content
# Add message and AI response to context
self.db.context_add_message(chat_id=chat_id, role="user", content=message,
message_id=message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
self.db.context_add_message(chat_id=chat_id, role="assistant", content=ai_response,
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
return ai_response, True
except Exception as e:
if str(e).find("Rate limit exceeded") != -1:
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК).", False
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return f"Извините, при обработке запроса произошла ошибка.", False
def get_last_assistant_message_id(self, chat_id: int):
return self.db.context_get_last_assistant_message_id(chat_id)
def set_last_response_id(self, chat_id: int, message_id: int):
self.db.context_set_last_message_id(chat_id, message_id)
def clear_chat_context(self, chat_id: int):
self.db.context_clear(chat_id)
def _get_chat_context(self, is_group_chat: bool, chat_id: int) -> list[dict]:
prompt = GROUP_CHAT_SYSTEM_PROMPT if is_group_chat else PRIVATE_CHAT_SYSTEM_PROMPT
chat = self.db.create_chat_if_not_exists(chat_id)
if chat['ai_prompt'] is not None:
prompt += '\n\n' + chat['ai_prompt']
messages = self.db.context_get_messages(chat_id)
return [{"role": "system", "content": prompt}] + messages
agent: AiAgent
def create_ai_agent(api_token: str, model: str, model_temp: float, db: BasicDatabase):
global agent
agent = AiAgent(api_token, model, model_temp, db)