import base64 import datetime import json from collections.abc import Callable from dataclasses import dataclass from io import BytesIO from PIL import Image from result import Ok, Err, Result from typing import List, Tuple, Any, Optional, Union, Dict, Awaitable from openrouter import OpenRouter, RetryConfig from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \ ChatMessageToolCall, MessageTypedDict, ToolDefinitionJSONTypedDict from openrouter.utils import BackoffStrategy from database import BasicDatabase GROUP_CHAT_SYSTEM_PROMPT = """ Ты - ИИ-помощник в групповом чате.\n Отвечай на вопросы и поддерживай контекст беседы.\n Ты не можешь обсуждать политику и религию.\n Сообщения пользователей будут приходить в следующем формате: '[дата время, имя]: текст сообщения'\n При ответе НЕ нужно указывать ни время, ни пользователя, которому предназначен ответ, ни свое имя.\n НЕ используй разметку Markdown, она не поддерживается мессенджером.\n Если нужно нарисовать изображение, используй вызов инструмента. Запрещено генерировать ASCII-арты, а также добавлять теги/маркеры вроде , [image]. """ PRIVATE_CHAT_SYSTEM_PROMPT = """ Ты - ИИ-помощник в чате c пользователем.\n Отвечай на вопросы и поддерживай контекст беседы.\n Сообщения пользователя будут приходить в следующем формате: '[дата время]: текст сообщения'\n При ответе НЕ нужно указывать время.\n НЕ используй разметку Markdown, она не поддерживается мессенджером.\n Если нужно нарисовать изображение, используй вызов инструмента. Запрещено генерировать ASCII-арты, а также добавлять теги/маркеры вроде , [image]. """ GENERATE_IMAGE_TOOL_DESCRIPTION = """ Генерация изображения по описанию. Используй этот инструмент, если пользователь просит сгенерировать изображение ('нарисуй', 'покажи' и т.п.), или если это улучшит ответ (например, в ролевой игре для визуализации сцены). """ GENERATE_IMAGE_TOOL_PROMPT_ARG_DESCRIPTION = """ Детальное описание на русском языке. Добавь детали для стиля, цвета, композиции, если нужно. """ OPENROUTER_X_TITLE = "TG/VK Chat Bot" OPENROUTER_HTTP_REFERER = "https://ultracoder.org" GROUP_CHAT_MAX_MESSAGES = 20 PRIVATE_CHAT_MAX_MESSAGES = 40 MAX_OUTPUT_TOKENS = 500 @dataclass() class Message: user_name: str = None text: str = None image: bytes = None message_id: int = None def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str: current_time = datetime.datetime.now().strftime("%d.%m.%Y %H:%M") prefix = f"[{current_time}, {username}]" if username is not None else f"[{current_time}]" return f"{prefix}: {text}" if text is not None else f"{prefix}:" def _encode_image(image: bytes) -> str: encoded_image = base64.b64encode(image).decode('utf-8') return f"data:image/jpeg;base64,{encoded_image}" def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict: serialized = {"role": role, "content": []} if text is not None: serialized["content"].append({"type": "text", "text": text}) if image is not None: serialized["content"].append({"type": "image_url", "image_url": {"url": _encode_image(image)}}) return serialized def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]: if isinstance(data, dict): return { k: _remove_none_recursive(v) for k, v in data.items() if v is not None } elif isinstance(data, list): return [ _remove_none_recursive(item) for item in data if item is not None ] else: return data def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict: return _remove_none_recursive(message.model_dump(by_alias=True)) def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]: return [{ "type": "function", "function": { "name": "generate_image", "description": GENERATE_IMAGE_TOOL_DESCRIPTION, "parameters": { "type": "object", "properties": { "prompt": { "type": "string", "description": GENERATE_IMAGE_TOOL_PROMPT_ARG_DESCRIPTION }, "aspect_ratio": { "type": "string", "enum": ["1:1", "3:4", "4:3", "9:16", "16:9"], "description": "Соотношение сторон (опционально)." } }, "required": ["prompt"] } } }] class AiAgent: def __init__(self, api_token_main: str, model_main: str, api_token_image: str, model_image: str, db: BasicDatabase, platform: str): retry_config = RetryConfig(strategy="backoff", backoff=BackoffStrategy( initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000), retry_connection_errors=True) self.db = db self.model_main = model_main self.model_image = model_image self.platform = platform self.client = OpenRouter(api_key=api_token_main, x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER, retry_config=retry_config) self.client_image = OpenRouter(api_key=api_token_image, x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER) @dataclass() class _ToolsArtifacts: generated_image: Optional[bytes] = None async def get_group_chat_reply(self, bot_id: int, chat_id: int, message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]: context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id) message.text = _add_message_prefix(message.text, message.user_name) context.append(_serialize_message(role="user", text=message.text, image=message.image)) for fwd_message in forwarded_messages: message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name) if fwd_message.text is not None: message_text += '\n' + fwd_message.text fwd_message.text = message_text context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image)) try: response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True) ai_response = response.content tools_artifacts = AiAgent._ToolsArtifacts() if response.tool_calls is not None: tools_artifacts = await self._process_tool_calls(bot_id, chat_id, tool_calls=response.tool_calls, context=context) response2 = await self._generate_reply(bot_id, chat_id, context=context) ai_response = response2.content self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image, message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES) for fwd_message in forwarded_messages: self.db.context_add_message(bot_id, chat_id, role="user", text=fwd_message.text, image=fwd_message.image, message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES) self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=tools_artifacts.generated_image, message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES) return Message(text=ai_response, image=tools_artifacts.generated_image), True except Exception as e: if str(e).find("Rate limit exceeded") != -1: return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False else: print(f"Ошибка выполнения запроса к ИИ: {e}") return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]: context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id) message.text = _add_message_prefix(message.text) content: list[dict[str, Any]] = [] if message.text is not None: content.append({"type": "text", "text": message.text}) if message.image is not None: content.append({"type": "image_url", "image_url": {"url": _encode_image(message.image)}}) context.append({"role": "user", "content": content}) try: response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True) context.append(_serialize_assistant_message(response)) ai_response = response.content tools_artifacts = AiAgent._ToolsArtifacts() if response.tool_calls is not None: tools_artifacts = await self._process_tool_calls(bot_id, chat_id, tool_calls=response.tool_calls, context=context) response2 = await self._generate_reply(bot_id, chat_id, context=context) ai_response = response2.content self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image, message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES) self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=tools_artifacts.generated_image, message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES) return Message(text=ai_response, image=tools_artifacts.generated_image), True except Exception as e: if str(e).find("Rate limit exceeded") != -1: return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False else: print(f"Ошибка выполнения запроса к ИИ: {e}") return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False def get_last_assistant_message_id(self, bot_id: int, chat_id: int): return self.db.context_get_last_assistant_message_id(bot_id, chat_id) def set_last_response_id(self, bot_id: int, chat_id: int, message_id: int): self.db.context_set_last_message_id(bot_id, chat_id, message_id) def clear_chat_context(self, bot_id: int, chat_id: int): self.db.context_clear(bot_id, chat_id) def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[MessageTypedDict]: prompt = GROUP_CHAT_SYSTEM_PROMPT if is_group_chat else PRIVATE_CHAT_SYSTEM_PROMPT bot = self.db.get_bot(bot_id) if bot['ai_prompt'] is not None: prompt += '\n\n' + bot['ai_prompt'] chat = self.db.create_chat_if_not_exists(bot_id, chat_id) if chat['ai_prompt'] is not None: prompt += '\n\n' + chat['ai_prompt'] messages = self.db.context_get_messages(bot_id, chat_id) context: List[MessageTypedDict] = [{"role": "system", "content": prompt}] for message in messages: context.append(_serialize_message(message["role"], message["text"], message["image"])) return context async def _generate_reply(self, bot_id: int, chat_id: int, context: List[MessageTypedDict], allow_tools: bool = False) -> AssistantMessage: response = await self.client.chat.send_async( model=self.model_main, messages=context, tools=_get_tools_description() if allow_tools else None, tool_choice="auto" if allow_tools else None, max_tokens=MAX_OUTPUT_TOKENS, user=f'{self.platform}_{bot_id}_{chat_id}' ) return response.choices[0].message async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall], context: List[MessageTypedDict]) -> _ToolsArtifacts: artifacts = AiAgent._ToolsArtifacts() if tool_calls is None: return artifacts functions_map: Dict[str, Callable[[int, int, Dict, AiAgent._ToolsArtifacts], Awaitable[List[ChatMessageContentItemTypedDict]]]] = { "generate_image": self._process_tool_generate_image } for tool_call in tool_calls: tool_name = tool_call.function.name tool_args = json.loads(tool_call.function.arguments) if tool_name in functions_map: tool_result = await functions_map[tool_name](bot_id, chat_id, tool_args, artifacts) context.append({ "role": "tool", "tool_call_id": tool_call.id, "content": tool_result }) artifacts.tools_called = True return artifacts async def _process_tool_generate_image(self, bot_id: int, chat_id: int, args: dict, artifacts: _ToolsArtifacts) \ -> List[ChatMessageContentItemTypedDict]: prompt = args.get("prompt", "") aspect_ratio = args.get("aspect_ratio", None) result = await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio) content = [] if result.is_ok(): content.append( {"type": "text", "text": "Изображение сгенерировано и будет показано пользователю."}) content.append( {"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}}) artifacts.generated_image = result.ok_value else: content.append( {"type": "text", "text": f"Не удалось сгенерировать изображение: {result.err_value}"}) return content async def _generate_image(self, bot_id: int, chat_id: int, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]: print(f"Генерация изображения: {prompt}") context = [{"role": "user", "content": prompt}] try: response = await self.client_image.chat.send_async( model=self.model_image, messages=context, user=f'{self.platform}_{bot_id}_{chat_id}', modalities=["image"], image_config={"aspect_ratio": aspect_ratio} if aspect_ratio is not None else None ) image_url = response.choices[0].message.images[0].image_url.url header, image_base64 = image_url.split(",", 1) mime_type = header.split(";")[0].replace("data:", "") image_bytes = base64.b64decode(image_base64) if mime_type != "image/jpeg": image = Image.open(BytesIO(image_bytes)).convert("RGB") output = BytesIO() image.save(output, format="JPEG", quality=80, optimize=True) image_bytes = output.getvalue() return Ok(image_bytes) except Exception as e: print(f"Ошибка генерации изображения: {e}") return Err(str(e)) agent: AiAgent def create_ai_agent(api_token_main: str, model_main: str, api_token_image: str, model_image: str, db: BasicDatabase, platform: str): global agent agent = AiAgent(api_token_main, model_main, api_token_image, model_image, db, platform)