Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.4.2
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| import logging | |||
| import time | |||
| from collections.abc import Generator, Mapping | |||
| @@ -57,10 +56,9 @@ from core.app.entities.task_entities import ( | |||
| WorkflowTaskState, | |||
| ) | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.app.task_pipeline.message_cycle_manage import MessageCycleManage | |||
| from core.app.task_pipeline.message_cycle_manager import MessageCycleManager | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType | |||
| from core.workflow.enums import SystemVariableKey | |||
| @@ -141,7 +139,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| ) | |||
| self._task_state = WorkflowTaskState() | |||
| self._message_cycle_manager = MessageCycleManage( | |||
| self._message_cycle_manager = MessageCycleManager( | |||
| application_generate_entity=application_generate_entity, task_state=self._task_state | |||
| ) | |||
| @@ -162,7 +160,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| :return: | |||
| """ | |||
| # start generate conversation name thread | |||
| self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name( | |||
| self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( | |||
| conversation_id=self._conversation_id, query=self._application_generate_entity.query | |||
| ) | |||
| @@ -605,22 +603,18 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| yield self._message_end_to_stream_response() | |||
| break | |||
| elif isinstance(event, QueueRetrieverResourcesEvent): | |||
| self._message_cycle_manager._handle_retriever_resources(event) | |||
| self._message_cycle_manager.handle_retriever_resources(event) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| message = self._get_message(session=session) | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| message.message_metadata = self._task_state.metadata.model_dump_json() | |||
| session.commit() | |||
| elif isinstance(event, QueueAnnotationReplyEvent): | |||
| self._message_cycle_manager._handle_annotation_reply(event) | |||
| self._message_cycle_manager.handle_annotation_reply(event) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| message = self._get_message(session=session) | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| message.message_metadata = self._task_state.metadata.model_dump_json() | |||
| session.commit() | |||
| elif isinstance(event, QueueTextChunkEvent): | |||
| delta_text = event.text | |||
| @@ -637,12 +631,12 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| tts_publisher.publish(queue_message) | |||
| self._task_state.answer += delta_text | |||
| yield self._message_cycle_manager._message_to_stream_response( | |||
| yield self._message_cycle_manager.message_to_stream_response( | |||
| answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector | |||
| ) | |||
| elif isinstance(event, QueueMessageReplaceEvent): | |||
| # published by moderation | |||
| yield self._message_cycle_manager._message_replace_to_stream_response( | |||
| yield self._message_cycle_manager.message_replace_to_stream_response( | |||
| answer=event.text, reason=event.reason | |||
| ) | |||
| elif isinstance(event, QueueAdvancedChatMessageEndEvent): | |||
| @@ -654,7 +648,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| ) | |||
| if output_moderation_answer: | |||
| self._task_state.answer = output_moderation_answer | |||
| yield self._message_cycle_manager._message_replace_to_stream_response( | |||
| yield self._message_cycle_manager.message_replace_to_stream_response( | |||
| answer=output_moderation_answer, | |||
| reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, | |||
| ) | |||
| @@ -683,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| message = self._get_message(session=session) | |||
| message.answer = self._task_state.answer | |||
| message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| message.message_metadata = self._task_state.metadata.model_dump_json() | |||
| message_files = [ | |||
| MessageFile( | |||
| message_id=message.id, | |||
| @@ -713,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| message.answer_price_unit = usage.completion_price_unit | |||
| message.total_price = usage.total_price | |||
| message.currency = usage.currency | |||
| self._task_state.metadata["usage"] = jsonable_encoder(usage) | |||
| self._task_state.metadata.usage = usage | |||
| else: | |||
| self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) | |||
| self._task_state.metadata.usage = LLMUsage.empty_usage() | |||
| message_was_created.send( | |||
| message, | |||
| application_generate_entity=self._application_generate_entity, | |||
| @@ -726,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| Message end to stream response. | |||
| :return: | |||
| """ | |||
| extras = {} | |||
| if self._task_state.metadata: | |||
| extras["metadata"] = self._task_state.metadata.copy() | |||
| extras = self._task_state.metadata.model_dump() | |||
| if "annotation_reply" in extras["metadata"]: | |||
| del extras["metadata"]["annotation_reply"] | |||
| if self._task_state.metadata.annotation_reply: | |||
| del extras["annotation_reply"] | |||
| return MessageEndStreamResponse( | |||
| task_id=self._application_generate_entity.task_id, | |||
| id=self._message_id, | |||
| files=self._recorded_files, | |||
| metadata=extras.get("metadata", {}), | |||
| metadata=extras, | |||
| ) | |||
| def _handle_output_moderation_chunk(self, text: str) -> bool: | |||
| @@ -50,7 +50,6 @@ from core.app.entities.task_entities import ( | |||
| WorkflowAppStreamResponse, | |||
| WorkflowFinishStreamResponse, | |||
| WorkflowStartStreamResponse, | |||
| WorkflowTaskState, | |||
| ) | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| @@ -130,9 +129,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| ) | |||
| self._application_generate_entity = application_generate_entity | |||
| self._workflow_id = workflow.id | |||
| self._workflow_features_dict = workflow.features_dict | |||
| self._task_state = WorkflowTaskState() | |||
| self._workflow_run_id = "" | |||
| def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: | |||
| @@ -543,7 +540,6 @@ class WorkflowAppGenerateTaskPipeline: | |||
| if tts_publisher: | |||
| tts_publisher.publish(queue_message) | |||
| self._task_state.answer += delta_text | |||
| yield self._text_chunk_to_stream_response( | |||
| delta_text, from_variable_selector=event.from_variable_selector | |||
| ) | |||
| @@ -1,4 +1,4 @@ | |||
| from collections.abc import Mapping | |||
| from collections.abc import Mapping, Sequence | |||
| from datetime import datetime | |||
| from enum import Enum, StrEnum | |||
| from typing import Any, Optional | |||
| @@ -6,6 +6,7 @@ from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey | |||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | |||
| @@ -283,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): | |||
| """ | |||
| event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES | |||
| retriever_resources: list[dict] | |||
| retriever_resources: Sequence[RetrievalSourceMetadata] | |||
| in_iteration_id: Optional[str] = None | |||
| """iteration id if node is in iteration""" | |||
| in_loop_id: Optional[str] = None | |||
| @@ -2,20 +2,37 @@ from collections.abc import Mapping, Sequence | |||
| from enum import Enum | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel, ConfigDict | |||
| from pydantic import BaseModel, ConfigDict, Field | |||
| from core.model_runtime.entities.llm_entities import LLMResult | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| class AnnotationReplyAccount(BaseModel): | |||
| id: str | |||
| name: str | |||
| class AnnotationReply(BaseModel): | |||
| id: str | |||
| account: AnnotationReplyAccount | |||
| class TaskStateMetadata(BaseModel): | |||
| annotation_reply: AnnotationReply | None = None | |||
| retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list) | |||
| usage: LLMUsage | None = None | |||
| class TaskState(BaseModel): | |||
| """ | |||
| TaskState entity | |||
| """ | |||
| metadata: dict = {} | |||
| metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata) | |||
| class EasyUITaskState(TaskState): | |||
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| import logging | |||
| import time | |||
| from collections.abc import Generator | |||
| @@ -43,7 +42,7 @@ from core.app.entities.task_entities import ( | |||
| StreamResponse, | |||
| ) | |||
| from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline | |||
| from core.app.task_pipeline.message_cycle_manage import MessageCycleManage | |||
| from core.app.task_pipeline.message_cycle_manager import MessageCycleManager | |||
| from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage | |||
| @@ -51,7 +50,6 @@ from core.model_runtime.entities.message_entities import ( | |||
| AssistantPromptMessage, | |||
| ) | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.ops.entities.trace_entity import TraceTaskName | |||
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| @@ -63,7 +61,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought | |||
| logger = logging.getLogger(__name__) | |||
| class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage): | |||
| class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| """ | |||
| EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. | |||
| """ | |||
| @@ -104,6 +102,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| ) | |||
| ) | |||
| self._message_cycle_manager = MessageCycleManager( | |||
| application_generate_entity=application_generate_entity, | |||
| task_state=self._task_state, | |||
| ) | |||
| self._conversation_name_generate_thread: Optional[Thread] = None | |||
| def process( | |||
| @@ -115,7 +118,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| ]: | |||
| if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: | |||
| # start generate conversation name thread | |||
| self._conversation_name_generate_thread = self._generate_conversation_name( | |||
| self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( | |||
| conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" | |||
| ) | |||
| @@ -136,9 +139,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| if isinstance(stream_response, ErrorStreamResponse): | |||
| raise stream_response.err | |||
| elif isinstance(stream_response, MessageEndStreamResponse): | |||
| extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} | |||
| extras = {"usage": self._task_state.llm_result.usage.model_dump()} | |||
| if self._task_state.metadata: | |||
| extras["metadata"] = self._task_state.metadata | |||
| extras["metadata"] = self._task_state.metadata.model_dump() | |||
| response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] | |||
| if self._conversation_mode == AppMode.COMPLETION.value: | |||
| response = CompletionAppBlockingResponse( | |||
| @@ -277,7 +280,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| ) | |||
| if output_moderation_answer: | |||
| self._task_state.llm_result.message.content = output_moderation_answer | |||
| yield self._message_replace_to_stream_response(answer=output_moderation_answer) | |||
| yield self._message_cycle_manager.message_replace_to_stream_response( | |||
| answer=output_moderation_answer | |||
| ) | |||
| with Session(db.engine) as session: | |||
| # Save message | |||
| @@ -286,9 +291,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| message_end_resp = self._message_end_to_stream_response() | |||
| yield message_end_resp | |||
| elif isinstance(event, QueueRetrieverResourcesEvent): | |||
| self._handle_retriever_resources(event) | |||
| self._message_cycle_manager.handle_retriever_resources(event) | |||
| elif isinstance(event, QueueAnnotationReplyEvent): | |||
| annotation = self._handle_annotation_reply(event) | |||
| annotation = self._message_cycle_manager.handle_annotation_reply(event) | |||
| if annotation: | |||
| self._task_state.llm_result.message.content = annotation.content | |||
| elif isinstance(event, QueueAgentThoughtEvent): | |||
| @@ -296,7 +301,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| if agent_thought_response is not None: | |||
| yield agent_thought_response | |||
| elif isinstance(event, QueueMessageFileEvent): | |||
| response = self._message_file_to_stream_response(event) | |||
| response = self._message_cycle_manager.message_file_to_stream_response(event) | |||
| if response: | |||
| yield response | |||
| elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): | |||
| @@ -318,7 +323,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| self._task_state.llm_result.message.content = current_content | |||
| if isinstance(event, QueueLLMChunkEvent): | |||
| yield self._message_to_stream_response( | |||
| yield self._message_cycle_manager.message_to_stream_response( | |||
| answer=cast(str, delta_text), | |||
| message_id=self._message_id, | |||
| ) | |||
| @@ -328,7 +333,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| message_id=self._message_id, | |||
| ) | |||
| elif isinstance(event, QueueMessageReplaceEvent): | |||
| yield self._message_replace_to_stream_response(answer=event.text) | |||
| yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) | |||
| elif isinstance(event, QueuePingEvent): | |||
| yield self._ping_stream_response() | |||
| else: | |||
| @@ -372,9 +377,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| message.provider_response_latency = time.perf_counter() - self._start_at | |||
| message.total_price = usage.total_price | |||
| message.currency = usage.currency | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| message.message_metadata = self._task_state.metadata.model_dump_json() | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| @@ -423,16 +426,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan | |||
| Message end to stream response. | |||
| :return: | |||
| """ | |||
| self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) | |||
| extras = {} | |||
| if self._task_state.metadata: | |||
| extras["metadata"] = self._task_state.metadata | |||
| self._task_state.metadata.usage = self._task_state.llm_result.usage | |||
| metadata_dict = self._task_state.metadata.model_dump() | |||
| return MessageEndStreamResponse( | |||
| task_id=self._application_generate_entity.task_id, | |||
| id=self._message_id, | |||
| metadata=extras.get("metadata", {}), | |||
| metadata=metadata_dict, | |||
| ) | |||
| def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: | |||
| @@ -17,6 +17,8 @@ from core.app.entities.queue_entities import ( | |||
| QueueRetrieverResourcesEvent, | |||
| ) | |||
| from core.app.entities.task_entities import ( | |||
| AnnotationReply, | |||
| AnnotationReplyAccount, | |||
| EasyUITaskState, | |||
| MessageFileStreamResponse, | |||
| MessageReplaceStreamResponse, | |||
| @@ -30,7 +32,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile | |||
| from services.annotation_service import AppAnnotationService | |||
| class MessageCycleManage: | |||
| class MessageCycleManager: | |||
| def __init__( | |||
| self, | |||
| *, | |||
| @@ -45,7 +47,7 @@ class MessageCycleManage: | |||
| self._application_generate_entity = application_generate_entity | |||
| self._task_state = task_state | |||
| def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: | |||
| def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: | |||
| """ | |||
| Generate conversation name. | |||
| :param conversation_id: conversation id | |||
| @@ -102,7 +104,7 @@ class MessageCycleManage: | |||
| db.session.commit() | |||
| db.session.close() | |||
| def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: | |||
| def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: | |||
| """ | |||
| Handle annotation reply. | |||
| :param event: event | |||
| @@ -111,25 +113,28 @@ class MessageCycleManage: | |||
| annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) | |||
| if annotation: | |||
| account = annotation.account | |||
| self._task_state.metadata["annotation_reply"] = { | |||
| "id": annotation.id, | |||
| "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, | |||
| } | |||
| self._task_state.metadata.annotation_reply = AnnotationReply( | |||
| id=annotation.id, | |||
| account=AnnotationReplyAccount( | |||
| id=annotation.account_id, | |||
| name=account.name if account else "Dify user", | |||
| ), | |||
| ) | |||
| return annotation | |||
| return None | |||
| def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: | |||
| def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: | |||
| """ | |||
| Handle retriever resources. | |||
| :param event: event | |||
| :return: | |||
| """ | |||
| if self._application_generate_entity.app_config.additional_features.show_retrieve_source: | |||
| self._task_state.metadata["retriever_resources"] = event.retriever_resources | |||
| self._task_state.metadata.retriever_resources = event.retriever_resources | |||
| def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: | |||
| def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: | |||
| """ | |||
| Message file to stream response. | |||
| :param event: event | |||
| @@ -166,7 +171,7 @@ class MessageCycleManage: | |||
| return None | |||
| def _message_to_stream_response( | |||
| def message_to_stream_response( | |||
| self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None | |||
| ) -> MessageStreamResponse: | |||
| """ | |||
| @@ -182,7 +187,7 @@ class MessageCycleManage: | |||
| from_variable_selector=from_variable_selector, | |||
| ) | |||
| def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: | |||
| def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: | |||
| """ | |||
| Message replace to stream response. | |||
| :param answer: answer | |||
| @@ -1,8 +1,10 @@ | |||
| import logging | |||
| from collections.abc import Sequence | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.app.entities.queue_entities import QueueRetrieverResourcesEvent | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_database import db | |||
| @@ -85,7 +87,8 @@ class DatasetIndexToolCallbackHandler: | |||
| db.session.commit() | |||
| def return_retriever_resource_info(self, resource: list): | |||
| # TODO(-LAN-): Improve type check | |||
| def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]): | |||
| """Handle return_retriever_resource_info.""" | |||
| self._queue_manager.publish( | |||
| QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER | |||
| @@ -0,0 +1,23 @@ | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| class RetrievalSourceMetadata(BaseModel): | |||
| position: Optional[int] = None | |||
| dataset_id: Optional[str] = None | |||
| dataset_name: Optional[str] = None | |||
| document_id: Optional[str] = None | |||
| document_name: Optional[str] = None | |||
| data_source_type: Optional[str] = None | |||
| segment_id: Optional[str] = None | |||
| retriever_from: Optional[str] = None | |||
| score: Optional[float] = None | |||
| hit_count: Optional[int] = None | |||
| word_count: Optional[int] = None | |||
| segment_position: Optional[int] = None | |||
| index_node_hash: Optional[str] = None | |||
| content: Optional[str] = None | |||
| page: Optional[int] = None | |||
| doc_metadata: Optional[dict[str, Any]] = None | |||
| title: Optional[str] = None | |||
| @@ -35,6 +35,7 @@ from core.prompt.simple_prompt_transform import ModelMode | |||
| from core.rag.data_post_processor.data_post_processor import DataPostProcessor | |||
| from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.rag.entities.context_entities import DocumentContext | |||
| from core.rag.entities.metadata_entities import Condition, MetadataCondition | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| @@ -198,21 +199,21 @@ class DatasetRetrieval: | |||
| dify_documents = [item for item in all_documents if item.provider == "dify"] | |||
| external_documents = [item for item in all_documents if item.provider == "external"] | |||
| document_context_list = [] | |||
| retrieval_resource_list = [] | |||
| document_context_list: list[DocumentContext] = [] | |||
| retrieval_resource_list: list[RetrievalSourceMetadata] = [] | |||
| # deal with external documents | |||
| for item in external_documents: | |||
| document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score"))) | |||
| source = { | |||
| "dataset_id": item.metadata.get("dataset_id"), | |||
| "dataset_name": item.metadata.get("dataset_name"), | |||
| "document_id": item.metadata.get("document_id") or item.metadata.get("title"), | |||
| "document_name": item.metadata.get("title"), | |||
| "data_source_type": "external", | |||
| "retriever_from": invoke_from.to_source(), | |||
| "score": item.metadata.get("score"), | |||
| "content": item.page_content, | |||
| } | |||
| source = RetrievalSourceMetadata( | |||
| dataset_id=item.metadata.get("dataset_id"), | |||
| dataset_name=item.metadata.get("dataset_name"), | |||
| document_id=item.metadata.get("document_id") or item.metadata.get("title"), | |||
| document_name=item.metadata.get("title"), | |||
| data_source_type="external", | |||
| retriever_from=invoke_from.to_source(), | |||
| score=item.metadata.get("score"), | |||
| content=item.page_content, | |||
| ) | |||
| retrieval_resource_list.append(source) | |||
| # deal with dify documents | |||
| if dify_documents: | |||
| @@ -248,32 +249,32 @@ class DatasetRetrieval: | |||
| .first() | |||
| ) | |||
| if dataset and document: | |||
| source = { | |||
| "dataset_id": dataset.id, | |||
| "dataset_name": dataset.name, | |||
| "document_id": document.id, | |||
| "document_name": document.name, | |||
| "data_source_type": document.data_source_type, | |||
| "segment_id": segment.id, | |||
| "retriever_from": invoke_from.to_source(), | |||
| "score": record.score or 0.0, | |||
| "doc_metadata": document.doc_metadata, | |||
| } | |||
| source = RetrievalSourceMetadata( | |||
| dataset_id=dataset.id, | |||
| dataset_name=dataset.name, | |||
| document_id=document.id, | |||
| document_name=document.name, | |||
| data_source_type=document.data_source_type, | |||
| segment_id=segment.id, | |||
| retriever_from=invoke_from.to_source(), | |||
| score=record.score or 0.0, | |||
| doc_metadata=document.doc_metadata, | |||
| ) | |||
| if invoke_from.to_source() == "dev": | |||
| source["hit_count"] = segment.hit_count | |||
| source["word_count"] = segment.word_count | |||
| source["segment_position"] = segment.position | |||
| source["index_node_hash"] = segment.index_node_hash | |||
| source.hit_count = segment.hit_count | |||
| source.word_count = segment.word_count | |||
| source.segment_position = segment.position | |||
| source.index_node_hash = segment.index_node_hash | |||
| if segment.answer: | |||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| source.content = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| else: | |||
| source["content"] = segment.content | |||
| source.content = segment.content | |||
| retrieval_resource_list.append(source) | |||
| if hit_callback and retrieval_resource_list: | |||
| retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True) | |||
| retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True) | |||
| for position, item in enumerate(retrieval_resource_list, start=1): | |||
| item["position"] = position | |||
| item.position = position | |||
| hit_callback.return_retriever_resource_info(retrieval_resource_list) | |||
| if document_context_list: | |||
| document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) | |||
| @@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.rag.models.document import Document as RagDocument | |||
| from core.rag.rerank.rerank_model import RerankModelRunner | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| @@ -107,7 +108,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| else: | |||
| document_context_list.append(segment.get_sign_content()) | |||
| if self.return_resource: | |||
| context_list = [] | |||
| context_list: list[RetrievalSourceMetadata] = [] | |||
| resource_number = 1 | |||
| for segment in sorted_segments: | |||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | |||
| @@ -121,28 +122,28 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| .first() | |||
| ) | |||
| if dataset and document: | |||
| source = { | |||
| "position": resource_number, | |||
| "dataset_id": dataset.id, | |||
| "dataset_name": dataset.name, | |||
| "document_id": document.id, | |||
| "document_name": document.name, | |||
| "data_source_type": document.data_source_type, | |||
| "segment_id": segment.id, | |||
| "retriever_from": self.retriever_from, | |||
| "score": document_score_list.get(segment.index_node_id, None), | |||
| "doc_metadata": document.doc_metadata, | |||
| } | |||
| source = RetrievalSourceMetadata( | |||
| position=resource_number, | |||
| dataset_id=dataset.id, | |||
| dataset_name=dataset.name, | |||
| document_id=document.id, | |||
| document_name=document.name, | |||
| data_source_type=document.data_source_type, | |||
| segment_id=segment.id, | |||
| retriever_from=self.retriever_from, | |||
| score=document_score_list.get(segment.index_node_id, None), | |||
| doc_metadata=document.doc_metadata, | |||
| ) | |||
| if self.retriever_from == "dev": | |||
| source["hit_count"] = segment.hit_count | |||
| source["word_count"] = segment.word_count | |||
| source["segment_position"] = segment.position | |||
| source["index_node_hash"] = segment.index_node_hash | |||
| source.hit_count = segment.hit_count | |||
| source.word_count = segment.word_count | |||
| source.segment_position = segment.position | |||
| source.index_node_hash = segment.index_node_hash | |||
| if segment.answer: | |||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| source.content = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| else: | |||
| source["content"] = segment.content | |||
| source.content = segment.content | |||
| context_list.append(source) | |||
| resource_number += 1 | |||
| @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field | |||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.rag.entities.context_entities import DocumentContext | |||
| from core.rag.models.document import Document as RetrievalDocument | |||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||
| @@ -14,7 +15,7 @@ from models.dataset import Dataset | |||
| from models.dataset import Document as DatasetDocument | |||
| from services.external_knowledge_service import ExternalDatasetService | |||
| default_retrieval_model = { | |||
| default_retrieval_model: dict[str, Any] = { | |||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| "reranking_enable": False, | |||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||
| @@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| else: | |||
| document_ids_filter = None | |||
| if dataset.provider == "external": | |||
| results = [] | |||
| results: list[RetrievalDocument] = [] | |||
| external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | |||
| tenant_id=dataset.tenant_id, | |||
| dataset_id=dataset.id, | |||
| @@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| document.metadata["dataset_name"] = dataset.name | |||
| results.append(document) | |||
| # deal with external documents | |||
| context_list = [] | |||
| context_list: list[RetrievalSourceMetadata] = [] | |||
| for position, item in enumerate(results, start=1): | |||
| if item.metadata is not None: | |||
| source = { | |||
| "position": position, | |||
| "dataset_id": item.metadata.get("dataset_id"), | |||
| "dataset_name": item.metadata.get("dataset_name"), | |||
| "document_id": item.metadata.get("document_id") or item.metadata.get("title"), | |||
| "document_name": item.metadata.get("title"), | |||
| "data_source_type": "external", | |||
| "retriever_from": self.retriever_from, | |||
| "score": item.metadata.get("score"), | |||
| "title": item.metadata.get("title"), | |||
| "content": item.page_content, | |||
| } | |||
| source = RetrievalSourceMetadata( | |||
| position=position, | |||
| dataset_id=item.metadata.get("dataset_id"), | |||
| dataset_name=item.metadata.get("dataset_name"), | |||
| document_id=item.metadata.get("document_id") or item.metadata.get("title"), | |||
| document_name=item.metadata.get("title"), | |||
| data_source_type="external", | |||
| retriever_from=self.retriever_from, | |||
| score=item.metadata.get("score"), | |||
| title=item.metadata.get("title"), | |||
| content=item.page_content, | |||
| ) | |||
| context_list.append(source) | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.return_retriever_resource_info(context_list) | |||
| @@ -125,7 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| return "" | |||
| # get retrieval model , if the model is not setting , using default | |||
| retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model | |||
| retrieval_resource_list = [] | |||
| retrieval_resource_list: list[RetrievalSourceMetadata] = [] | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| documents = RetrievalService.retrieve( | |||
| @@ -163,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| for item in documents: | |||
| if item.metadata is not None and item.metadata.get("score"): | |||
| document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | |||
| document_context_list = [] | |||
| document_context_list: list[DocumentContext] = [] | |||
| records = RetrievalService.format_retrieval_documents(documents) | |||
| if records: | |||
| for record in records: | |||
| @@ -197,37 +198,37 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| .first() | |||
| ) | |||
| if dataset and document: | |||
| source = { | |||
| "dataset_id": dataset.id, | |||
| "dataset_name": dataset.name, | |||
| "document_id": document.id, # type: ignore | |||
| "document_name": document.name, # type: ignore | |||
| "data_source_type": document.data_source_type, # type: ignore | |||
| "segment_id": segment.id, | |||
| "retriever_from": self.retriever_from, | |||
| "score": record.score or 0.0, | |||
| "doc_metadata": document.doc_metadata, # type: ignore | |||
| } | |||
| source = RetrievalSourceMetadata( | |||
| dataset_id=dataset.id, | |||
| dataset_name=dataset.name, | |||
| document_id=document.id, # type: ignore | |||
| document_name=document.name, # type: ignore | |||
| data_source_type=document.data_source_type, # type: ignore | |||
| segment_id=segment.id, | |||
| retriever_from=self.retriever_from, | |||
| score=record.score or 0.0, | |||
| doc_metadata=document.doc_metadata, # type: ignore | |||
| ) | |||
| if self.retriever_from == "dev": | |||
| source["hit_count"] = segment.hit_count | |||
| source["word_count"] = segment.word_count | |||
| source["segment_position"] = segment.position | |||
| source["index_node_hash"] = segment.index_node_hash | |||
| source.hit_count = segment.hit_count | |||
| source.word_count = segment.word_count | |||
| source.segment_position = segment.position | |||
| source.index_node_hash = segment.index_node_hash | |||
| if segment.answer: | |||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| source.content = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| else: | |||
| source["content"] = segment.content | |||
| source.content = segment.content | |||
| retrieval_resource_list.append(source) | |||
| if self.return_resource and retrieval_resource_list: | |||
| retrieval_resource_list = sorted( | |||
| retrieval_resource_list, | |||
| key=lambda x: x.get("score") or 0.0, | |||
| key=lambda x: x.score or 0.0, | |||
| reverse=True, | |||
| ) | |||
| for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore | |||
| item["position"] = position # type: ignore | |||
| item.position = position # type: ignore | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.return_retriever_resource_info(retrieval_resource_list) | |||
| if document_context_list: | |||
| @@ -1,9 +1,10 @@ | |||
| from collections.abc import Mapping | |||
| from collections.abc import Mapping, Sequence | |||
| from datetime import datetime | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel, Field | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.workflow.entities.node_entities import AgentNodeStrategyInit | |||
| from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState | |||
| from core.workflow.nodes import NodeType | |||
| @@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent): | |||
| class NodeRunRetrieverResourceEvent(BaseNodeEvent): | |||
| retriever_resources: list[dict] = Field(..., description="retriever resources") | |||
| retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") | |||
| context: str = Field(..., description="context") | |||
| @@ -1,8 +1,10 @@ | |||
| from collections.abc import Sequence | |||
| from datetime import datetime | |||
| from pydantic import BaseModel, Field | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| @@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel): | |||
| class RunRetrieverResourceEvent(BaseModel): | |||
| retriever_resources: list[dict] = Field(..., description="retriever resources") | |||
| retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") | |||
| context: str = Field(..., description="context") | |||
| @@ -43,6 +43,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.variables import ( | |||
| ArrayAnySegment, | |||
| ArrayFileSegment, | |||
| @@ -474,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) | |||
| elif isinstance(context_value_variable, ArraySegment): | |||
| context_str = "" | |||
| original_retriever_resource = [] | |||
| original_retriever_resource: list[RetrievalSourceMetadata] = [] | |||
| for item in context_value_variable.value: | |||
| if isinstance(item, str): | |||
| context_str += item + "\n" | |||
| @@ -492,7 +493,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| retriever_resources=original_retriever_resource, context=context_str.strip() | |||
| ) | |||
| def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: | |||
| def _convert_to_original_retriever_resource(self, context_dict: dict): | |||
| if ( | |||
| "metadata" in context_dict | |||
| and "_source" in context_dict["metadata"] | |||
| @@ -500,24 +501,24 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| ): | |||
| metadata = context_dict.get("metadata", {}) | |||
| source = { | |||
| "position": metadata.get("position"), | |||
| "dataset_id": metadata.get("dataset_id"), | |||
| "dataset_name": metadata.get("dataset_name"), | |||
| "document_id": metadata.get("document_id"), | |||
| "document_name": metadata.get("document_name"), | |||
| "data_source_type": metadata.get("data_source_type"), | |||
| "segment_id": metadata.get("segment_id"), | |||
| "retriever_from": metadata.get("retriever_from"), | |||
| "score": metadata.get("score"), | |||
| "hit_count": metadata.get("segment_hit_count"), | |||
| "word_count": metadata.get("segment_word_count"), | |||
| "segment_position": metadata.get("segment_position"), | |||
| "index_node_hash": metadata.get("segment_index_node_hash"), | |||
| "content": context_dict.get("content"), | |||
| "page": metadata.get("page"), | |||
| "doc_metadata": metadata.get("doc_metadata"), | |||
| } | |||
| source = RetrievalSourceMetadata( | |||
| position=metadata.get("position"), | |||
| dataset_id=metadata.get("dataset_id"), | |||
| dataset_name=metadata.get("dataset_name"), | |||
| document_id=metadata.get("document_id"), | |||
| document_name=metadata.get("document_name"), | |||
| data_source_type=metadata.get("data_source_type"), | |||
| segment_id=metadata.get("segment_id"), | |||
| retriever_from=metadata.get("retriever_from"), | |||
| score=metadata.get("score"), | |||
| hit_count=metadata.get("segment_hit_count"), | |||
| word_count=metadata.get("segment_word_count"), | |||
| segment_position=metadata.get("segment_position"), | |||
| index_node_hash=metadata.get("segment_index_node_hash"), | |||
| content=context_dict.get("content"), | |||
| page=metadata.get("page"), | |||
| doc_metadata=metadata.get("doc_metadata"), | |||
| ) | |||
| return source | |||