Browse Source

r2

tags/2.0.0-beta.1
jyong 6 months ago
parent
commit
49d1846e63

+ 38
- 38
api/core/datasource/utils/message_transformer.py View File

from mimetypes import guess_extension from mimetypes import guess_extension
from typing import Optional from typing import Optional


from core.datasource.datasource_file_manager import DatasourceFileManager
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage
from core.file import File, FileTransferMethod, FileType from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)




class ToolFileMessageTransformer:
class DatasourceFileMessageTransformer:
@classmethod @classmethod
def transform_tool_invoke_messages(
def transform_datasource_invoke_messages(
cls, cls,
messages: Generator[ToolInvokeMessage, None, None],
messages: Generator[DatasourceInvokeMessage, None, None],
user_id: str, user_id: str,
tenant_id: str, tenant_id: str,
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
) -> Generator[DatasourceInvokeMessage, None, None]:
""" """
Transform tool message and handle file download
Transform datasource message and handle file download
""" """
for message in messages: for message in messages:
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
if message.type in {DatasourceInvokeMessage.MessageType.TEXT, DatasourceInvokeMessage.MessageType.LINK}:
yield message yield message
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(
message.message, ToolInvokeMessage.TextMessage
elif message.type == DatasourceInvokeMessage.MessageType.IMAGE and isinstance(
message.message, DatasourceInvokeMessage.TextMessage
): ):
# try to download image # try to download image
try: try:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)


file = ToolFileManager.create_file_by_url(
file = DatasourceFileManager.create_file_by_url(
user_id=user_id, user_id=user_id,
tenant_id=tenant_id, tenant_id=tenant_id,
file_url=message.message.text, file_url=message.message.text,
conversation_id=conversation_id, conversation_id=conversation_id,
) )


url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}"
url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}"


yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
yield DatasourceInvokeMessage(
type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
message=DatasourceInvokeMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
) )
except Exception as e: except Exception as e:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(
yield DatasourceInvokeMessage(
type=DatasourceInvokeMessage.MessageType.TEXT,
message=DatasourceInvokeMessage.TextMessage(
text=f"Failed to download image: {message.message.text}: {e}" text=f"Failed to download image: {message.message.text}: {e}"
), ),
meta=message.meta.copy() if message.meta is not None else {}, meta=message.meta.copy() if message.meta is not None else {},
) )
elif message.type == ToolInvokeMessage.MessageType.BLOB:
elif message.type == DatasourceInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage # get mime type and save blob to storage
meta = message.meta or {} meta = message.meta or {}


filename = meta.get("file_name", None) filename = meta.get("file_name", None)
# if message is str, encode it to bytes # if message is str, encode it to bytes


if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
if not isinstance(message.message, DatasourceInvokeMessage.BlobMessage):
raise ValueError("unexpected message type") raise ValueError("unexpected message type")


# FIXME: should do a type check here. # FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes) assert isinstance(message.message.blob, bytes)
file = ToolFileManager.create_file_by_raw(
file = DatasourceFileManager.create_file_by_raw(
user_id=user_id, user_id=user_id,
tenant_id=tenant_id, tenant_id=tenant_id,
conversation_id=conversation_id, conversation_id=conversation_id,
filename=filename, filename=filename,
) )


url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype))


# check if file is image # check if file is image
if "image" in mimetype: if "image" in mimetype:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
yield DatasourceInvokeMessage(
type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
message=DatasourceInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {}, meta=meta.copy() if meta is not None else {},
) )
else: else:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BINARY_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
yield DatasourceInvokeMessage(
type=DatasourceInvokeMessage.MessageType.BINARY_LINK,
message=DatasourceInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {}, meta=meta.copy() if meta is not None else {},
) )
elif message.type == ToolInvokeMessage.MessageType.FILE:
elif message.type == DatasourceInvokeMessage.MessageType.FILE:
meta = message.meta or {} meta = message.meta or {}
file = meta.get("file", None) file = meta.get("file", None)
if isinstance(file, File): if isinstance(file, File):
assert file.related_id is not None assert file.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE: if file.type == FileType.IMAGE:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=ToolInvokeMessage.TextMessage(text=url),
yield DatasourceInvokeMessage(
type=DatasourceInvokeMessage.MessageType.IMAGE_LINK,
message=DatasourceInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {}, meta=meta.copy() if meta is not None else {},
) )
else: else:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text=url),
yield DatasourceInvokeMessage(
type=DatasourceInvokeMessage.MessageType.LINK,
message=DatasourceInvokeMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {}, meta=meta.copy() if meta is not None else {},
) )
else: else:
yield message yield message


@classmethod @classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
return f"/files/tools/{tool_file_id}{extension or '.bin'}"
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"

+ 2
- 1
api/core/plugin/manager/datasource.py View File



response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/dispatch/datasource/invoke_first_step",
f"plugin/{tenant_id}/dispatch/datasource/{online_document}/pages",
ToolInvokeMessage, ToolInvokeMessage,
data={ data={
"user_id": user_id, "user_id": user_id,
"data": { "data": {
"provider": datasource_provider_id.provider_name, "provider": datasource_provider_id.provider_name,
"datasource": datasource_name, "datasource": datasource_name,

"credentials": credentials, "credentials": credentials,
"datasource_parameters": datasource_parameters, "datasource_parameters": datasource_parameters,
}, },

+ 59
- 63
api/core/workflow/nodes/datasource/datasource_node.py View File

from sqlalchemy.orm import Session from sqlalchemy.orm import Session


from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.datasource.datasource_engine import DatasourceEngine
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage, DatasourceParameter
from core.datasource.errors import DatasourceInvokeError
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.file import File, FileTransferMethod from core.file import File, FileTransferMethod
from core.plugin.manager.exc import PluginDaemonClientSideError from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager from core.plugin.manager.plugin import PluginInstallationManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService


from .entities import DatasourceNodeData from .entities import DatasourceNodeData
from .exc import (
ToolFileError,
ToolNodeError,
ToolParameterError,
)
from .exc import DatasourceNodeError, DatasourceParameterError, ToolFileError




class DatasourceNode(BaseNode[DatasourceNodeData]): class DatasourceNode(BaseNode[DatasourceNodeData]):


# get datasource runtime # get datasource runtime
try: try:
from core.tools.tool_manager import ToolManager
from core.datasource.datasource_manager import DatasourceManager


tool_runtime = ToolManager.get_workflow_tool_runtime(
datasource_runtime = DatasourceManager.get_workflow_datasource_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
) )
except ToolNodeError as e:
except DatasourceNodeError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
return return


# get parameters # get parameters
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
datasource_parameters = datasource_runtime.get_merged_runtime_parameters() or []
parameters = self._generate_parameters( parameters = self._generate_parameters(
tool_parameters=tool_parameters,
datasource_parameters=datasource_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data, node_data=self.node_data,
) )
parameters_for_log = self._generate_parameters( parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
datasource_parameters=datasource_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data, node_data=self.node_data,
for_log=True, for_log=True,
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])


try: try:
message_stream = ToolEngine.generic_invoke(
tool=tool_runtime,
tool_parameters=parameters,
message_stream = DatasourceEngine.generic_invoke(
datasource=datasource_runtime,
datasource_parameters=parameters,
user_id=self.user_id, user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth, workflow_call_depth=self.workflow_call_depth,
app_id=self.app_id, app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None, conversation_id=conversation_id.text if conversation_id else None,
) )
except ToolNodeError as e:
except DatasourceNodeError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool: {str(e)}",
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to invoke datasource: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__,
) )
) )
return return


try: try:
# convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
except (PluginDaemonClientSideError, ToolInvokeError) as e:
# convert datasource messages
yield from self._transform_message(message_stream, datasource_info, parameters_for_log)
except (PluginDaemonClientSideError, DatasourceInvokeError) as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to transform tool message: {str(e)}",
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to transform datasource message: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__,
) )
) )
def _generate_parameters( def _generate_parameters(
self, self,
*, *,
tool_parameters: Sequence[ToolParameter],
datasource_parameters: Sequence[DatasourceParameter],
variable_pool: VariablePool, variable_pool: VariablePool,
node_data: ToolNodeData,
node_data: DatasourceNodeData,
for_log: bool = False, for_log: bool = False,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Mapping[str, Any]: A dictionary containing the generated parameters. Mapping[str, Any]: A dictionary containing the generated parameters.


""" """
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}


result: dict[str, Any] = {} result: dict[str, Any] = {}
for parameter_name in node_data.tool_parameters:
parameter = tool_parameters_dictionary.get(parameter_name)
for parameter_name in node_data.datasource_parameters:
parameter = datasource_parameters_dictionary.get(parameter_name)
if not parameter: if not parameter:
result[parameter_name] = None result[parameter_name] = None
continue continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
datasource_input = node_data.datasource_parameters[parameter_name]
if datasource_input.type == "variable":
variable = variable_pool.get(datasource_input.value)
if variable is None: if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
parameter_value = variable.value parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value))
elif datasource_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(datasource_input.value))
parameter_value = segment_group.log if for_log else segment_group.text parameter_value = segment_group.log if for_log else segment_group.text
else: else:
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
result[parameter_name] = parameter_value result[parameter_name] = parameter_value


return result return result


def _transform_message( def _transform_message(
self, self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
messages: Generator[DatasourceInvokeMessage, None, None],
datasource_info: Mapping[str, Any],
parameters_for_log: dict[str, Any], parameters_for_log: dict[str, Any],
) -> Generator: ) -> Generator:
""" """
Convert ToolInvokeMessages into tuple[plain_text, files] Convert ToolInvokeMessages into tuple[plain_text, files]
""" """
# transform message and handle file storage # transform message and handle file storage
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages, messages=messages,
user_id=self.user_id, user_id=self.user_id,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,


for message in message_stream: for message in message_stream:
if message.type in { if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
DatasourceInvokeMessage.MessageType.IMAGE_LINK,
DatasourceInvokeMessage.MessageType.BINARY_LINK,
DatasourceInvokeMessage.MessageType.IMAGE,
}: }:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)


url = message.message.text url = message.message.text
if message.meta: if message.meta:
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
) )
files.append(file) files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
elif message.type == DatasourceInvokeMessage.MessageType.BLOB:
# get tool file id # get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
assert message.meta assert message.meta


tool_file_id = message.message.text.split("/")[-1].split(".")[0] tool_file_id = message.message.text.split("/")[-1].split(".")[0]
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
) )
) )
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
elif message.type == DatasourceInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
text += message.message.text text += message.message.text
yield RunStreamChunkEvent( yield RunStreamChunkEvent(
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
) )
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
elif message.type == DatasourceInvokeMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT: if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {}) msg_metadata = message.message.json_object.pop("execution_metadata", {})
agent_execution_metadata = { agent_execution_metadata = {
if key in NodeRunMetadataKey.__members__.values() if key in NodeRunMetadataKey.__members__.values()
} }
json.append(message.message.json_object) json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
elif message.type == DatasourceInvokeMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n" stream_text = f"Link: {message.message.text}\n"
text += stream_text text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage)
variable_name = message.message.variable_name variable_name = message.message.variable_name
variable_value = message.message.variable_value variable_value = message.message.variable_value
if message.message.stream: if message.message.stream:
) )
else: else:
variables[variable_name] = variable_value variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
elif message.type == DatasourceInvokeMessage.MessageType.FILE:
assert message.meta is not None assert message.meta is not None
files.append(message.meta["file"]) files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
elif message.type == DatasourceInvokeMessage.MessageType.LOG:
assert isinstance(message.message, DatasourceInvokeMessage.LogMessage)
if message.message.metadata: if message.message.metadata:
icon = tool_info.get("icon", "")
icon = datasource_info.get("icon", "")
dict_metadata = dict(message.message.metadata) dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"): if dict_metadata.get("provider"):
manager = PluginInstallationManager() manager = PluginInstallationManager()
outputs={"text": text, "files": files, "json": json, **variables}, outputs={"text": text, "files": files, "json": json, **variables},
metadata={ metadata={
**agent_execution_metadata, **agent_execution_metadata,
NodeRunMetadataKey.TOOL_INFO: tool_info,
NodeRunMetadataKey.DATASOURCE_INFO: datasource_info,
NodeRunMetadataKey.AGENT_LOG: agent_logs, NodeRunMetadataKey.AGENT_LOG: agent_logs,
}, },
inputs=parameters_for_log, inputs=parameters_for_log,
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: ToolNodeData,
node_data: DatasourceNodeData,
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" """
Extract variable selector to variable mapping Extract variable selector to variable mapping
:return: :return:
""" """
result = {} result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
for parameter_name in node_data.datasource_parameters:
input = node_data.datasource_parameters[parameter_name]
if input.type == "mixed": if input.type == "mixed":
assert isinstance(input.value, str) assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors() selectors = VariableTemplateParser(input.value).extract_variable_selectors()

+ 6
- 6
api/core/workflow/nodes/datasource/exc.py View File

class ToolNodeError(ValueError):
"""Base exception for tool node errors."""
class DatasourceNodeError(ValueError):
"""Base exception for datasource node errors."""


pass pass




class ToolParameterError(ToolNodeError):
"""Exception raised for errors in tool parameters."""
class DatasourceParameterError(DatasourceNodeError):
"""Exception raised for errors in datasource parameters."""


pass pass




class ToolFileError(ToolNodeError):
"""Exception raised for errors related to tool files."""
class DatasourceFileError(DatasourceNodeError):
"""Exception raised for errors related to datasource files."""


pass pass

+ 1
- 0
api/core/workflow/nodes/enums.py View File

ANSWER = "answer" ANSWER = "answer"
LLM = "llm" LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
KNOWLEDGE_INDEX = "knowledge-index"
IF_ELSE = "if-else" IF_ELSE = "if-else"
CODE = "code" CODE = "code"
TEMPLATE_TRANSFORM = "template-transform" TEMPLATE_TRANSFORM = "template-transform"

+ 3
- 0
api/core/workflow/nodes/knowledge_index/__init__.py View File

from .knowledge_index_node import KnowledgeRetrievalNode

__all__ = ["KnowledgeRetrievalNode"]

+ 147
- 0
api/core/workflow/nodes/knowledge_index/entities.py View File

from collections.abc import Sequence
from typing import Any, Literal, Optional, Union

from pydantic import BaseModel, Field

from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import VisionConfig


class RerankingModelConfig(BaseModel):
"""
Reranking Model Config.
"""

provider: str
model: str

class VectorSetting(BaseModel):
"""
Vector Setting.
"""

vector_weight: float
embedding_provider_name: str
embedding_model_name: str


class KeywordSetting(BaseModel):
"""
Keyword Setting.
"""

keyword_weight: float

class WeightedScoreConfig(BaseModel):
"""
Weighted score Config.
"""

vector_setting: VectorSetting
keyword_setting: KeywordSetting


class EmbeddingSetting(BaseModel):
"""
Embedding Setting.
"""
embedding_provider_name: str
embedding_model_name: str


class EconomySetting(BaseModel):
"""
Economy Setting.
"""

keyword_number: int


class RetrievalSetting(BaseModel):
"""
Retrieval Setting.
"""
search_method: Literal["semantic_search", "keyword_search", "hybrid_search"]
top_k: int
score_threshold: Optional[float] = 0.5
score_threshold_enabled: bool = False
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None

class IndexMethod(BaseModel):
"""
Knowledge Index Setting.
"""
indexing_technique: Literal["high_quality", "economy"]
embedding_setting: EmbeddingSetting
economy_setting: EconomySetting

class FileInfo(BaseModel):
"""
File Info.
"""
file_id: str

class OnlineDocumentIcon(BaseModel):
"""
Document Icon.
"""
icon_url: str
icon_type: str
icon_emoji: str

class OnlineDocumentInfo(BaseModel):
"""
Online document info.
"""
provider: str
workspace_id: str
page_id: str
page_type: str
icon: OnlineDocumentIcon

class WebsiteInfo(BaseModel):
"""
website import info.
"""
provider: str
url: str

class GeneralStructureChunk(BaseModel):
"""
General Structure Chunk.
"""
general_chunk: list[str]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]


class ParentChildChunk(BaseModel):
"""
Parent Child Chunk.
"""
parent_content: str
child_content: list[str]


class ParentChildStructureChunk(BaseModel):
"""
Parent Child Structure Chunk.
"""
parent_child_chunks: list[ParentChildChunk]
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]


class KnowledgeIndexNodeData(BaseNodeData):
"""
Knowledge index Node Data.
"""

type: str = "knowledge-index"
dataset_id: str
index_chunk_variable_selector: list[str]
chunk_structure: Literal["general", "parent-child"]
index_method: IndexMethod
retrieval_setting: RetrievalSetting


+ 22
- 0
api/core/workflow/nodes/knowledge_index/exc.py View File

class KnowledgeIndexNodeError(ValueError):
"""Base class for KnowledgeIndexNode errors."""


class ModelNotExistError(KnowledgeIndexNodeError):
"""Raised when the model does not exist."""


class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
"""Raised when the model credentials are not initialized."""


class ModelNotSupportedError(KnowledgeIndexNodeError):
"""Raised when the model is not supported."""


class ModelQuotaExceededError(KnowledgeIndexNodeError):
"""Raised when the model provider quota is exceeded."""


class InvalidModelTypeError(KnowledgeIndexNodeError):
"""Raised when the model is not a Large Language Model."""

+ 154
- 0
api/core/workflow/nodes/knowledge_index/knowledge_index_node.py View File

import json
import logging
import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast

from sqlalchemy import Integer, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast

from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
from core.variables.segments import ObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from models.workflow import WorkflowNodeExecutionStatus
from services.dataset_service import DatasetService
from services.feature_service import FeatureService

from .entities import KnowledgeIndexNodeData, KnowledgeRetrievalNodeData, ModelConfig
from .exc import (
InvalidModelTypeError,
KnowledgeIndexNodeError,
KnowledgeRetrievalNodeError,
ModelCredentialsNotInitializedError,
ModelNotExistError,
ModelNotSupportedError,
ModelQuotaExceededError,
)

logger = logging.getLogger(__name__)

default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}


class KnowledgeIndexNode(LLMNode):
_node_data_cls = KnowledgeIndexNodeData # type: ignore
_node_type = NodeType.KNOWLEDGE_INDEX

def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeIndexNodeData, self.node_data)
# extract variables
variable = self.graph_runtime_state.variable_pool.get(node_data.index_chunk_variable_selector)
if not isinstance(variable, ObjectSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error="Query variable is not object type.",
)
chunks = variable.value
variables = {"chunks": chunks}
if not chunks:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
)
# check rate limit
if self.tenant_id:
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{self.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
error_type="RateLimitExceeded",
)

# retrieve knowledge
try:
results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks)
outputs = {"result": results}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
)

except KnowledgeIndexNodeError as e:
logger.warning("Error when running knowledge index node")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)


def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: list[any]) -> Any:
dataset = Dataset.query.filter_by(id=node_data.dataset_id).first()
if not dataset:
raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.")
DatasetService.invoke_knowledge_index(
dataset=dataset,
chunks=chunks,
index_method=node_data.index_method,
retrieval_setting=node_data.retrieval_setting,
)
pass

+ 66
- 0
api/core/workflow/nodes/knowledge_index/template_prompts.py View File

METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
""" # noqa: E501

METADATA_FILTER_USER_PROMPT_1 = """
{ "input_text": "I want to know which company’s email address test@example.com is?",
"metadata_fields": ["filename", "email", "phone", "address"]
}
"""

METADATA_FILTER_ASSISTANT_PROMPT_1 = """
```json
{"metadata_map": [
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
]
}
```
"""

METADATA_FILTER_USER_PROMPT_2 = """
{"input_text": "What are the movies with a score of more than 9 in 2024?",
"metadata_fields": ["name", "year", "rating", "country"]}
"""

METADATA_FILTER_ASSISTANT_PROMPT_2 = """
```json
{"metadata_map": [
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
]}
```
"""

METADATA_FILTER_USER_PROMPT_3 = """
'{{"input_text": "{input_text}",',
'"metadata_fields": {metadata_fields}}}'
"""

METADATA_FILTER_COMPLETION_PROMPT = """
### Job Description
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
</example>
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output
""" # noqa: E501

+ 0
- 1
api/core/workflow/nodes/knowledge_retrieval/entities.py View File

class ModelConfig(BaseModel): class ModelConfig(BaseModel):
""" """
Model Config. Model Config.
"""


provider: str provider: str
name: str name: str

+ 1
- 0
api/models/dataset.py View File

updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
embedding_model = db.Column(db.String(255), nullable=True) embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True)
keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10"))
collection_binding_id = db.Column(StringUUID, nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True) retrieval_model = db.Column(JSONB, nullable=True)
built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))

+ 403
- 0
api/services/dataset_service.py View File

from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.nodes.knowledge_index.entities import IndexMethod, RetrievalSetting
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted
from events.document_event import document_was_deleted from events.document_event import document_was_deleted
from extensions.ext_database import db from extensions.ext_database import db
return documents, batch return documents, batch


@staticmethod @staticmethod
def save_document_with_dataset_id(
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account | Any,
dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = "web",
):
# check document limit
features = FeatureService.get_features(current_user.current_tenant_id)

if features.billing.enabled:
if not knowledge_config.original_document_id:
count = 0
if knowledge_config.data_source:
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
for notion_info in notion_info_list: # type: ignore
count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list
count = len(website_info.urls) # type: ignore
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)

if features.billing.subscription.plan == "sandbox" and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")

DocumentService.check_documents_upload_quota(count, features)

# if dataset is empty, update dataset data_source_type
if not dataset.data_source_type:
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore

if not dataset.indexing_technique:
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is invalid")

dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality":
model_manager = ModelManager()
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}

dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore

documents = []
if knowledge_config.original_document_id:
document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
documents.append(document)
batch = document.batch
else:
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
# save process rule
if not dataset_process_rule:
process_rule = knowledge_config.process_rule
if process_rule:
if process_rule.mode in ("custom", "hierarchical"):
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
created_by=account.id,
)
elif process_rule.mode == "automatic":
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
else:
logging.warn(
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
)
return
db.session.add(dataset_process_rule)
db.session.commit()
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
with redis_client.lock(lock_name, timeout=600):
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)

# raise error if file not found
if not file:
raise FileNotExistsError()

file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
# check duplicate
if knowledge_config.duplicate:
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
).first()
if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
file_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
).first()
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info.pages:
if page.page_id not in exist_page_ids:
data_source_info = {
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
"type": page.type,
}
# Truncate page name to 255 characters to prevent DB field length errors
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
truncated_page_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
else:
exist_document.pop(page.page_id)
# delete not selected documents
if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
if not website_info:
raise ValueError("No website info list found.")
urls = website_info.urls
for url in urls:
data_source_info = {
"url": url,
"provider": website_info.provider,
"job_id": website_info.job_id,
"only_main_content": website_info.only_main_content,
"mode": "crawl",
}
if len(url) > 255:
document_name = url[:200] + "..."
else:
document_name = url
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
document_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()

# trigger async task
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)

return documents, batch

@staticmethod
def invoke_knowledge_index(
dataset: Dataset,
chunks: list[Any],
index_method: IndexMethod,
retrieval_setting: RetrievalSetting,
original_document_id: str | None = None,
account: Account | Any,
created_from: str = "rag-pipline",
):

if not dataset.indexing_technique:
if index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Indexing technique is invalid")

dataset.indexing_technique = index_method.indexing_technique
if index_method.indexing_technique == "high_quality":
model_manager = ModelManager()
if index_method.embedding_setting.embedding_model and index_method.embedding_setting.embedding_model_provider:
dataset_embedding_model = index_method.embedding_setting.embedding_model
dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
dataset_embedding_model_provider, dataset_embedding_model
)
dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model:
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}

dataset.retrieval_model = (
retrieval_setting.model_dump()
if retrieval_setting
else default_retrieval_model
) # type: ignore

documents = []
if original_document_id:
document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
documents.append(document)
batch = document.batch
else:
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))

lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
with redis_client.lock(lock_name, timeout=600):
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
for chunk in chunks:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)

# raise error if file not found
if not file:
raise FileNotExistsError()

file_name = file.name
data_source_info = {
"upload_file_id": file_id,
}
# check duplicate
if knowledge_config.duplicate:
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type="upload_file",
enabled=True,
name=file_name,
).first()
if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.created_from = created_from
document.doc_form = knowledge_config.doc_form
document.doc_language = knowledge_config.doc_language
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = "waiting"
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(
dataset,
dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form,
knowledge_config.doc_language,
data_source_info,
created_from,
position,
account,
file_name,
batch,
)
db.session.add(document)
db.session.flush()
document_ids.append(document.id)
documents.append(document)
position += 1
db.session.commit()

# trigger async task
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)

return documents, batch
@staticmethod
def check_documents_upload_quota(count: int, features: FeatureModel): def check_documents_upload_quota(count: int, features: FeatureModel):
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
if count > can_upload_size: if count > can_upload_size:

Loading…
Cancel
Save