Signed-off-by: -LAN- <laipz8200@outlook.com>tags/0.15.0
| @@ -5,6 +5,9 @@ from collections.abc import Generator, Mapping | |||
| from threading import Thread | |||
| from typing import Any, Optional, Union | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | |||
| from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| @@ -79,8 +82,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| _task_state: WorkflowTaskState | |||
| _application_generate_entity: AdvancedChatAppGenerateEntity | |||
| _workflow: Workflow | |||
| _user: Union[Account, EndUser] | |||
| _workflow_system_variables: dict[SystemVariableKey, Any] | |||
| _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] | |||
| _conversation_name_generate_thread: Optional[Thread] = None | |||
| @@ -96,32 +97,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| stream: bool, | |||
| dialogue_count: int, | |||
| ) -> None: | |||
| """ | |||
| Initialize AdvancedChatAppGenerateTaskPipeline. | |||
| :param application_generate_entity: application generate entity | |||
| :param workflow: workflow | |||
| :param queue_manager: queue manager | |||
| :param conversation: conversation | |||
| :param message: message | |||
| :param user: user | |||
| :param stream: stream | |||
| :param dialogue_count: dialogue count | |||
| """ | |||
| super().__init__(application_generate_entity, queue_manager, user, stream) | |||
| super().__init__( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| stream=stream, | |||
| ) | |||
| if isinstance(self._user, EndUser): | |||
| user_id = self._user.session_id | |||
| if isinstance(user, EndUser): | |||
| self._user_id = user.session_id | |||
| self._created_by_role = CreatedByRole.END_USER | |||
| elif isinstance(user, Account): | |||
| self._user_id = user.id | |||
| self._created_by_role = CreatedByRole.ACCOUNT | |||
| else: | |||
| user_id = self._user.id | |||
| raise NotImplementedError(f"User type not supported: {type(user)}") | |||
| self._workflow_id = workflow.id | |||
| self._workflow_features_dict = workflow.features_dict | |||
| self._conversation_id = conversation.id | |||
| self._conversation_mode = conversation.mode | |||
| self._message_id = message.id | |||
| self._message_created_at = int(message.created_at.timestamp()) | |||
| self._workflow = workflow | |||
| self._conversation = conversation | |||
| self._message = message | |||
| self._workflow_system_variables = { | |||
| SystemVariableKey.QUERY: message.query, | |||
| SystemVariableKey.FILES: application_generate_entity.files, | |||
| SystemVariableKey.CONVERSATION_ID: conversation.id, | |||
| SystemVariableKey.USER_ID: user_id, | |||
| SystemVariableKey.USER_ID: self._user_id, | |||
| SystemVariableKey.DIALOGUE_COUNT: dialogue_count, | |||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| @@ -139,13 +143,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| Process generate task pipeline. | |||
| :return: | |||
| """ | |||
| db.session.refresh(self._workflow) | |||
| db.session.refresh(self._user) | |||
| db.session.close() | |||
| # start generate conversation name thread | |||
| self._conversation_name_generate_thread = self._generate_conversation_name( | |||
| self._conversation, self._application_generate_entity.query | |||
| conversation_id=self._conversation_id, query=self._application_generate_entity.query | |||
| ) | |||
| generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) | |||
| @@ -171,12 +171,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| return ChatbotAppBlockingResponse( | |||
| task_id=stream_response.task_id, | |||
| data=ChatbotAppBlockingResponse.Data( | |||
| id=self._message.id, | |||
| mode=self._conversation.mode, | |||
| conversation_id=self._conversation.id, | |||
| message_id=self._message.id, | |||
| id=self._message_id, | |||
| mode=self._conversation_mode, | |||
| conversation_id=self._conversation_id, | |||
| message_id=self._message_id, | |||
| answer=self._task_state.answer, | |||
| created_at=int(self._message.created_at.timestamp()), | |||
| created_at=self._message_created_at, | |||
| **extras, | |||
| ), | |||
| ) | |||
| @@ -194,9 +194,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| """ | |||
| for stream_response in generator: | |||
| yield ChatbotAppStreamResponse( | |||
| conversation_id=self._conversation.id, | |||
| message_id=self._message.id, | |||
| created_at=int(self._message.created_at.timestamp()), | |||
| conversation_id=self._conversation_id, | |||
| message_id=self._message_id, | |||
| created_at=self._message_created_at, | |||
| stream_response=stream_response, | |||
| ) | |||
| @@ -214,7 +214,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| tts_publisher = None | |||
| task_id = self._application_generate_entity.task_id | |||
| tenant_id = self._application_generate_entity.app_config.tenant_id | |||
| features_dict = self._workflow.features_dict | |||
| features_dict = self._workflow_features_dict | |||
| if ( | |||
| features_dict.get("text_to_speech") | |||
| @@ -274,26 +274,33 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if isinstance(event, QueuePingEvent): | |||
| yield self._ping_stream_response() | |||
| elif isinstance(event, QueueErrorEvent): | |||
| err = self._handle_error(event, self._message) | |||
| with Session(db.engine) as session: | |||
| err = self._handle_error(event=event, session=session, message_id=self._message_id) | |||
| session.commit() | |||
| yield self._error_to_stream_response(err) | |||
| break | |||
| elif isinstance(event, QueueWorkflowStartedEvent): | |||
| # override graph runtime state | |||
| graph_runtime_state = event.graph_runtime_state | |||
| # init workflow run | |||
| workflow_run = self._handle_workflow_run_start() | |||
| self._refetch_message() | |||
| self._message.workflow_run_id = workflow_run.id | |||
| db.session.commit() | |||
| db.session.refresh(self._message) | |||
| db.session.close() | |||
| yield self._workflow_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| with Session(db.engine) as session: | |||
| # init workflow run | |||
| workflow_run = self._handle_workflow_run_start( | |||
| session=session, | |||
| workflow_id=self._workflow_id, | |||
| user_id=self._user_id, | |||
| created_by_role=self._created_by_role, | |||
| ) | |||
| message = self._get_message(session=session) | |||
| if not message: | |||
| raise ValueError(f"Message not found: {self._message_id}") | |||
| message.workflow_run_id = workflow_run.id | |||
| session.commit() | |||
| workflow_start_resp = self._workflow_start_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| yield workflow_start_resp | |||
| elif isinstance( | |||
| event, | |||
| QueueNodeRetryEvent, | |||
| @@ -304,28 +311,28 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| workflow_run=workflow_run, event=event | |||
| ) | |||
| response = self._workflow_node_retry_to_stream_response( | |||
| node_retry_resp = self._workflow_node_retry_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| if response: | |||
| yield response | |||
| if node_retry_resp: | |||
| yield node_retry_resp | |||
| elif isinstance(event, QueueNodeStartedEvent): | |||
| if not workflow_run: | |||
| raise ValueError("workflow run not initialized.") | |||
| workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) | |||
| response_start = self._workflow_node_start_to_stream_response( | |||
| node_start_resp = self._workflow_node_start_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| if response_start: | |||
| yield response_start | |||
| if node_start_resp: | |||
| yield node_start_resp | |||
| elif isinstance(event, QueueNodeSucceededEvent): | |||
| workflow_node_execution = self._handle_workflow_node_execution_success(event) | |||
| @@ -333,25 +340,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if event.node_type in [NodeType.ANSWER, NodeType.END]: | |||
| self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) | |||
| response_finish = self._workflow_node_finish_to_stream_response( | |||
| node_finish_resp = self._workflow_node_finish_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| if response_finish: | |||
| yield response_finish | |||
| if node_finish_resp: | |||
| yield node_finish_resp | |||
| elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): | |||
| workflow_node_execution = self._handle_workflow_node_execution_failed(event) | |||
| response_finish = self._workflow_node_finish_to_stream_response( | |||
| node_finish_resp = self._workflow_node_finish_to_stream_response( | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_node_execution=workflow_node_execution, | |||
| ) | |||
| if response: | |||
| yield response | |||
| if node_finish_resp: | |||
| yield node_finish_resp | |||
| elif isinstance(event, QueueParallelBranchRunStartedEvent): | |||
| if not workflow_run: | |||
| @@ -395,20 +401,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if not graph_runtime_state: | |||
| raise ValueError("workflow run not initialized.") | |||
| workflow_run = self._handle_workflow_run_success( | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| conversation_id=self._conversation.id, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_success( | |||
| session=session, | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| conversation_id=self._conversation_id, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) | |||
| elif isinstance(event, QueueWorkflowPartialSuccessEvent): | |||
| if not workflow_run: | |||
| @@ -417,21 +427,25 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| workflow_run = self._handle_workflow_run_partial_success( | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| exceptions_count=event.exceptions_count, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_partial_success( | |||
| session=session, | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| exceptions_count=event.exceptions_count, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) | |||
| elif isinstance(event, QueueWorkflowFailedEvent): | |||
| if not workflow_run: | |||
| @@ -440,71 +454,73 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| workflow_run = self._handle_workflow_run_failed( | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.FAILED, | |||
| error=event.error, | |||
| conversation_id=self._conversation.id, | |||
| trace_manager=trace_manager, | |||
| exceptions_count=event.exceptions_count, | |||
| ) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) | |||
| yield self._error_to_stream_response(self._handle_error(err_event, self._message)) | |||
| break | |||
| elif isinstance(event, QueueStopEvent): | |||
| if workflow_run and graph_runtime_state: | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_failed( | |||
| session=session, | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.STOPPED, | |||
| error=event.get_stop_reason(), | |||
| conversation_id=self._conversation.id, | |||
| status=WorkflowRunStatus.FAILED, | |||
| error=event.error, | |||
| conversation_id=self._conversation_id, | |||
| trace_manager=trace_manager, | |||
| exceptions_count=event.exceptions_count, | |||
| ) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| # Save message | |||
| self._save_message(graph_runtime_state=graph_runtime_state) | |||
| err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) | |||
| err = self._handle_error(event=err_event, session=session, message_id=self._message_id) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| yield self._error_to_stream_response(err) | |||
| break | |||
| elif isinstance(event, QueueStopEvent): | |||
| if workflow_run and graph_runtime_state: | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_failed( | |||
| session=session, | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.STOPPED, | |||
| error=event.get_stop_reason(), | |||
| conversation_id=self._conversation_id, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| ) | |||
| # Save message | |||
| self._save_message(session=session, graph_runtime_state=graph_runtime_state) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| yield self._message_end_to_stream_response() | |||
| break | |||
| elif isinstance(event, QueueRetrieverResourcesEvent): | |||
| self._handle_retriever_resources(event) | |||
| self._refetch_message() | |||
| self._message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| db.session.commit() | |||
| db.session.refresh(self._message) | |||
| db.session.close() | |||
| with Session(db.engine) as session: | |||
| message = self._get_message(session=session) | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| session.commit() | |||
| elif isinstance(event, QueueAnnotationReplyEvent): | |||
| self._handle_annotation_reply(event) | |||
| self._refetch_message() | |||
| self._message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| db.session.commit() | |||
| db.session.refresh(self._message) | |||
| db.session.close() | |||
| with Session(db.engine) as session: | |||
| message = self._get_message(session=session) | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| session.commit() | |||
| elif isinstance(event, QueueTextChunkEvent): | |||
| delta_text = event.text | |||
| if delta_text is None: | |||
| @@ -521,7 +537,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| self._task_state.answer += delta_text | |||
| yield self._message_to_stream_response( | |||
| answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector | |||
| answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector | |||
| ) | |||
| elif isinstance(event, QueueMessageReplaceEvent): | |||
| # published by moderation | |||
| @@ -536,7 +552,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| yield self._message_replace_to_stream_response(answer=output_moderation_answer) | |||
| # Save message | |||
| self._save_message(graph_runtime_state=graph_runtime_state) | |||
| with Session(db.engine) as session: | |||
| self._save_message(session=session, graph_runtime_state=graph_runtime_state) | |||
| session.commit() | |||
| yield self._message_end_to_stream_response() | |||
| else: | |||
| @@ -549,54 +567,46 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if self._conversation_name_generate_thread: | |||
| self._conversation_name_generate_thread.join() | |||
| def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: | |||
| self._refetch_message() | |||
| self._message.answer = self._task_state.answer | |||
| self._message.provider_response_latency = time.perf_counter() - self._start_at | |||
| self._message.message_metadata = ( | |||
| def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: | |||
| message = self._get_message(session=session) | |||
| message.answer = self._task_state.answer | |||
| message.provider_response_latency = time.perf_counter() - self._start_at | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| message_files = [ | |||
| MessageFile( | |||
| message_id=self._message.id, | |||
| message_id=message.id, | |||
| type=file["type"], | |||
| transfer_method=file["transfer_method"], | |||
| url=file["remote_url"], | |||
| belongs_to="assistant", | |||
| upload_file_id=file["related_id"], | |||
| created_by_role=CreatedByRole.ACCOUNT | |||
| if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | |||
| if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} | |||
| else CreatedByRole.END_USER, | |||
| created_by=self._message.from_account_id or self._message.from_end_user_id or "", | |||
| created_by=message.from_account_id or message.from_end_user_id or "", | |||
| ) | |||
| for file in self._recorded_files | |||
| ] | |||
| db.session.add_all(message_files) | |||
| session.add_all(message_files) | |||
| if graph_runtime_state and graph_runtime_state.llm_usage: | |||
| usage = graph_runtime_state.llm_usage | |||
| self._message.message_tokens = usage.prompt_tokens | |||
| self._message.message_unit_price = usage.prompt_unit_price | |||
| self._message.message_price_unit = usage.prompt_price_unit | |||
| self._message.answer_tokens = usage.completion_tokens | |||
| self._message.answer_unit_price = usage.completion_unit_price | |||
| self._message.answer_price_unit = usage.completion_price_unit | |||
| self._message.total_price = usage.total_price | |||
| self._message.currency = usage.currency | |||
| message.message_tokens = usage.prompt_tokens | |||
| message.message_unit_price = usage.prompt_unit_price | |||
| message.message_price_unit = usage.prompt_price_unit | |||
| message.answer_tokens = usage.completion_tokens | |||
| message.answer_unit_price = usage.completion_unit_price | |||
| message.answer_price_unit = usage.completion_price_unit | |||
| message.total_price = usage.total_price | |||
| message.currency = usage.currency | |||
| self._task_state.metadata["usage"] = jsonable_encoder(usage) | |||
| else: | |||
| self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) | |||
| db.session.commit() | |||
| message_was_created.send( | |||
| self._message, | |||
| message, | |||
| application_generate_entity=self._application_generate_entity, | |||
| conversation=self._conversation, | |||
| is_first_message=self._application_generate_entity.conversation_id is None, | |||
| extras=self._application_generate_entity.extras, | |||
| ) | |||
| def _message_end_to_stream_response(self) -> MessageEndStreamResponse: | |||
| @@ -613,7 +623,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| return MessageEndStreamResponse( | |||
| task_id=self._application_generate_entity.task_id, | |||
| id=self._message.id, | |||
| id=self._message_id, | |||
| files=self._recorded_files, | |||
| metadata=extras.get("metadata", {}), | |||
| ) | |||
| @@ -641,11 +651,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| return False | |||
| def _refetch_message(self) -> None: | |||
| """ | |||
| Refetch message. | |||
| :return: | |||
| """ | |||
| message = db.session.query(Message).filter(Message.id == self._message.id).first() | |||
| if message: | |||
| self._message = message | |||
| def _get_message(self, *, session: Session): | |||
| stmt = select(Message).where(Message.id == self._message_id) | |||
| message = session.scalar(stmt) | |||
| if not message: | |||
| raise ValueError(f"Message not found: {self._message_id}") | |||
| return message | |||
| @@ -70,7 +70,6 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| queue_manager=queue_manager, | |||
| conversation=conversation, | |||
| message=message, | |||
| user=user, | |||
| stream=stream, | |||
| ) | |||
| @@ -3,6 +3,8 @@ import time | |||
| from collections.abc import Generator | |||
| from typing import Any, Optional, Union | |||
| from sqlalchemy.orm import Session | |||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | |||
| from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| @@ -50,6 +52,7 @@ from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.enums import SystemVariableKey | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.enums import CreatedByRole | |||
| from models.model import EndUser | |||
| from models.workflow import ( | |||
| Workflow, | |||
| @@ -68,8 +71,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. | |||
| """ | |||
| _workflow: Workflow | |||
| _user: Union[Account, EndUser] | |||
| _task_state: WorkflowTaskState | |||
| _application_generate_entity: WorkflowAppGenerateEntity | |||
| _workflow_system_variables: dict[SystemVariableKey, Any] | |||
| @@ -83,25 +84,27 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| user: Union[Account, EndUser], | |||
| stream: bool, | |||
| ) -> None: | |||
| """ | |||
| Initialize GenerateTaskPipeline. | |||
| :param application_generate_entity: application generate entity | |||
| :param workflow: workflow | |||
| :param queue_manager: queue manager | |||
| :param user: user | |||
| :param stream: is streamed | |||
| """ | |||
| super().__init__(application_generate_entity, queue_manager, user, stream) | |||
| super().__init__( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| stream=stream, | |||
| ) | |||
| if isinstance(self._user, EndUser): | |||
| user_id = self._user.session_id | |||
| if isinstance(user, EndUser): | |||
| self._user_id = user.session_id | |||
| self._created_by_role = CreatedByRole.END_USER | |||
| elif isinstance(user, Account): | |||
| self._user_id = user.id | |||
| self._created_by_role = CreatedByRole.ACCOUNT | |||
| else: | |||
| user_id = self._user.id | |||
| raise ValueError(f"Invalid user type: {type(user)}") | |||
| self._workflow_id = workflow.id | |||
| self._workflow_features_dict = workflow.features_dict | |||
| self._workflow = workflow | |||
| self._workflow_system_variables = { | |||
| SystemVariableKey.FILES: application_generate_entity.files, | |||
| SystemVariableKey.USER_ID: user_id, | |||
| SystemVariableKey.USER_ID: self._user_id, | |||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||
| @@ -115,10 +118,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| Process generate task pipeline. | |||
| :return: | |||
| """ | |||
| db.session.refresh(self._workflow) | |||
| db.session.refresh(self._user) | |||
| db.session.close() | |||
| generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) | |||
| if self._stream: | |||
| return self._to_stream_response(generator) | |||
| @@ -185,7 +184,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| tts_publisher = None | |||
| task_id = self._application_generate_entity.task_id | |||
| tenant_id = self._application_generate_entity.app_config.tenant_id | |||
| features_dict = self._workflow.features_dict | |||
| features_dict = self._workflow_features_dict | |||
| if ( | |||
| features_dict.get("text_to_speech") | |||
| @@ -242,18 +241,26 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if isinstance(event, QueuePingEvent): | |||
| yield self._ping_stream_response() | |||
| elif isinstance(event, QueueErrorEvent): | |||
| err = self._handle_error(event) | |||
| err = self._handle_error(event=event) | |||
| yield self._error_to_stream_response(err) | |||
| break | |||
| elif isinstance(event, QueueWorkflowStartedEvent): | |||
| # override graph runtime state | |||
| graph_runtime_state = event.graph_runtime_state | |||
| # init workflow run | |||
| workflow_run = self._handle_workflow_run_start() | |||
| yield self._workflow_start_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| with Session(db.engine) as session: | |||
| # init workflow run | |||
| workflow_run = self._handle_workflow_run_start( | |||
| session=session, | |||
| workflow_id=self._workflow_id, | |||
| user_id=self._user_id, | |||
| created_by_role=self._created_by_role, | |||
| ) | |||
| start_resp = self._workflow_start_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| session.commit() | |||
| yield start_resp | |||
| elif isinstance( | |||
| event, | |||
| QueueNodeRetryEvent, | |||
| @@ -350,22 +357,28 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| workflow_run = self._handle_workflow_run_success( | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(workflow_run) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_success( | |||
| session=session, | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(session=session, workflow_run=workflow_run) | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| elif isinstance(event, QueueWorkflowPartialSuccessEvent): | |||
| if not workflow_run: | |||
| raise ValueError("workflow run not initialized.") | |||
| @@ -373,49 +386,58 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| workflow_run = self._handle_workflow_run_partial_success( | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| exceptions_count=event.exceptions_count, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(workflow_run) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_partial_success( | |||
| session=session, | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| outputs=event.outputs, | |||
| exceptions_count=event.exceptions_count, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(session=session, workflow_run=workflow_run) | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): | |||
| if not workflow_run: | |||
| raise ValueError("workflow run not initialized.") | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| workflow_run = self._handle_workflow_run_failed( | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.FAILED | |||
| if isinstance(event, QueueWorkflowFailedEvent) | |||
| else WorkflowRunStatus.STOPPED, | |||
| error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, | |||
| ) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(workflow_run) | |||
| yield self._workflow_finish_to_stream_response( | |||
| task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_failed( | |||
| session=session, | |||
| workflow_run=workflow_run, | |||
| start_at=graph_runtime_state.start_at, | |||
| total_tokens=graph_runtime_state.total_tokens, | |||
| total_steps=graph_runtime_state.node_run_steps, | |||
| status=WorkflowRunStatus.FAILED | |||
| if isinstance(event, QueueWorkflowFailedEvent) | |||
| else WorkflowRunStatus.STOPPED, | |||
| error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, | |||
| ) | |||
| # save workflow app log | |||
| self._save_workflow_app_log(session=session, workflow_run=workflow_run) | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| elif isinstance(event, QueueTextChunkEvent): | |||
| delta_text = event.text | |||
| if delta_text is None: | |||
| @@ -435,7 +457,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if tts_publisher: | |||
| tts_publisher.publish(None) | |||
| def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: | |||
| def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None: | |||
| """ | |||
| Save workflow app log. | |||
| :return: | |||
| @@ -457,12 +479,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| workflow_app_log.workflow_id = workflow_run.workflow_id | |||
| workflow_app_log.workflow_run_id = workflow_run.id | |||
| workflow_app_log.created_from = created_from.value | |||
| workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user" | |||
| workflow_app_log.created_by = self._user.id | |||
| workflow_app_log.created_by_role = self._created_by_role | |||
| workflow_app_log.created_by = self._user_id | |||
| db.session.add(workflow_app_log) | |||
| db.session.commit() | |||
| db.session.close() | |||
| session.add(workflow_app_log) | |||
| def _text_chunk_to_stream_response( | |||
| self, text: str, from_variable_selector: Optional[list[str]] = None | |||
| @@ -1,6 +1,9 @@ | |||
| import logging | |||
| import time | |||
| from typing import Optional, Union | |||
| from typing import Optional | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import ( | |||
| @@ -17,9 +20,7 @@ from core.app.entities.task_entities import ( | |||
| from core.errors.error import QuotaExceededError | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.moderation.output_moderation import ModerationRule, OutputModeration | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.model import EndUser, Message | |||
| from models.model import Message | |||
| logger = logging.getLogger(__name__) | |||
| @@ -36,7 +37,6 @@ class BasedGenerateTaskPipeline: | |||
| self, | |||
| application_generate_entity: AppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| user: Union[Account, EndUser], | |||
| stream: bool, | |||
| ) -> None: | |||
| """ | |||
| @@ -48,18 +48,11 @@ class BasedGenerateTaskPipeline: | |||
| """ | |||
| self._application_generate_entity = application_generate_entity | |||
| self._queue_manager = queue_manager | |||
| self._user = user | |||
| self._start_at = time.perf_counter() | |||
| self._output_moderation_handler = self._init_output_moderation() | |||
| self._stream = stream | |||
| def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None): | |||
| """ | |||
| Handle error event. | |||
| :param event: event | |||
| :param message: message | |||
| :return: | |||
| """ | |||
| def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): | |||
| logger.debug("error: %s", event.error) | |||
| e = event.error | |||
| err: Exception | |||
| @@ -71,16 +64,17 @@ class BasedGenerateTaskPipeline: | |||
| else: | |||
| err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) | |||
| if message: | |||
| refetch_message = db.session.query(Message).filter(Message.id == message.id).first() | |||
| if refetch_message: | |||
| err_desc = self._error_to_desc(err) | |||
| refetch_message.status = "error" | |||
| refetch_message.error = err_desc | |||
| if not message_id or not session: | |||
| return err | |||
| db.session.commit() | |||
| stmt = select(Message).where(Message.id == message_id) | |||
| message = session.scalar(stmt) | |||
| if not message: | |||
| return err | |||
| err_desc = self._error_to_desc(err) | |||
| message.status = "error" | |||
| message.error = err_desc | |||
| return err | |||
| def _error_to_desc(self, e: Exception) -> str: | |||
| @@ -5,6 +5,9 @@ from collections.abc import Generator | |||
| from threading import Thread | |||
| from typing import Optional, Union, cast | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME | |||
| from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| @@ -55,8 +58,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.prompt.utils.prompt_template_parser import PromptTemplateParser | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.model import AppMode, Conversation, EndUser, Message, MessageAgentThought | |||
| from models.model import AppMode, Conversation, Message, MessageAgentThought | |||
| logger = logging.getLogger(__name__) | |||
| @@ -77,23 +79,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| queue_manager: AppQueueManager, | |||
| conversation: Conversation, | |||
| message: Message, | |||
| user: Union[Account, EndUser], | |||
| stream: bool, | |||
| ) -> None: | |||
| """ | |||
| Initialize GenerateTaskPipeline. | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: queue manager | |||
| :param conversation: conversation | |||
| :param message: message | |||
| :param user: user | |||
| :param stream: stream | |||
| """ | |||
| super().__init__(application_generate_entity, queue_manager, user, stream) | |||
| super().__init__( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| stream=stream, | |||
| ) | |||
| self._model_config = application_generate_entity.model_conf | |||
| self._app_config = application_generate_entity.app_config | |||
| self._conversation = conversation | |||
| self._message = message | |||
| self._conversation_id = conversation.id | |||
| self._conversation_mode = conversation.mode | |||
| self._message_id = message.id | |||
| self._message_created_at = int(message.created_at.timestamp()) | |||
| self._task_state = EasyUITaskState( | |||
| llm_result=LLMResult( | |||
| @@ -113,18 +113,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| CompletionAppBlockingResponse, | |||
| Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], | |||
| ]: | |||
| """ | |||
| Process generate task pipeline. | |||
| :return: | |||
| """ | |||
| db.session.refresh(self._conversation) | |||
| db.session.refresh(self._message) | |||
| db.session.close() | |||
| if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: | |||
| # start generate conversation name thread | |||
| self._conversation_name_generate_thread = self._generate_conversation_name( | |||
| self._conversation, self._application_generate_entity.query or "" | |||
| conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" | |||
| ) | |||
| generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) | |||
| @@ -148,15 +140,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| if self._task_state.metadata: | |||
| extras["metadata"] = self._task_state.metadata | |||
| response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] | |||
| if self._conversation.mode == AppMode.COMPLETION.value: | |||
| if self._conversation_mode == AppMode.COMPLETION.value: | |||
| response = CompletionAppBlockingResponse( | |||
| task_id=self._application_generate_entity.task_id, | |||
| data=CompletionAppBlockingResponse.Data( | |||
| id=self._message.id, | |||
| mode=self._conversation.mode, | |||
| message_id=self._message.id, | |||
| id=self._message_id, | |||
| mode=self._conversation_mode, | |||
| message_id=self._message_id, | |||
| answer=cast(str, self._task_state.llm_result.message.content), | |||
| created_at=int(self._message.created_at.timestamp()), | |||
| created_at=self._message_created_at, | |||
| **extras, | |||
| ), | |||
| ) | |||
| @@ -164,12 +156,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| response = ChatbotAppBlockingResponse( | |||
| task_id=self._application_generate_entity.task_id, | |||
| data=ChatbotAppBlockingResponse.Data( | |||
| id=self._message.id, | |||
| mode=self._conversation.mode, | |||
| conversation_id=self._conversation.id, | |||
| message_id=self._message.id, | |||
| id=self._message_id, | |||
| mode=self._conversation_mode, | |||
| conversation_id=self._conversation_id, | |||
| message_id=self._message_id, | |||
| answer=cast(str, self._task_state.llm_result.message.content), | |||
| created_at=int(self._message.created_at.timestamp()), | |||
| created_at=self._message_created_at, | |||
| **extras, | |||
| ), | |||
| ) | |||
| @@ -190,15 +182,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| for stream_response in generator: | |||
| if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): | |||
| yield CompletionAppStreamResponse( | |||
| message_id=self._message.id, | |||
| created_at=int(self._message.created_at.timestamp()), | |||
| message_id=self._message_id, | |||
| created_at=self._message_created_at, | |||
| stream_response=stream_response, | |||
| ) | |||
| else: | |||
| yield ChatbotAppStreamResponse( | |||
| conversation_id=self._conversation.id, | |||
| message_id=self._message.id, | |||
| created_at=int(self._message.created_at.timestamp()), | |||
| conversation_id=self._conversation_id, | |||
| message_id=self._message_id, | |||
| created_at=self._message_created_at, | |||
| stream_response=stream_response, | |||
| ) | |||
| @@ -265,7 +257,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| event = message.event | |||
| if isinstance(event, QueueErrorEvent): | |||
| err = self._handle_error(event, self._message) | |||
| with Session(db.engine) as session: | |||
| err = self._handle_error(event=event, session=session, message_id=self._message_id) | |||
| session.commit() | |||
| yield self._error_to_stream_response(err) | |||
| break | |||
| elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): | |||
| @@ -283,10 +277,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| self._task_state.llm_result.message.content = output_moderation_answer | |||
| yield self._message_replace_to_stream_response(answer=output_moderation_answer) | |||
| # Save message | |||
| self._save_message(trace_manager) | |||
| yield self._message_end_to_stream_response() | |||
| with Session(db.engine) as session: | |||
| # Save message | |||
| self._save_message(session=session, trace_manager=trace_manager) | |||
| session.commit() | |||
| message_end_resp = self._message_end_to_stream_response() | |||
| yield message_end_resp | |||
| elif isinstance(event, QueueRetrieverResourcesEvent): | |||
| self._handle_retriever_resources(event) | |||
| elif isinstance(event, QueueAnnotationReplyEvent): | |||
| @@ -320,9 +316,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| self._task_state.llm_result.message.content = current_content | |||
| if isinstance(event, QueueLLMChunkEvent): | |||
| yield self._message_to_stream_response(cast(str, delta_text), self._message.id) | |||
| yield self._message_to_stream_response( | |||
| answer=cast(str, delta_text), | |||
| message_id=self._message_id, | |||
| ) | |||
| else: | |||
| yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id) | |||
| yield self._agent_message_to_stream_response( | |||
| answer=cast(str, delta_text), | |||
| message_id=self._message_id, | |||
| ) | |||
| elif isinstance(event, QueueMessageReplaceEvent): | |||
| yield self._message_replace_to_stream_response(answer=event.text) | |||
| elif isinstance(event, QueuePingEvent): | |||
| @@ -334,7 +336,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| if self._conversation_name_generate_thread: | |||
| self._conversation_name_generate_thread.join() | |||
| def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None: | |||
| def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None: | |||
| """ | |||
| Save message. | |||
| :return: | |||
| @@ -342,53 +344,46 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| llm_result = self._task_state.llm_result | |||
| usage = llm_result.usage | |||
| message = db.session.query(Message).filter(Message.id == self._message.id).first() | |||
| message_stmt = select(Message).where(Message.id == self._message_id) | |||
| message = session.scalar(message_stmt) | |||
| if not message: | |||
| raise Exception(f"Message {self._message.id} not found") | |||
| self._message = message | |||
| conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() | |||
| raise ValueError(f"message {self._message_id} not found") | |||
| conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id) | |||
| conversation = session.scalar(conversation_stmt) | |||
| if not conversation: | |||
| raise Exception(f"Conversation {self._conversation.id} not found") | |||
| self._conversation = conversation | |||
| raise ValueError(f"Conversation {self._conversation_id} not found") | |||
| self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( | |||
| message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( | |||
| self._model_config.mode, self._task_state.llm_result.prompt_messages | |||
| ) | |||
| self._message.message_tokens = usage.prompt_tokens | |||
| self._message.message_unit_price = usage.prompt_unit_price | |||
| self._message.message_price_unit = usage.prompt_price_unit | |||
| self._message.answer = ( | |||
| message.message_tokens = usage.prompt_tokens | |||
| message.message_unit_price = usage.prompt_unit_price | |||
| message.message_price_unit = usage.prompt_price_unit | |||
| message.answer = ( | |||
| PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip()) | |||
| if llm_result.message.content | |||
| else "" | |||
| ) | |||
| self._message.answer_tokens = usage.completion_tokens | |||
| self._message.answer_unit_price = usage.completion_unit_price | |||
| self._message.answer_price_unit = usage.completion_price_unit | |||
| self._message.provider_response_latency = time.perf_counter() - self._start_at | |||
| self._message.total_price = usage.total_price | |||
| self._message.currency = usage.currency | |||
| self._message.message_metadata = ( | |||
| message.answer_tokens = usage.completion_tokens | |||
| message.answer_unit_price = usage.completion_unit_price | |||
| message.answer_price_unit = usage.completion_price_unit | |||
| message.provider_response_latency = time.perf_counter() - self._start_at | |||
| message.total_price = usage.total_price | |||
| message.currency = usage.currency | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| db.session.commit() | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| TraceTask( | |||
| TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id | |||
| TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id | |||
| ) | |||
| ) | |||
| message_was_created.send( | |||
| self._message, | |||
| message, | |||
| application_generate_entity=self._application_generate_entity, | |||
| conversation=self._conversation, | |||
| is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} | |||
| and hasattr(self._application_generate_entity, "conversation_id") | |||
| and self._application_generate_entity.conversation_id is None, | |||
| extras=self._application_generate_entity.extras, | |||
| ) | |||
| def _handle_stop(self, event: QueueStopEvent) -> None: | |||
| @@ -434,7 +429,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| return MessageEndStreamResponse( | |||
| task_id=self._application_generate_entity.task_id, | |||
| id=self._message.id, | |||
| id=self._message_id, | |||
| metadata=extras.get("metadata", {}), | |||
| ) | |||
| @@ -36,7 +36,7 @@ class MessageCycleManage: | |||
| ] | |||
| _task_state: Union[EasyUITaskState, WorkflowTaskState] | |||
| def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: | |||
| def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: | |||
| """ | |||
| Generate conversation name. | |||
| :param conversation: conversation | |||
| @@ -56,7 +56,7 @@ class MessageCycleManage: | |||
| target=self._generate_conversation_name_worker, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "conversation_id": conversation.id, | |||
| "conversation_id": conversation_id, | |||
| "query": query, | |||
| }, | |||
| ) | |||
| @@ -5,6 +5,7 @@ from datetime import UTC, datetime | |||
| from typing import Any, Optional, Union, cast | |||
| from uuid import uuid4 | |||
| from sqlalchemy import func, select | |||
| from sqlalchemy.orm import Session | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity | |||
| @@ -63,27 +64,34 @@ from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError | |||
| class WorkflowCycleManage: | |||
| _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] | |||
| _workflow: Workflow | |||
| _user: Union[Account, EndUser] | |||
| _task_state: WorkflowTaskState | |||
| _workflow_system_variables: dict[SystemVariableKey, Any] | |||
| _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] | |||
| def _handle_workflow_run_start(self) -> WorkflowRun: | |||
| max_sequence = ( | |||
| db.session.query(db.func.max(WorkflowRun.sequence_number)) | |||
| .filter(WorkflowRun.tenant_id == self._workflow.tenant_id) | |||
| .filter(WorkflowRun.app_id == self._workflow.app_id) | |||
| .scalar() | |||
| or 0 | |||
| def _handle_workflow_run_start( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| workflow_id: str, | |||
| user_id: str, | |||
| created_by_role: CreatedByRole, | |||
| ) -> WorkflowRun: | |||
| workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) | |||
| workflow = session.scalar(workflow_stmt) | |||
| if not workflow: | |||
| raise ValueError(f"Workflow not found: {workflow_id}") | |||
| max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where( | |||
| WorkflowRun.tenant_id == workflow.tenant_id, | |||
| WorkflowRun.app_id == workflow.app_id, | |||
| ) | |||
| max_sequence = session.scalar(max_sequence_stmt) or 0 | |||
| new_sequence_number = max_sequence + 1 | |||
| inputs = {**self._application_generate_entity.inputs} | |||
| for key, value in (self._workflow_system_variables or {}).items(): | |||
| if key.value == "conversation": | |||
| continue | |||
| inputs[f"sys.{key.value}"] = value | |||
| triggered_from = ( | |||
| @@ -96,33 +104,32 @@ class WorkflowCycleManage: | |||
| inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) | |||
| # init workflow run | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = WorkflowRun() | |||
| system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] | |||
| workflow_run.id = system_id or str(uuid4()) | |||
| workflow_run.tenant_id = self._workflow.tenant_id | |||
| workflow_run.app_id = self._workflow.app_id | |||
| workflow_run.sequence_number = new_sequence_number | |||
| workflow_run.workflow_id = self._workflow.id | |||
| workflow_run.type = self._workflow.type | |||
| workflow_run.triggered_from = triggered_from.value | |||
| workflow_run.version = self._workflow.version | |||
| workflow_run.graph = self._workflow.graph | |||
| workflow_run.inputs = json.dumps(inputs) | |||
| workflow_run.status = WorkflowRunStatus.RUNNING | |||
| workflow_run.created_by_role = ( | |||
| CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER | |||
| ) | |||
| workflow_run.created_by = self._user.id | |||
| workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| session.add(workflow_run) | |||
| session.commit() | |||
| workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4())) | |||
| workflow_run = WorkflowRun() | |||
| workflow_run.id = workflow_run_id | |||
| workflow_run.tenant_id = workflow.tenant_id | |||
| workflow_run.app_id = workflow.app_id | |||
| workflow_run.sequence_number = new_sequence_number | |||
| workflow_run.workflow_id = workflow.id | |||
| workflow_run.type = workflow.type | |||
| workflow_run.triggered_from = triggered_from.value | |||
| workflow_run.version = workflow.version | |||
| workflow_run.graph = workflow.graph | |||
| workflow_run.inputs = json.dumps(inputs) | |||
| workflow_run.status = WorkflowRunStatus.RUNNING | |||
| workflow_run.created_by_role = created_by_role | |||
| workflow_run.created_by = user_id | |||
| workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| session.add(workflow_run) | |||
| return workflow_run | |||
| def _handle_workflow_run_success( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| workflow_run: WorkflowRun, | |||
| start_at: float, | |||
| total_tokens: int, | |||
| @@ -141,7 +148,7 @@ class WorkflowCycleManage: | |||
| :param conversation_id: conversation id | |||
| :return: | |||
| """ | |||
| workflow_run = self._refetch_workflow_run(workflow_run.id) | |||
| workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id) | |||
| outputs = WorkflowEntry.handle_special_values(outputs) | |||
| @@ -152,9 +159,6 @@ class WorkflowCycleManage: | |||
| workflow_run.total_steps = total_steps | |||
| workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| db.session.refresh(workflow_run) | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| TraceTask( | |||
| @@ -165,12 +169,12 @@ class WorkflowCycleManage: | |||
| ) | |||
| ) | |||
| db.session.close() | |||
| return workflow_run | |||
| def _handle_workflow_run_partial_success( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| workflow_run: WorkflowRun, | |||
| start_at: float, | |||
| total_tokens: int, | |||
| @@ -190,7 +194,7 @@ class WorkflowCycleManage: | |||
| :param conversation_id: conversation id | |||
| :return: | |||
| """ | |||
| workflow_run = self._refetch_workflow_run(workflow_run.id) | |||
| workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id) | |||
| outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) | |||
| @@ -201,8 +205,6 @@ class WorkflowCycleManage: | |||
| workflow_run.total_steps = total_steps | |||
| workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_run.exceptions_count = exceptions_count | |||
| db.session.commit() | |||
| db.session.refresh(workflow_run) | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| @@ -214,12 +216,12 @@ class WorkflowCycleManage: | |||
| ) | |||
| ) | |||
| db.session.close() | |||
| return workflow_run | |||
| def _handle_workflow_run_failed( | |||
| self, | |||
| *, | |||
| session: Session, | |||
| workflow_run: WorkflowRun, | |||
| start_at: float, | |||
| total_tokens: int, | |||
| @@ -240,7 +242,7 @@ class WorkflowCycleManage: | |||
| :param error: error message | |||
| :return: | |||
| """ | |||
| workflow_run = self._refetch_workflow_run(workflow_run.id) | |||
| workflow_run = self._refetch_workflow_run(session=session, workflow_run_id=workflow_run.id) | |||
| workflow_run.status = status.value | |||
| workflow_run.error = error | |||
| @@ -249,21 +251,18 @@ class WorkflowCycleManage: | |||
| workflow_run.total_steps = total_steps | |||
| workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_run.exceptions_count = exceptions_count | |||
| db.session.commit() | |||
| running_workflow_node_executions = ( | |||
| db.session.query(WorkflowNodeExecution) | |||
| .filter( | |||
| WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, | |||
| WorkflowNodeExecution.app_id == workflow_run.app_id, | |||
| WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, | |||
| WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, | |||
| WorkflowNodeExecution.workflow_run_id == workflow_run.id, | |||
| WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, | |||
| ) | |||
| .all() | |||
| stmt = select(WorkflowNodeExecution).where( | |||
| WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, | |||
| WorkflowNodeExecution.app_id == workflow_run.app_id, | |||
| WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, | |||
| WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, | |||
| WorkflowNodeExecution.workflow_run_id == workflow_run.id, | |||
| WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, | |||
| ) | |||
| running_workflow_node_executions = session.scalars(stmt).all() | |||
| for workflow_node_execution in running_workflow_node_executions: | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value | |||
| workflow_node_execution.error = error | |||
| @@ -271,13 +270,6 @@ class WorkflowCycleManage: | |||
| workflow_node_execution.elapsed_time = ( | |||
| workflow_node_execution.finished_at - workflow_node_execution.created_at | |||
| ).total_seconds() | |||
| db.session.commit() | |||
| db.session.close() | |||
| # with Session(db.engine, expire_on_commit=False) as session: | |||
| # session.add(workflow_run) | |||
| # session.refresh(workflow_run) | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| @@ -485,14 +477,14 @@ class WorkflowCycleManage: | |||
| ################################################# | |||
| def _workflow_start_to_stream_response( | |||
| self, task_id: str, workflow_run: WorkflowRun | |||
| self, | |||
| *, | |||
| session: Session, | |||
| task_id: str, | |||
| workflow_run: WorkflowRun, | |||
| ) -> WorkflowStartStreamResponse: | |||
| """ | |||
| Workflow start to stream response. | |||
| :param task_id: task id | |||
| :param workflow_run: workflow run | |||
| :return: | |||
| """ | |||
| # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this | |||
| _ = session | |||
| return WorkflowStartStreamResponse( | |||
| task_id=task_id, | |||
| workflow_run_id=workflow_run.id, | |||
| @@ -506,36 +498,32 @@ class WorkflowCycleManage: | |||
| ) | |||
| def _workflow_finish_to_stream_response( | |||
| self, task_id: str, workflow_run: WorkflowRun | |||
| self, | |||
| *, | |||
| session: Session, | |||
| task_id: str, | |||
| workflow_run: WorkflowRun, | |||
| ) -> WorkflowFinishStreamResponse: | |||
| """ | |||
| Workflow finish to stream response. | |||
| :param task_id: task id | |||
| :param workflow_run: workflow run | |||
| :return: | |||
| """ | |||
| # Attach WorkflowRun to an active session so "created_by_role" can be accessed. | |||
| workflow_run = db.session.merge(workflow_run) | |||
| # Refresh to ensure any expired attributes are fully loaded | |||
| db.session.refresh(workflow_run) | |||
| created_by = None | |||
| if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value: | |||
| created_by_account = workflow_run.created_by_account | |||
| if created_by_account: | |||
| if workflow_run.created_by_role == CreatedByRole.ACCOUNT: | |||
| stmt = select(Account).where(Account.id == workflow_run.created_by) | |||
| account = session.scalar(stmt) | |||
| if account: | |||
| created_by = { | |||
| "id": created_by_account.id, | |||
| "name": created_by_account.name, | |||
| "email": created_by_account.email, | |||
| "id": account.id, | |||
| "name": account.name, | |||
| "email": account.email, | |||
| } | |||
| else: | |||
| created_by_end_user = workflow_run.created_by_end_user | |||
| if created_by_end_user: | |||
| elif workflow_run.created_by_role == CreatedByRole.END_USER: | |||
| stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) | |||
| end_user = session.scalar(stmt) | |||
| if end_user: | |||
| created_by = { | |||
| "id": created_by_end_user.id, | |||
| "user": created_by_end_user.session_id, | |||
| "id": end_user.id, | |||
| "user": end_user.session_id, | |||
| } | |||
| else: | |||
| raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}") | |||
| return WorkflowFinishStreamResponse( | |||
| task_id=task_id, | |||
| @@ -895,14 +883,14 @@ class WorkflowCycleManage: | |||
| return None | |||
| def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: | |||
| def _refetch_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: | |||
| """ | |||
| Refetch workflow run | |||
| :param workflow_run_id: workflow run id | |||
| :return: | |||
| """ | |||
| workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() | |||
| stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) | |||
| workflow_run = session.scalar(stmt) | |||
| if not workflow_run: | |||
| raise WorkflowRunNotFoundError(workflow_run_id) | |||
| @@ -9,6 +9,8 @@ from typing import Any, Optional, Union | |||
| from uuid import UUID, uuid4 | |||
| from flask import current_app | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token | |||
| from core.ops.entities.config_entity import ( | |||
| @@ -329,15 +331,15 @@ class TraceTask: | |||
| ): | |||
| self.trace_type = trace_type | |||
| self.message_id = message_id | |||
| self.workflow_run = workflow_run | |||
| self.workflow_run_id = workflow_run.id if workflow_run else None | |||
| self.conversation_id = conversation_id | |||
| self.user_id = user_id | |||
| self.timer = timer | |||
| self.kwargs = kwargs | |||
| self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") | |||
| self.app_id = None | |||
| self.kwargs = kwargs | |||
| def execute(self): | |||
| return self.preprocess() | |||
| @@ -345,19 +347,23 @@ class TraceTask: | |||
| preprocess_map = { | |||
| TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), | |||
| TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( | |||
| self.workflow_run, self.conversation_id, self.user_id | |||
| workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id | |||
| ), | |||
| TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id), | |||
| TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( | |||
| message_id=self.message_id, timer=self.timer, **self.kwargs | |||
| ), | |||
| TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id), | |||
| TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs), | |||
| TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( | |||
| self.message_id, self.timer, **self.kwargs | |||
| message_id=self.message_id, timer=self.timer, **self.kwargs | |||
| ), | |||
| TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace( | |||
| self.message_id, self.timer, **self.kwargs | |||
| message_id=self.message_id, timer=self.timer, **self.kwargs | |||
| ), | |||
| TraceTaskName.TOOL_TRACE: lambda: self.tool_trace( | |||
| message_id=self.message_id, timer=self.timer, **self.kwargs | |||
| ), | |||
| TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs), | |||
| TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( | |||
| self.conversation_id, self.timer, **self.kwargs | |||
| conversation_id=self.conversation_id, timer=self.timer, **self.kwargs | |||
| ), | |||
| } | |||
| @@ -367,86 +373,100 @@ class TraceTask: | |||
| def conversation_trace(self, **kwargs): | |||
| return kwargs | |||
| def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id): | |||
| if not workflow_run: | |||
| raise ValueError("Workflow run not found") | |||
| db.session.merge(workflow_run) | |||
| db.session.refresh(workflow_run) | |||
| workflow_id = workflow_run.workflow_id | |||
| tenant_id = workflow_run.tenant_id | |||
| workflow_run_id = workflow_run.id | |||
| workflow_run_elapsed_time = workflow_run.elapsed_time | |||
| workflow_run_status = workflow_run.status | |||
| workflow_run_inputs = workflow_run.inputs_dict | |||
| workflow_run_outputs = workflow_run.outputs_dict | |||
| workflow_run_version = workflow_run.version | |||
| error = workflow_run.error or "" | |||
| total_tokens = workflow_run.total_tokens | |||
| file_list = workflow_run_inputs.get("sys.file") or [] | |||
| query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" | |||
| # get workflow_app_log_id | |||
| workflow_app_log_data = ( | |||
| db.session.query(WorkflowAppLog) | |||
| .filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id) | |||
| .first() | |||
| ) | |||
| workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None | |||
| # get message_id | |||
| message_data = ( | |||
| db.session.query(Message.id) | |||
| .filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id) | |||
| .first() | |||
| ) | |||
| message_id = str(message_data.id) if message_data else None | |||
| metadata = { | |||
| "workflow_id": workflow_id, | |||
| "conversation_id": conversation_id, | |||
| "workflow_run_id": workflow_run_id, | |||
| "tenant_id": tenant_id, | |||
| "elapsed_time": workflow_run_elapsed_time, | |||
| "status": workflow_run_status, | |||
| "version": workflow_run_version, | |||
| "total_tokens": total_tokens, | |||
| "file_list": file_list, | |||
| "triggered_form": workflow_run.triggered_from, | |||
| "user_id": user_id, | |||
| } | |||
| def workflow_trace( | |||
| self, | |||
| *, | |||
| workflow_run_id: str | None, | |||
| conversation_id: str | None, | |||
| user_id: str | None, | |||
| ): | |||
| if not workflow_run_id: | |||
| return {} | |||
| workflow_trace_info = WorkflowTraceInfo( | |||
| workflow_data=workflow_run.to_dict(), | |||
| conversation_id=conversation_id, | |||
| workflow_id=workflow_id, | |||
| tenant_id=tenant_id, | |||
| workflow_run_id=workflow_run_id, | |||
| workflow_run_elapsed_time=workflow_run_elapsed_time, | |||
| workflow_run_status=workflow_run_status, | |||
| workflow_run_inputs=workflow_run_inputs, | |||
| workflow_run_outputs=workflow_run_outputs, | |||
| workflow_run_version=workflow_run_version, | |||
| error=error, | |||
| total_tokens=total_tokens, | |||
| file_list=file_list, | |||
| query=query, | |||
| metadata=metadata, | |||
| workflow_app_log_id=workflow_app_log_id, | |||
| message_id=message_id, | |||
| start_time=workflow_run.created_at, | |||
| end_time=workflow_run.finished_at, | |||
| ) | |||
| with Session(db.engine) as session: | |||
| workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) | |||
| workflow_run = session.scalars(workflow_run_stmt).first() | |||
| if not workflow_run: | |||
| raise ValueError("Workflow run not found") | |||
| workflow_id = workflow_run.workflow_id | |||
| tenant_id = workflow_run.tenant_id | |||
| workflow_run_id = workflow_run.id | |||
| workflow_run_elapsed_time = workflow_run.elapsed_time | |||
| workflow_run_status = workflow_run.status | |||
| workflow_run_inputs = workflow_run.inputs_dict | |||
| workflow_run_outputs = workflow_run.outputs_dict | |||
| workflow_run_version = workflow_run.version | |||
| error = workflow_run.error or "" | |||
| total_tokens = workflow_run.total_tokens | |||
| file_list = workflow_run_inputs.get("sys.file") or [] | |||
| query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" | |||
| # get workflow_app_log_id | |||
| workflow_app_log_data_stmt = select(WorkflowAppLog.id).where( | |||
| WorkflowAppLog.tenant_id == tenant_id, | |||
| WorkflowAppLog.app_id == workflow_run.app_id, | |||
| WorkflowAppLog.workflow_run_id == workflow_run.id, | |||
| ) | |||
| workflow_app_log_id = session.scalar(workflow_app_log_data_stmt) | |||
| # get message_id | |||
| message_id = None | |||
| if conversation_id: | |||
| message_data_stmt = select(Message.id).where( | |||
| Message.conversation_id == conversation_id, | |||
| Message.workflow_run_id == workflow_run_id, | |||
| ) | |||
| message_id = session.scalar(message_data_stmt) | |||
| metadata = { | |||
| "workflow_id": workflow_id, | |||
| "conversation_id": conversation_id, | |||
| "workflow_run_id": workflow_run_id, | |||
| "tenant_id": tenant_id, | |||
| "elapsed_time": workflow_run_elapsed_time, | |||
| "status": workflow_run_status, | |||
| "version": workflow_run_version, | |||
| "total_tokens": total_tokens, | |||
| "file_list": file_list, | |||
| "triggered_form": workflow_run.triggered_from, | |||
| "user_id": user_id, | |||
| } | |||
| workflow_trace_info = WorkflowTraceInfo( | |||
| workflow_data=workflow_run.to_dict(), | |||
| conversation_id=conversation_id, | |||
| workflow_id=workflow_id, | |||
| tenant_id=tenant_id, | |||
| workflow_run_id=workflow_run_id, | |||
| workflow_run_elapsed_time=workflow_run_elapsed_time, | |||
| workflow_run_status=workflow_run_status, | |||
| workflow_run_inputs=workflow_run_inputs, | |||
| workflow_run_outputs=workflow_run_outputs, | |||
| workflow_run_version=workflow_run_version, | |||
| error=error, | |||
| total_tokens=total_tokens, | |||
| file_list=file_list, | |||
| query=query, | |||
| metadata=metadata, | |||
| workflow_app_log_id=workflow_app_log_id, | |||
| message_id=message_id, | |||
| start_time=workflow_run.created_at, | |||
| end_time=workflow_run.finished_at, | |||
| ) | |||
| return workflow_trace_info | |||
| def message_trace(self, message_id): | |||
| def message_trace(self, message_id: str | None): | |||
| if not message_id: | |||
| return {} | |||
| message_data = get_message_data(message_id) | |||
| if not message_data: | |||
| return {} | |||
| conversation_mode = db.session.query(Conversation.mode).filter_by(id=message_data.conversation_id).first() | |||
| conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id) | |||
| conversation_mode = db.session.scalars(conversation_mode_stmt).all() | |||
| if not conversation_mode or len(conversation_mode) == 0: | |||
| return {} | |||
| conversation_mode = conversation_mode[0] | |||
| created_at = message_data.created_at | |||
| inputs = message_data.message | |||
| @@ -18,7 +18,7 @@ def filter_none_values(data: dict): | |||
| return new_data | |||
| def get_message_data(message_id): | |||
| def get_message_data(message_id: str): | |||
| return db.session.query(Message).filter(Message.id == message_id).first() | |||
| @@ -3,6 +3,7 @@ import json | |||
| from flask_login import UserMixin # type: ignore | |||
| from sqlalchemy import func | |||
| from sqlalchemy.orm import Mapped, mapped_column | |||
| from .engine import db | |||
| from .types import StringUUID | |||
| @@ -20,7 +21,7 @@ class Account(UserMixin, db.Model): # type: ignore[name-defined] | |||
| __tablename__ = "accounts" | |||
| __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| name = db.Column(db.String(255), nullable=False) | |||
| email = db.Column(db.String(255), nullable=False) | |||
| password = db.Column(db.String(255), nullable=True) | |||
| @@ -530,13 +530,13 @@ class Conversation(db.Model): # type: ignore[name-defined] | |||
| db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| app_id = db.Column(StringUUID, nullable=False) | |||
| app_model_config_id = db.Column(StringUUID, nullable=True) | |||
| model_provider = db.Column(db.String(255), nullable=True) | |||
| override_model_configs = db.Column(db.Text) | |||
| model_id = db.Column(db.String(255), nullable=True) | |||
| mode = db.Column(db.String(255), nullable=False) | |||
| mode: Mapped[str] = mapped_column(db.String(255)) | |||
| name = db.Column(db.String(255), nullable=False) | |||
| summary = db.Column(db.Text) | |||
| _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) | |||
| @@ -770,7 +770,7 @@ class Message(db.Model): # type: ignore[name-defined] | |||
| db.Index("message_created_at_idx", "created_at"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| app_id = db.Column(StringUUID, nullable=False) | |||
| model_provider = db.Column(db.String(255), nullable=True) | |||
| model_id = db.Column(db.String(255), nullable=True) | |||
| @@ -797,7 +797,7 @@ class Message(db.Model): # type: ignore[name-defined] | |||
| from_source = db.Column(db.String(255), nullable=False) | |||
| from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) | |||
| from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||
| workflow_run_id = db.Column(StringUUID) | |||
| @@ -1322,7 +1322,7 @@ class EndUser(UserMixin, db.Model): # type: ignore[name-defined] | |||
| external_user_id = db.Column(db.String(255), nullable=True) | |||
| name = db.Column(db.String(255)) | |||
| is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |||
| session_id = db.Column(db.String(255), nullable=False) | |||
| session_id: Mapped[str] = mapped_column() | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @@ -392,40 +392,28 @@ class WorkflowRun(db.Model): # type: ignore[name-defined] | |||
| db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| app_id = db.Column(StringUUID, nullable=False) | |||
| sequence_number = db.Column(db.Integer, nullable=False) | |||
| workflow_id = db.Column(StringUUID, nullable=False) | |||
| type = db.Column(db.String(255), nullable=False) | |||
| triggered_from = db.Column(db.String(255), nullable=False) | |||
| version = db.Column(db.String(255), nullable=False) | |||
| graph = db.Column(db.Text) | |||
| inputs = db.Column(db.Text) | |||
| status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id: Mapped[str] = mapped_column(StringUUID) | |||
| app_id: Mapped[str] = mapped_column(StringUUID) | |||
| sequence_number: Mapped[int] = mapped_column() | |||
| workflow_id: Mapped[str] = mapped_column(StringUUID) | |||
| type: Mapped[str] = mapped_column(db.String(255)) | |||
| triggered_from: Mapped[str] = mapped_column(db.String(255)) | |||
| version: Mapped[str] = mapped_column(db.String(255)) | |||
| graph: Mapped[str] = mapped_column(db.Text) | |||
| inputs: Mapped[str] = mapped_column(db.Text) | |||
| status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded | |||
| outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") | |||
| error = db.Column(db.Text) | |||
| error: Mapped[str] = mapped_column(db.Text) | |||
| elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) | |||
| total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) | |||
| total_tokens: Mapped[int] = mapped_column(server_default=db.text("0")) | |||
| total_steps = db.Column(db.Integer, server_default=db.text("0")) | |||
| created_by_role = db.Column(db.String(255), nullable=False) # account, end_user | |||
| created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| finished_at = db.Column(db.DateTime) | |||
| exceptions_count = db.Column(db.Integer, server_default=db.text("0")) | |||
| @property | |||
| def created_by_account(self): | |||
| created_by_role = CreatedByRole(self.created_by_role) | |||
| return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None | |||
| @property | |||
| def created_by_end_user(self): | |||
| from models.model import EndUser | |||
| created_by_role = CreatedByRole(self.created_by_role) | |||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None | |||
| @property | |||
| def graph_dict(self): | |||
| return json.loads(self.graph) if self.graph else {} | |||
| @@ -750,11 +738,11 @@ class WorkflowAppLog(db.Model): # type: ignore[name-defined] | |||
| db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| app_id = db.Column(StringUUID, nullable=False) | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id: Mapped[str] = mapped_column(StringUUID) | |||
| app_id: Mapped[str] = mapped_column(StringUUID) | |||
| workflow_id = db.Column(StringUUID, nullable=False) | |||
| workflow_run_id = db.Column(StringUUID, nullable=False) | |||
| workflow_run_id: Mapped[str] = mapped_column(StringUUID) | |||
| created_from = db.Column(db.String(255), nullable=False) | |||
| created_by_role = db.Column(db.String(255), nullable=False) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||