import aiohttp import base64 import datetime import json from collections.abc import Callable from dataclasses import dataclass from io import BytesIO from PIL import Image from result import Ok, Err, Result from typing import List, Tuple, Optional, Union, Dict, Awaitable from openrouter import OpenRouter, RetryConfig from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \ ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict from openrouter.errors import ResponseValidationError, ChatError from openrouter.utils import BackoffStrategy from fal_client import AsyncClient as FalClient from replicate import Client as ReplicateClient from database import BasicDatabase OPENROUTER_X_TITLE = "TG/VK Chat Bot" OPENROUTER_HTTP_REFERER = "https://ultracoder.org" GROUP_CHAT_MAX_MESSAGES = 40 PRIVATE_CHAT_MAX_MESSAGES = 40 MAX_OUTPUT_TOKENS = 500 FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image" REPLICATE_MODEL = "ultracoderru/nova-anime-xl-14:1a8ded85309ffeb41780db667102ea935e4140b1a15f4a0f669a60a104b722db" @dataclass() class Message: user_name: str = None text: str = None image: bytes = None message_id: int = None class AiAgent: def __init__(self, openrouter_token: str, openrouter_model: str, fal_token: str, replicate_token: str, db: BasicDatabase, platform: str): retry_config = RetryConfig(strategy="backoff", backoff=BackoffStrategy( initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000), retry_connection_errors=True) self.db = db self.openrouter_model = openrouter_model self.platform = platform self._load_prompts() self.client_openrouter = OpenRouter(api_key=openrouter_token, x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER, retry_config=retry_config) self.client_fal = FalClient(key=fal_token) self.replicate_client = ReplicateClient(api_token=replicate_token) @dataclass() class _ToolsArtifacts: generated_image: Optional[bytes] = None async def get_group_chat_reply(self, bot_id: int, chat_id: int, message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]: message.text = _add_message_prefix(message.text, message.user_name) context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id) context.append(_serialize_message(role="user", text=message.text, image=message.image)) for fwd_message in forwarded_messages: message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name) if fwd_message.text is not None: message_text += '\n' + fwd_message.text fwd_message.text = message_text context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image)) try: response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True) ai_response = response.content tools_artifacts = AiAgent._ToolsArtifacts() if response.tool_calls is not None: tools_artifacts = await self._process_tool_calls(bot_id, chat_id, tool_calls=response.tool_calls, context=context) response2 = await self._generate_reply(bot_id, chat_id, context=context) ai_response = response2.content self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image, message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES) for fwd_message in forwarded_messages: self.db.context_add_message(bot_id, chat_id, role="user", text=fwd_message.text, image=fwd_message.image, message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES) self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=tools_artifacts.generated_image, message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES) return Message(text=ai_response, image=tools_artifacts.generated_image), True except Exception as e: if str(e).find("Rate limit exceeded") != -1: return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False else: print(f"Ошибка выполнения запроса к ИИ: {e}") return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]: message.text = _add_message_prefix(message.text) context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id) context.append(_serialize_message(role="user", text=message.text, image=message.image)) try: response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True) context.append(_serialize_assistant_message(response)) ai_response = response.content tools_artifacts = AiAgent._ToolsArtifacts() if response.tool_calls is not None: tools_artifacts = await self._process_tool_calls(bot_id, chat_id, tool_calls=response.tool_calls, context=context) response2 = await self._generate_reply(bot_id, chat_id, context=context) ai_response = response2.content self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image, message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES) self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=tools_artifacts.generated_image, message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES) return Message(text=ai_response, image=tools_artifacts.generated_image), True except Exception as e: if str(e).find("Rate limit exceeded") != -1: return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False else: print(f"Ошибка выполнения запроса к ИИ: {e}") return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False def get_last_assistant_message_id(self, bot_id: int, chat_id: int): return self.db.context_get_last_assistant_message_id(bot_id, chat_id) def set_last_response_id(self, bot_id: int, chat_id: int, message_id: int): self.db.context_set_last_message_id(bot_id, chat_id, message_id) def clear_chat_context(self, bot_id: int, chat_id: int): self.db.context_clear(bot_id, chat_id) #################################################################################### def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[MessageTypedDict]: context: List[MessageTypedDict] = [ self._construct_system_prompt(is_group_chat=is_group_chat, bot_id=bot_id, chat_id=chat_id) ] for message in self.db.context_get_messages(bot_id, chat_id): context.append(_serialize_message(message["role"], message["text"], message["image"])) return context def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> SystemMessageTypedDict: prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK') prompt += '\n' + self.system_prompt_image_generation bot = self.db.get_bot(bot_id) if bot['ai_prompt'] is not None: prompt += '\n' + bot['ai_prompt'] + '\n' chat = self.db.create_chat_if_not_exists(bot_id, chat_id) if chat['ai_prompt'] is not None: prompt += '\n' + chat['ai_prompt'] return {"role": "system", "content": prompt} async def _generate_reply(self, bot_id: int, chat_id: int, context: List[MessageTypedDict], allow_tools: bool = False) -> AssistantMessage: response = await self._async_chat_completion_request( model=self.openrouter_model, messages=context, tools=self.tools_description if allow_tools else None, tool_choice="auto" if allow_tools else None, max_tokens=MAX_OUTPUT_TOKENS, user=f'{self.platform}_{bot_id}_{chat_id}' ) return self._filter_response(response.choices[0].message) async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall], context: List[MessageTypedDict]) -> _ToolsArtifacts: artifacts = AiAgent._ToolsArtifacts() if tool_calls is None: return artifacts functions_map: Dict[str, Callable[[int, int, Dict, AiAgent._ToolsArtifacts], Awaitable[List[ChatMessageContentItemTypedDict]]]] = { "generate_image": self._process_tool_generate_image, "generate_image_anime": self._process_tool_generate_image_anime } for tool_call in tool_calls: tool_name = tool_call.function.name tool_args = json.loads(tool_call.function.arguments) if tool_name in functions_map: tool_result = await functions_map[tool_name](bot_id, chat_id, tool_args, artifacts) context.append({ "role": "tool", "tool_call_id": tool_call.id, "content": tool_result }) return artifacts async def _process_tool_generate_image(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \ -> List[ChatMessageContentItemTypedDict]: prompt = args.get("prompt", "") aspect_ratio = args.get("aspect_ratio", None) result = await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio) if result.is_ok(): artifacts.generated_image = result.ok_value return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.", image=result.ok_value) else: return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}") async def _generate_image(self, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]: width, height = _get_resolution_for_aspect_ratio(aspect_ratio) print(f"Генерация изображения {width}x{height}: {prompt}") arguments = { "prompt": prompt, "image_size": {"width": width, "height": height}, "enable_safety_checker": False } try: result = await self.client_fal.run(FAL_MODEL, arguments=arguments) if "images" not in result: raise RuntimeError("Неожиданный ответ от сервера.") image_url = result["images"][0]["url"] async with aiohttp.ClientSession() as session: async with session.get(image_url) as response: if response.status == 200: image = await response.read() if not image_url.endswith(".jpg"): image = _convert_image_to_jpeg(image) return Ok(image) else: raise RuntimeError(f"Не удалось загрузить изображение ({response.status}).") except Exception as e: print(f"Ошибка генерации изображения: {e}") return Err(str(e)) async def _process_tool_generate_image_anime(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \ -> List[ChatMessageContentItemTypedDict]: prompt = args.get("prompt", "") negative_prompt = args.get("negative_prompt", "") aspect_ratio = args.get("aspect_ratio", None) result = await self._generate_image_anime(prompt=prompt, negative_prompt=negative_prompt, aspect_ratio=aspect_ratio) if result.is_ok(): artifacts.generated_image = result.ok_value return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.", image=result.ok_value) else: return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}") async def _generate_image_anime(self, prompt: str, negative_prompt: str, aspect_ratio: Optional[str]) \ -> Result[bytes, str]: width, height = _get_resolution_for_aspect_ratio(aspect_ratio) print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}") arguments = { "prompt": prompt, "negative_prompt": negative_prompt, "add_recommended_tags": False, "width": width, "height": height, "guidance_scale": 4.5, "num_inference_steps": 20, "hires_enable": True, "disable_safety_checker": True } try: outputs = await self.replicate_client.async_run(REPLICATE_MODEL, input=arguments) image = _convert_image_to_jpeg(await outputs[0].aread(), (width, height)) return Ok(image) except Exception as e: print(f"Ошибка генерации изображения: {e}") return Err(str(e)) async def _async_chat_completion_request(self, **kwargs): try: return await self.client_openrouter.chat.send_async(**kwargs) except ResponseValidationError as e: # Костыль для OpenRouter SDK: # https://github.com/OpenRouterTeam/python-sdk/issues/44 body = json.loads(e.body) if "error" in body: try: raw_response = json.loads(body["error"]["metadata"]["raw"]) message = str(raw_response["error"]["message"]) e = RuntimeError(message) except Exception: pass raise e except ChatError as e: if e.message == "Provider returned error": body = json.loads(e.body) try: raw_response = json.loads(body["error"]["metadata"]["raw"]) message = str(raw_response["error"]["message"]) e = RuntimeError(message) except Exception: pass raise e @staticmethod def _filter_response(response: AssistantMessage) -> AssistantMessage: text = str(response.content) text = text.replace("", "") response.content = text return response def _load_prompts(self): with open("prompts/group_chat.md", "r") as f: self.system_prompt_group_chat = f.read() with open("prompts/private_chat.md", "r") as f: self.system_prompt_private_chat = f.read() with open("prompts/image_generation.md", "r") as f: self.system_prompt_image_generation = f.read() with open("prompts/tools.json", "r") as f: self.tools_description = json.loads(f.read()) agent: AiAgent def create_ai_agent(openrouter_token: str, openrouter_model: str, fal_token: str, replicate_token: str, db: BasicDatabase, platform: str): global agent agent = AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, db, platform) def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str: current_time = datetime.datetime.now().strftime("%d.%m.%Y %H:%M") prefix = f"[{current_time}, {username}]" if username is not None else f"[{current_time}]" return f"{prefix}: {text}" if text is not None else f"{prefix}:" def _encode_image(image: bytes) -> str: encoded_image = base64.b64encode(image).decode('utf-8') return f"data:image/jpeg;base64,{encoded_image}" def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict: return {"role": role, "content": _serialize_message_content(text, image)} def _serialize_message_content(text: Optional[str], image: Optional[bytes] = None) -> list[dict]: content = [] if text is not None: content.append({"type": "text", "text": text}) if image is not None: content.append({"type": "image_url", "detail": "high", "image_url": {"url": _encode_image(image)}}) return content def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]: if isinstance(data, dict): return { k: _remove_none_recursive(v) for k, v in data.items() if v is not None } elif isinstance(data, list): return [ _remove_none_recursive(item) for item in data if item is not None ] else: return data def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict: return _remove_none_recursive(message.model_dump(by_alias=True)) def _get_resolution_for_aspect_ratio(aspect_ratio: str) -> Tuple[int, int]: aspect_ratio_resolution_map = { "1:1": (1280, 1280), "4:3": (1280, 1024), "3:4": (1024, 1280), "16:9": (1280, 720), "9:16": (720, 1280) } return aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024)) def _convert_image_to_jpeg(image: bytes, size: Optional[tuple[int, int]] = None) -> bytes: img = Image.open(BytesIO(image)).convert("RGB") if size is not None: img = img.resize(size=size, resample=Image.Resampling.BILINEAR) output = BytesIO() img.save(output, format='JPEG', quality=87, optimize=True) return output.getvalue()