vk_chat_bot/ai_agent.py

344 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import datetime
import json
from dataclasses import dataclass
from io import BytesIO
from PIL import Image
from typing import List, Tuple, Any, Optional, Union
from openrouter import OpenRouter, RetryConfig
from openrouter.components import AssistantMessage, ChatMessageToolCall, \
MessageTypedDict, ToolDefinitionJSONTypedDict, AssistantMessageTypedDict
from openrouter.utils import BackoffStrategy
from database import BasicDatabase
GROUP_CHAT_SYSTEM_PROMPT = """
Ты - ИИ-помощник в групповом чате.\n
Отвечай на вопросы и поддерживай контекст беседы.\n
Ты не можешь обсуждать политику и религию.\n
Сообщения пользователей будут приходить в следующем формате: '[дата время, имя]: текст сообщения'\n
При ответе НЕ нужно указывать ни время, ни пользователя, которому предназначен ответ, ни свое имя.\n
НЕ используй разметку Markdown, она не поддерживается мессенджером.\n
Если нужно нарисовать изображение, используй вызов инструмента.
Запрещено генерировать ASCII-арты, а также добавлять теги/маркеры вроде <image>, [image].
"""
PRIVATE_CHAT_SYSTEM_PROMPT = """
Ты - ИИ-помощник в чате c пользователем.\n
Отвечай на вопросы и поддерживай контекст беседы.\n
Сообщения пользователя будут приходить в следующем формате: '[дата время]: текст сообщения'\n
При ответе НЕ нужно указывать время.\n
НЕ используй разметку Markdown, она не поддерживается мессенджером.\n
Если нужно нарисовать изображение, используй вызов инструмента.
Запрещено генерировать ASCII-арты, а также добавлять теги/маркеры вроде <image>, [image].
"""
GENERATE_IMAGE_TOOL_DESCRIPTION = """
Генерация изображения по описанию.
Используй этот инструмент, если пользователь просит сгенерировать изображение ('нарисуй', 'покажи' и т.п.),
или если это улучшит ответ (например, в ролевой игре для визуализации сцены).
"""
GENERATE_IMAGE_TOOL_PROMPT_ARG_DESCRIPTION = """
Детальное описание на русском языке.
Добавь детали для стиля, цвета, композиции, если нужно.
"""
OPENROUTER_X_TITLE = "TG/VK Chat Bot"
OPENROUTER_HTTP_REFERER = "https://ultracoder.org"
GROUP_CHAT_MAX_MESSAGES = 20
PRIVATE_CHAT_MAX_MESSAGES = 40
MAX_OUTPUT_TOKENS = 500
@dataclass()
class Message:
user_name: str = None
text: str = None
image: bytes = None
message_id: int = None
def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str:
current_time = datetime.datetime.now().strftime("%d.%m.%Y %H:%M")
prefix = f"[{current_time}, {username}]" if username is not None else f"[{current_time}]"
return f"{prefix}: {text}" if text is not None else f"{prefix}:"
def _encode_image(image: bytes) -> str:
encoded_image = base64.b64encode(image).decode('utf-8')
return f"data:image/jpeg;base64,{encoded_image}"
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
serialized = {"role": role, "content": []}
if text is not None:
serialized["content"].append({"type": "text", "text": text})
if image is not None:
serialized["content"].append({"type": "image_url", "image_url": {"url": _encode_image(image)}})
return serialized
def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]:
if isinstance(data, dict):
return {
k: _remove_none_recursive(v)
for k, v in data.items()
if v is not None
}
elif isinstance(data, list):
return [
_remove_none_recursive(item)
for item in data
if item is not None
]
else:
return data
def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict:
return _remove_none_recursive(message.model_dump(by_alias=True))
def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]:
return [{
"type": "function",
"function": {
"name": "generate_image",
"description": GENERATE_IMAGE_TOOL_DESCRIPTION,
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": GENERATE_IMAGE_TOOL_PROMPT_ARG_DESCRIPTION
},
"aspect_ratio": {
"type": "string",
"enum": ["1:1", "3:4", "4:3", "9:16", "16:9"],
"description": "Соотношение сторон (опционально)."
}
},
"required": ["prompt"]
}
}
}]
class AiAgent:
@dataclass()
class ToolCallResult:
tools_called: bool = False
generated_image: Optional[bytes] = None
def __init__(self,
api_token_main: str, model_main: str,
api_token_image: str, model_image: str,
db: BasicDatabase,
platform: str):
retry_config = RetryConfig(strategy="backoff",
backoff=BackoffStrategy(
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
retry_connection_errors=True)
self.db = db
self.model_main = model_main
self.model_image = model_image
self.platform = platform
self.client = OpenRouter(api_key=api_token_main,
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER,
retry_config=retry_config)
self.client_image = OpenRouter(api_key=api_token_image,
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER)
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
message.text = _add_message_prefix(message.text, message.user_name)
context.append(_serialize_message(role="user", text=message.text, image=message.image))
for fwd_message in forwarded_messages:
message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name)
if fwd_message.text is not None:
message_text += '\n' + fwd_message.text
fwd_message.text = message_text
context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image))
try:
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
ai_response = response.content
tools_call_result = await self._process_tool_calls(bot_id, chat_id,
tool_calls=response.tool_calls, context=context)
if tools_call_result.tools_called:
response2 = await self._generate_reply(bot_id, chat_id, context=context)
ai_response = response2.content
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
for fwd_message in forwarded_messages:
self.db.context_add_message(bot_id, chat_id,
role="user", text=fwd_message.text, image=fwd_message.image,
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id,
role="assistant", text=ai_response, image=tools_call_result.generated_image,
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_call_result.generated_image), True
except Exception as e:
if str(e).find("Rate limit exceeded") != -1:
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]:
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
message.text = _add_message_prefix(message.text)
content: list[dict[str, Any]] = []
if message.text is not None:
content.append({"type": "text", "text": message.text})
if message.image is not None:
content.append({"type": "image_url", "image_url": {"url": _encode_image(message.image)}})
context.append({"role": "user", "content": content})
try:
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
context.append(_serialize_assistant_message(response))
ai_response = response.content
tools_call_result = await self._process_tool_calls(bot_id, chat_id,
tool_calls=response.tool_calls, context=context)
if tools_call_result.tools_called:
response2 = await self._generate_reply(bot_id, chat_id, context=context)
ai_response = response2.content
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
self.db.context_add_message(bot_id, chat_id, role="assistant",
text=ai_response, image=tools_call_result.generated_image,
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_call_result.generated_image), True
except Exception as e:
if str(e).find("Rate limit exceeded") != -1:
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
else:
print(f"Ошибка выполнения запроса к ИИ: {e}")
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
def get_last_assistant_message_id(self, bot_id: int, chat_id: int):
return self.db.context_get_last_assistant_message_id(bot_id, chat_id)
def set_last_response_id(self, bot_id: int, chat_id: int, message_id: int):
self.db.context_set_last_message_id(bot_id, chat_id, message_id)
def clear_chat_context(self, bot_id: int, chat_id: int):
self.db.context_clear(bot_id, chat_id)
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[MessageTypedDict]:
prompt = GROUP_CHAT_SYSTEM_PROMPT if is_group_chat else PRIVATE_CHAT_SYSTEM_PROMPT
bot = self.db.get_bot(bot_id)
if bot['ai_prompt'] is not None:
prompt += '\n\n' + bot['ai_prompt']
chat = self.db.create_chat_if_not_exists(bot_id, chat_id)
if chat['ai_prompt'] is not None:
prompt += '\n\n' + chat['ai_prompt']
messages = self.db.context_get_messages(bot_id, chat_id)
context: List[MessageTypedDict] = [{"role": "system", "content": prompt}]
for message in messages:
context.append(_serialize_message(message["role"], message["text"], message["image"]))
return context
async def _generate_reply(self, bot_id: int, chat_id: int,
context: List[MessageTypedDict], allow_tools: bool = False) -> AssistantMessage:
response = await self.client.chat.send_async(
model=self.model_main,
messages=context,
tools=_get_tools_description() if allow_tools else None,
tool_choice="auto" if allow_tools else None,
max_tokens=MAX_OUTPUT_TOKENS,
user=f'{self.platform}_{bot_id}_{chat_id}'
)
return response.choices[0].message
async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall],
context: List[MessageTypedDict]) -> ToolCallResult:
result = AiAgent.ToolCallResult()
if tool_calls is not None:
for tool_call in tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
if tool_name == "generate_image":
prompt = tool_args.get("prompt", "")
aspect_ratio = tool_args.get("aspect_ratio", None)
result.generated_image, success = \
await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio)
tool_result_content = []
if success:
tool_result_content.append(
{"type": "text",
"text": "Изображение сгенерировано и будет показано пользователю."})
tool_result_content.append(
{"type": "image_url", "image_url": {"url": _encode_image(result.generated_image)}})
else:
tool_result_content.append(
{"type": "text",
"text": "Не удалось сгенерировать изображение."})
context.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result_content
})
result.tools_called = True
break
return result
async def _generate_image(self, bot_id: int, chat_id: int, prompt: str, aspect_ratio: Optional[str]) \
-> Tuple[Optional[bytes], bool]:
print(f"Генерация изображения: {prompt}")
context = [{"role": "user", "content": prompt}]
try:
response = await self.client_image.chat.send_async(
model=self.model_image,
messages=context,
user=f'{self.platform}_{bot_id}_{chat_id}',
modalities=["image"],
image_config={"aspect_ratio": aspect_ratio} if aspect_ratio is not None else None
)
image_url = response.choices[0].message.images[0].image_url.url
header, image_base64 = image_url.split(",", 1)
mime_type = header.split(";")[0].replace("data:", "")
image_bytes = base64.b64decode(image_base64)
if mime_type != "image/jpeg":
image = Image.open(BytesIO(image_bytes)).convert("RGB")
output = BytesIO()
image.save(output, format="JPEG", quality=80, optimize=True)
image_bytes = output.getvalue()
return image_bytes, True
except Exception:
return None, False
agent: AiAgent
def create_ai_agent(api_token_main: str, model_main: str,
api_token_image: str, model_image: str,
db: BasicDatabase, platform: str):
global agent
agent = AiAgent(api_token_main, model_main, api_token_image, model_image, db, platform)