134 lines
5.9 KiB
Python
134 lines
5.9 KiB
Python
from dataclasses import dataclass
|
||
from typing import List, Tuple
|
||
|
||
from openrouter import OpenRouter, RetryConfig
|
||
from openrouter.utils import BackoffStrategy
|
||
|
||
from database import BasicDatabase
|
||
|
||
GROUP_CHAT_SYSTEM_PROMPT = """
|
||
Ты - ИИ-помощник в групповом чате.\n
|
||
Отвечай на вопросы и поддерживай контекст беседы.\n
|
||
Ты не можешь обсуждать политику и религию.\n
|
||
Сообщения пользователей будут приходить в следующем формате: '[Имя]: текст сообщения'\n
|
||
При ответе НЕ нужно указывать пользователя, которому он предназначен.
|
||
"""
|
||
GROUP_CHAT_MAX_MESSAGES = 20
|
||
|
||
PRIVATE_CHAT_SYSTEM_PROMPT = """
|
||
Ты - ИИ-помощник в чате c пользователем.\n
|
||
Отвечай на вопросы и поддерживай контекст беседы.
|
||
"""
|
||
PRIVATE_CHAT_MAX_MESSAGES = 40
|
||
|
||
|
||
@dataclass()
|
||
class Message:
|
||
user_name: str = None
|
||
text: str = None
|
||
message_id: int = None
|
||
|
||
|
||
class AiAgent:
|
||
def __init__(self, api_token: str, model: str, model_temp: float, db: BasicDatabase):
|
||
retry_config = RetryConfig(strategy="backoff",
|
||
backoff=BackoffStrategy(
|
||
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
|
||
retry_connection_errors=True)
|
||
self.db = db
|
||
self.model = model
|
||
self.model_temp = model_temp
|
||
self.client = OpenRouter(api_key=api_token, retry_config=retry_config)
|
||
|
||
async def get_group_chat_reply(self, chat_id: int,
|
||
message: Message, forwarded_messages: List[Message]) -> Tuple[str, bool]:
|
||
message_text = f"[{message.user_name}]: {message.text}"
|
||
for fwd_message in forwarded_messages:
|
||
message_text += '\n<Цитируемое сообщение от {}>\n'.format(fwd_message.user_name)
|
||
message_text += fwd_message.text + '\n'
|
||
message_text += '<Конец цитаты>'
|
||
|
||
context = self._get_chat_context(is_group_chat=True, chat_id=chat_id)
|
||
context.append({"role": "user", "content": message_text})
|
||
|
||
try:
|
||
# Get response from OpenRouter
|
||
response = await self.client.chat.send_async(
|
||
model=self.model,
|
||
messages=context,
|
||
max_tokens=500,
|
||
temperature=self.model_temp
|
||
)
|
||
|
||
# Extract AI response
|
||
ai_response = response.choices[0].message.content
|
||
|
||
# Add message and AI response to context
|
||
self.db.context_add_message(chat_id=chat_id, role="user", content=message_text,
|
||
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||
self.db.context_add_message(chat_id=chat_id, role="assistant", content=ai_response,
|
||
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||
|
||
return ai_response, True
|
||
|
||
except Exception as e:
|
||
if str(e).find("Rate limit exceeded") != -1:
|
||
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК).", False
|
||
else:
|
||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||
return f"Извините, при обработке запроса произошла ошибка.", False
|
||
|
||
async def get_private_chat_reply(self, chat_id: int, message: str, message_id: int) -> Tuple[str, bool]:
|
||
context = self._get_chat_context(is_group_chat=False, chat_id=chat_id)
|
||
context.append({"role": "user", "content": message})
|
||
|
||
try:
|
||
# Get response from OpenRouter
|
||
response = await self.client.chat.send_async(
|
||
model=self.model,
|
||
messages=context,
|
||
max_tokens=500,
|
||
temperature=0.5
|
||
)
|
||
|
||
# Extract AI response
|
||
ai_response = response.choices[0].message.content
|
||
|
||
# Add message and AI response to context
|
||
self.db.context_add_message(chat_id=chat_id, role="user", content=message,
|
||
message_id=message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||
self.db.context_add_message(chat_id=chat_id, role="assistant", content=ai_response,
|
||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||
|
||
return ai_response, True
|
||
|
||
except Exception as e:
|
||
if str(e).find("Rate limit exceeded") != -1:
|
||
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК).", False
|
||
else:
|
||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||
return f"Извините, при обработке запроса произошла ошибка.", False
|
||
|
||
def set_last_response_id(self, chat_id: int, message_id: int):
|
||
self.db.context_set_last_message_id(chat_id, message_id)
|
||
|
||
def clear_chat_context(self, chat_id: int):
|
||
self.db.context_clear(chat_id)
|
||
|
||
def _get_chat_context(self, is_group_chat: bool, chat_id: int) -> list[dict]:
|
||
prompt = GROUP_CHAT_SYSTEM_PROMPT if is_group_chat else PRIVATE_CHAT_SYSTEM_PROMPT
|
||
|
||
chat = self.db.create_chat_if_not_exists(chat_id)
|
||
if chat['ai_prompt'] is not None:
|
||
prompt += '\n\n' + chat['ai_prompt']
|
||
|
||
messages = self.db.context_get_messages(chat_id)
|
||
return [{"role": "system", "content": prompt}] + messages
|
||
|
||
|
||
agent: AiAgent
|
||
|
||
|
||
def create_ai_agent(api_token: str, model: str, model_temp: float, db: BasicDatabase):
|
||
global agent
|
||
agent = AiAgent(api_token, model, model_temp, db)
|