Compare commits
2 commits
97dbfd5dbb
...
54ef3bbddd
| Author | SHA1 | Date | |
|---|---|---|---|
| 54ef3bbddd | |||
| beda26cb55 |
5 changed files with 66 additions and 15 deletions
45
ai_agent.py
45
ai_agent.py
|
|
@ -11,11 +11,12 @@ from typing import List, Tuple, Optional, Union, Dict, Awaitable
|
||||||
from openrouter import OpenRouter, RetryConfig
|
from openrouter import OpenRouter, RetryConfig
|
||||||
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \
|
from openrouter.components import AssistantMessage, AssistantMessageTypedDict, ChatMessageContentItemTypedDict, \
|
||||||
ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict
|
ChatMessageToolCall, MessageTypedDict, SystemMessageTypedDict
|
||||||
from openrouter.errors import ResponseValidationError, ChatError
|
from openrouter.errors import ResponseValidationError, OpenRouterError
|
||||||
from openrouter.utils import BackoffStrategy
|
from openrouter.utils import BackoffStrategy
|
||||||
|
|
||||||
from fal_client import AsyncClient as FalClient
|
from fal_client import AsyncClient as FalClient
|
||||||
from replicate import Client as ReplicateClient
|
from replicate import Client as ReplicateClient
|
||||||
|
from tavily import TavilyClient
|
||||||
|
|
||||||
from utils import download_file
|
from utils import download_file
|
||||||
from database import BasicDatabase
|
from database import BasicDatabase
|
||||||
|
|
@ -43,7 +44,7 @@ class Message:
|
||||||
class AiAgent:
|
class AiAgent:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
openrouter_token: str, openrouter_model: str,
|
openrouter_token: str, openrouter_model: str,
|
||||||
fal_token: str, replicate_token: str,
|
fal_token: str, replicate_token: str, tavily_token: str,
|
||||||
db: BasicDatabase,
|
db: BasicDatabase,
|
||||||
platform: str):
|
platform: str):
|
||||||
retry_config = RetryConfig(strategy="backoff",
|
retry_config = RetryConfig(strategy="backoff",
|
||||||
|
|
@ -61,6 +62,7 @@ class AiAgent:
|
||||||
retry_config=retry_config)
|
retry_config=retry_config)
|
||||||
self.client_fal = FalClient(key=fal_token)
|
self.client_fal = FalClient(key=fal_token)
|
||||||
self.replicate_client = ReplicateClient(api_token=replicate_token)
|
self.replicate_client = ReplicateClient(api_token=replicate_token)
|
||||||
|
self.tavily_client = TavilyClient(api_key=tavily_token)
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class _ToolsArtifacts:
|
class _ToolsArtifacts:
|
||||||
|
|
@ -168,7 +170,7 @@ class AiAgent:
|
||||||
def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> SystemMessageTypedDict:
|
def _construct_system_prompt(self, is_group_chat: bool, bot_id: int, chat_id: int) -> SystemMessageTypedDict:
|
||||||
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
|
prompt = self.system_prompt_group_chat if is_group_chat else self.system_prompt_private_chat
|
||||||
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
|
prompt = prompt.replace('{platform}', 'Telegram' if self.platform == 'tg' else 'VK')
|
||||||
prompt += '\n' + self.system_prompt_image_generation
|
prompt += '\n' + self.system_prompt_tools
|
||||||
|
|
||||||
bot = self.db.get_bot(bot_id)
|
bot = self.db.get_bot(bot_id)
|
||||||
if bot['ai_prompt'] is not None:
|
if bot['ai_prompt'] is not None:
|
||||||
|
|
@ -202,7 +204,8 @@ class AiAgent:
|
||||||
Callable[[int, int, Dict, AiAgent._ToolsArtifacts],
|
Callable[[int, int, Dict, AiAgent._ToolsArtifacts],
|
||||||
Awaitable[List[ChatMessageContentItemTypedDict]]]] = {
|
Awaitable[List[ChatMessageContentItemTypedDict]]]] = {
|
||||||
"generate_image": self._process_tool_generate_image,
|
"generate_image": self._process_tool_generate_image,
|
||||||
"generate_image_anime": self._process_tool_generate_image_anime
|
"generate_image_anime": self._process_tool_generate_image_anime,
|
||||||
|
"tavily_search": self._process_tool_tavily_search
|
||||||
}
|
}
|
||||||
|
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
|
|
@ -293,6 +296,30 @@ class AiAgent:
|
||||||
print(f"Ошибка генерации изображения: {e}")
|
print(f"Ошибка генерации изображения: {e}")
|
||||||
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
|
return _serialize_message_content(text=f"Не удалось сгенерировать изображение: {e}")
|
||||||
|
|
||||||
|
async def _process_tool_tavily_search(self, _bot_id: int, _chat_id: int, args: dict,
|
||||||
|
_artifacts: _ToolsArtifacts) -> List[ChatMessageContentItemTypedDict]:
|
||||||
|
query = args.get("query", "")
|
||||||
|
print(f"Веб-поиск: {query}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = self.tavily_client.search(query=query, max_results=5)
|
||||||
|
|
||||||
|
if not results or "results" not in results:
|
||||||
|
return _serialize_message_content(text="Не удалось получить результаты поиска.")
|
||||||
|
|
||||||
|
answer_parts = []
|
||||||
|
for i, result in enumerate(results["results"], 1):
|
||||||
|
title = result.get("title", "Без названия")
|
||||||
|
url = result.get("url", "")
|
||||||
|
content = result.get("content", "")
|
||||||
|
answer_parts.append(f"{i}. {title}\n {url}\n {content}\n")
|
||||||
|
|
||||||
|
answer = "\n".join(answer_parts)
|
||||||
|
return _serialize_message_content(text=f"По запросу \"{query}\" найдено:\n\n{answer}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Ошибка веб-поиска: {e}")
|
||||||
|
return _serialize_message_content(text=f"Не удалось выполнить веб-поиск: {e}")
|
||||||
|
|
||||||
async def _async_chat_completion_request(self, **kwargs):
|
async def _async_chat_completion_request(self, **kwargs):
|
||||||
try:
|
try:
|
||||||
return await self.client_openrouter.chat.send_async(**kwargs)
|
return await self.client_openrouter.chat.send_async(**kwargs)
|
||||||
|
|
@ -308,7 +335,7 @@ class AiAgent:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
raise e
|
raise e
|
||||||
except ChatError as e:
|
except OpenRouterError as e:
|
||||||
if e.message == "Provider returned error":
|
if e.message == "Provider returned error":
|
||||||
body = json.loads(e.body)
|
body = json.loads(e.body)
|
||||||
try:
|
try:
|
||||||
|
|
@ -331,8 +358,8 @@ class AiAgent:
|
||||||
self.system_prompt_group_chat = f.read()
|
self.system_prompt_group_chat = f.read()
|
||||||
with open("prompts/private_chat.md", "r") as f:
|
with open("prompts/private_chat.md", "r") as f:
|
||||||
self.system_prompt_private_chat = f.read()
|
self.system_prompt_private_chat = f.read()
|
||||||
with open("prompts/image_generation.md", "r") as f:
|
with open("prompts/tools.md", "r") as f:
|
||||||
self.system_prompt_image_generation = f.read()
|
self.system_prompt_tools = f.read()
|
||||||
with open("prompts/tools.json", "r") as f:
|
with open("prompts/tools.json", "r") as f:
|
||||||
self.tools_description = json.loads(f.read())
|
self.tools_description = json.loads(f.read())
|
||||||
|
|
||||||
|
|
@ -341,10 +368,10 @@ agent: AiAgent
|
||||||
|
|
||||||
|
|
||||||
def create_ai_agent(openrouter_token: str, openrouter_model: str,
|
def create_ai_agent(openrouter_token: str, openrouter_model: str,
|
||||||
fal_token: str, replicate_token: str,
|
fal_token: str, replicate_token: str, tavily_token: str,
|
||||||
db: BasicDatabase, platform: str):
|
db: BasicDatabase, platform: str):
|
||||||
global agent
|
global agent
|
||||||
agent = AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, db, platform)
|
agent = AiAgent(openrouter_token, openrouter_model, fal_token, replicate_token, tavily_token, db, platform)
|
||||||
|
|
||||||
|
|
||||||
def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str:
|
def _add_message_prefix(text: Optional[str], username: Optional[str] = None) -> str:
|
||||||
|
|
|
||||||
|
|
@ -46,5 +46,22 @@
|
||||||
"required": ["prompt", "negative_prompt"]
|
"required": ["prompt", "negative_prompt"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "tavily_search",
|
||||||
|
"description": "Веб-поиск по теме запроса. Используй для поиска информации о событиях, фактах, новостях, определениях, статистике и других актуальных данных.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Запрос для поиска (на русском или английском языке)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
# Генерация изображений
|
# Доступные инструменты
|
||||||
|
|
||||||
|
## Генерация изображений
|
||||||
Если пользователь просит "нарисовать" или "показать" что-то, сгенерируй изображение путем вызова одной из функций.
|
Если пользователь просит "нарисовать" или "показать" что-то, сгенерируй изображение путем вызова одной из функций.
|
||||||
При вызове функции не нужно добавлять сообщение - оно будет отброшено.
|
При вызове функции не нужно добавлять сообщение - оно будет отброшено.
|
||||||
Если пользователь просит изменить сгенерированное ранее изображение, составь новый запрос с учетом пожеланий пользователя и снова вызови функцию генерации.
|
Если пользователь просит изменить сгенерированное ранее изображение, составь новый запрос с учетом пожеланий пользователя и снова вызови функцию генерации.
|
||||||
|
|
@ -7,10 +9,12 @@
|
||||||
- Никогда не генерируй ASCII-арты вместо вызова функции.
|
- Никогда не генерируй ASCII-арты вместо вызова функции.
|
||||||
- Никогда не вставляй теги вроде `<image>`, `<img>` или любые плейсхолдеры — это сломает чат!
|
- Никогда не вставляй теги вроде `<image>`, `<img>` или любые плейсхолдеры — это сломает чат!
|
||||||
|
|
||||||
|
При вызове функции выбери оптимальное соотношение сторон для сцены (задается отдельным параметром функции) на основе контекста беседы или сцены.
|
||||||
|
|
||||||
НИКОГДА НЕ добавляй в ответ параметры или код генерации - пользователю это не нужно!
|
НИКОГДА НЕ добавляй в ответ параметры или код генерации - пользователю это не нужно!
|
||||||
Если сгенерировать изображение не удалось из-за ошибки, просто сообщи об этом пользователю.
|
Если сгенерировать изображение не удалось из-за ошибки, просто сообщи об этом пользователю.
|
||||||
|
|
||||||
## Генерация обычных (не аниме) изображений
|
### Генерация обычных (не аниме) изображений
|
||||||
Для генерации используй функцию `generate_image` и составляй запрос на естественном языке по следующей формуле:
|
Для генерации используй функцию `generate_image` и составляй запрос на естественном языке по следующей формуле:
|
||||||
1. Объекты сцены.
|
1. Объекты сцены.
|
||||||
2. Действие/поза.
|
2. Действие/поза.
|
||||||
|
|
@ -18,11 +22,14 @@
|
||||||
4. Освещение, ракурс, композиция.
|
4. Освещение, ракурс, композиция.
|
||||||
5. Стиль (digital art, cinematic, photorealistic и др).
|
5. Стиль (digital art, cinematic, photorealistic и др).
|
||||||
|
|
||||||
## Генерация изображений в стиле аниме
|
### Генерация изображений в стиле аниме
|
||||||
Для генерации используй функцию `generate_image_anime` и составляй запрос, следуя правилам:
|
Для генерации используй функцию `generate_image_anime` и составляй запрос, следуя правилам:
|
||||||
1. Описывай сцену набором тегов Danbooru для SDXL, обязательно разделяй теги запятыми.
|
1. Описывай сцену набором тегов Danbooru для SDXL, обязательно разделяй теги запятыми.
|
||||||
2. Положительный запрос должен начинаться с `masterpiece, best quality, amazing quality, 4k, very aesthetic, high resolution, ultra-detailed, absurdres, newest, scenery`, а заканчиваться `depth of field, volumetric lighting`.
|
2. Положительный запрос должен начинаться с `masterpiece, best quality, amazing quality, 4k, very aesthetic, high resolution, ultra-detailed, absurdres, newest, scenery`, а заканчиваться `depth of field, volumetric lighting`.
|
||||||
3. Отрицательный запрос должен заканчиваться `modern, recent, old, oldest, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured, long body, lowres, bad anatomy, bad hands, missing fingers, extra digits, fewer digits, cropped, very displeasing, (worst quality, bad quality:1.2), bad anatomy, sketch, jpeg artifacts, signature, watermark, username, signature, simple background, conjoined, bad ai-generated`.
|
3. Отрицательный запрос должен заканчиваться `modern, recent, old, oldest, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured, long body, lowres, bad anatomy, bad hands, missing fingers, extra digits, fewer digits, cropped, very displeasing, (worst quality, bad quality:1.2), bad anatomy, sketch, jpeg artifacts, signature, watermark, username, signature, simple background, conjoined, bad ai-generated`.
|
||||||
4. Ты можешь добавлять тегам веса, например: `1girl, (long hair:1.2), pink hair`.
|
4. Ты можешь добавлять тегам веса, например: `1girl, (long hair:1.2), pink hair`.
|
||||||
|
|
||||||
Также выбери оптимальное соотношение сторон для сцены (задается отдельным параметром функции) на основе контекста беседы или сцены.
|
## Веб-поиск
|
||||||
|
Для поиска информации о событиях, фактах, новостях, определениях, статистике и других актуальных данных используй функцию `tavily_search`.
|
||||||
|
- Вызывай функцию поиска, когда нужна актуальная информация из интернета.
|
||||||
|
- После получения результатов дай пользователю краткую сводку найденной информации.
|
||||||
|
|
@ -26,7 +26,7 @@ async def main() -> None:
|
||||||
|
|
||||||
create_ai_agent(config['openrouter_token'], config['openrouter_model'],
|
create_ai_agent(config['openrouter_token'], config['openrouter_model'],
|
||||||
config['fal_token'], config['replicate_token'],
|
config['fal_token'], config['replicate_token'],
|
||||||
database.DB, 'tg')
|
config['tavily_token'], database.DB, 'tg')
|
||||||
|
|
||||||
bots: list[Bot] = []
|
bots: list[Bot] = []
|
||||||
for item in database.DB.get_bots():
|
for item in database.DB.get_bots():
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
create_ai_agent(config['openrouter_token'], config['openrouter_model'],
|
create_ai_agent(config['openrouter_token'], config['openrouter_model'],
|
||||||
config['fal_token'], config['replicate_token'],
|
config['fal_token'], config['replicate_token'],
|
||||||
database.DB, 'vk')
|
config['tavily_token'], database.DB, 'vk')
|
||||||
|
|
||||||
bot = Bot(labeler=handlers.labeler)
|
bot = Bot(labeler=handlers.labeler)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue