| import json | import json | ||||
| from typing import Optional, Union | from typing import Optional, Union | ||||
| from gunicorn.config import User | |||||
| from core.callback_handler.entity.agent_loop import AgentLoop | from core.callback_handler.entity.agent_loop import AgentLoop | ||||
| from core.callback_handler.entity.dataset_query import DatasetQueryObj | from core.callback_handler.entity.dataset_query import DatasetQueryObj | ||||
| from core.callback_handler.entity.llm_message import LLMMessage | from core.callback_handler.entity.llm_message import LLMMessage | ||||
| class PubHandler: | class PubHandler: | ||||
| def __init__(self, user: Union[Account | User], task_id: str, | |||||
| def __init__(self, user: Union[Account | EndUser], task_id: str, | |||||
| message: Message, conversation: Conversation, | message: Message, conversation: Conversation, | ||||
| chain_pub: bool = False, agent_thought_pub: bool = False): | chain_pub: bool = False, agent_thought_pub: bool = False): | ||||
| self._channel = PubHandler.generate_channel_name(user, task_id) | self._channel = PubHandler.generate_channel_name(user, task_id) | ||||
| self._agent_thought_pub = agent_thought_pub | self._agent_thought_pub = agent_thought_pub | ||||
| @classmethod | @classmethod | ||||
| def generate_channel_name(cls, user: Union[Account | User], task_id: str): | |||||
| def generate_channel_name(cls, user: Union[Account | EndUser], task_id: str): | |||||
| user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id | user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id | ||||
| return "generate_result:{}-{}".format(user_str, task_id) | return "generate_result:{}-{}".format(user_str, task_id) | ||||
| @classmethod | @classmethod | ||||
| def generate_stopped_cache_key(cls, user: Union[Account | User], task_id: str): | |||||
| def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str): | |||||
| user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id | user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id | ||||
| return "generate_result_stopped:{}-{}".format(user_str, task_id) | return "generate_result_stopped:{}-{}".format(user_str, task_id) | ||||
| redis_client.publish(self._channel, json.dumps(content)) | redis_client.publish(self._channel, json.dumps(content)) | ||||
| @classmethod | @classmethod | ||||
| def pub_error(cls, user: Union[Account | User], task_id: str, e): | |||||
| def pub_error(cls, user: Union[Account | EndUser], task_id: str, e): | |||||
| content = { | content = { | ||||
| 'error': type(e).__name__, | 'error': type(e).__name__, | ||||
| 'description': e.description if getattr(e, 'description', None) is not None else str(e) | 'description': e.description if getattr(e, 'description', None) is not None else str(e) | ||||
| return redis_client.get(self._stopped_cache_key) is not None | return redis_client.get(self._stopped_cache_key) is not None | ||||
| @classmethod | @classmethod | ||||
| def stop(cls, user: Union[Account | User], task_id: str): | |||||
| def stop(cls, user: Union[Account | EndUser], task_id: str): | |||||
| stopped_cache_key = cls.generate_stopped_cache_key(user, task_id) | stopped_cache_key = cls.generate_stopped_cache_key(user, task_id) | ||||
| redis_client.setex(stopped_cache_key, 600, 1) | redis_client.setex(stopped_cache_key, 600, 1) | ||||