Рефакторинг агента.

This commit is contained in:
Kirill Kirilenko 2026-02-15 03:04:21 +03:00
parent e05e0d4c82
commit ec18d3ffa4

View file

@ -2,15 +2,16 @@ import base64
import datetime
import json
from collections.abc import Callable
from dataclasses import dataclass
from io import BytesIO
from PIL import Image
from typing import List, Tuple, Any, Optional, Union
from result import Ok, Err, Result
from typing import List, Tuple, Any, Optional, Union, Dict, Awaitable
from openrouter import OpenRouter, RetryConfig
from openrouter.components import AssistantMessage, ChatMessageToolCall, \
MessageTypedDict, ToolDefinitionJSONTypedDict, AssistantMessageTypedDict
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \
ChatMessageToolCall, MessageTypedDict, ToolDefinitionJSONTypedDict
from openrouter.utils import BackoffStrategy
from database import BasicDatabase
@ -130,11 +131,6 @@ def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]:
class AiAgent:
@dataclass()
class ToolCallResult:
tools_called: bool = False
generated_image: Optional[bytes] = None
def __init__(self,
api_token_main: str, model_main: str,
api_token_image: str, model_image: str,
@ -154,6 +150,10 @@ class AiAgent:
self.client_image = OpenRouter(api_key=api_token_image,
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER)
@dataclass()
class _ToolsArtifacts:
generated_image: Optional[bytes] = None
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
@ -172,9 +172,10 @@ class AiAgent:
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
ai_response = response.content
tools_call_result = await self._process_tool_calls(bot_id, chat_id,
tool_calls=response.tool_calls, context=context)
if tools_call_result.tools_called:
tools_artifacts = AiAgent._ToolsArtifacts()
if response.tool_calls is not None:
tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
tool_calls=response.tool_calls, context=context)
response2 = await self._generate_reply(bot_id, chat_id, context=context)
ai_response = response2.content
@ -185,10 +186,10 @@ class AiAgent:
role="user", text=fwd_message.text, image=fwd_message.image,
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id,
role="assistant", text=ai_response, image=tools_call_result.generated_image,
role="assistant", text=ai_response, image=tools_artifacts.generated_image,
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_call_result.generated_image), True
return Message(text=ai_response, image=tools_artifacts.generated_image), True
except Exception as e:
if str(e).find("Rate limit exceeded") != -1:
@ -212,19 +213,20 @@ class AiAgent:
context.append(_serialize_assistant_message(response))
ai_response = response.content
tools_call_result = await self._process_tool_calls(bot_id, chat_id,
tool_calls=response.tool_calls, context=context)
if tools_call_result.tools_called:
tools_artifacts = AiAgent._ToolsArtifacts()
if response.tool_calls is not None:
tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
tool_calls=response.tool_calls, context=context)
response2 = await self._generate_reply(bot_id, chat_id, context=context)
ai_response = response2.content
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id, role="assistant",
text=ai_response, image=tools_call_result.generated_image,
text=ai_response, image=tools_artifacts.generated_image,
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_call_result.generated_image), True
return Message(text=ai_response, image=tools_artifacts.generated_image), True
except Exception as e:
if str(e).find("Rate limit exceeded") != -1:
@ -273,39 +275,52 @@ class AiAgent:
return response.choices[0].message
async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall],
context: List[MessageTypedDict]) -> ToolCallResult:
result = AiAgent.ToolCallResult()
if tool_calls is not None:
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
if tool_name == "generate_image":
prompt = tool_args.get("prompt", "")
aspect_ratio = tool_args.get("aspect_ratio", None)
result.generated_image, success = \
await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio)
tool_result_content = []
if success:
tool_result_content.append(
{"type": "text",
"text": "Изображение сгенерировано и будет показано пользователю."})
tool_result_content.append(
{"type": "image_url", "image_url": {"url": _encode_image(result.generated_image)}})
else:
tool_result_content.append(
{"type": "text",
"text": "Не удалось сгенерировать изображение."})
context.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result_content
})
result.tools_called = True
break
return result
context: List[MessageTypedDict]) -> _ToolsArtifacts:
artifacts = AiAgent._ToolsArtifacts()
if tool_calls is None:
return artifacts
async def _generate_image(self, bot_id: int, chat_id: int, prompt: str, aspect_ratio: Optional[str]) \
-> Tuple[Optional[bytes], bool]:
functions_map: Dict[str,
Callable[[int, int, Dict, AiAgent._ToolsArtifacts],
Awaitable[List[ChatMessageContentItemTypedDict]]]] = {
"generate_image": self._process_tool_generate_image
}
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
if tool_name in functions_map:
tool_result = await functions_map[tool_name](bot_id, chat_id, tool_args, artifacts)
context.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result
})
artifacts.tools_called = True
return artifacts
async def _process_tool_generate_image(self, bot_id: int, chat_id: int, args: dict, artifacts: _ToolsArtifacts) \
-> List[ChatMessageContentItemTypedDict]:
prompt = args.get("prompt", "")
aspect_ratio = args.get("aspect_ratio", None)
result = await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio)
content = []
if result.is_ok():
content.append(
{"type": "text",
"text": "Изображение сгенерировано и будет показано пользователю."})
content.append(
{"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}})
artifacts.generated_image = result.ok_value
else:
content.append(
{"type": "text",
"text": f"Не удалось сгенерировать изображение: {result.err_value}"})
return content
async def _generate_image(self, bot_id: int, chat_id: int,
prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]:
print(f"Генерация изображения: {prompt}")
context = [{"role": "user", "content": prompt}]
@ -329,9 +344,10 @@ class AiAgent:
image.save(output, format="JPEG", quality=80, optimize=True)
image_bytes = output.getvalue()
return image_bytes, True
except Exception:
return None, False
return Ok(image_bytes)
except Exception as e:
print(f"Ошибка генерации изображения: {e}")
return Err(str(e))
agent: AiAgent