Добавлена отправка изображения в высоком качестве пользователю.

Исправлен выбор размера изображения для Seedream.
Обновлена модель генерации аниме.
This commit is contained in:
Kirill Kirilenko 2026-03-07 20:21:06 +03:00
parent 25992ef772
commit e0f521256b
7 changed files with 64 additions and 62 deletions

View file

@ -1,4 +1,3 @@
import aiohttp
import base64 import base64
import datetime import datetime
import json import json
@ -7,7 +6,6 @@ from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
from result import Ok, Err, Result
from typing import List, Tuple, Optional, Union, Dict, Awaitable from typing import List, Tuple, Optional, Union, Dict, Awaitable
from openrouter import OpenRouter, RetryConfig from openrouter import OpenRouter, RetryConfig
@ -19,6 +17,7 @@ from openrouter.utils import BackoffStrategy
from fal_client import AsyncClient as FalClient from fal_client import AsyncClient as FalClient
from replicate import Client as ReplicateClient from replicate import Client as ReplicateClient
from utils import download_file
from database import BasicDatabase from database import BasicDatabase
OPENROUTER_X_TITLE = "TG/VK Chat Bot" OPENROUTER_X_TITLE = "TG/VK Chat Bot"
@ -29,7 +28,7 @@ PRIVATE_CHAT_MAX_MESSAGES = 40
MAX_OUTPUT_TOKENS = 500 MAX_OUTPUT_TOKENS = 500
FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image" FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image"
REPLICATE_MODEL = "ultracoderru/nova-anime-xl-14:1a8ded85309ffeb41780db667102ea935e4140b1a15f4a0f669a60a104b722db" REPLICATE_MODEL = "ultracoderru/nova-anime-xl-14:3e9ada8e10123780c70bce6f14f907d7cdfc653e92f40767f47b08c9c24b6c4a"
@dataclass() @dataclass()
@ -37,6 +36,7 @@ class Message:
user_name: str = None user_name: str = None
text: str = None text: str = None
image: bytes = None image: bytes = None
image_hires: bytes = None
message_id: int = None message_id: int = None
@ -65,6 +65,7 @@ class AiAgent:
@dataclass() @dataclass()
class _ToolsArtifacts: class _ToolsArtifacts:
generated_image: Optional[bytes] = None generated_image: Optional[bytes] = None
generated_image_hires: Optional[bytes] = None
async def get_group_chat_reply(self, bot_id: int, chat_id: int, async def get_group_chat_reply(self, bot_id: int, chat_id: int,
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]: message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
@ -101,7 +102,8 @@ class AiAgent:
role="assistant", text=ai_response, image=tools_artifacts.generated_image, role="assistant", text=ai_response, image=tools_artifacts.generated_image,
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES) message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_artifacts.generated_image), True return Message(text=ai_response, image=tools_artifacts.generated_image,
image_hires=tools_artifacts.generated_image_hires), True
except Exception as e: except Exception as e:
if str(e).find("Rate limit exceeded") != -1: if str(e).find("Rate limit exceeded") != -1:
@ -134,7 +136,8 @@ class AiAgent:
text=ai_response, image=tools_artifacts.generated_image, text=ai_response, image=tools_artifacts.generated_image,
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES) message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
return Message(text=ai_response, image=tools_artifacts.generated_image), True return Message(text=ai_response, image=tools_artifacts.generated_image,
image_hires=tools_artifacts.generated_image_hires), True
except Exception as e: except Exception as e:
if str(e).find("Rate limit exceeded") != -1: if str(e).find("Rate limit exceeded") != -1:
@ -218,22 +221,20 @@ class AiAgent:
-> List[ChatMessageContentItemTypedDict]: -> List[ChatMessageContentItemTypedDict]:
prompt = args.get("prompt", "") prompt = args.get("prompt", "")
aspect_ratio = args.get("aspect_ratio", None) aspect_ratio = args.get("aspect_ratio", None)
result = await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio)
if result.is_ok(): aspect_ratio_size_map = {
artifacts.generated_image = result.ok_value "1:1": "square",
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.", "4:3": "landscape_4_3",
image=result.ok_value) "3:4": "portrait_4_3",
else: "16:9": "landscape_16_9",
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}") "9:16": "portrait_16_9",
}
async def _generate_image(self, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]: image_size = aspect_ratio_size_map.get(aspect_ratio, "landscape_4_3")
width, height = _get_resolution_for_aspect_ratio(aspect_ratio) print(f"Генерация изображения {image_size}: {prompt}")
print(f"Генерация изображения {width}x{height}: {prompt}")
arguments = { arguments = {
"prompt": prompt, "prompt": prompt,
"image_size": {"width": width, "height": height}, "image_size": image_size,
"enable_safety_checker": False "enable_safety_checker": False
} }
@ -242,20 +243,13 @@ class AiAgent:
if "images" not in result: if "images" not in result:
raise RuntimeError("Неожиданный ответ от сервера.") raise RuntimeError("Неожиданный ответ от сервера.")
image_url = result["images"][0]["url"] image_url = result["images"][0]["url"]
artifacts.generated_image_hires = await download_file(image_url)
async with aiohttp.ClientSession() as session: artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280)
async with session.get(image_url) as response: return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
if response.status == 200: image=artifacts.generated_image)
image = await response.read()
if not image_url.endswith(".jpg"):
image = _convert_image_to_jpeg(image)
return Ok(image)
else:
raise RuntimeError(f"Не удалось загрузить изображение ({response.status}).")
except Exception as e: except Exception as e:
print(f"Ошибка генерации изображения: {e}") print(f"Ошибка генерации изображения: {e}")
return Err(str(e)) return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
async def _process_tool_generate_image_anime(self, _bot_id: int, _chat_id: int, async def _process_tool_generate_image_anime(self, _bot_id: int, _chat_id: int,
args: dict, artifacts: _ToolsArtifacts) \ args: dict, artifacts: _ToolsArtifacts) \
@ -263,19 +257,15 @@ class AiAgent:
prompt = args.get("prompt", "") prompt = args.get("prompt", "")
negative_prompt = args.get("negative_prompt", "") negative_prompt = args.get("negative_prompt", "")
aspect_ratio = args.get("aspect_ratio", None) aspect_ratio = args.get("aspect_ratio", None)
result = await self._generate_image_anime(prompt=prompt, negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio)
if result.is_ok(): aspect_ratio_resolution_map = {
artifacts.generated_image = result.ok_value "1:1": (1280, 1280),
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.", "4:3": (1280, 1024),
image=result.ok_value) "3:4": (1024, 1280),
else: "16:9": (1280, 720),
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}") "9:16": (720, 1280)
}
async def _generate_image_anime(self, prompt: str, negative_prompt: str, aspect_ratio: Optional[str]) \ width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
-> Result[bytes, str]:
width, height = _get_resolution_for_aspect_ratio(aspect_ratio)
print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}") print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}")
arguments = { arguments = {
@ -287,16 +277,19 @@ class AiAgent:
"guidance_scale": 4.5, "guidance_scale": 4.5,
"num_inference_steps": 20, "num_inference_steps": 20,
"hires_enable": True, "hires_enable": True,
"hires_num_inference_steps": 30,
"disable_safety_checker": True "disable_safety_checker": True
} }
try: try:
outputs = await self.replicate_client.async_run(REPLICATE_MODEL, input=arguments) outputs = await self.replicate_client.async_run(REPLICATE_MODEL, input=arguments)
image = _convert_image_to_jpeg(await outputs[0].aread(), (width, height)) artifacts.generated_image_hires = await outputs[0].aread()
return Ok(image) artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280)
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
image=artifacts.generated_image)
except Exception as e: except Exception as e:
print(f"Ошибка генерации изображения: {e}") print(f"Ошибка генерации изображения: {e}")
return Err(str(e)) return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
async def _async_chat_completion_request(self, **kwargs): async def _async_chat_completion_request(self, **kwargs):
try: try:
@ -397,21 +390,15 @@ def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageT
return _remove_none_recursive(message.model_dump(by_alias=True)) return _remove_none_recursive(message.model_dump(by_alias=True))
def _get_resolution_for_aspect_ratio(aspect_ratio: str) -> Tuple[int, int]: def _compress_image(image: bytes, max_side: Optional[int] = None) -> bytes:
aspect_ratio_resolution_map = {
"1:1": (1280, 1280),
"4:3": (1280, 1024),
"3:4": (1024, 1280),
"16:9": (1280, 720),
"9:16": (720, 1280)
}
return aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
def _convert_image_to_jpeg(image: bytes, size: Optional[tuple[int, int]] = None) -> bytes:
img = Image.open(BytesIO(image)).convert("RGB") img = Image.open(BytesIO(image)).convert("RGB")
if size is not None:
img = img.resize(size=size, resample=Image.Resampling.BILINEAR) if img.width > max_side or img.height > max_side:
scale = min(max_side / img.width, max_side / img.height)
new_width = int(img.width * scale)
new_height = int(img.height * scale)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
output = BytesIO() output = BytesIO()
img.save(output, format='JPEG', quality=87, optimize=True) img.save(output, format='JPEG', quality=87, optimize=True)
return output.getvalue() return output.getvalue()

View file

@ -86,6 +86,7 @@ async def any_message_handler(message: Message, bot: Bot):
if answer.image is not None: if answer.image is not None:
answer_id = (await message.reply_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id answer_id = (await message.reply_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id
await message.answer_document(document=wrap_document(answer.image_hires, 'image', 'png'))
else: else:
answer_id = (await message.reply(answer.text)).message_id answer_id = (await message.reply(answer.text)).message_id
if success: if success:

View file

@ -61,6 +61,7 @@ async def any_message_handler(message: Message, bot: Bot):
if answer.image is not None: if answer.image is not None:
answer_id = (await message.answer_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id answer_id = (await message.answer_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id
await message.answer_document(document=wrap_document(answer.image_hires, 'image', 'png'))
else: else:
answer_id = (await message.answer(answer.text)).message_id answer_id = (await message.answer(answer.text)).message_id
if success: if success:

View file

@ -1,4 +1,5 @@
import io from datetime import datetime
from io import BytesIO
from typing import Optional from typing import Optional
from aiogram import Bot from aiogram import Bot
@ -22,7 +23,7 @@ async def get_user_name_for_ai(user: User):
async def download_photo(photo: PhotoSize, bot: Bot) -> bytes: async def download_photo(photo: PhotoSize, bot: Bot) -> bytes:
# noinspection PyTypeChecker # noinspection PyTypeChecker
photo_bytes: io.BytesIO = await bot.download(photo.file_id) photo_bytes: BytesIO = await bot.download(photo.file_id)
return photo_bytes.getvalue() return photo_bytes.getvalue()
@ -54,3 +55,8 @@ async def create_ai_message(message: Message, bot: Bot) -> ai_agent.Message:
def wrap_photo(image: bytes) -> BufferedInputFile: def wrap_photo(image: bytes) -> BufferedInputFile:
return BufferedInputFile(image, 'image.jpg') return BufferedInputFile(image, 'image.jpg')
def wrap_document(document: bytes, name_prefix: str, extension: str) -> BufferedInputFile:
name = "{}_{}.{}".format(name_prefix, datetime.now().strftime("%Y%m%d_%H%M%S"), extension)
return BufferedInputFile(document, name)

View file

@ -2,6 +2,7 @@ import asyncio
from calendar import timegm from calendar import timegm
from typing import Awaitable, Callable, Coroutine, Optional from typing import Awaitable, Callable, Coroutine, Optional
from aiohttp import ClientSession
from pymorphy3 import MorphAnalyzer from pymorphy3 import MorphAnalyzer
from time import gmtime from time import gmtime
@ -50,3 +51,9 @@ async def run_with_progress(main_func: Callable[[], Coroutine], progress_func: C
await progress_task await progress_task
return result return result
async def download_file(url: str) -> bytes:
async with ClientSession() as session:
async with session.get(url) as response:
return await response.read()

View file

@ -94,7 +94,7 @@ async def any_message_handler(message: Message):
interval=4) interval=4)
if answer.image is not None: if answer.image is not None:
photo = await upload_photo(answer.image, chat_id=chat_id, api=message.ctx_api) photo = await upload_photo(answer.image_hires, chat_id=chat_id, api=message.ctx_api)
answer_id = (await message.reply(answer.text, attachment=photo)).conversation_message_id answer_id = (await message.reply(answer.text, attachment=photo)).conversation_message_id
else: else:
answer_id = (await message.reply(answer.text)).conversation_message_id answer_id = (await message.reply(answer.text)).conversation_message_id

View file

@ -62,7 +62,7 @@ async def any_message_handler(message: Message):
interval=4) interval=4)
if answer.image is not None: if answer.image is not None:
photo = await upload_photo(answer.image, chat_id=chat_id, api=message.ctx_api) photo = await upload_photo(answer.image_hires, chat_id=chat_id, api=message.ctx_api)
answer_id = (await message.answer(answer.text, attachment=photo)).conversation_message_id answer_id = (await message.answer(answer.text, attachment=photo)).conversation_message_id
else: else:
answer_id = (await message.answer(answer.text)).message_id answer_id = (await message.answer(answer.text)).message_id