vk_chat_bot/ai_agent.py
Kirill Kirilenko 997bc39a22 Системные запросы и описание инструментов вынесены в отдельные текстовые файлы.
Модель генерации изображений жестко прописана в коде для оптимизации составления запросов.
Добавлено удаление из ответа ИИ тегов <image>.
2026-02-21 00:50:33 +03:00

366 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import aiohttp
import base64
import datetime
import json
from collections.abc import Callable
from dataclasses import dataclass
from io import BytesIO
from PIL import Image
from result import Ok, Err, Result
from typing import List, Tuple, Any, Optional, Union, Dict, Awaitable
from openrouter import OpenRouter, RetryConfig
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \
ChatMessageToolCall, MessageTypedDict
from openrouter.errors import ResponseValidationError, ChatError
from openrouter.utils import BackoffStrategy
from fal_client import AsyncClient as FalClient
from database import BasicDatabase
OPENROUTER_X_TITLE = "TG/VK Chat Bot"
OPENROUTER_HTTP_REFERER = "https://ultracoder.org"
GROUP_CHAT_MAX_MESSAGES = 20
PRIVATE_CHAT_MAX_MESSAGES = 40
MAX_OUTPUT_TOKENS = 500
@dataclass()
class Message:
user_name: str = None
text: str = None
image: bytes = None
message_id: int = None
def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str:
current_time = datetime.datetime.now().strftime("%d.%m.%Y %H:%M")
prefix = f"[{current_time}, {username}]" if username is not None else f"[{current_time}]"
return f"{prefix}: {text}" if text is not None else f"{prefix}:"
def _encode_image(image: bytes) -> str:
encoded_image = base64.b64encode(image).decode('utf-8')
return f"data:image/jpeg;base64,{encoded_image}"
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
serialized = {"role": role, "content": []}
if text is not None:
serialized["content"].append({"type": "text", "text": text})
if image is not None:
serialized["content"].append({"type": "image_url", "image_url": {"url": _encode_image(image)}})
return serialized
def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]:
if isinstance(data, dict):
return {
k: _remove_none_recursive(v)
for k, v in data.items()
if v is not None
}
elif isinstance(data, list):
return [
_remove_none_recursive(item)
for item in data
if item is not None
]
else:
return data
def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict:
return _remove_none_recursive(message.model_dump(by_alias=True))
class AiAgent:
def __init__(self,
openrouter_token: str, openrouter_model: str,
fal_token: str,
db: BasicDatabase,
platform: str):
retry_config = RetryConfig(strategy="backoff",
backoff=BackoffStrategy(
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
retry_connection_errors=True)
self.db = db
self.openrouter_model = openrouter_model
self.fal_model = "fal-ai/bytedance/seedream/v4.5/text-to-image"
self.platform = platform
self._load_prompts()
self.client_openrouter = OpenRouter(api_key=openrouter_token,
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER,
retry_config=retry_config)
self.client_fal = FalClient(key=fal_token)
@dataclass()
class _ToolsArtifacts:
generated_image: Optional[bytes] = None
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
message.text = _add_message_prefix(message.text, message.user_name)
context.append(_serialize_message(role="user", text=message.text, image=message.image))
for fwd_message in forwarded_messages:
message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name)
if fwd_message.text is not None:
message_text += '\n' + fwd_message.text
fwd_message.text = message_text
context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image))
try:
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
ai_response = response.content
tools_artifacts = AiAgent._ToolsArtifacts()
if response.tool_calls is not None:
tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
tool_calls=response.tool_calls, context=context)
response2 = await self._generate_reply(bot_id, chat_id, context=context)
ai_response = response2.content
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
for fwd_message in forwarded_messages:
self.db.context_add_message(bot_id, chat_id,
role="user", text=fwd_message.text, image=fwd_message.image,
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id,
role="assistant", text=ai_response, image=tools_artifacts.generated_image,
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_artifacts.generated_image), True
except Exception as e:
if str(e).find("Rate limit exceeded") != -1:
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]:
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
message.text = _add_message_prefix(message.text)
content: list[dict[str, Any]] = []
if message.text is not None:
content.append({"type": "text", "text": message.text})
if message.image is not None:
content.append({"type": "image_url", "image_url": {"url": _encode_image(message.image)}})
context.append({"role": "user", "content": content})
try:
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
context.append(_serialize_assistant_message(response))
ai_response = response.content
tools_artifacts = AiAgent._ToolsArtifacts()
if response.tool_calls is not None:
tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
tool_calls=response.tool_calls, context=context)
response2 = await self._generate_reply(bot_id, chat_id, context=context)
ai_response = response2.content
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id, role="assistant",
text=ai_response, image=tools_artifacts.generated_image,
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_artifacts.generated_image), True
except Exception as e:
if str(e).find("Rate limit exceeded") != -1:
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
def get_last_assistant_message_id(self, bot_id: int, chat_id: int):
return self.db.context_get_last_assistant_message_id(bot_id, chat_id)
def set_last_response_id(self, bot_id: int, chat_id: int, message_id: int):
self.db.context_set_last_message_id(bot_id, chat_id, message_id)
def clear_chat_context(self, bot_id: int, chat_id: int):
self.db.context_clear(bot_id, chat_id)
#######################################
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[MessageTypedDict]:
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
prompt += '\n' + self.system_prompt_image_generation
bot = self.db.get_bot(bot_id)
if bot['ai_prompt'] is not None:
prompt += '\n' + bot['ai_prompt'] + '\n'
chat = self.db.create_chat_if_not_exists(bot_id, chat_id)
if chat['ai_prompt'] is not None:
prompt += '\n' + chat['ai_prompt']
messages = self.db.context_get_messages(bot_id, chat_id)
context: List[MessageTypedDict] = [{"role": "system", "content": prompt}]
for message in messages:
context.append(_serialize_message(message["role"], message["text"], message["image"]))
return context
async def _generate_reply(self, bot_id: int, chat_id: int,
context: List[MessageTypedDict], allow_tools: bool = False) -> AssistantMessage:
response = await self._async_chat_completion_request(
model=self.openrouter_model,
messages=context,
tools=self.tools_description if allow_tools else None,
tool_choice="auto" if allow_tools else None,
max_tokens=MAX_OUTPUT_TOKENS,
user=f'{self.platform}_{bot_id}_{chat_id}'
)
return self._filter_response(response.choices[0].message)
async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall],
context: List[MessageTypedDict]) -> _ToolsArtifacts:
artifacts = AiAgent._ToolsArtifacts()
if tool_calls is None:
return artifacts
functions_map: Dict[str,
Callable[[int, int, Dict, AiAgent._ToolsArtifacts],
Awaitable[List[ChatMessageContentItemTypedDict]]]] = {
"generate_image": self._process_tool_generate_image
}
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
if tool_name in functions_map:
tool_result = await functions_map[tool_name](bot_id, chat_id, tool_args, artifacts)
context.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result
})
artifacts.tools_called = True
return artifacts
async def _process_tool_generate_image(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \
-> List[ChatMessageContentItemTypedDict]:
prompt = args.get("prompt", "")
aspect_ratio = args.get("aspect_ratio", None)
result = await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio)
content = []
if result.is_ok():
content.append(
{"type": "text",
"text": "Изображение сгенерировано и будет показано пользователю."})
content.append(
{"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}})
artifacts.generated_image = result.ok_value
else:
content.append(
{"type": "text",
"text": f"Не удалось сгенерировать изображение: {result.err_value}"})
return content
async def _generate_image(self, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]:
aspect_ratio_resolution_map = {
"1:1": (1280, 1280),
"4:3": (1280, 1024),
"3:4": (1024, 1280),
"16:9": (1280, 720),
"9:16": (720, 1280)
}
width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
print(f"Генерация изображения {width}x{height}: {prompt}")
arguments = {
"prompt": prompt,
"image_size": {"width": width, "height": height},
"enable_safety_checker": False
}
try:
result = await self.client_fal.run(self.fal_model, arguments=arguments)
if "images" not in result:
raise RuntimeError("Неожиданный ответ от сервера.")
image_url = result["images"][0]["url"]
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
if response.status == 200:
image_bytes = await response.read()
if not image_url.endswith(".jpg"):
image = Image.open(BytesIO(image_bytes)).convert("RGB")
output = BytesIO()
image.save(output, format="JPEG", quality=80, optimize=True)
image_bytes = output.getvalue()
return Ok(image_bytes)
else:
raise RuntimeError(f"Не удалось загрузить изображение ({response.status}).")
except Exception as e:
print(f"Ошибка генерации изображения: {e}")
return Err(str(e))
async def _async_chat_completion_request(self, **kwargs):
try:
return await self.client_openrouter.chat.send_async(**kwargs)
except ResponseValidationError as e:
# Костыль для OpenRouter SDK:
# https://github.com/OpenRouterTeam/python-sdk/issues/44
body = json.loads(e.body)
if "error" in body:
try:
raw_response = json.loads(body["error"]["metadata"]["raw"])
message = str(raw_response["error"]["message"])
e = RuntimeError(message)
except Exception:
pass
raise e
except ChatError as e:
if e.message == "Provider returned error":
body = json.loads(e.body)
try:
raw_response = json.loads(body["error"]["metadata"]["raw"])
message = str(raw_response["error"]["message"])
e = RuntimeError(message)
except Exception:
pass
raise e
@staticmethod
def _filter_response(response: AssistantMessage) -> AssistantMessage:
text = str(response.content)
text = text.replace("<image>", "")
response.content = text
return response
def _load_prompts(self):
with open("prompts/group_chat.txt", "r") as f:
self.system_prompt_group_chat = f.read()
with open("prompts/private_chat.txt", "r") as f:
self.system_prompt_private_chat = f.read()
with open("prompts/image_generation.txt", "r") as f:
self.system_prompt_image_generation = f.read()
with open("prompts/tools.json", "r") as f:
self.tools_description = json.loads(f.read())
agent: AiAgent
def create_ai_agent(openrouter_token: str, openrouter_model: str, fal_token: str,
db: BasicDatabase, platform: str):
global agent
agent = AiAgent(openrouter_token, openrouter_model, fal_token, db, platform)