Рефакторинг AiAgent.
Исправлена грамматическая ошибка в системном запросе.
This commit is contained in:
parent
07c804c338
commit
1d6fd6f612
2 changed files with 33 additions and 45 deletions
76
ai_agent.py
76
ai_agent.py
|
|
@ -8,11 +8,11 @@ from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from result import Ok, Err, Result
|
from result import Ok, Err, Result
|
||||||
from typing import List, Tuple, Any, Optional, Union, Dict, Awaitable
|
from typing import List, Tuple, Optional, Union, Dict, Awaitable
|
||||||
|
|
||||||
from openrouter import OpenRouter, RetryConfig
|
from openrouter import OpenRouter, RetryConfig
|
||||||
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \
|
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \
|
||||||
ChatMessageToolCall, MessageTypedDict
|
ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict
|
||||||
from openrouter.errors import ResponseValidationError, ChatError
|
from openrouter.errors import ResponseValidationError, ChatError
|
||||||
from openrouter.utils import BackoffStrategy
|
from openrouter.utils import BackoffStrategy
|
||||||
|
|
||||||
|
|
@ -68,9 +68,9 @@ class AiAgent:
|
||||||
|
|
||||||
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
||||||
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
|
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
|
||||||
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
|
|
||||||
|
|
||||||
message.text = _add_message_prefix(message.text, message.user_name)
|
message.text = _add_message_prefix(message.text, message.user_name)
|
||||||
|
|
||||||
|
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
|
||||||
context.append(_serialize_message(role="user", text=message.text, image=message.image))
|
context.append(_serialize_message(role="user", text=message.text, image=message.image))
|
||||||
|
|
||||||
for fwd_message in forwarded_messages:
|
for fwd_message in forwarded_messages:
|
||||||
|
|
@ -111,14 +111,10 @@ class AiAgent:
|
||||||
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
|
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
|
||||||
|
|
||||||
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]:
|
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]:
|
||||||
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
|
|
||||||
message.text = _add_message_prefix(message.text)
|
message.text = _add_message_prefix(message.text)
|
||||||
content: list[dict[str, Any]] = []
|
|
||||||
if message.text is not None:
|
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
|
||||||
content.append({"type": "text", "text": message.text})
|
context.append(_serialize_message(role="user", text=message.text, image=message.image))
|
||||||
if message.image is not None:
|
|
||||||
content.append({"type": "image_url", "image_url": {"url": _encode_image(message.image)}})
|
|
||||||
context.append({"role": "user", "content": content})
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
|
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
|
||||||
|
|
@ -156,9 +152,17 @@ class AiAgent:
|
||||||
def clear_chat_context(self, bot_id: int, chat_id: int):
|
def clear_chat_context(self, bot_id: int, chat_id: int):
|
||||||
self.db.context_clear(bot_id, chat_id)
|
self.db.context_clear(bot_id, chat_id)
|
||||||
|
|
||||||
#######################################
|
####################################################################################
|
||||||
|
|
||||||
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[MessageTypedDict]:
|
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[MessageTypedDict]:
|
||||||
|
context: List[MessageTypedDict] = [
|
||||||
|
self._construct_system_prompt(is_group_chat=is_group_chat, bot_id=bot_id, chat_id=chat_id)
|
||||||
|
]
|
||||||
|
for message in self.db.context_get_messages(bot_id, chat_id):
|
||||||
|
context.append(_serialize_message(message["role"], message["text"], message["image"]))
|
||||||
|
return context
|
||||||
|
|
||||||
|
def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> SystemMessageTypedDict:
|
||||||
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
|
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
|
||||||
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
|
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
|
||||||
prompt += '\n' + self.system_prompt_image_generation
|
prompt += '\n' + self.system_prompt_image_generation
|
||||||
|
|
@ -171,12 +175,7 @@ class AiAgent:
|
||||||
if chat['ai_prompt'] is not None:
|
if chat['ai_prompt'] is not None:
|
||||||
prompt += '\n' + chat['ai_prompt']
|
prompt += '\n' + chat['ai_prompt']
|
||||||
|
|
||||||
messages = self.db.context_get_messages(bot_id, chat_id)
|
return {"role": "system", "content": prompt}
|
||||||
|
|
||||||
context: List[MessageTypedDict] = [{"role": "system", "content": prompt}]
|
|
||||||
for message in messages:
|
|
||||||
context.append(_serialize_message(message["role"], message["text"], message["image"]))
|
|
||||||
return context
|
|
||||||
|
|
||||||
async def _generate_reply(self, bot_id: int, chat_id: int,
|
async def _generate_reply(self, bot_id: int, chat_id: int,
|
||||||
context: List[MessageTypedDict], allow_tools: bool = False) -> AssistantMessage:
|
context: List[MessageTypedDict], allow_tools: bool = False) -> AssistantMessage:
|
||||||
|
|
@ -213,7 +212,6 @@ class AiAgent:
|
||||||
"tool_call_id": tool_call.id,
|
"tool_call_id": tool_call.id,
|
||||||
"content": tool_result
|
"content": tool_result
|
||||||
})
|
})
|
||||||
artifacts.tools_called = True
|
|
||||||
return artifacts
|
return artifacts
|
||||||
|
|
||||||
async def _process_tool_generate_image(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \
|
async def _process_tool_generate_image(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \
|
||||||
|
|
@ -222,19 +220,12 @@ class AiAgent:
|
||||||
aspect_ratio = args.get("aspect_ratio", None)
|
aspect_ratio = args.get("aspect_ratio", None)
|
||||||
result = await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio)
|
result = await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio)
|
||||||
|
|
||||||
content = []
|
|
||||||
if result.is_ok():
|
if result.is_ok():
|
||||||
content.append(
|
|
||||||
{"type": "text",
|
|
||||||
"text": "Изображение сгенерировано и будет показано пользователю."})
|
|
||||||
content.append(
|
|
||||||
{"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}})
|
|
||||||
artifacts.generated_image = result.ok_value
|
artifacts.generated_image = result.ok_value
|
||||||
|
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
|
||||||
|
image=result.ok_value)
|
||||||
else:
|
else:
|
||||||
content.append(
|
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}")
|
||||||
{"type": "text",
|
|
||||||
"text": f"Не удалось сгенерировать изображение: {result.err_value}"})
|
|
||||||
return content
|
|
||||||
|
|
||||||
async def _generate_image(self, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]:
|
async def _generate_image(self, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]:
|
||||||
width, height = _get_resolution_for_aspect_ratio(aspect_ratio)
|
width, height = _get_resolution_for_aspect_ratio(aspect_ratio)
|
||||||
|
|
@ -275,24 +266,17 @@ class AiAgent:
|
||||||
result = await self._generate_image_anime(prompt=prompt, negative_prompt=negative_prompt,
|
result = await self._generate_image_anime(prompt=prompt, negative_prompt=negative_prompt,
|
||||||
aspect_ratio=aspect_ratio)
|
aspect_ratio=aspect_ratio)
|
||||||
|
|
||||||
content = []
|
|
||||||
if result.is_ok():
|
if result.is_ok():
|
||||||
content.append(
|
|
||||||
{"type": "text",
|
|
||||||
"text": "Изображение сгенерировано и будет показано пользователю."})
|
|
||||||
content.append(
|
|
||||||
{"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}})
|
|
||||||
artifacts.generated_image = result.ok_value
|
artifacts.generated_image = result.ok_value
|
||||||
|
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
|
||||||
|
image=result.ok_value)
|
||||||
else:
|
else:
|
||||||
content.append(
|
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}")
|
||||||
{"type": "text",
|
|
||||||
"text": f"Не удалось сгенерировать изображение: {result.err_value}"})
|
|
||||||
return content
|
|
||||||
|
|
||||||
async def _generate_image_anime(self, prompt: str, negative_prompt: str, aspect_ratio: Optional[str]) \
|
async def _generate_image_anime(self, prompt: str, negative_prompt: str, aspect_ratio: Optional[str]) \
|
||||||
-> Result[bytes, str]:
|
-> Result[bytes, str]:
|
||||||
width, height = _get_resolution_for_aspect_ratio(aspect_ratio)
|
width, height = _get_resolution_for_aspect_ratio(aspect_ratio)
|
||||||
print(f"Генерация изображения {width}x{height}: positive='{prompt}', negative='{negative_prompt}'")
|
print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}")
|
||||||
|
|
||||||
arguments = {
|
arguments = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
|
@ -379,12 +363,16 @@ def _encode_image(image: bytes) -> str:
|
||||||
|
|
||||||
|
|
||||||
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
|
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
|
||||||
serialized = {"role": role, "content": []}
|
return {"role": role, "content": _serialize_message_content(text, image)}
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_message_content(text: Optional[str], image: Optional[bytes] = None) -> list[dict]:
|
||||||
|
content = []
|
||||||
if text is not None:
|
if text is not None:
|
||||||
serialized["content"].append({"type": "text", "text": text})
|
content.append({"type": "text", "text": text})
|
||||||
if image is not None:
|
if image is not None:
|
||||||
serialized["content"].append({"type": "image_url", "image_url": {"url": _encode_image(image)}})
|
content.append({"type": "image_url", "detail": "high", "image_url": {"url": _encode_image(image)}})
|
||||||
return serialized
|
return content
|
||||||
|
|
||||||
|
|
||||||
def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]:
|
def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]:
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@
|
||||||
Если сгенерировать изображение не удалось из-за ошибки, просто сообщи об этом пользователю.
|
Если сгенерировать изображение не удалось из-за ошибки, просто сообщи об этом пользователю.
|
||||||
|
|
||||||
## Генерация обычных (не аниме) изображений
|
## Генерация обычных (не аниме) изображений
|
||||||
Для генерации используй функцию `generate_image` и составляй запрос на естесственном языке по следующей формуле:
|
Для генерации используй функцию `generate_image` и составляй запрос на естественном языке по следующей формуле:
|
||||||
1. Объекты сцены.
|
1. Объекты сцены.
|
||||||
2. Действие/поза.
|
2. Действие/поза.
|
||||||
3. Окружение.
|
3. Окружение.
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue