Генерация изображений через fal.ai.
This commit is contained in:
parent
d1df004de3
commit
10c91f2cdc
3 changed files with 56 additions and 49 deletions
83
ai_agent.py
83
ai_agent.py
|
|
@ -1,3 +1,4 @@
|
||||||
|
import aiohttp
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
|
@ -15,6 +16,8 @@ from openrouter.components import AssistantMessage, AssistantMessageTypedDict, C
|
||||||
from openrouter.errors import ResponseValidationError, ChatError
|
from openrouter.errors import ResponseValidationError, ChatError
|
||||||
from openrouter.utils import BackoffStrategy
|
from openrouter.utils import BackoffStrategy
|
||||||
|
|
||||||
|
from fal_client import AsyncClient as FalClient
|
||||||
|
|
||||||
from database import BasicDatabase
|
from database import BasicDatabase
|
||||||
|
|
||||||
GROUP_CHAT_SYSTEM_PROMPT = """
|
GROUP_CHAT_SYSTEM_PROMPT = """
|
||||||
|
|
@ -121,7 +124,7 @@ def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]:
|
||||||
},
|
},
|
||||||
"aspect_ratio": {
|
"aspect_ratio": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["1:1", "3:4", "4:3", "9:16", "16:9"],
|
"enum": ["1:1", "4:3", "3:4", "16:9", "9:16"],
|
||||||
"description": "Соотношение сторон (опционально)."
|
"description": "Соотношение сторон (опционально)."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
@ -133,8 +136,8 @@ def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]:
|
||||||
|
|
||||||
class AiAgent:
|
class AiAgent:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
api_token_main: str, model_main: str,
|
openrouter_token: str, openrouter_model: str,
|
||||||
api_token_image: str, model_image: str,
|
fal_token: str, fal_model: str,
|
||||||
db: BasicDatabase,
|
db: BasicDatabase,
|
||||||
platform: str):
|
platform: str):
|
||||||
retry_config = RetryConfig(strategy="backoff",
|
retry_config = RetryConfig(strategy="backoff",
|
||||||
|
|
@ -142,14 +145,13 @@ class AiAgent:
|
||||||
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
|
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
|
||||||
retry_connection_errors=True)
|
retry_connection_errors=True)
|
||||||
self.db = db
|
self.db = db
|
||||||
self.model_main = model_main
|
self.model_main = openrouter_model
|
||||||
self.model_image = model_image
|
self.model_image = fal_model
|
||||||
self.platform = platform
|
self.platform = platform
|
||||||
self.client = OpenRouter(api_key=api_token_main,
|
self.client_openrouter = OpenRouter(api_key=openrouter_token,
|
||||||
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER,
|
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER,
|
||||||
retry_config=retry_config)
|
retry_config=retry_config)
|
||||||
self.client_image = OpenRouter(api_key=api_token_image,
|
self.client_fal = FalClient(key=fal_token)
|
||||||
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER)
|
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class _ToolsArtifacts:
|
class _ToolsArtifacts:
|
||||||
|
|
@ -300,11 +302,11 @@ class AiAgent:
|
||||||
artifacts.tools_called = True
|
artifacts.tools_called = True
|
||||||
return artifacts
|
return artifacts
|
||||||
|
|
||||||
async def _process_tool_generate_image(self, bot_id: int, chat_id: int, args: dict, artifacts: _ToolsArtifacts) \
|
async def _process_tool_generate_image(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \
|
||||||
-> List[ChatMessageContentItemTypedDict]:
|
-> List[ChatMessageContentItemTypedDict]:
|
||||||
prompt = args.get("prompt", "")
|
prompt = args.get("prompt", "")
|
||||||
aspect_ratio = args.get("aspect_ratio", None)
|
aspect_ratio = args.get("aspect_ratio", None)
|
||||||
result = await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio)
|
result = await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio)
|
||||||
|
|
||||||
content = []
|
content = []
|
||||||
if result.is_ok():
|
if result.is_ok():
|
||||||
|
|
@ -320,45 +322,50 @@ class AiAgent:
|
||||||
"text": f"Не удалось сгенерировать изображение: {result.err_value}"})
|
"text": f"Не удалось сгенерировать изображение: {result.err_value}"})
|
||||||
return content
|
return content
|
||||||
|
|
||||||
async def _generate_image(self, bot_id: int, chat_id: int,
|
async def _generate_image(self, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]:
|
||||||
prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]:
|
aspect_ratio_resolution_map = {
|
||||||
print(f"Генерация изображения: {prompt}")
|
"1:1": (1280, 1280),
|
||||||
context = [{"role": "user", "content": prompt}]
|
"4:3": (1280, 1024),
|
||||||
|
"3:4": (1024, 1280),
|
||||||
|
"16:9": (1280, 720),
|
||||||
|
"9:16": (720, 1280)
|
||||||
|
}
|
||||||
|
width, height = aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
|
||||||
|
print(f"Генерация изображения {width}x{height}: {prompt}")
|
||||||
|
|
||||||
|
arguments = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"image_size": {"width": width, "height": height},
|
||||||
|
"enable_safety_checker": False
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._async_chat_completion_request(
|
result = await self.client_fal.run(self.model_image, arguments=arguments)
|
||||||
model=self.model_image,
|
if "images" not in result:
|
||||||
messages=context,
|
raise RuntimeError("Неожиданный ответ от сервера.")
|
||||||
user=f'{self.platform}_{bot_id}_{chat_id}',
|
image_url = result["images"][0]["url"]
|
||||||
modalities=["image"],
|
|
||||||
image_config={"aspect_ratio": aspect_ratio} if aspect_ratio is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
image_url = response.choices[0].message.images[0].image_url.url
|
async with aiohttp.ClientSession() as session:
|
||||||
header, image_base64 = image_url.split(",", 1)
|
async with session.get(image_url) as response:
|
||||||
mime_type = header.split(";")[0].replace("data:", "")
|
if response.status == 200:
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = await response.read()
|
||||||
|
if not image_url.endswith(".jpg"):
|
||||||
if mime_type != "image/jpeg":
|
|
||||||
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
||||||
output = BytesIO()
|
output = BytesIO()
|
||||||
image.save(output, format="JPEG", quality=80, optimize=True)
|
image.save(output, format="JPEG", quality=80, optimize=True)
|
||||||
image_bytes = output.getvalue()
|
image_bytes = output.getvalue()
|
||||||
|
|
||||||
return Ok(image_bytes)
|
return Ok(image_bytes)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Не удалось загрузить изображение ({response.status}).")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Костыль для модели Seedream 4.5
|
print(f"Ошибка генерации изображения: {e}")
|
||||||
message = str(e)
|
return Err(str(e))
|
||||||
prefix = "Request id:"
|
|
||||||
if prefix in message:
|
|
||||||
message = message.split(prefix)[0].strip()
|
|
||||||
print(f"Ошибка генерации изображения: {message}")
|
|
||||||
return Err(message)
|
|
||||||
|
|
||||||
async def _async_chat_completion_request(self, **kwargs):
|
async def _async_chat_completion_request(self, **kwargs):
|
||||||
try:
|
try:
|
||||||
return await self.client_image.chat.send_async(**kwargs)
|
return await self.client_openrouter.chat.send_async(**kwargs)
|
||||||
except ResponseValidationError as e:
|
except ResponseValidationError as e:
|
||||||
# Костыль для OpenRouter SDK:
|
# Костыль для OpenRouter SDK:
|
||||||
# https://github.com/OpenRouterTeam/python-sdk/issues/44
|
# https://github.com/OpenRouterTeam/python-sdk/issues/44
|
||||||
|
|
@ -386,8 +393,8 @@ class AiAgent:
|
||||||
agent: AiAgent
|
agent: AiAgent
|
||||||
|
|
||||||
|
|
||||||
def create_ai_agent(api_token_main: str, model_main: str,
|
def create_ai_agent(openrouter_token: str, openrouter_model: str,
|
||||||
api_token_image: str, model_image: str,
|
fal_token: str, fal_model: str,
|
||||||
db: BasicDatabase, platform: str):
|
db: BasicDatabase, platform: str):
|
||||||
global agent
|
global agent
|
||||||
agent = AiAgent(api_token_main, model_main, api_token_image, model_image, db, platform)
|
agent = AiAgent(openrouter_token, openrouter_model, fal_token, fal_model, db, platform)
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,8 @@ async def main() -> None:
|
||||||
|
|
||||||
database.create_database(config['db_connection_string'])
|
database.create_database(config['db_connection_string'])
|
||||||
|
|
||||||
create_ai_agent(config['openrouter_token_main'], config['openrouter_model_main'],
|
create_ai_agent(config['openrouter_token'], config['openrouter_model'],
|
||||||
config['openrouter_token_image'], config['openrouter_model_image'],
|
config['fal_token'], config['fal_model'],
|
||||||
database.DB, 'tg')
|
database.DB, 'tg')
|
||||||
|
|
||||||
bots: list[Bot] = []
|
bots: list[Bot] = []
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,8 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
database.create_database(config['db_connection_string'])
|
database.create_database(config['db_connection_string'])
|
||||||
|
|
||||||
create_ai_agent(config['openrouter_token_main'], config['openrouter_model_main'],
|
create_ai_agent(config['openrouter_token'], config['openrouter_model'],
|
||||||
config['openrouter_token_image'], config['openrouter_model_image'],
|
config['fal_token'], config['fal_model'],
|
||||||
database.DB, 'vk')
|
database.DB, 'vk')
|
||||||
|
|
||||||
bot = Bot(labeler=handlers.labeler)
|
bot = Bot(labeler=handlers.labeler)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue