Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.4.2
| import json | |||||
| import logging | import logging | ||||
| import time | import time | ||||
| from collections.abc import Generator, Mapping | from collections.abc import Generator, Mapping | ||||
| WorkflowTaskState, | WorkflowTaskState, | ||||
| ) | ) | ||||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | ||||
| from core.app.task_pipeline.message_cycle_manage import MessageCycleManage | |||||
| from core.app.task_pipeline.message_cycle_manager import MessageCycleManager | |||||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | ||||
| from core.model_runtime.entities.llm_entities import LLMUsage | from core.model_runtime.entities.llm_entities import LLMUsage | ||||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||||
| from core.ops.ops_trace_manager import TraceQueueManager | from core.ops.ops_trace_manager import TraceQueueManager | ||||
| from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType | from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType | ||||
| from core.workflow.enums import SystemVariableKey | from core.workflow.enums import SystemVariableKey | ||||
| ) | ) | ||||
| self._task_state = WorkflowTaskState() | self._task_state = WorkflowTaskState() | ||||
| self._message_cycle_manager = MessageCycleManage( | |||||
| self._message_cycle_manager = MessageCycleManager( | |||||
| application_generate_entity=application_generate_entity, task_state=self._task_state | application_generate_entity=application_generate_entity, task_state=self._task_state | ||||
| ) | ) | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # start generate conversation name thread | # start generate conversation name thread | ||||
| self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name( | |||||
| self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( | |||||
| conversation_id=self._conversation_id, query=self._application_generate_entity.query | conversation_id=self._conversation_id, query=self._application_generate_entity.query | ||||
| ) | ) | ||||
| yield self._message_end_to_stream_response() | yield self._message_end_to_stream_response() | ||||
| break | break | ||||
| elif isinstance(event, QueueRetrieverResourcesEvent): | elif isinstance(event, QueueRetrieverResourcesEvent): | ||||
| self._message_cycle_manager._handle_retriever_resources(event) | |||||
| self._message_cycle_manager.handle_retriever_resources(event) | |||||
| with Session(db.engine, expire_on_commit=False) as session: | with Session(db.engine, expire_on_commit=False) as session: | ||||
| message = self._get_message(session=session) | message = self._get_message(session=session) | ||||
| message.message_metadata = ( | |||||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||||
| ) | |||||
| message.message_metadata = self._task_state.metadata.model_dump_json() | |||||
| session.commit() | session.commit() | ||||
| elif isinstance(event, QueueAnnotationReplyEvent): | elif isinstance(event, QueueAnnotationReplyEvent): | ||||
| self._message_cycle_manager._handle_annotation_reply(event) | |||||
| self._message_cycle_manager.handle_annotation_reply(event) | |||||
| with Session(db.engine, expire_on_commit=False) as session: | with Session(db.engine, expire_on_commit=False) as session: | ||||
| message = self._get_message(session=session) | message = self._get_message(session=session) | ||||
| message.message_metadata = ( | |||||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||||
| ) | |||||
| message.message_metadata = self._task_state.metadata.model_dump_json() | |||||
| session.commit() | session.commit() | ||||
| elif isinstance(event, QueueTextChunkEvent): | elif isinstance(event, QueueTextChunkEvent): | ||||
| delta_text = event.text | delta_text = event.text | ||||
| tts_publisher.publish(queue_message) | tts_publisher.publish(queue_message) | ||||
| self._task_state.answer += delta_text | self._task_state.answer += delta_text | ||||
| yield self._message_cycle_manager._message_to_stream_response( | |||||
| yield self._message_cycle_manager.message_to_stream_response( | |||||
| answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector | answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector | ||||
| ) | ) | ||||
| elif isinstance(event, QueueMessageReplaceEvent): | elif isinstance(event, QueueMessageReplaceEvent): | ||||
| # published by moderation | # published by moderation | ||||
| yield self._message_cycle_manager._message_replace_to_stream_response( | |||||
| yield self._message_cycle_manager.message_replace_to_stream_response( | |||||
| answer=event.text, reason=event.reason | answer=event.text, reason=event.reason | ||||
| ) | ) | ||||
| elif isinstance(event, QueueAdvancedChatMessageEndEvent): | elif isinstance(event, QueueAdvancedChatMessageEndEvent): | ||||
| ) | ) | ||||
| if output_moderation_answer: | if output_moderation_answer: | ||||
| self._task_state.answer = output_moderation_answer | self._task_state.answer = output_moderation_answer | ||||
| yield self._message_cycle_manager._message_replace_to_stream_response( | |||||
| yield self._message_cycle_manager.message_replace_to_stream_response( | |||||
| answer=output_moderation_answer, | answer=output_moderation_answer, | ||||
| reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, | reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, | ||||
| ) | ) | ||||
| message = self._get_message(session=session) | message = self._get_message(session=session) | ||||
| message.answer = self._task_state.answer | message.answer = self._task_state.answer | ||||
| message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at | message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at | ||||
| message.message_metadata = ( | |||||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||||
| ) | |||||
| message.message_metadata = self._task_state.metadata.model_dump_json() | |||||
| message_files = [ | message_files = [ | ||||
| MessageFile( | MessageFile( | ||||
| message_id=message.id, | message_id=message.id, | ||||
| message.answer_price_unit = usage.completion_price_unit | message.answer_price_unit = usage.completion_price_unit | ||||
| message.total_price = usage.total_price | message.total_price = usage.total_price | ||||
| message.currency = usage.currency | message.currency = usage.currency | ||||
| self._task_state.metadata["usage"] = jsonable_encoder(usage) | |||||
| self._task_state.metadata.usage = usage | |||||
| else: | else: | ||||
| self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) | |||||
| self._task_state.metadata.usage = LLMUsage.empty_usage() | |||||
| message_was_created.send( | message_was_created.send( | ||||
| message, | message, | ||||
| application_generate_entity=self._application_generate_entity, | application_generate_entity=self._application_generate_entity, | ||||
| Message end to stream response. | Message end to stream response. | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| extras = {} | |||||
| if self._task_state.metadata: | |||||
| extras["metadata"] = self._task_state.metadata.copy() | |||||
| extras = self._task_state.metadata.model_dump() | |||||
| if "annotation_reply" in extras["metadata"]: | |||||
| del extras["metadata"]["annotation_reply"] | |||||
| if self._task_state.metadata.annotation_reply: | |||||
| del extras["annotation_reply"] | |||||
| return MessageEndStreamResponse( | return MessageEndStreamResponse( | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| id=self._message_id, | id=self._message_id, | ||||
| files=self._recorded_files, | files=self._recorded_files, | ||||
| metadata=extras.get("metadata", {}), | |||||
| metadata=extras, | |||||
| ) | ) | ||||
| def _handle_output_moderation_chunk(self, text: str) -> bool: | def _handle_output_moderation_chunk(self, text: str) -> bool: |
| WorkflowAppStreamResponse, | WorkflowAppStreamResponse, | ||||
| WorkflowFinishStreamResponse, | WorkflowFinishStreamResponse, | ||||
| WorkflowStartStreamResponse, | WorkflowStartStreamResponse, | ||||
| WorkflowTaskState, | |||||
| ) | ) | ||||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | ||||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | ||||
| ) | ) | ||||
| self._application_generate_entity = application_generate_entity | self._application_generate_entity = application_generate_entity | ||||
| self._workflow_id = workflow.id | |||||
| self._workflow_features_dict = workflow.features_dict | self._workflow_features_dict = workflow.features_dict | ||||
| self._task_state = WorkflowTaskState() | |||||
| self._workflow_run_id = "" | self._workflow_run_id = "" | ||||
| def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: | def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: | ||||
| if tts_publisher: | if tts_publisher: | ||||
| tts_publisher.publish(queue_message) | tts_publisher.publish(queue_message) | ||||
| self._task_state.answer += delta_text | |||||
| yield self._text_chunk_to_stream_response( | yield self._text_chunk_to_stream_response( | ||||
| delta_text, from_variable_selector=event.from_variable_selector | delta_text, from_variable_selector=event.from_variable_selector | ||||
| ) | ) |
| from collections.abc import Mapping | |||||
| from collections.abc import Mapping, Sequence | |||||
| from datetime import datetime | from datetime import datetime | ||||
| from enum import Enum, StrEnum | from enum import Enum, StrEnum | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | ||||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | from core.workflow.entities.node_entities import AgentNodeStrategyInit | ||||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | ||||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | ||||
| """ | """ | ||||
| event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES | event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES | ||||
| retriever_resources: list[dict] | |||||
| retriever_resources: Sequence[RetrievalSourceMetadata] | |||||
| in_iteration_id: Optional[str] = None | in_iteration_id: Optional[str] = None | ||||
| """iteration id if node is in iteration""" | """iteration id if node is in iteration""" | ||||
| in_loop_id: Optional[str] = None | in_loop_id: Optional[str] = None |
| from enum import Enum | from enum import Enum | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from pydantic import BaseModel, ConfigDict | |||||
| from pydantic import BaseModel, ConfigDict, Field | |||||
| from core.model_runtime.entities.llm_entities import LLMResult | |||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | |||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | from core.workflow.entities.node_entities import AgentNodeStrategyInit | ||||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | ||||
| class AnnotationReplyAccount(BaseModel): | |||||
| id: str | |||||
| name: str | |||||
| class AnnotationReply(BaseModel): | |||||
| id: str | |||||
| account: AnnotationReplyAccount | |||||
| class TaskStateMetadata(BaseModel): | |||||
| annotation_reply: AnnotationReply | None = None | |||||
| retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list) | |||||
| usage: LLMUsage | None = None | |||||
| class TaskState(BaseModel): | class TaskState(BaseModel): | ||||
| """ | """ | ||||
| TaskState entity | TaskState entity | ||||
| """ | """ | ||||
| metadata: dict = {} | |||||
| metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata) | |||||
| class EasyUITaskState(TaskState): | class EasyUITaskState(TaskState): |
| import json | |||||
| import logging | import logging | ||||
| import time | import time | ||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| StreamResponse, | StreamResponse, | ||||
| ) | ) | ||||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | ||||
| from core.app.task_pipeline.message_cycle_manage import MessageCycleManage | |||||
| from core.app.task_pipeline.message_cycle_manager import MessageCycleManager | |||||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | ||||
| from core.model_manager import ModelInstance | from core.model_manager import ModelInstance | ||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage | from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage | ||||
| AssistantPromptMessage, | AssistantPromptMessage, | ||||
| ) | ) | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||||
| from core.ops.entities.trace_entity import TraceTaskName | from core.ops.entities.trace_entity import TraceTaskName | ||||
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | ||||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | from core.prompt.utils.prompt_message_util import PromptMessageUtil | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage): | |||||
| class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||||
| """ | """ | ||||
| EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. | EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. | ||||
| """ | """ | ||||
| ) | ) | ||||
| ) | ) | ||||
| self._message_cycle_manager = MessageCycleManager( | |||||
| application_generate_entity=application_generate_entity, | |||||
| task_state=self._task_state, | |||||
| ) | |||||
| self._conversation_name_generate_thread: Optional[Thread] = None | self._conversation_name_generate_thread: Optional[Thread] = None | ||||
| def process( | def process( | ||||
| ]: | ]: | ||||
| if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: | if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: | ||||
| # start generate conversation name thread | # start generate conversation name thread | ||||
| self._conversation_name_generate_thread = self._generate_conversation_name( | |||||
| self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( | |||||
| conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" | conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" | ||||
| ) | ) | ||||
| if isinstance(stream_response, ErrorStreamResponse): | if isinstance(stream_response, ErrorStreamResponse): | ||||
| raise stream_response.err | raise stream_response.err | ||||
| elif isinstance(stream_response, MessageEndStreamResponse): | elif isinstance(stream_response, MessageEndStreamResponse): | ||||
| extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} | |||||
| extras = {"usage": self._task_state.llm_result.usage.model_dump()} | |||||
| if self._task_state.metadata: | if self._task_state.metadata: | ||||
| extras["metadata"] = self._task_state.metadata | |||||
| extras["metadata"] = self._task_state.metadata.model_dump() | |||||
| response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] | response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] | ||||
| if self._conversation_mode == AppMode.COMPLETION.value: | if self._conversation_mode == AppMode.COMPLETION.value: | ||||
| response = CompletionAppBlockingResponse( | response = CompletionAppBlockingResponse( | ||||
| ) | ) | ||||
| if output_moderation_answer: | if output_moderation_answer: | ||||
| self._task_state.llm_result.message.content = output_moderation_answer | self._task_state.llm_result.message.content = output_moderation_answer | ||||
| yield self._message_replace_to_stream_response(answer=output_moderation_answer) | |||||
| yield self._message_cycle_manager.message_replace_to_stream_response( | |||||
| answer=output_moderation_answer | |||||
| ) | |||||
| with Session(db.engine) as session: | with Session(db.engine) as session: | ||||
| # Save message | # Save message | ||||
| message_end_resp = self._message_end_to_stream_response() | message_end_resp = self._message_end_to_stream_response() | ||||
| yield message_end_resp | yield message_end_resp | ||||
| elif isinstance(event, QueueRetrieverResourcesEvent): | elif isinstance(event, QueueRetrieverResourcesEvent): | ||||
| self._handle_retriever_resources(event) | |||||
| self._message_cycle_manager.handle_retriever_resources(event) | |||||
| elif isinstance(event, QueueAnnotationReplyEvent): | elif isinstance(event, QueueAnnotationReplyEvent): | ||||
| annotation = self._handle_annotation_reply(event) | |||||
| annotation = self._message_cycle_manager.handle_annotation_reply(event) | |||||
| if annotation: | if annotation: | ||||
| self._task_state.llm_result.message.content = annotation.content | self._task_state.llm_result.message.content = annotation.content | ||||
| elif isinstance(event, QueueAgentThoughtEvent): | elif isinstance(event, QueueAgentThoughtEvent): | ||||
| if agent_thought_response is not None: | if agent_thought_response is not None: | ||||
| yield agent_thought_response | yield agent_thought_response | ||||
| elif isinstance(event, QueueMessageFileEvent): | elif isinstance(event, QueueMessageFileEvent): | ||||
| response = self._message_file_to_stream_response(event) | |||||
| response = self._message_cycle_manager.message_file_to_stream_response(event) | |||||
| if response: | if response: | ||||
| yield response | yield response | ||||
| elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): | elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): | ||||
| self._task_state.llm_result.message.content = current_content | self._task_state.llm_result.message.content = current_content | ||||
| if isinstance(event, QueueLLMChunkEvent): | if isinstance(event, QueueLLMChunkEvent): | ||||
| yield self._message_to_stream_response( | |||||
| yield self._message_cycle_manager.message_to_stream_response( | |||||
| answer=cast(str, delta_text), | answer=cast(str, delta_text), | ||||
| message_id=self._message_id, | message_id=self._message_id, | ||||
| ) | ) | ||||
| message_id=self._message_id, | message_id=self._message_id, | ||||
| ) | ) | ||||
| elif isinstance(event, QueueMessageReplaceEvent): | elif isinstance(event, QueueMessageReplaceEvent): | ||||
| yield self._message_replace_to_stream_response(answer=event.text) | |||||
| yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) | |||||
| elif isinstance(event, QueuePingEvent): | elif isinstance(event, QueuePingEvent): | ||||
| yield self._ping_stream_response() | yield self._ping_stream_response() | ||||
| else: | else: | ||||
| message.provider_response_latency = time.perf_counter() - self._start_at | message.provider_response_latency = time.perf_counter() - self._start_at | ||||
| message.total_price = usage.total_price | message.total_price = usage.total_price | ||||
| message.currency = usage.currency | message.currency = usage.currency | ||||
| message.message_metadata = ( | |||||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||||
| ) | |||||
| message.message_metadata = self._task_state.metadata.model_dump_json() | |||||
| if trace_manager: | if trace_manager: | ||||
| trace_manager.add_trace_task( | trace_manager.add_trace_task( | ||||
| Message end to stream response. | Message end to stream response. | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) | |||||
| extras = {} | |||||
| if self._task_state.metadata: | |||||
| extras["metadata"] = self._task_state.metadata | |||||
| self._task_state.metadata.usage = self._task_state.llm_result.usage | |||||
| metadata_dict = self._task_state.metadata.model_dump() | |||||
| return MessageEndStreamResponse( | return MessageEndStreamResponse( | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| id=self._message_id, | id=self._message_id, | ||||
| metadata=extras.get("metadata", {}), | |||||
| metadata=metadata_dict, | |||||
| ) | ) | ||||
| def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: | def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: |
| QueueRetrieverResourcesEvent, | QueueRetrieverResourcesEvent, | ||||
| ) | ) | ||||
| from core.app.entities.task_entities import ( | from core.app.entities.task_entities import ( | ||||
| AnnotationReply, | |||||
| AnnotationReplyAccount, | |||||
| EasyUITaskState, | EasyUITaskState, | ||||
| MessageFileStreamResponse, | MessageFileStreamResponse, | ||||
| MessageReplaceStreamResponse, | MessageReplaceStreamResponse, | ||||
| from services.annotation_service import AppAnnotationService | from services.annotation_service import AppAnnotationService | ||||
| class MessageCycleManage: | |||||
| class MessageCycleManager: | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| *, | *, | ||||
| self._application_generate_entity = application_generate_entity | self._application_generate_entity = application_generate_entity | ||||
| self._task_state = task_state | self._task_state = task_state | ||||
| def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: | |||||
| def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: | |||||
| """ | """ | ||||
| Generate conversation name. | Generate conversation name. | ||||
| :param conversation_id: conversation id | :param conversation_id: conversation id | ||||
| db.session.commit() | db.session.commit() | ||||
| db.session.close() | db.session.close() | ||||
| def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: | |||||
| def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: | |||||
| """ | """ | ||||
| Handle annotation reply. | Handle annotation reply. | ||||
| :param event: event | :param event: event | ||||
| annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) | annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) | ||||
| if annotation: | if annotation: | ||||
| account = annotation.account | account = annotation.account | ||||
| self._task_state.metadata["annotation_reply"] = { | |||||
| "id": annotation.id, | |||||
| "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, | |||||
| } | |||||
| self._task_state.metadata.annotation_reply = AnnotationReply( | |||||
| id=annotation.id, | |||||
| account=AnnotationReplyAccount( | |||||
| id=annotation.account_id, | |||||
| name=account.name if account else "Dify user", | |||||
| ), | |||||
| ) | |||||
| return annotation | return annotation | ||||
| return None | return None | ||||
| def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: | |||||
| def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: | |||||
| """ | """ | ||||
| Handle retriever resources. | Handle retriever resources. | ||||
| :param event: event | :param event: event | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| if self._application_generate_entity.app_config.additional_features.show_retrieve_source: | if self._application_generate_entity.app_config.additional_features.show_retrieve_source: | ||||
| self._task_state.metadata["retriever_resources"] = event.retriever_resources | |||||
| self._task_state.metadata.retriever_resources = event.retriever_resources | |||||
| def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: | |||||
| def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: | |||||
| """ | """ | ||||
| Message file to stream response. | Message file to stream response. | ||||
| :param event: event | :param event: event | ||||
| return None | return None | ||||
| def _message_to_stream_response( | |||||
| def message_to_stream_response( | |||||
| self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None | self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None | ||||
| ) -> MessageStreamResponse: | ) -> MessageStreamResponse: | ||||
| """ | """ | ||||
| from_variable_selector=from_variable_selector, | from_variable_selector=from_variable_selector, | ||||
| ) | ) | ||||
| def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: | |||||
| def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: | |||||
| """ | """ | ||||
| Message replace to stream response. | Message replace to stream response. | ||||
| :param answer: answer | :param answer: answer |
| import logging | import logging | ||||
| from collections.abc import Sequence | |||||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.app.entities.queue_entities import QueueRetrieverResourcesEvent | from core.app.entities.queue_entities import QueueRetrieverResourcesEvent | ||||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||||
| from core.rag.index_processor.constant.index_type import IndexType | from core.rag.index_processor.constant.index_type import IndexType | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| db.session.commit() | db.session.commit() | ||||
| def return_retriever_resource_info(self, resource: list): | |||||
| # TODO(-LAN-): Improve type check | |||||
| def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]): | |||||
| """Handle return_retriever_resource_info.""" | """Handle return_retriever_resource_info.""" | ||||
| self._queue_manager.publish( | self._queue_manager.publish( | ||||
| QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER | QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER |
| from typing import Any, Optional | |||||
| from pydantic import BaseModel | |||||
| class RetrievalSourceMetadata(BaseModel): | |||||
| position: Optional[int] = None | |||||
| dataset_id: Optional[str] = None | |||||
| dataset_name: Optional[str] = None | |||||
| document_id: Optional[str] = None | |||||
| document_name: Optional[str] = None | |||||
| data_source_type: Optional[str] = None | |||||
| segment_id: Optional[str] = None | |||||
| retriever_from: Optional[str] = None | |||||
| score: Optional[float] = None | |||||
| hit_count: Optional[int] = None | |||||
| word_count: Optional[int] = None | |||||
| segment_position: Optional[int] = None | |||||
| index_node_hash: Optional[str] = None | |||||
| content: Optional[str] = None | |||||
| page: Optional[int] = None | |||||
| doc_metadata: Optional[dict[str, Any]] = None | |||||
| title: Optional[str] = None |
| from core.rag.data_post_processor.data_post_processor import DataPostProcessor | from core.rag.data_post_processor.data_post_processor import DataPostProcessor | ||||
| from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | ||||
| from core.rag.datasource.retrieval_service import RetrievalService | from core.rag.datasource.retrieval_service import RetrievalService | ||||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||||
| from core.rag.entities.context_entities import DocumentContext | from core.rag.entities.context_entities import DocumentContext | ||||
| from core.rag.entities.metadata_entities import Condition, MetadataCondition | from core.rag.entities.metadata_entities import Condition, MetadataCondition | ||||
| from core.rag.index_processor.constant.index_type import IndexType | from core.rag.index_processor.constant.index_type import IndexType | ||||
| dify_documents = [item for item in all_documents if item.provider == "dify"] | dify_documents = [item for item in all_documents if item.provider == "dify"] | ||||
| external_documents = [item for item in all_documents if item.provider == "external"] | external_documents = [item for item in all_documents if item.provider == "external"] | ||||
| document_context_list = [] | |||||
| retrieval_resource_list = [] | |||||
| document_context_list: list[DocumentContext] = [] | |||||
| retrieval_resource_list: list[RetrievalSourceMetadata] = [] | |||||
| # deal with external documents | # deal with external documents | ||||
| for item in external_documents: | for item in external_documents: | ||||
| document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score"))) | document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score"))) | ||||
| source = { | |||||
| "dataset_id": item.metadata.get("dataset_id"), | |||||
| "dataset_name": item.metadata.get("dataset_name"), | |||||
| "document_id": item.metadata.get("document_id") or item.metadata.get("title"), | |||||
| "document_name": item.metadata.get("title"), | |||||
| "data_source_type": "external", | |||||
| "retriever_from": invoke_from.to_source(), | |||||
| "score": item.metadata.get("score"), | |||||
| "content": item.page_content, | |||||
| } | |||||
| source = RetrievalSourceMetadata( | |||||
| dataset_id=item.metadata.get("dataset_id"), | |||||
| dataset_name=item.metadata.get("dataset_name"), | |||||
| document_id=item.metadata.get("document_id") or item.metadata.get("title"), | |||||
| document_name=item.metadata.get("title"), | |||||
| data_source_type="external", | |||||
| retriever_from=invoke_from.to_source(), | |||||
| score=item.metadata.get("score"), | |||||
| content=item.page_content, | |||||
| ) | |||||
| retrieval_resource_list.append(source) | retrieval_resource_list.append(source) | ||||
| # deal with dify documents | # deal with dify documents | ||||
| if dify_documents: | if dify_documents: | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| if dataset and document: | if dataset and document: | ||||
| source = { | |||||
| "dataset_id": dataset.id, | |||||
| "dataset_name": dataset.name, | |||||
| "document_id": document.id, | |||||
| "document_name": document.name, | |||||
| "data_source_type": document.data_source_type, | |||||
| "segment_id": segment.id, | |||||
| "retriever_from": invoke_from.to_source(), | |||||
| "score": record.score or 0.0, | |||||
| "doc_metadata": document.doc_metadata, | |||||
| } | |||||
| source = RetrievalSourceMetadata( | |||||
| dataset_id=dataset.id, | |||||
| dataset_name=dataset.name, | |||||
| document_id=document.id, | |||||
| document_name=document.name, | |||||
| data_source_type=document.data_source_type, | |||||
| segment_id=segment.id, | |||||
| retriever_from=invoke_from.to_source(), | |||||
| score=record.score or 0.0, | |||||
| doc_metadata=document.doc_metadata, | |||||
| ) | |||||
| if invoke_from.to_source() == "dev": | if invoke_from.to_source() == "dev": | ||||
| source["hit_count"] = segment.hit_count | |||||
| source["word_count"] = segment.word_count | |||||
| source["segment_position"] = segment.position | |||||
| source["index_node_hash"] = segment.index_node_hash | |||||
| source.hit_count = segment.hit_count | |||||
| source.word_count = segment.word_count | |||||
| source.segment_position = segment.position | |||||
| source.index_node_hash = segment.index_node_hash | |||||
| if segment.answer: | if segment.answer: | ||||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||||
| source.content = f"question:{segment.content} \nanswer:{segment.answer}" | |||||
| else: | else: | ||||
| source["content"] = segment.content | |||||
| source.content = segment.content | |||||
| retrieval_resource_list.append(source) | retrieval_resource_list.append(source) | ||||
| if hit_callback and retrieval_resource_list: | if hit_callback and retrieval_resource_list: | ||||
| retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True) | |||||
| retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True) | |||||
| for position, item in enumerate(retrieval_resource_list, start=1): | for position, item in enumerate(retrieval_resource_list, start=1): | ||||
| item["position"] = position | |||||
| item.position = position | |||||
| hit_callback.return_retriever_resource_info(retrieval_resource_list) | hit_callback.return_retriever_resource_info(retrieval_resource_list) | ||||
| if document_context_list: | if document_context_list: | ||||
| document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) | document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) |
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.rag.datasource.retrieval_service import RetrievalService | from core.rag.datasource.retrieval_service import RetrievalService | ||||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||||
| from core.rag.models.document import Document as RagDocument | from core.rag.models.document import Document as RagDocument | ||||
| from core.rag.rerank.rerank_model import RerankModelRunner | from core.rag.rerank.rerank_model import RerankModelRunner | ||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| else: | else: | ||||
| document_context_list.append(segment.get_sign_content()) | document_context_list.append(segment.get_sign_content()) | ||||
| if self.return_resource: | if self.return_resource: | ||||
| context_list = [] | |||||
| context_list: list[RetrievalSourceMetadata] = [] | |||||
| resource_number = 1 | resource_number = 1 | ||||
| for segment in sorted_segments: | for segment in sorted_segments: | ||||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| if dataset and document: | if dataset and document: | ||||
| source = { | |||||
| "position": resource_number, | |||||
| "dataset_id": dataset.id, | |||||
| "dataset_name": dataset.name, | |||||
| "document_id": document.id, | |||||
| "document_name": document.name, | |||||
| "data_source_type": document.data_source_type, | |||||
| "segment_id": segment.id, | |||||
| "retriever_from": self.retriever_from, | |||||
| "score": document_score_list.get(segment.index_node_id, None), | |||||
| "doc_metadata": document.doc_metadata, | |||||
| } | |||||
| source = RetrievalSourceMetadata( | |||||
| position=resource_number, | |||||
| dataset_id=dataset.id, | |||||
| dataset_name=dataset.name, | |||||
| document_id=document.id, | |||||
| document_name=document.name, | |||||
| data_source_type=document.data_source_type, | |||||
| segment_id=segment.id, | |||||
| retriever_from=self.retriever_from, | |||||
| score=document_score_list.get(segment.index_node_id, None), | |||||
| doc_metadata=document.doc_metadata, | |||||
| ) | |||||
| if self.retriever_from == "dev": | if self.retriever_from == "dev": | ||||
| source["hit_count"] = segment.hit_count | |||||
| source["word_count"] = segment.word_count | |||||
| source["segment_position"] = segment.position | |||||
| source["index_node_hash"] = segment.index_node_hash | |||||
| source.hit_count = segment.hit_count | |||||
| source.word_count = segment.word_count | |||||
| source.segment_position = segment.position | |||||
| source.index_node_hash = segment.index_node_hash | |||||
| if segment.answer: | if segment.answer: | ||||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||||
| source.content = f"question:{segment.content} \nanswer:{segment.answer}" | |||||
| else: | else: | ||||
| source["content"] = segment.content | |||||
| source.content = segment.content | |||||
| context_list.append(source) | context_list.append(source) | ||||
| resource_number += 1 | resource_number += 1 | ||||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig | from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig | ||||
| from core.rag.datasource.retrieval_service import RetrievalService | from core.rag.datasource.retrieval_service import RetrievalService | ||||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||||
| from core.rag.entities.context_entities import DocumentContext | from core.rag.entities.context_entities import DocumentContext | ||||
| from core.rag.models.document import Document as RetrievalDocument | from core.rag.models.document import Document as RetrievalDocument | ||||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | ||||
| from models.dataset import Document as DatasetDocument | from models.dataset import Document as DatasetDocument | ||||
| from services.external_knowledge_service import ExternalDatasetService | from services.external_knowledge_service import ExternalDatasetService | ||||
| default_retrieval_model = { | |||||
| default_retrieval_model: dict[str, Any] = { | |||||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | ||||
| "reranking_enable": False, | "reranking_enable": False, | ||||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | ||||
| else: | else: | ||||
| document_ids_filter = None | document_ids_filter = None | ||||
| if dataset.provider == "external": | if dataset.provider == "external": | ||||
| results = [] | |||||
| results: list[RetrievalDocument] = [] | |||||
| external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | ||||
| tenant_id=dataset.tenant_id, | tenant_id=dataset.tenant_id, | ||||
| dataset_id=dataset.id, | dataset_id=dataset.id, | ||||
| document.metadata["dataset_name"] = dataset.name | document.metadata["dataset_name"] = dataset.name | ||||
| results.append(document) | results.append(document) | ||||
| # deal with external documents | # deal with external documents | ||||
| context_list = [] | |||||
| context_list: list[RetrievalSourceMetadata] = [] | |||||
| for position, item in enumerate(results, start=1): | for position, item in enumerate(results, start=1): | ||||
| if item.metadata is not None: | if item.metadata is not None: | ||||
| source = { | |||||
| "position": position, | |||||
| "dataset_id": item.metadata.get("dataset_id"), | |||||
| "dataset_name": item.metadata.get("dataset_name"), | |||||
| "document_id": item.metadata.get("document_id") or item.metadata.get("title"), | |||||
| "document_name": item.metadata.get("title"), | |||||
| "data_source_type": "external", | |||||
| "retriever_from": self.retriever_from, | |||||
| "score": item.metadata.get("score"), | |||||
| "title": item.metadata.get("title"), | |||||
| "content": item.page_content, | |||||
| } | |||||
| source = RetrievalSourceMetadata( | |||||
| position=position, | |||||
| dataset_id=item.metadata.get("dataset_id"), | |||||
| dataset_name=item.metadata.get("dataset_name"), | |||||
| document_id=item.metadata.get("document_id") or item.metadata.get("title"), | |||||
| document_name=item.metadata.get("title"), | |||||
| data_source_type="external", | |||||
| retriever_from=self.retriever_from, | |||||
| score=item.metadata.get("score"), | |||||
| title=item.metadata.get("title"), | |||||
| content=item.page_content, | |||||
| ) | |||||
| context_list.append(source) | context_list.append(source) | ||||
| for hit_callback in self.hit_callbacks: | for hit_callback in self.hit_callbacks: | ||||
| hit_callback.return_retriever_resource_info(context_list) | hit_callback.return_retriever_resource_info(context_list) | ||||
| return "" | return "" | ||||
| # get retrieval model , if the model is not setting , using default | # get retrieval model , if the model is not setting , using default | ||||
| retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model | retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model | ||||
| retrieval_resource_list = [] | |||||
| retrieval_resource_list: list[RetrievalSourceMetadata] = [] | |||||
| if dataset.indexing_technique == "economy": | if dataset.indexing_technique == "economy": | ||||
| # use keyword table query | # use keyword table query | ||||
| documents = RetrievalService.retrieve( | documents = RetrievalService.retrieve( | ||||
| for item in documents: | for item in documents: | ||||
| if item.metadata is not None and item.metadata.get("score"): | if item.metadata is not None and item.metadata.get("score"): | ||||
| document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | ||||
| document_context_list = [] | |||||
| document_context_list: list[DocumentContext] = [] | |||||
| records = RetrievalService.format_retrieval_documents(documents) | records = RetrievalService.format_retrieval_documents(documents) | ||||
| if records: | if records: | ||||
| for record in records: | for record in records: | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| if dataset and document: | if dataset and document: | ||||
| source = { | |||||
| "dataset_id": dataset.id, | |||||
| "dataset_name": dataset.name, | |||||
| "document_id": document.id, # type: ignore | |||||
| "document_name": document.name, # type: ignore | |||||
| "data_source_type": document.data_source_type, # type: ignore | |||||
| "segment_id": segment.id, | |||||
| "retriever_from": self.retriever_from, | |||||
| "score": record.score or 0.0, | |||||
| "doc_metadata": document.doc_metadata, # type: ignore | |||||
| } | |||||
| source = RetrievalSourceMetadata( | |||||
| dataset_id=dataset.id, | |||||
| dataset_name=dataset.name, | |||||
| document_id=document.id, # type: ignore | |||||
| document_name=document.name, # type: ignore | |||||
| data_source_type=document.data_source_type, # type: ignore | |||||
| segment_id=segment.id, | |||||
| retriever_from=self.retriever_from, | |||||
| score=record.score or 0.0, | |||||
| doc_metadata=document.doc_metadata, # type: ignore | |||||
| ) | |||||
| if self.retriever_from == "dev": | if self.retriever_from == "dev": | ||||
| source["hit_count"] = segment.hit_count | |||||
| source["word_count"] = segment.word_count | |||||
| source["segment_position"] = segment.position | |||||
| source["index_node_hash"] = segment.index_node_hash | |||||
| source.hit_count = segment.hit_count | |||||
| source.word_count = segment.word_count | |||||
| source.segment_position = segment.position | |||||
| source.index_node_hash = segment.index_node_hash | |||||
| if segment.answer: | if segment.answer: | ||||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||||
| source.content = f"question:{segment.content} \nanswer:{segment.answer}" | |||||
| else: | else: | ||||
| source["content"] = segment.content | |||||
| source.content = segment.content | |||||
| retrieval_resource_list.append(source) | retrieval_resource_list.append(source) | ||||
| if self.return_resource and retrieval_resource_list: | if self.return_resource and retrieval_resource_list: | ||||
| retrieval_resource_list = sorted( | retrieval_resource_list = sorted( | ||||
| retrieval_resource_list, | retrieval_resource_list, | ||||
| key=lambda x: x.get("score") or 0.0, | |||||
| key=lambda x: x.score or 0.0, | |||||
| reverse=True, | reverse=True, | ||||
| ) | ) | ||||
| for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore | for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore | ||||
| item["position"] = position # type: ignore | |||||
| item.position = position # type: ignore | |||||
| for hit_callback in self.hit_callbacks: | for hit_callback in self.hit_callbacks: | ||||
| hit_callback.return_retriever_resource_info(retrieval_resource_list) | hit_callback.return_retriever_resource_info(retrieval_resource_list) | ||||
| if document_context_list: | if document_context_list: |
| from collections.abc import Mapping | |||||
| from collections.abc import Mapping, Sequence | |||||
| from datetime import datetime | from datetime import datetime | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | from core.workflow.entities.node_entities import AgentNodeStrategyInit | ||||
| from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState | from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState | ||||
| from core.workflow.nodes import NodeType | from core.workflow.nodes import NodeType | ||||
| class NodeRunRetrieverResourceEvent(BaseNodeEvent): | class NodeRunRetrieverResourceEvent(BaseNodeEvent): | ||||
| retriever_resources: list[dict] = Field(..., description="retriever resources") | |||||
| retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") | |||||
| context: str = Field(..., description="context") | context: str = Field(..., description="context") | ||||
| from collections.abc import Sequence | |||||
| from datetime import datetime | from datetime import datetime | ||||
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||
| from core.model_runtime.entities.llm_entities import LLMUsage | from core.model_runtime.entities.llm_entities import LLMUsage | ||||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||||
| from core.workflow.entities.node_entities import NodeRunResult | from core.workflow.entities.node_entities import NodeRunResult | ||||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | ||||
| class RunRetrieverResourceEvent(BaseModel): | class RunRetrieverResourceEvent(BaseModel): | ||||
| retriever_resources: list[dict] = Field(..., description="retriever resources") | |||||
| retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") | |||||
| context: str = Field(..., description="context") | context: str = Field(..., description="context") | ||||
| from core.plugin.entities.plugin import ModelProviderID | from core.plugin.entities.plugin import ModelProviderID | ||||
| from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig | from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig | ||||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | from core.prompt.utils.prompt_message_util import PromptMessageUtil | ||||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||||
| from core.variables import ( | from core.variables import ( | ||||
| ArrayAnySegment, | ArrayAnySegment, | ||||
| ArrayFileSegment, | ArrayFileSegment, | ||||
| yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) | yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) | ||||
| elif isinstance(context_value_variable, ArraySegment): | elif isinstance(context_value_variable, ArraySegment): | ||||
| context_str = "" | context_str = "" | ||||
| original_retriever_resource = [] | |||||
| original_retriever_resource: list[RetrievalSourceMetadata] = [] | |||||
| for item in context_value_variable.value: | for item in context_value_variable.value: | ||||
| if isinstance(item, str): | if isinstance(item, str): | ||||
| context_str += item + "\n" | context_str += item + "\n" | ||||
| retriever_resources=original_retriever_resource, context=context_str.strip() | retriever_resources=original_retriever_resource, context=context_str.strip() | ||||
| ) | ) | ||||
| def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: | |||||
| def _convert_to_original_retriever_resource(self, context_dict: dict): | |||||
| if ( | if ( | ||||
| "metadata" in context_dict | "metadata" in context_dict | ||||
| and "_source" in context_dict["metadata"] | and "_source" in context_dict["metadata"] | ||||
| ): | ): | ||||
| metadata = context_dict.get("metadata", {}) | metadata = context_dict.get("metadata", {}) | ||||
| source = { | |||||
| "position": metadata.get("position"), | |||||
| "dataset_id": metadata.get("dataset_id"), | |||||
| "dataset_name": metadata.get("dataset_name"), | |||||
| "document_id": metadata.get("document_id"), | |||||
| "document_name": metadata.get("document_name"), | |||||
| "data_source_type": metadata.get("data_source_type"), | |||||
| "segment_id": metadata.get("segment_id"), | |||||
| "retriever_from": metadata.get("retriever_from"), | |||||
| "score": metadata.get("score"), | |||||
| "hit_count": metadata.get("segment_hit_count"), | |||||
| "word_count": metadata.get("segment_word_count"), | |||||
| "segment_position": metadata.get("segment_position"), | |||||
| "index_node_hash": metadata.get("segment_index_node_hash"), | |||||
| "content": context_dict.get("content"), | |||||
| "page": metadata.get("page"), | |||||
| "doc_metadata": metadata.get("doc_metadata"), | |||||
| } | |||||
| source = RetrievalSourceMetadata( | |||||
| position=metadata.get("position"), | |||||
| dataset_id=metadata.get("dataset_id"), | |||||
| dataset_name=metadata.get("dataset_name"), | |||||
| document_id=metadata.get("document_id"), | |||||
| document_name=metadata.get("document_name"), | |||||
| data_source_type=metadata.get("data_source_type"), | |||||
| segment_id=metadata.get("segment_id"), | |||||
| retriever_from=metadata.get("retriever_from"), | |||||
| score=metadata.get("score"), | |||||
| hit_count=metadata.get("segment_hit_count"), | |||||
| word_count=metadata.get("segment_word_count"), | |||||
| segment_position=metadata.get("segment_position"), | |||||
| index_node_hash=metadata.get("segment_index_node_hash"), | |||||
| content=context_dict.get("content"), | |||||
| page=metadata.get("page"), | |||||
| doc_metadata=metadata.get("doc_metadata"), | |||||
| ) | |||||
| return source | return source | ||||