Добавлена отправка изображения в высоком качестве пользователю.
Исправлен выбор размера изображения для Seedream. Обновлена модель генерации аниме.
This commit is contained in:
parent
25992ef772
commit
e0f521256b
7 changed files with 64 additions and 62 deletions
103
ai_agent.py
103
ai_agent.py
|
|
@ -1,4 +1,3 @@
|
||||||
import aiohttp
|
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
|
@ -7,7 +6,6 @@ from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from result import Ok, Err, Result
|
|
||||||
from typing import List, Tuple, Optional, Union, Dict, Awaitable
|
from typing import List, Tuple, Optional, Union, Dict, Awaitable
|
||||||
|
|
||||||
from openrouter import OpenRouter, RetryConfig
|
from openrouter import OpenRouter, RetryConfig
|
||||||
|
|
@ -19,6 +17,7 @@ from openrouter.utils import BackoffStrategy
|
||||||
from fal_client import AsyncClient as FalClient
|
from fal_client import AsyncClient as FalClient
|
||||||
from replicate import Client as ReplicateClient
|
from replicate import Client as ReplicateClient
|
||||||
|
|
||||||
|
from utils import download_file
|
||||||
from database import BasicDatabase
|
from database import BasicDatabase
|
||||||
|
|
||||||
OPENROUTER_X_TITLE = "TG/VK Chat Bot"
|
OPENROUTER_X_TITLE = "TG/VK Chat Bot"
|
||||||
|
|
@ -29,7 +28,7 @@ PRIVATE_CHAT_MAX_MESSAGES = 40
|
||||||
MAX_OUTPUT_TOKENS = 500
|
MAX_OUTPUT_TOKENS = 500
|
||||||
|
|
||||||
FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image"
|
FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image"
|
||||||
REPLICATE_MODEL = "ultracoderru/nova-anime-xl-14:1a8ded85309ffeb41780db667102ea935e4140b1a15f4a0f669a60a104b722db"
|
REPLICATE_MODEL = "ultracoderru/nova-anime-xl-14:3e9ada8e10123780c70bce6f14f907d7cdfc653e92f40767f47b08c9c24b6c4a"
|
||||||
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
|
|
@ -37,6 +36,7 @@ class Message:
|
||||||
user_name: str = None
|
user_name: str = None
|
||||||
text: str = None
|
text: str = None
|
||||||
image: bytes = None
|
image: bytes = None
|
||||||
|
image_hires: bytes = None
|
||||||
message_id: int = None
|
message_id: int = None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -65,6 +65,7 @@ class AiAgent:
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class _ToolsArtifacts:
|
class _ToolsArtifacts:
|
||||||
generated_image: Optional[bytes] = None
|
generated_image: Optional[bytes] = None
|
||||||
|
generated_image_hires: Optional[bytes] = None
|
||||||
|
|
||||||
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
||||||
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
|
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
|
||||||
|
|
@ -101,7 +102,8 @@ class AiAgent:
|
||||||
role="assistant", text=ai_response, image=tools_artifacts.generated_image,
|
role="assistant", text=ai_response, image=tools_artifacts.generated_image,
|
||||||
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||||
|
|
||||||
return Message(text=ai_response, image=tools_artifacts.generated_image), True
|
return Message(text=ai_response, image=tools_artifacts.generated_image,
|
||||||
|
image_hires=tools_artifacts.generated_image_hires), True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("Rate limit exceeded") != -1:
|
if str(e).find("Rate limit exceeded") != -1:
|
||||||
|
|
@ -134,7 +136,8 @@ class AiAgent:
|
||||||
text=ai_response, image=tools_artifacts.generated_image,
|
text=ai_response, image=tools_artifacts.generated_image,
|
||||||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||||
|
|
||||||
return Message(text=ai_response, image=tools_artifacts.generated_image), True
|
return Message(text=ai_response, image=tools_artifacts.generated_image,
|
||||||
|
image_hires=tools_artifacts.generated_image_hires), True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("Rate limit exceeded") != -1:
|
if str(e).find("Rate limit exceeded") != -1:
|
||||||
|
|
@ -218,22 +221,20 @@ class AiAgent:
|
||||||
-> List[ChatMessageContentItemTypedDict]:
|
-> List[ChatMessageContentItemTypedDict]:
|
||||||
prompt = args.get("prompt", "")
|
prompt = args.get("prompt", "")
|
||||||
aspect_ratio = args.get("aspect_ratio", None)
|
aspect_ratio = args.get("aspect_ratio", None)
|
||||||
result = await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio)
|
|
||||||
|
|
||||||
if result.is_ok():
|
aspect_ratio_size_map = {
|
||||||
artifacts.generated_image = result.ok_value
|
"1:1": "square",
|
||||||
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
|
"4:3": "landscape_4_3",
|
||||||
image=result.ok_value)
|
"3:4": "portrait_4_3",
|
||||||
else:
|
"16:9": "landscape_16_9",
|
||||||
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}")
|
"9:16": "portrait_16_9",
|
||||||
|
}
|
||||||
async def _generate_image(self, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]:
|
image_size = aspect_ratio_size_map.get(aspect_ratio, "landscape_4_3")
|
||||||
width, height = _get_resolution_for_aspect_ratio(aspect_ratio)
|
print(f"Генерация изображения {image_size}: {prompt}")
|
||||||
print(f"Генерация изображения {width}x{height}: {prompt}")
|
|
||||||
|
|
||||||
arguments = {
|
arguments = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"image_size": {"width": width, "height": height},
|
"image_size": image_size,
|
||||||
"enable_safety_checker": False
|
"enable_safety_checker": False
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -242,20 +243,13 @@ class AiAgent:
|
||||||
if "images" not in result:
|
if "images" not in result:
|
||||||
raise RuntimeError("Неожиданный ответ от сервера.")
|
raise RuntimeError("Неожиданный ответ от сервера.")
|
||||||
image_url = result["images"][0]["url"]
|
image_url = result["images"][0]["url"]
|
||||||
|
artifacts.generated_image_hires = await download_file(image_url)
|
||||||
async with aiohttp.ClientSession() as session:
|
artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280)
|
||||||
async with session.get(image_url) as response:
|
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
|
||||||
if response.status == 200:
|
image=artifacts.generated_image)
|
||||||
image = await response.read()
|
|
||||||
if not image_url.endswith(".jpg"):
|
|
||||||
image = _convert_image_to_jpeg(image)
|
|
||||||
return Ok(image)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Не удалось загрузить изображение ({response.status}).")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Ошибка генерации изображения: {e}")
|
print(f"Ошибка генерации изображения: {e}")
|
||||||
return Err(str(e))
|
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
|
||||||
|
|
||||||
async def _process_tool_generate_image_anime(self, _bot_id: int, _chat_id: int,
|
async def _process_tool_generate_image_anime(self, _bot_id: int, _chat_id: int,
|
||||||
args: dict, artifacts: _ToolsArtifacts) \
|
args: dict, artifacts: _ToolsArtifacts) \
|
||||||
|
|
@ -263,19 +257,15 @@ class AiAgent:
|
||||||
prompt = args.get("prompt", "")
|
prompt = args.get("prompt", "")
|
||||||
negative_prompt = args.get("negative_prompt", "")
|
negative_prompt = args.get("negative_prompt", "")
|
||||||
aspect_ratio = args.get("aspect_ratio", None)
|
aspect_ratio = args.get("aspect_ratio", None)
|
||||||
result = await self._generate_image_anime(prompt=prompt, negative_prompt=negative_prompt,
|
|
||||||
aspect_ratio=aspect_ratio)
|
|
||||||
|
|
||||||
if result.is_ok():
|
aspect_ratio_resolution_map = {
|
||||||
artifacts.generated_image = result.ok_value
|
"1:1": (1280, 1280),
|
||||||
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
|
"4:3": (1280, 1024),
|
||||||
image=result.ok_value)
|
"3:4": (1024, 1280),
|
||||||
else:
|
"16:9": (1280, 720),
|
||||||
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {result.err_value}")
|
"9:16": (720, 1280)
|
||||||
|
}
|
||||||
async def _generate_image_anime(self, prompt: str, negative_prompt: str, aspect_ratio: Optional[str]) \
|
width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
|
||||||
-> Result[bytes, str]:
|
|
||||||
width, height = _get_resolution_for_aspect_ratio(aspect_ratio)
|
|
||||||
print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}")
|
print(f"Генерация аниме-изображения {width}x{height}:\n+ {prompt}\n- {negative_prompt}")
|
||||||
|
|
||||||
arguments = {
|
arguments = {
|
||||||
|
|
@ -287,16 +277,19 @@ class AiAgent:
|
||||||
"guidance_scale": 4.5,
|
"guidance_scale": 4.5,
|
||||||
"num_inference_steps": 20,
|
"num_inference_steps": 20,
|
||||||
"hires_enable": True,
|
"hires_enable": True,
|
||||||
|
"hires_num_inference_steps": 30,
|
||||||
"disable_safety_checker": True
|
"disable_safety_checker": True
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
outputs = await self.replicate_client.async_run(REPLICATE_MODEL, input=arguments)
|
outputs = await self.replicate_client.async_run(REPLICATE_MODEL, input=arguments)
|
||||||
image = _convert_image_to_jpeg(await outputs[0].aread(), (width, height))
|
artifacts.generated_image_hires = await outputs[0].aread()
|
||||||
return Ok(image)
|
artifacts.generated_image = _compress_image(artifacts.generated_image_hires, 1280)
|
||||||
|
return _serialize_message_content(text="Изображение сгенерировано и будет показано пользователю.",
|
||||||
|
image=artifacts.generated_image)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Ошибка генерации изображения: {e}")
|
print(f"Ошибка генерации изображения: {e}")
|
||||||
return Err(str(e))
|
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
|
||||||
|
|
||||||
async def _async_chat_completion_request(self, **kwargs):
|
async def _async_chat_completion_request(self, **kwargs):
|
||||||
try:
|
try:
|
||||||
|
|
@ -397,21 +390,15 @@ def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageT
|
||||||
return _remove_none_recursive(message.model_dump(by_alias=True))
|
return _remove_none_recursive(message.model_dump(by_alias=True))
|
||||||
|
|
||||||
|
|
||||||
def _get_resolution_for_aspect_ratio(aspect_ratio: str) -> Tuple[int, int]:
|
def _compress_image(image: bytes, max_side: Optional[int] = None) -> bytes:
|
||||||
aspect_ratio_resolution_map = {
|
|
||||||
"1:1": (1280, 1280),
|
|
||||||
"4:3": (1280, 1024),
|
|
||||||
"3:4": (1024, 1280),
|
|
||||||
"16:9": (1280, 720),
|
|
||||||
"9:16": (720, 1280)
|
|
||||||
}
|
|
||||||
return aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_image_to_jpeg(image: bytes, size: Optional[tuple[int, int]] = None) -> bytes:
|
|
||||||
img = Image.open(BytesIO(image)).convert("RGB")
|
img = Image.open(BytesIO(image)).convert("RGB")
|
||||||
if size is not None:
|
|
||||||
img = img.resize(size=size, resample=Image.Resampling.BILINEAR)
|
if img.width > max_side or img.height > max_side:
|
||||||
|
scale = min(max_side / img.width, max_side / img.height)
|
||||||
|
new_width = int(img.width * scale)
|
||||||
|
new_height = int(img.height * scale)
|
||||||
|
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
output = BytesIO()
|
output = BytesIO()
|
||||||
img.save(output, format='JPEG', quality=87, optimize=True)
|
img.save(output, format='JPEG', quality=87, optimize=True)
|
||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,7 @@ async def any_message_handler(message: Message, bot: Bot):
|
||||||
|
|
||||||
if answer.image is not None:
|
if answer.image is not None:
|
||||||
answer_id = (await message.reply_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id
|
answer_id = (await message.reply_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id
|
||||||
|
await message.answer_document(document=wrap_document(answer.image_hires, 'image', 'png'))
|
||||||
else:
|
else:
|
||||||
answer_id = (await message.reply(answer.text)).message_id
|
answer_id = (await message.reply(answer.text)).message_id
|
||||||
if success:
|
if success:
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,7 @@ async def any_message_handler(message: Message, bot: Bot):
|
||||||
|
|
||||||
if answer.image is not None:
|
if answer.image is not None:
|
||||||
answer_id = (await message.answer_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id
|
answer_id = (await message.answer_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id
|
||||||
|
await message.answer_document(document=wrap_document(answer.image_hires, 'image', 'png'))
|
||||||
else:
|
else:
|
||||||
answer_id = (await message.answer(answer.text)).message_id
|
answer_id = (await message.answer(answer.text)).message_id
|
||||||
if success:
|
if success:
|
||||||
|
|
|
||||||
10
tg/utils.py
10
tg/utils.py
|
|
@ -1,4 +1,5 @@
|
||||||
import io
|
from datetime import datetime
|
||||||
|
from io import BytesIO
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
|
|
@ -22,7 +23,7 @@ async def get_user_name_for_ai(user: User):
|
||||||
|
|
||||||
async def download_photo(photo: PhotoSize, bot: Bot) -> bytes:
|
async def download_photo(photo: PhotoSize, bot: Bot) -> bytes:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
photo_bytes: io.BytesIO = await bot.download(photo.file_id)
|
photo_bytes: BytesIO = await bot.download(photo.file_id)
|
||||||
return photo_bytes.getvalue()
|
return photo_bytes.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,3 +55,8 @@ async def create_ai_message(message: Message, bot: Bot) -> ai_agent.Message:
|
||||||
|
|
||||||
def wrap_photo(image: bytes) -> BufferedInputFile:
|
def wrap_photo(image: bytes) -> BufferedInputFile:
|
||||||
return BufferedInputFile(image, 'image.jpg')
|
return BufferedInputFile(image, 'image.jpg')
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_document(document: bytes, name_prefix: str, extension: str) -> BufferedInputFile:
|
||||||
|
name = "{}_{}.{}".format(name_prefix, datetime.now().strftime("%Y%m%d_%H%M%S"), extension)
|
||||||
|
return BufferedInputFile(document, name)
|
||||||
|
|
|
||||||
7
utils.py
7
utils.py
|
|
@ -2,6 +2,7 @@ import asyncio
|
||||||
from calendar import timegm
|
from calendar import timegm
|
||||||
from typing import Awaitable, Callable, Coroutine, Optional
|
from typing import Awaitable, Callable, Coroutine, Optional
|
||||||
|
|
||||||
|
from aiohttp import ClientSession
|
||||||
from pymorphy3 import MorphAnalyzer
|
from pymorphy3 import MorphAnalyzer
|
||||||
from time import gmtime
|
from time import gmtime
|
||||||
|
|
||||||
|
|
@ -50,3 +51,9 @@ async def run_with_progress(main_func: Callable[[], Coroutine], progress_func: C
|
||||||
await progress_task
|
await progress_task
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def download_file(url: str) -> bytes:
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with session.get(url) as response:
|
||||||
|
return await response.read()
|
||||||
|
|
|
||||||
|
|
@ -94,7 +94,7 @@ async def any_message_handler(message: Message):
|
||||||
interval=4)
|
interval=4)
|
||||||
|
|
||||||
if answer.image is not None:
|
if answer.image is not None:
|
||||||
photo = await upload_photo(answer.image, chat_id=chat_id, api=message.ctx_api)
|
photo = await upload_photo(answer.image_hires, chat_id=chat_id, api=message.ctx_api)
|
||||||
answer_id = (await message.reply(answer.text, attachment=photo)).conversation_message_id
|
answer_id = (await message.reply(answer.text, attachment=photo)).conversation_message_id
|
||||||
else:
|
else:
|
||||||
answer_id = (await message.reply(answer.text)).conversation_message_id
|
answer_id = (await message.reply(answer.text)).conversation_message_id
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ async def any_message_handler(message: Message):
|
||||||
interval=4)
|
interval=4)
|
||||||
|
|
||||||
if answer.image is not None:
|
if answer.image is not None:
|
||||||
photo = await upload_photo(answer.image, chat_id=chat_id, api=message.ctx_api)
|
photo = await upload_photo(answer.image_hires, chat_id=chat_id, api=message.ctx_api)
|
||||||
answer_id = (await message.answer(answer.text, attachment=photo)).conversation_message_id
|
answer_id = (await message.answer(answer.text, attachment=photo)).conversation_message_id
|
||||||
else:
|
else:
|
||||||
answer_id = (await message.answer(answer.text)).message_id
|
answer_id = (await message.answer(answer.text)).message_id
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue