Преглед на файлове

Refactor/message cycle manage and knowledge retrieval (#20460)

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/1.4.2
-LAN- преди 5 месеца
родител
ревизия
a6ea15e63c
No account linked to committer's email address

+ 17
- 27
api/core/app/apps/advanced_chat/generate_task_pipeline.py Целия файл

import json
import logging import logging
import time import time
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
WorkflowTaskState, WorkflowTaskState,
) )
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
) )


self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
self._message_cycle_manager = MessageCycleManage(
self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity, task_state=self._task_state application_generate_entity=application_generate_entity, task_state=self._task_state
) )


:return: :return:
""" """
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query conversation_id=self._conversation_id, query=self._application_generate_entity.query
) )


yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
break break
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._message_cycle_manager._handle_retriever_resources(event)
self._message_cycle_manager.handle_retriever_resources(event)


with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
session.commit() session.commit()
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
self._message_cycle_manager._handle_annotation_reply(event)
self._message_cycle_manager.handle_annotation_reply(event)


with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
session.commit() session.commit()
elif isinstance(event, QueueTextChunkEvent): elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text delta_text = event.text
tts_publisher.publish(queue_message) tts_publisher.publish(queue_message)


self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._message_cycle_manager._message_to_stream_response(
yield self._message_cycle_manager.message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
) )
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation # published by moderation
yield self._message_cycle_manager._message_replace_to_stream_response(
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=event.text, reason=event.reason answer=event.text, reason=event.reason
) )
elif isinstance(event, QueueAdvancedChatMessageEndEvent): elif isinstance(event, QueueAdvancedChatMessageEndEvent):
) )
if output_moderation_answer: if output_moderation_answer:
self._task_state.answer = output_moderation_answer self._task_state.answer = output_moderation_answer
yield self._message_cycle_manager._message_replace_to_stream_response(
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer, answer=output_moderation_answer,
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
) )
message = self._get_message(session=session) message = self._get_message(session=session)
message.answer = self._task_state.answer message.answer = self._task_state.answer
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
message_files = [ message_files = [
MessageFile( MessageFile(
message_id=message.id, message_id=message.id,
message.answer_price_unit = usage.completion_price_unit message.answer_price_unit = usage.completion_price_unit
message.total_price = usage.total_price message.total_price = usage.total_price
message.currency = usage.currency message.currency = usage.currency
self._task_state.metadata["usage"] = jsonable_encoder(usage)
self._task_state.metadata.usage = usage
else: else:
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
self._task_state.metadata.usage = LLMUsage.empty_usage()
message_was_created.send( message_was_created.send(
message, message,
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
Message end to stream response. Message end to stream response.
:return: :return:
""" """
extras = {}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.copy()
extras = self._task_state.metadata.model_dump()


if "annotation_reply" in extras["metadata"]:
del extras["metadata"]["annotation_reply"]
if self._task_state.metadata.annotation_reply:
del extras["annotation_reply"]


return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
id=self._message_id, id=self._message_id,
files=self._recorded_files, files=self._recorded_files,
metadata=extras.get("metadata", {}),
metadata=extras,
) )


def _handle_output_moderation_chunk(self, text: str) -> bool: def _handle_output_moderation_chunk(self, text: str) -> bool:

+ 0
- 4
api/core/app/apps/workflow/generate_task_pipeline.py Целия файл

WorkflowAppStreamResponse, WorkflowAppStreamResponse,
WorkflowFinishStreamResponse, WorkflowFinishStreamResponse,
WorkflowStartStreamResponse, WorkflowStartStreamResponse,
WorkflowTaskState,
) )
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
) )


self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict self._workflow_features_dict = workflow.features_dict
self._task_state = WorkflowTaskState()
self._workflow_run_id = "" self._workflow_run_id = ""


def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
if tts_publisher: if tts_publisher:
tts_publisher.publish(queue_message) tts_publisher.publish(queue_message)


self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response( yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector delta_text, from_variable_selector=event.from_variable_selector
) )

+ 3
- 2
api/core/app/entities/queue_entities.py Целия файл

from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from datetime import datetime from datetime import datetime
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel


from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
""" """


event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
retriever_resources: Sequence[RetrievalSourceMetadata]
in_iteration_id: Optional[str] = None in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
in_loop_id: Optional[str] = None in_loop_id: Optional[str] = None

+ 20
- 3
api/core/app/entities/task_entities.py Целия файл

from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional


from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field


from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus




class AnnotationReplyAccount(BaseModel):
id: str
name: str


class AnnotationReply(BaseModel):
id: str
account: AnnotationReplyAccount


class TaskStateMetadata(BaseModel):
annotation_reply: AnnotationReply | None = None
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list)
usage: LLMUsage | None = None


class TaskState(BaseModel): class TaskState(BaseModel):
""" """
TaskState entity TaskState entity
""" """


metadata: dict = {}
metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata)




class EasyUITaskState(TaskState): class EasyUITaskState(TaskState):

+ 22
- 23
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py Целия файл

import json
import logging import logging
import time import time
from collections.abc import Generator from collections.abc import Generator
StreamResponse, StreamResponse,
) )
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
AssistantPromptMessage, AssistantPromptMessage,
) )
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)




class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage):
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
""" """
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
) )
) )


self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity,
task_state=self._task_state,
)

self._conversation_name_generate_thread: Optional[Thread] = None self._conversation_name_generate_thread: Optional[Thread] = None


def process( def process(
]: ]:
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
) )


if isinstance(stream_response, ErrorStreamResponse): if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse): elif isinstance(stream_response, MessageEndStreamResponse):
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
if self._task_state.metadata: if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata
extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION.value: if self._conversation_mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse( response = CompletionAppBlockingResponse(
) )
if output_moderation_answer: if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer self._task_state.llm_result.message.content = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer
)


with Session(db.engine) as session: with Session(db.engine) as session:
# Save message # Save message
message_end_resp = self._message_end_to_stream_response() message_end_resp = self._message_end_to_stream_response()
yield message_end_resp yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
self._message_cycle_manager.handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
annotation = self._handle_annotation_reply(event)
annotation = self._message_cycle_manager.handle_annotation_reply(event)
if annotation: if annotation:
self._task_state.llm_result.message.content = annotation.content self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent): elif isinstance(event, QueueAgentThoughtEvent):
if agent_thought_response is not None: if agent_thought_response is not None:
yield agent_thought_response yield agent_thought_response
elif isinstance(event, QueueMessageFileEvent): elif isinstance(event, QueueMessageFileEvent):
response = self._message_file_to_stream_response(event)
response = self._message_cycle_manager.message_file_to_stream_response(event)
if response: if response:
yield response yield response
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
self._task_state.llm_result.message.content = current_content self._task_state.llm_result.message.content = current_content


if isinstance(event, QueueLLMChunkEvent): if isinstance(event, QueueLLMChunkEvent):
yield self._message_to_stream_response(
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text), answer=cast(str, delta_text),
message_id=self._message_id, message_id=self._message_id,
) )
message_id=self._message_id, message_id=self._message_id,
) )
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text)
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent): elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response() yield self._ping_stream_response()
else: else:
message.provider_response_latency = time.perf_counter() - self._start_at message.provider_response_latency = time.perf_counter() - self._start_at
message.total_price = usage.total_price message.total_price = usage.total_price
message.currency = usage.currency message.currency = usage.currency
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()


if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
Message end to stream response. Message end to stream response.
:return: :return:
""" """
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)

extras = {}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata

self._task_state.metadata.usage = self._task_state.llm_result.usage
metadata_dict = self._task_state.metadata.model_dump()
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
id=self._message_id, id=self._message_id,
metadata=extras.get("metadata", {}),
metadata=metadata_dict,
) )


def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:

api/core/app/task_pipeline/message_cycle_manage.py → api/core/app/task_pipeline/message_cycle_manager.py Целия файл

QueueRetrieverResourcesEvent, QueueRetrieverResourcesEvent,
) )
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AnnotationReply,
AnnotationReplyAccount,
EasyUITaskState, EasyUITaskState,
MessageFileStreamResponse, MessageFileStreamResponse,
MessageReplaceStreamResponse, MessageReplaceStreamResponse,
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService




class MessageCycleManage:
class MessageCycleManager:
def __init__( def __init__(
self, self,
*, *,
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._task_state = task_state self._task_state = task_state


def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
""" """
Generate conversation name. Generate conversation name.
:param conversation_id: conversation id :param conversation_id: conversation id
db.session.commit() db.session.commit()
db.session.close() db.session.close()


def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
""" """
Handle annotation reply. Handle annotation reply.
:param event: event :param event: event
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation: if annotation:
account = annotation.account account = annotation.account
self._task_state.metadata["annotation_reply"] = {
"id": annotation.id,
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
}
self._task_state.metadata.annotation_reply = AnnotationReply(
id=annotation.id,
account=AnnotationReplyAccount(
id=annotation.account_id,
name=account.name if account else "Dify user",
),
)


return annotation return annotation


return None return None


def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
""" """
Handle retriever resources. Handle retriever resources.
:param event: event :param event: event
:return: :return:
""" """
if self._application_generate_entity.app_config.additional_features.show_retrieve_source: if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata["retriever_resources"] = event.retriever_resources
self._task_state.metadata.retriever_resources = event.retriever_resources


def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
""" """
Message file to stream response. Message file to stream response.
:param event: event :param event: event


return None return None


def _message_to_stream_response(
def message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
) -> MessageStreamResponse: ) -> MessageStreamResponse:
""" """
from_variable_selector=from_variable_selector, from_variable_selector=from_variable_selector,
) )


def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
""" """
Message replace to stream response. Message replace to stream response.
:param answer: answer :param answer: answer

+ 4
- 1
api/core/callback_handler/index_tool_callback_handler.py Целия файл

import logging import logging
from collections.abc import Sequence


from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db


db.session.commit() db.session.commit()


def return_retriever_resource_info(self, resource: list):
# TODO(-LAN-): Improve type check
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
"""Handle return_retriever_resource_info.""" """Handle return_retriever_resource_info."""
self._queue_manager.publish( self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER

+ 23
- 0
api/core/rag/entities/citation_metadata.py Целия файл

from typing import Any, Optional

from pydantic import BaseModel


class RetrievalSourceMetadata(BaseModel):
position: Optional[int] = None
dataset_id: Optional[str] = None
dataset_name: Optional[str] = None
document_id: Optional[str] = None
document_name: Optional[str] = None
data_source_type: Optional[str] = None
segment_id: Optional[str] = None
retriever_from: Optional[str] = None
score: Optional[float] = None
hit_count: Optional[int] = None
word_count: Optional[int] = None
segment_position: Optional[int] = None
index_node_hash: Optional[str] = None
content: Optional[str] = None
page: Optional[int] = None
doc_metadata: Optional[dict[str, Any]] = None
title: Optional[str] = None

+ 32
- 31
api/core/rag/retrieval/dataset_retrieval.py Целия файл

from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType


dify_documents = [item for item in all_documents if item.provider == "dify"] dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"] external_documents = [item for item in all_documents if item.provider == "external"]
document_context_list = []
retrieval_resource_list = []
document_context_list: list[DocumentContext] = []
retrieval_resource_list: list[RetrievalSourceMetadata] = []
# deal with external documents # deal with external documents
for item in external_documents: for item in external_documents:
document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score"))) document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
source = {
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": invoke_from.to_source(),
"score": item.metadata.get("score"),
"content": item.page_content,
}
source = RetrievalSourceMetadata(
dataset_id=item.metadata.get("dataset_id"),
dataset_name=item.metadata.get("dataset_name"),
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
document_name=item.metadata.get("title"),
data_source_type="external",
retriever_from=invoke_from.to_source(),
score=item.metadata.get("score"),
content=item.page_content,
)
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
# deal with dify documents # deal with dify documents
if dify_documents: if dify_documents:
.first() .first()
) )
if dataset and document: if dataset and document:
source = {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": invoke_from.to_source(),
"score": record.score or 0.0,
"doc_metadata": document.doc_metadata,
}
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=invoke_from.to_source(),
score=record.score or 0.0,
doc_metadata=document.doc_metadata,
)


if invoke_from.to_source() == "dev": if invoke_from.to_source() == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
source.hit_count = segment.hit_count
source.word_count = segment.word_count
source.segment_position = segment.position
source.index_node_hash = segment.index_node_hash
if segment.answer: if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else: else:
source["content"] = segment.content
source.content = segment.content
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
if hit_callback and retrieval_resource_list: if hit_callback and retrieval_resource_list:
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True)
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
for position, item in enumerate(retrieval_resource_list, start=1): for position, item in enumerate(retrieval_resource_list, start=1):
item["position"] = position
item.position = position
hit_callback.return_retriever_resource_info(retrieval_resource_list) hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list: if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)

+ 20
- 19
api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py Целия файл

from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.models.document import Document as RagDocument from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
else: else:
document_context_list.append(segment.get_sign_content()) document_context_list.append(segment.get_sign_content())
if self.return_resource: if self.return_resource:
context_list = []
context_list: list[RetrievalSourceMetadata] = []
resource_number = 1 resource_number = 1
for segment in sorted_segments: for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
.first() .first()
) )
if dataset and document: if dataset and document:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),
"doc_metadata": document.doc_metadata,
}
source = RetrievalSourceMetadata(
position=resource_number,
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=self.retriever_from,
score=document_score_list.get(segment.index_node_id, None),
doc_metadata=document.doc_metadata,
)


if self.retriever_from == "dev": if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
source.hit_count = segment.hit_count
source.word_count = segment.word_count
source.segment_position = segment.position
source.index_node_hash = segment.index_node_hash
if segment.answer: if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else: else:
source["content"] = segment.content
source.content = segment.content
context_list.append(source) context_list.append(source)
resource_number += 1 resource_number += 1



+ 37
- 36
api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py Целия файл



from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document as RetrievalDocument from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService


default_retrieval_model = {
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False, "reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
else: else:
document_ids_filter = None document_ids_filter = None
if dataset.provider == "external": if dataset.provider == "external":
results = []
results: list[RetrievalDocument] = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
document.metadata["dataset_name"] = dataset.name document.metadata["dataset_name"] = dataset.name
results.append(document) results.append(document)
# deal with external documents # deal with external documents
context_list = []
context_list: list[RetrievalSourceMetadata] = []
for position, item in enumerate(results, start=1): for position, item in enumerate(results, start=1):
if item.metadata is not None: if item.metadata is not None:
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
source = RetrievalSourceMetadata(
position=position,
dataset_id=item.metadata.get("dataset_id"),
dataset_name=item.metadata.get("dataset_name"),
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
document_name=item.metadata.get("title"),
data_source_type="external",
retriever_from=self.retriever_from,
score=item.metadata.get("score"),
title=item.metadata.get("title"),
content=item.page_content,
)
context_list.append(source) context_list.append(source)
for hit_callback in self.hit_callbacks: for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list) hit_callback.return_retriever_resource_info(context_list)
return "" return ""
# get retrieval model , if the model is not setting , using default # get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
retrieval_resource_list = []
retrieval_resource_list: list[RetrievalSourceMetadata] = []
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
documents = RetrievalService.retrieve( documents = RetrievalService.retrieve(
for item in documents: for item in documents:
if item.metadata is not None and item.metadata.get("score"): if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
document_context_list: list[DocumentContext] = []
records = RetrievalService.format_retrieval_documents(documents) records = RetrievalService.format_retrieval_documents(documents)
if records: if records:
for record in records: for record in records:
.first() .first()
) )
if dataset and document: if dataset and document:
source = {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id, # type: ignore
"document_name": document.name, # type: ignore
"data_source_type": document.data_source_type, # type: ignore
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": record.score or 0.0,
"doc_metadata": document.doc_metadata, # type: ignore
}
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id, # type: ignore
document_name=document.name, # type: ignore
data_source_type=document.data_source_type, # type: ignore
segment_id=segment.id,
retriever_from=self.retriever_from,
score=record.score or 0.0,
doc_metadata=document.doc_metadata, # type: ignore
)


if self.retriever_from == "dev": if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
source.hit_count = segment.hit_count
source.word_count = segment.word_count
source.segment_position = segment.position
source.index_node_hash = segment.index_node_hash
if segment.answer: if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else: else:
source["content"] = segment.content
source.content = segment.content
retrieval_resource_list.append(source) retrieval_resource_list.append(source)


if self.return_resource and retrieval_resource_list: if self.return_resource and retrieval_resource_list:
retrieval_resource_list = sorted( retrieval_resource_list = sorted(
retrieval_resource_list, retrieval_resource_list,
key=lambda x: x.get("score") or 0.0,
key=lambda x: x.score or 0.0,
reverse=True, reverse=True,
) )
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
item["position"] = position # type: ignore
item.position = position # type: ignore
for hit_callback in self.hit_callbacks: for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(retrieval_resource_list) hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list: if document_context_list:

+ 3
- 2
api/core/workflow/graph_engine/entities/event.py Целия файл

from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any, Optional


from pydantic import BaseModel, Field from pydantic import BaseModel, Field


from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType




class NodeRunRetrieverResourceEvent(BaseNodeEvent): class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources")
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context") context: str = Field(..., description="context")





+ 3
- 1
api/core/workflow/nodes/event/event.py Целия файл

from collections.abc import Sequence
from datetime import datetime from datetime import datetime


from pydantic import BaseModel, Field from pydantic import BaseModel, Field


from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus






class RunRetrieverResourceEvent(BaseModel): class RunRetrieverResourceEvent(BaseModel):
retriever_resources: list[dict] = Field(..., description="retriever resources")
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context") context: str = Field(..., description="context")





+ 21
- 20
api/core/workflow/nodes/llm/node.py Целия файл

from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.variables import ( from core.variables import (
ArrayAnySegment, ArrayAnySegment,
ArrayFileSegment, ArrayFileSegment,
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
elif isinstance(context_value_variable, ArraySegment): elif isinstance(context_value_variable, ArraySegment):
context_str = "" context_str = ""
original_retriever_resource = []
original_retriever_resource: list[RetrievalSourceMetadata] = []
for item in context_value_variable.value: for item in context_value_variable.value:
if isinstance(item, str): if isinstance(item, str):
context_str += item + "\n" context_str += item + "\n"
retriever_resources=original_retriever_resource, context=context_str.strip() retriever_resources=original_retriever_resource, context=context_str.strip()
) )


def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
def _convert_to_original_retriever_resource(self, context_dict: dict):
if ( if (
"metadata" in context_dict "metadata" in context_dict
and "_source" in context_dict["metadata"] and "_source" in context_dict["metadata"]
): ):
metadata = context_dict.get("metadata", {}) metadata = context_dict.get("metadata", {})


source = {
"position": metadata.get("position"),
"dataset_id": metadata.get("dataset_id"),
"dataset_name": metadata.get("dataset_name"),
"document_id": metadata.get("document_id"),
"document_name": metadata.get("document_name"),
"data_source_type": metadata.get("data_source_type"),
"segment_id": metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"),
"score": metadata.get("score"),
"hit_count": metadata.get("segment_hit_count"),
"word_count": metadata.get("segment_word_count"),
"segment_position": metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"),
"page": metadata.get("page"),
"doc_metadata": metadata.get("doc_metadata"),
}
source = RetrievalSourceMetadata(
position=metadata.get("position"),
dataset_id=metadata.get("dataset_id"),
dataset_name=metadata.get("dataset_name"),
document_id=metadata.get("document_id"),
document_name=metadata.get("document_name"),
data_source_type=metadata.get("data_source_type"),
segment_id=metadata.get("segment_id"),
retriever_from=metadata.get("retriever_from"),
score=metadata.get("score"),
hit_count=metadata.get("segment_hit_count"),
word_count=metadata.get("segment_word_count"),
segment_position=metadata.get("segment_position"),
index_node_hash=metadata.get("segment_index_node_hash"),
content=context_dict.get("content"),
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
)


return source return source



Loading…
Отказ
Запис