From 10c91f2cdc083b6bbb4ec0a6ff08a7cee1329f59 Mon Sep 17 00:00:00 2001 From: Kirill Kirilenko Date: Fri, 20 Feb 2026 16:18:57 +0300 Subject: [PATCH] =?UTF-8?q?=D0=93=D0=B5=D0=BD=D0=B5=D1=80=D0=B0=D1=86?= =?UTF-8?q?=D0=B8=D1=8F=20=D0=B8=D0=B7=D0=BE=D0=B1=D1=80=D0=B0=D0=B6=D0=B5?= =?UTF-8?q?=D0=BD=D0=B8=D0=B9=20=D1=87=D0=B5=D1=80=D0=B5=D0=B7=20fal.ai.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai_agent.py | 97 +++++++++++++++++++++++++++----------------------- tg/__main__.py | 4 +-- vk/__main__.py | 4 +-- 3 files changed, 56 insertions(+), 49 deletions(-) diff --git a/ai_agent.py b/ai_agent.py index c2d66b0..62d1ba7 100644 --- a/ai_agent.py +++ b/ai_agent.py @@ -1,3 +1,4 @@ +import aiohttp import base64 import datetime import json @@ -15,6 +16,8 @@ from openrouter.components import AssistantMessage, AssistantMessageTypedDict, C from openrouter.errors import ResponseValidationError, ChatError from openrouter.utils import BackoffStrategy +from fal_client import AsyncClient as FalClient + from database import BasicDatabase GROUP_CHAT_SYSTEM_PROMPT = """ @@ -121,7 +124,7 @@ def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]: }, "aspect_ratio": { "type": "string", - "enum": ["1:1", "3:4", "4:3", "9:16", "16:9"], + "enum": ["1:1", "4:3", "3:4", "16:9", "9:16"], "description": "Соотношение сторон (опционально)." } }, @@ -133,8 +136,8 @@ def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]: class AiAgent: def __init__(self, - api_token_main: str, model_main: str, - api_token_image: str, model_image: str, + openrouter_token: str, openrouter_model: str, + fal_token: str, fal_model: str, db: BasicDatabase, platform: str): retry_config = RetryConfig(strategy="backoff", @@ -142,14 +145,13 @@ class AiAgent: initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000), retry_connection_errors=True) self.db = db - self.model_main = model_main - self.model_image = model_image + self.model_main = openrouter_model + self.model_image = fal_model self.platform = platform - self.client = OpenRouter(api_key=api_token_main, - x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER, - retry_config=retry_config) - self.client_image = OpenRouter(api_key=api_token_image, - x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER) + self.client_openrouter = OpenRouter(api_key=openrouter_token, + x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER, + retry_config=retry_config) + self.client_fal = FalClient(key=fal_token) @dataclass() class _ToolsArtifacts: @@ -300,11 +302,11 @@ class AiAgent: artifacts.tools_called = True return artifacts - async def _process_tool_generate_image(self, bot_id: int, chat_id: int, args: dict, artifacts: _ToolsArtifacts) \ + async def _process_tool_generate_image(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \ -> List[ChatMessageContentItemTypedDict]: prompt = args.get("prompt", "") aspect_ratio = args.get("aspect_ratio", None) - result = await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio) + result = await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio) content = [] if result.is_ok(): @@ -320,45 +322,50 @@ class AiAgent: "text": f"Не удалось сгенерировать изображение: {result.err_value}"}) return content - async def _generate_image(self, bot_id: int, chat_id: int, - prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]: - print(f"Генерация изображения: {prompt}") - context = [{"role": "user", "content": prompt}] + async def _generate_image(self, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]: + aspect_ratio_resolution_map = { + "1:1": (1280, 1280), + "4:3": (1280, 1024), + "3:4": (1024, 1280), + "16:9": (1280, 720), + "9:16": (720, 1280) + } + width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024)) + print(f"Генерация изображения {width}x{height}: {prompt}") + + arguments = { + "prompt": prompt, + "image_size": {"width": width, "height": height}, + "enable_safety_checker": False + } try: - response = await self._async_chat_completion_request( - model=self.model_image, - messages=context, - user=f'{self.platform}_{bot_id}_{chat_id}', - modalities=["image"], - image_config={"aspect_ratio": aspect_ratio} if aspect_ratio is not None else None - ) + result = await self.client_fal.run(self.model_image, arguments=arguments) + if "images" not in result: + raise RuntimeError("Неожиданный ответ от сервера.") + image_url = result["images"][0]["url"] - image_url = response.choices[0].message.images[0].image_url.url - header, image_base64 = image_url.split(",", 1) - mime_type = header.split(";")[0].replace("data:", "") - image_bytes = base64.b64decode(image_base64) + async with aiohttp.ClientSession() as session: + async with session.get(image_url) as response: + if response.status == 200: + image_bytes = await response.read() + if not image_url.endswith(".jpg"): + image = Image.open(BytesIO(image_bytes)).convert("RGB") + output = BytesIO() + image.save(output, format="JPEG", quality=80, optimize=True) + image_bytes = output.getvalue() - if mime_type != "image/jpeg": - image = Image.open(BytesIO(image_bytes)).convert("RGB") - output = BytesIO() - image.save(output, format="JPEG", quality=80, optimize=True) - image_bytes = output.getvalue() - - return Ok(image_bytes) + return Ok(image_bytes) + else: + raise RuntimeError(f"Не удалось загрузить изображение ({response.status}).") except Exception as e: - # Костыль для модели Seedream 4.5 - message = str(e) - prefix = "Request id:" - if prefix in message: - message = message.split(prefix)[0].strip() - print(f"Ошибка генерации изображения: {message}") - return Err(message) + print(f"Ошибка генерации изображения: {e}") + return Err(str(e)) async def _async_chat_completion_request(self, **kwargs): try: - return await self.client_image.chat.send_async(**kwargs) + return await self.client_openrouter.chat.send_async(**kwargs) except ResponseValidationError as e: # Костыль для OpenRouter SDK: # https://github.com/OpenRouterTeam/python-sdk/issues/44 @@ -386,8 +393,8 @@ class AiAgent: agent: AiAgent -def create_ai_agent(api_token_main: str, model_main: str, - api_token_image: str, model_image: str, +def create_ai_agent(openrouter_token: str, openrouter_model: str, + fal_token: str, fal_model: str, db: BasicDatabase, platform: str): global agent - agent = AiAgent(api_token_main, model_main, api_token_image, model_image, db, platform) + agent = AiAgent(openrouter_token, openrouter_model, fal_token, fal_model, db, platform) diff --git a/tg/__main__.py b/tg/__main__.py index 682fd50..c03d57f 100644 --- a/tg/__main__.py +++ b/tg/__main__.py @@ -24,8 +24,8 @@ async def main() -> None: database.create_database(config['db_connection_string']) - create_ai_agent(config['openrouter_token_main'], config['openrouter_model_main'], - config['openrouter_token_image'], config['openrouter_model_image'], + create_ai_agent(config['openrouter_token'], config['openrouter_model'], + config['fal_token'], config['fal_model'], database.DB, 'tg') bots: list[Bot] = [] diff --git a/vk/__main__.py b/vk/__main__.py index 2cc8a07..0b9a702 100644 --- a/vk/__main__.py +++ b/vk/__main__.py @@ -24,8 +24,8 @@ if __name__ == '__main__': database.create_database(config['db_connection_string']) - create_ai_agent(config['openrouter_token_main'], config['openrouter_model_main'], - config['openrouter_token_image'], config['openrouter_model_image'], + create_ai_agent(config['openrouter_token'], config['openrouter_model'], + config['fal_token'], config['fal_model'], database.DB, 'vk') bot = Bot(labeler=handlers.labeler)