Рефакторинг агента.

This commit is contained in:
Kirill Kirilenko 2026-02-15 03:04:21 +03:00
parent e05e0d4c82
commit ec18d3ffa4

View file

@ -2,15 +2,16 @@ import base64
import datetime import datetime
import json import json
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
from typing import List, Tuple, Any, Optional, Union from result import Ok, Err, Result
from typing import List, Tuple, Any, Optional, Union, Dict, Awaitable
from openrouter import OpenRouter, RetryConfig from openrouter import OpenRouter, RetryConfig
from openrouter.components import AssistantMessage, ChatMessageToolCall, \ from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \
MessageTypedDict, ToolDefinitionJSONTypedDict, AssistantMessageTypedDict ChatMessageToolCall, MessageTypedDict, ToolDefinitionJSONTypedDict
from openrouter.utils import BackoffStrategy from openrouter.utils import BackoffStrategy
from database import BasicDatabase from database import BasicDatabase
@ -130,11 +131,6 @@ def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]:
class AiAgent: class AiAgent:
@dataclass()
class ToolCallResult:
tools_called: bool = False
generated_image: Optional[bytes] = None
def __init__(self, def __init__(self,
api_token_main: str, model_main: str, api_token_main: str, model_main: str,
api_token_image: str, model_image: str, api_token_image: str, model_image: str,
@ -154,6 +150,10 @@ class AiAgent:
self.client_image = OpenRouter(api_key=api_token_image, self.client_image = OpenRouter(api_key=api_token_image,
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER) x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER)
@dataclass()
class _ToolsArtifacts:
generated_image: Optional[bytes] = None
async def get_group_chat_reply(self, bot_id: int, chat_id: int, async def get_group_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]: message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id) context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
@ -172,9 +172,10 @@ class AiAgent:
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True) response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
ai_response = response.content ai_response = response.content
tools_call_result = await self._process_tool_calls(bot_id, chat_id, tools_artifacts = AiAgent._ToolsArtifacts()
tool_calls=response.tool_calls, context=context) if response.tool_calls is not None:
if tools_call_result.tools_called: tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
tool_calls=response.tool_calls, context=context)
response2 = await self._generate_reply(bot_id, chat_id, context=context) response2 = await self._generate_reply(bot_id, chat_id, context=context)
ai_response = response2.content ai_response = response2.content
@ -185,10 +186,10 @@ class AiAgent:
role="user", text=fwd_message.text, image=fwd_message.image, role="user", text=fwd_message.text, image=fwd_message.image,
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES) message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id, self.db.context_add_message(bot_id, chat_id,
role="assistant", text=ai_response, image=tools_call_result.generated_image, role="assistant", text=ai_response, image=tools_artifacts.generated_image,
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES) message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_call_result.generated_image), True return Message(text=ai_response, image=tools_artifacts.generated_image), True
except Exception as e: except Exception as e:
if str(e).find("Rate limit exceeded") != -1: if str(e).find("Rate limit exceeded") != -1:
@ -212,19 +213,20 @@ class AiAgent:
context.append(_serialize_assistant_message(response)) context.append(_serialize_assistant_message(response))
ai_response = response.content ai_response = response.content
tools_call_result = await self._process_tool_calls(bot_id, chat_id, tools_artifacts = AiAgent._ToolsArtifacts()
tool_calls=response.tool_calls, context=context) if response.tool_calls is not None:
if tools_call_result.tools_called: tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
tool_calls=response.tool_calls, context=context)
response2 = await self._generate_reply(bot_id, chat_id, context=context) response2 = await self._generate_reply(bot_id, chat_id, context=context)
ai_response = response2.content ai_response = response2.content
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image, self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES) message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id, role="assistant", self.db.context_add_message(bot_id, chat_id, role="assistant",
text=ai_response, image=tools_call_result.generated_image, text=ai_response, image=tools_artifacts.generated_image,
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES) message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_call_result.generated_image), True return Message(text=ai_response, image=tools_artifacts.generated_image), True
except Exception as e: except Exception as e:
if str(e).find("Rate limit exceeded") != -1: if str(e).find("Rate limit exceeded") != -1:
@ -273,39 +275,52 @@ class AiAgent:
return response.choices[0].message return response.choices[0].message
async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall], async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall],
context: List[MessageTypedDict]) -> ToolCallResult: context: List[MessageTypedDict]) -> _ToolsArtifacts:
result = AiAgent.ToolCallResult() artifacts = AiAgent._ToolsArtifacts()
if tool_calls is not None: if tool_calls is None:
for tool_call in tool_calls: return artifacts
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
if tool_name == "generate_image":
prompt = tool_args.get("prompt", "")
aspect_ratio = tool_args.get("aspect_ratio", None)
result.generated_image, success = \
await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio)
tool_result_content = []
if success:
tool_result_content.append(
{"type": "text",
"text": "Изображение сгенерировано и будет показано пользователю."})
tool_result_content.append(
{"type": "image_url", "image_url": {"url": _encode_image(result.generated_image)}})
else:
tool_result_content.append(
{"type": "text",
"text": "Не удалось сгенерировать изображение."})
context.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result_content
})
result.tools_called = True
break
return result
async def _generate_image(self, bot_id: int, chat_id: int, prompt: str, aspect_ratio: Optional[str]) \ functions_map: Dict[str,
-> Tuple[Optional[bytes], bool]: Callable[[int, int, Dict, AiAgent._ToolsArtifacts],
Awaitable[List[ChatMessageContentItemTypedDict]]]] = {
"generate_image": self._process_tool_generate_image
}
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
if tool_name in functions_map:
tool_result = await functions_map[tool_name](bot_id, chat_id, tool_args, artifacts)
context.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result
})
artifacts.tools_called = True
return artifacts
async def _process_tool_generate_image(self, bot_id: int, chat_id: int, args: dict, artifacts: _ToolsArtifacts) \
-> List[ChatMessageContentItemTypedDict]:
prompt = args.get("prompt", "")
aspect_ratio = args.get("aspect_ratio", None)
result = await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio)
content = []
if result.is_ok():
content.append(
{"type": "text",
"text": "Изображение сгенерировано и будет показано пользователю."})
content.append(
{"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}})
artifacts.generated_image = result.ok_value
else:
content.append(
{"type": "text",
"text": f"Не удалось сгенерировать изображение: {result.err_value}"})
return content
async def _generate_image(self, bot_id: int, chat_id: int,
prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]:
print(f"Генерация изображения: {prompt}") print(f"Генерация изображения: {prompt}")
context = [{"role": "user", "content": prompt}] context = [{"role": "user", "content": prompt}]
@ -329,9 +344,10 @@ class AiAgent:
image.save(output, format="JPEG", quality=80, optimize=True) image.save(output, format="JPEG", quality=80, optimize=True)
image_bytes = output.getvalue() image_bytes = output.getvalue()
return image_bytes, True return Ok(image_bytes)
except Exception: except Exception as e:
return None, False print(f"Ошибка генерации изображения: {e}")
return Err(str(e))
agent: AiAgent agent: AiAgent