Рефакторинг агента.
This commit is contained in:
parent
e05e0d4c82
commit
ec18d3ffa4
1 changed files with 70 additions and 54 deletions
124
ai_agent.py
124
ai_agent.py
|
|
@ -2,15 +2,16 @@ import base64
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import List, Tuple, Any, Optional, Union
|
from result import Ok, Err, Result
|
||||||
|
from typing import List, Tuple, Any, Optional, Union, Dict, Awaitable
|
||||||
|
|
||||||
from openrouter import OpenRouter, RetryConfig
|
from openrouter import OpenRouter, RetryConfig
|
||||||
from openrouter.components import AssistantMessage, ChatMessageToolCall, \
|
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \
|
||||||
MessageTypedDict, ToolDefinitionJSONTypedDict, AssistantMessageTypedDict
|
ChatMessageToolCall, MessageTypedDict, ToolDefinitionJSONTypedDict
|
||||||
|
|
||||||
from openrouter.utils import BackoffStrategy
|
from openrouter.utils import BackoffStrategy
|
||||||
|
|
||||||
from database import BasicDatabase
|
from database import BasicDatabase
|
||||||
|
|
@ -130,11 +131,6 @@ def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]:
|
||||||
|
|
||||||
|
|
||||||
class AiAgent:
|
class AiAgent:
|
||||||
@dataclass()
|
|
||||||
class ToolCallResult:
|
|
||||||
tools_called: bool = False
|
|
||||||
generated_image: Optional[bytes] = None
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
api_token_main: str, model_main: str,
|
api_token_main: str, model_main: str,
|
||||||
api_token_image: str, model_image: str,
|
api_token_image: str, model_image: str,
|
||||||
|
|
@ -154,6 +150,10 @@ class AiAgent:
|
||||||
self.client_image = OpenRouter(api_key=api_token_image,
|
self.client_image = OpenRouter(api_key=api_token_image,
|
||||||
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER)
|
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER)
|
||||||
|
|
||||||
|
@dataclass()
|
||||||
|
class _ToolsArtifacts:
|
||||||
|
generated_image: Optional[bytes] = None
|
||||||
|
|
||||||
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
||||||
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
|
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
|
||||||
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
|
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
|
||||||
|
|
@ -172,9 +172,10 @@ class AiAgent:
|
||||||
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
|
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
|
||||||
ai_response = response.content
|
ai_response = response.content
|
||||||
|
|
||||||
tools_call_result = await self._process_tool_calls(bot_id, chat_id,
|
tools_artifacts = AiAgent._ToolsArtifacts()
|
||||||
tool_calls=response.tool_calls, context=context)
|
if response.tool_calls is not None:
|
||||||
if tools_call_result.tools_called:
|
tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
|
||||||
|
tool_calls=response.tool_calls, context=context)
|
||||||
response2 = await self._generate_reply(bot_id, chat_id, context=context)
|
response2 = await self._generate_reply(bot_id, chat_id, context=context)
|
||||||
ai_response = response2.content
|
ai_response = response2.content
|
||||||
|
|
||||||
|
|
@ -185,10 +186,10 @@ class AiAgent:
|
||||||
role="user", text=fwd_message.text, image=fwd_message.image,
|
role="user", text=fwd_message.text, image=fwd_message.image,
|
||||||
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||||
self.db.context_add_message(bot_id, chat_id,
|
self.db.context_add_message(bot_id, chat_id,
|
||||||
role="assistant", text=ai_response, image=tools_call_result.generated_image,
|
role="assistant", text=ai_response, image=tools_artifacts.generated_image,
|
||||||
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||||
|
|
||||||
return Message(text=ai_response, image=tools_call_result.generated_image), True
|
return Message(text=ai_response, image=tools_artifacts.generated_image), True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("Rate limit exceeded") != -1:
|
if str(e).find("Rate limit exceeded") != -1:
|
||||||
|
|
@ -212,19 +213,20 @@ class AiAgent:
|
||||||
context.append(_serialize_assistant_message(response))
|
context.append(_serialize_assistant_message(response))
|
||||||
ai_response = response.content
|
ai_response = response.content
|
||||||
|
|
||||||
tools_call_result = await self._process_tool_calls(bot_id, chat_id,
|
tools_artifacts = AiAgent._ToolsArtifacts()
|
||||||
tool_calls=response.tool_calls, context=context)
|
if response.tool_calls is not None:
|
||||||
if tools_call_result.tools_called:
|
tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
|
||||||
|
tool_calls=response.tool_calls, context=context)
|
||||||
response2 = await self._generate_reply(bot_id, chat_id, context=context)
|
response2 = await self._generate_reply(bot_id, chat_id, context=context)
|
||||||
ai_response = response2.content
|
ai_response = response2.content
|
||||||
|
|
||||||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||||||
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||||
self.db.context_add_message(bot_id, chat_id, role="assistant",
|
self.db.context_add_message(bot_id, chat_id, role="assistant",
|
||||||
text=ai_response, image=tools_call_result.generated_image,
|
text=ai_response, image=tools_artifacts.generated_image,
|
||||||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||||
|
|
||||||
return Message(text=ai_response, image=tools_call_result.generated_image), True
|
return Message(text=ai_response, image=tools_artifacts.generated_image), True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("Rate limit exceeded") != -1:
|
if str(e).find("Rate limit exceeded") != -1:
|
||||||
|
|
@ -273,39 +275,52 @@ class AiAgent:
|
||||||
return response.choices[0].message
|
return response.choices[0].message
|
||||||
|
|
||||||
async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall],
|
async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall],
|
||||||
context: List[MessageTypedDict]) -> ToolCallResult:
|
context: List[MessageTypedDict]) -> _ToolsArtifacts:
|
||||||
result = AiAgent.ToolCallResult()
|
artifacts = AiAgent._ToolsArtifacts()
|
||||||
if tool_calls is not None:
|
if tool_calls is None:
|
||||||
for tool_call in tool_calls:
|
return artifacts
|
||||||
tool_name = tool_call.function.name
|
|
||||||
tool_args = json.loads(tool_call.function.arguments)
|
|
||||||
if tool_name == "generate_image":
|
|
||||||
prompt = tool_args.get("prompt", "")
|
|
||||||
aspect_ratio = tool_args.get("aspect_ratio", None)
|
|
||||||
result.generated_image, success = \
|
|
||||||
await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio)
|
|
||||||
tool_result_content = []
|
|
||||||
if success:
|
|
||||||
tool_result_content.append(
|
|
||||||
{"type": "text",
|
|
||||||
"text": "Изображение сгенерировано и будет показано пользователю."})
|
|
||||||
tool_result_content.append(
|
|
||||||
{"type": "image_url", "image_url": {"url": _encode_image(result.generated_image)}})
|
|
||||||
else:
|
|
||||||
tool_result_content.append(
|
|
||||||
{"type": "text",
|
|
||||||
"text": "Не удалось сгенерировать изображение."})
|
|
||||||
context.append({
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tool_call.id,
|
|
||||||
"content": tool_result_content
|
|
||||||
})
|
|
||||||
result.tools_called = True
|
|
||||||
break
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def _generate_image(self, bot_id: int, chat_id: int, prompt: str, aspect_ratio: Optional[str]) \
|
functions_map: Dict[str,
|
||||||
-> Tuple[Optional[bytes], bool]:
|
Callable[[int, int, Dict, AiAgent._ToolsArtifacts],
|
||||||
|
Awaitable[List[ChatMessageContentItemTypedDict]]]] = {
|
||||||
|
"generate_image": self._process_tool_generate_image
|
||||||
|
}
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
tool_name = tool_call.function.name
|
||||||
|
tool_args = json.loads(tool_call.function.arguments)
|
||||||
|
if tool_name in functions_map:
|
||||||
|
tool_result = await functions_map[tool_name](bot_id, chat_id, tool_args, artifacts)
|
||||||
|
context.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call.id,
|
||||||
|
"content": tool_result
|
||||||
|
})
|
||||||
|
artifacts.tools_called = True
|
||||||
|
return artifacts
|
||||||
|
|
||||||
|
async def _process_tool_generate_image(self, bot_id: int, chat_id: int, args: dict, artifacts: _ToolsArtifacts) \
|
||||||
|
-> List[ChatMessageContentItemTypedDict]:
|
||||||
|
prompt = args.get("prompt", "")
|
||||||
|
aspect_ratio = args.get("aspect_ratio", None)
|
||||||
|
result = await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio)
|
||||||
|
|
||||||
|
content = []
|
||||||
|
if result.is_ok():
|
||||||
|
content.append(
|
||||||
|
{"type": "text",
|
||||||
|
"text": "Изображение сгенерировано и будет показано пользователю."})
|
||||||
|
content.append(
|
||||||
|
{"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}})
|
||||||
|
artifacts.generated_image = result.ok_value
|
||||||
|
else:
|
||||||
|
content.append(
|
||||||
|
{"type": "text",
|
||||||
|
"text": f"Не удалось сгенерировать изображение: {result.err_value}"})
|
||||||
|
return content
|
||||||
|
|
||||||
|
async def _generate_image(self, bot_id: int, chat_id: int,
|
||||||
|
prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]:
|
||||||
print(f"Генерация изображения: {prompt}")
|
print(f"Генерация изображения: {prompt}")
|
||||||
context = [{"role": "user", "content": prompt}]
|
context = [{"role": "user", "content": prompt}]
|
||||||
|
|
||||||
|
|
@ -329,9 +344,10 @@ class AiAgent:
|
||||||
image.save(output, format="JPEG", quality=80, optimize=True)
|
image.save(output, format="JPEG", quality=80, optimize=True)
|
||||||
image_bytes = output.getvalue()
|
image_bytes = output.getvalue()
|
||||||
|
|
||||||
return image_bytes, True
|
return Ok(image_bytes)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
return None, False
|
print(f"Ошибка генерации изображения: {e}")
|
||||||
|
return Err(str(e))
|
||||||
|
|
||||||
|
|
||||||
agent: AiAgent
|
agent: AiAgent
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue