From e0f521256bce9688551da66cb9611c13cd46b0c8 Mon Sep 17 00:00:00 2001 From: Kirill Kirilenko Date: Sat, 7 Mar 2026 20:21:06 +0300 Subject: [PATCH] =?UTF-8?q?=D0=94=D0=BE=D0=B1=D0=B0=D0=B2=D0=BB=D0=B5?= =?UTF-8?q?=D0=BD=D0=B0=20=D0=BE=D1=82=D0=BF=D1=80=D0=B0=D0=B2=D0=BA=D0=B0?= =?UTF-8?q?=20=D0=B8=D0=B7=D0=BE=D0=B1=D1=80=D0=B0=D0=B6=D0=B5=D0=BD=D0=B8?= =?UTF-8?q?=D1=8F=20=D0=B2=20=D0=B2=D1=8B=D1=81=D0=BE=D0=BA=D0=BE=D0=BC=20?= =?UTF-8?q?=D0=BA=D0=B0=D1=87=D0=B5=D1=81=D1=82=D0=B2=D0=B5=20=D0=BF=D0=BE?= =?UTF-8?q?=D0=BB=D1=8C=D0=B7=D0=BE=D0=B2=D0=B0=D1=82=D0=B5=D0=BB=D1=8E.?= =?UTF-8?q?=20=D0=98=D1=81=D0=BF=D1=80=D0=B0=D0=B2=D0=BB=D0=B5=D0=BD=20?= =?UTF-8?q?=D0=B2=D1=8B=D0=B1=D0=BE=D1=80=20=D1=80=D0=B0=D0=B7=D0=BC=D0=B5?= =?UTF-8?q?=D1=80=D0=B0=20=D0=B8=D0=B7=D0=BE=D0=B1=D1=80=D0=B0=D0=B6=D0=B5?= =?UTF-8?q?=D0=BD=D0=B8=D1=8F=20=D0=B4=D0=BB=D1=8F=20Seedream.=20=D0=9E?= =?UTF-8?q?=D0=B1=D0=BD=D0=BE=D0=B2=D0=BB=D0=B5=D0=BD=D0=B0=20=D0=BC=D0=BE?= =?UTF-8?q?=D0=B4=D0=B5=D0=BB=D1=8C=20=D0=B3=D0=B5=D0=BD=D0=B5=D1=80=D0=B0?= =?UTF-8?q?=D1=86=D0=B8=D0=B8=20=D0=B0=D0=BD=D0=B8=D0=BC=D0=B5.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai_agent.py | 103 ++++++++++++++++++----------------------- tg/handlers/default.py | 1 + tg/handlers/private.py | 1 + tg/utils.py | 10 +++- utils.py | 7 +++ vk/handlers/default.py | 2 +- vk/handlers/private.py | 2 +- 7 files changed, 64 insertions(+), 62 deletions(-) diff --git a/ai_agent.py b/ai_agent.py index b514caa..ac011d8 100644 --- a/ai_agent.py +++ b/ai_agent.py @@ -1,4 +1,3 @@ -import aiohttp import base64 import datetime import json @@ -7,7 +6,6 @@ from collections.abc import Callable from dataclasses import dataclass from io import BytesIO from PIL import Image -from result import Ok, Err, Result from typing import List, Tuple, Optional, Union, Dict, Awaitable from openrouter import OpenRouter, RetryConfig @@ -19,6 +17,7 @@ from openrouter.utils import BackoffStrategy from fal_client import AsyncClient as FalClient from replicate import Client as ReplicateClient +from utils import download_file from database import BasicDatabase OPENROUTER_X_TITLE = "TG/VK Chat Bot" @@ -29,7 +28,7 @@ PRIVATE_CHAT_MAX_MESSAGES = 40 MAX_OUTPUT_TOKENS = 500 FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image" -REPLICATE_MODEL = "ultracoderru/nova-anime-xl-14:1a8ded85309ffeb41780db667102ea935e4140b1a15f4a0f669a60a104b722db" +REPLICATE_MODEL = "ultracoderru/nova-anime-xl-14:3e9ada8e10123780c70bce6f14f907d7cdfc653e92f40767f47b08c9c24b6c4a" @dataclass() @@ -37,6 +36,7 @@ class Message: user_name: str = None text: str = None image: bytes = None + image_hires: bytes = None message_id: int = None @@ -65,6 +65,7 @@ class AiAgent: @dataclass() class _ToolsArtifacts: generated_image: Optional[bytes] = None + generated_image_hires: Optional[bytes] = None async def get_group_chat_reply(self, bot_id: int, chat_id: int, message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]: @@ -101,7 +102,8 @@ class AiAgent: role="assistant", text=ai_response, image=tools_artifacts.generated_image, message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES) - return Message(text=ai_response, image=tools_artifacts.generated_image), True + return Message(text=ai_response, image=tools_artifacts.generated_image, + image_hires=tools_artifacts.generated_image_hires), True except Exception as e: if str(e).find("Rate limit exceeded") != -1: @@ -134,7 +136,8 @@ class AiAgent: text=ai_response, image=tools_artifacts.generated_image, message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES) - return Message(text=ai_response, image=tools_artifacts.generated_image), True + return Message(text=ai_response, image=tools_artifacts.generated_image, + image_hires=tools_artifacts.generated_image_hires), True except Exception as e: if str(e).find("Rate limit exceeded") != -1: @@ -218,22 +221,20 @@ class AiAgent: -> List[ChatMessageContentItemTypedDict]: prompt = args.get("prompt", "") aspect_ratio = args.get("aspect_ratio", None) - result = await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio) - if result.is_ok(): - artifacts.generated_image = result.ok_value - return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.", - image=result.ok_value) - else: - return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}") - - async def _generate_image(self, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]: - width, height = _get_resolution_for_aspect_ratio(aspect_ratio) - print(f"Генерация изображения {width}x{height}: {prompt}") + aspect_ratio_size_map = { + "1:1": "square", + "4:3": "landscape_4_3", + "3:4": "portrait_4_3", + "16:9": "landscape_16_9", + "9:16": "portrait_16_9", + } + image_size = aspect_ratio_size_map.get(aspect_ratio, "landscape_4_3") + print(f"Генерация изображения {image_size}: {prompt}") arguments = { "prompt": prompt, - "image_size": {"width": width, "height": height}, + "image_size": image_size, "enable_safety_checker": False } @@ -242,20 +243,13 @@ class AiAgent: if "images" not in result: raise RuntimeError("Неожиданный ответ от сервера.") image_url = result["images"][0]["url"] - - async with aiohttp.ClientSession() as session: - async with session.get(image_url) as response: - if response.status == 200: - image = await response.read() - if not image_url.endswith(".jpg"): - image = _convert_image_to_jpeg(image) - return Ok(image) - else: - raise RuntimeError(f"Не удалось загрузить изображение ({response.status}).") - + artifacts.generated_image_hires = await download_file(image_url) + artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280) + return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.", + image=artifacts.generated_image) except Exception as e: print(f"Ошибка генерации изображения: {e}") - return Err(str(e)) + return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}") async def _process_tool_generate_image_anime(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \ @@ -263,19 +257,15 @@ class AiAgent: prompt = args.get("prompt", "") negative_prompt = args.get("negative_prompt", "") aspect_ratio = args.get("aspect_ratio", None) - result = await self._generate_image_anime(prompt=prompt, negative_prompt=negative_prompt, - aspect_ratio=aspect_ratio) - if result.is_ok(): - artifacts.generated_image = result.ok_value - return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.", - image=result.ok_value) - else: - return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}") - - async def _generate_image_anime(self, prompt: str, negative_prompt: str, aspect_ratio: Optional[str]) \ - -> Result[bytes, str]: - width, height = _get_resolution_for_aspect_ratio(aspect_ratio) + aspect_ratio_resolution_map = { + "1:1": (1280, 1280), + "4:3": (1280, 1024), + "3:4": (1024, 1280), + "16:9": (1280, 720), + "9:16": (720, 1280) + } + width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024)) print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}") arguments = { @@ -287,16 +277,19 @@ class AiAgent: "guidance_scale": 4.5, "num_inference_steps": 20, "hires_enable": True, + "hires_num_inference_steps": 30, "disable_safety_checker": True } try: outputs = await self.replicate_client.async_run(REPLICATE_MODEL, input=arguments) - image = _convert_image_to_jpeg(await outputs[0].aread(), (width, height)) - return Ok(image) + artifacts.generated_image_hires = await outputs[0].aread() + artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280) + return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.", + image=artifacts.generated_image) except Exception as e: print(f"Ошибка генерации изображения: {e}") - return Err(str(e)) + return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}") async def _async_chat_completion_request(self, **kwargs): try: @@ -397,21 +390,15 @@ def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageT return _remove_none_recursive(message.model_dump(by_alias=True)) -def _get_resolution_for_aspect_ratio(aspect_ratio: str) -> Tuple[int, int]: - aspect_ratio_resolution_map = { - "1:1": (1280, 1280), - "4:3": (1280, 1024), - "3:4": (1024, 1280), - "16:9": (1280, 720), - "9:16": (720, 1280) - } - return aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024)) - - -def _convert_image_to_jpeg(image: bytes, size: Optional[tuple[int, int]] = None) -> bytes: +def _compress_image(image: bytes, max_side: Optional[int] = None) -> bytes: img = Image.open(BytesIO(image)).convert("RGB") - if size is not None: - img = img.resize(size=size, resample=Image.Resampling.BILINEAR) + + if img.width > max_side or img.height > max_side: + scale = min(max_side / img.width, max_side / img.height) + new_width = int(img.width * scale) + new_height = int(img.height * scale) + img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + output = BytesIO() img.save(output, format='JPEG', quality=87, optimize=True) return output.getvalue() diff --git a/tg/handlers/default.py b/tg/handlers/default.py index 62fdade..ed26d14 100644 --- a/tg/handlers/default.py +++ b/tg/handlers/default.py @@ -86,6 +86,7 @@ async def any_message_handler(message: Message, bot: Bot): if answer.image is not None: answer_id = (await message.reply_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id + await message.answer_document(document=wrap_document(answer.image_hires, 'image', 'png')) else: answer_id = (await message.reply(answer.text)).message_id if success: diff --git a/tg/handlers/private.py b/tg/handlers/private.py index da13d89..0ac9d46 100644 --- a/tg/handlers/private.py +++ b/tg/handlers/private.py @@ -61,6 +61,7 @@ async def any_message_handler(message: Message, bot: Bot): if answer.image is not None: answer_id = (await message.answer_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id + await message.answer_document(document=wrap_document(answer.image_hires, 'image', 'png')) else: answer_id = (await message.answer(answer.text)).message_id if success: diff --git a/tg/utils.py b/tg/utils.py index ea4a5d2..1085885 100644 --- a/tg/utils.py +++ b/tg/utils.py @@ -1,4 +1,5 @@ -import io +from datetime import datetime +from io import BytesIO from typing import Optional from aiogram import Bot @@ -22,7 +23,7 @@ async def get_user_name_for_ai(user: User): async def download_photo(photo: PhotoSize, bot: Bot) -> bytes: # noinspection PyTypeChecker - photo_bytes: io.BytesIO = await bot.download(photo.file_id) + photo_bytes: BytesIO = await bot.download(photo.file_id) return photo_bytes.getvalue() @@ -54,3 +55,8 @@ async def create_ai_message(message: Message, bot: Bot) -> ai_agent.Message: def wrap_photo(image: bytes) -> BufferedInputFile: return BufferedInputFile(image, 'image.jpg') + + +def wrap_document(document: bytes, name_prefix: str, extension: str) -> BufferedInputFile: + name = "{}_{}.{}".format(name_prefix, datetime.now().strftime("%Y%m%d_%H%M%S"), extension) + return BufferedInputFile(document, name) diff --git a/utils.py b/utils.py index 4fb2caa..0cd1d22 100644 --- a/utils.py +++ b/utils.py @@ -2,6 +2,7 @@ import asyncio from calendar import timegm from typing import Awaitable, Callable, Coroutine, Optional +from aiohttp import ClientSession from pymorphy3 import MorphAnalyzer from time import gmtime @@ -50,3 +51,9 @@ async def run_with_progress(main_func: Callable[[], Coroutine], progress_func: C await progress_task return result + + +async def download_file(url: str) -> bytes: + async with ClientSession() as session: + async with session.get(url) as response: + return await response.read() diff --git a/vk/handlers/default.py b/vk/handlers/default.py index e0b370e..b9c33c7 100644 --- a/vk/handlers/default.py +++ b/vk/handlers/default.py @@ -94,7 +94,7 @@ async def any_message_handler(message: Message): interval=4) if answer.image is not None: - photo = await upload_photo(answer.image, chat_id=chat_id, api=message.ctx_api) + photo = await upload_photo(answer.image_hires, chat_id=chat_id, api=message.ctx_api) answer_id = (await message.reply(answer.text, attachment=photo)).conversation_message_id else: answer_id = (await message.reply(answer.text)).conversation_message_id diff --git a/vk/handlers/private.py b/vk/handlers/private.py index 32ef825..f6ea638 100644 --- a/vk/handlers/private.py +++ b/vk/handlers/private.py @@ -62,7 +62,7 @@ async def any_message_handler(message: Message): interval=4) if answer.image is not None: - photo = await upload_photo(answer.image, chat_id=chat_id, api=message.ctx_api) + photo = await upload_photo(answer.image_hires, chat_id=chat_id, api=message.ctx_api) answer_id = (await message.answer(answer.text, attachment=photo)).conversation_message_id else: answer_id = (await message.answer(answer.text)).message_id