diff --git a/tg/handlers/default.py b/tg/handlers/default.py index 74b9e3f..9735af0 100644 --- a/tg/handlers/default.py +++ b/tg/handlers/default.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Optional from aiogram import Router, F, Bot @@ -92,8 +93,10 @@ async def any_message_handler(message: Message): # noinspection PyUnresolvedReferences agent = AiAgent(message.bot.config['openrouter_token']) - await message.bot.send_chat_action(chat_id, 'typing') - await message.reply(await agent.get_reply(chat_id, chat_prompt, ai_message, ai_fwd_messages)) + await message.reply( + await utils.run_with_progress(partial(agent.get_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages), + partial(message.bot.send_chat_action, chat_id, 'typing'), + interval=4)) def clear_ai_chat_context(chat_id: int): @@ -114,5 +117,6 @@ async def get_ai_reply(bot: Bot, chat_id, user: User, message: str, fwd_user: Us # noinspection PyUnresolvedReferences agent = AiAgent(bot.config['openrouter_token']) - await bot.send_chat_action(chat_id, 'typing') - return await agent.get_reply(chat_id, chat_prompt, ai_message, ai_fwd_messages) + return await utils.run_with_progress(partial(agent.get_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages), + partial(bot.send_chat_action(chat_id, 'typing')), + interval=4) diff --git a/utils.py b/utils.py index 5178918..965386e 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,6 @@ +import asyncio from calendar import timegm -from typing import Optional +from typing import Awaitable, Callable, Coroutine, Optional from pymorphy3 import MorphAnalyzer from time import gmtime @@ -21,3 +22,26 @@ def full_name(first_name: str, last_name: Optional[str]) -> str: if last_name is not None: return f"{first_name} {last_name}" return first_name + + +async def run_with_progress(main_func: Callable[[], Coroutine], progress_func: Callable[[], Awaitable], interval: int): + completion_event = asyncio.Event() + + async def progress(): + while not completion_event.is_set(): + await progress_func() + wait_event_task = asyncio.create_task(completion_event.wait()) + wait_timer_task = asyncio.create_task(asyncio.sleep(interval)) + await asyncio.wait([wait_event_task, wait_timer_task], + return_when=asyncio.FIRST_COMPLETED) + if completion_event.is_set(): + wait_timer_task.cancel() + + progress_task = asyncio.create_task(progress()) + main_task = asyncio.create_task(main_func()) + + result = await main_task + completion_event.set() + await progress_task + + return result diff --git a/vk/handlers/default.py b/vk/handlers/default.py index f60ae78..30c9ef6 100644 --- a/vk/handlers/default.py +++ b/vk/handlers/default.py @@ -1,4 +1,5 @@ import re +from functools import partial from typing import Optional, Tuple, List from vkbottle import API @@ -87,8 +88,10 @@ async def any_message_handler(message: Message): # noinspection PyUnresolvedReferences agent = AiAgent(message.ctx_api.config['openrouter_token']) - await message.ctx_api.messages.set_activity(peer_id=chat_id, type='typing') - await message.reply(await agent.get_reply(chat_id, chat_prompt, ai_message, ai_fwd_messages)) + await message.reply( + await utils.run_with_progress(partial(agent.get_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages), + partial(message.ctx_api.messages.set_activity, peer_id=chat_id, type='typing'), + interval=4)) def clear_ai_chat_context(chat_id: int): @@ -112,5 +115,6 @@ async def get_ai_reply(api: API, chat_id, message: Tuple[int, str], fwd_messages # noinspection PyUnresolvedReferences agent = AiAgent(api.config['openrouter_token']) - await api.messages.set_activity(peer_id=chat_id, type='typing') - return await agent.get_reply(chat_id, chat_prompt, ai_message, ai_fwd_messages) + return await utils.run_with_progress(partial(agent.get_reply, chat_id, chat_prompt, ai_message, ai_fwd_messages), + partial(api.messages.set_activity, peer_id=chat_id, type='typing'), + interval=4)