From ec18d3ffa45d16aa5f9b1ba57f26c257d45ca3a8 Mon Sep 17 00:00:00 2001 From: Kirill Kirilenko Date: Sun, 15 Feb 2026 03:04:21 +0300 Subject: [PATCH] =?UTF-8?q?=D0=A0=D0=B5=D1=84=D0=B0=D0=BA=D1=82=D0=BE?= =?UTF-8?q?=D1=80=D0=B8=D0=BD=D0=B3=20=D0=B0=D0=B3=D0=B5=D0=BD=D1=82=D0=B0?= =?UTF-8?q?.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai_agent.py | 124 +++++++++++++++++++++++++++++----------------------- 1 file changed, 70 insertions(+), 54 deletions(-) diff --git a/ai_agent.py b/ai_agent.py index 4bc4a60..861763d 100644 --- a/ai_agent.py +++ b/ai_agent.py @@ -2,15 +2,16 @@ import base64 import datetime import json +from collections.abc import Callable from dataclasses import dataclass from io import BytesIO from PIL import Image -from typing import List, Tuple, Any, Optional, Union +from result import Ok, Err, Result +from typing import List, Tuple, Any, Optional, Union, Dict, Awaitable from openrouter import OpenRouter, RetryConfig -from openrouter.components import AssistantMessage, ChatMessageToolCall, \ - MessageTypedDict, ToolDefinitionJSONTypedDict, AssistantMessageTypedDict - +from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \ + ChatMessageToolCall, MessageTypedDict, ToolDefinitionJSONTypedDict from openrouter.utils import BackoffStrategy from database import BasicDatabase @@ -130,11 +131,6 @@ def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]: class AiAgent: - @dataclass() - class ToolCallResult: - tools_called: bool = False - generated_image: Optional[bytes] = None - def __init__(self, api_token_main: str, model_main: str, api_token_image: str, model_image: str, @@ -154,6 +150,10 @@ class AiAgent: self.client_image = OpenRouter(api_key=api_token_image, x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER) + @dataclass() + class _ToolsArtifacts: + generated_image: Optional[bytes] = None + async def get_group_chat_reply(self, bot_id: int, chat_id: int, message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]: context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id) @@ -172,9 +172,10 @@ class AiAgent: response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True) ai_response = response.content - tools_call_result = await self._process_tool_calls(bot_id, chat_id, - tool_calls=response.tool_calls, context=context) - if tools_call_result.tools_called: + tools_artifacts = AiAgent._ToolsArtifacts() + if response.tool_calls is not None: + tools_artifacts = await self._process_tool_calls(bot_id, chat_id, + tool_calls=response.tool_calls, context=context) response2 = await self._generate_reply(bot_id, chat_id, context=context) ai_response = response2.content @@ -185,10 +186,10 @@ class AiAgent: role="user", text=fwd_message.text, image=fwd_message.image, message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES) self.db.context_add_message(bot_id, chat_id, - role="assistant", text=ai_response, image=tools_call_result.generated_image, + role="assistant", text=ai_response, image=tools_artifacts.generated_image, message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES) - return Message(text=ai_response, image=tools_call_result.generated_image), True + return Message(text=ai_response, image=tools_artifacts.generated_image), True except Exception as e: if str(e).find("Rate limit exceeded") != -1: @@ -212,19 +213,20 @@ class AiAgent: context.append(_serialize_assistant_message(response)) ai_response = response.content - tools_call_result = await self._process_tool_calls(bot_id, chat_id, - tool_calls=response.tool_calls, context=context) - if tools_call_result.tools_called: + tools_artifacts = AiAgent._ToolsArtifacts() + if response.tool_calls is not None: + tools_artifacts = await self._process_tool_calls(bot_id, chat_id, + tool_calls=response.tool_calls, context=context) response2 = await self._generate_reply(bot_id, chat_id, context=context) ai_response = response2.content self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image, message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES) self.db.context_add_message(bot_id, chat_id, role="assistant", - text=ai_response, image=tools_call_result.generated_image, + text=ai_response, image=tools_artifacts.generated_image, message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES) - return Message(text=ai_response, image=tools_call_result.generated_image), True + return Message(text=ai_response, image=tools_artifacts.generated_image), True except Exception as e: if str(e).find("Rate limit exceeded") != -1: @@ -273,39 +275,52 @@ class AiAgent: return response.choices[0].message async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall], - context: List[MessageTypedDict]) -> ToolCallResult: - result = AiAgent.ToolCallResult() - if tool_calls is not None: - for tool_call in tool_calls: - tool_name = tool_call.function.name - tool_args = json.loads(tool_call.function.arguments) - if tool_name == "generate_image": - prompt = tool_args.get("prompt", "") - aspect_ratio = tool_args.get("aspect_ratio", None) - result.generated_image, success = \ - await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio) - tool_result_content = [] - if success: - tool_result_content.append( - {"type": "text", - "text": "Изображение сгенерировано и будет показано пользователю."}) - tool_result_content.append( - {"type": "image_url", "image_url": {"url": _encode_image(result.generated_image)}}) - else: - tool_result_content.append( - {"type": "text", - "text": "Не удалось сгенерировать изображение."}) - context.append({ - "role": "tool", - "tool_call_id": tool_call.id, - "content": tool_result_content - }) - result.tools_called = True - break - return result + context: List[MessageTypedDict]) -> _ToolsArtifacts: + artifacts = AiAgent._ToolsArtifacts() + if tool_calls is None: + return artifacts - async def _generate_image(self, bot_id: int, chat_id: int, prompt: str, aspect_ratio: Optional[str]) \ - -> Tuple[Optional[bytes], bool]: + functions_map: Dict[str, + Callable[[int, int, Dict, AiAgent._ToolsArtifacts], + Awaitable[List[ChatMessageContentItemTypedDict]]]] = { + "generate_image": self._process_tool_generate_image + } + + for tool_call in tool_calls: + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + if tool_name in functions_map: + tool_result = await functions_map[tool_name](bot_id, chat_id, tool_args, artifacts) + context.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": tool_result + }) + artifacts.tools_called = True + return artifacts + + async def _process_tool_generate_image(self, bot_id: int, chat_id: int, args: dict, artifacts: _ToolsArtifacts) \ + -> List[ChatMessageContentItemTypedDict]: + prompt = args.get("prompt", "") + aspect_ratio = args.get("aspect_ratio", None) + result = await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio) + + content = [] + if result.is_ok(): + content.append( + {"type": "text", + "text": "Изображение сгенерировано и будет показано пользователю."}) + content.append( + {"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}}) + artifacts.generated_image = result.ok_value + else: + content.append( + {"type": "text", + "text": f"Не удалось сгенерировать изображение: {result.err_value}"}) + return content + + async def _generate_image(self, bot_id: int, chat_id: int, + prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]: print(f"Генерация изображения: {prompt}") context = [{"role": "user", "content": prompt}] @@ -329,9 +344,10 @@ class AiAgent: image.save(output, format="JPEG", quality=80, optimize=True) image_bytes = output.getvalue() - return image_bytes, True - except Exception: - return None, False + return Ok(image_bytes) + except Exception as e: + print(f"Ошибка генерации изображения: {e}") + return Err(str(e)) agent: AiAgent