Рефакторинг агента.
This commit is contained in:
parent
e05e0d4c82
commit
ec18d3ffa4
1 changed files with 70 additions and 54 deletions
106
ai_agent.py
106
ai_agent.py
|
|
@ -2,15 +2,16 @@ import base64
|
|||
import datetime
|
||||
import json
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from typing import List, Tuple, Any, Optional, Union
|
||||
from result import Ok, Err, Result
|
||||
from typing import List, Tuple, Any, Optional, Union, Dict, Awaitable
|
||||
|
||||
from openrouter import OpenRouter, RetryConfig
|
||||
from openrouter.components import AssistantMessage, ChatMessageToolCall, \
|
||||
MessageTypedDict, ToolDefinitionJSONTypedDict, AssistantMessageTypedDict
|
||||
|
||||
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \
|
||||
ChatMessageToolCall, MessageTypedDict, ToolDefinitionJSONTypedDict
|
||||
from openrouter.utils import BackoffStrategy
|
||||
|
||||
from database import BasicDatabase
|
||||
|
|
@ -130,11 +131,6 @@ def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]:
|
|||
|
||||
|
||||
class AiAgent:
|
||||
@dataclass()
|
||||
class ToolCallResult:
|
||||
tools_called: bool = False
|
||||
generated_image: Optional[bytes] = None
|
||||
|
||||
def __init__(self,
|
||||
api_token_main: str, model_main: str,
|
||||
api_token_image: str, model_image: str,
|
||||
|
|
@ -154,6 +150,10 @@ class AiAgent:
|
|||
self.client_image = OpenRouter(api_key=api_token_image,
|
||||
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER)
|
||||
|
||||
@dataclass()
|
||||
class _ToolsArtifacts:
|
||||
generated_image: Optional[bytes] = None
|
||||
|
||||
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
||||
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
|
||||
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
|
||||
|
|
@ -172,9 +172,10 @@ class AiAgent:
|
|||
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
|
||||
ai_response = response.content
|
||||
|
||||
tools_call_result = await self._process_tool_calls(bot_id, chat_id,
|
||||
tools_artifacts = AiAgent._ToolsArtifacts()
|
||||
if response.tool_calls is not None:
|
||||
tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
|
||||
tool_calls=response.tool_calls, context=context)
|
||||
if tools_call_result.tools_called:
|
||||
response2 = await self._generate_reply(bot_id, chat_id, context=context)
|
||||
ai_response = response2.content
|
||||
|
||||
|
|
@ -185,10 +186,10 @@ class AiAgent:
|
|||
role="user", text=fwd_message.text, image=fwd_message.image,
|
||||
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||
self.db.context_add_message(bot_id, chat_id,
|
||||
role="assistant", text=ai_response, image=tools_call_result.generated_image,
|
||||
role="assistant", text=ai_response, image=tools_artifacts.generated_image,
|
||||
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||
|
||||
return Message(text=ai_response, image=tools_call_result.generated_image), True
|
||||
return Message(text=ai_response, image=tools_artifacts.generated_image), True
|
||||
|
||||
except Exception as e:
|
||||
if str(e).find("Rate limit exceeded") != -1:
|
||||
|
|
@ -212,19 +213,20 @@ class AiAgent:
|
|||
context.append(_serialize_assistant_message(response))
|
||||
ai_response = response.content
|
||||
|
||||
tools_call_result = await self._process_tool_calls(bot_id, chat_id,
|
||||
tools_artifacts = AiAgent._ToolsArtifacts()
|
||||
if response.tool_calls is not None:
|
||||
tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
|
||||
tool_calls=response.tool_calls, context=context)
|
||||
if tools_call_result.tools_called:
|
||||
response2 = await self._generate_reply(bot_id, chat_id, context=context)
|
||||
ai_response = response2.content
|
||||
|
||||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||||
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
self.db.context_add_message(bot_id, chat_id, role="assistant",
|
||||
text=ai_response, image=tools_call_result.generated_image,
|
||||
text=ai_response, image=tools_artifacts.generated_image,
|
||||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
|
||||
return Message(text=ai_response, image=tools_call_result.generated_image), True
|
||||
return Message(text=ai_response, image=tools_artifacts.generated_image), True
|
||||
|
||||
except Exception as e:
|
||||
if str(e).find("Rate limit exceeded") != -1:
|
||||
|
|
@ -273,39 +275,52 @@ class AiAgent:
|
|||
return response.choices[0].message
|
||||
|
||||
async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall],
|
||||
context: List[MessageTypedDict]) -> ToolCallResult:
|
||||
result = AiAgent.ToolCallResult()
|
||||
if tool_calls is not None:
|
||||
context: List[MessageTypedDict]) -> _ToolsArtifacts:
|
||||
artifacts = AiAgent._ToolsArtifacts()
|
||||
if tool_calls is None:
|
||||
return artifacts
|
||||
|
||||
functions_map: Dict[str,
|
||||
Callable[[int, int, Dict, AiAgent._ToolsArtifacts],
|
||||
Awaitable[List[ChatMessageContentItemTypedDict]]]] = {
|
||||
"generate_image": self._process_tool_generate_image
|
||||
}
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
if tool_name == "generate_image":
|
||||
prompt = tool_args.get("prompt", "")
|
||||
aspect_ratio = tool_args.get("aspect_ratio", None)
|
||||
result.generated_image, success = \
|
||||
await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio)
|
||||
tool_result_content = []
|
||||
if success:
|
||||
tool_result_content.append(
|
||||
{"type": "text",
|
||||
"text": "Изображение сгенерировано и будет показано пользователю."})
|
||||
tool_result_content.append(
|
||||
{"type": "image_url", "image_url": {"url": _encode_image(result.generated_image)}})
|
||||
else:
|
||||
tool_result_content.append(
|
||||
{"type": "text",
|
||||
"text": "Не удалось сгенерировать изображение."})
|
||||
if tool_name in functions_map:
|
||||
tool_result = await functions_map[tool_name](bot_id, chat_id, tool_args, artifacts)
|
||||
context.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_result_content
|
||||
"content": tool_result
|
||||
})
|
||||
result.tools_called = True
|
||||
break
|
||||
return result
|
||||
artifacts.tools_called = True
|
||||
return artifacts
|
||||
|
||||
async def _generate_image(self, bot_id: int, chat_id: int, prompt: str, aspect_ratio: Optional[str]) \
|
||||
-> Tuple[Optional[bytes], bool]:
|
||||
async def _process_tool_generate_image(self, bot_id: int, chat_id: int, args: dict, artifacts: _ToolsArtifacts) \
|
||||
-> List[ChatMessageContentItemTypedDict]:
|
||||
prompt = args.get("prompt", "")
|
||||
aspect_ratio = args.get("aspect_ratio", None)
|
||||
result = await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio)
|
||||
|
||||
content = []
|
||||
if result.is_ok():
|
||||
content.append(
|
||||
{"type": "text",
|
||||
"text": "Изображение сгенерировано и будет показано пользователю."})
|
||||
content.append(
|
||||
{"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}})
|
||||
artifacts.generated_image = result.ok_value
|
||||
else:
|
||||
content.append(
|
||||
{"type": "text",
|
||||
"text": f"Не удалось сгенерировать изображение: {result.err_value}"})
|
||||
return content
|
||||
|
||||
async def _generate_image(self, bot_id: int, chat_id: int,
|
||||
prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]:
|
||||
print(f"Генерация изображения: {prompt}")
|
||||
context = [{"role": "user", "content": prompt}]
|
||||
|
||||
|
|
@ -329,9 +344,10 @@ class AiAgent:
|
|||
image.save(output, format="JPEG", quality=80, optimize=True)
|
||||
image_bytes = output.getvalue()
|
||||
|
||||
return image_bytes, True
|
||||
except Exception:
|
||||
return None, False
|
||||
return Ok(image_bytes)
|
||||
except Exception as e:
|
||||
print(f"Ошибка генерации изображения: {e}")
|
||||
return Err(str(e))
|
||||
|
||||
|
||||
agent: AiAgent
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue