Добавлена обработка личных сообщений.
This commit is contained in:
parent
aad3a73ea0
commit
e61b2d1c14
16 changed files with 275 additions and 133 deletions
70
ai_agent.py
70
ai_agent.py
|
|
@ -4,14 +4,20 @@ from typing import List, Dict, Optional
|
|||
from openrouter import OpenRouter, RetryConfig
|
||||
from openrouter.utils import BackoffStrategy
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Ты - помощник в групповом чате.\n
|
||||
|
||||
GROUP_CHAT_SYSTEM_PROMPT = """
|
||||
Ты - ИИ-помощник в групповом чате.\n
|
||||
Отвечай на вопросы и поддерживай контекст беседы.\n
|
||||
Ты не можешь обсуждать политику и религию.\n
|
||||
Сообщения пользователей будут приходить в следующем формате: '[Имя]: текст сообщения'\n
|
||||
При ответе НЕ нужно указывать пользователя, которому он предназначен.
|
||||
"""
|
||||
|
||||
PRIVATE_CHAT_SYSTEM_PROMPT = """
|
||||
Ты - ИИ-помощник в чате c пользователем.\n
|
||||
Отвечай на вопросы и поддерживай контекст беседы.
|
||||
"""
|
||||
|
||||
|
||||
class ChatContext:
|
||||
def __init__(self, max_messages: int):
|
||||
|
|
@ -32,7 +38,7 @@ class ChatContext:
|
|||
|
||||
|
||||
@dataclass()
|
||||
class AiMessage:
|
||||
class Message:
|
||||
user_name: str = None
|
||||
text: str = None
|
||||
|
||||
|
|
@ -43,24 +49,54 @@ class AiAgent:
|
|||
backoff=BackoffStrategy(
|
||||
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
|
||||
retry_connection_errors=True)
|
||||
self.model = "meta-llama/llama-3.3-70b-instruct:free"
|
||||
self.client = OpenRouter(api_key=api_token, retry_config=retry_config)
|
||||
self.chat_contexts: Dict[int, ChatContext] = {}
|
||||
|
||||
async def get_reply(self, chat_id: int, chat_prompt: str,
|
||||
message: AiMessage, forwarded_messages: List[AiMessage]) -> str:
|
||||
async def get_group_chat_reply(self, chat_id: int, chat_prompt: str,
|
||||
message: Message, forwarded_messages: List[Message]) -> str:
|
||||
message_text = message.text
|
||||
for fwd_message in forwarded_messages:
|
||||
message_text += '\n<Цитируемое сообщение от {}>\n'.format(fwd_message.user_name)
|
||||
message_text += fwd_message.text + '\n'
|
||||
message_text += '<Конец цитаты>'
|
||||
|
||||
context = self._get_chat_context(chat_id, chat_prompt)
|
||||
context = self._get_chat_context(is_group_chat=True, chat_id=chat_id, chat_prompt=chat_prompt)
|
||||
context.add_message(role="user", content=f"[{message.user_name}]: {message_text}")
|
||||
|
||||
try:
|
||||
# Get response from OpenRouter
|
||||
response = await self.client.chat.send_async(
|
||||
model="meta-llama/llama-3.3-70b-instruct:free",
|
||||
model=self.model,
|
||||
messages=context.get_messages_for_api(),
|
||||
max_tokens=500,
|
||||
temperature=0.5
|
||||
)
|
||||
|
||||
# Extract AI response
|
||||
ai_response = response.choices[0].message.content
|
||||
|
||||
# Add AI response to context
|
||||
context.add_message(role="assistant", content=ai_response)
|
||||
|
||||
return ai_response
|
||||
|
||||
except Exception as e:
|
||||
context.remove_last_message()
|
||||
if str(e).find("Rate limit exceeded") != -1:
|
||||
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."
|
||||
else:
|
||||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||||
return f"Извините, при обработке запроса произошла ошибка."
|
||||
|
||||
async def get_private_chat_reply(self, chat_id: int, chat_prompt: str, message: str) -> str:
|
||||
context = self._get_chat_context(is_group_chat=False, chat_id=chat_id, chat_prompt=chat_prompt)
|
||||
context.add_message(role="user", content=message)
|
||||
|
||||
try:
|
||||
# Get response from OpenRouter
|
||||
response = await self.client.chat.send_async(
|
||||
model=self.model,
|
||||
messages=context.get_messages_for_api(),
|
||||
max_tokens=500,
|
||||
temperature=0.5
|
||||
|
|
@ -85,12 +121,26 @@ class AiAgent:
|
|||
def clear_chat_context(self, chat_id: int):
|
||||
self.chat_contexts.pop(chat_id, None)
|
||||
|
||||
def _get_chat_context(self, chat_id: int, chat_prompt: Optional[str]) -> ChatContext:
|
||||
def _get_chat_context(self, is_group_chat: bool, chat_id: int, chat_prompt: Optional[str]) -> ChatContext:
|
||||
"""Get or create chat context for a specific chat"""
|
||||
if chat_id not in self.chat_contexts:
|
||||
self.chat_contexts[chat_id] = ChatContext(max_messages=20)
|
||||
prompt = SYSTEM_PROMPT
|
||||
if is_group_chat:
|
||||
self.chat_contexts[chat_id] = ChatContext(max_messages=20)
|
||||
prompt = GROUP_CHAT_SYSTEM_PROMPT
|
||||
else:
|
||||
self.chat_contexts[chat_id] = ChatContext(max_messages=40)
|
||||
prompt = PRIVATE_CHAT_SYSTEM_PROMPT
|
||||
|
||||
if chat_prompt is not None:
|
||||
prompt += '\n\n' + chat_prompt
|
||||
|
||||
self.chat_contexts[chat_id].add_message(role="system", content=prompt)
|
||||
return self.chat_contexts[chat_id]
|
||||
|
||||
|
||||
agent: AiAgent
|
||||
|
||||
|
||||
def create_ai_agent(api_token: str):
|
||||
global agent
|
||||
agent = AiAgent(api_token)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ import json
|
|||
|
||||
from aiogram import Bot, Dispatcher
|
||||
|
||||
from ai_agent import create_ai_agent
|
||||
|
||||
from . import handlers
|
||||
from . import tasks
|
||||
|
||||
|
|
@ -13,7 +15,7 @@ async def main() -> None:
|
|||
print('Конфигурация загружена.')
|
||||
|
||||
bot = Bot(token=config['api_token'])
|
||||
bot.config = config
|
||||
create_ai_agent(config['openrouter_token'])
|
||||
|
||||
dp = Dispatcher()
|
||||
dp.include_router(handlers.router)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from aiogram import Router
|
||||
|
||||
from . import user, admin, action, default
|
||||
from . import private, user, admin, action, default
|
||||
|
||||
router = Router()
|
||||
router.include_router(private.router)
|
||||
router.include_router(user.router)
|
||||
router.include_router(admin.router)
|
||||
router.include_router(action.router)
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@ from aiogram import Bot, Router, F
|
|||
from aiogram.types import Message
|
||||
from aiogram.utils.formatting import Bold
|
||||
|
||||
import ai_agent
|
||||
import utils
|
||||
from messages import *
|
||||
|
||||
import tg.tg_database as database
|
||||
from .default import clear_ai_chat_context
|
||||
|
||||
router = Router()
|
||||
|
||||
|
|
@ -90,7 +91,7 @@ async def set_ai_prompt_handler(message: Message, bot: Bot):
|
|||
return
|
||||
|
||||
database.DB.chat_update(chat_id, ai_prompt=message.reply_to_message.text)
|
||||
clear_ai_chat_context(chat_id)
|
||||
ai_agent.agent.clear_chat_context(chat_id)
|
||||
await message.answer('Личность ИИ изменена.')
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,19 @@
|
|||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
from aiogram import Router, F, Bot
|
||||
from aiogram import Router, F
|
||||
from aiogram.types import Message
|
||||
from aiogram.types.user import User
|
||||
from aiogram.enums.content_type import ContentType
|
||||
|
||||
import ai_agent
|
||||
import utils
|
||||
from ai_agent import AiAgent, AiMessage
|
||||
|
||||
import tg.tg_database as database
|
||||
from tg.utils import get_user_name_for_ai
|
||||
|
||||
router = Router()
|
||||
|
||||
agent: Optional[AiAgent] = None
|
||||
bot_user: Optional[User] = None
|
||||
|
||||
ACCEPTED_CONTENT_TYPES: list[ContentType] = [
|
||||
|
|
@ -35,17 +36,6 @@ ACCEPTED_CONTENT_TYPES: list[ContentType] = [
|
|||
]
|
||||
|
||||
|
||||
async def get_user_name_for_ai(user: User):
|
||||
if user.first_name and user.last_name:
|
||||
return "{} {}".format(user.first_name, user.last_name)
|
||||
elif user.first_name:
|
||||
return user.first_name
|
||||
elif user.username:
|
||||
return user.username
|
||||
else:
|
||||
return str(user.id)
|
||||
|
||||
|
||||
@router.message(F.content_type.in_(ACCEPTED_CONTENT_TYPES))
|
||||
async def any_message_handler(message: Message):
|
||||
chat_id = message.chat.id
|
||||
|
|
@ -66,8 +56,8 @@ async def any_message_handler(message: Message):
|
|||
if bot_user is None:
|
||||
bot_user = await message.bot.get_me()
|
||||
|
||||
ai_message = AiMessage()
|
||||
ai_fwd_messages: list[AiMessage] = []
|
||||
ai_message = ai_agent.Message()
|
||||
ai_fwd_messages: list[ai_agent.Message] = []
|
||||
|
||||
# Ответ на сообщение бота
|
||||
if message.reply_to_message and message.reply_to_message.from_user.id == bot_user.id:
|
||||
|
|
@ -82,41 +72,14 @@ async def any_message_handler(message: Message):
|
|||
|
||||
if message.reply_to_message and message.reply_to_message.content_type == ContentType.TEXT:
|
||||
ai_fwd_messages = [
|
||||
AiMessage(user_name=await get_user_name_for_ai(message.reply_to_message.from_user),
|
||||
text=message.reply_to_message.text)]
|
||||
ai_agent.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user),
|
||||
text=message.reply_to_message.text)]
|
||||
|
||||
ai_message.user_name = await get_user_name_for_ai(message.from_user)
|
||||
chat_prompt = chat['ai_prompt']
|
||||
|
||||
global agent
|
||||
if agent is None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
agent = AiAgent(message.bot.config['openrouter_token'])
|
||||
|
||||
await message.reply(
|
||||
await utils.run_with_progress(partial(agent.get_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages),
|
||||
partial(message.bot.send_chat_action, chat_id, 'typing'),
|
||||
interval=4))
|
||||
|
||||
|
||||
def clear_ai_chat_context(chat_id: int):
|
||||
global agent
|
||||
if agent is not None:
|
||||
agent.clear_chat_context(chat_id)
|
||||
|
||||
|
||||
async def get_ai_reply(bot: Bot, chat_id, user: User, message: str, fwd_user: User, fwd_message: str) -> str:
|
||||
chat = database.DB.create_chat_if_not_exists(chat_id)
|
||||
chat_prompt = chat['ai_prompt']
|
||||
|
||||
ai_message = AiMessage(user_name=await get_user_name_for_ai(user), text=message)
|
||||
ai_fwd_messages = [AiMessage(user_name=await get_user_name_for_ai(fwd_user), text=fwd_message)]
|
||||
|
||||
global agent
|
||||
if agent is None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
agent = AiAgent(bot.config['openrouter_token'])
|
||||
|
||||
return await utils.run_with_progress(partial(agent.get_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages),
|
||||
partial(bot.send_chat_action(chat_id, 'typing')),
|
||||
interval=4)
|
||||
await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_group_chat_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages),
|
||||
partial(message.bot.send_chat_action, chat_id, 'typing'),
|
||||
interval=4))
|
||||
|
|
|
|||
53
tg/handlers/private.py
Normal file
53
tg/handlers/private.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
from functools import partial
|
||||
|
||||
from aiogram import Router, F
|
||||
from aiogram.enums import ChatType
|
||||
from aiogram.filters import Command, CommandObject, CommandStart
|
||||
from aiogram.types import Message
|
||||
|
||||
import ai_agent
|
||||
import utils
|
||||
|
||||
import tg.tg_database as database
|
||||
from .default import ACCEPTED_CONTENT_TYPES
|
||||
|
||||
router = Router()
|
||||
|
||||
|
||||
@router.message(CommandStart(), F.chat.type == ChatType.PRIVATE)
|
||||
async def start_handler(message: Message):
|
||||
chat_id = message.chat.id
|
||||
database.DB.create_chat_if_not_exists(chat_id)
|
||||
database.DB.chat_update(chat_id, active=1)
|
||||
await message.answer("Привет!")
|
||||
|
||||
|
||||
@router.message(Command("личность", prefix="!"), F.chat.type == ChatType.PRIVATE)
|
||||
async def set_prompt_handler(message: Message, command: CommandObject):
|
||||
chat_id = message.chat.id
|
||||
database.DB.create_chat_if_not_exists(chat_id)
|
||||
|
||||
database.DB.chat_update(chat_id, ai_prompt=command.args)
|
||||
await message.answer("Личность ИИ изменена.")
|
||||
|
||||
|
||||
@router.message(Command("сброс", prefix="!"), F.chat.type == ChatType.PRIVATE)
|
||||
async def reset_context_handler(message: Message):
|
||||
chat_id = message.chat.id
|
||||
database.DB.create_chat_if_not_exists(chat_id)
|
||||
|
||||
ai_agent.agent.clear_chat_context(chat_id)
|
||||
await message.answer("Контекст очищен.")
|
||||
|
||||
|
||||
@router.message(F.content_type.in_(ACCEPTED_CONTENT_TYPES), F.chat.type == ChatType.PRIVATE)
|
||||
async def any_message_handler(message: Message):
|
||||
chat_id = message.chat.id
|
||||
chat = database.DB.create_chat_if_not_exists(chat_id)
|
||||
chat_prompt = chat['ai_prompt']
|
||||
|
||||
await message.reply(
|
||||
await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_private_chat_reply, chat_id, chat_prompt, message.text),
|
||||
partial(message.bot.send_chat_action, chat_id, 'typing'),
|
||||
interval=4))
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from functools import partial
|
||||
from typing import List, Any
|
||||
|
||||
from aiogram import Bot, Router, F
|
||||
|
|
@ -5,10 +6,12 @@ from aiogram.enums import ContentType
|
|||
from aiogram.types import Message
|
||||
from aiogram.utils.formatting import Bold, Italic
|
||||
|
||||
import ai_agent
|
||||
import utils
|
||||
from messages import *
|
||||
|
||||
import tg.tg_database as database
|
||||
from .default import get_ai_reply
|
||||
from tg.utils import get_user_name_for_ai
|
||||
|
||||
router = Router()
|
||||
|
||||
|
|
@ -135,7 +138,7 @@ async def warnings_handler(message: Message, bot: Bot):
|
|||
|
||||
|
||||
@router.message(F.text == "!проверка")
|
||||
async def check_rules_violation_handler(message: Message):
|
||||
async def check_rules_violation_handler(message: Message, bot: Bot):
|
||||
chat_id = message.chat.id
|
||||
chat = database.DB.create_chat_if_not_exists(chat_id)
|
||||
if chat['active'] == 0:
|
||||
|
|
@ -155,5 +158,14 @@ async def check_rules_violation_handler(message: Message):
|
|||
prompt += chat_rules + '\n\n'
|
||||
prompt += 'Проверь, не нарушают ли правила следующие сообщения (если нарушают, то укажи пункты правил):'
|
||||
|
||||
await message.answer(await get_ai_reply(message.bot, chat_id, message.from_user, prompt,
|
||||
message.reply_to_message.from_user, message.reply_to_message.text))
|
||||
chat_prompt = chat['ai_prompt']
|
||||
|
||||
ai_message = ai_agent.Message(user_name=await get_user_name_for_ai(message.from_user), text=prompt)
|
||||
ai_fwd_messages = [ai_agent.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user),
|
||||
text=message.reply_to_message.text)]
|
||||
|
||||
await message.answer(
|
||||
await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_group_chat_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages),
|
||||
partial(bot.send_chat_action, chat_id, 'typing'),
|
||||
interval=4))
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from aiogram.exceptions import TelegramBadRequest
|
|||
from aiogram.types import ChatMemberBanned, ChatMemberLeft
|
||||
from aiogram.utils.formatting import Bold
|
||||
|
||||
from messages import *
|
||||
import tg.tg_database as database
|
||||
from tg.handlers.user import format_rating
|
||||
|
||||
|
|
|
|||
12
tg/utils.py
Normal file
12
tg/utils.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from aiogram.types import User
|
||||
|
||||
|
||||
async def get_user_name_for_ai(user: User):
|
||||
if user.first_name and user.last_name:
|
||||
return "{} {}".format(user.first_name, user.last_name)
|
||||
elif user.first_name:
|
||||
return user.first_name
|
||||
elif user.username:
|
||||
return user.username
|
||||
else:
|
||||
return str(user.id)
|
||||
|
|
@ -1,10 +1,12 @@
|
|||
import json
|
||||
|
||||
from vkbottle.bot import Bot as VkBot
|
||||
|
||||
from ai_agent import create_ai_agent
|
||||
|
||||
from . import handlers
|
||||
from . import tasks
|
||||
|
||||
from vkbottle.bot import Bot as VkBot
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open('vk.json', 'r') as file:
|
||||
|
|
@ -12,7 +14,8 @@ if __name__ == '__main__':
|
|||
print('Конфигурация загружена.')
|
||||
|
||||
bot = VkBot(config['api_token'], labeler=handlers.labeler)
|
||||
bot.api.config = config
|
||||
create_ai_agent(config["openrouter_token"])
|
||||
|
||||
bot.loop_wrapper.on_startup.append(tasks.startup_task(bot.api))
|
||||
bot.loop_wrapper.add_task(tasks.daily_maintenance_task(bot.api))
|
||||
bot.run_forever()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from vkbottle.framework.labeler import BotLabeler
|
||||
|
||||
from . import user, admin, action, default
|
||||
from . import private, user, admin, action, default
|
||||
|
||||
labeler = BotLabeler()
|
||||
labeler.load(private.labeler)
|
||||
labeler.load(user.labeler)
|
||||
labeler.load(admin.labeler)
|
||||
labeler.load(action.labeler)
|
||||
|
|
|
|||
|
|
@ -3,10 +3,11 @@ from vkbottle.bot import Message
|
|||
from vkbottle.framework.labeler import BotLabeler
|
||||
from vkbottle_types.codegen.objects import MessagesGetConversationMembers
|
||||
|
||||
from messages import *
|
||||
import ai_agent
|
||||
import utils
|
||||
from messages import *
|
||||
|
||||
import vk.vk_database as database
|
||||
from .default import clear_ai_chat_context
|
||||
|
||||
labeler = BotLabeler()
|
||||
|
||||
|
|
@ -145,7 +146,7 @@ async def set_ai_prompt_handler(message: Message):
|
|||
return
|
||||
|
||||
database.DB.chat_update(chat_id, ai_prompt=message.reply_message.text)
|
||||
clear_ai_chat_context(chat_id)
|
||||
ai_agent.agent.clear_chat_context(chat_id)
|
||||
await message.answer('Личность ИИ изменена.')
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,30 +1,22 @@
|
|||
import re
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, List
|
||||
from typing import Optional
|
||||
|
||||
from vkbottle import API
|
||||
from vkbottle.bot import Message
|
||||
from vkbottle.framework.labeler import BotLabeler
|
||||
from vkbottle_types.codegen.objects import GroupsGroup
|
||||
|
||||
import ai_agent
|
||||
import utils
|
||||
from ai_agent import AiAgent, AiMessage
|
||||
|
||||
import vk.vk_database as database
|
||||
from vk.utils import get_user_name_for_ai
|
||||
|
||||
labeler = BotLabeler()
|
||||
|
||||
agent: Optional[AiAgent] = None
|
||||
bot_user: Optional[GroupsGroup] = None
|
||||
|
||||
|
||||
async def get_user_name_for_ai(api: API, user_id: int):
|
||||
user = await api.users.get(user_ids=[user_id])
|
||||
if len(user) == 1:
|
||||
return "{} {}".format(user[0].first_name, user[0].last_name)
|
||||
else:
|
||||
return '@id' + str(user_id)
|
||||
|
||||
|
||||
# Обычные сообщения (не команды и не действия)
|
||||
@labeler.chat_message()
|
||||
async def any_message_handler(message: Message):
|
||||
|
|
@ -33,6 +25,9 @@ async def any_message_handler(message: Message):
|
|||
if chat['active'] == 0:
|
||||
return
|
||||
|
||||
if len(message.text) == 0:
|
||||
return
|
||||
|
||||
# Игнорировать ботов
|
||||
if message.from_id < 0:
|
||||
return
|
||||
|
|
@ -50,8 +45,8 @@ async def any_message_handler(message: Message):
|
|||
if bot_user is None:
|
||||
bot_user = (await message.ctx_api.groups.get_by_id()).groups[0]
|
||||
|
||||
ai_message = AiMessage()
|
||||
ai_fwd_messages: list[AiMessage] = []
|
||||
ai_message = ai_agent.Message()
|
||||
ai_fwd_messages: list[ai_agent.Message] = []
|
||||
|
||||
# Ответ на сообщение бота
|
||||
if message.reply_message and message.reply_message.from_id == -bot_user.id:
|
||||
|
|
@ -71,50 +66,20 @@ async def any_message_handler(message: Message):
|
|||
|
||||
if message.reply_message and len(message.reply_message.text) > 0:
|
||||
ai_fwd_messages.append(
|
||||
AiMessage(user_name=await get_user_name_for_ai(message.ctx_api, message.reply_message.from_id),
|
||||
text=message.reply_message.text))
|
||||
ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.reply_message.from_id),
|
||||
text=message.reply_message.text))
|
||||
else:
|
||||
for fwd_message in message.fwd_messages:
|
||||
if len(fwd_message.text) > 0:
|
||||
ai_fwd_messages.append(
|
||||
AiMessage(user_name=await get_user_name_for_ai(message.ctx_api, fwd_message.from_id),
|
||||
text=fwd_message.text))
|
||||
ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, fwd_message.from_id),
|
||||
text=fwd_message.text))
|
||||
|
||||
ai_message.user_name = await get_user_name_for_ai(message.ctx_api, message.from_id)
|
||||
chat_prompt = chat['ai_prompt']
|
||||
|
||||
global agent
|
||||
if agent is None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
agent = AiAgent(message.ctx_api.config['openrouter_token'])
|
||||
|
||||
await message.reply(
|
||||
await utils.run_with_progress(partial(agent.get_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages),
|
||||
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
|
||||
interval=4))
|
||||
|
||||
|
||||
def clear_ai_chat_context(chat_id: int):
|
||||
global agent
|
||||
if agent is not None:
|
||||
agent.clear_chat_context(chat_id)
|
||||
|
||||
|
||||
async def get_ai_reply(api: API, chat_id, message: Tuple[int, str], fwd_messages: List[Tuple[int, str]]) -> str:
|
||||
chat = database.DB.create_chat_if_not_exists(chat_id)
|
||||
chat_prompt = chat['ai_prompt']
|
||||
|
||||
ai_message = AiMessage(user_name=await get_user_name_for_ai(api, message[0]), text=message[1])
|
||||
ai_fwd_messages: List[AiMessage] = []
|
||||
for fwd_message in fwd_messages:
|
||||
ai_fwd_messages.append(
|
||||
AiMessage(user_name=await get_user_name_for_ai(api, fwd_message[0]), text=fwd_message[1]))
|
||||
|
||||
global agent
|
||||
if agent is None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
agent = AiAgent(api.config['openrouter_token'])
|
||||
|
||||
return await utils.run_with_progress(partial(agent.get_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages),
|
||||
partial(api.messages.set_activity, peer_id=chat_id, type='typing'),
|
||||
interval=4)
|
||||
await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_group_chat_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages),
|
||||
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
|
||||
interval=4))
|
||||
|
|
|
|||
55
vk/handlers/private.py
Normal file
55
vk/handlers/private.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
from functools import partial
|
||||
|
||||
from vkbottle.bot import Message
|
||||
from vkbottle.dispatch.rules.base import RegexRule
|
||||
from vkbottle.framework.labeler import BotLabeler
|
||||
|
||||
import ai_agent
|
||||
import utils
|
||||
|
||||
import vk.vk_database as database
|
||||
|
||||
labeler = BotLabeler()
|
||||
|
||||
|
||||
@labeler.private_message(text="!старт")
|
||||
async def start_handler(message: Message):
|
||||
chat_id = message.peer_id
|
||||
database.DB.create_chat_if_not_exists(chat_id)
|
||||
database.DB.chat_update(chat_id, active=1)
|
||||
await message.answer("Привет!")
|
||||
|
||||
|
||||
@labeler.private_message(RegexRule(r"^!личность ((?:.|\n)+)"))
|
||||
async def set_prompt_handler(message: Message, match):
|
||||
chat_id = message.peer_id
|
||||
database.DB.create_chat_if_not_exists(chat_id)
|
||||
|
||||
database.DB.chat_update(chat_id, ai_prompt=match[0])
|
||||
await message.answer("Личность ИИ изменена.")
|
||||
|
||||
|
||||
@labeler.private_message(text="!сброс")
|
||||
async def reset_context_handler(message: Message):
|
||||
chat_id = message.peer_id
|
||||
database.DB.create_chat_if_not_exists(chat_id)
|
||||
|
||||
ai_agent.agent.clear_chat_context(chat_id)
|
||||
await message.answer("Контекст очищен.")
|
||||
|
||||
|
||||
@labeler.private_message()
|
||||
async def any_message_handler(message: Message):
|
||||
chat_id = message.peer_id
|
||||
chat = database.DB.create_chat_if_not_exists(chat_id)
|
||||
|
||||
if len(message.text) == 0:
|
||||
return
|
||||
|
||||
chat_prompt = chat['ai_prompt']
|
||||
|
||||
await message.reply(
|
||||
await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_private_chat_reply, chat_id, chat_prompt, message.text),
|
||||
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
|
||||
interval=4))
|
||||
|
|
@ -1,12 +1,16 @@
|
|||
from functools import partial
|
||||
from typing import List, Any
|
||||
|
||||
from vkbottle import bold, italic, API
|
||||
from vkbottle.bot import Message
|
||||
from vkbottle.framework.labeler import BotLabeler
|
||||
|
||||
import ai_agent
|
||||
import utils
|
||||
from messages import *
|
||||
|
||||
import vk.vk_database as database
|
||||
from .default import get_ai_reply
|
||||
from vk.utils import get_user_name_for_ai
|
||||
|
||||
labeler = BotLabeler()
|
||||
|
||||
|
|
@ -178,16 +182,27 @@ async def check_rules_violation_handler(message: Message):
|
|||
prompt += chat_rules + '\n\n'
|
||||
prompt += 'Проверь, не нарушают ли правила следующие сообщения (если нарушают, то укажи пункты правил):'
|
||||
|
||||
fwd_messages: list[tuple[int, str]] = []
|
||||
chat_prompt = chat['ai_prompt']
|
||||
|
||||
ai_message = ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.from_id), text=prompt)
|
||||
ai_fwd_messages: list[ai_agent.Message] = []
|
||||
if message.reply_message is not None and len(message.reply_message.text) > 0:
|
||||
fwd_messages.append((message.reply_message.from_id, message.reply_message.text))
|
||||
ai_fwd_messages.append(
|
||||
ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, message.reply_message.from_id),
|
||||
text=message.reply_message.text))
|
||||
else:
|
||||
for fwd_message in message.fwd_messages:
|
||||
if len(fwd_message.text) > 0:
|
||||
fwd_messages.append((fwd_message.from_id, fwd_message.text))
|
||||
ai_fwd_messages.append(
|
||||
ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, fwd_message.from_id),
|
||||
text=fwd_message.text))
|
||||
|
||||
if len(fwd_messages) == 0:
|
||||
if len(ai_fwd_messages) == 0:
|
||||
await message.answer(MESSAGE_NEED_REPLY_OR_FORWARD)
|
||||
return
|
||||
|
||||
await message.answer(await get_ai_reply(message.ctx_api, chat_id, (message.from_id, prompt), fwd_messages))
|
||||
await message.answer(
|
||||
await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_group_chat_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages),
|
||||
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
|
||||
interval=4))
|
||||
|
|
|
|||
9
vk/utils.py
Normal file
9
vk/utils.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from vkbottle import API
|
||||
|
||||
|
||||
async def get_user_name_for_ai(api: API, user_id: int):
|
||||
user = await api.users.get(user_ids=[user_id])
|
||||
if len(user) == 1:
|
||||
return "{} {}".format(user[0].first_name, user[0].last_name)
|
||||
else:
|
||||
return '@id' + str(user_id)
|
||||
Loading…
Add table
Reference in a new issue