Добавлена отправка изображения в высоком качестве пользователю.
Исправлен выбор размера изображения для Seedream. Обновлена модель генерации аниме.
This commit is contained in:
parent
25992ef772
commit
e0f521256b
7 changed files with 64 additions and 62 deletions
103
ai_agent.py
103
ai_agent.py
|
|
@ -1,4 +1,3 @@
|
|||
import aiohttp
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
|
|
@ -7,7 +6,6 @@ from collections.abc import Callable
|
|||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from result import Ok, Err, Result
|
||||
from typing import List, Tuple, Optional, Union, Dict, Awaitable
|
||||
|
||||
from openrouter import OpenRouter, RetryConfig
|
||||
|
|
@ -19,6 +17,7 @@ from openrouter.utils import BackoffStrategy
|
|||
from fal_client import AsyncClient as FalClient
|
||||
from replicate import Client as ReplicateClient
|
||||
|
||||
from utils import download_file
|
||||
from database import BasicDatabase
|
||||
|
||||
OPENROUTER_X_TITLE = "TG/VK Chat Bot"
|
||||
|
|
@ -29,7 +28,7 @@ PRIVATE_CHAT_MAX_MESSAGES = 40
|
|||
MAX_OUTPUT_TOKENS = 500
|
||||
|
||||
FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image"
|
||||
REPLICATE_MODEL = "ultracoderru/nova-anime-xl-14:1a8ded85309ffeb41780db667102ea935e4140b1a15f4a0f669a60a104b722db"
|
||||
REPLICATE_MODEL = "ultracoderru/nova-anime-xl-14:3e9ada8e10123780c70bce6f14f907d7cdfc653e92f40767f47b08c9c24b6c4a"
|
||||
|
||||
|
||||
@dataclass()
|
||||
|
|
@ -37,6 +36,7 @@ class Message:
|
|||
user_name: str = None
|
||||
text: str = None
|
||||
image: bytes = None
|
||||
image_hires: bytes = None
|
||||
message_id: int = None
|
||||
|
||||
|
||||
|
|
@ -65,6 +65,7 @@ class AiAgent:
|
|||
@dataclass()
|
||||
class _ToolsArtifacts:
|
||||
generated_image: Optional[bytes] = None
|
||||
generated_image_hires: Optional[bytes] = None
|
||||
|
||||
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
||||
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
|
||||
|
|
@ -101,7 +102,8 @@ class AiAgent:
|
|||
role="assistant", text=ai_response, image=tools_artifacts.generated_image,
|
||||
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||
|
||||
return Message(text=ai_response, image=tools_artifacts.generated_image), True
|
||||
return Message(text=ai_response, image=tools_artifacts.generated_image,
|
||||
image_hires=tools_artifacts.generated_image_hires), True
|
||||
|
||||
except Exception as e:
|
||||
if str(e).find("Rate limit exceeded") != -1:
|
||||
|
|
@ -134,7 +136,8 @@ class AiAgent:
|
|||
text=ai_response, image=tools_artifacts.generated_image,
|
||||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
|
||||
return Message(text=ai_response, image=tools_artifacts.generated_image), True
|
||||
return Message(text=ai_response, image=tools_artifacts.generated_image,
|
||||
image_hires=tools_artifacts.generated_image_hires), True
|
||||
|
||||
except Exception as e:
|
||||
if str(e).find("Rate limit exceeded") != -1:
|
||||
|
|
@ -218,22 +221,20 @@ class AiAgent:
|
|||
-> List[ChatMessageContentItemTypedDict]:
|
||||
prompt = args.get("prompt", "")
|
||||
aspect_ratio = args.get("aspect_ratio", None)
|
||||
result = await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio)
|
||||
|
||||
if result.is_ok():
|
||||
artifacts.generated_image = result.ok_value
|
||||
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
|
||||
image=result.ok_value)
|
||||
else:
|
||||
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}")
|
||||
|
||||
async def _generate_image(self, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]:
|
||||
width, height = _get_resolution_for_aspect_ratio(aspect_ratio)
|
||||
print(f"Генерация изображения {width}x{height}: {prompt}")
|
||||
aspect_ratio_size_map = {
|
||||
"1:1": "square",
|
||||
"4:3": "landscape_4_3",
|
||||
"3:4": "portrait_4_3",
|
||||
"16:9": "landscape_16_9",
|
||||
"9:16": "portrait_16_9",
|
||||
}
|
||||
image_size = aspect_ratio_size_map.get(aspect_ratio, "landscape_4_3")
|
||||
print(f"Генерация изображения {image_size}: {prompt}")
|
||||
|
||||
arguments = {
|
||||
"prompt": prompt,
|
||||
"image_size": {"width": width, "height": height},
|
||||
"image_size": image_size,
|
||||
"enable_safety_checker": False
|
||||
}
|
||||
|
||||
|
|
@ -242,20 +243,13 @@ class AiAgent:
|
|||
if "images" not in result:
|
||||
raise RuntimeError("Неожиданный ответ от сервера.")
|
||||
image_url = result["images"][0]["url"]
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as response:
|
||||
if response.status == 200:
|
||||
image = await response.read()
|
||||
if not image_url.endswith(".jpg"):
|
||||
image = _convert_image_to_jpeg(image)
|
||||
return Ok(image)
|
||||
else:
|
||||
raise RuntimeError(f"Не удалось загрузить изображение ({response.status}).")
|
||||
|
||||
artifacts.generated_image_hires = await download_file(image_url)
|
||||
artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280)
|
||||
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
|
||||
image=artifacts.generated_image)
|
||||
except Exception as e:
|
||||
print(f"Ошибка генерации изображения: {e}")
|
||||
return Err(str(e))
|
||||
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
|
||||
|
||||
async def _process_tool_generate_image_anime(self, _bot_id: int, _chat_id: int,
|
||||
args: dict, artifacts: _ToolsArtifacts) \
|
||||
|
|
@ -263,19 +257,15 @@ class AiAgent:
|
|||
prompt = args.get("prompt", "")
|
||||
negative_prompt = args.get("negative_prompt", "")
|
||||
aspect_ratio = args.get("aspect_ratio", None)
|
||||
result = await self._generate_image_anime(prompt=prompt, negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio)
|
||||
|
||||
if result.is_ok():
|
||||
artifacts.generated_image = result.ok_value
|
||||
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
|
||||
image=result.ok_value)
|
||||
else:
|
||||
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}")
|
||||
|
||||
async def _generate_image_anime(self, prompt: str, negative_prompt: str, aspect_ratio: Optional[str]) \
|
||||
-> Result[bytes, str]:
|
||||
width, height = _get_resolution_for_aspect_ratio(aspect_ratio)
|
||||
aspect_ratio_resolution_map = {
|
||||
"1:1": (1280, 1280),
|
||||
"4:3": (1280, 1024),
|
||||
"3:4": (1024, 1280),
|
||||
"16:9": (1280, 720),
|
||||
"9:16": (720, 1280)
|
||||
}
|
||||
width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
|
||||
print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}")
|
||||
|
||||
arguments = {
|
||||
|
|
@ -287,16 +277,19 @@ class AiAgent:
|
|||
"guidance_scale": 4.5,
|
||||
"num_inference_steps": 20,
|
||||
"hires_enable": True,
|
||||
"hires_num_inference_steps": 30,
|
||||
"disable_safety_checker": True
|
||||
}
|
||||
|
||||
try:
|
||||
outputs = await self.replicate_client.async_run(REPLICATE_MODEL, input=arguments)
|
||||
image = _convert_image_to_jpeg(await outputs[0].aread(), (width, height))
|
||||
return Ok(image)
|
||||
artifacts.generated_image_hires = await outputs[0].aread()
|
||||
artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280)
|
||||
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
|
||||
image=artifacts.generated_image)
|
||||
except Exception as e:
|
||||
print(f"Ошибка генерации изображения: {e}")
|
||||
return Err(str(e))
|
||||
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
|
||||
|
||||
async def _async_chat_completion_request(self, **kwargs):
|
||||
try:
|
||||
|
|
@ -397,21 +390,15 @@ def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageT
|
|||
return _remove_none_recursive(message.model_dump(by_alias=True))
|
||||
|
||||
|
||||
def _get_resolution_for_aspect_ratio(aspect_ratio: str) -> Tuple[int, int]:
|
||||
aspect_ratio_resolution_map = {
|
||||
"1:1": (1280, 1280),
|
||||
"4:3": (1280, 1024),
|
||||
"3:4": (1024, 1280),
|
||||
"16:9": (1280, 720),
|
||||
"9:16": (720, 1280)
|
||||
}
|
||||
return aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
|
||||
|
||||
|
||||
def _convert_image_to_jpeg(image: bytes, size: Optional[tuple[int, int]] = None) -> bytes:
|
||||
def _compress_image(image: bytes, max_side: Optional[int] = None) -> bytes:
|
||||
img = Image.open(BytesIO(image)).convert("RGB")
|
||||
if size is not None:
|
||||
img = img.resize(size=size, resample=Image.Resampling.BILINEAR)
|
||||
|
||||
if img.width > max_side or img.height > max_side:
|
||||
scale = min(max_side / img.width, max_side / img.height)
|
||||
new_width = int(img.width * scale)
|
||||
new_height = int(img.height * scale)
|
||||
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
output = BytesIO()
|
||||
img.save(output, format='JPEG', quality=87, optimize=True)
|
||||
return output.getvalue()
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ async def any_message_handler(message: Message, bot: Bot):
|
|||
|
||||
if answer.image is not None:
|
||||
answer_id = (await message.reply_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id
|
||||
await message.answer_document(document=wrap_document(answer.image_hires, 'image', 'png'))
|
||||
else:
|
||||
answer_id = (await message.reply(answer.text)).message_id
|
||||
if success:
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ async def any_message_handler(message: Message, bot: Bot):
|
|||
|
||||
if answer.image is not None:
|
||||
answer_id = (await message.answer_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id
|
||||
await message.answer_document(document=wrap_document(answer.image_hires, 'image', 'png'))
|
||||
else:
|
||||
answer_id = (await message.answer(answer.text)).message_id
|
||||
if success:
|
||||
|
|
|
|||
10
tg/utils.py
10
tg/utils.py
|
|
@ -1,4 +1,5 @@
|
|||
import io
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from aiogram import Bot
|
||||
|
|
@ -22,7 +23,7 @@ async def get_user_name_for_ai(user: User):
|
|||
|
||||
async def download_photo(photo: PhotoSize, bot: Bot) -> bytes:
|
||||
# noinspection PyTypeChecker
|
||||
photo_bytes: io.BytesIO = await bot.download(photo.file_id)
|
||||
photo_bytes: BytesIO = await bot.download(photo.file_id)
|
||||
return photo_bytes.getvalue()
|
||||
|
||||
|
||||
|
|
@ -54,3 +55,8 @@ async def create_ai_message(message: Message, bot: Bot) -> ai_agent.Message:
|
|||
|
||||
def wrap_photo(image: bytes) -> BufferedInputFile:
|
||||
return BufferedInputFile(image, 'image.jpg')
|
||||
|
||||
|
||||
def wrap_document(document: bytes, name_prefix: str, extension: str) -> BufferedInputFile:
|
||||
name = "{}_{}.{}".format(name_prefix, datetime.now().strftime("%Y%m%d_%H%M%S"), extension)
|
||||
return BufferedInputFile(document, name)
|
||||
|
|
|
|||
7
utils.py
7
utils.py
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
from calendar import timegm
|
||||
from typing import Awaitable, Callable, Coroutine, Optional
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from pymorphy3 import MorphAnalyzer
|
||||
from time import gmtime
|
||||
|
||||
|
|
@ -50,3 +51,9 @@ async def run_with_progress(main_func: Callable[[], Coroutine], progress_func: C
|
|||
await progress_task
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def download_file(url: str) -> bytes:
|
||||
async with ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
return await response.read()
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ async def any_message_handler(message: Message):
|
|||
interval=4)
|
||||
|
||||
if answer.image is not None:
|
||||
photo = await upload_photo(answer.image, chat_id=chat_id, api=message.ctx_api)
|
||||
photo = await upload_photo(answer.image_hires, chat_id=chat_id, api=message.ctx_api)
|
||||
answer_id = (await message.reply(answer.text, attachment=photo)).conversation_message_id
|
||||
else:
|
||||
answer_id = (await message.reply(answer.text)).conversation_message_id
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ async def any_message_handler(message: Message):
|
|||
interval=4)
|
||||
|
||||
if answer.image is not None:
|
||||
photo = await upload_photo(answer.image, chat_id=chat_id, api=message.ctx_api)
|
||||
photo = await upload_photo(answer.image_hires, chat_id=chat_id, api=message.ctx_api)
|
||||
answer_id = (await message.answer(answer.text, attachment=photo)).conversation_message_id
|
||||
else:
|
||||
answer_id = (await message.answer(answer.text)).message_id
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue