| from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse | from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse | ||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | ||||
| from core.ops.ops_trace_manager import TraceQueueManager | from core.ops.ops_trace_manager import TraceQueueManager | ||||
| from core.prompt.utils.get_thread_messages_length import get_thread_messages_length | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| from models.account import Account | from models.account import Account | ||||
| class AdvancedChatAppGenerator(MessageBasedAppGenerator): | class AdvancedChatAppGenerator(MessageBasedAppGenerator): | ||||
| _dialogue_count: int | |||||
| def generate( | def generate( | ||||
| self, | self, | ||||
| app_model: App, | app_model: App, | ||||
| db.session.commit() | db.session.commit() | ||||
| db.session.refresh(conversation) | db.session.refresh(conversation) | ||||
| # get conversation dialogue count | |||||
| self._dialogue_count = get_thread_messages_length(conversation.id) | |||||
| # init queue manager | # init queue manager | ||||
| queue_manager = MessageBasedAppQueueManager( | queue_manager = MessageBasedAppQueueManager( | ||||
| task_id=application_generate_entity.task_id, | task_id=application_generate_entity.task_id, | ||||
| queue_manager=queue_manager, | queue_manager=queue_manager, | ||||
| conversation=conversation, | conversation=conversation, | ||||
| message=message, | message=message, | ||||
| dialogue_count=self._dialogue_count, | |||||
| ) | ) | ||||
| runner.run() | runner.run() | ||||
| message=message, | message=message, | ||||
| user=user, | user=user, | ||||
| stream=stream, | stream=stream, | ||||
| dialogue_count=self._dialogue_count, | |||||
| ) | ) | ||||
| try: | try: |
| queue_manager: AppQueueManager, | queue_manager: AppQueueManager, | ||||
| conversation: Conversation, | conversation: Conversation, | ||||
| message: Message, | message: Message, | ||||
| dialogue_count: int, | |||||
| ) -> None: | ) -> None: | ||||
| super().__init__(queue_manager) | super().__init__(queue_manager) | ||||
| self.application_generate_entity = application_generate_entity | self.application_generate_entity = application_generate_entity | ||||
| self.conversation = conversation | self.conversation = conversation | ||||
| self.message = message | self.message = message | ||||
| self._dialogue_count = dialogue_count | |||||
| def run(self) -> None: | def run(self) -> None: | ||||
| app_config = self.application_generate_entity.app_config | app_config = self.application_generate_entity.app_config | ||||
| session.commit() | session.commit() | ||||
| # Increment dialogue count. | |||||
| self.conversation.dialogue_count += 1 | |||||
| conversation_dialogue_count = self.conversation.dialogue_count | |||||
| db.session.commit() | |||||
| # Create a variable pool. | # Create a variable pool. | ||||
| system_inputs = { | system_inputs = { | ||||
| SystemVariableKey.QUERY: query, | SystemVariableKey.QUERY: query, | ||||
| SystemVariableKey.FILES: files, | SystemVariableKey.FILES: files, | ||||
| SystemVariableKey.CONVERSATION_ID: self.conversation.id, | SystemVariableKey.CONVERSATION_ID: self.conversation.id, | ||||
| SystemVariableKey.USER_ID: user_id, | SystemVariableKey.USER_ID: user_id, | ||||
| SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, | |||||
| SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count, | |||||
| SystemVariableKey.APP_ID: app_config.app_id, | SystemVariableKey.APP_ID: app_config.app_id, | ||||
| SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, | SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, | ||||
| SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, | SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, |
| message: Message, | message: Message, | ||||
| user: Union[Account, EndUser], | user: Union[Account, EndUser], | ||||
| stream: bool, | stream: bool, | ||||
| dialogue_count: int, | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| Initialize AdvancedChatAppGenerateTaskPipeline. | Initialize AdvancedChatAppGenerateTaskPipeline. | ||||
| :param message: message | :param message: message | ||||
| :param user: user | :param user: user | ||||
| :param stream: stream | :param stream: stream | ||||
| :param dialogue_count: dialogue count | |||||
| """ | """ | ||||
| super().__init__(application_generate_entity, queue_manager, user, stream) | super().__init__(application_generate_entity, queue_manager, user, stream) | ||||
| SystemVariableKey.FILES: application_generate_entity.files, | SystemVariableKey.FILES: application_generate_entity.files, | ||||
| SystemVariableKey.CONVERSATION_ID: conversation.id, | SystemVariableKey.CONVERSATION_ID: conversation.id, | ||||
| SystemVariableKey.USER_ID: user_id, | SystemVariableKey.USER_ID: user_id, | ||||
| SystemVariableKey.DIALOGUE_COUNT: conversation.dialogue_count, | |||||
| SystemVariableKey.DIALOGUE_COUNT: dialogue_count, | |||||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | ||||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | SystemVariableKey.WORKFLOW_ID: workflow.id, | ||||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, |
| from core.prompt.utils.extract_thread_messages import extract_thread_messages | |||||
| from extensions.ext_database import db | |||||
| from models.model import Message | |||||
| def get_thread_messages_length(conversation_id: str) -> int: | |||||
| """ | |||||
| Get the number of thread messages based on the parent message id. | |||||
| """ | |||||
| # Fetch all messages related to the conversation | |||||
| query = ( | |||||
| db.session.query( | |||||
| Message.id, | |||||
| Message.parent_message_id, | |||||
| Message.answer, | |||||
| ) | |||||
| .filter( | |||||
| Message.conversation_id == conversation_id, | |||||
| ) | |||||
| .order_by(Message.created_at.desc()) | |||||
| ) | |||||
| messages = query.all() | |||||
| # Extract thread messages | |||||
| thread_messages = extract_thread_messages(messages) | |||||
| # Exclude the newly created message with an empty answer | |||||
| if thread_messages and not thread_messages[0].answer: | |||||
| thread_messages.pop(0) | |||||
| return len(thread_messages) |