Реализована отправка изображений в запросе к ИИ (пока только для Telegram).
This commit is contained in:
parent
4f35663784
commit
e20c2a7d28
10 changed files with 157 additions and 74 deletions
68
ai_agent.py
68
ai_agent.py
|
|
@ -1,5 +1,6 @@
|
|||
import base64
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Any, Optional
|
||||
|
||||
from openrouter import OpenRouter, RetryConfig
|
||||
from openrouter.utils import BackoffStrategy
|
||||
|
|
@ -26,13 +27,26 @@ OPENROUTER_HEADERS = {
|
|||
'X-Title': 'TG/VK Chat Bot'
|
||||
}
|
||||
|
||||
|
||||
@dataclass()
|
||||
class Message:
|
||||
user_name: str = None
|
||||
text: str = None
|
||||
image: bytes = None
|
||||
message_id: int = None
|
||||
|
||||
|
||||
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
|
||||
json = {"role": role, "content": []}
|
||||
if text is not None:
|
||||
json["content"].append({"type": "text", "text": text})
|
||||
if image is not None:
|
||||
encoded_image = base64.b64encode(image).decode('utf-8')
|
||||
image_url = f"data:image/jpeg;base64,{encoded_image}"
|
||||
json["content"].append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
return json
|
||||
|
||||
|
||||
class AiAgent:
|
||||
def __init__(self, api_token: str, model: str, db: BasicDatabase, platform: str):
|
||||
retry_config = RetryConfig(strategy="backoff",
|
||||
|
|
@ -46,14 +60,20 @@ class AiAgent:
|
|||
|
||||
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
||||
message: Message, forwarded_messages: List[Message]) -> Tuple[str, bool]:
|
||||
message_text = f"[{message.user_name}]: {message.text}"
|
||||
for fwd_message in forwarded_messages:
|
||||
message_text += '\n<Цитируемое сообщение от {}>\n'.format(fwd_message.user_name)
|
||||
message_text += fwd_message.text + '\n'
|
||||
message_text += '<Конец цитаты>'
|
||||
|
||||
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
|
||||
context.append({"role": "user", "content": message_text})
|
||||
|
||||
if message.text is not None:
|
||||
message.text = f"[{message.user_name}]: {message.text}"
|
||||
else:
|
||||
message.text = f"[{message.user_name}]:"
|
||||
context.append(_serialize_message(role="user", text=message.text, image=message.image))
|
||||
|
||||
for fwd_message in forwarded_messages:
|
||||
message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name)
|
||||
if fwd_message.text is not None:
|
||||
message_text += '\n' + fwd_message.text
|
||||
fwd_message.text = message_text
|
||||
context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image))
|
||||
|
||||
try:
|
||||
# Get response from OpenRouter
|
||||
|
|
@ -68,10 +88,13 @@ class AiAgent:
|
|||
# Extract AI response
|
||||
ai_response = response.choices[0].message.content
|
||||
|
||||
# Add message and AI response to context
|
||||
self.db.context_add_message(bot_id, chat_id, role="user", content=message_text,
|
||||
# Add input messages and AI response to context
|
||||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||||
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||
self.db.context_add_message(bot_id, chat_id, role="assistant", content=ai_response,
|
||||
for fwd_message in forwarded_messages:
|
||||
self.db.context_add_message(bot_id, chat_id, role="user", text=fwd_message.text, image=fwd_message.image,
|
||||
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||
self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None,
|
||||
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||
|
||||
return ai_response, True
|
||||
|
|
@ -83,9 +106,16 @@ class AiAgent:
|
|||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||||
return f"Извините, при обработке запроса произошла ошибка.", False
|
||||
|
||||
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: str, message_id: int) -> Tuple[str, bool]:
|
||||
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[str, bool]:
|
||||
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
|
||||
context.append({"role": "user", "content": message})
|
||||
content: list[dict[str, Any]] = []
|
||||
if message.text is not None:
|
||||
content.append({"type": "text", "text": message.text})
|
||||
if message.image is not None:
|
||||
encoded_image = base64.b64encode(message.image).decode('utf-8')
|
||||
image_url = f"data:image/jpeg;base64,{encoded_image}"
|
||||
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
context.append({"role": "user", "content": content})
|
||||
|
||||
try:
|
||||
# Get response from OpenRouter
|
||||
|
|
@ -101,9 +131,9 @@ class AiAgent:
|
|||
ai_response = response.choices[0].message.content
|
||||
|
||||
# Add message and AI response to context
|
||||
self.db.context_add_message(bot_id, chat_id, role="user", content=message,
|
||||
message_id=message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
self.db.context_add_message(bot_id, chat_id, role="assistant", content=ai_response,
|
||||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||||
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None,
|
||||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
|
||||
return ai_response, True
|
||||
|
|
@ -136,7 +166,11 @@ class AiAgent:
|
|||
prompt += '\n\n' + chat['ai_prompt']
|
||||
|
||||
messages = self.db.context_get_messages(bot_id, chat_id)
|
||||
return [{"role": "system", "content": prompt}] + messages
|
||||
|
||||
context = [{"role": "system", "content": prompt}]
|
||||
for message in messages:
|
||||
context.append(_serialize_message(message["role"], message["text"], message["image"]))
|
||||
return context
|
||||
|
||||
|
||||
agent: AiAgent
|
||||
|
|
|
|||
29
database.py
29
database.py
|
|
@ -120,7 +120,7 @@ class BasicDatabase:
|
|||
|
||||
def context_get_messages(self, bot_id: int, chat_id: int) -> list[dict]:
|
||||
self.cursor.execute("""
|
||||
SELECT role, content FROM contexts
|
||||
SELECT role, text, image FROM contexts
|
||||
WHERE bot_id = ? AND chat_id = ? AND message_id IS NOT NULL
|
||||
ORDER BY message_id
|
||||
""", bot_id, chat_id)
|
||||
|
|
@ -138,16 +138,27 @@ class BasicDatabase:
|
|||
LIMIT 1
|
||||
""", bot_id, chat_id).fetchval()
|
||||
|
||||
def context_add_message(self, bot_id: int, chat_id: int, role: str, content: str, message_id: Optional[int], max_messages: int):
|
||||
def context_add_message(self, bot_id: int, chat_id: int, role: str,
|
||||
text: Optional[str], image: Optional[bytes],
|
||||
message_id: Optional[int], max_messages: int):
|
||||
assert (text or image)
|
||||
|
||||
self._context_trim(bot_id, chat_id, max_messages)
|
||||
|
||||
if message_id is not None:
|
||||
self.cursor.execute(
|
||||
"INSERT INTO contexts (bot_id, chat_id, message_id, role, content) VALUES (?, ?, ?, ?, ?)",
|
||||
bot_id, chat_id, message_id, role, content)
|
||||
else:
|
||||
self.cursor.execute("INSERT INTO contexts (bot_id, chat_id, role, content) VALUES (?, ?, ?, ?)",
|
||||
bot_id, chat_id, role, content)
|
||||
# Подготовка данных для вставки
|
||||
data = {
|
||||
"bot_id": bot_id, "chat_id": chat_id,
|
||||
"message_id": message_id, "role": role,
|
||||
"text": text, "image": image
|
||||
}
|
||||
|
||||
# Формирование SQL-запроса и параметров вставки
|
||||
columns = [k for k, v in data.items() if v is not None]
|
||||
placeholders = ', '.join(['?' for _ in columns])
|
||||
values = tuple(data[k] for k in columns)
|
||||
|
||||
query = f"INSERT INTO contexts ({', '.join(columns)}) VALUES ({placeholders})"
|
||||
self.cursor.execute(query, values)
|
||||
|
||||
def context_set_last_message_id(self, bot_id: int, chat_id: int, message_id: int):
|
||||
self.cursor.execute("UPDATE contexts SET message_id = ? WHERE bot_id = ? AND chat_id = ? AND message_id IS NULL",
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ MESSAGE_CHAT_NOT_ACTIVE = 'Извините, но я пока не работа
|
|||
MESSAGE_PERMISSION_DENIED = 'Извините, но о таком меня может попросить только администратор чата.'
|
||||
MESSAGE_NEED_REPLY = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение.'
|
||||
MESSAGE_NEED_REPLY_OR_FORWARD = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение или с пересылкой текстовых сообщений.'
|
||||
MESSAGE_NOT_TEXT = 'Извините, но я понимаю только текст.'
|
||||
MESSAGE_UNSUPPORTED_CONTENT_TYPE = 'Извините, но я понимаю только текст и изображения.'
|
||||
MESSAGE_DEFAULT_RULES = 'Правила не установлены. Просто ведите себя хорошо.'
|
||||
MESSAGE_DEFAULT_CHECK_RULES = 'Правила чата не установлены. Проверка невозможна.'
|
||||
MESSAGE_DEFAULT_GREETING_JOIN = 'Добро пожаловать, {name}!'
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import utils
|
|||
from messages import *
|
||||
|
||||
import tg.tg_database as database
|
||||
from tg.utils import get_user_name_for_ai
|
||||
from tg.utils import *
|
||||
|
||||
router = Router()
|
||||
|
||||
|
|
@ -51,46 +51,31 @@ async def any_message_handler(message: Message, bot: Bot):
|
|||
|
||||
bot_user = await bot.me()
|
||||
|
||||
ai_message = ai_agent.Message()
|
||||
ai_fwd_messages: list[ai_agent.Message] = []
|
||||
|
||||
bot_username_mention = '@' + bot_user.username
|
||||
if message.content_type == ContentType.TEXT and message.text.find(bot_username_mention) != -1:
|
||||
# Сообщение содержит @bot_username
|
||||
ai_message.text = message.text.replace(bot_username_mention, bot_user.first_name)
|
||||
|
||||
if message.reply_to_message:
|
||||
# Сообщение является ответом
|
||||
if message.reply_to_message.content_type == ContentType.TEXT:
|
||||
ai_fwd_messages = [
|
||||
ai_agent.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user),
|
||||
text=message.reply_to_message.text)]
|
||||
else:
|
||||
await message.reply(MESSAGE_NOT_TEXT)
|
||||
return
|
||||
elif message.reply_to_message and message.reply_to_message.from_user.id == bot_user.id:
|
||||
# Ответ на сообщение бота
|
||||
if message.content_type == ContentType.TEXT:
|
||||
ai_message.text = message.text
|
||||
try:
|
||||
message_text = get_message_text(message)
|
||||
bot_username_mention = '@' + bot_user.username
|
||||
if message_text is not None and message_text.find(bot_username_mention) != -1:
|
||||
# Сообщение содержит @bot_username
|
||||
message_text = message_text.replace(bot_username_mention, bot_user.first_name)
|
||||
if message.reply_to_message:
|
||||
# Сообщение также является ответом -> переслать оригинальное сообщение
|
||||
ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)]
|
||||
elif message.reply_to_message and message.reply_to_message.from_user.id == bot_user.id:
|
||||
# Ответ на сообщение бота
|
||||
last_id = ai_agent.agent.get_last_assistant_message_id(bot.id, chat_id)
|
||||
if message.reply_to_message.message_id != last_id:
|
||||
# Оригинального сообщения нет в контексте, или оно не последнее -> переслать его
|
||||
ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)]
|
||||
else:
|
||||
await message.reply(MESSAGE_NOT_TEXT)
|
||||
return
|
||||
|
||||
last_id = ai_agent.agent.get_last_assistant_message_id(bot.id, chat_id)
|
||||
if message.reply_to_message.message_id != last_id:
|
||||
# Оригинального сообщения нет в контексте, или оно не последнее
|
||||
if message.content_type == ContentType.TEXT:
|
||||
ai_fwd_messages = [
|
||||
ai_agent.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user),
|
||||
text=message.reply_to_message.text)]
|
||||
else:
|
||||
await message.reply(MESSAGE_NOT_TEXT)
|
||||
return
|
||||
else:
|
||||
except UnsupportedContentType:
|
||||
await message.reply(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
|
||||
return
|
||||
|
||||
ai_message.user_name = await get_user_name_for_ai(message.from_user)
|
||||
ai_message.message_id = message.message_id
|
||||
ai_message = await create_ai_message(message, bot)
|
||||
ai_message.text = message_text
|
||||
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import utils
|
|||
from messages import *
|
||||
|
||||
import tg.tg_database as database
|
||||
from tg.utils import *
|
||||
from .default import ACCEPTED_CONTENT_TYPES
|
||||
|
||||
router = Router()
|
||||
|
|
@ -45,12 +46,14 @@ async def reset_context_handler(message: Message, bot: Bot):
|
|||
async def any_message_handler(message: Message, bot: Bot):
|
||||
chat_id = message.chat.id
|
||||
|
||||
if message.content_type != ContentType.TEXT:
|
||||
await message.answer(MESSAGE_NOT_TEXT)
|
||||
try:
|
||||
ai_message = await create_ai_message(message, bot)
|
||||
except UnsupportedContentType:
|
||||
await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
|
||||
return
|
||||
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_private_chat_reply, bot.id, chat_id, message.text, message.message_id),
|
||||
partial(ai_agent.agent.get_private_chat_reply, bot.id, chat_id, ai_message),
|
||||
partial(message.bot.send_chat_action, chat_id, 'typing'),
|
||||
interval=4)
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,8 @@ class TgDatabase(database.BasicDatabase):
|
|||
chat_id BIGINT NOT NULL,
|
||||
message_id BIGINT,
|
||||
role VARCHAR(16) NOT NULL,
|
||||
content VARCHAR(4000) NOT NULL,
|
||||
text VARCHAR(4000),
|
||||
image MEDIUMBLOB,
|
||||
UNIQUE KEY contexts_unique (bot_id, chat_id, message_id),
|
||||
CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE)
|
||||
""")
|
||||
|
|
|
|||
46
tg/utils.py
46
tg/utils.py
|
|
@ -1,4 +1,11 @@
|
|||
from aiogram.types import User
|
||||
import io
|
||||
from typing import Optional
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.enums import ContentType
|
||||
from aiogram.types import User, PhotoSize, Message
|
||||
|
||||
import ai_agent
|
||||
|
||||
|
||||
async def get_user_name_for_ai(user: User):
|
||||
|
|
@ -10,3 +17,40 @@ async def get_user_name_for_ai(user: User):
|
|||
return user.username
|
||||
else:
|
||||
return str(user.id)
|
||||
|
||||
|
||||
async def download_photo(photo: PhotoSize, bot: Bot) -> bytes:
|
||||
# noinspection PyTypeChecker
|
||||
photo_bytes: io.BytesIO = await bot.download(photo.file_id)
|
||||
return photo_bytes.getvalue()
|
||||
|
||||
|
||||
def get_message_text(message: Message) -> Optional[str]:
|
||||
if message.content_type == ContentType.TEXT:
|
||||
return message.text
|
||||
elif message.content_type == ContentType.PHOTO:
|
||||
return message.caption
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class UnsupportedContentType(RuntimeError):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
async def create_ai_message(message: Message, bot: Bot) -> ai_agent.Message:
|
||||
ai_message = ai_agent.Message()
|
||||
ai_message.message_id = message.message_id
|
||||
ai_message.user_name = await get_user_name_for_ai(message.from_user)
|
||||
if message.content_type == ContentType.TEXT:
|
||||
ai_message.text = message.text
|
||||
elif message.content_type == ContentType.PHOTO:
|
||||
if message.media_group_id is None:
|
||||
ai_message.text = message.caption
|
||||
ai_message.image = await download_photo(message.photo[-1], bot)
|
||||
else:
|
||||
raise UnsupportedContentType()
|
||||
else:
|
||||
raise UnsupportedContentType()
|
||||
return ai_message
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ async def any_message_handler(message: Message):
|
|||
if len(message.text) > 0:
|
||||
ai_message.text = message.text
|
||||
else:
|
||||
await message.reply(MESSAGE_NOT_TEXT)
|
||||
await message.reply(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
|
||||
return
|
||||
|
||||
last_id = ai_agent.agent.get_last_assistant_message_id(bot_id, chat_id)
|
||||
|
|
@ -75,7 +75,7 @@ async def any_message_handler(message: Message):
|
|||
message.reply_message.from_id),
|
||||
text=message.reply_message.text)]
|
||||
else:
|
||||
await message.reply(MESSAGE_NOT_TEXT)
|
||||
await message.reply(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
|
||||
return
|
||||
else:
|
||||
return
|
||||
|
|
@ -88,7 +88,7 @@ async def any_message_handler(message: Message):
|
|||
user_name=await get_user_name_for_ai(message.ctx_api, message.reply_message.from_id),
|
||||
text=message.reply_message.text))
|
||||
else:
|
||||
await message.reply(MESSAGE_NOT_TEXT)
|
||||
await message.reply(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
|
||||
return
|
||||
else:
|
||||
for fwd_message in message.fwd_messages:
|
||||
|
|
@ -97,7 +97,7 @@ async def any_message_handler(message: Message):
|
|||
ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, fwd_message.from_id),
|
||||
text=fwd_message.text))
|
||||
else:
|
||||
await message.reply(MESSAGE_NOT_TEXT)
|
||||
await message.reply(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
|
||||
return
|
||||
|
||||
ai_message.user_name = await get_user_name_for_ai(message.ctx_api, message.from_id)
|
||||
|
|
|
|||
|
|
@ -49,11 +49,15 @@ async def any_message_handler(message: Message):
|
|||
chat_id = message.peer_id
|
||||
|
||||
if len(message.text) == 0:
|
||||
await message.answer(MESSAGE_NOT_TEXT)
|
||||
await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
|
||||
return
|
||||
|
||||
ai_message = ai_agent.Message()
|
||||
ai_message.text = message.text
|
||||
ai_message.message_id = message.message_id
|
||||
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_private_chat_reply, bot_id, chat_id, message.text, message.message_id),
|
||||
partial(ai_agent.agent.get_private_chat_reply, bot_id, chat_id, ai_message),
|
||||
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
|
||||
interval=4)
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,8 @@ class VkDatabase(database.BasicDatabase):
|
|||
chat_id BIGINT NOT NULL,
|
||||
message_id BIGINT,
|
||||
role VARCHAR(16) NOT NULL,
|
||||
content VARCHAR(4000) NOT NULL,
|
||||
text VARCHAR(4000),
|
||||
image MEDIUMBLOB,
|
||||
UNIQUE KEY contexts_unique (bot_id, chat_id, message_id),
|
||||
CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE)
|
||||
""")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue