306 lines
14 KiB
Python
306 lines
14 KiB
Python
import base64
|
||
import datetime
|
||
import json
|
||
|
||
from dataclasses import dataclass
|
||
from io import BytesIO
|
||
from PIL import Image
|
||
from typing import List, Tuple, Any, Optional
|
||
|
||
from openrouter import OpenRouter, RetryConfig
|
||
from openrouter.components import ToolDefinitionJSONTypedDict, MessageTypedDict
|
||
from openrouter.utils import BackoffStrategy
|
||
|
||
from database import BasicDatabase
|
||
|
||
GROUP_CHAT_SYSTEM_PROMPT = """
|
||
Ты - ИИ-помощник в групповом чате.\n
|
||
Отвечай на вопросы и поддерживай контекст беседы.\n
|
||
Ты не можешь обсуждать политику и религию.\n
|
||
Сообщения пользователей будут приходить в следующем формате: '[дата время, имя]: текст сообщения'\n
|
||
При ответе НЕ нужно указывать ни время, ни пользователя, которому предназначен ответ, ни свое имя.\n
|
||
НЕ используй разметку Markdown.
|
||
"""
|
||
GROUP_CHAT_MAX_MESSAGES = 20
|
||
|
||
PRIVATE_CHAT_SYSTEM_PROMPT = """
|
||
Ты - ИИ-помощник в чате c пользователем.\n
|
||
Отвечай на вопросы и поддерживай контекст беседы.\n
|
||
Сообщения пользователя будут приходить в следующем формате: '[дата время]: текст сообщения'\n
|
||
При ответе НЕ нужно указывать время.\n
|
||
Никогда не используй разметку Markdown.\n
|
||
Никогда не добавляй ASCII-арты в ответ.
|
||
"""
|
||
PRIVATE_CHAT_MAX_MESSAGES = 40
|
||
|
||
OPENROUTER_HEADERS = {
|
||
'HTTP-Referer': 'https://ultracoder.org',
|
||
'X-Title': 'TG/VK Chat Bot'
|
||
}
|
||
|
||
|
||
@dataclass()
|
||
class Message:
|
||
user_name: str = None
|
||
text: str = None
|
||
image: bytes = None
|
||
message_id: int = None
|
||
|
||
|
||
def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str:
|
||
current_time = datetime.datetime.now().strftime("%d.%m.%Y %H:%M")
|
||
prefix = f"[{current_time}, {username}]" if username is not None else f"[{current_time}]"
|
||
return f"{prefix}: {text}" if text is not None else f"{prefix}:"
|
||
|
||
|
||
def _encode_image(image: bytes) -> str:
|
||
encoded_image = base64.b64encode(image).decode('utf-8')
|
||
return f"data:image/jpeg;base64,{encoded_image}"
|
||
|
||
|
||
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
|
||
serialized = {"role": role, "content": []}
|
||
if text is not None:
|
||
serialized["content"].append({"type": "text", "text": text})
|
||
if image is not None:
|
||
serialized["content"].append({"type": "image_url", "image_url": {"url": _encode_image(image)}})
|
||
|
||
return serialized
|
||
|
||
|
||
def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]:
|
||
return [{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "generate_image",
|
||
"description": """
|
||
Генерация изображения по описанию.
|
||
Используй этот инструмент, если пользователь просит сгенерировать изображение ('нарисуй', 'покажи' и т.п.),
|
||
ИЛИ если это улучшит ответ (например, в ролевой игре для визуализации сцены).
|
||
""",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"prompt": {
|
||
"type": "string",
|
||
"description": """
|
||
Детальное описание на английском (рекомендуется).
|
||
Добавь детали для стиля, цвета, композиции, если нужно.
|
||
"""
|
||
},
|
||
"aspect_ratio": {
|
||
"type": "string",
|
||
"enum": ["1:1", "3:4", "4:3", "9:16", "16:9"],
|
||
"description": "Соотношение сторон (опционально)."
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}]
|
||
|
||
|
||
class AiAgent:
|
||
def __init__(self,
|
||
api_token_main: str, model_main: str,
|
||
api_token_image: str, model_image: str,
|
||
db: BasicDatabase,
|
||
platform: str):
|
||
retry_config = RetryConfig(strategy="backoff",
|
||
backoff=BackoffStrategy(
|
||
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
|
||
retry_connection_errors=True)
|
||
self.db = db
|
||
self.model_main = model_main
|
||
self.model_image = model_image
|
||
self.platform = platform
|
||
self.client = OpenRouter(api_key=api_token_main, retry_config=retry_config)
|
||
self.client_image = OpenRouter(api_key=api_token_image)
|
||
|
||
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
||
message: Message, forwarded_messages: List[Message]) -> Tuple[str, bool]:
|
||
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
|
||
|
||
message.text = _add_message_prefix(message.text, message.user_name)
|
||
context.append(_serialize_message(role="user", text=message.text, image=message.image))
|
||
|
||
for fwd_message in forwarded_messages:
|
||
message_text = '<Цитируемое сообщение от {}>'.format(fwd_message.user_name)
|
||
if fwd_message.text is not None:
|
||
message_text += '\n' + fwd_message.text
|
||
fwd_message.text = message_text
|
||
context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image))
|
||
|
||
try:
|
||
response = await self.client.chat.send_async(
|
||
model=self.model_main,
|
||
messages=context,
|
||
max_tokens=500,
|
||
user=f'{self.platform}_{bot_id}_{chat_id}',
|
||
http_headers=OPENROUTER_HEADERS
|
||
)
|
||
ai_response = response.choices[0].message.content
|
||
|
||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||
for fwd_message in forwarded_messages:
|
||
self.db.context_add_message(bot_id, chat_id, role="user", text=fwd_message.text, image=fwd_message.image,
|
||
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||
self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None,
|
||
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||
|
||
return ai_response, True
|
||
|
||
except Exception as e:
|
||
if str(e).find("Rate limit exceeded") != -1:
|
||
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК).", False
|
||
else:
|
||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||
return f"Извините, при обработке запроса произошла ошибка.", False
|
||
|
||
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]:
|
||
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
|
||
message.text = _add_message_prefix(message.text)
|
||
content: list[dict[str, Any]] = []
|
||
if message.text is not None:
|
||
content.append({"type": "text", "text": message.text})
|
||
if message.image is not None:
|
||
content.append({"type": "image_url", "image_url": {"url": _encode_image(message.image)}})
|
||
context.append({"role": "user", "content": content})
|
||
|
||
try:
|
||
user_tag = f'{self.platform}_{bot_id}_{chat_id}'
|
||
|
||
response = await self.client.chat.send_async(
|
||
model=self.model_main,
|
||
messages=context,
|
||
tools=_get_tools_description(),
|
||
tool_choice="auto",
|
||
max_tokens=500,
|
||
user=user_tag,
|
||
http_headers=OPENROUTER_HEADERS
|
||
)
|
||
ai_response = response.choices[0].message.content
|
||
context.append(response.choices[0].message)
|
||
|
||
image_response_image: Optional[bytes] = None
|
||
if response.choices[0].message.tool_calls is not None:
|
||
for tool_call in response.choices[0].message.tool_calls:
|
||
tool_name = tool_call.function.name
|
||
tool_args = json.loads(tool_call.function.arguments)
|
||
if tool_name == "generate_image":
|
||
prompt = tool_args.get("prompt", "")
|
||
aspect_ratio = tool_args.get("aspect_ratio", None)
|
||
image_response_text, image_response_image =\
|
||
await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio, user_tag=user_tag)
|
||
context.append({
|
||
"role": "tool",
|
||
"tool_call_id": tool_call.id,
|
||
"content": [
|
||
{"type": "text",
|
||
"text": """
|
||
Изображение сгенерировано и будет показано пользователю.
|
||
НЕ добавляй никаких тегов или маркеров вроде <image>, [image] — они запрещены и не нужны.
|
||
"""},
|
||
{"type": "image_url", "image_url": {"url": _encode_image(image_response_image)}}
|
||
]
|
||
})
|
||
|
||
response2 = await self.client.chat.send_async(
|
||
model=self.model_main,
|
||
messages=context,
|
||
max_tokens=500,
|
||
user=user_tag,
|
||
http_headers=OPENROUTER_HEADERS
|
||
)
|
||
ai_response = response2.choices[0].message.content
|
||
break
|
||
|
||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||
|
||
if image_response_image is not None:
|
||
self.db.context_add_message(bot_id, chat_id, role="assistant",
|
||
text=ai_response, image=image_response_image,
|
||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||
else:
|
||
self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None,
|
||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||
|
||
return Message(text=ai_response, image=image_response_image), True
|
||
|
||
except Exception as e:
|
||
if str(e).find("Rate limit exceeded") != -1:
|
||
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
|
||
else:
|
||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||
return Message(text=f"Извините, при обработке запроса произошла ошибка."), False
|
||
|
||
def get_last_assistant_message_id(self, bot_id: int, chat_id: int):
|
||
return self.db.context_get_last_assistant_message_id(bot_id, chat_id)
|
||
|
||
def set_last_response_id(self, bot_id: int, chat_id: int, message_id: int):
|
||
self.db.context_set_last_message_id(bot_id, chat_id, message_id)
|
||
|
||
def clear_chat_context(self, bot_id: int, chat_id: int):
|
||
self.db.context_clear(bot_id, chat_id)
|
||
|
||
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[MessageTypedDict]:
|
||
prompt = GROUP_CHAT_SYSTEM_PROMPT if is_group_chat else PRIVATE_CHAT_SYSTEM_PROMPT
|
||
|
||
bot = self.db.get_bot(bot_id)
|
||
if bot['ai_prompt'] is not None:
|
||
prompt += '\n\n' + bot['ai_prompt']
|
||
|
||
chat = self.db.create_chat_if_not_exists(bot_id, chat_id)
|
||
if chat['ai_prompt'] is not None:
|
||
prompt += '\n\n' + chat['ai_prompt']
|
||
|
||
messages = self.db.context_get_messages(bot_id, chat_id)
|
||
|
||
context: List[MessageTypedDict] = [{"role": "system", "content": prompt}]
|
||
for message in messages:
|
||
context.append(_serialize_message(message["role"], message["text"], message["image"]))
|
||
return context
|
||
|
||
async def _generate_image(self, prompt: str, aspect_ratio: Optional[str], user_tag: str) -> Tuple[str, bytes]:
|
||
context = [{"role": "user", "content": prompt}]
|
||
if aspect_ratio is not None:
|
||
response = await self.client_image.chat.send_async(
|
||
model=self.model_image,
|
||
messages=context,
|
||
user=user_tag,
|
||
modalities=["image"],
|
||
image_config={"aspect_ratio": aspect_ratio}
|
||
)
|
||
else:
|
||
response = await self.client_image.chat.send_async(
|
||
model=self.model_image,
|
||
messages=context,
|
||
user=user_tag,
|
||
modalities=["image"]
|
||
)
|
||
|
||
content = response.choices[0].message.content
|
||
|
||
image_url = response.choices[0].message.images[0].image_url.url
|
||
header, image_base64 = image_url.split(",", 1)
|
||
mime_type = header.split(";")[0].replace("data:", "")
|
||
image_bytes = base64.b64decode(image_base64)
|
||
|
||
if mime_type != "image/jpeg":
|
||
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
||
output = BytesIO()
|
||
image.save(output, format="JPEG", quality=80, optimize=True)
|
||
image_bytes = output.getvalue()
|
||
|
||
return content, image_bytes
|
||
|
||
|
||
agent: AiAgent
|
||
|
||
|
||
def create_ai_agent(api_token_main: str, model_main: str,
|
||
api_token_image: str, model_image: str,
|
||
db: BasicDatabase, platform: str):
|
||
global agent
|
||
agent = AiAgent(api_token_main, model_main, api_token_image, model_image, db, platform)
|