Рефакторинг ИИ-агента.
Добавлена обработка ответов боту (вместо прямого упоминания).
This commit is contained in:
parent
bca3f640ae
commit
7fc6373fca
5 changed files with 145 additions and 208 deletions
74
ai_agent.py
Normal file
74
ai_agent.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from openrouter import OpenRouter
|
||||
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Ты - помощник в групповом чате Telegram.
|
||||
Отвечай на вопросы и поддерживай контекст беседы.
|
||||
Ты не можешь обсуждать политику и религию.
|
||||
Сообщения пользователей будут приходить в следующем формате: '[Имя]: текст сообщения'
|
||||
При ответе НЕ нужно указывать пользователя, которому он предназначен.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatContext:
|
||||
def __init__(self, max_messages: int):
|
||||
self.max_messages: int = max_messages
|
||||
self.messages: List[Dict[str, str]] = []
|
||||
|
||||
def add_message(self, role: str, content: str):
|
||||
if len(self.messages) == self.max_messages:
|
||||
# Всегда сохраняем в контексте системное сообщение
|
||||
self.messages.pop(1)
|
||||
self.messages.append({"role": role, "content": content})
|
||||
|
||||
def get_messages_for_api(self) -> List[Dict[str, str]]:
|
||||
return self.messages
|
||||
|
||||
|
||||
class AiAgent:
|
||||
def __init__(self, api_token: str):
|
||||
self.client = OpenRouter(api_key=api_token)
|
||||
self.chat_contexts: Dict[int, ChatContext] = {}
|
||||
|
||||
async def get_reply(self, chat_id: int, chat_prompt: str, user_name: str, message: str) -> str:
|
||||
context = self._get_chat_context(chat_id, chat_prompt)
|
||||
context.add_message(role="user", content=f"[{user_name}]: {message}")
|
||||
messages_for_api = context.get_messages_for_api()
|
||||
|
||||
try:
|
||||
# Get response from OpenRouter
|
||||
response = await self.client.chat.send_async(
|
||||
model="meta-llama/llama-3.3-70b-instruct:free",
|
||||
messages=messages_for_api,
|
||||
max_tokens=500,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
# Extract AI response
|
||||
ai_response = response.choices[0].message.content
|
||||
|
||||
# Add AI response to context
|
||||
context.add_message(role="assistant", content=ai_response)
|
||||
|
||||
return ai_response
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing message: {e}")
|
||||
return "Извините, при обработке запроса произошла ошибка."
|
||||
|
||||
def clear_chat_context(self, chat_id: int):
|
||||
self.chat_contexts.pop(chat_id, None)
|
||||
|
||||
def _get_chat_context(self, chat_id: int, chat_prompt: Optional[str]) -> ChatContext:
|
||||
"""Get or create chat context for a specific chat"""
|
||||
if chat_id not in self.chat_contexts:
|
||||
self.chat_contexts[chat_id] = ChatContext(10)
|
||||
prompt = SYSTEM_PROMPT
|
||||
if chat_prompt is not None:
|
||||
prompt += '\n\n' + chat_prompt
|
||||
self.chat_contexts[chat_id].add_message(role="system", content=prompt)
|
||||
return self.chat_contexts[chat_id]
|
||||
|
|
@ -5,7 +5,7 @@ from aiogram.utils.formatting import Bold
|
|||
import utils
|
||||
from messages import *
|
||||
import tg.tg_database as database
|
||||
from .default import reset_ai_chat_context
|
||||
from .default import clear_ai_chat_context
|
||||
|
||||
router = Router()
|
||||
|
||||
|
|
@ -90,7 +90,7 @@ async def set_ai_prompt_handler(message: Message, bot: Bot):
|
|||
return
|
||||
|
||||
database.DB.chat_update(chat_id, ai_prompt=message.reply_to_message.text)
|
||||
reset_ai_chat_context(chat_id)
|
||||
clear_ai_chat_context(chat_id)
|
||||
await message.answer('Личность ИИ изменена.')
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,117 +1,19 @@
|
|||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from aiogram import Router, F
|
||||
from aiogram.types import Message
|
||||
from aiogram.types.user import User
|
||||
from aiogram.enums.content_type import ContentType
|
||||
|
||||
from openrouter import OpenRouter
|
||||
from dataclasses import dataclass
|
||||
|
||||
import utils
|
||||
from ai_agent import AiAgent
|
||||
import tg.tg_database as database
|
||||
|
||||
router = Router()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatContext:
|
||||
def __init__(self, max_messages: int):
|
||||
self.max_messages: int = max_messages
|
||||
self.messages: List[Dict[str, str]] = []
|
||||
|
||||
def add_message(self, role: str, content: str):
|
||||
if len(self.messages) == self.max_messages:
|
||||
# Всегда сохраняем в контексте системное сообщение
|
||||
self.messages.pop(1)
|
||||
self.messages.append({"role": role, "content": content})
|
||||
|
||||
def get_messages_for_api(self) -> List[Dict[str, str]]:
|
||||
return self.messages
|
||||
|
||||
|
||||
chat_contexts: Dict[int, ChatContext] = {}
|
||||
agent: Optional[AiAgent] = None
|
||||
bot_user: Optional[User] = None
|
||||
|
||||
system_prompt = """
|
||||
Ты - помощник в групповом чате Telegram.
|
||||
Отвечай на вопросы и поддерживай контекст беседы.
|
||||
Ты не можешь обсуждать политику и религию.
|
||||
Сообщения пользователей будут приходить в следующем формате: '[Имя]: текст сообщения'
|
||||
При ответе НЕ нужно указывать пользователя, которому он предназначен.
|
||||
"""
|
||||
|
||||
|
||||
def get_ai_chat_context(chat_id: int) -> ChatContext:
|
||||
"""Get or create chat context for a specific chat"""
|
||||
if chat_id not in chat_contexts:
|
||||
chat_contexts[chat_id] = ChatContext(10)
|
||||
chat = database.DB.get_chat(chat_id)
|
||||
prompt = system_prompt
|
||||
if chat['ai_prompt'] is not None:
|
||||
prompt += '\n\n' + chat['ai_prompt']
|
||||
chat_contexts[chat_id].add_message(role="system", content=prompt)
|
||||
return chat_contexts[chat_id]
|
||||
|
||||
|
||||
def reset_ai_chat_context(chat_id: int):
|
||||
chat_contexts.pop(chat_id, None)
|
||||
get_ai_chat_context(chat_id)
|
||||
|
||||
|
||||
async def ai_message_handler(message: Message):
|
||||
chat_id = message.chat.id
|
||||
|
||||
# Extract user information and message content
|
||||
if message.from_user.first_name and message.from_user.last_name:
|
||||
user_name = "{} {}".format(message.from_user.first_name, message.from_user.last_name)
|
||||
elif message.from_user.first_name:
|
||||
user_name = message.from_user.first_name
|
||||
elif message.from_user.username:
|
||||
user_name = message.from_user.username
|
||||
else:
|
||||
user_name = str(message.from_user.id)
|
||||
|
||||
bot_mention = '@' + bot_user.username
|
||||
message_text = message.text.replace(bot_mention, bot_user.first_name)
|
||||
|
||||
context = get_ai_chat_context(chat_id)
|
||||
context.add_message(
|
||||
role="user",
|
||||
content=f"[{user_name}]: {message_text}"
|
||||
)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
api_key = message.bot.config['openrouter_token']
|
||||
|
||||
client = OpenRouter(api_key=api_key)
|
||||
messages_for_api = context.get_messages_for_api()
|
||||
|
||||
await message.bot.send_chat_action(chat_id, 'typing')
|
||||
|
||||
try:
|
||||
# Get response from OpenRouter
|
||||
response = await client.chat.send_async(
|
||||
model="meta-llama/llama-3.3-70b-instruct:free",
|
||||
messages=messages_for_api,
|
||||
max_tokens=500,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
# Extract AI response
|
||||
ai_response = response.choices[0].message.content
|
||||
|
||||
# Add AI response to context
|
||||
context.add_message(role="assistant", content=ai_response)
|
||||
|
||||
# Send response back to chat
|
||||
await message.reply(ai_response)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing message: {e}")
|
||||
await message.reply("Извините, при обработке запроса произошла ошибка.")
|
||||
|
||||
|
||||
ACCEPTED_CONTENT_TYPES: list[ContentType] = [
|
||||
ContentType.TEXT,
|
||||
ContentType.ANIMATION,
|
||||
|
|
@ -152,6 +54,34 @@ async def any_message_handler(message: Message):
|
|||
if bot_user is None:
|
||||
bot_user = await message.bot.get_me()
|
||||
|
||||
bot_mention = '@' + bot_user.username
|
||||
if message.content_type == ContentType.TEXT and message.text.find(bot_mention) != -1:
|
||||
await ai_message_handler(message)
|
||||
if message.content_type == ContentType.TEXT and message.text.find('@' + bot_user.username) != -1:
|
||||
message_text = message.text.replace('@' + bot_user.username, bot_user.first_name)
|
||||
elif message.reply_to_message and message.reply_to_message.from_user.id == bot_user.id:
|
||||
message_text = message.text
|
||||
else:
|
||||
return
|
||||
|
||||
if message.from_user.first_name and message.from_user.last_name:
|
||||
user_name = "{} {}".format(message.from_user.first_name, message.from_user.last_name)
|
||||
elif message.from_user.first_name:
|
||||
user_name = message.from_user.first_name
|
||||
elif message.from_user.username:
|
||||
user_name = message.from_user.username
|
||||
else:
|
||||
user_name = str(message.from_user.id)
|
||||
|
||||
chat_prompt = chat['ai_prompt']
|
||||
|
||||
global agent
|
||||
if agent is None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
agent = AiAgent(message.bot.config['openrouter_token'])
|
||||
|
||||
await message.bot.send_chat_action(chat_id, 'typing')
|
||||
await message.reply(await agent.get_reply(chat_id, chat_prompt, user_name, message_text))
|
||||
|
||||
|
||||
def clear_ai_chat_context(chat_id: int):
|
||||
global agent
|
||||
if agent is not None:
|
||||
agent.clear_chat_context(chat_id)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from vkbottle_types.codegen.objects import MessagesGetConversationMembers
|
|||
from messages import *
|
||||
import utils
|
||||
import vk.vk_database as database
|
||||
from .default import reset_ai_chat_context
|
||||
from .default import clear_ai_chat_context
|
||||
|
||||
labeler = BotLabeler()
|
||||
|
||||
|
|
@ -145,7 +145,7 @@ async def set_ai_prompt_handler(message: Message):
|
|||
return
|
||||
|
||||
database.DB.chat_update(chat_id, ai_prompt=message.reply_message.text)
|
||||
reset_ai_chat_context(chat_id)
|
||||
clear_ai_chat_context(chat_id)
|
||||
await message.answer('Личность ИИ изменена.')
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,112 +1,18 @@
|
|||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from vkbottle.bot import Message
|
||||
from vkbottle.framework.labeler import BotLabeler
|
||||
from vkbottle_types.codegen.objects import GroupsGroup
|
||||
|
||||
from openrouter import OpenRouter
|
||||
from dataclasses import dataclass
|
||||
|
||||
import utils
|
||||
from ai_agent import AiAgent
|
||||
import vk.vk_database as database
|
||||
|
||||
labeler = BotLabeler()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatContext:
|
||||
def __init__(self, max_messages: int):
|
||||
self.max_messages: int = max_messages
|
||||
self.messages: List[Dict[str, str]] = []
|
||||
|
||||
def add_message(self, role: str, content: str):
|
||||
if len(self.messages) == self.max_messages:
|
||||
# Всегда сохраняем в контексте системное сообщение
|
||||
self.messages.pop(1)
|
||||
self.messages.append({"role": role, "content": content})
|
||||
|
||||
def get_messages_for_api(self) -> List[Dict[str, str]]:
|
||||
return self.messages
|
||||
|
||||
|
||||
chat_contexts: Dict[int, ChatContext] = {}
|
||||
agent: Optional[AiAgent] = None
|
||||
bot_user: Optional[GroupsGroup] = None
|
||||
|
||||
system_prompt = """
|
||||
Ты - помощник в групповом чате Telegram.
|
||||
Отвечай на вопросы и поддерживай контекст беседы.
|
||||
Ты не можешь обсуждать политику и религию.
|
||||
Сообщения пользователей будут приходить в следующем формате: '[Имя]: текст сообщения'
|
||||
При ответе НЕ нужно указывать пользователя, которому он предназначен.
|
||||
"""
|
||||
|
||||
|
||||
def get_ai_chat_context(chat_id: int) -> ChatContext:
|
||||
"""Get or create chat context for a specific chat"""
|
||||
if chat_id not in chat_contexts:
|
||||
chat_contexts[chat_id] = ChatContext(10)
|
||||
chat = database.DB.get_chat(chat_id)
|
||||
prompt = system_prompt
|
||||
if chat['ai_prompt'] is not None:
|
||||
prompt += '\n\n' + chat['ai_prompt']
|
||||
chat_contexts[chat_id].add_message(role="system", content=prompt)
|
||||
return chat_contexts[chat_id]
|
||||
|
||||
|
||||
def reset_ai_chat_context(chat_id: int):
|
||||
chat_contexts.pop(chat_id, None)
|
||||
get_ai_chat_context(chat_id)
|
||||
|
||||
|
||||
async def ai_message_handler(message: Message):
|
||||
chat_id = message.peer_id
|
||||
|
||||
# Extract user information and message content
|
||||
user = await message.ctx_api.users.get(user_ids=[message.from_id])
|
||||
if len(user) == 1:
|
||||
user_name = "{} {}".format(user[0].first_name, user[0].last_name)
|
||||
else:
|
||||
user_name = '@id' + str(message.from_id)
|
||||
|
||||
bot_mention = '@' + bot_user.screen_name
|
||||
message_text = message.text.replace(bot_mention, bot_user.name)
|
||||
|
||||
context = get_ai_chat_context(chat_id)
|
||||
context.add_message(
|
||||
role="user",
|
||||
content=f"[{user_name}]: {message_text}"
|
||||
)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
api_key = message.ctx_api.config['openrouter_token']
|
||||
|
||||
client = OpenRouter(api_key=api_key)
|
||||
messages_for_api = context.get_messages_for_api()
|
||||
|
||||
await message.ctx_api.messages.set_activity(peer_id=chat_id, type='typing')
|
||||
|
||||
try:
|
||||
# Get response from OpenRouter
|
||||
response = await client.chat.send_async(
|
||||
model="meta-llama/llama-3.3-70b-instruct:free",
|
||||
messages=messages_for_api,
|
||||
max_tokens=500,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
# Extract AI response
|
||||
ai_response = response.choices[0].message.content
|
||||
|
||||
# Add AI response to context
|
||||
context.add_message(role="assistant", content=ai_response)
|
||||
|
||||
# Send response back to chat
|
||||
await message.reply(ai_response)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing message: {e}")
|
||||
await message.reply("Извините, при обработке запроса произошла ошибка.")
|
||||
|
||||
|
||||
# Обычные сообщения (не команды и не действия)
|
||||
@labeler.chat_message()
|
||||
|
|
@ -133,6 +39,33 @@ async def any_message_handler(message: Message):
|
|||
if bot_user is None:
|
||||
bot_user = (await message.ctx_api.groups.get_by_id()).groups[0]
|
||||
|
||||
bot_mention = '@' + bot_user.screen_name
|
||||
if message.text is not None and message.text.find(bot_mention) != -1:
|
||||
await ai_message_handler(message)
|
||||
if message.text is not None and message.text.find('@' + bot_user.screen_name) != -1:
|
||||
message_text = message.text.replace('@' + bot_user.screen_name, bot_user.name)
|
||||
elif message.text is not None and message.text.find('club' + str(bot_user.id)) != -1:
|
||||
message_text = message.text.replace('club' + str(bot_user.id), bot_user.name)
|
||||
elif message.reply_message and message.reply_message.from_id == -bot_user.id:
|
||||
message_text = message.text
|
||||
else:
|
||||
return
|
||||
|
||||
user = await message.ctx_api.users.get(user_ids=[message.from_id])
|
||||
if len(user) == 1:
|
||||
user_name = "{} {}".format(user[0].first_name, user[0].last_name)
|
||||
else:
|
||||
user_name = '@id' + str(message.from_id)
|
||||
|
||||
chat_prompt = chat['ai_prompt']
|
||||
|
||||
global agent
|
||||
if agent is None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
agent = AiAgent(message.ctx_api.config['openrouter_token'])
|
||||
|
||||
await message.ctx_api.messages.set_activity(peer_id=chat_id, type='typing')
|
||||
await message.reply(await agent.get_reply(chat_id, chat_prompt, user_name, message_text))
|
||||
|
||||
|
||||
def clear_ai_chat_context(chat_id: int):
|
||||
global agent
|
||||
if agent is not None:
|
||||
agent.clear_chat_context(chat_id)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue