Реализована генерация изображений в ЛС.
This commit is contained in:
parent
f1ae49c82d
commit
44f9bb8c04
4 changed files with 166 additions and 37 deletions
186
ai_agent.py
186
ai_agent.py
|
|
@ -1,9 +1,14 @@
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
|
import json
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
from typing import List, Tuple, Any, Optional
|
from typing import List, Tuple, Any, Optional
|
||||||
|
|
||||||
from openrouter import OpenRouter, RetryConfig
|
from openrouter import OpenRouter, RetryConfig
|
||||||
|
from openrouter.components import ToolDefinitionJSONTypedDict, MessageTypedDict
|
||||||
from openrouter.utils import BackoffStrategy
|
from openrouter.utils import BackoffStrategy
|
||||||
|
|
||||||
from database import BasicDatabase
|
from database import BasicDatabase
|
||||||
|
|
@ -23,7 +28,8 @@ PRIVATE_CHAT_SYSTEM_PROMPT = """
|
||||||
Отвечай на вопросы и поддерживай контекст беседы.\n
|
Отвечай на вопросы и поддерживай контекст беседы.\n
|
||||||
Сообщения пользователя будут приходить в следующем формате: '[дата время]: текст сообщения'\n
|
Сообщения пользователя будут приходить в следующем формате: '[дата время]: текст сообщения'\n
|
||||||
При ответе НЕ нужно указывать время.\n
|
При ответе НЕ нужно указывать время.\n
|
||||||
НЕ используй разметку Markdown.
|
Никогда не используй разметку Markdown.\n
|
||||||
|
Никогда не добавляй ASCII-арты в ответ.
|
||||||
"""
|
"""
|
||||||
PRIVATE_CHAT_MAX_MESSAGES = 40
|
PRIVATE_CHAT_MAX_MESSAGES = 40
|
||||||
|
|
||||||
|
|
@ -47,27 +53,68 @@ def _add_message_prefix(text: Optional[str], username: Optional[str] = None) ->
|
||||||
return f"{prefix}: {text}" if text is not None else f"{prefix}:"
|
return f"{prefix}: {text}" if text is not None else f"{prefix}:"
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_image(image: bytes) -> str:
|
||||||
|
encoded_image = base64.b64encode(image).decode('utf-8')
|
||||||
|
return f"data:image/jpeg;base64,{encoded_image}"
|
||||||
|
|
||||||
|
|
||||||
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
|
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
|
||||||
json = {"role": role, "content": []}
|
serialized = {"role": role, "content": []}
|
||||||
if text is not None:
|
if text is not None:
|
||||||
json["content"].append({"type": "text", "text": text})
|
serialized["content"].append({"type": "text", "text": text})
|
||||||
if image is not None:
|
if image is not None:
|
||||||
encoded_image = base64.b64encode(image).decode('utf-8')
|
serialized["content"].append({"type": "image_url", "image_url": {"url": _encode_image(image)}})
|
||||||
image_url = f"data:image/jpeg;base64,{encoded_image}"
|
|
||||||
json["content"].append({"type": "image_url", "image_url": {"url": image_url}})
|
return serialized
|
||||||
return json
|
|
||||||
|
|
||||||
|
def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]:
|
||||||
|
return [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "generate_image",
|
||||||
|
"description": """
|
||||||
|
Генерация изображения по описанию.
|
||||||
|
Используй этот инструмент, если пользователь просит сгенерировать изображение ('нарисуй', 'покажи' и т.п.),
|
||||||
|
ИЛИ если это улучшит ответ (например, в ролевой игре для визуализации сцены).
|
||||||
|
""",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": """
|
||||||
|
Детальное описание на английском (рекомендуется).
|
||||||
|
Добавь детали для стиля, цвета, композиции, если нужно.
|
||||||
|
"""
|
||||||
|
},
|
||||||
|
"aspect_ratio": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["1:1", "3:4", "4:3", "9:16", "16:9"],
|
||||||
|
"description": "Соотношение сторон (опционально)."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
|
||||||
class AiAgent:
|
class AiAgent:
|
||||||
def __init__(self, api_token: str, model: str, db: BasicDatabase, platform: str):
|
def __init__(self,
|
||||||
|
api_token_main: str, model_main: str,
|
||||||
|
api_token_image: str, model_image: str,
|
||||||
|
db: BasicDatabase,
|
||||||
|
platform: str):
|
||||||
retry_config = RetryConfig(strategy="backoff",
|
retry_config = RetryConfig(strategy="backoff",
|
||||||
backoff=BackoffStrategy(
|
backoff=BackoffStrategy(
|
||||||
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
|
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
|
||||||
retry_connection_errors=True)
|
retry_connection_errors=True)
|
||||||
self.db = db
|
self.db = db
|
||||||
self.model = model
|
self.model_main = model_main
|
||||||
|
self.model_image = model_image
|
||||||
self.platform = platform
|
self.platform = platform
|
||||||
self.client = OpenRouter(api_key=api_token, retry_config=retry_config)
|
self.client = OpenRouter(api_key=api_token_main, retry_config=retry_config)
|
||||||
|
self.client_image = OpenRouter(api_key=api_token_image)
|
||||||
|
|
||||||
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
||||||
message: Message, forwarded_messages: List[Message]) -> Tuple[str, bool]:
|
message: Message, forwarded_messages: List[Message]) -> Tuple[str, bool]:
|
||||||
|
|
@ -84,19 +131,15 @@ class AiAgent:
|
||||||
context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image))
|
context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get response from OpenRouter
|
|
||||||
response = await self.client.chat.send_async(
|
response = await self.client.chat.send_async(
|
||||||
model=self.model,
|
model=self.model_main,
|
||||||
messages=context,
|
messages=context,
|
||||||
max_tokens=500,
|
max_tokens=500,
|
||||||
user=f'{self.platform}_{bot_id}_{chat_id}',
|
user=f'{self.platform}_{bot_id}_{chat_id}',
|
||||||
http_headers=OPENROUTER_HEADERS
|
http_headers=OPENROUTER_HEADERS
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract AI response
|
|
||||||
ai_response = response.choices[0].message.content
|
ai_response = response.choices[0].message.content
|
||||||
|
|
||||||
# Add input messages and AI response to context
|
|
||||||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||||||
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||||
for fwd_message in forwarded_messages:
|
for fwd_message in forwarded_messages:
|
||||||
|
|
@ -114,45 +157,85 @@ class AiAgent:
|
||||||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||||||
return f"Извините, при обработке запроса произошла ошибка.", False
|
return f"Извините, при обработке запроса произошла ошибка.", False
|
||||||
|
|
||||||
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[str, bool]:
|
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]:
|
||||||
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
|
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
|
||||||
message.text = _add_message_prefix(message.text)
|
message.text = _add_message_prefix(message.text)
|
||||||
content: list[dict[str, Any]] = []
|
content: list[dict[str, Any]] = []
|
||||||
if message.text is not None:
|
if message.text is not None:
|
||||||
content.append({"type": "text", "text": message.text})
|
content.append({"type": "text", "text": message.text})
|
||||||
if message.image is not None:
|
if message.image is not None:
|
||||||
encoded_image = base64.b64encode(message.image).decode('utf-8')
|
content.append({"type": "image_url", "image_url": {"url": _encode_image(message.image)}})
|
||||||
image_url = f"data:image/jpeg;base64,{encoded_image}"
|
|
||||||
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
|
||||||
context.append({"role": "user", "content": content})
|
context.append({"role": "user", "content": content})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get response from OpenRouter
|
user_tag = f'{self.platform}_{bot_id}_{chat_id}'
|
||||||
|
|
||||||
response = await self.client.chat.send_async(
|
response = await self.client.chat.send_async(
|
||||||
model=self.model,
|
model=self.model_main,
|
||||||
messages=context,
|
messages=context,
|
||||||
|
tools=_get_tools_description(),
|
||||||
|
tool_choice="auto",
|
||||||
max_tokens=500,
|
max_tokens=500,
|
||||||
user=f'{self.platform}_{bot_id}_{chat_id}',
|
user=user_tag,
|
||||||
http_headers=OPENROUTER_HEADERS
|
http_headers=OPENROUTER_HEADERS
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract AI response
|
|
||||||
ai_response = response.choices[0].message.content
|
ai_response = response.choices[0].message.content
|
||||||
|
context.append(response.choices[0].message)
|
||||||
|
|
||||||
|
if len(response.choices[0].message.tool_calls) > 0 and ai_response is not None and len(ai_response) > 0:
|
||||||
|
print(f"Модель хочет вызвать функцию, но также вернула текст: '{ai_response}'")
|
||||||
|
|
||||||
|
image_response_image: Optional[bytes] = None
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
tool_name = tool_call.function.name
|
||||||
|
tool_args = json.loads(tool_call.function.arguments)
|
||||||
|
if tool_name == "generate_image":
|
||||||
|
prompt = tool_args.get("prompt", "")
|
||||||
|
aspect_ratio = tool_args.get("aspect_ratio", None)
|
||||||
|
image_response_text, image_response_image =\
|
||||||
|
await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio, user_tag=user_tag)
|
||||||
|
context.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call.id,
|
||||||
|
"content": [
|
||||||
|
{"type": "text",
|
||||||
|
"text": """
|
||||||
|
Изображение сгенерировано и будет показано пользователю.
|
||||||
|
НЕ добавляй никаких тегов или маркеров вроде <image>, [image] — они запрещены и не нужны.
|
||||||
|
"""},
|
||||||
|
{"type": "image_url", "image_url": {"url": _encode_image(image_response_image)}}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
response2 = await self.client.chat.send_async(
|
||||||
|
model=self.model_main,
|
||||||
|
messages=context,
|
||||||
|
max_tokens=500,
|
||||||
|
user=user_tag,
|
||||||
|
http_headers=OPENROUTER_HEADERS
|
||||||
|
)
|
||||||
|
ai_response = response2.choices[0].message.content
|
||||||
|
break
|
||||||
|
|
||||||
# Add message and AI response to context
|
|
||||||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||||||
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||||
self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None,
|
|
||||||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
|
||||||
|
|
||||||
return ai_response, True
|
if image_response_image is not None:
|
||||||
|
self.db.context_add_message(bot_id, chat_id, role="assistant",
|
||||||
|
text=ai_response, image=image_response_image,
|
||||||
|
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||||
|
else:
|
||||||
|
self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None,
|
||||||
|
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||||
|
|
||||||
|
return Message(text=ai_response, image=image_response_image), True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("Rate limit exceeded") != -1:
|
if str(e).find("Rate limit exceeded") != -1:
|
||||||
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК).", False
|
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
|
||||||
else:
|
else:
|
||||||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||||||
return f"Извините, при обработке запроса произошла ошибка.", False
|
return Message(text=f"Извините, при обработке запроса произошла ошибка."), False
|
||||||
|
|
||||||
def get_last_assistant_message_id(self, bot_id: int, chat_id: int):
|
def get_last_assistant_message_id(self, bot_id: int, chat_id: int):
|
||||||
return self.db.context_get_last_assistant_message_id(bot_id, chat_id)
|
return self.db.context_get_last_assistant_message_id(bot_id, chat_id)
|
||||||
|
|
@ -163,7 +246,7 @@ class AiAgent:
|
||||||
def clear_chat_context(self, bot_id: int, chat_id: int):
|
def clear_chat_context(self, bot_id: int, chat_id: int):
|
||||||
self.db.context_clear(bot_id, chat_id)
|
self.db.context_clear(bot_id, chat_id)
|
||||||
|
|
||||||
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> list[dict]:
|
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[MessageTypedDict]:
|
||||||
prompt = GROUP_CHAT_SYSTEM_PROMPT if is_group_chat else PRIVATE_CHAT_SYSTEM_PROMPT
|
prompt = GROUP_CHAT_SYSTEM_PROMPT if is_group_chat else PRIVATE_CHAT_SYSTEM_PROMPT
|
||||||
|
|
||||||
bot = self.db.get_bot(bot_id)
|
bot = self.db.get_bot(bot_id)
|
||||||
|
|
@ -176,15 +259,50 @@ class AiAgent:
|
||||||
|
|
||||||
messages = self.db.context_get_messages(bot_id, chat_id)
|
messages = self.db.context_get_messages(bot_id, chat_id)
|
||||||
|
|
||||||
context = [{"role": "system", "content": prompt}]
|
context: List[MessageTypedDict] = [{"role": "system", "content": prompt}]
|
||||||
for message in messages:
|
for message in messages:
|
||||||
context.append(_serialize_message(message["role"], message["text"], message["image"]))
|
context.append(_serialize_message(message["role"], message["text"], message["image"]))
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
async def _generate_image(self, prompt: str, aspect_ratio: Optional[str], user_tag: str) -> Tuple[str, bytes]:
|
||||||
|
context = [{"role": "user", "content": prompt}]
|
||||||
|
if aspect_ratio is not None:
|
||||||
|
response = await self.client_image.chat.send_async(
|
||||||
|
model=self.model_image,
|
||||||
|
messages=context,
|
||||||
|
user=user_tag,
|
||||||
|
modalities=["image"],
|
||||||
|
image_config={"aspect_ratio": aspect_ratio}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await self.client_image.chat.send_async(
|
||||||
|
model=self.model_image,
|
||||||
|
messages=context,
|
||||||
|
user=user_tag,
|
||||||
|
modalities=["image"]
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
image_url = response.choices[0].message.images[0].image_url.url
|
||||||
|
header, image_base64 = image_url.split(",", 1)
|
||||||
|
mime_type = header.split(";")[0].replace("data:", "")
|
||||||
|
image_bytes = base64.b64decode(image_base64)
|
||||||
|
|
||||||
|
if mime_type != "image/jpeg":
|
||||||
|
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
||||||
|
output = BytesIO()
|
||||||
|
image.save(output, format="JPEG", quality=80, optimize=True)
|
||||||
|
image_bytes = output.getvalue()
|
||||||
|
|
||||||
|
return content, image_bytes
|
||||||
|
|
||||||
|
|
||||||
agent: AiAgent
|
agent: AiAgent
|
||||||
|
|
||||||
|
|
||||||
def create_ai_agent(api_token: str, model: str, db: BasicDatabase, platform: str):
|
def create_ai_agent(api_token_main: str, model_main: str,
|
||||||
|
api_token_image: str, model_image: str,
|
||||||
|
db: BasicDatabase, platform: str):
|
||||||
global agent
|
global agent
|
||||||
agent = AiAgent(api_token, model, db, platform)
|
agent = AiAgent(api_token_main, model_main, api_token_image, model_image, db, platform)
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,9 @@ async def main() -> None:
|
||||||
|
|
||||||
database.create_database(config['db_connection_string'])
|
database.create_database(config['db_connection_string'])
|
||||||
|
|
||||||
create_ai_agent(config['openrouter_token'], config['openrouter_model'], database.DB, 'tg')
|
create_ai_agent(config['openrouter_token_main'], config['openrouter_model_main'],
|
||||||
|
config['openrouter_token_image'], config['openrouter_model_image'],
|
||||||
|
database.DB, 'tg')
|
||||||
|
|
||||||
bots: list[Bot] = []
|
bots: list[Bot] = []
|
||||||
for item in database.DB.get_bots():
|
for item in database.DB.get_bots():
|
||||||
|
|
|
||||||
|
|
@ -52,11 +52,16 @@ async def any_message_handler(message: Message, bot: Bot):
|
||||||
await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
|
await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
answer: ai_agent.Message
|
||||||
|
success: bool
|
||||||
answer, success = await utils.run_with_progress(
|
answer, success = await utils.run_with_progress(
|
||||||
partial(ai_agent.agent.get_private_chat_reply, bot.id, chat_id, ai_message),
|
partial(ai_agent.agent.get_private_chat_reply, bot.id, chat_id, ai_message),
|
||||||
partial(message.bot.send_chat_action, chat_id, 'typing'),
|
partial(message.bot.send_chat_action, chat_id, 'typing'),
|
||||||
interval=4)
|
interval=4)
|
||||||
|
|
||||||
answer_id = (await message.answer(answer)).message_id
|
if answer.image is not None:
|
||||||
|
answer_id = (await message.answer_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id
|
||||||
|
else:
|
||||||
|
answer_id = (await message.answer(answer.text)).message_id
|
||||||
if success:
|
if success:
|
||||||
ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id)
|
ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.enums import ContentType
|
from aiogram.enums import ContentType
|
||||||
from aiogram.types import User, PhotoSize, Message
|
from aiogram.types import User, PhotoSize, Message, BufferedInputFile
|
||||||
|
|
||||||
import ai_agent
|
import ai_agent
|
||||||
import utils
|
import utils
|
||||||
|
|
@ -50,3 +50,7 @@ async def create_ai_message(message: Message, bot: Bot) -> ai_agent.Message:
|
||||||
else:
|
else:
|
||||||
raise utils.UnsupportedContentType()
|
raise utils.UnsupportedContentType()
|
||||||
return ai_message
|
return ai_message
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_photo(image: bytes) -> BufferedInputFile:
|
||||||
|
return BufferedInputFile(image, 'image.jpg')
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue