diff --git a/ai_agent.py b/ai_agent.py index 39b910b..acc1c20 100644 --- a/ai_agent.py +++ b/ai_agent.py @@ -18,6 +18,8 @@ from openrouter.utils import BackoffStrategy from fal_client import AsyncClient as FalClient +import replicate + from database import BasicDatabase OPENROUTER_X_TITLE = "TG/VK Chat Bot" @@ -91,7 +93,7 @@ def _get_resolution_for_aspect_ratio(aspect_ratio: str) -> Tuple[int, int]: class AiAgent: def __init__(self, openrouter_token: str, openrouter_model: str, - fal_token: str, + fal_token: str, replicate_token: str, db: BasicDatabase, platform: str): retry_config = RetryConfig(strategy="backoff", @@ -109,6 +111,7 @@ class AiAgent: x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER, retry_config=retry_config) self.client_fal = FalClient(key=fal_token) + self.replicate_client = replicate.Client(api_token=replicate_token) @dataclass() class _ToolsArtifacts: @@ -341,8 +344,7 @@ class AiAgent: "text": f"Не удалось сгенерировать изображение: {result.err_value}"}) return content - @staticmethod - async def _generate_image_anime(prompt: str, negative_prompt: str, aspect_ratio: Optional[str]) \ + async def _generate_image_anime(self, prompt: str, negative_prompt: str, aspect_ratio: Optional[str]) \ -> Result[bytes, str]: width, height = _get_resolution_for_aspect_ratio(aspect_ratio) print(f"Генерация изображения {width}x{height}: positive='{prompt}', negative='{negative_prompt}'") @@ -351,24 +353,23 @@ class AiAgent: "prompt": prompt, "negative_prompt": negative_prompt, "width": width, - "height": height + "height": height, + "cfg": 6, + "steps": 30, + "disable_safety_checker": True } try: - async with aiohttp.ClientSession() as session: - async with session.post("http://192.168.64.2:8787/sdapi/v1/txt2img", - json=arguments, timeout=120) as response: - if response.status == 200: - data = await response.json() - image_base64 = data["images"][0] - image_bytes = base64.b64decode(image_base64) - image = Image.open(BytesIO(image_bytes)).convert("RGB") - output = BytesIO() - image.save(output, format="JPEG", quality=80, optimize=True) - image_bytes = output.getvalue() - return Ok(image_bytes) - else: - raise RuntimeError(f"Сервер вернул код {response.status}") + outputs = await self.replicate_client.async_run( + "ultracoderru/nova-anime-xl-il-140:2af9bf809587d173212ddf9679d99f1d7f9a5442ed23c0c02e77d3a230865303", + input=arguments) + + image_bytes = await outputs[0].aread() + image = Image.open(BytesIO(image_bytes)).convert("RGB") + output = BytesIO() + image.save(output, format="JPEG", quality=80, optimize=True) + image_bytes = output.getvalue() + return Ok(image_bytes) except Exception as e: print(f"Ошибка генерации изображения: {e}") return Err(str(e)) @@ -420,7 +421,8 @@ class AiAgent: agent: AiAgent -def create_ai_agent(openrouter_token: str, openrouter_model: str, fal_token: str, +def create_ai_agent(openrouter_token: str, openrouter_model: str, + fal_token: str, replicate_token: str, db: BasicDatabase, platform: str): global agent - agent = AiAgent(openrouter_token, openrouter_model, fal_token, db, platform) + agent = AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, db, platform) diff --git a/tg/__main__.py b/tg/__main__.py index e89ca16..2b0e77b 100644 --- a/tg/__main__.py +++ b/tg/__main__.py @@ -24,7 +24,8 @@ async def main() -> None: database.create_database(config['db_connection_string']) - create_ai_agent(config['openrouter_token'], config['openrouter_model'], config['fal_token'], + create_ai_agent(config['openrouter_token'], config['openrouter_model'], + config['fal_token'], config['replicate_token'], database.DB, 'tg') bots: list[Bot] = [] diff --git a/vk/__main__.py b/vk/__main__.py index d8d8c3c..1d6bd75 100644 --- a/vk/__main__.py +++ b/vk/__main__.py @@ -24,7 +24,8 @@ if __name__ == '__main__': database.create_database(config['db_connection_string']) - create_ai_agent(config['openrouter_token'], config['openrouter_model'], config['fal_token'], + create_ai_agent(config['openrouter_token'], config['openrouter_model'], + config['fal_token'], config['replicate_token'], database.DB, 'vk') bot = Bot(labeler=handlers.labeler)