| from mimetypes import guess_extension | from mimetypes import guess_extension | ||||
| from typing import Optional | from typing import Optional | ||||
| from core.datasource.datasource_file_manager import DatasourceFileManager | |||||
| from core.datasource.entities.datasource_entities import DatasourceInvokeMessage | |||||
| from core.file import File, FileTransferMethod, FileType | from core.file import File, FileTransferMethod, FileType | ||||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||||
| from core.tools.tool_file_manager import ToolFileManager | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class ToolFileMessageTransformer: | |||||
| class DatasourceFileMessageTransformer: | |||||
| @classmethod | @classmethod | ||||
| def transform_tool_invoke_messages( | |||||
| def transform_datasource_invoke_messages( | |||||
| cls, | cls, | ||||
| messages: Generator[ToolInvokeMessage, None, None], | |||||
| messages: Generator[DatasourceInvokeMessage, None, None], | |||||
| user_id: str, | user_id: str, | ||||
| tenant_id: str, | tenant_id: str, | ||||
| conversation_id: Optional[str] = None, | conversation_id: Optional[str] = None, | ||||
| ) -> Generator[ToolInvokeMessage, None, None]: | |||||
| ) -> Generator[DatasourceInvokeMessage, None, None]: | |||||
| """ | """ | ||||
| Transform tool message and handle file download | |||||
| Transform datasource message and handle file download | |||||
| """ | """ | ||||
| for message in messages: | for message in messages: | ||||
| if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: | |||||
| if message.type in {DatasourceInvokeMessage.MessageType.TEXT, DatasourceInvokeMessage.MessageType.LINK}: | |||||
| yield message | yield message | ||||
| elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance( | |||||
| message.message, ToolInvokeMessage.TextMessage | |||||
| elif message.type == DatasourceInvokeMessage.MessageType.IMAGE and isinstance( | |||||
| message.message, DatasourceInvokeMessage.TextMessage | |||||
| ): | ): | ||||
| # try to download image | # try to download image | ||||
| try: | try: | ||||
| assert isinstance(message.message, ToolInvokeMessage.TextMessage) | |||||
| assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) | |||||
| file = ToolFileManager.create_file_by_url( | |||||
| file = DatasourceFileManager.create_file_by_url( | |||||
| user_id=user_id, | user_id=user_id, | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| file_url=message.message.text, | file_url=message.message.text, | ||||
| conversation_id=conversation_id, | conversation_id=conversation_id, | ||||
| ) | ) | ||||
| url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}" | |||||
| url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}" | |||||
| yield ToolInvokeMessage( | |||||
| type=ToolInvokeMessage.MessageType.IMAGE_LINK, | |||||
| message=ToolInvokeMessage.TextMessage(text=url), | |||||
| yield DatasourceInvokeMessage( | |||||
| type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, | |||||
| message=DatasourceInvokeMessage.TextMessage(text=url), | |||||
| meta=message.meta.copy() if message.meta is not None else {}, | meta=message.meta.copy() if message.meta is not None else {}, | ||||
| ) | ) | ||||
| except Exception as e: | except Exception as e: | ||||
| yield ToolInvokeMessage( | |||||
| type=ToolInvokeMessage.MessageType.TEXT, | |||||
| message=ToolInvokeMessage.TextMessage( | |||||
| yield DatasourceInvokeMessage( | |||||
| type=DatasourceInvokeMessage.MessageType.TEXT, | |||||
| message=DatasourceInvokeMessage.TextMessage( | |||||
| text=f"Failed to download image: {message.message.text}: {e}" | text=f"Failed to download image: {message.message.text}: {e}" | ||||
| ), | ), | ||||
| meta=message.meta.copy() if message.meta is not None else {}, | meta=message.meta.copy() if message.meta is not None else {}, | ||||
| ) | ) | ||||
| elif message.type == ToolInvokeMessage.MessageType.BLOB: | |||||
| elif message.type == DatasourceInvokeMessage.MessageType.BLOB: | |||||
| # get mime type and save blob to storage | # get mime type and save blob to storage | ||||
| meta = message.meta or {} | meta = message.meta or {} | ||||
| filename = meta.get("file_name", None) | filename = meta.get("file_name", None) | ||||
| # if message is str, encode it to bytes | # if message is str, encode it to bytes | ||||
| if not isinstance(message.message, ToolInvokeMessage.BlobMessage): | |||||
| if not isinstance(message.message, DatasourceInvokeMessage.BlobMessage): | |||||
| raise ValueError("unexpected message type") | raise ValueError("unexpected message type") | ||||
| # FIXME: should do a type check here. | # FIXME: should do a type check here. | ||||
| assert isinstance(message.message.blob, bytes) | assert isinstance(message.message.blob, bytes) | ||||
| file = ToolFileManager.create_file_by_raw( | |||||
| file = DatasourceFileManager.create_file_by_raw( | |||||
| user_id=user_id, | user_id=user_id, | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| conversation_id=conversation_id, | conversation_id=conversation_id, | ||||
| filename=filename, | filename=filename, | ||||
| ) | ) | ||||
| url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype)) | |||||
| url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype)) | |||||
| # check if file is image | # check if file is image | ||||
| if "image" in mimetype: | if "image" in mimetype: | ||||
| yield ToolInvokeMessage( | |||||
| type=ToolInvokeMessage.MessageType.IMAGE_LINK, | |||||
| message=ToolInvokeMessage.TextMessage(text=url), | |||||
| yield DatasourceInvokeMessage( | |||||
| type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, | |||||
| message=DatasourceInvokeMessage.TextMessage(text=url), | |||||
| meta=meta.copy() if meta is not None else {}, | meta=meta.copy() if meta is not None else {}, | ||||
| ) | ) | ||||
| else: | else: | ||||
| yield ToolInvokeMessage( | |||||
| type=ToolInvokeMessage.MessageType.BINARY_LINK, | |||||
| message=ToolInvokeMessage.TextMessage(text=url), | |||||
| yield DatasourceInvokeMessage( | |||||
| type=DatasourceInvokeMessage.MessageType.BINARY_LINK, | |||||
| message=DatasourceInvokeMessage.TextMessage(text=url), | |||||
| meta=meta.copy() if meta is not None else {}, | meta=meta.copy() if meta is not None else {}, | ||||
| ) | ) | ||||
| elif message.type == ToolInvokeMessage.MessageType.FILE: | |||||
| elif message.type == DatasourceInvokeMessage.MessageType.FILE: | |||||
| meta = message.meta or {} | meta = message.meta or {} | ||||
| file = meta.get("file", None) | file = meta.get("file", None) | ||||
| if isinstance(file, File): | if isinstance(file, File): | ||||
| assert file.related_id is not None | assert file.related_id is not None | ||||
| url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) | url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) | ||||
| if file.type == FileType.IMAGE: | if file.type == FileType.IMAGE: | ||||
| yield ToolInvokeMessage( | |||||
| type=ToolInvokeMessage.MessageType.IMAGE_LINK, | |||||
| message=ToolInvokeMessage.TextMessage(text=url), | |||||
| yield DatasourceInvokeMessage( | |||||
| type=DatasourceInvokeMessage.MessageType.IMAGE_LINK, | |||||
| message=DatasourceInvokeMessage.TextMessage(text=url), | |||||
| meta=meta.copy() if meta is not None else {}, | meta=meta.copy() if meta is not None else {}, | ||||
| ) | ) | ||||
| else: | else: | ||||
| yield ToolInvokeMessage( | |||||
| type=ToolInvokeMessage.MessageType.LINK, | |||||
| message=ToolInvokeMessage.TextMessage(text=url), | |||||
| yield DatasourceInvokeMessage( | |||||
| type=DatasourceInvokeMessage.MessageType.LINK, | |||||
| message=DatasourceInvokeMessage.TextMessage(text=url), | |||||
| meta=meta.copy() if meta is not None else {}, | meta=meta.copy() if meta is not None else {}, | ||||
| ) | ) | ||||
| else: | else: | ||||
| yield message | yield message | ||||
| @classmethod | @classmethod | ||||
| def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: | |||||
| return f"/files/tools/{tool_file_id}{extension or '.bin'}" | |||||
| def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str: | |||||
| return f"/files/datasources/{datasource_file_id}{extension or '.bin'}" |
| response = self._request_with_plugin_daemon_response_stream( | response = self._request_with_plugin_daemon_response_stream( | ||||
| "POST", | "POST", | ||||
| f"plugin/{tenant_id}/dispatch/datasource/invoke_first_step", | |||||
| f"plugin/{tenant_id}/dispatch/datasource/{online_document}/pages", | |||||
| ToolInvokeMessage, | ToolInvokeMessage, | ||||
| data={ | data={ | ||||
| "user_id": user_id, | "user_id": user_id, | ||||
| "data": { | "data": { | ||||
| "provider": datasource_provider_id.provider_name, | "provider": datasource_provider_id.provider_name, | ||||
| "datasource": datasource_name, | "datasource": datasource_name, | ||||
| "credentials": credentials, | "credentials": credentials, | ||||
| "datasource_parameters": datasource_parameters, | "datasource_parameters": datasource_parameters, | ||||
| }, | }, |
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler | from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler | ||||
| from core.datasource.datasource_engine import DatasourceEngine | |||||
| from core.datasource.entities.datasource_entities import DatasourceInvokeMessage, DatasourceParameter | |||||
| from core.datasource.errors import DatasourceInvokeError | |||||
| from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer | |||||
| from core.file import File, FileTransferMethod | from core.file import File, FileTransferMethod | ||||
| from core.plugin.manager.exc import PluginDaemonClientSideError | from core.plugin.manager.exc import PluginDaemonClientSideError | ||||
| from core.plugin.manager.plugin import PluginInstallationManager | from core.plugin.manager.plugin import PluginInstallationManager | ||||
| from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter | |||||
| from core.tools.errors import ToolInvokeError | |||||
| from core.tools.tool_engine import ToolEngine | |||||
| from core.tools.utils.message_transformer import ToolFileMessageTransformer | |||||
| from core.variables.segments import ArrayAnySegment | from core.variables.segments import ArrayAnySegment | ||||
| from core.variables.variables import ArrayAnyVariable | from core.variables.variables import ArrayAnyVariable | ||||
| from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult | ||||
| from services.tools.builtin_tools_manage_service import BuiltinToolManageService | from services.tools.builtin_tools_manage_service import BuiltinToolManageService | ||||
| from .entities import DatasourceNodeData | from .entities import DatasourceNodeData | ||||
| from .exc import ( | |||||
| ToolFileError, | |||||
| ToolNodeError, | |||||
| ToolParameterError, | |||||
| ) | |||||
| from .exc import DatasourceNodeError, DatasourceParameterError, ToolFileError | |||||
| class DatasourceNode(BaseNode[DatasourceNodeData]): | class DatasourceNode(BaseNode[DatasourceNodeData]): | ||||
| # get datasource runtime | # get datasource runtime | ||||
| try: | try: | ||||
| from core.tools.tool_manager import ToolManager | |||||
| from core.datasource.datasource_manager import DatasourceManager | |||||
| tool_runtime = ToolManager.get_workflow_tool_runtime( | |||||
| datasource_runtime = DatasourceManager.get_workflow_datasource_runtime( | |||||
| self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from | self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from | ||||
| ) | ) | ||||
| except ToolNodeError as e: | |||||
| except DatasourceNodeError as e: | |||||
| yield RunCompletedEvent( | yield RunCompletedEvent( | ||||
| run_result=NodeRunResult( | run_result=NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.FAILED, | status=WorkflowNodeExecutionStatus.FAILED, | ||||
| return | return | ||||
| # get parameters | # get parameters | ||||
| tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] | |||||
| datasource_parameters = datasource_runtime.get_merged_runtime_parameters() or [] | |||||
| parameters = self._generate_parameters( | parameters = self._generate_parameters( | ||||
| tool_parameters=tool_parameters, | |||||
| datasource_parameters=datasource_parameters, | |||||
| variable_pool=self.graph_runtime_state.variable_pool, | variable_pool=self.graph_runtime_state.variable_pool, | ||||
| node_data=self.node_data, | node_data=self.node_data, | ||||
| ) | ) | ||||
| parameters_for_log = self._generate_parameters( | parameters_for_log = self._generate_parameters( | ||||
| tool_parameters=tool_parameters, | |||||
| datasource_parameters=datasource_parameters, | |||||
| variable_pool=self.graph_runtime_state.variable_pool, | variable_pool=self.graph_runtime_state.variable_pool, | ||||
| node_data=self.node_data, | node_data=self.node_data, | ||||
| for_log=True, | for_log=True, | ||||
| conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) | conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) | ||||
| try: | try: | ||||
| message_stream = ToolEngine.generic_invoke( | |||||
| tool=tool_runtime, | |||||
| tool_parameters=parameters, | |||||
| message_stream = DatasourceEngine.generic_invoke( | |||||
| datasource=datasource_runtime, | |||||
| datasource_parameters=parameters, | |||||
| user_id=self.user_id, | user_id=self.user_id, | ||||
| workflow_tool_callback=DifyWorkflowCallbackHandler(), | workflow_tool_callback=DifyWorkflowCallbackHandler(), | ||||
| workflow_call_depth=self.workflow_call_depth, | workflow_call_depth=self.workflow_call_depth, | ||||
| app_id=self.app_id, | app_id=self.app_id, | ||||
| conversation_id=conversation_id.text if conversation_id else None, | conversation_id=conversation_id.text if conversation_id else None, | ||||
| ) | ) | ||||
| except ToolNodeError as e: | |||||
| except DatasourceNodeError as e: | |||||
| yield RunCompletedEvent( | yield RunCompletedEvent( | ||||
| run_result=NodeRunResult( | run_result=NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.FAILED, | status=WorkflowNodeExecutionStatus.FAILED, | ||||
| inputs=parameters_for_log, | inputs=parameters_for_log, | ||||
| metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, | |||||
| error=f"Failed to invoke tool: {str(e)}", | |||||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| error=f"Failed to invoke datasource: {str(e)}", | |||||
| error_type=type(e).__name__, | error_type=type(e).__name__, | ||||
| ) | ) | ||||
| ) | ) | ||||
| return | return | ||||
| try: | try: | ||||
| # convert tool messages | |||||
| yield from self._transform_message(message_stream, tool_info, parameters_for_log) | |||||
| except (PluginDaemonClientSideError, ToolInvokeError) as e: | |||||
| # convert datasource messages | |||||
| yield from self._transform_message(message_stream, datasource_info, parameters_for_log) | |||||
| except (PluginDaemonClientSideError, DatasourceInvokeError) as e: | |||||
| yield RunCompletedEvent( | yield RunCompletedEvent( | ||||
| run_result=NodeRunResult( | run_result=NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.FAILED, | status=WorkflowNodeExecutionStatus.FAILED, | ||||
| inputs=parameters_for_log, | inputs=parameters_for_log, | ||||
| metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, | |||||
| error=f"Failed to transform tool message: {str(e)}", | |||||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||||
| error=f"Failed to transform datasource message: {str(e)}", | |||||
| error_type=type(e).__name__, | error_type=type(e).__name__, | ||||
| ) | ) | ||||
| ) | ) | ||||
| def _generate_parameters( | def _generate_parameters( | ||||
| self, | self, | ||||
| *, | *, | ||||
| tool_parameters: Sequence[ToolParameter], | |||||
| datasource_parameters: Sequence[DatasourceParameter], | |||||
| variable_pool: VariablePool, | variable_pool: VariablePool, | ||||
| node_data: ToolNodeData, | |||||
| node_data: DatasourceNodeData, | |||||
| for_log: bool = False, | for_log: bool = False, | ||||
| ) -> dict[str, Any]: | ) -> dict[str, Any]: | ||||
| """ | """ | ||||
| Mapping[str, Any]: A dictionary containing the generated parameters. | Mapping[str, Any]: A dictionary containing the generated parameters. | ||||
| """ | """ | ||||
| tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} | |||||
| datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters} | |||||
| result: dict[str, Any] = {} | result: dict[str, Any] = {} | ||||
| for parameter_name in node_data.tool_parameters: | |||||
| parameter = tool_parameters_dictionary.get(parameter_name) | |||||
| for parameter_name in node_data.datasource_parameters: | |||||
| parameter = datasource_parameters_dictionary.get(parameter_name) | |||||
| if not parameter: | if not parameter: | ||||
| result[parameter_name] = None | result[parameter_name] = None | ||||
| continue | continue | ||||
| tool_input = node_data.tool_parameters[parameter_name] | |||||
| if tool_input.type == "variable": | |||||
| variable = variable_pool.get(tool_input.value) | |||||
| datasource_input = node_data.datasource_parameters[parameter_name] | |||||
| if datasource_input.type == "variable": | |||||
| variable = variable_pool.get(datasource_input.value) | |||||
| if variable is None: | if variable is None: | ||||
| raise ToolParameterError(f"Variable {tool_input.value} does not exist") | |||||
| raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist") | |||||
| parameter_value = variable.value | parameter_value = variable.value | ||||
| elif tool_input.type in {"mixed", "constant"}: | |||||
| segment_group = variable_pool.convert_template(str(tool_input.value)) | |||||
| elif datasource_input.type in {"mixed", "constant"}: | |||||
| segment_group = variable_pool.convert_template(str(datasource_input.value)) | |||||
| parameter_value = segment_group.log if for_log else segment_group.text | parameter_value = segment_group.log if for_log else segment_group.text | ||||
| else: | else: | ||||
| raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") | |||||
| raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'") | |||||
| result[parameter_name] = parameter_value | result[parameter_name] = parameter_value | ||||
| return result | return result | ||||
| def _transform_message( | def _transform_message( | ||||
| self, | self, | ||||
| messages: Generator[ToolInvokeMessage, None, None], | |||||
| tool_info: Mapping[str, Any], | |||||
| messages: Generator[DatasourceInvokeMessage, None, None], | |||||
| datasource_info: Mapping[str, Any], | |||||
| parameters_for_log: dict[str, Any], | parameters_for_log: dict[str, Any], | ||||
| ) -> Generator: | ) -> Generator: | ||||
| """ | """ | ||||
| Convert ToolInvokeMessages into tuple[plain_text, files] | Convert ToolInvokeMessages into tuple[plain_text, files] | ||||
| """ | """ | ||||
| # transform message and handle file storage | # transform message and handle file storage | ||||
| message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( | |||||
| message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( | |||||
| messages=messages, | messages=messages, | ||||
| user_id=self.user_id, | user_id=self.user_id, | ||||
| tenant_id=self.tenant_id, | tenant_id=self.tenant_id, | ||||
| for message in message_stream: | for message in message_stream: | ||||
| if message.type in { | if message.type in { | ||||
| ToolInvokeMessage.MessageType.IMAGE_LINK, | |||||
| ToolInvokeMessage.MessageType.BINARY_LINK, | |||||
| ToolInvokeMessage.MessageType.IMAGE, | |||||
| DatasourceInvokeMessage.MessageType.IMAGE_LINK, | |||||
| DatasourceInvokeMessage.MessageType.BINARY_LINK, | |||||
| DatasourceInvokeMessage.MessageType.IMAGE, | |||||
| }: | }: | ||||
| assert isinstance(message.message, ToolInvokeMessage.TextMessage) | |||||
| assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) | |||||
| url = message.message.text | url = message.message.text | ||||
| if message.meta: | if message.meta: | ||||
| tenant_id=self.tenant_id, | tenant_id=self.tenant_id, | ||||
| ) | ) | ||||
| files.append(file) | files.append(file) | ||||
| elif message.type == ToolInvokeMessage.MessageType.BLOB: | |||||
| elif message.type == DatasourceInvokeMessage.MessageType.BLOB: | |||||
| # get tool file id | # get tool file id | ||||
| assert isinstance(message.message, ToolInvokeMessage.TextMessage) | |||||
| assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) | |||||
| assert message.meta | assert message.meta | ||||
| tool_file_id = message.message.text.split("/")[-1].split(".")[0] | tool_file_id = message.message.text.split("/")[-1].split(".")[0] | ||||
| tenant_id=self.tenant_id, | tenant_id=self.tenant_id, | ||||
| ) | ) | ||||
| ) | ) | ||||
| elif message.type == ToolInvokeMessage.MessageType.TEXT: | |||||
| assert isinstance(message.message, ToolInvokeMessage.TextMessage) | |||||
| elif message.type == DatasourceInvokeMessage.MessageType.TEXT: | |||||
| assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) | |||||
| text += message.message.text | text += message.message.text | ||||
| yield RunStreamChunkEvent( | yield RunStreamChunkEvent( | ||||
| chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] | chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] | ||||
| ) | ) | ||||
| elif message.type == ToolInvokeMessage.MessageType.JSON: | |||||
| assert isinstance(message.message, ToolInvokeMessage.JsonMessage) | |||||
| elif message.type == DatasourceInvokeMessage.MessageType.JSON: | |||||
| assert isinstance(message.message, DatasourceInvokeMessage.JsonMessage) | |||||
| if self.node_type == NodeType.AGENT: | if self.node_type == NodeType.AGENT: | ||||
| msg_metadata = message.message.json_object.pop("execution_metadata", {}) | msg_metadata = message.message.json_object.pop("execution_metadata", {}) | ||||
| agent_execution_metadata = { | agent_execution_metadata = { | ||||
| if key in NodeRunMetadataKey.__members__.values() | if key in NodeRunMetadataKey.__members__.values() | ||||
| } | } | ||||
| json.append(message.message.json_object) | json.append(message.message.json_object) | ||||
| elif message.type == ToolInvokeMessage.MessageType.LINK: | |||||
| assert isinstance(message.message, ToolInvokeMessage.TextMessage) | |||||
| elif message.type == DatasourceInvokeMessage.MessageType.LINK: | |||||
| assert isinstance(message.message, DatasourceInvokeMessage.TextMessage) | |||||
| stream_text = f"Link: {message.message.text}\n" | stream_text = f"Link: {message.message.text}\n" | ||||
| text += stream_text | text += stream_text | ||||
| yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) | yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) | ||||
| elif message.type == ToolInvokeMessage.MessageType.VARIABLE: | |||||
| assert isinstance(message.message, ToolInvokeMessage.VariableMessage) | |||||
| elif message.type == DatasourceInvokeMessage.MessageType.VARIABLE: | |||||
| assert isinstance(message.message, DatasourceInvokeMessage.VariableMessage) | |||||
| variable_name = message.message.variable_name | variable_name = message.message.variable_name | ||||
| variable_value = message.message.variable_value | variable_value = message.message.variable_value | ||||
| if message.message.stream: | if message.message.stream: | ||||
| ) | ) | ||||
| else: | else: | ||||
| variables[variable_name] = variable_value | variables[variable_name] = variable_value | ||||
| elif message.type == ToolInvokeMessage.MessageType.FILE: | |||||
| elif message.type == DatasourceInvokeMessage.MessageType.FILE: | |||||
| assert message.meta is not None | assert message.meta is not None | ||||
| files.append(message.meta["file"]) | files.append(message.meta["file"]) | ||||
| elif message.type == ToolInvokeMessage.MessageType.LOG: | |||||
| assert isinstance(message.message, ToolInvokeMessage.LogMessage) | |||||
| elif message.type == DatasourceInvokeMessage.MessageType.LOG: | |||||
| assert isinstance(message.message, DatasourceInvokeMessage.LogMessage) | |||||
| if message.message.metadata: | if message.message.metadata: | ||||
| icon = tool_info.get("icon", "") | |||||
| icon = datasource_info.get("icon", "") | |||||
| dict_metadata = dict(message.message.metadata) | dict_metadata = dict(message.message.metadata) | ||||
| if dict_metadata.get("provider"): | if dict_metadata.get("provider"): | ||||
| manager = PluginInstallationManager() | manager = PluginInstallationManager() | ||||
| outputs={"text": text, "files": files, "json": json, **variables}, | outputs={"text": text, "files": files, "json": json, **variables}, | ||||
| metadata={ | metadata={ | ||||
| **agent_execution_metadata, | **agent_execution_metadata, | ||||
| NodeRunMetadataKey.TOOL_INFO: tool_info, | |||||
| NodeRunMetadataKey.DATASOURCE_INFO: datasource_info, | |||||
| NodeRunMetadataKey.AGENT_LOG: agent_logs, | NodeRunMetadataKey.AGENT_LOG: agent_logs, | ||||
| }, | }, | ||||
| inputs=parameters_for_log, | inputs=parameters_for_log, | ||||
| *, | *, | ||||
| graph_config: Mapping[str, Any], | graph_config: Mapping[str, Any], | ||||
| node_id: str, | node_id: str, | ||||
| node_data: ToolNodeData, | |||||
| node_data: DatasourceNodeData, | |||||
| ) -> Mapping[str, Sequence[str]]: | ) -> Mapping[str, Sequence[str]]: | ||||
| """ | """ | ||||
| Extract variable selector to variable mapping | Extract variable selector to variable mapping | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| result = {} | result = {} | ||||
| for parameter_name in node_data.tool_parameters: | |||||
| input = node_data.tool_parameters[parameter_name] | |||||
| for parameter_name in node_data.datasource_parameters: | |||||
| input = node_data.datasource_parameters[parameter_name] | |||||
| if input.type == "mixed": | if input.type == "mixed": | ||||
| assert isinstance(input.value, str) | assert isinstance(input.value, str) | ||||
| selectors = VariableTemplateParser(input.value).extract_variable_selectors() | selectors = VariableTemplateParser(input.value).extract_variable_selectors() |
| class ToolNodeError(ValueError): | |||||
| """Base exception for tool node errors.""" | |||||
| class DatasourceNodeError(ValueError): | |||||
| """Base exception for datasource node errors.""" | |||||
| pass | pass | ||||
| class ToolParameterError(ToolNodeError): | |||||
| """Exception raised for errors in tool parameters.""" | |||||
| class DatasourceParameterError(DatasourceNodeError): | |||||
| """Exception raised for errors in datasource parameters.""" | |||||
| pass | pass | ||||
| class ToolFileError(ToolNodeError): | |||||
| """Exception raised for errors related to tool files.""" | |||||
| class DatasourceFileError(DatasourceNodeError): | |||||
| """Exception raised for errors related to datasource files.""" | |||||
| pass | pass |
| ANSWER = "answer" | ANSWER = "answer" | ||||
| LLM = "llm" | LLM = "llm" | ||||
| KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" | KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" | ||||
| KNOWLEDGE_INDEX = "knowledge-index" | |||||
| IF_ELSE = "if-else" | IF_ELSE = "if-else" | ||||
| CODE = "code" | CODE = "code" | ||||
| TEMPLATE_TRANSFORM = "template-transform" | TEMPLATE_TRANSFORM = "template-transform" |
| from .knowledge_index_node import KnowledgeRetrievalNode | |||||
| __all__ = ["KnowledgeRetrievalNode"] |
| from collections.abc import Sequence | |||||
| from typing import Any, Literal, Optional, Union | |||||
| from pydantic import BaseModel, Field | |||||
| from core.workflow.nodes.base import BaseNodeData | |||||
| from core.workflow.nodes.llm.entities import VisionConfig | |||||
| class RerankingModelConfig(BaseModel): | |||||
| """ | |||||
| Reranking Model Config. | |||||
| """ | |||||
| provider: str | |||||
| model: str | |||||
| class VectorSetting(BaseModel): | |||||
| """ | |||||
| Vector Setting. | |||||
| """ | |||||
| vector_weight: float | |||||
| embedding_provider_name: str | |||||
| embedding_model_name: str | |||||
| class KeywordSetting(BaseModel): | |||||
| """ | |||||
| Keyword Setting. | |||||
| """ | |||||
| keyword_weight: float | |||||
| class WeightedScoreConfig(BaseModel): | |||||
| """ | |||||
| Weighted score Config. | |||||
| """ | |||||
| vector_setting: VectorSetting | |||||
| keyword_setting: KeywordSetting | |||||
| class EmbeddingSetting(BaseModel): | |||||
| """ | |||||
| Embedding Setting. | |||||
| """ | |||||
| embedding_provider_name: str | |||||
| embedding_model_name: str | |||||
| class EconomySetting(BaseModel): | |||||
| """ | |||||
| Economy Setting. | |||||
| """ | |||||
| keyword_number: int | |||||
| class RetrievalSetting(BaseModel): | |||||
| """ | |||||
| Retrieval Setting. | |||||
| """ | |||||
| search_method: Literal["semantic_search", "keyword_search", "hybrid_search"] | |||||
| top_k: int | |||||
| score_threshold: Optional[float] = 0.5 | |||||
| score_threshold_enabled: bool = False | |||||
| reranking_mode: str = "reranking_model" | |||||
| reranking_enable: bool = True | |||||
| reranking_model: Optional[RerankingModelConfig] = None | |||||
| weights: Optional[WeightedScoreConfig] = None | |||||
| class IndexMethod(BaseModel): | |||||
| """ | |||||
| Knowledge Index Setting. | |||||
| """ | |||||
| indexing_technique: Literal["high_quality", "economy"] | |||||
| embedding_setting: EmbeddingSetting | |||||
| economy_setting: EconomySetting | |||||
| class FileInfo(BaseModel): | |||||
| """ | |||||
| File Info. | |||||
| """ | |||||
| file_id: str | |||||
| class OnlineDocumentIcon(BaseModel): | |||||
| """ | |||||
| Document Icon. | |||||
| """ | |||||
| icon_url: str | |||||
| icon_type: str | |||||
| icon_emoji: str | |||||
| class OnlineDocumentInfo(BaseModel): | |||||
| """ | |||||
| Online document info. | |||||
| """ | |||||
| provider: str | |||||
| workspace_id: str | |||||
| page_id: str | |||||
| page_type: str | |||||
| icon: OnlineDocumentIcon | |||||
| class WebsiteInfo(BaseModel): | |||||
| """ | |||||
| website import info. | |||||
| """ | |||||
| provider: str | |||||
| url: str | |||||
| class GeneralStructureChunk(BaseModel): | |||||
| """ | |||||
| General Structure Chunk. | |||||
| """ | |||||
| general_chunk: list[str] | |||||
| data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo] | |||||
| class ParentChildChunk(BaseModel): | |||||
| """ | |||||
| Parent Child Chunk. | |||||
| """ | |||||
| parent_content: str | |||||
| child_content: list[str] | |||||
| class ParentChildStructureChunk(BaseModel): | |||||
| """ | |||||
| Parent Child Structure Chunk. | |||||
| """ | |||||
| parent_child_chunks: list[ParentChildChunk] | |||||
| data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo] | |||||
| class KnowledgeIndexNodeData(BaseNodeData): | |||||
| """ | |||||
| Knowledge index Node Data. | |||||
| """ | |||||
| type: str = "knowledge-index" | |||||
| dataset_id: str | |||||
| index_chunk_variable_selector: list[str] | |||||
| chunk_structure: Literal["general", "parent-child"] | |||||
| index_method: IndexMethod | |||||
| retrieval_setting: RetrievalSetting | |||||
| class KnowledgeIndexNodeError(ValueError): | |||||
| """Base class for KnowledgeIndexNode errors.""" | |||||
| class ModelNotExistError(KnowledgeIndexNodeError): | |||||
| """Raised when the model does not exist.""" | |||||
| class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError): | |||||
| """Raised when the model credentials are not initialized.""" | |||||
| class ModelNotSupportedError(KnowledgeIndexNodeError): | |||||
| """Raised when the model is not supported.""" | |||||
| class ModelQuotaExceededError(KnowledgeIndexNodeError): | |||||
| """Raised when the model provider quota is exceeded.""" | |||||
| class InvalidModelTypeError(KnowledgeIndexNodeError): | |||||
| """Raised when the model is not a Large Language Model.""" |
| import json | |||||
| import logging | |||||
| import re | |||||
| import time | |||||
| from collections import defaultdict | |||||
| from collections.abc import Mapping, Sequence | |||||
| from typing import Any, Optional, cast | |||||
| from sqlalchemy import Integer, and_, func, or_, text | |||||
| from sqlalchemy import cast as sqlalchemy_cast | |||||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity | |||||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||||
| from core.entities.agent_entities import PlanningStrategy | |||||
| from core.entities.model_entities import ModelStatus | |||||
| from core.model_manager import ModelInstance, ModelManager | |||||
| from core.model_runtime.entities.message_entities import PromptMessageRole | |||||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | |||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||||
| from core.prompt.simple_prompt_transform import ModelMode | |||||
| from core.rag.datasource.retrieval_service import RetrievalService | |||||
| from core.rag.entities.metadata_entities import Condition, MetadataCondition | |||||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||||
| from core.variables import StringSegment | |||||
| from core.variables.segments import ObjectSegment | |||||
| from core.workflow.entities.node_entities import NodeRunResult | |||||
| from core.workflow.nodes.enums import NodeType | |||||
| from core.workflow.nodes.event.event import ModelInvokeCompletedEvent | |||||
| from core.workflow.nodes.knowledge_retrieval.template_prompts import ( | |||||
| METADATA_FILTER_ASSISTANT_PROMPT_1, | |||||
| METADATA_FILTER_ASSISTANT_PROMPT_2, | |||||
| METADATA_FILTER_COMPLETION_PROMPT, | |||||
| METADATA_FILTER_SYSTEM_PROMPT, | |||||
| METADATA_FILTER_USER_PROMPT_1, | |||||
| METADATA_FILTER_USER_PROMPT_3, | |||||
| ) | |||||
| from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate | |||||
| from core.workflow.nodes.llm.node import LLMNode | |||||
| from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2 | |||||
| from extensions.ext_database import db | |||||
| from extensions.ext_redis import redis_client | |||||
| from libs.json_in_md_parser import parse_and_check_json_markdown | |||||
| from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog | |||||
| from models.workflow import WorkflowNodeExecutionStatus | |||||
| from services.dataset_service import DatasetService | |||||
| from services.feature_service import FeatureService | |||||
| from .entities import KnowledgeIndexNodeData, KnowledgeRetrievalNodeData, ModelConfig | |||||
| from .exc import ( | |||||
| InvalidModelTypeError, | |||||
| KnowledgeIndexNodeError, | |||||
| KnowledgeRetrievalNodeError, | |||||
| ModelCredentialsNotInitializedError, | |||||
| ModelNotExistError, | |||||
| ModelNotSupportedError, | |||||
| ModelQuotaExceededError, | |||||
| ) | |||||
| logger = logging.getLogger(__name__) | |||||
| default_retrieval_model = { | |||||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||||
| "reranking_enable": False, | |||||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||||
| "top_k": 2, | |||||
| "score_threshold_enabled": False, | |||||
| } | |||||
| class KnowledgeIndexNode(LLMNode): | |||||
| _node_data_cls = KnowledgeIndexNodeData # type: ignore | |||||
| _node_type = NodeType.KNOWLEDGE_INDEX | |||||
| def _run(self) -> NodeRunResult: # type: ignore | |||||
| node_data = cast(KnowledgeIndexNodeData, self.node_data) | |||||
| # extract variables | |||||
| variable = self.graph_runtime_state.variable_pool.get(node_data.index_chunk_variable_selector) | |||||
| if not isinstance(variable, ObjectSegment): | |||||
| return NodeRunResult( | |||||
| status=WorkflowNodeExecutionStatus.FAILED, | |||||
| inputs={}, | |||||
| error="Query variable is not object type.", | |||||
| ) | |||||
| chunks = variable.value | |||||
| variables = {"chunks": chunks} | |||||
| if not chunks: | |||||
| return NodeRunResult( | |||||
| status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." | |||||
| ) | |||||
| # check rate limit | |||||
| if self.tenant_id: | |||||
| knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) | |||||
| if knowledge_rate_limit.enabled: | |||||
| current_time = int(time.time() * 1000) | |||||
| key = f"rate_limit_{self.tenant_id}" | |||||
| redis_client.zadd(key, {current_time: current_time}) | |||||
| redis_client.zremrangebyscore(key, 0, current_time - 60000) | |||||
| request_count = redis_client.zcard(key) | |||||
| if request_count > knowledge_rate_limit.limit: | |||||
| # add ratelimit record | |||||
| rate_limit_log = RateLimitLog( | |||||
| tenant_id=self.tenant_id, | |||||
| subscription_plan=knowledge_rate_limit.subscription_plan, | |||||
| operation="knowledge", | |||||
| ) | |||||
| db.session.add(rate_limit_log) | |||||
| db.session.commit() | |||||
| return NodeRunResult( | |||||
| status=WorkflowNodeExecutionStatus.FAILED, | |||||
| inputs=variables, | |||||
| error="Sorry, you have reached the knowledge base request rate limit of your subscription.", | |||||
| error_type="RateLimitExceeded", | |||||
| ) | |||||
| # retrieve knowledge | |||||
| try: | |||||
| results = self._invoke_knowledge_index(node_data=node_data, chunks=chunks) | |||||
| outputs = {"result": results} | |||||
| return NodeRunResult( | |||||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs | |||||
| ) | |||||
| except KnowledgeIndexNodeError as e: | |||||
| logger.warning("Error when running knowledge index node") | |||||
| return NodeRunResult( | |||||
| status=WorkflowNodeExecutionStatus.FAILED, | |||||
| inputs=variables, | |||||
| error=str(e), | |||||
| error_type=type(e).__name__, | |||||
| ) | |||||
| # Temporary handle all exceptions from DatasetRetrieval class here. | |||||
| except Exception as e: | |||||
| return NodeRunResult( | |||||
| status=WorkflowNodeExecutionStatus.FAILED, | |||||
| inputs=variables, | |||||
| error=str(e), | |||||
| error_type=type(e).__name__, | |||||
| ) | |||||
| def _invoke_knowledge_index(self, node_data: KnowledgeIndexNodeData, chunks: list[any]) -> Any: | |||||
| dataset = Dataset.query.filter_by(id=node_data.dataset_id).first() | |||||
| if not dataset: | |||||
| raise KnowledgeIndexNodeError(f"Dataset {node_data.dataset_id} not found.") | |||||
| DatasetService.invoke_knowledge_index( | |||||
| dataset=dataset, | |||||
| chunks=chunks, | |||||
| index_method=node_data.index_method, | |||||
| retrieval_setting=node_data.retrieval_setting, | |||||
| ) | |||||
| pass |
| METADATA_FILTER_SYSTEM_PROMPT = """ | |||||
| ### Job Description', | |||||
| You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value | |||||
| ### Task | |||||
| Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". | |||||
| ### Format | |||||
| The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. | |||||
| ### Constraint | |||||
| DO NOT include anything other than the JSON array in your response. | |||||
| """ # noqa: E501 | |||||
| METADATA_FILTER_USER_PROMPT_1 = """ | |||||
| { "input_text": "I want to know which company’s email address test@example.com is?", | |||||
| "metadata_fields": ["filename", "email", "phone", "address"] | |||||
| } | |||||
| """ | |||||
| METADATA_FILTER_ASSISTANT_PROMPT_1 = """ | |||||
| ```json | |||||
| {"metadata_map": [ | |||||
| {"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="} | |||||
| ] | |||||
| } | |||||
| ``` | |||||
| """ | |||||
| METADATA_FILTER_USER_PROMPT_2 = """ | |||||
| {"input_text": "What are the movies with a score of more than 9 in 2024?", | |||||
| "metadata_fields": ["name", "year", "rating", "country"]} | |||||
| """ | |||||
| METADATA_FILTER_ASSISTANT_PROMPT_2 = """ | |||||
| ```json | |||||
| {"metadata_map": [ | |||||
| {"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, | |||||
| {"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}, | |||||
| ]} | |||||
| ``` | |||||
| """ | |||||
| METADATA_FILTER_USER_PROMPT_3 = """ | |||||
| '{{"input_text": "{input_text}",', | |||||
| '"metadata_fields": {metadata_fields}}}' | |||||
| """ | |||||
| METADATA_FILTER_COMPLETION_PROMPT = """ | |||||
| ### Job Description | |||||
| You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value | |||||
| ### Task | |||||
| # Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". | |||||
| ### Format | |||||
| The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. | |||||
| ### Constraint | |||||
| DO NOT include anything other than the JSON array in your response. | |||||
| ### Example | |||||
| Here is the chat example between human and assistant, inside <example></example> XML tags. | |||||
| <example> | |||||
| User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}} | |||||
| Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}} | |||||
| User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}} | |||||
| Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}} | |||||
| </example> | |||||
| ### User Input | |||||
| {{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}} | |||||
| ### Assistant Output | |||||
| """ # noqa: E501 |
| class ModelConfig(BaseModel): | class ModelConfig(BaseModel): | ||||
| """ | """ | ||||
| Model Config. | Model Config. | ||||
| """ | |||||
| provider: str | provider: str | ||||
| name: str | name: str |
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| embedding_model = db.Column(db.String(255), nullable=True) | embedding_model = db.Column(db.String(255), nullable=True) | ||||
| embedding_model_provider = db.Column(db.String(255), nullable=True) | embedding_model_provider = db.Column(db.String(255), nullable=True) | ||||
| keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10")) | |||||
| collection_binding_id = db.Column(StringUUID, nullable=True) | collection_binding_id = db.Column(StringUUID, nullable=True) | ||||
| retrieval_model = db.Column(JSONB, nullable=True) | retrieval_model = db.Column(JSONB, nullable=True) | ||||
| built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) |
| from core.rag.index_processor.constant.built_in_field import BuiltInField | from core.rag.index_processor.constant.built_in_field import BuiltInField | ||||
| from core.rag.index_processor.constant.index_type import IndexType | from core.rag.index_processor.constant.index_type import IndexType | ||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| from core.workflow.nodes.knowledge_index.entities import IndexMethod, RetrievalSetting | |||||
| from events.dataset_event import dataset_was_deleted | from events.dataset_event import dataset_was_deleted | ||||
| from events.document_event import document_was_deleted | from events.document_event import document_was_deleted | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| return documents, batch | return documents, batch | ||||
| @staticmethod | @staticmethod | ||||
| def save_document_with_dataset_id( | |||||
| dataset: Dataset, | |||||
| knowledge_config: KnowledgeConfig, | |||||
| account: Account | Any, | |||||
| dataset_process_rule: Optional[DatasetProcessRule] = None, | |||||
| created_from: str = "web", | |||||
| ): | |||||
| # check document limit | |||||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||||
| if features.billing.enabled: | |||||
| if not knowledge_config.original_document_id: | |||||
| count = 0 | |||||
| if knowledge_config.data_source: | |||||
| if knowledge_config.data_source.info_list.data_source_type == "upload_file": | |||||
| upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore | |||||
| count = len(upload_file_list) | |||||
| elif knowledge_config.data_source.info_list.data_source_type == "notion_import": | |||||
| notion_info_list = knowledge_config.data_source.info_list.notion_info_list | |||||
| for notion_info in notion_info_list: # type: ignore | |||||
| count = count + len(notion_info.pages) | |||||
| elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": | |||||
| website_info = knowledge_config.data_source.info_list.website_info_list | |||||
| count = len(website_info.urls) # type: ignore | |||||
| batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) | |||||
| if features.billing.subscription.plan == "sandbox" and count > 1: | |||||
| raise ValueError("Your current plan does not support batch upload, please upgrade your plan.") | |||||
| if count > batch_upload_limit: | |||||
| raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | |||||
| DocumentService.check_documents_upload_quota(count, features) | |||||
| # if dataset is empty, update dataset data_source_type | |||||
| if not dataset.data_source_type: | |||||
| dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore | |||||
| if not dataset.indexing_technique: | |||||
| if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: | |||||
| raise ValueError("Indexing technique is invalid") | |||||
| dataset.indexing_technique = knowledge_config.indexing_technique | |||||
| if knowledge_config.indexing_technique == "high_quality": | |||||
| model_manager = ModelManager() | |||||
| if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: | |||||
| dataset_embedding_model = knowledge_config.embedding_model | |||||
| dataset_embedding_model_provider = knowledge_config.embedding_model_provider | |||||
| else: | |||||
| embedding_model = model_manager.get_default_model_instance( | |||||
| tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING | |||||
| ) | |||||
| dataset_embedding_model = embedding_model.model | |||||
| dataset_embedding_model_provider = embedding_model.provider | |||||
| dataset.embedding_model = dataset_embedding_model | |||||
| dataset.embedding_model_provider = dataset_embedding_model_provider | |||||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||||
| dataset_embedding_model_provider, dataset_embedding_model | |||||
| ) | |||||
| dataset.collection_binding_id = dataset_collection_binding.id | |||||
| if not dataset.retrieval_model: | |||||
| default_retrieval_model = { | |||||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||||
| "reranking_enable": False, | |||||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||||
| "top_k": 2, | |||||
| "score_threshold_enabled": False, | |||||
| } | |||||
| dataset.retrieval_model = ( | |||||
| knowledge_config.retrieval_model.model_dump() | |||||
| if knowledge_config.retrieval_model | |||||
| else default_retrieval_model | |||||
| ) # type: ignore | |||||
| documents = [] | |||||
| if knowledge_config.original_document_id: | |||||
| document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) | |||||
| documents.append(document) | |||||
| batch = document.batch | |||||
| else: | |||||
| batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) | |||||
| # save process rule | |||||
| if not dataset_process_rule: | |||||
| process_rule = knowledge_config.process_rule | |||||
| if process_rule: | |||||
| if process_rule.mode in ("custom", "hierarchical"): | |||||
| dataset_process_rule = DatasetProcessRule( | |||||
| dataset_id=dataset.id, | |||||
| mode=process_rule.mode, | |||||
| rules=process_rule.rules.model_dump_json() if process_rule.rules else None, | |||||
| created_by=account.id, | |||||
| ) | |||||
| elif process_rule.mode == "automatic": | |||||
| dataset_process_rule = DatasetProcessRule( | |||||
| dataset_id=dataset.id, | |||||
| mode=process_rule.mode, | |||||
| rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), | |||||
| created_by=account.id, | |||||
| ) | |||||
| else: | |||||
| logging.warn( | |||||
| f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" | |||||
| ) | |||||
| return | |||||
| db.session.add(dataset_process_rule) | |||||
| db.session.commit() | |||||
| lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) | |||||
| with redis_client.lock(lock_name, timeout=600): | |||||
| position = DocumentService.get_documents_position(dataset.id) | |||||
| document_ids = [] | |||||
| duplicate_document_ids = [] | |||||
| if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore | |||||
| upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore | |||||
| for file_id in upload_file_list: | |||||
| file = ( | |||||
| db.session.query(UploadFile) | |||||
| .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) | |||||
| .first() | |||||
| ) | |||||
| # raise error if file not found | |||||
| if not file: | |||||
| raise FileNotExistsError() | |||||
| file_name = file.name | |||||
| data_source_info = { | |||||
| "upload_file_id": file_id, | |||||
| } | |||||
| # check duplicate | |||||
| if knowledge_config.duplicate: | |||||
| document = Document.query.filter_by( | |||||
| dataset_id=dataset.id, | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| data_source_type="upload_file", | |||||
| enabled=True, | |||||
| name=file_name, | |||||
| ).first() | |||||
| if document: | |||||
| document.dataset_process_rule_id = dataset_process_rule.id # type: ignore | |||||
| document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||||
| document.created_from = created_from | |||||
| document.doc_form = knowledge_config.doc_form | |||||
| document.doc_language = knowledge_config.doc_language | |||||
| document.data_source_info = json.dumps(data_source_info) | |||||
| document.batch = batch | |||||
| document.indexing_status = "waiting" | |||||
| db.session.add(document) | |||||
| documents.append(document) | |||||
| duplicate_document_ids.append(document.id) | |||||
| continue | |||||
| document = DocumentService.build_document( | |||||
| dataset, | |||||
| dataset_process_rule.id, # type: ignore | |||||
| knowledge_config.data_source.info_list.data_source_type, # type: ignore | |||||
| knowledge_config.doc_form, | |||||
| knowledge_config.doc_language, | |||||
| data_source_info, | |||||
| created_from, | |||||
| position, | |||||
| account, | |||||
| file_name, | |||||
| batch, | |||||
| ) | |||||
| db.session.add(document) | |||||
| db.session.flush() | |||||
| document_ids.append(document.id) | |||||
| documents.append(document) | |||||
| position += 1 | |||||
| elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore | |||||
| notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore | |||||
| if not notion_info_list: | |||||
| raise ValueError("No notion info list found.") | |||||
| exist_page_ids = [] | |||||
| exist_document = {} | |||||
| documents = Document.query.filter_by( | |||||
| dataset_id=dataset.id, | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| data_source_type="notion_import", | |||||
| enabled=True, | |||||
| ).all() | |||||
| if documents: | |||||
| for document in documents: | |||||
| data_source_info = json.loads(document.data_source_info) | |||||
| exist_page_ids.append(data_source_info["notion_page_id"]) | |||||
| exist_document[data_source_info["notion_page_id"]] = document.id | |||||
| for notion_info in notion_info_list: | |||||
| workspace_id = notion_info.workspace_id | |||||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||||
| db.and_( | |||||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||||
| DataSourceOauthBinding.provider == "notion", | |||||
| DataSourceOauthBinding.disabled == False, | |||||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |||||
| ) | |||||
| ).first() | |||||
| if not data_source_binding: | |||||
| raise ValueError("Data source binding not found.") | |||||
| for page in notion_info.pages: | |||||
| if page.page_id not in exist_page_ids: | |||||
| data_source_info = { | |||||
| "notion_workspace_id": workspace_id, | |||||
| "notion_page_id": page.page_id, | |||||
| "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, | |||||
| "type": page.type, | |||||
| } | |||||
| # Truncate page name to 255 characters to prevent DB field length errors | |||||
| truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" | |||||
| document = DocumentService.build_document( | |||||
| dataset, | |||||
| dataset_process_rule.id, # type: ignore | |||||
| knowledge_config.data_source.info_list.data_source_type, # type: ignore | |||||
| knowledge_config.doc_form, | |||||
| knowledge_config.doc_language, | |||||
| data_source_info, | |||||
| created_from, | |||||
| position, | |||||
| account, | |||||
| truncated_page_name, | |||||
| batch, | |||||
| ) | |||||
| db.session.add(document) | |||||
| db.session.flush() | |||||
| document_ids.append(document.id) | |||||
| documents.append(document) | |||||
| position += 1 | |||||
| else: | |||||
| exist_document.pop(page.page_id) | |||||
| # delete not selected documents | |||||
| if len(exist_document) > 0: | |||||
| clean_notion_document_task.delay(list(exist_document.values()), dataset.id) | |||||
| elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore | |||||
| website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore | |||||
| if not website_info: | |||||
| raise ValueError("No website info list found.") | |||||
| urls = website_info.urls | |||||
| for url in urls: | |||||
| data_source_info = { | |||||
| "url": url, | |||||
| "provider": website_info.provider, | |||||
| "job_id": website_info.job_id, | |||||
| "only_main_content": website_info.only_main_content, | |||||
| "mode": "crawl", | |||||
| } | |||||
| if len(url) > 255: | |||||
| document_name = url[:200] + "..." | |||||
| else: | |||||
| document_name = url | |||||
| document = DocumentService.build_document( | |||||
| dataset, | |||||
| dataset_process_rule.id, # type: ignore | |||||
| knowledge_config.data_source.info_list.data_source_type, # type: ignore | |||||
| knowledge_config.doc_form, | |||||
| knowledge_config.doc_language, | |||||
| data_source_info, | |||||
| created_from, | |||||
| position, | |||||
| account, | |||||
| document_name, | |||||
| batch, | |||||
| ) | |||||
| db.session.add(document) | |||||
| db.session.flush() | |||||
| document_ids.append(document.id) | |||||
| documents.append(document) | |||||
| position += 1 | |||||
| db.session.commit() | |||||
| # trigger async task | |||||
| if document_ids: | |||||
| document_indexing_task.delay(dataset.id, document_ids) | |||||
| if duplicate_document_ids: | |||||
| duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) | |||||
| return documents, batch | |||||
| @staticmethod | |||||
| def invoke_knowledge_index( | |||||
| dataset: Dataset, | |||||
| chunks: list[Any], | |||||
| index_method: IndexMethod, | |||||
| retrieval_setting: RetrievalSetting, | |||||
| original_document_id: str | None = None, | |||||
| account: Account | Any, | |||||
| created_from: str = "rag-pipline", | |||||
| ): | |||||
| if not dataset.indexing_technique: | |||||
| if index_method.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: | |||||
| raise ValueError("Indexing technique is invalid") | |||||
| dataset.indexing_technique = index_method.indexing_technique | |||||
| if index_method.indexing_technique == "high_quality": | |||||
| model_manager = ModelManager() | |||||
| if index_method.embedding_setting.embedding_model and index_method.embedding_setting.embedding_model_provider: | |||||
| dataset_embedding_model = index_method.embedding_setting.embedding_model | |||||
| dataset_embedding_model_provider = index_method.embedding_setting.embedding_model_provider | |||||
| else: | |||||
| embedding_model = model_manager.get_default_model_instance( | |||||
| tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING | |||||
| ) | |||||
| dataset_embedding_model = embedding_model.model | |||||
| dataset_embedding_model_provider = embedding_model.provider | |||||
| dataset.embedding_model = dataset_embedding_model | |||||
| dataset.embedding_model_provider = dataset_embedding_model_provider | |||||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||||
| dataset_embedding_model_provider, dataset_embedding_model | |||||
| ) | |||||
| dataset.collection_binding_id = dataset_collection_binding.id | |||||
| if not dataset.retrieval_model: | |||||
| default_retrieval_model = { | |||||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||||
| "reranking_enable": False, | |||||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||||
| "top_k": 2, | |||||
| "score_threshold_enabled": False, | |||||
| } | |||||
| dataset.retrieval_model = ( | |||||
| retrieval_setting.model_dump() | |||||
| if retrieval_setting | |||||
| else default_retrieval_model | |||||
| ) # type: ignore | |||||
| documents = [] | |||||
| if original_document_id: | |||||
| document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) | |||||
| documents.append(document) | |||||
| batch = document.batch | |||||
| else: | |||||
| batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) | |||||
| lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) | |||||
| with redis_client.lock(lock_name, timeout=600): | |||||
| position = DocumentService.get_documents_position(dataset.id) | |||||
| document_ids = [] | |||||
| duplicate_document_ids = [] | |||||
| for chunk in chunks: | |||||
| file = ( | |||||
| db.session.query(UploadFile) | |||||
| .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) | |||||
| .first() | |||||
| ) | |||||
| # raise error if file not found | |||||
| if not file: | |||||
| raise FileNotExistsError() | |||||
| file_name = file.name | |||||
| data_source_info = { | |||||
| "upload_file_id": file_id, | |||||
| } | |||||
| # check duplicate | |||||
| if knowledge_config.duplicate: | |||||
| document = Document.query.filter_by( | |||||
| dataset_id=dataset.id, | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| data_source_type="upload_file", | |||||
| enabled=True, | |||||
| name=file_name, | |||||
| ).first() | |||||
| if document: | |||||
| document.dataset_process_rule_id = dataset_process_rule.id # type: ignore | |||||
| document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||||
| document.created_from = created_from | |||||
| document.doc_form = knowledge_config.doc_form | |||||
| document.doc_language = knowledge_config.doc_language | |||||
| document.data_source_info = json.dumps(data_source_info) | |||||
| document.batch = batch | |||||
| document.indexing_status = "waiting" | |||||
| db.session.add(document) | |||||
| documents.append(document) | |||||
| duplicate_document_ids.append(document.id) | |||||
| continue | |||||
| document = DocumentService.build_document( | |||||
| dataset, | |||||
| dataset_process_rule.id, # type: ignore | |||||
| knowledge_config.data_source.info_list.data_source_type, # type: ignore | |||||
| knowledge_config.doc_form, | |||||
| knowledge_config.doc_language, | |||||
| data_source_info, | |||||
| created_from, | |||||
| position, | |||||
| account, | |||||
| file_name, | |||||
| batch, | |||||
| ) | |||||
| db.session.add(document) | |||||
| db.session.flush() | |||||
| document_ids.append(document.id) | |||||
| documents.append(document) | |||||
| position += 1 | |||||
| db.session.commit() | |||||
| # trigger async task | |||||
| if document_ids: | |||||
| document_indexing_task.delay(dataset.id, document_ids) | |||||
| if duplicate_document_ids: | |||||
| duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) | |||||
| return documents, batch | |||||
| @staticmethod | |||||
| def check_documents_upload_quota(count: int, features: FeatureModel): | def check_documents_upload_quota(count: int, features: FeatureModel): | ||||
| can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size | can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size | ||||
| if count > can_upload_size: | if count > can_upload_size: |