Рефакторинг AiAgent.

Исправлена грамматическая ошибка в системном запросе.
This commit is contained in:
Kirill Kirilenko 2026-03-05 20:46:53 +03:00
parent 07c804c338
commit 1d6fd6f612
2 changed files with 33 additions and 45 deletions

View file

@ -8,11 +8,11 @@ from dataclasses import dataclass
from io import BytesIO
from PIL import Image
from result import Ok, Err, Result
from typing import List, Tuple, Any, Optional, Union, Dict, Awaitable
from typing import List, Tuple, Optional, Union, Dict, Awaitable
from openrouter import OpenRouter, RetryConfig
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \
ChatMessageToolCall, MessageTypedDict
ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict
from openrouter.errors import ResponseValidationError, ChatError
from openrouter.utils import BackoffStrategy
@ -68,9 +68,9 @@ class AiAgent:
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
message.text = _add_message_prefix(message.text, message.user_name)
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
context.append(_serialize_message(role="user", text=message.text, image=message.image))
for fwd_message in forwarded_messages:
@ -111,14 +111,10 @@ class AiAgent:
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]:
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
message.text = _add_message_prefix(message.text)
content: list[dict[str, Any]] = []
if message.text is not None:
content.append({"type": "text", "text": message.text})
if message.image is not None:
content.append({"type": "image_url", "image_url": {"url": _encode_image(message.image)}})
context.append({"role": "user", "content": content})
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
context.append(_serialize_message(role="user", text=message.text, image=message.image))
try:
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
@ -156,9 +152,17 @@ class AiAgent:
def clear_chat_context(self, bot_id: int, chat_id: int):
self.db.context_clear(bot_id, chat_id)
#######################################
####################################################################################
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[MessageTypedDict]:
context: List[MessageTypedDict] = [
self._construct_system_prompt(is_group_chat=is_group_chat, bot_id=bot_id, chat_id=chat_id)
]
for message in self.db.context_get_messages(bot_id, chat_id):
context.append(_serialize_message(message["role"], message["text"], message["image"]))
return context
def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> SystemMessageTypedDict:
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
prompt += '\n' + self.system_prompt_image_generation
@ -171,12 +175,7 @@ class AiAgent:
if chat['ai_prompt'] is not None:
prompt += '\n' + chat['ai_prompt']
messages = self.db.context_get_messages(bot_id, chat_id)
context: List[MessageTypedDict] = [{"role": "system", "content": prompt}]
for message in messages:
context.append(_serialize_message(message["role"], message["text"], message["image"]))
return context
return {"role": "system", "content": prompt}
async def _generate_reply(self, bot_id: int, chat_id: int,
context: List[MessageTypedDict], allow_tools: bool = False) -> AssistantMessage:
@ -213,7 +212,6 @@ class AiAgent:
"tool_call_id": tool_call.id,
"content": tool_result
})
artifacts.tools_called = True
return artifacts
async def _process_tool_generate_image(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \
@ -222,19 +220,12 @@ class AiAgent:
aspect_ratio = args.get("aspect_ratio", None)
result = await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio)
content = []
if result.is_ok():
content.append(
{"type": "text",
"text": "Изображение сгенерировано и будет показано пользователю."})
content.append(
{"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}})
artifacts.generated_image = result.ok_value
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
image=result.ok_value)
else:
content.append(
{"type": "text",
"text": f"Не удалось сгенерировать изображение: {result.err_value}"})
return content
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}")
async def _generate_image(self, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]:
width, height = _get_resolution_for_aspect_ratio(aspect_ratio)
@ -275,24 +266,17 @@ class AiAgent:
result = await self._generate_image_anime(prompt=prompt, negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio)
content = []
if result.is_ok():
content.append(
{"type": "text",
"text": "Изображение сгенерировано и будет показано пользователю."})
content.append(
{"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}})
artifacts.generated_image = result.ok_value
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
image=result.ok_value)
else:
content.append(
{"type": "text",
"text": f"Не удалось сгенерировать изображение: {result.err_value}"})
return content
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}")
async def _generate_image_anime(self, prompt: str, negative_prompt: str, aspect_ratio: Optional[str]) \
-> Result[bytes, str]:
width, height = _get_resolution_for_aspect_ratio(aspect_ratio)
print(f"Генерация изображения {width}x{height}: positive='{prompt}', negative='{negative_prompt}'")
print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}")
arguments = {
"prompt": prompt,
@ -379,12 +363,16 @@ def _encode_image(image: bytes) -> str:
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
serialized = {"role": role, "content": []}
return {"role": role, "content": _serialize_message_content(text, image)}
def _serialize_message_content(text: Optional[str], image: Optional[bytes] = None) -> list[dict]:
content = []
if text is not None:
serialized["content"].append({"type": "text", "text": text})
content.append({"type": "text", "text": text})
if image is not None:
serialized["content"].append({"type": "image_url", "image_url": {"url": _encode_image(image)}})
return serialized
content.append({"type": "image_url", "detail": "high", "image_url": {"url": _encode_image(image)}})
return content
def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]:

View file

@ -13,7 +13,7 @@
Если сгенерировать изображение не удалось из-за ошибки, просто сообщи об этом пользователю.
## Генерация обычных (не аниме) изображений
Для генерации используй функцию `generate_image` и составляй запрос на естесственном языке по следующей формуле:
Для генерации используй функцию `generate_image` и составляй запрос на естественном языке по следующей формуле:
1. Объекты сцены.
2. Действие/поза.
3. Окружение.