diff --git a/ai_agent.py b/ai_agent.py index eb5113e..756765e 100644 --- a/ai_agent.py +++ b/ai_agent.py @@ -8,11 +8,11 @@ from dataclasses import dataclass from io import BytesIO from PIL import Image from result import Ok, Err, Result -from typing import List, Tuple, Any, Optional, Union, Dict, Awaitable +from typing import List, Tuple, Optional, Union, Dict, Awaitable from openrouter import OpenRouter, RetryConfig from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \ - ChatMessageToolCall, MessageTypedDict + ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict from openrouter.errors import ResponseValidationError, ChatError from openrouter.utils import BackoffStrategy @@ -68,9 +68,9 @@ class AiAgent: async def get_group_chat_reply(self, bot_id: int, chat_id: int, message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]: - context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id) - message.text = _add_message_prefix(message.text, message.user_name) + + context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id) context.append(_serialize_message(role="user", text=message.text, image=message.image)) for fwd_message in forwarded_messages: @@ -111,14 +111,10 @@ class AiAgent: return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]: - context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id) message.text = _add_message_prefix(message.text) - content: list[dict[str, Any]] = [] - if message.text is not None: - content.append({"type": "text", "text": message.text}) - if message.image is not None: - content.append({"type": "image_url", "image_url": {"url": _encode_image(message.image)}}) - context.append({"role": "user", "content": content}) + + context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id) + context.append(_serialize_message(role="user", text=message.text, image=message.image)) try: response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True) @@ -156,9 +152,17 @@ class AiAgent: def clear_chat_context(self, bot_id: int, chat_id: int): self.db.context_clear(bot_id, chat_id) - ####################################### + #################################################################################### def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[MessageTypedDict]: + context: List[MessageTypedDict] = [ + self._construct_system_prompt(is_group_chat=is_group_chat, bot_id=bot_id, chat_id=chat_id) + ] + for message in self.db.context_get_messages(bot_id, chat_id): + context.append(_serialize_message(message["role"], message["text"], message["image"])) + return context + + def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> SystemMessageTypedDict: prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK') prompt += '\n' + self.system_prompt_image_generation @@ -171,12 +175,7 @@ class AiAgent: if chat['ai_prompt'] is not None: prompt += '\n' + chat['ai_prompt'] - messages = self.db.context_get_messages(bot_id, chat_id) - - context: List[MessageTypedDict] = [{"role": "system", "content": prompt}] - for message in messages: - context.append(_serialize_message(message["role"], message["text"], message["image"])) - return context + return {"role": "system", "content": prompt} async def _generate_reply(self, bot_id: int, chat_id: int, context: List[MessageTypedDict], allow_tools: bool = False) -> AssistantMessage: @@ -213,7 +212,6 @@ class AiAgent: "tool_call_id": tool_call.id, "content": tool_result }) - artifacts.tools_called = True return artifacts async def _process_tool_generate_image(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \ @@ -222,19 +220,12 @@ class AiAgent: aspect_ratio = args.get("aspect_ratio", None) result = await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio) - content = [] if result.is_ok(): - content.append( - {"type": "text", - "text": "Изображение сгенерировано и будет показано пользователю."}) - content.append( - {"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}}) artifacts.generated_image = result.ok_value + return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.", + image=result.ok_value) else: - content.append( - {"type": "text", - "text": f"Не удалось сгенерировать изображение: {result.err_value}"}) - return content + return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}") async def _generate_image(self, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]: width, height = _get_resolution_for_aspect_ratio(aspect_ratio) @@ -275,24 +266,17 @@ class AiAgent: result = await self._generate_image_anime(prompt=prompt, negative_prompt=negative_prompt, aspect_ratio=aspect_ratio) - content = [] if result.is_ok(): - content.append( - {"type": "text", - "text": "Изображение сгенерировано и будет показано пользователю."}) - content.append( - {"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}}) artifacts.generated_image = result.ok_value + return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.", + image=result.ok_value) else: - content.append( - {"type": "text", - "text": f"Не удалось сгенерировать изображение: {result.err_value}"}) - return content + return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}") async def _generate_image_anime(self, prompt: str, negative_prompt: str, aspect_ratio: Optional[str]) \ -> Result[bytes, str]: width, height = _get_resolution_for_aspect_ratio(aspect_ratio) - print(f"Генерация изображения {width}x{height}: positive='{prompt}', negative='{negative_prompt}'") + print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}") arguments = { "prompt": prompt, @@ -379,12 +363,16 @@ def _encode_image(image: bytes) -> str: def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict: - serialized = {"role": role, "content": []} + return {"role": role, "content": _serialize_message_content(text, image)} + + +def _serialize_message_content(text: Optional[str], image: Optional[bytes] = None) -> list[dict]: + content = [] if text is not None: - serialized["content"].append({"type": "text", "text": text}) + content.append({"type": "text", "text": text}) if image is not None: - serialized["content"].append({"type": "image_url", "image_url": {"url": _encode_image(image)}}) - return serialized + content.append({"type": "image_url", "detail": "high", "image_url": {"url": _encode_image(image)}}) + return content def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]: diff --git a/prompts/image_generation.md b/prompts/image_generation.md index 15b3b5b..d2cdedd 100644 --- a/prompts/image_generation.md +++ b/prompts/image_generation.md @@ -13,7 +13,7 @@ Если сгенерировать изображение не удалось из-за ошибки, просто сообщи об этом пользователю. ## Генерация обычных (не аниме) изображений -Для генерации используй функцию `generate_image` и составляй запрос на естесственном языке по следующей формуле: +Для генерации используй функцию `generate_image` и составляй запрос на естественном языке по следующей формуле: 1. Объекты сцены. 2. Действие/поза. 3. Окружение.