Реализована генерация изображений в групповых чатах Telegram.
This commit is contained in:
parent
cb60b23ae7
commit
e05e0d4c82
5 changed files with 169 additions and 120 deletions
270
ai_agent.py
270
ai_agent.py
|
|
@ -5,10 +5,12 @@ import json
|
|||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from typing import List, Tuple, Any, Optional
|
||||
from typing import List, Tuple, Any, Optional, Union
|
||||
|
||||
from openrouter import OpenRouter, RetryConfig
|
||||
from openrouter.components import ToolDefinitionJSONTypedDict, MessageTypedDict
|
||||
from openrouter.components import AssistantMessage, ChatMessageToolCall, \
|
||||
MessageTypedDict, ToolDefinitionJSONTypedDict, AssistantMessageTypedDict
|
||||
|
||||
from openrouter.utils import BackoffStrategy
|
||||
|
||||
from database import BasicDatabase
|
||||
|
|
@ -19,24 +21,38 @@ GROUP_CHAT_SYSTEM_PROMPT = """
|
|||
Ты не можешь обсуждать политику и религию.\n
|
||||
Сообщения пользователей будут приходить в следующем формате: '[дата время, имя]: текст сообщения'\n
|
||||
При ответе НЕ нужно указывать ни время, ни пользователя, которому предназначен ответ, ни свое имя.\n
|
||||
НЕ используй разметку Markdown.
|
||||
НЕ используй разметку Markdown, она не поддерживается мессенджером.\n
|
||||
Если нужно нарисовать изображение, используй вызов инструмента.
|
||||
Запрещено генерировать ASCII-арты, а также добавлять теги/маркеры вроде <image>, [image].
|
||||
"""
|
||||
GROUP_CHAT_MAX_MESSAGES = 20
|
||||
|
||||
PRIVATE_CHAT_SYSTEM_PROMPT = """
|
||||
Ты - ИИ-помощник в чате c пользователем.\n
|
||||
Отвечай на вопросы и поддерживай контекст беседы.\n
|
||||
Сообщения пользователя будут приходить в следующем формате: '[дата время]: текст сообщения'\n
|
||||
При ответе НЕ нужно указывать время.\n
|
||||
Никогда не используй разметку Markdown.\n
|
||||
Никогда не добавляй ASCII-арты в ответ.
|
||||
НЕ используй разметку Markdown, она не поддерживается мессенджером.\n
|
||||
Если нужно нарисовать изображение, используй вызов инструмента.
|
||||
Запрещено генерировать ASCII-арты, а также добавлять теги/маркеры вроде <image>, [image].
|
||||
"""
|
||||
PRIVATE_CHAT_MAX_MESSAGES = 40
|
||||
|
||||
OPENROUTER_HEADERS = {
|
||||
'HTTP-Referer': 'https://ultracoder.org',
|
||||
'X-Title': 'TG/VK Chat Bot'
|
||||
}
|
||||
GENERATE_IMAGE_TOOL_DESCRIPTION = """
|
||||
Генерация изображения по описанию.
|
||||
Используй этот инструмент, если пользователь просит сгенерировать изображение ('нарисуй', 'покажи' и т.п.),
|
||||
или если это улучшит ответ (например, в ролевой игре для визуализации сцены).
|
||||
"""
|
||||
|
||||
GENERATE_IMAGE_TOOL_PROMPT_ARG_DESCRIPTION = """
|
||||
Детальное описание на русском языке.
|
||||
Добавь детали для стиля, цвета, композиции, если нужно.
|
||||
"""
|
||||
|
||||
OPENROUTER_X_TITLE = "TG/VK Chat Bot"
|
||||
OPENROUTER_HTTP_REFERER = "https://ultracoder.org"
|
||||
|
||||
GROUP_CHAT_MAX_MESSAGES = 20
|
||||
PRIVATE_CHAT_MAX_MESSAGES = 40
|
||||
MAX_OUTPUT_TOKENS = 500
|
||||
|
||||
|
||||
@dataclass()
|
||||
|
|
@ -64,42 +80,61 @@ def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -
|
|||
serialized["content"].append({"type": "text", "text": text})
|
||||
if image is not None:
|
||||
serialized["content"].append({"type": "image_url", "image_url": {"url": _encode_image(image)}})
|
||||
|
||||
return serialized
|
||||
|
||||
|
||||
def _remove_none_recursive(data: Union[dict, list, any]) -> Union[dict, list, any]:
|
||||
if isinstance(data, dict):
|
||||
return {
|
||||
k: _remove_none_recursive(v)
|
||||
for k, v in data.items()
|
||||
if v is not None
|
||||
}
|
||||
elif isinstance(data, list):
|
||||
return [
|
||||
_remove_none_recursive(item)
|
||||
for item in data
|
||||
if item is not None
|
||||
]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def _serialize_assistant_message(message: AssistantMessage) -> AssistantMessageTypedDict:
|
||||
return _remove_none_recursive(message.model_dump(by_alias=True))
|
||||
|
||||
|
||||
def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]:
|
||||
return [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_image",
|
||||
"description": """
|
||||
Генерация изображения по описанию.
|
||||
Используй этот инструмент, если пользователь просит сгенерировать изображение ('нарисуй', 'покажи' и т.п.),
|
||||
ИЛИ если это улучшит ответ (например, в ролевой игре для визуализации сцены).
|
||||
""",
|
||||
"description": GENERATE_IMAGE_TOOL_DESCRIPTION,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": """
|
||||
Детальное описание на английском (рекомендуется).
|
||||
Добавь детали для стиля, цвета, композиции, если нужно.
|
||||
"""
|
||||
"description": GENERATE_IMAGE_TOOL_PROMPT_ARG_DESCRIPTION
|
||||
},
|
||||
"aspect_ratio": {
|
||||
"type": "string",
|
||||
"enum": ["1:1", "3:4", "4:3", "9:16", "16:9"],
|
||||
"description": "Соотношение сторон (опционально)."
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["prompt"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
|
||||
class AiAgent:
|
||||
@dataclass()
|
||||
class ToolCallResult:
|
||||
tools_called: bool = False
|
||||
generated_image: Optional[bytes] = None
|
||||
|
||||
def __init__(self,
|
||||
api_token_main: str, model_main: str,
|
||||
api_token_image: str, model_image: str,
|
||||
|
|
@ -113,11 +148,14 @@ class AiAgent:
|
|||
self.model_main = model_main
|
||||
self.model_image = model_image
|
||||
self.platform = platform
|
||||
self.client = OpenRouter(api_key=api_token_main, retry_config=retry_config)
|
||||
self.client_image = OpenRouter(api_key=api_token_image)
|
||||
self.client = OpenRouter(api_key=api_token_main,
|
||||
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER,
|
||||
retry_config=retry_config)
|
||||
self.client_image = OpenRouter(api_key=api_token_image,
|
||||
x_title=OPENROUTER_X_TITLE, http_referer=OPENROUTER_HTTP_REFERER)
|
||||
|
||||
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
||||
message: Message, forwarded_messages: List[Message]) -> Tuple[str, bool]:
|
||||
message: Message, forwarded_messages: List[Message]) -> Tuple[Message, bool]:
|
||||
context = self._get_chat_context(is_group_chat=True, bot_id=bot_id, chat_id=chat_id)
|
||||
|
||||
message.text = _add_message_prefix(message.text, message.user_name)
|
||||
|
|
@ -131,31 +169,33 @@ class AiAgent:
|
|||
context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image))
|
||||
|
||||
try:
|
||||
response = await self.client.chat.send_async(
|
||||
model=self.model_main,
|
||||
messages=context,
|
||||
max_tokens=500,
|
||||
user=f'{self.platform}_{bot_id}_{chat_id}',
|
||||
http_headers=OPENROUTER_HEADERS
|
||||
)
|
||||
ai_response = response.choices[0].message.content
|
||||
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
|
||||
ai_response = response.content
|
||||
|
||||
tools_call_result = await self._process_tool_calls(bot_id, chat_id,
|
||||
tool_calls=response.tool_calls, context=context)
|
||||
if tools_call_result.tools_called:
|
||||
response2 = await self._generate_reply(bot_id, chat_id, context=context)
|
||||
ai_response = response2.content
|
||||
|
||||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||||
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||
for fwd_message in forwarded_messages:
|
||||
self.db.context_add_message(bot_id, chat_id, role="user", text=fwd_message.text, image=fwd_message.image,
|
||||
self.db.context_add_message(bot_id, chat_id,
|
||||
role="user", text=fwd_message.text, image=fwd_message.image,
|
||||
message_id=fwd_message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||
self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None,
|
||||
self.db.context_add_message(bot_id, chat_id,
|
||||
role="assistant", text=ai_response, image=tools_call_result.generated_image,
|
||||
message_id=None, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||
|
||||
return ai_response, True
|
||||
return Message(text=ai_response, image=tools_call_result.generated_image), True
|
||||
|
||||
except Exception as e:
|
||||
if str(e).find("Rate limit exceeded") != -1:
|
||||
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК).", False
|
||||
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
|
||||
else:
|
||||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||||
return f"Извините, при обработке запроса произошла ошибка.", False
|
||||
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
|
||||
|
||||
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]:
|
||||
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
|
||||
|
|
@ -168,72 +208,30 @@ class AiAgent:
|
|||
context.append({"role": "user", "content": content})
|
||||
|
||||
try:
|
||||
user_tag = f'{self.platform}_{bot_id}_{chat_id}'
|
||||
response = await self._generate_reply(bot_id, chat_id, context=context, allow_tools=True)
|
||||
context.append(_serialize_assistant_message(response))
|
||||
ai_response = response.content
|
||||
|
||||
response = await self.client.chat.send_async(
|
||||
model=self.model_main,
|
||||
messages=context,
|
||||
tools=_get_tools_description(),
|
||||
tool_choice="auto",
|
||||
max_tokens=500,
|
||||
user=user_tag,
|
||||
http_headers=OPENROUTER_HEADERS
|
||||
)
|
||||
ai_response = response.choices[0].message.content
|
||||
context.append(response.choices[0].message)
|
||||
|
||||
image_response_image: Optional[bytes] = None
|
||||
if response.choices[0].message.tool_calls is not None:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
if tool_name == "generate_image":
|
||||
prompt = tool_args.get("prompt", "")
|
||||
aspect_ratio = tool_args.get("aspect_ratio", None)
|
||||
image_response_text, image_response_image =\
|
||||
await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio, user_tag=user_tag)
|
||||
context.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": [
|
||||
{"type": "text",
|
||||
"text": """
|
||||
Изображение сгенерировано и будет показано пользователю.
|
||||
НЕ добавляй никаких тегов или маркеров вроде <image>, [image] — они запрещены и не нужны.
|
||||
"""},
|
||||
{"type": "image_url", "image_url": {"url": _encode_image(image_response_image)}}
|
||||
]
|
||||
})
|
||||
|
||||
response2 = await self.client.chat.send_async(
|
||||
model=self.model_main,
|
||||
messages=context,
|
||||
max_tokens=500,
|
||||
user=user_tag,
|
||||
http_headers=OPENROUTER_HEADERS
|
||||
)
|
||||
ai_response = response2.choices[0].message.content
|
||||
break
|
||||
tools_call_result = await self._process_tool_calls(bot_id, chat_id,
|
||||
tool_calls=response.tool_calls, context=context)
|
||||
if tools_call_result.tools_called:
|
||||
response2 = await self._generate_reply(bot_id, chat_id, context=context)
|
||||
ai_response = response2.content
|
||||
|
||||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||||
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
self.db.context_add_message(bot_id, chat_id, role="assistant",
|
||||
text=ai_response, image=tools_call_result.generated_image,
|
||||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
|
||||
if image_response_image is not None:
|
||||
self.db.context_add_message(bot_id, chat_id, role="assistant",
|
||||
text=ai_response, image=image_response_image,
|
||||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
else:
|
||||
self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None,
|
||||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
|
||||
return Message(text=ai_response, image=image_response_image), True
|
||||
return Message(text=ai_response, image=tools_call_result.generated_image), True
|
||||
|
||||
except Exception as e:
|
||||
if str(e).find("Rate limit exceeded") != -1:
|
||||
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
|
||||
else:
|
||||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||||
return Message(text=f"Извините, при обработке запроса произошла ошибка."), False
|
||||
return Message(text=f"Извините, при обработке запроса произошла ошибка:\n{e}"), False
|
||||
|
||||
def get_last_assistant_message_id(self, bot_id: int, chat_id: int):
|
||||
return self.db.context_get_last_assistant_message_id(bot_id, chat_id)
|
||||
|
|
@ -262,38 +260,78 @@ class AiAgent:
|
|||
context.append(_serialize_message(message["role"], message["text"], message["image"]))
|
||||
return context
|
||||
|
||||
async def _generate_image(self, prompt: str, aspect_ratio: Optional[str], user_tag: str) -> Tuple[str, bytes]:
|
||||
async def _generate_reply(self, bot_id: int, chat_id: int,
|
||||
context: List[MessageTypedDict], allow_tools: bool = False) -> AssistantMessage:
|
||||
response = await self.client.chat.send_async(
|
||||
model=self.model_main,
|
||||
messages=context,
|
||||
tools=_get_tools_description() if allow_tools else None,
|
||||
tool_choice="auto" if allow_tools else None,
|
||||
max_tokens=MAX_OUTPUT_TOKENS,
|
||||
user=f'{self.platform}_{bot_id}_{chat_id}'
|
||||
)
|
||||
return response.choices[0].message
|
||||
|
||||
async def _process_tool_calls(self, bot_id: int, chat_id: int, tool_calls: List[ChatMessageToolCall],
|
||||
context: List[MessageTypedDict]) -> ToolCallResult:
|
||||
result = AiAgent.ToolCallResult()
|
||||
if tool_calls is not None:
|
||||
for tool_call in tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
if tool_name == "generate_image":
|
||||
prompt = tool_args.get("prompt", "")
|
||||
aspect_ratio = tool_args.get("aspect_ratio", None)
|
||||
result.generated_image, success = \
|
||||
await self._generate_image(bot_id, chat_id, prompt=prompt, aspect_ratio=aspect_ratio)
|
||||
tool_result_content = []
|
||||
if success:
|
||||
tool_result_content.append(
|
||||
{"type": "text",
|
||||
"text": "Изображение сгенерировано и будет показано пользователю."})
|
||||
tool_result_content.append(
|
||||
{"type": "image_url", "image_url": {"url": _encode_image(result.generated_image)}})
|
||||
else:
|
||||
tool_result_content.append(
|
||||
{"type": "text",
|
||||
"text": "Не удалось сгенерировать изображение."})
|
||||
context.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": tool_result_content
|
||||
})
|
||||
result.tools_called = True
|
||||
break
|
||||
return result
|
||||
|
||||
async def _generate_image(self, bot_id: int, chat_id: int, prompt: str, aspect_ratio: Optional[str]) \
|
||||
-> Tuple[Optional[bytes], bool]:
|
||||
print(f"Генерация изображения: {prompt}")
|
||||
context = [{"role": "user", "content": prompt}]
|
||||
if aspect_ratio is not None:
|
||||
|
||||
try:
|
||||
response = await self.client_image.chat.send_async(
|
||||
model=self.model_image,
|
||||
messages=context,
|
||||
user=user_tag,
|
||||
user=f'{self.platform}_{bot_id}_{chat_id}',
|
||||
modalities=["image"],
|
||||
image_config={"aspect_ratio": aspect_ratio}
|
||||
)
|
||||
else:
|
||||
response = await self.client_image.chat.send_async(
|
||||
model=self.model_image,
|
||||
messages=context,
|
||||
user=user_tag,
|
||||
modalities=["image"]
|
||||
image_config={"aspect_ratio": aspect_ratio} if aspect_ratio is not None else None
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
image_url = response.choices[0].message.images[0].image_url.url
|
||||
header, image_base64 = image_url.split(",", 1)
|
||||
mime_type = header.split(";")[0].replace("data:", "")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
|
||||
image_url = response.choices[0].message.images[0].image_url.url
|
||||
header, image_base64 = image_url.split(",", 1)
|
||||
mime_type = header.split(";")[0].replace("data:", "")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
if mime_type != "image/jpeg":
|
||||
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
||||
output = BytesIO()
|
||||
image.save(output, format="JPEG", quality=80, optimize=True)
|
||||
image_bytes = output.getvalue()
|
||||
|
||||
if mime_type != "image/jpeg":
|
||||
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
||||
output = BytesIO()
|
||||
image.save(output, format="JPEG", quality=80, optimize=True)
|
||||
image_bytes = output.getvalue()
|
||||
|
||||
return content, image_bytes
|
||||
return image_bytes, True
|
||||
except Exception:
|
||||
return None, False
|
||||
|
||||
|
||||
agent: AiAgent
|
||||
|
|
|
|||
|
|
@ -77,11 +77,16 @@ async def any_message_handler(message: Message, bot: Bot):
|
|||
ai_message = await create_ai_message(message, bot)
|
||||
ai_message.text = message_text
|
||||
|
||||
answer: ai_agent.Message
|
||||
success: bool
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
|
||||
partial(message.bot.send_chat_action, chat_id, 'typing'),
|
||||
interval=4)
|
||||
|
||||
answer_id = (await message.reply(answer)).message_id
|
||||
if answer.image is not None:
|
||||
answer_id = (await message.reply_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id
|
||||
else:
|
||||
answer_id = (await message.reply(answer.text)).message_id
|
||||
if success:
|
||||
ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id)
|
||||
|
|
|
|||
|
|
@ -213,11 +213,13 @@ async def check_rules_violation_handler(message: Message, bot: Bot):
|
|||
ai_fwd_messages = [ai_agent.Message(user_name=await get_user_name_for_ai(message.reply_to_message.from_user),
|
||||
text=message.reply_to_message.text)]
|
||||
|
||||
answer: ai_agent.Message
|
||||
success: bool
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_group_chat_reply, bot.id, chat_id, ai_message, ai_fwd_messages),
|
||||
partial(bot.send_chat_action, chat_id, 'typing'),
|
||||
interval=4)
|
||||
|
||||
answer_id = (await message.answer(answer)).message_id
|
||||
answer_id = (await message.answer(answer.text)).message_id
|
||||
if success:
|
||||
ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id)
|
||||
|
|
|
|||
|
|
@ -86,11 +86,13 @@ async def any_message_handler(message: Message):
|
|||
ai_message = await create_ai_message(message)
|
||||
ai_message.text = message_text
|
||||
|
||||
answer: ai_agent.Message
|
||||
success: bool
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages),
|
||||
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
|
||||
interval=4)
|
||||
|
||||
answer_id = (await message.reply(answer)).conversation_message_id
|
||||
answer_id = (await message.reply(answer.text)).conversation_message_id
|
||||
if success:
|
||||
ai_agent.agent.set_last_response_id(bot_id, chat_id, answer_id)
|
||||
|
|
|
|||
|
|
@ -264,11 +264,13 @@ async def check_rules_violation_handler(message: Message):
|
|||
await message.answer(MESSAGE_NEED_REPLY_OR_FORWARD)
|
||||
return
|
||||
|
||||
answer: ai_agent.Message
|
||||
success: bool
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_group_chat_reply, bot_id, chat_id, ai_message, ai_fwd_messages),
|
||||
partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'),
|
||||
interval=4)
|
||||
|
||||
answer_id = (await message.answer(answer)).message_id
|
||||
answer_id = (await message.answer(answer.text)).message_id
|
||||
if success:
|
||||
ai_agent.agent.set_last_response_id(bot_id, chat_id, answer_id)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue