425 lines
19 KiB
Python
425 lines
19 KiB
Python
import aiohttp
|
||
import base64
|
||
import datetime
|
||
import json
|
||
|
||
from collections.abc import Callable
|
||
from dataclasses import dataclass
|
||
from io import BytesIO
|
||
from PIL import Image
|
||
from result import Ok, Err, Result
|
||
from typing import List, Tuple, Any, Optional, Union, Dict, Awaitable
|
||
|
||
from openrouter import OpenRouter, RetryConfig
|
||
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \
|
||
ChatMessageToolCall, MessageTypedDict
|
||
from openrouter.errors import ResponseValidationError, ChatError
|
||
from openrouter.utils import BackoffStrategy
|
||
|
||
from fal_client import AsyncClient as FalClient
|
||
from replicate import Client as ReplicateClient
|
||
|
||
from database import BasicDatabase
|
||
|
||
OPENROUTER_X_TITLE = "TG/VK Chat Bot"
|
||
OPENROUTER_HTTP_REFERER = "https://ultracoder.org"
|
||
|
||
GROUP_CHAT_MAX_MESSAGES = 20
|
||
PRIVATE_CHAT_MAX_MESSAGES = 40
|
||
MAX_OUTPUT_TOKENS = 500
|
||
|
||
FAL_MODEL = "fal-ai/bytedance/seedream/v4.5/text-to-image"
|
||
REPLICATE_MODEL = "ultracoderru/nova-anime-xl-il-140:2af9bf809587d173212ddf9679d99f1d7f9a5442ed23c0c02e77d3a230865303"
|
||
|
||
|
||
@dataclass()
|
||
class Message:
|
||
user_name: str = None
|
||
text: str = None
|
||
image: bytes = None
|
||
message_id: int = None
|
||
|
||
|
||
class AiAgent:
|
||
def __init__(self,
|
||
openrouter_token: str, openrouter_model: str,
|
||
fal_token: str, replicate_token: str,
|
||
db: BasicDatabase,
|
||
platform: str):
|
||
retry_config = RetryConfig(strategy="backoff",
|
||
backoff=BackoffStrategy(
|
||
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
|
||
retry_connection_errors=True)
|
||
self.db = db
|
||
self.openrouter_model = openrouter_model
|
||
self.platform = platform
|
||
|
||
self._load_prompts()
|
||
|
||
self.client_openrouter = OpenRouter(api_key=openrouter_token,
|
||
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER,
|
||
retry_config=retry_config)
|
||
self.client_fal = FalClient(key=fal_token)
|
||
self.replicate_client = ReplicateClient(api_token=replicate_token)
|
||
|
||
@dataclass()
|
||
class _ToolsArtifacts:
|
||
generated_image: Optional[bytes] = None
|
||
|
||
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
||
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
|
||
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
|
||
|
||
message.text = _add_message_prefix(message.text, message.user_name)
|
||
context.append(_serialize_message(role="user", text=message.text, image=message.image))
|
||
|
||
for fwd_message in forwarded_messages:
|
||
message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name)
|
||
if fwd_message.text is not None:
|
||
message_text += '\n' + fwd_message.text
|
||
fwd_message.text = message_text
|
||
context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image))
|
||
|
||
try:
|
||
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
|
||
ai_response = response.content
|
||
|
||
tools_artifacts = AiAgent._ToolsArtifacts()
|
||
if response.tool_calls is not None:
|
||
tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
|
||
tool_calls=response.tool_calls, context=context)
|
||
response2 = await self._generate_reply(bot_id, chat_id, context=context)
|
||
ai_response = response2.content
|
||
|
||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||
for fwd_message in forwarded_messages:
|
||
self.db.context_add_message(bot_id, chat_id,
|
||
role="user", text=fwd_message.text, image=fwd_message.image,
|
||
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||
self.db.context_add_message(bot_id, chat_id,
|
||
role="assistant", text=ai_response, image=tools_artifacts.generated_image,
|
||
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||
|
||
return Message(text=ai_response, image=tools_artifacts.generated_image), True
|
||
|
||
except Exception as e:
|
||
if str(e).find("Rate limit exceeded") != -1:
|
||
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
|
||
else:
|
||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
|
||
|
||
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]:
|
||
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
|
||
message.text = _add_message_prefix(message.text)
|
||
content: list[dict[str, Any]] = []
|
||
if message.text is not None:
|
||
content.append({"type": "text", "text": message.text})
|
||
if message.image is not None:
|
||
content.append({"type": "image_url", "image_url": {"url": _encode_image(message.image)}})
|
||
context.append({"role": "user", "content": content})
|
||
|
||
try:
|
||
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
|
||
context.append(_serialize_assistant_message(response))
|
||
ai_response = response.content
|
||
|
||
tools_artifacts = AiAgent._ToolsArtifacts()
|
||
if response.tool_calls is not None:
|
||
tools_artifacts = await self._process_tool_calls(bot_id, chat_id,
|
||
tool_calls=response.tool_calls, context=context)
|
||
response2 = await self._generate_reply(bot_id, chat_id, context=context)
|
||
ai_response = response2.content
|
||
|
||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||
self.db.context_add_message(bot_id, chat_id, role="assistant",
|
||
text=ai_response, image=tools_artifacts.generated_image,
|
||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||
|
||
return Message(text=ai_response, image=tools_artifacts.generated_image), True
|
||
|
||
except Exception as e:
|
||
if str(e).find("Rate limit exceeded") != -1:
|
||
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
|
||
else:
|
||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
|
||
|
||
def get_last_assistant_message_id(self, bot_id: int, chat_id: int):
|
||
return self.db.context_get_last_assistant_message_id(bot_id, chat_id)
|
||
|
||
def set_last_response_id(self, bot_id: int, chat_id: int, message_id: int):
|
||
self.db.context_set_last_message_id(bot_id, chat_id, message_id)
|
||
|
||
def clear_chat_context(self, bot_id: int, chat_id: int):
|
||
self.db.context_clear(bot_id, chat_id)
|
||
|
||
#######################################
|
||
|
||
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[MessageTypedDict]:
|
||
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
|
||
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
|
||
prompt += '\n' + self.system_prompt_image_generation
|
||
|
||
bot = self.db.get_bot(bot_id)
|
||
if bot['ai_prompt'] is not None:
|
||
prompt += '\n' + bot['ai_prompt'] + '\n'
|
||
|
||
chat = self.db.create_chat_if_not_exists(bot_id, chat_id)
|
||
if chat['ai_prompt'] is not None:
|
||
prompt += '\n' + chat['ai_prompt']
|
||
|
||
messages = self.db.context_get_messages(bot_id, chat_id)
|
||
|
||
context: List[MessageTypedDict] = [{"role": "system", "content": prompt}]
|
||
for message in messages:
|
||
context.append(_serialize_message(message["role"], message["text"], message["image"]))
|
||
return context
|
||
|
||
async def _generate_reply(self, bot_id: int, chat_id: int,
|
||
context: List[MessageTypedDict], allow_tools: bool = False) -> AssistantMessage:
|
||
response = await self._async_chat_completion_request(
|
||
model=self.openrouter_model,
|
||
messages=context,
|
||
tools=self.tools_description if allow_tools else None,
|
||
tool_choice="auto" if allow_tools else None,
|
||
max_tokens=MAX_OUTPUT_TOKENS,
|
||
user=f'{self.platform}_{bot_id}_{chat_id}'
|
||
)
|
||
return self._filter_response(response.choices[0].message)
|
||
|
||
async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall],
|
||
context: List[MessageTypedDict]) -> _ToolsArtifacts:
|
||
artifacts = AiAgent._ToolsArtifacts()
|
||
if tool_calls is None:
|
||
return artifacts
|
||
|
||
functions_map: Dict[str,
|
||
Callable[[int, int, Dict, AiAgent._ToolsArtifacts],
|
||
Awaitable[List[ChatMessageContentItemTypedDict]]]] = {
|
||
"generate_image": self._process_tool_generate_image,
|
||
"generate_image_anime": self._process_tool_generate_image_anime
|
||
}
|
||
|
||
for tool_call in tool_calls:
|
||
tool_name = tool_call.function.name
|
||
tool_args = json.loads(tool_call.function.arguments)
|
||
if tool_name in functions_map:
|
||
tool_result = await functions_map[tool_name](bot_id, chat_id, tool_args, artifacts)
|
||
context.append({
|
||
"role": "tool",
|
||
"tool_call_id": tool_call.id,
|
||
"content": tool_result
|
||
})
|
||
artifacts.tools_called = True
|
||
return artifacts
|
||
|
||
async def _process_tool_generate_image(self, _bot_id: int, _chat_id: int, args: dict, artifacts: _ToolsArtifacts) \
|
||
-> List[ChatMessageContentItemTypedDict]:
|
||
prompt = args.get("prompt", "")
|
||
aspect_ratio = args.get("aspect_ratio", None)
|
||
result = await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio)
|
||
|
||
content = []
|
||
if result.is_ok():
|
||
content.append(
|
||
{"type": "text",
|
||
"text": "Изображение сгенерировано и будет показано пользователю."})
|
||
content.append(
|
||
{"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}})
|
||
artifacts.generated_image = result.ok_value
|
||
else:
|
||
content.append(
|
||
{"type": "text",
|
||
"text": f"Не удалось сгенерировать изображение: {result.err_value}"})
|
||
return content
|
||
|
||
async def _generate_image(self, prompt: str, aspect_ratio: Optional[str]) -> Result[bytes, str]:
|
||
width, height = _get_resolution_for_aspect_ratio(aspect_ratio)
|
||
print(f"Генерация изображения {width}x{height}: {prompt}")
|
||
|
||
arguments = {
|
||
"prompt": prompt,
|
||
"image_size": {"width": width, "height": height},
|
||
"enable_safety_checker": False
|
||
}
|
||
|
||
try:
|
||
result = await self.client_fal.run(FAL_MODEL, arguments=arguments)
|
||
if "images" not in result:
|
||
raise RuntimeError("Неожиданный ответ от сервера.")
|
||
image_url = result["images"][0]["url"]
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(image_url) as response:
|
||
if response.status == 200:
|
||
image = await response.read()
|
||
if not image_url.endswith(".jpg"):
|
||
image = _convert_image_to_jpeg(image)
|
||
return Ok(image)
|
||
else:
|
||
raise RuntimeError(f"Не удалось загрузить изображение ({response.status}).")
|
||
|
||
except Exception as e:
|
||
print(f"Ошибка генерации изображения: {e}")
|
||
return Err(str(e))
|
||
|
||
async def _process_tool_generate_image_anime(self, _bot_id: int, _chat_id: int,
|
||
args: dict, artifacts: _ToolsArtifacts) \
|
||
-> List[ChatMessageContentItemTypedDict]:
|
||
prompt = args.get("prompt", "")
|
||
negative_prompt = args.get("negative_prompt", "")
|
||
aspect_ratio = args.get("aspect_ratio", None)
|
||
result = await self._generate_image_anime(prompt=prompt, negative_prompt=negative_prompt,
|
||
aspect_ratio=aspect_ratio)
|
||
|
||
content = []
|
||
if result.is_ok():
|
||
content.append(
|
||
{"type": "text",
|
||
"text": "Изображение сгенерировано и будет показано пользователю."})
|
||
content.append(
|
||
{"type": "image_url", "image_url": {"url": _encode_image(result.ok_value)}})
|
||
artifacts.generated_image = result.ok_value
|
||
else:
|
||
content.append(
|
||
{"type": "text",
|
||
"text": f"Не удалось сгенерировать изображение: {result.err_value}"})
|
||
return content
|
||
|
||
async def _generate_image_anime(self, prompt: str, negative_prompt: str, aspect_ratio: Optional[str]) \
|
||
-> Result[bytes, str]:
|
||
width, height = _get_resolution_for_aspect_ratio(aspect_ratio)
|
||
print(f"Генерация изображения {width}x{height}: positive='{prompt}', negative='{negative_prompt}'")
|
||
|
||
arguments = {
|
||
"prompt": prompt,
|
||
"negative_prompt": negative_prompt,
|
||
"width": width,
|
||
"height": height,
|
||
"cfg": 4.5,
|
||
"steps": 20,
|
||
"disable_safety_checker": True
|
||
}
|
||
|
||
try:
|
||
outputs = await self.replicate_client.async_run(REPLICATE_MODEL, input=arguments)
|
||
image = _convert_image_to_jpeg(await outputs[0].aread())
|
||
return Ok(image)
|
||
except Exception as e:
|
||
print(f"Ошибка генерации изображения: {e}")
|
||
return Err(str(e))
|
||
|
||
async def _async_chat_completion_request(self, **kwargs):
|
||
try:
|
||
return await self.client_openrouter.chat.send_async(**kwargs)
|
||
except ResponseValidationError as e:
|
||
# Костыль для OpenRouter SDK:
|
||
# https://github.com/OpenRouterTeam/python-sdk/issues/44
|
||
body = json.loads(e.body)
|
||
if "error" in body:
|
||
try:
|
||
raw_response = json.loads(body["error"]["metadata"]["raw"])
|
||
message = str(raw_response["error"]["message"])
|
||
e = RuntimeError(message)
|
||
except Exception:
|
||
pass
|
||
raise e
|
||
except ChatError as e:
|
||
if e.message == "Provider returned error":
|
||
body = json.loads(e.body)
|
||
try:
|
||
raw_response = json.loads(body["error"]["metadata"]["raw"])
|
||
message = str(raw_response["error"]["message"])
|
||
e = RuntimeError(message)
|
||
except Exception:
|
||
pass
|
||
raise e
|
||
|
||
@staticmethod
|
||
def _filter_response(response: AssistantMessage) -> AssistantMessage:
|
||
text = str(response.content)
|
||
text = text.replace("<image>", "")
|
||
response.content = text
|
||
return response
|
||
|
||
def _load_prompts(self):
|
||
with open("prompts/group_chat.txt", "r") as f:
|
||
self.system_prompt_group_chat = f.read()
|
||
with open("prompts/private_chat.txt", "r") as f:
|
||
self.system_prompt_private_chat = f.read()
|
||
with open("prompts/image_generation.txt", "r") as f:
|
||
self.system_prompt_image_generation = f.read()
|
||
with open("prompts/tools.json", "r") as f:
|
||
self.tools_description = json.loads(f.read())
|
||
|
||
|
||
agent: AiAgent
|
||
|
||
|
||
def create_ai_agent(openrouter_token: str, openrouter_model: str,
|
||
fal_token: str, replicate_token: str,
|
||
db: BasicDatabase, platform: str):
|
||
global agent
|
||
agent = AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, db, platform)
|
||
|
||
|
||
def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str:
|
||
current_time = datetime.datetime.now().strftime("%d.%m.%Y %H:%M")
|
||
prefix = f"[{current_time}, {username}]" if username is not None else f"[{current_time}]"
|
||
return f"{prefix}: {text}" if text is not None else f"{prefix}:"
|
||
|
||
|
||
def _encode_image(image: bytes) -> str:
|
||
encoded_image = base64.b64encode(image).decode('utf-8')
|
||
return f"data:image/jpeg;base64,{encoded_image}"
|
||
|
||
|
||
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
|
||
serialized = {"role": role, "content": []}
|
||
if text is not None:
|
||
serialized["content"].append({"type": "text", "text": text})
|
||
if image is not None:
|
||
serialized["content"].append({"type": "image_url", "image_url": {"url": _encode_image(image)}})
|
||
return serialized
|
||
|
||
|
||
def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]:
|
||
if isinstance(data, dict):
|
||
return {
|
||
k: _remove_none_recursive(v)
|
||
for k, v in data.items()
|
||
if v is not None
|
||
}
|
||
elif isinstance(data, list):
|
||
return [
|
||
_remove_none_recursive(item)
|
||
for item in data
|
||
if item is not None
|
||
]
|
||
else:
|
||
return data
|
||
|
||
|
||
def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict:
|
||
return _remove_none_recursive(message.model_dump(by_alias=True))
|
||
|
||
|
||
def _get_resolution_for_aspect_ratio(aspect_ratio: str) -> Tuple[int, int]:
|
||
aspect_ratio_resolution_map = {
|
||
"1:1": (1280, 1280),
|
||
"4:3": (1280, 1024),
|
||
"3:4": (1024, 1280),
|
||
"16:9": (1280, 720),
|
||
"9:16": (720, 1280)
|
||
}
|
||
return aspect_ratio_resolution_map.get(aspect_ratio, (1280, 1024))
|
||
|
||
|
||
def _convert_image_to_jpeg(image: bytes) -> bytes:
|
||
img = Image.open(BytesIO(image)).convert("RGB")
|
||
output = BytesIO()
|
||
img.save(output, format='JPEG', quality=95, optimize=True)
|
||
return output.getvalue()
|