| @@ -2,8 +2,6 @@ import decimal | |||
| import json | |||
| from typing import Optional, Union | |||
| from gunicorn.config import User | |||
| from core.callback_handler.entity.agent_loop import AgentLoop | |||
| from core.callback_handler.entity.dataset_query import DatasetQueryObj | |||
| from core.callback_handler.entity.llm_message import LLMMessage | |||
| @@ -269,7 +267,7 @@ class ConversationMessageTask: | |||
| class PubHandler: | |||
| def __init__(self, user: Union[Account | User], task_id: str, | |||
| def __init__(self, user: Union[Account | EndUser], task_id: str, | |||
| message: Message, conversation: Conversation, | |||
| chain_pub: bool = False, agent_thought_pub: bool = False): | |||
| self._channel = PubHandler.generate_channel_name(user, task_id) | |||
| @@ -282,12 +280,12 @@ class PubHandler: | |||
| self._agent_thought_pub = agent_thought_pub | |||
| @classmethod | |||
| def generate_channel_name(cls, user: Union[Account | User], task_id: str): | |||
| def generate_channel_name(cls, user: Union[Account | EndUser], task_id: str): | |||
| user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id | |||
| return "generate_result:{}-{}".format(user_str, task_id) | |||
| @classmethod | |||
| def generate_stopped_cache_key(cls, user: Union[Account | User], task_id: str): | |||
| def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str): | |||
| user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id | |||
| return "generate_result_stopped:{}-{}".format(user_str, task_id) | |||
| @@ -366,7 +364,7 @@ class PubHandler: | |||
| redis_client.publish(self._channel, json.dumps(content)) | |||
| @classmethod | |||
| def pub_error(cls, user: Union[Account | User], task_id: str, e): | |||
| def pub_error(cls, user: Union[Account | EndUser], task_id: str, e): | |||
| content = { | |||
| 'error': type(e).__name__, | |||
| 'description': e.description if getattr(e, 'description', None) is not None else str(e) | |||
| @@ -379,7 +377,7 @@ class PubHandler: | |||
| return redis_client.get(self._stopped_cache_key) is not None | |||
| @classmethod | |||
| def stop(cls, user: Union[Account | User], task_id: str): | |||
| def stop(cls, user: Union[Account | EndUser], task_id: str): | |||
| stopped_cache_key = cls.generate_stopped_cache_key(user, task_id) | |||
| redis_client.setex(stopped_cache_key, 600, 1) | |||