diff --git a/ai_agent.py b/ai_agent.py index 2ebf44d..4bc4a60 100644 --- a/ai_agent.py +++ b/ai_agent.py @@ -5,10 +5,12 @@ import json from dataclasses import dataclass from io import BytesIO from PIL import Image -from typing import List, Tuple, Any, Optional +from typing import List, Tuple, Any, Optional, Union from openrouter import OpenRouter, RetryConfig -from openrouter.components import ToolDefinitionJSONTypedDict, MessageTypedDict +from openrouter.components import AssistantMessage, ChatMessageToolCall, \ + MessageTypedDict, ToolDefinitionJSONTypedDict, AssistantMessageTypedDict + from openrouter.utils import BackoffStrategy from database import BasicDatabase @@ -19,24 +21,38 @@ GROUP_CHAT_SYSTEM_PROMPT = """ Ты не можешь обсуждать политику и религию.\n Сообщения пользователей будут приходить в следующем формате: '[дата время, имя]: текст сообщения'\n При ответе НЕ нужно указывать ни время, ни пользователя, которому предназначен ответ, ни свое имя.\n -НЕ используй разметку Markdown. +НЕ используй разметку Markdown, она не поддерживается мессенджером.\n +Если нужно нарисовать изображение, используй вызов инструмента. +Запрещено генерировать ASCII-арты, а также добавлять теги/маркеры вроде , [image]. """ -GROUP_CHAT_MAX_MESSAGES = 20 PRIVATE_CHAT_SYSTEM_PROMPT = """ Ты - ИИ-помощник в чате c пользователем.\n Отвечай на вопросы и поддерживай контекст беседы.\n Сообщения пользователя будут приходить в следующем формате: '[дата время]: текст сообщения'\n При ответе НЕ нужно указывать время.\n -Никогда не используй разметку Markdown.\n -Никогда не добавляй ASCII-арты в ответ. +НЕ используй разметку Markdown, она не поддерживается мессенджером.\n +Если нужно нарисовать изображение, используй вызов инструмента. +Запрещено генерировать ASCII-арты, а также добавлять теги/маркеры вроде , [image]. """ -PRIVATE_CHAT_MAX_MESSAGES = 40 -OPENROUTER_HEADERS = { - 'HTTP-Referer': 'https://ultracoder.org', - 'X-Title': 'TG/VK Chat Bot' -} +GENERATE_IMAGE_TOOL_DESCRIPTION = """ +Генерация изображения по описанию. +Используй этот инструмент, если пользователь просит сгенерировать изображение ('нарисуй', 'покажи' и т.п.), +или если это улучшит ответ (например, в ролевой игре для визуализации сцены). +""" + +GENERATE_IMAGE_TOOL_PROMPT_ARG_DESCRIPTION = """ +Детальное описание на русском языке. +Добавь детали для стиля, цвета, композиции, если нужно. +""" + +OPENROUTER_X_TITLE = "TG/VK Chat Bot" +OPENROUTER_HTTP_REFERER = "https://ultracoder.org" + +GROUP_CHAT_MAX_MESSAGES = 20 +PRIVATE_CHAT_MAX_MESSAGES = 40 +MAX_OUTPUT_TOKENS = 500 @dataclass() @@ -64,42 +80,61 @@ def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) - serialized["content"].append({"type": "text", "text": text}) if image is not None: serialized["content"].append({"type": "image_url", "image_url": {"url": _encode_image(image)}}) - return serialized +def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]: + if isinstance(data, dict): + return { + k: _remove_none_recursive(v) + for k, v in data.items() + if v is not None + } + elif isinstance(data, list): + return [ + _remove_none_recursive(item) + for item in data + if item is not None + ] + else: + return data + + +def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict: + return _remove_none_recursive(message.model_dump(by_alias=True)) + + def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]: return [{ "type": "function", "function": { "name": "generate_image", - "description": """ -Генерация изображения по описанию. -Используй этот инструмент, если пользователь просит сгенерировать изображение ('нарисуй', 'покажи' и т.п.), -ИЛИ если это улучшит ответ (например, в ролевой игре для визуализации сцены). -""", + "description": GENERATE_IMAGE_TOOL_DESCRIPTION, "parameters": { "type": "object", "properties": { "prompt": { "type": "string", - "description": """ -Детальное описание на английском (рекомендуется). -Добавь детали для стиля, цвета, композиции, если нужно. -""" + "description": GENERATE_IMAGE_TOOL_PROMPT_ARG_DESCRIPTION }, "aspect_ratio": { "type": "string", "enum": ["1:1", "3:4", "4:3", "9:16", "16:9"], "description": "Соотношение сторон (опционально)." } - } + }, + "required": ["prompt"] } } }] class AiAgent: + @dataclass() + class ToolCallResult: + tools_called: bool = False + generated_image: Optional[bytes] = None + def __init__(self, api_token_main: str, model_main: str, api_token_image: str, model_image: str, @@ -113,11 +148,14 @@ class AiAgent: self.model_main = model_main self.model_image = model_image self.platform = platform - self.client = OpenRouter(api_key=api_token_main, retry_config=retry_config) - self.client_image = OpenRouter(api_key=api_token_image) + self.client = OpenRouter(api_key=api_token_main, + x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER, + retry_config=retry_config) + self.client_image = OpenRouter(api_key=api_token_image, + x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER) async def get_group_chat_reply(self, bot_id: int, chat_id: int, - message: Message, forwarded_messages: List[Message]) -> Tuple[str, bool]: + message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]: context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id) message.text = _add_message_prefix(message.text, message.user_name) @@ -131,31 +169,33 @@ class AiAgent: context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image)) try: - response = await self.client.chat.send_async( - model=self.model_main, - messages=context, - max_tokens=500, - user=f'{self.platform}_{bot_id}_{chat_id}', - http_headers=OPENROUTER_HEADERS - ) - ai_response = response.choices[0].message.content + response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True) + ai_response = response.content + + tools_call_result = await self._process_tool_calls(bot_id, chat_id, + tool_calls=response.tool_calls, context=context) + if tools_call_result.tools_called: + response2 = await self._generate_reply(bot_id, chat_id, context=context) + ai_response = response2.content self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image, message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES) for fwd_message in forwarded_messages: - self.db.context_add_message(bot_id, chat_id, role="user", text=fwd_message.text, image=fwd_message.image, + self.db.context_add_message(bot_id, chat_id, + role="user", text=fwd_message.text, image=fwd_message.image, message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES) - self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None, + self.db.context_add_message(bot_id, chat_id, + role="assistant", text=ai_response, image=tools_call_result.generated_image, message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES) - return ai_response, True + return Message(text=ai_response, image=tools_call_result.generated_image), True except Exception as e: if str(e).find("Rate limit exceeded") != -1: - return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК).", False + return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False else: print(f"Ошибка выполнения запроса к ИИ: {e}") - return f"Извините, при обработке запроса произошла ошибка.", False + return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]: context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id) @@ -168,72 +208,30 @@ class AiAgent: context.append({"role": "user", "content": content}) try: - user_tag = f'{self.platform}_{bot_id}_{chat_id}' + response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True) + context.append(_serialize_assistant_message(response)) + ai_response = response.content - response = await self.client.chat.send_async( - model=self.model_main, - messages=context, - tools=_get_tools_description(), - tool_choice="auto", - max_tokens=500, - user=user_tag, - http_headers=OPENROUTER_HEADERS - ) - ai_response = response.choices[0].message.content - context.append(response.choices[0].message) - - image_response_image: Optional[bytes] = None - if response.choices[0].message.tool_calls is not None: - for tool_call in response.choices[0].message.tool_calls: - tool_name = tool_call.function.name - tool_args = json.loads(tool_call.function.arguments) - if tool_name == "generate_image": - prompt = tool_args.get("prompt", "") - aspect_ratio = tool_args.get("aspect_ratio", None) - image_response_text, image_response_image =\ - await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio, user_tag=user_tag) - context.append({ - "role": "tool", - "tool_call_id": tool_call.id, - "content": [ - {"type": "text", - "text": """ -Изображение сгенерировано и будет показано пользователю. -НЕ добавляй никаких тегов или маркеров вроде , [image] — они запрещены и не нужны. -"""}, - {"type": "image_url", "image_url": {"url": _encode_image(image_response_image)}} - ] - }) - - response2 = await self.client.chat.send_async( - model=self.model_main, - messages=context, - max_tokens=500, - user=user_tag, - http_headers=OPENROUTER_HEADERS - ) - ai_response = response2.choices[0].message.content - break + tools_call_result = await self._process_tool_calls(bot_id, chat_id, + tool_calls=response.tool_calls, context=context) + if tools_call_result.tools_called: + response2 = await self._generate_reply(bot_id, chat_id, context=context) + ai_response = response2.content self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image, message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES) + self.db.context_add_message(bot_id, chat_id, role="assistant", + text=ai_response, image=tools_call_result.generated_image, + message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES) - if image_response_image is not None: - self.db.context_add_message(bot_id, chat_id, role="assistant", - text=ai_response, image=image_response_image, - message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES) - else: - self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None, - message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES) - - return Message(text=ai_response, image=image_response_image), True + return Message(text=ai_response, image=tools_call_result.generated_image), True except Exception as e: if str(e).find("Rate limit exceeded") != -1: return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False else: print(f"Ошибка выполнения запроса к ИИ: {e}") - return Message(text=f"Извините, при обработке запроса произошла ошибка."), False + return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False def get_last_assistant_message_id(self, bot_id: int, chat_id: int): return self.db.context_get_last_assistant_message_id(bot_id, chat_id) @@ -262,38 +260,78 @@ class AiAgent: context.append(_serialize_message(message["role"], message["text"], message["image"])) return context - async def _generate_image(self, prompt: str, aspect_ratio: Optional[str], user_tag: str) -> Tuple[str, bytes]: + async def _generate_reply(self, bot_id: int, chat_id: int, + context: List[MessageTypedDict], allow_tools: bool = False) -> AssistantMessage: + response = await self.client.chat.send_async( + model=self.model_main, + messages=context, + tools=_get_tools_description() if allow_tools else None, + tool_choice="auto" if allow_tools else None, + max_tokens=MAX_OUTPUT_TOKENS, + user=f'{self.platform}_{bot_id}_{chat_id}' + ) + return response.choices[0].message + + async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall], + context: List[MessageTypedDict]) -> ToolCallResult: + result = AiAgent.ToolCallResult() + if tool_calls is not None: + for tool_call in tool_calls: + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + if tool_name == "generate_image": + prompt = tool_args.get("prompt", "") + aspect_ratio = tool_args.get("aspect_ratio", None) + result.generated_image, success = \ + await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio) + tool_result_content = [] + if success: + tool_result_content.append( + {"type": "text", + "text": "Изображение сгенерировано и будет показано пользователю."}) + tool_result_content.append( + {"type": "image_url", "image_url": {"url": _encode_image(result.generated_image)}}) + else: + tool_result_content.append( + {"type": "text", + "text": "Не удалось сгенерировать изображение."}) + context.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": tool_result_content + }) + result.tools_called = True + break + return result + + async def _generate_image(self, bot_id: int, chat_id: int, prompt: str, aspect_ratio: Optional[str]) \ + -> Tuple[Optional[bytes], bool]: + print(f"Генерация изображения: {prompt}") context = [{"role": "user", "content": prompt}] - if aspect_ratio is not None: + + try: response = await self.client_image.chat.send_async( model=self.model_image, messages=context, - user=user_tag, + user=f'{self.platform}_{bot_id}_{chat_id}', modalities=["image"], - image_config={"aspect_ratio": aspect_ratio} - ) - else: - response = await self.client_image.chat.send_async( - model=self.model_image, - messages=context, - user=user_tag, - modalities=["image"] + image_config={"aspect_ratio": aspect_ratio} if aspect_ratio is not None else None ) - content = response.choices[0].message.content + image_url = response.choices[0].message.images[0].image_url.url + header, image_base64 = image_url.split(",", 1) + mime_type = header.split(";")[0].replace("data:", "") + image_bytes = base64.b64decode(image_base64) - image_url = response.choices[0].message.images[0].image_url.url - header, image_base64 = image_url.split(",", 1) - mime_type = header.split(";")[0].replace("data:", "") - image_bytes = base64.b64decode(image_base64) + if mime_type != "image/jpeg": + image = Image.open(BytesIO(image_bytes)).convert("RGB") + output = BytesIO() + image.save(output, format="JPEG", quality=80, optimize=True) + image_bytes = output.getvalue() - if mime_type != "image/jpeg": - image = Image.open(BytesIO(image_bytes)).convert("RGB") - output = BytesIO() - image.save(output, format="JPEG", quality=80, optimize=True) - image_bytes = output.getvalue() - - return content, image_bytes + return image_bytes, True + except Exception: + return None, False agent: AiAgent diff --git a/tg/handlers/default.py b/tg/handlers/default.py index 535bb8b..62fdade 100644 --- a/tg/handlers/default.py +++ b/tg/handlers/default.py @@ -77,11 +77,16 @@ async def any_message_handler(message: Message, bot: Bot): ai_message = await create_ai_message(message, bot) ai_message.text = message_text + answer: ai_agent.Message + success: bool answer, success = await utils.run_with_progress( partial(ai_agent.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages), partial(message.bot.send_chat_action, chat_id, 'typing'), interval=4) - answer_id = (await message.reply(answer)).message_id + if answer.image is not None: + answer_id = (await message.reply_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id + else: + answer_id = (await message.reply(answer.text)).message_id if success: ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id) diff --git a/tg/handlers/user.py b/tg/handlers/user.py index 69bdb57..677e263 100644 --- a/tg/handlers/user.py +++ b/tg/handlers/user.py @@ -213,11 +213,13 @@ async def check_rules_violation_handler(message: Message, bot: Bot): ai_fwd_messages = [ai_agent.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user), text=message.reply_to_message.text)] + answer: ai_agent.Message + success: bool answer, success = await utils.run_with_progress( partial(ai_agent.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages), partial(bot.send_chat_action, chat_id, 'typing'), interval=4) - answer_id = (await message.answer(answer)).message_id + answer_id = (await message.answer(answer.text)).message_id if success: ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id) diff --git a/vk/handlers/default.py b/vk/handlers/default.py index d7451f3..065dd5f 100644 --- a/vk/handlers/default.py +++ b/vk/handlers/default.py @@ -86,11 +86,13 @@ async def any_message_handler(message: Message): ai_message = await create_ai_message(message) ai_message.text = message_text + answer: ai_agent.Message + success: bool answer, success = await utils.run_with_progress( partial(ai_agent.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages), partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'), interval=4) - answer_id = (await message.reply(answer)).conversation_message_id + answer_id = (await message.reply(answer.text)).conversation_message_id if success: ai_agent.agent.set_last_response_id(bot_id, chat_id, answer_id) diff --git a/vk/handlers/user.py b/vk/handlers/user.py index df8d7e4..3aacc19 100644 --- a/vk/handlers/user.py +++ b/vk/handlers/user.py @@ -264,11 +264,13 @@ async def check_rules_violation_handler(message: Message): await message.answer(MESSAGE_NEED_REPLY_OR_FORWARD) return + answer: ai_agent.Message + success: bool answer, success = await utils.run_with_progress( partial(ai_agent.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages), partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'), interval=4) - answer_id = (await message.answer(answer)).message_id + answer_id = (await message.answer(answer.text)).message_id if success: ai_agent.agent.set_last_response_id(bot_id, chat_id, answer_id)