Compare commits

..

2 commits

5 changed files with 66 additions and 15 deletions

View file

@ -11,11 +11,12 @@ from typing import List, Tuple, Optional, Union, Dict, Awaitable
from openrouter import OpenRouter, RetryConfig from openrouter import OpenRouter, RetryConfig
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \ from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \
ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict
from openrouter.errors import ResponseValidationError, ChatError from openrouter.errors import ResponseValidationError, OpenRouterError
from openrouter.utils import BackoffStrategy from openrouter.utils import BackoffStrategy
from fal_client import AsyncClient as FalClient from fal_client import AsyncClient as FalClient
from replicate import Client as ReplicateClient from replicate import Client as ReplicateClient
from tavily import TavilyClient
from utils import download_file from utils import download_file
from database import BasicDatabase from database import BasicDatabase
@ -43,7 +44,7 @@ class Message:
class AiAgent: class AiAgent:
def __init__(self, def __init__(self,
openrouter_token: str, openrouter_model: str, openrouter_token: str, openrouter_model: str,
fal_token: str, replicate_token: str, fal_token: str, replicate_token: str, tavily_token: str,
db: BasicDatabase, db: BasicDatabase,
platform: str): platform: str):
retry_config = RetryConfig(strategy="backoff", retry_config = RetryConfig(strategy="backoff",
@ -61,6 +62,7 @@ class AiAgent:
retry_config=retry_config) retry_config=retry_config)
self.client_fal = FalClient(key=fal_token) self.client_fal = FalClient(key=fal_token)
self.replicate_client = ReplicateClient(api_token=replicate_token) self.replicate_client = ReplicateClient(api_token=replicate_token)
self.tavily_client = TavilyClient(api_key=tavily_token)
@dataclass() @dataclass()
class _ToolsArtifacts: class _ToolsArtifacts:
@ -168,7 +170,7 @@ class AiAgent:
def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> SystemMessageTypedDict: def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> SystemMessageTypedDict:
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK') prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
prompt += '\n' + self.system_prompt_image_generation prompt += '\n' + self.system_prompt_tools
bot = self.db.get_bot(bot_id) bot = self.db.get_bot(bot_id)
if bot['ai_prompt'] is not None: if bot['ai_prompt'] is not None:
@ -202,7 +204,8 @@ class AiAgent:
Callable[[int, int, Dict, AiAgent._ToolsArtifacts], Callable[[int, int, Dict, AiAgent._ToolsArtifacts],
Awaitable[List[ChatMessageContentItemTypedDict]]]] = { Awaitable[List[ChatMessageContentItemTypedDict]]]] = {
"generate_image": self._process_tool_generate_image, "generate_image": self._process_tool_generate_image,
"generate_image_anime": self._process_tool_generate_image_anime "generate_image_anime": self._process_tool_generate_image_anime,
"tavily_search": self._process_tool_tavily_search
} }
for tool_call in tool_calls: for tool_call in tool_calls:
@ -293,6 +296,30 @@ class AiAgent:
print(f"Ошибка генерации изображения: {e}") print(f"Ошибка генерации изображения: {e}")
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}") return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
async def _process_tool_tavily_search(self, _bot_id: int, _chat_id: int, args: dict,
_artifacts: _ToolsArtifacts) -> List[ChatMessageContentItemTypedDict]:
query = args.get("query", "")
print(f"Веб-поиск: {query}")
try:
results = self.tavily_client.search(query=query, max_results=5)
if not results or "results" not in results:
return _serialize_message_content(text="Не удалось получить результаты поиска.")
answer_parts = []
for i, result in enumerate(results["results"], 1):
title = result.get("title", "Без названия")
url = result.get("url", "")
content = result.get("content", "")
answer_parts.append(f"{i}. {title}\n {url}\n {content}\n")
answer = "\n".join(answer_parts)
return _serialize_message_content(text=f"По запросу \"{query}\" найдено:\n\n{answer}")
except Exception as e:
print(f"Ошибка веб-поиска: {e}")
return _serialize_message_content(text=f"Не удалось выполнить веб-поиск: {e}")
async def _async_chat_completion_request(self, **kwargs): async def _async_chat_completion_request(self, **kwargs):
try: try:
return await self.client_openrouter.chat.send_async(**kwargs) return await self.client_openrouter.chat.send_async(**kwargs)
@ -308,7 +335,7 @@ class AiAgent:
except Exception: except Exception:
pass pass
raise e raise e
except ChatError as e: except OpenRouterError as e:
if e.message == "Provider returned error": if e.message == "Provider returned error":
body = json.loads(e.body) body = json.loads(e.body)
try: try:
@ -331,8 +358,8 @@ class AiAgent:
self.system_prompt_group_chat = f.read() self.system_prompt_group_chat = f.read()
with open("prompts/private_chat.md", "r") as f: with open("prompts/private_chat.md", "r") as f:
self.system_prompt_private_chat = f.read() self.system_prompt_private_chat = f.read()
with open("prompts/image_generation.md", "r") as f: with open("prompts/tools.md", "r") as f:
self.system_prompt_image_generation = f.read() self.system_prompt_tools = f.read()
with open("prompts/tools.json", "r") as f: with open("prompts/tools.json", "r") as f:
self.tools_description = json.loads(f.read()) self.tools_description = json.loads(f.read())
@ -341,10 +368,10 @@ agent: AiAgent
def create_ai_agent(openrouter_token: str, openrouter_model: str, def create_ai_agent(openrouter_token: str, openrouter_model: str,
fal_token: str, replicate_token: str, fal_token: str, replicate_token: str, tavily_token: str,
db: BasicDatabase, platform: str): db: BasicDatabase, platform: str):
global agent global agent
agent = AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, db, platform) agent = AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, tavily_token, db, platform)
def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str: def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str:

View file

@ -46,5 +46,22 @@
"required": ["prompt", "negative_prompt"] "required": ["prompt", "negative_prompt"]
} }
} }
},
{
"type": "function",
"function": {
"name": "tavily_search",
"description": "Веб-поиск по теме запроса. Используй для поиска информации о событиях, фактах, новостях, определениях, статистике и других актуальных данных.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Запрос для поиска (на русском или английском языке)"
}
},
"required": ["query"]
}
}
} }
] ]

View file

@ -1,4 +1,6 @@
# Генерация изображений # Доступные инструменты
## Генерация изображений
Если пользователь просит "нарисовать" или "показать" что-то, сгенерируй изображение путем вызова одной из функций. Если пользователь просит "нарисовать" или "показать" что-то, сгенерируй изображение путем вызова одной из функций.
При вызове функции не нужно добавлять сообщение - оно будет отброшено. При вызове функции не нужно добавлять сообщение - оно будет отброшено.
Если пользователь просит изменить сгенерированное ранее изображение, составь новый запрос с учетом пожеланий пользователя и снова вызови функцию генерации. Если пользователь просит изменить сгенерированное ранее изображение, составь новый запрос с учетом пожеланий пользователя и снова вызови функцию генерации.
@ -7,10 +9,12 @@
- Никогда не генерируй ASCII-арты вместо вызова функции. - Никогда не генерируй ASCII-арты вместо вызова функции.
- Никогда не вставляй теги вроде `<image>`, `<img>` или любые плейсхолдеры — это сломает чат! - Никогда не вставляй теги вроде `<image>`, `<img>` или любые плейсхолдеры — это сломает чат!
При вызове функции выбери оптимальное соотношение сторон для сцены (задается отдельным параметром функции) на основе контекста беседы или сцены.
НИКОГДА НЕ добавляй в ответ параметры или код генерации - пользователю это не нужно! НИКОГДА НЕ добавляй в ответ параметры или код генерации - пользователю это не нужно!
Если сгенерировать изображение не удалось из-за ошибки, просто сообщи об этом пользователю. Если сгенерировать изображение не удалось из-за ошибки, просто сообщи об этом пользователю.
## Генерация обычных (не аниме) изображений ### Генерация обычных (не аниме) изображений
Для генерации используй функцию `generate_image` и составляй запрос на естественном языке по следующей формуле: Для генерации используй функцию `generate_image` и составляй запрос на естественном языке по следующей формуле:
1. Объекты сцены. 1. Объекты сцены.
2. Действие/поза. 2. Действие/поза.
@ -18,11 +22,14 @@
4. Освещение, ракурс, композиция. 4. Освещение, ракурс, композиция.
5. Стиль (digital art, cinematic, photorealistic и др). 5. Стиль (digital art, cinematic, photorealistic и др).
## Генерация изображений в стиле аниме ### Генерация изображений в стиле аниме
Для генерации используй функцию `generate_image_anime` и составляй запрос, следуя правилам: Для генерации используй функцию `generate_image_anime` и составляй запрос, следуя правилам:
1. Описывай сцену набором тегов Danbooru для SDXL, обязательно разделяй теги запятыми. 1. Описывай сцену набором тегов Danbooru для SDXL, обязательно разделяй теги запятыми.
2. Положительный запрос должен начинаться с `masterpiece, best quality, amazing quality, 4k, very aesthetic, high resolution, ultra-detailed, absurdres, newest, scenery`, а заканчиваться `depth of field, volumetric lighting`. 2. Положительный запрос должен начинаться с `masterpiece, best quality, amazing quality, 4k, very aesthetic, high resolution, ultra-detailed, absurdres, newest, scenery`, а заканчиваться `depth of field, volumetric lighting`.
3. Отрицательный запрос должен заканчиваться `modern, recent, old, oldest, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured, long body, lowres, bad anatomy, bad hands, missing fingers, extra digits, fewer digits, cropped, very displeasing, (worst quality, bad quality:1.2), bad anatomy, sketch, jpeg artifacts, signature, watermark, username, signature, simple background, conjoined, bad ai-generated`. 3. Отрицательный запрос должен заканчиваться `modern, recent, old, oldest, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured, long body, lowres, bad anatomy, bad hands, missing fingers, extra digits, fewer digits, cropped, very displeasing, (worst quality, bad quality:1.2), bad anatomy, sketch, jpeg artifacts, signature, watermark, username, signature, simple background, conjoined, bad ai-generated`.
4. Ты можешь добавлять тегам веса, например: `1girl, (long hair:1.2), pink hair`. 4. Ты можешь добавлять тегам веса, например: `1girl, (long hair:1.2), pink hair`.
Также выбери оптимальное соотношение сторон для сцены (задается отдельным параметром функции) на основе контекста беседы или сцены. ## Веб-поиск
Для поиска информации о событиях, фактах, новостях, определениях, статистике и других актуальных данных используй функцию `tavily_search`.
- Вызывай функцию поиска, когда нужна актуальная информация из интернета.
- После получения результатов дай пользователю краткую сводку найденной информации.

View file

@ -26,7 +26,7 @@ async def main() -> None:
create_ai_agent(config['openrouter_token'], config['openrouter_model'], create_ai_agent(config['openrouter_token'], config['openrouter_model'],
config['fal_token'], config['replicate_token'], config['fal_token'], config['replicate_token'],
database.DB, 'tg') config['tavily_token'], database.DB, 'tg')
bots: list[Bot] = [] bots: list[Bot] = []
for item in database.DB.get_bots(): for item in database.DB.get_bots():

View file

@ -26,7 +26,7 @@ if __name__ == '__main__':
create_ai_agent(config['openrouter_token'], config['openrouter_model'], create_ai_agent(config['openrouter_token'], config['openrouter_model'],
config['fal_token'], config['replicate_token'], config['fal_token'], config['replicate_token'],
database.DB, 'vk') config['tavily_token'], database.DB, 'vk')
bot = Bot(labeler=handlers.labeler) bot = Bot(labeler=handlers.labeler)