Реализована генерация изображений в ЛС.
This commit is contained in:
parent
f1ae49c82d
commit
44f9bb8c04
4 changed files with 166 additions and 37 deletions
186
ai_agent.py
186
ai_agent.py
|
|
@ -1,9 +1,14 @@
|
|||
import base64
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from typing import List, Tuple, Any, Optional
|
||||
|
||||
from openrouter import OpenRouter, RetryConfig
|
||||
from openrouter.components import ToolDefinitionJSONTypedDict, MessageTypedDict
|
||||
from openrouter.utils import BackoffStrategy
|
||||
|
||||
from database import BasicDatabase
|
||||
|
|
@ -23,7 +28,8 @@ PRIVATE_CHAT_SYSTEM_PROMPT = """
|
|||
Отвечай на вопросы и поддерживай контекст беседы.\n
|
||||
Сообщения пользователя будут приходить в следующем формате: '[дата время]: текст сообщения'\n
|
||||
При ответе НЕ нужно указывать время.\n
|
||||
НЕ используй разметку Markdown.
|
||||
Никогда не используй разметку Markdown.\n
|
||||
Никогда не добавляй ASCII-арты в ответ.
|
||||
"""
|
||||
PRIVATE_CHAT_MAX_MESSAGES = 40
|
||||
|
||||
|
|
@ -47,27 +53,68 @@ def _add_message_prefix(text: Optional[str], username: Optional[str] = None) ->
|
|||
return f"{prefix}: {text}" if text is not None else f"{prefix}:"
|
||||
|
||||
|
||||
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
|
||||
json = {"role": role, "content": []}
|
||||
if text is not None:
|
||||
json["content"].append({"type": "text", "text": text})
|
||||
if image is not None:
|
||||
def _encode_image(image: bytes) -> str:
|
||||
encoded_image = base64.b64encode(image).decode('utf-8')
|
||||
image_url = f"data:image/jpeg;base64,{encoded_image}"
|
||||
json["content"].append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
return json
|
||||
return f"data:image/jpeg;base64,{encoded_image}"
|
||||
|
||||
|
||||
def _serialize_message(role: str, text: Optional[str], image: Optional[bytes]) -> dict:
|
||||
serialized = {"role": role, "content": []}
|
||||
if text is not None:
|
||||
serialized["content"].append({"type": "text", "text": text})
|
||||
if image is not None:
|
||||
serialized["content"].append({"type": "image_url", "image_url": {"url": _encode_image(image)}})
|
||||
|
||||
return serialized
|
||||
|
||||
|
||||
def _get_tools_description() -> List[ToolDefinitionJSONTypedDict]:
|
||||
return [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_image",
|
||||
"description": """
|
||||
Генерация изображения по описанию.
|
||||
Используй этот инструмент, если пользователь просит сгенерировать изображение ('нарисуй', 'покажи' и т.п.),
|
||||
ИЛИ если это улучшит ответ (например, в ролевой игре для визуализации сцены).
|
||||
""",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": """
|
||||
Детальное описание на английском (рекомендуется).
|
||||
Добавь детали для стиля, цвета, композиции, если нужно.
|
||||
"""
|
||||
},
|
||||
"aspect_ratio": {
|
||||
"type": "string",
|
||||
"enum": ["1:1", "3:4", "4:3", "9:16", "16:9"],
|
||||
"description": "Соотношение сторон (опционально)."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
|
||||
class AiAgent:
|
||||
def __init__(self, api_token: str, model: str, db: BasicDatabase, platform: str):
|
||||
def __init__(self,
|
||||
api_token_main: str, model_main: str,
|
||||
api_token_image: str, model_image: str,
|
||||
db: BasicDatabase,
|
||||
platform: str):
|
||||
retry_config = RetryConfig(strategy="backoff",
|
||||
backoff=BackoffStrategy(
|
||||
initial_interval=2000, max_interval=8000, exponent=2, max_elapsed_time=14000),
|
||||
retry_connection_errors=True)
|
||||
self.db = db
|
||||
self.model = model
|
||||
self.model_main = model_main
|
||||
self.model_image = model_image
|
||||
self.platform = platform
|
||||
self.client = OpenRouter(api_key=api_token, retry_config=retry_config)
|
||||
self.client = OpenRouter(api_key=api_token_main, retry_config=retry_config)
|
||||
self.client_image = OpenRouter(api_key=api_token_image)
|
||||
|
||||
async def get_group_chat_reply(self, bot_id: int, chat_id: int,
|
||||
message: Message, forwarded_messages: List[Message]) -> Tuple[str, bool]:
|
||||
|
|
@ -84,19 +131,15 @@ class AiAgent:
|
|||
context.append(_serialize_message(role="user", text=fwd_message.text, image=fwd_message.image))
|
||||
|
||||
try:
|
||||
# Get response from OpenRouter
|
||||
response = await self.client.chat.send_async(
|
||||
model=self.model,
|
||||
model=self.model_main,
|
||||
messages=context,
|
||||
max_tokens=500,
|
||||
user=f'{self.platform}_{bot_id}_{chat_id}',
|
||||
http_headers=OPENROUTER_HEADERS
|
||||
)
|
||||
|
||||
# Extract AI response
|
||||
ai_response = response.choices[0].message.content
|
||||
|
||||
# Add input messages and AI response to context
|
||||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||||
message_id=message.message_id, max_messages=GROUP_CHAT_MAX_MESSAGES)
|
||||
for fwd_message in forwarded_messages:
|
||||
|
|
@ -114,45 +157,85 @@ class AiAgent:
|
|||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||||
return f"Извините, при обработке запроса произошла ошибка.", False
|
||||
|
||||
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[str, bool]:
|
||||
async def get_private_chat_reply(self, bot_id: int, chat_id: int, message: Message) -> Tuple[Message, bool]:
|
||||
context = self._get_chat_context(is_group_chat=False, bot_id=bot_id, chat_id=chat_id)
|
||||
message.text = _add_message_prefix(message.text)
|
||||
content: list[dict[str, Any]] = []
|
||||
if message.text is not None:
|
||||
content.append({"type": "text", "text": message.text})
|
||||
if message.image is not None:
|
||||
encoded_image = base64.b64encode(message.image).decode('utf-8')
|
||||
image_url = f"data:image/jpeg;base64,{encoded_image}"
|
||||
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
content.append({"type": "image_url", "image_url": {"url": _encode_image(message.image)}})
|
||||
context.append({"role": "user", "content": content})
|
||||
|
||||
try:
|
||||
# Get response from OpenRouter
|
||||
user_tag = f'{self.platform}_{bot_id}_{chat_id}'
|
||||
|
||||
response = await self.client.chat.send_async(
|
||||
model=self.model,
|
||||
model=self.model_main,
|
||||
messages=context,
|
||||
tools=_get_tools_description(),
|
||||
tool_choice="auto",
|
||||
max_tokens=500,
|
||||
user=f'{self.platform}_{bot_id}_{chat_id}',
|
||||
user=user_tag,
|
||||
http_headers=OPENROUTER_HEADERS
|
||||
)
|
||||
|
||||
# Extract AI response
|
||||
ai_response = response.choices[0].message.content
|
||||
context.append(response.choices[0].message)
|
||||
|
||||
if len(response.choices[0].message.tool_calls) > 0 and ai_response is not None and len(ai_response) > 0:
|
||||
print(f"Модель хочет вызвать функцию, но также вернула текст: '{ai_response}'")
|
||||
|
||||
image_response_image: Optional[bytes] = None
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
tool_name = tool_call.function.name
|
||||
tool_args = json.loads(tool_call.function.arguments)
|
||||
if tool_name == "generate_image":
|
||||
prompt = tool_args.get("prompt", "")
|
||||
aspect_ratio = tool_args.get("aspect_ratio", None)
|
||||
image_response_text, image_response_image =\
|
||||
await self._generate_image(prompt=prompt, aspect_ratio=aspect_ratio, user_tag=user_tag)
|
||||
context.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": [
|
||||
{"type": "text",
|
||||
"text": """
|
||||
Изображение сгенерировано и будет показано пользователю.
|
||||
НЕ добавляй никаких тегов или маркеров вроде <image>, [image] — они запрещены и не нужны.
|
||||
"""},
|
||||
{"type": "image_url", "image_url": {"url": _encode_image(image_response_image)}}
|
||||
]
|
||||
})
|
||||
|
||||
response2 = await self.client.chat.send_async(
|
||||
model=self.model_main,
|
||||
messages=context,
|
||||
max_tokens=500,
|
||||
user=user_tag,
|
||||
http_headers=OPENROUTER_HEADERS
|
||||
)
|
||||
ai_response = response2.choices[0].message.content
|
||||
break
|
||||
|
||||
# Add message and AI response to context
|
||||
self.db.context_add_message(bot_id, chat_id, role="user", text=message.text, image=message.image,
|
||||
message_id=message.message_id, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
|
||||
if image_response_image is not None:
|
||||
self.db.context_add_message(bot_id, chat_id, role="assistant",
|
||||
text=ai_response, image=image_response_image,
|
||||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
else:
|
||||
self.db.context_add_message(bot_id, chat_id, role="assistant", text=ai_response, image=None,
|
||||
message_id=None, max_messages=PRIVATE_CHAT_MAX_MESSAGES)
|
||||
|
||||
return ai_response, True
|
||||
return Message(text=ai_response, image=image_response_image), True
|
||||
|
||||
except Exception as e:
|
||||
if str(e).find("Rate limit exceeded") != -1:
|
||||
return "Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК).", False
|
||||
return Message(text="Извините, достигнут дневной лимит запросов к ИИ (обновляется в 03:00 МСК)."), False
|
||||
else:
|
||||
print(f"Ошибка выполнения запроса к ИИ: {e}")
|
||||
return f"Извините, при обработке запроса произошла ошибка.", False
|
||||
return Message(text=f"Извините, при обработке запроса произошла ошибка."), False
|
||||
|
||||
def get_last_assistant_message_id(self, bot_id: int, chat_id: int):
|
||||
return self.db.context_get_last_assistant_message_id(bot_id, chat_id)
|
||||
|
|
@ -163,7 +246,7 @@ class AiAgent:
|
|||
def clear_chat_context(self, bot_id: int, chat_id: int):
|
||||
self.db.context_clear(bot_id, chat_id)
|
||||
|
||||
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> list[dict]:
|
||||
def _get_chat_context(self, is_group_chat: bool, bot_id: int, chat_id: int) -> List[MessageTypedDict]:
|
||||
prompt = GROUP_CHAT_SYSTEM_PROMPT if is_group_chat else PRIVATE_CHAT_SYSTEM_PROMPT
|
||||
|
||||
bot = self.db.get_bot(bot_id)
|
||||
|
|
@ -176,15 +259,50 @@ class AiAgent:
|
|||
|
||||
messages = self.db.context_get_messages(bot_id, chat_id)
|
||||
|
||||
context = [{"role": "system", "content": prompt}]
|
||||
context: List[MessageTypedDict] = [{"role": "system", "content": prompt}]
|
||||
for message in messages:
|
||||
context.append(_serialize_message(message["role"], message["text"], message["image"]))
|
||||
return context
|
||||
|
||||
async def _generate_image(self, prompt: str, aspect_ratio: Optional[str], user_tag: str) -> Tuple[str, bytes]:
|
||||
context = [{"role": "user", "content": prompt}]
|
||||
if aspect_ratio is not None:
|
||||
response = await self.client_image.chat.send_async(
|
||||
model=self.model_image,
|
||||
messages=context,
|
||||
user=user_tag,
|
||||
modalities=["image"],
|
||||
image_config={"aspect_ratio": aspect_ratio}
|
||||
)
|
||||
else:
|
||||
response = await self.client_image.chat.send_async(
|
||||
model=self.model_image,
|
||||
messages=context,
|
||||
user=user_tag,
|
||||
modalities=["image"]
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
|
||||
image_url = response.choices[0].message.images[0].image_url.url
|
||||
header, image_base64 = image_url.split(",", 1)
|
||||
mime_type = header.split(";")[0].replace("data:", "")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
|
||||
if mime_type != "image/jpeg":
|
||||
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
||||
output = BytesIO()
|
||||
image.save(output, format="JPEG", quality=80, optimize=True)
|
||||
image_bytes = output.getvalue()
|
||||
|
||||
return content, image_bytes
|
||||
|
||||
|
||||
agent: AiAgent
|
||||
|
||||
|
||||
def create_ai_agent(api_token: str, model: str, db: BasicDatabase, platform: str):
|
||||
def create_ai_agent(api_token_main: str, model_main: str,
|
||||
api_token_image: str, model_image: str,
|
||||
db: BasicDatabase, platform: str):
|
||||
global agent
|
||||
agent = AiAgent(api_token, model, db, platform)
|
||||
agent = AiAgent(api_token_main, model_main, api_token_image, model_image, db, platform)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,9 @@ async def main() -> None:
|
|||
|
||||
database.create_database(config['db_connection_string'])
|
||||
|
||||
create_ai_agent(config['openrouter_token'], config['openrouter_model'], database.DB, 'tg')
|
||||
create_ai_agent(config['openrouter_token_main'], config['openrouter_model_main'],
|
||||
config['openrouter_token_image'], config['openrouter_model_image'],
|
||||
database.DB, 'tg')
|
||||
|
||||
bots: list[Bot] = []
|
||||
for item in database.DB.get_bots():
|
||||
|
|
|
|||
|
|
@ -52,11 +52,16 @@ async def any_message_handler(message: Message, bot: Bot):
|
|||
await message.answer(MESSAGE_UNSUPPORTED_CONTENT_TYPE)
|
||||
return
|
||||
|
||||
answer: ai_agent.Message
|
||||
success: bool
|
||||
answer, success = await utils.run_with_progress(
|
||||
partial(ai_agent.agent.get_private_chat_reply, bot.id, chat_id, ai_message),
|
||||
partial(message.bot.send_chat_action, chat_id, 'typing'),
|
||||
interval=4)
|
||||
|
||||
answer_id = (await message.answer(answer)).message_id
|
||||
if answer.image is not None:
|
||||
answer_id = (await message.answer_photo(photo=wrap_photo(answer.image), caption=answer.text)).message_id
|
||||
else:
|
||||
answer_id = (await message.answer(answer.text)).message_id
|
||||
if success:
|
||||
ai_agent.agent.set_last_response_id(bot.id, chat_id, answer_id)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
|
||||
from aiogram import Bot
|
||||
from aiogram.enums import ContentType
|
||||
from aiogram.types import User, PhotoSize, Message
|
||||
from aiogram.types import User, PhotoSize, Message, BufferedInputFile
|
||||
|
||||
import ai_agent
|
||||
import utils
|
||||
|
|
@ -50,3 +50,7 @@ async def create_ai_message(message: Message, bot: Bot) -> ai_agent.Message:
|
|||
else:
|
||||
raise utils.UnsupportedContentType()
|
||||
return ai_message
|
||||
|
||||
|
||||
def wrap_photo(image: bytes) -> BufferedInputFile:
|
||||
return BufferedInputFile(image, 'image.jpg')
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue