vk_chat_bot/ai_agent.py

181 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
from dataclasses import dataclass
from typing import List, Tuple, Any, Optional
from openrouter import OpenRouter, RetryConfig
from openrouter.utils import BackoffStrategy
from database import BasicDatabase
GROUP_CHAT_SYSTEM_PROMPT = """
Ты - ИИ-помощник в групповом чате.\n
Отвечай на вопросы и поддерживай контекст беседы.\n
Ты не можешь обсуждать политику и религию.\n
Сообщения пользователей будут приходить в следующем формате: '[Имя]: текст сообщения'\n
При ответе НЕ нужно указывать ни пользователя, которому он предназначен, ни свое имя.
"""
GROUP_CHAT_MAX_MESSAGES = 20
PRIVATE_CHAT_SYSTEM_PROMPT = """
Ты - ИИ-помощник в чате c пользователем.\n
Отвечай на вопросы и поддерживай контекст беседы.
"""
PRIVATE_CHAT_MAX_MESSAGES = 40
OPENROUTER_HEADERS = {
'HTTP-Referer': 'https://ultracoder.org',
'X-Title': 'TG/VK Chat Bot'
}
@dataclass()
class Message:
user_name: str = None
text: str = None
image: bytes = None
message_id: int = None
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
json = {"role": role, "content": []}
if text is not None:
json["content"].append({"type": "text", "text": text})
if image is not None:
encoded_image = base64.b64encode(image).decode('utf-8')
image_url = f"data:image/jpeg;base64,{encoded_image}"
json["content"].append({"type": "image_url", "image_url": {"url": image_url}})
return json
class AiAgent:
def __init__(self, api_token: str, model: str, db: BasicDatabase, platform: str):
retry_config = RetryConfig(strategy="backoff",
backoff=BackoffStrategy(
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
retry_connection_errors=True)
self.db = db
self.model = model
self.platform = platform
self.client = OpenRouter(api_key=api_token, retry_config=retry_config)
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message]) -> Tuple[str, bool]:
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
if message.text is not None:
message.text = f"[{message.user_name}]: {message.text}"
else:
message.text = f"[{message.user_name}]:"
context.append(_serialize_message(role="user", text=message.text, image=message.image))
for fwd_message in forwarded_messages:
message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name)
if fwd_message.text is not None:
message_text += '\n' + fwd_message.text
fwd_message.text = message_text
context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image))
try:
# Get response from OpenRouter
response = await self.client.chat.send_async(
model=self.model,
messages=context,
max_tokens=500,
user=f'{self.platform}_{bot_id}_{chat_id}',
http_headers=OPENROUTER_HEADERS
)
# Extract AI response
ai_response = response.choices[0].message.content
# Add input messages and AI response to context
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
for fwd_message in forwarded_messages:
self.db.context_add_message(bot_id, chat_id, role="user", text=fwd_message.text, image=fwd_message.image,
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None,
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
return ai_response, True
except Exception as e:
if str(e).find("Rate limit exceeded") != -1:
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК).", False
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return f"Извините, при обработке запроса произошла ошибка.", False
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[str, bool]:
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
content: list[dict[str, Any]] = []
if message.text is not None:
content.append({"type": "text", "text": message.text})
if message.image is not None:
encoded_image = base64.b64encode(message.image).decode('utf-8')
image_url = f"data:image/jpeg;base64,{encoded_image}"
content.append({"type": "image_url", "image_url": {"url": image_url}})
context.append({"role": "user", "content": content})
try:
# Get response from OpenRouter
response = await self.client.chat.send_async(
model=self.model,
messages=context,
max_tokens=500,
user=f'{self.platform}_{bot_id}_{chat_id}',
http_headers=OPENROUTER_HEADERS
)
# Extract AI response
ai_response = response.choices[0].message.content
# Add message and AI response to context
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None,
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
return ai_response, True
except Exception as e:
if str(e).find("Rate limit exceeded") != -1:
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК).", False
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return f"Извините, при обработке запроса произошла ошибка.", False
def get_last_assistant_message_id(self, bot_id: int, chat_id: int):
return self.db.context_get_last_assistant_message_id(bot_id, chat_id)
def set_last_response_id(self, bot_id: int, chat_id: int, message_id: int):
self.db.context_set_last_message_id(bot_id, chat_id, message_id)
def clear_chat_context(self, bot_id: int, chat_id: int):
self.db.context_clear(bot_id, chat_id)
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> list[dict]:
prompt = GROUP_CHAT_SYSTEM_PROMPT if is_group_chat else PRIVATE_CHAT_SYSTEM_PROMPT
bot = self.db.get_bot(bot_id)
if bot['ai_prompt'] is not None:
prompt += '\n\n' + bot['ai_prompt']
chat = self.db.create_chat_if_not_exists(bot_id, chat_id)
if chat['ai_prompt'] is not None:
prompt += '\n\n' + chat['ai_prompt']
messages = self.db.context_get_messages(bot_id, chat_id)
context = [{"role": "system", "content": prompt}]
for message in messages:
context.append(_serialize_message(message["role"], message["text"], message["image"]))
return context
agent: AiAgent
def create_ai_agent(api_token: str, model: str, db: BasicDatabase, platform: str):
global agent
agent = AiAgent(api_token, model, db, platform)