92 lines
3.7 KiB
Python
92 lines
3.7 KiB
Python
from dataclasses import dataclass
|
||
from typing import List, Dict, Optional
|
||
|
||
from openrouter import OpenRouter
|
||
|
||
|
||
SYSTEM_PROMPT = """
|
||
Ты - помощник в групповом чате.
|
||
Отвечай на вопросы и поддерживай контекст беседы.
|
||
Ты не можешь обсуждать политику и религию.
|
||
Сообщения пользователей будут приходить в следующем формате: '[Имя]: текст сообщения'
|
||
При ответе НЕ нужно указывать пользователя, которому он предназначен.
|
||
"""
|
||
|
||
|
||
class ChatContext:
|
||
def __init__(self, max_messages: int):
|
||
self.max_messages: int = max_messages
|
||
self.messages: List[Dict[str, str]] = []
|
||
|
||
def add_message(self, role: str, content: str):
|
||
if len(self.messages) == self.max_messages:
|
||
# Всегда сохраняем в контексте системное сообщение
|
||
self.messages.pop(1)
|
||
self.messages.append({"role": role, "content": content})
|
||
|
||
def get_messages_for_api(self) -> List[Dict[str, str]]:
|
||
return self.messages
|
||
|
||
def remove_last_message(self):
|
||
self.messages.pop()
|
||
|
||
|
||
@dataclass()
|
||
class AiMessage:
|
||
user_name: str = None
|
||
text: str = None
|
||
|
||
|
||
class AiAgent:
|
||
def __init__(self, api_token: str):
|
||
self.client = OpenRouter(api_key=api_token)
|
||
self.chat_contexts: Dict[int, ChatContext] = {}
|
||
|
||
async def get_reply(self, chat_id: int, chat_prompt: str,
|
||
message: AiMessage, forwarded_messages: List[AiMessage]) -> str:
|
||
message_text = message.text
|
||
for fwd_message in forwarded_messages:
|
||
message_text += '\n<Цитируемое сообщение от {}>\n'.format(fwd_message.user_name)
|
||
message_text += fwd_message.text + '\n'
|
||
message_text += '<Конец цитаты>'
|
||
|
||
context = self._get_chat_context(chat_id, chat_prompt)
|
||
context.add_message(role="user", content=f"[{message.user_name}]: {message_text}")
|
||
|
||
try:
|
||
# Get response from OpenRouter
|
||
response = await self.client.chat.send_async(
|
||
model="cognitivecomputations/dolphin-mistral-24b-venice-edition:free",
|
||
messages=context.get_messages_for_api(),
|
||
max_tokens=500,
|
||
temperature=0.5
|
||
)
|
||
|
||
# Extract AI response
|
||
ai_response = response.choices[0].message.content
|
||
|
||
# Add AI response to context
|
||
context.add_message(role="assistant", content=ai_response)
|
||
|
||
return ai_response
|
||
|
||
except Exception as e:
|
||
context.remove_last_message()
|
||
if str(e).find("Rate limit exceeded") != -1:
|
||
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."
|
||
else:
|
||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||
return f"Извините, при обработке запроса произошла ошибка."
|
||
|
||
def clear_chat_context(self, chat_id: int):
|
||
self.chat_contexts.pop(chat_id, None)
|
||
|
||
def _get_chat_context(self, chat_id: int, chat_prompt: Optional[str]) -> ChatContext:
|
||
"""Get or create chat context for a specific chat"""
|
||
if chat_id not in self.chat_contexts:
|
||
self.chat_contexts[chat_id] = ChatContext(max_messages=20)
|
||
prompt = SYSTEM_PROMPT
|
||
if chat_prompt is not None:
|
||
prompt += '\n\n' + chat_prompt
|
||
self.chat_contexts[chat_id].add_message(role="system", content=prompt)
|
||
return self.chat_contexts[chat_id]
|