| @@ -23,6 +23,7 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, | |||
| from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.prompt.utils.get_thread_messages_length import get_thread_messages_length | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| from models.account import Account | |||
| @@ -33,6 +34,8 @@ logger = logging.getLogger(__name__) | |||
| class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| _dialogue_count: int | |||
| def generate( | |||
| self, | |||
| app_model: App, | |||
| @@ -211,6 +214,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| db.session.commit() | |||
| db.session.refresh(conversation) | |||
| # get conversation dialogue count | |||
| self._dialogue_count = get_thread_messages_length(conversation.id) | |||
| # init queue manager | |||
| queue_manager = MessageBasedAppQueueManager( | |||
| task_id=application_generate_entity.task_id, | |||
| @@ -281,6 +287,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| queue_manager=queue_manager, | |||
| conversation=conversation, | |||
| message=message, | |||
| dialogue_count=self._dialogue_count, | |||
| ) | |||
| runner.run() | |||
| @@ -334,6 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| message=message, | |||
| user=user, | |||
| stream=stream, | |||
| dialogue_count=self._dialogue_count, | |||
| ) | |||
| try: | |||
| @@ -39,12 +39,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| queue_manager: AppQueueManager, | |||
| conversation: Conversation, | |||
| message: Message, | |||
| dialogue_count: int, | |||
| ) -> None: | |||
| super().__init__(queue_manager) | |||
| self.application_generate_entity = application_generate_entity | |||
| self.conversation = conversation | |||
| self.message = message | |||
| self._dialogue_count = dialogue_count | |||
| def run(self) -> None: | |||
| app_config = self.application_generate_entity.app_config | |||
| @@ -122,19 +124,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| session.commit() | |||
| # Increment dialogue count. | |||
| self.conversation.dialogue_count += 1 | |||
| conversation_dialogue_count = self.conversation.dialogue_count | |||
| db.session.commit() | |||
| # Create a variable pool. | |||
| system_inputs = { | |||
| SystemVariableKey.QUERY: query, | |||
| SystemVariableKey.FILES: files, | |||
| SystemVariableKey.CONVERSATION_ID: self.conversation.id, | |||
| SystemVariableKey.USER_ID: user_id, | |||
| SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, | |||
| SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count, | |||
| SystemVariableKey.APP_ID: app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, | |||
| @@ -88,6 +88,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| message: Message, | |||
| user: Union[Account, EndUser], | |||
| stream: bool, | |||
| dialogue_count: int, | |||
| ) -> None: | |||
| """ | |||
| Initialize AdvancedChatAppGenerateTaskPipeline. | |||
| @@ -98,6 +99,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| :param message: message | |||
| :param user: user | |||
| :param stream: stream | |||
| :param dialogue_count: dialogue count | |||
| """ | |||
| super().__init__(application_generate_entity, queue_manager, user, stream) | |||
| @@ -114,7 +116,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| SystemVariableKey.FILES: application_generate_entity.files, | |||
| SystemVariableKey.CONVERSATION_ID: conversation.id, | |||
| SystemVariableKey.USER_ID: user_id, | |||
| SystemVariableKey.DIALOGUE_COUNT: conversation.dialogue_count, | |||
| SystemVariableKey.DIALOGUE_COUNT: dialogue_count, | |||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||
| @@ -0,0 +1,32 @@ | |||
| from core.prompt.utils.extract_thread_messages import extract_thread_messages | |||
| from extensions.ext_database import db | |||
| from models.model import Message | |||
| def get_thread_messages_length(conversation_id: str) -> int: | |||
| """ | |||
| Get the number of thread messages based on the parent message id. | |||
| """ | |||
| # Fetch all messages related to the conversation | |||
| query = ( | |||
| db.session.query( | |||
| Message.id, | |||
| Message.parent_message_id, | |||
| Message.answer, | |||
| ) | |||
| .filter( | |||
| Message.conversation_id == conversation_id, | |||
| ) | |||
| .order_by(Message.created_at.desc()) | |||
| ) | |||
| messages = query.all() | |||
| # Extract thread messages | |||
| thread_messages = extract_thread_messages(messages) | |||
| # Exclude the newly created message with an empty answer | |||
| if thread_messages and not thread_messages[0].answer: | |||
| thread_messages.pop(0) | |||
| return len(thread_messages) | |||