| @@ -98,6 +98,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| ) | |||
| self._stream_generate_routes = self._get_stream_generate_routes() | |||
| self._conversation_name_generate_thread = None | |||
| def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: | |||
| """ | |||
| @@ -108,6 +109,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| db.session.refresh(self._user) | |||
| db.session.close() | |||
| # start generate conversation name thread | |||
| self._conversation_name_generate_thread = self._generate_conversation_name( | |||
| self._conversation, | |||
| self._application_generate_entity.query | |||
| ) | |||
| generator = self._process_stream_response() | |||
| if self._stream: | |||
| return self._to_stream_response(generator) | |||
| @@ -278,6 +285,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| else: | |||
| continue | |||
| if self._conversation_name_generate_thread: | |||
| self._conversation_name_generate_thread.join() | |||
| def _save_message(self) -> None: | |||
| """ | |||
| Save message. | |||
| @@ -97,6 +97,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| ) | |||
| ) | |||
| self._conversation_name_generate_thread = None | |||
| def process(self) -> Union[ | |||
| ChatbotAppBlockingResponse, | |||
| CompletionAppBlockingResponse, | |||
| @@ -110,6 +112,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| db.session.refresh(self._message) | |||
| db.session.close() | |||
| if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: | |||
| # start generate conversation name thread | |||
| self._conversation_name_generate_thread = self._generate_conversation_name( | |||
| self._conversation, | |||
| self._application_generate_entity.query | |||
| ) | |||
| generator = self._process_stream_response() | |||
| if self._stream: | |||
| return self._to_stream_response(generator) | |||
| @@ -256,6 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| else: | |||
| continue | |||
| if self._conversation_name_generate_thread: | |||
| self._conversation_name_generate_thread.join() | |||
| def _save_message(self) -> None: | |||
| """ | |||
| Save message. | |||
| @@ -1,5 +1,8 @@ | |||
| from threading import Thread | |||
| from typing import Optional, Union | |||
| from flask import Flask, current_app | |||
| from core.app.entities.app_invoke_entities import ( | |||
| AdvancedChatAppGenerateEntity, | |||
| AgentChatAppGenerateEntity, | |||
| @@ -19,9 +22,10 @@ from core.app.entities.task_entities import ( | |||
| MessageReplaceStreamResponse, | |||
| MessageStreamResponse, | |||
| ) | |||
| from core.llm_generator.llm_generator import LLMGenerator | |||
| from core.tools.tool_file_manager import ToolFileManager | |||
| from extensions.ext_database import db | |||
| from models.model import MessageAnnotation, MessageFile | |||
| from models.model import AppMode, Conversation, MessageAnnotation, MessageFile | |||
| from services.annotation_service import AppAnnotationService | |||
| @@ -34,6 +38,59 @@ class MessageCycleManage: | |||
| ] | |||
| _task_state: Union[EasyUITaskState, AdvancedChatTaskState] | |||
| def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: | |||
| """ | |||
| Generate conversation name. | |||
| :param conversation: conversation | |||
| :param query: query | |||
| :return: thread | |||
| """ | |||
| is_first_message = self._application_generate_entity.conversation_id is None | |||
| extras = self._application_generate_entity.extras | |||
| auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) | |||
| if auto_generate_conversation_name and is_first_message: | |||
| # start generate thread | |||
| thread = Thread(target=self._generate_conversation_name_worker, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'conversation_id': conversation.id, | |||
| 'query': query | |||
| }) | |||
| thread.start() | |||
| return thread | |||
| return None | |||
| def _generate_conversation_name_worker(self, | |||
| flask_app: Flask, | |||
| conversation_id: str, | |||
| query: str): | |||
| with flask_app.app_context(): | |||
| # get conversation and message | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter(Conversation.id == conversation_id) | |||
| .first() | |||
| ) | |||
| if conversation.mode != AppMode.COMPLETION.value: | |||
| app_model = conversation.app | |||
| if not app_model: | |||
| return | |||
| # generate conversation name | |||
| try: | |||
| name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query) | |||
| conversation.name = name | |||
| except: | |||
| pass | |||
| db.session.merge(conversation) | |||
| db.session.commit() | |||
| db.session.close() | |||
| def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: | |||
| """ | |||
| Handle annotation reply. | |||
| @@ -5,7 +5,6 @@ from .create_installed_app_when_app_created import handle | |||
| from .create_site_record_when_app_created import handle | |||
| from .deduct_quota_when_messaeg_created import handle | |||
| from .delete_installed_app_when_app_deleted import handle | |||
| from .generate_conversation_name_when_first_message_created import handle | |||
| from .update_app_dataset_join_when_app_model_config_updated import handle | |||
| from .update_provider_last_used_at_when_messaeg_created import handle | |||
| from .update_app_dataset_join_when_app_published_workflow_updated import handle | |||
| @@ -1,32 +0,0 @@ | |||
| from core.llm_generator.llm_generator import LLMGenerator | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from models.model import AppMode | |||
| @message_was_created.connect | |||
| def handle(sender, **kwargs): | |||
| message = sender | |||
| conversation = kwargs.get('conversation') | |||
| is_first_message = kwargs.get('is_first_message') | |||
| extras = kwargs.get('extras', {}) | |||
| auto_generate_conversation_name = True | |||
| if extras: | |||
| auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) | |||
| if auto_generate_conversation_name and is_first_message: | |||
| if conversation.mode != AppMode.COMPLETION.value: | |||
| app_model = conversation.app | |||
| if not app_model: | |||
| return | |||
| # generate conversation name | |||
| try: | |||
| name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query) | |||
| conversation.name = name | |||
| except: | |||
| pass | |||
| db.session.merge(conversation) | |||
| db.session.commit() | |||