diff --git a/ai_agent.py b/ai_agent.py index ab254d7..77a5d59 100644 --- a/ai_agent.py +++ b/ai_agent.py @@ -1,5 +1,6 @@ +import base64 from dataclasses import dataclass -from typing import List, Tuple +from typing import List, Tuple, Any, Optional from openrouter import OpenRouter, RetryConfig from openrouter.utils import BackoffStrategy @@ -26,13 +27,26 @@ OPENROUTER_HEADERS = { 'X-Title': 'TG/VK Chat Bot' } + @dataclass() class Message: user_name: str = None text: str = None + image: bytes = None message_id: int = None +def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict: + json = {"role": role, "content": []} + if text is not None: + json["content"].append({"type": "text", "text": text}) + if image is not None: + encoded_image = base64.b64encode(image).decode('utf-8') + image_url = f"data:image/jpeg;base64,{encoded_image}" + json["content"].append({"type": "image_url", "image_url": {"url": image_url}}) + return json + + class AiAgent: def __init__(self, api_token: str, model: str, db: BasicDatabase, platform: str): retry_config = RetryConfig(strategy="backoff", @@ -46,14 +60,20 @@ class AiAgent: async def get_group_chat_reply(self, bot_id: int, chat_id: int, message: Message, forwarded_messages: List[Message]) -> Tuple[str, bool]: - message_text = f"[{message.user_name}]: {message.text}" - for fwd_message in forwarded_messages: - message_text += '\n<Цитируемое сообщение от {}>\n'.format(fwd_message.user_name) - message_text += fwd_message.text + '\n' - message_text += '<Конец цитаты>' - context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id) - context.append({"role": "user", "content": message_text}) + + if message.text is not None: + message.text = f"[{message.user_name}]: {message.text}" + else: + message.text = f"[{message.user_name}]:" + context.append(_serialize_message(role="user", text=message.text, image=message.image)) + + for fwd_message in forwarded_messages: + message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name) + if fwd_message.text is not None: + message_text += '\n' + fwd_message.text + fwd_message.text = message_text + context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image)) try: # Get response from OpenRouter @@ -68,10 +88,13 @@ class AiAgent: # Extract AI response ai_response = response.choices[0].message.content - # Add message and AI response to context - self.db.context_add_message(bot_id, chat_id, role="user", content=message_text, + # Add input messages and AI response to context + self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image, message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES) - self.db.context_add_message(bot_id, chat_id, role="assistant", content=ai_response, + for fwd_message in forwarded_messages: + self.db.context_add_message(bot_id, chat_id, role="user", text=fwd_message.text, image=fwd_message.image, + message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES) + self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None, message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES) return ai_response, True @@ -83,9 +106,16 @@ class AiAgent: print(f"Ошибка выполнения запроса к ИИ: {e}") return f"Извините, при обработке запроса произошла ошибка.", False - async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: str, message_id: int) -> Tuple[str, bool]: + async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[str, bool]: context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id) - context.append({"role": "user", "content": message}) + content: list[dict[str, Any]] = [] + if message.text is not None: + content.append({"type": "text", "text": message.text}) + if message.image is not None: + encoded_image = base64.b64encode(message.image).decode('utf-8') + image_url = f"data:image/jpeg;base64,{encoded_image}" + content.append({"type": "image_url", "image_url": {"url": image_url}}) + context.append({"role": "user", "content": content}) try: # Get response from OpenRouter @@ -101,9 +131,9 @@ class AiAgent: ai_response = response.choices[0].message.content # Add message and AI response to context - self.db.context_add_message(bot_id, chat_id, role="user", content=message, - message_id=message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES) - self.db.context_add_message(bot_id, chat_id, role="assistant", content=ai_response, + self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image, + message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES) + self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None, message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES) return ai_response, True @@ -136,7 +166,11 @@ class AiAgent: prompt += '\n\n' + chat['ai_prompt'] messages = self.db.context_get_messages(bot_id, chat_id) - return [{"role": "system", "content": prompt}] + messages + + context = [{"role": "system", "content": prompt}] + for message in messages: + context.append(_serialize_message(message["role"], message["text"], message["image"])) + return context agent: AiAgent diff --git a/database.py b/database.py index c771588..3b2e357 100644 --- a/database.py +++ b/database.py @@ -120,7 +120,7 @@ class BasicDatabase: def context_get_messages(self, bot_id: int, chat_id: int) -> list[dict]: self.cursor.execute(""" - SELECT role, content FROM contexts + SELECT role, text, image FROM contexts WHERE bot_id = ? AND chat_id = ? AND message_id IS NOT NULL ORDER BY message_id """, bot_id, chat_id) @@ -138,16 +138,27 @@ class BasicDatabase: LIMIT 1 """, bot_id, chat_id).fetchval() - def context_add_message(self, bot_id: int, chat_id: int, role: str, content: str, message_id: Optional[int], max_messages: int): + def context_add_message(self, bot_id: int, chat_id: int, role: str, + text: Optional[str], image: Optional[bytes], + message_id: Optional[int], max_messages: int): + assert (text or image) + self._context_trim(bot_id, chat_id, max_messages) - if message_id is not None: - self.cursor.execute( - "INSERT INTO contexts (bot_id, chat_id, message_id, role, content) VALUES (?, ?, ?, ?, ?)", - bot_id, chat_id, message_id, role, content) - else: - self.cursor.execute("INSERT INTO contexts (bot_id, chat_id, role, content) VALUES (?, ?, ?, ?)", - bot_id, chat_id, role, content) + # Подготовка данных для вставки + data = { + "bot_id": bot_id, "chat_id": chat_id, + "message_id": message_id, "role": role, + "text": text, "image": image + } + + # Формирование SQL-запроса и параметров вставки + columns = [k for k, v in data.items() if v is not None] + placeholders = ', '.join(['?' for _ in columns]) + values = tuple(data[k] for k in columns) + + query = f"INSERT INTO contexts ({', '.join(columns)}) VALUES ({placeholders})" + self.cursor.execute(query, values) def context_set_last_message_id(self, bot_id: int, chat_id: int, message_id: int): self.cursor.execute("UPDATE contexts SET message_id = ? WHERE bot_id = ? AND chat_id = ? AND message_id IS NULL", diff --git a/messages.py b/messages.py index d9953c5..b4bf3dc 100644 --- a/messages.py +++ b/messages.py @@ -2,7 +2,7 @@ MESSAGE_CHAT_NOT_ACTIVE = 'Извините, но я пока не работа MESSAGE_PERMISSION_DENIED = 'Извините, но о таком меня может попросить только администратор чата.' MESSAGE_NEED_REPLY = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение.' MESSAGE_NEED_REPLY_OR_FORWARD = 'Извините, но эту команду нужно вызывать в ответ на текстовое сообщение или с пересылкой текстовых сообщений.' -MESSAGE_NOT_TEXT = 'Извините, но я понимаю только текст.' +MESSAGE_UNSUPPORTED_CONTENT_TYPE = 'Извините, но я понимаю только текст и изображения.' MESSAGE_DEFAULT_RULES = 'Правила не установлены. Просто ведите себя хорошо.' MESSAGE_DEFAULT_CHECK_RULES = 'Правила чата не установлены. Проверка невозможна.' MESSAGE_DEFAULT_GREETING_JOIN = 'Добро пожаловать, {name}!' diff --git a/tg/handlers/default.py b/tg/handlers/default.py index 1d03000..de44e36 100644 --- a/tg/handlers/default.py +++ b/tg/handlers/default.py @@ -9,7 +9,7 @@ import utils from messages import * import tg.tg_database as database -from tg.utils import get_user_name_for_ai +from tg.utils import * router = Router() @@ -51,46 +51,31 @@ async def any_message_handler(message: Message, bot: Bot): bot_user = await bot.me() - ai_message = ai_agent.Message() ai_fwd_messages: list[ai_agent.Message] = [] - bot_username_mention = '@' + bot_user.username - if message.content_type == ContentType.TEXT and message.text.find(bot_username_mention) != -1: - # Сообщение содержит @bot_username - ai_message.text = message.text.replace(bot_username_mention, bot_user.first_name) - - if message.reply_to_message: - # Сообщение является ответом - if message.reply_to_message.content_type == ContentType.TEXT: - ai_fwd_messages = [ - ai_agent.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user), - text=message.reply_to_message.text)] - else: - await message.reply(MESSAGE_NOT_TEXT) - return - elif message.reply_to_message and message.reply_to_message.from_user.id == bot_user.id: - # Ответ на сообщение бота - if message.content_type == ContentType.TEXT: - ai_message.text = message.text + try: + message_text = get_message_text(message) + bot_username_mention = '@' + bot_user.username + if message_text is not None and message_text.find(bot_username_mention) != -1: + # Сообщение содержит @bot_username + message_text = message_text.replace(bot_username_mention, bot_user.first_name) + if message.reply_to_message: + # Сообщение также является ответом -> переслать оригинальное сообщение + ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)] + elif message.reply_to_message and message.reply_to_message.from_user.id == bot_user.id: + # Ответ на сообщение бота + last_id = ai_agent.agent.get_last_assistant_message_id(bot.id, chat_id) + if message.reply_to_message.message_id != last_id: + # Оригинального сообщения нет в контексте, или оно не последнее -> переслать его + ai_fwd_messages = [await create_ai_message(message.reply_to_message, bot)] else: - await message.reply(MESSAGE_NOT_TEXT) return - - last_id = ai_agent.agent.get_last_assistant_message_id(bot.id, chat_id) - if message.reply_to_message.message_id != last_id: - # Оригинального сообщения нет в контексте, или оно не последнее - if message.content_type == ContentType.TEXT: - ai_fwd_messages = [ - ai_agent.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user), - text=message.reply_to_message.text)] - else: - await message.reply(MESSAGE_NOT_TEXT) - return - else: + except UnsupportedContentType: + await message.reply(MESSAGE_UNSUPPORTED_CONTENT_TYPE) return - ai_message.user_name = await get_user_name_for_ai(message.from_user) - ai_message.message_id = message.message_id + ai_message = await create_ai_message(message, bot) + ai_message.text = message_text answer, success = await utils.run_with_progress( partial(ai_agent.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages), diff --git a/tg/handlers/private.py b/tg/handlers/private.py index b3dba58..15488b4 100644 --- a/tg/handlers/private.py +++ b/tg/handlers/private.py @@ -10,6 +10,7 @@ import utils from messages import * import tg.tg_database as database +from tg.utils import * from .default import ACCEPTED_CONTENT_TYPES router = Router() @@ -45,12 +46,14 @@ async def reset_context_handler(message: Message, bot: Bot): async def any_message_handler(message: Message, bot: Bot): chat_id = message.chat.id - if message.content_type != ContentType.TEXT: - await message.answer(MESSAGE_NOT_TEXT) + try: + ai_message = await create_ai_message(message, bot) + except UnsupportedContentType: + await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE) return answer, success = await utils.run_with_progress( - partial(ai_agent.agent.get_private_chat_reply, bot.id, chat_id, message.text, message.message_id), + partial(ai_agent.agent.get_private_chat_reply, bot.id, chat_id, ai_message), partial(message.bot.send_chat_action, chat_id, 'typing'), interval=4) diff --git a/tg/tg_database.py b/tg/tg_database.py index 06d6bb7..dc8c050 100644 --- a/tg/tg_database.py +++ b/tg/tg_database.py @@ -48,7 +48,8 @@ class TgDatabase(database.BasicDatabase): chat_id BIGINT NOT NULL, message_id BIGINT, role VARCHAR(16) NOT NULL, - content VARCHAR(4000) NOT NULL, + text VARCHAR(4000), + image MEDIUMBLOB, UNIQUE KEY contexts_unique (bot_id, chat_id, message_id), CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE) """) diff --git a/tg/utils.py b/tg/utils.py index 5f25783..eef3b85 100644 --- a/tg/utils.py +++ b/tg/utils.py @@ -1,4 +1,11 @@ -from aiogram.types import User +import io +from typing import Optional + +from aiogram import Bot +from aiogram.enums import ContentType +from aiogram.types import User, PhotoSize, Message + +import ai_agent async def get_user_name_for_ai(user: User): @@ -10,3 +17,40 @@ async def get_user_name_for_ai(user: User): return user.username else: return str(user.id) + + +async def download_photo(photo: PhotoSize, bot: Bot) -> bytes: + # noinspection PyTypeChecker + photo_bytes: io.BytesIO = await bot.download(photo.file_id) + return photo_bytes.getvalue() + + +def get_message_text(message: Message) -> Optional[str]: + if message.content_type == ContentType.TEXT: + return message.text + elif message.content_type == ContentType.PHOTO: + return message.caption + else: + return None + + +class UnsupportedContentType(RuntimeError): + def __init__(self): + pass + + +async def create_ai_message(message: Message, bot: Bot) -> ai_agent.Message: + ai_message = ai_agent.Message() + ai_message.message_id = message.message_id + ai_message.user_name = await get_user_name_for_ai(message.from_user) + if message.content_type == ContentType.TEXT: + ai_message.text = message.text + elif message.content_type == ContentType.PHOTO: + if message.media_group_id is None: + ai_message.text = message.caption + ai_message.image = await download_photo(message.photo[-1], bot) + else: + raise UnsupportedContentType() + else: + raise UnsupportedContentType() + return ai_message diff --git a/vk/handlers/default.py b/vk/handlers/default.py index ba8b0eb..2d8084e 100644 --- a/vk/handlers/default.py +++ b/vk/handlers/default.py @@ -63,7 +63,7 @@ async def any_message_handler(message: Message): if len(message.text) > 0: ai_message.text = message.text else: - await message.reply(MESSAGE_NOT_TEXT) + await message.reply(MESSAGE_UNSUPPORTED_CONTENT_TYPE) return last_id = ai_agent.agent.get_last_assistant_message_id(bot_id, chat_id) @@ -75,7 +75,7 @@ async def any_message_handler(message: Message): message.reply_message.from_id), text=message.reply_message.text)] else: - await message.reply(MESSAGE_NOT_TEXT) + await message.reply(MESSAGE_UNSUPPORTED_CONTENT_TYPE) return else: return @@ -88,7 +88,7 @@ async def any_message_handler(message: Message): user_name=await get_user_name_for_ai(message.ctx_api, message.reply_message.from_id), text=message.reply_message.text)) else: - await message.reply(MESSAGE_NOT_TEXT) + await message.reply(MESSAGE_UNSUPPORTED_CONTENT_TYPE) return else: for fwd_message in message.fwd_messages: @@ -97,7 +97,7 @@ async def any_message_handler(message: Message): ai_agent.Message(user_name=await get_user_name_for_ai(message.ctx_api, fwd_message.from_id), text=fwd_message.text)) else: - await message.reply(MESSAGE_NOT_TEXT) + await message.reply(MESSAGE_UNSUPPORTED_CONTENT_TYPE) return ai_message.user_name = await get_user_name_for_ai(message.ctx_api, message.from_id) diff --git a/vk/handlers/private.py b/vk/handlers/private.py index 5aee140..5382909 100644 --- a/vk/handlers/private.py +++ b/vk/handlers/private.py @@ -49,11 +49,15 @@ async def any_message_handler(message: Message): chat_id = message.peer_id if len(message.text) == 0: - await message.answer(MESSAGE_NOT_TEXT) + await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE) return + ai_message = ai_agent.Message() + ai_message.text = message.text + ai_message.message_id = message.message_id + answer, success = await utils.run_with_progress( - partial(ai_agent.agent.get_private_chat_reply, bot_id, chat_id, message.text, message.message_id), + partial(ai_agent.agent.get_private_chat_reply, bot_id, chat_id, ai_message), partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'), interval=4) diff --git a/vk/vk_database.py b/vk/vk_database.py index 6abc5a1..5d169bc 100644 --- a/vk/vk_database.py +++ b/vk/vk_database.py @@ -51,7 +51,8 @@ class VkDatabase(database.BasicDatabase): chat_id BIGINT NOT NULL, message_id BIGINT, role VARCHAR(16) NOT NULL, - content VARCHAR(4000) NOT NULL, + text VARCHAR(4000), + image MEDIUMBLOB, UNIQUE KEY contexts_unique (bot_id, chat_id, message_id), CONSTRAINT fk_contexts_chats FOREIGN KEY (bot_id, chat_id) REFERENCES chats (bot_id, chat_id) ON UPDATE CASCADE ON DELETE CASCADE) """)