| @@ -93,7 +93,7 @@ class DraftRagPipelineApi(Resource): | |||
| parser.add_argument("hash", type=str, required=False, location="json") | |||
| parser.add_argument("environment_variables", type=list, required=False, location="json") | |||
| parser.add_argument("conversation_variables", type=list, required=False, location="json") | |||
| parser.add_argument("rag_pipeline_variables", type=dict, required=False, location="json") | |||
| parser.add_argument("rag_pipeline_variables", type=list, required=False, location="json") | |||
| args = parser.parse_args() | |||
| elif "text/plain" in content_type: | |||
| try: | |||
| @@ -101,8 +101,8 @@ class DraftRagPipelineApi(Resource): | |||
| if "graph" not in data or "features" not in data: | |||
| raise ValueError("graph or features not found in data") | |||
| if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict): | |||
| raise ValueError("graph or features is not a dict") | |||
| if not isinstance(data.get("graph"), dict): | |||
| raise ValueError("graph is not a dict") | |||
| args = { | |||
| "graph": data.get("graph"), | |||
| @@ -129,11 +129,9 @@ class DraftRagPipelineApi(Resource): | |||
| conversation_variables = [ | |||
| variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list | |||
| ] | |||
| rag_pipeline_variables_list = args.get("rag_pipeline_variables") or {} | |||
| rag_pipeline_variables = { | |||
| k: [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in v] | |||
| for k, v in rag_pipeline_variables_list.items() | |||
| } | |||
| rag_pipeline_variables_list = args.get("rag_pipeline_variables") or [] | |||
| rag_pipeline_variables = [variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list] | |||
| rag_pipeline_service = RagPipelineService() | |||
| workflow = rag_pipeline_service.sync_draft_workflow( | |||
| pipeline=pipeline, | |||
| @@ -634,12 +632,15 @@ class RagPipelineSecondStepApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| datasource_provider = request.args.get("datasource_provider", required=True, type=str) | |||
| node_id = request.args.get("node_id", required=True, type=str) | |||
| rag_pipeline_service = RagPipelineService() | |||
| return rag_pipeline_service.get_second_step_parameters( | |||
| pipeline=pipeline, datasource_provider=datasource_provider | |||
| variables = rag_pipeline_service.get_second_step_parameters( | |||
| pipeline=pipeline, node_id=node_id | |||
| ) | |||
| return { | |||
| "variables": variables, | |||
| } | |||
| class RagPipelineWorkflowRunListApi(Resource): | |||
| @@ -785,3 +786,7 @@ api.add_resource( | |||
| DatasourceListApi, | |||
| "/rag/pipelines/datasource-plugins", | |||
| ) | |||
| api.add_resource( | |||
| RagPipelineSecondStepApi, | |||
| "/rag/pipelines/<uuid:pipeline_id>/workflows/processing/paramters", | |||
| ) | |||
| @@ -4,7 +4,6 @@ from typing import Any, Optional, TextIO, Union | |||
| from pydantic import BaseModel | |||
| from configs import dify_config | |||
| from core.datasource.entities.datasource_entities import DatasourceInvokeMessage | |||
| from core.ops.entities.trace_entity import TraceTaskName | |||
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| @@ -114,35 +113,6 @@ class DifyAgentCallbackHandler(BaseModel): | |||
| color=self.color, | |||
| ) | |||
| def on_datasource_end( | |||
| self, | |||
| datasource_name: str, | |||
| datasource_inputs: Mapping[str, Any], | |||
| datasource_outputs: Iterable[DatasourceInvokeMessage] | str, | |||
| message_id: Optional[str] = None, | |||
| timer: Optional[Any] = None, | |||
| trace_manager: Optional[TraceQueueManager] = None, | |||
| ) -> None: | |||
| """Run on datasource end.""" | |||
| if dify_config.DEBUG: | |||
| print_text("\n[on_datasource_end]\n", color=self.color) | |||
| print_text("Datasource: " + datasource_name + "\n", color=self.color) | |||
| print_text("Inputs: " + str(datasource_inputs) + "\n", color=self.color) | |||
| print_text("Outputs: " + str(datasource_outputs)[:1000] + "\n", color=self.color) | |||
| print_text("\n") | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| TraceTask( | |||
| TraceTaskName.DATASOURCE_TRACE, | |||
| message_id=message_id, | |||
| datasource_name=datasource_name, | |||
| datasource_inputs=datasource_inputs, | |||
| datasource_outputs=datasource_outputs, | |||
| timer=timer, | |||
| ) | |||
| ) | |||
| @property | |||
| def ignore_agent(self) -> bool: | |||
| """Whether to ignore agent callbacks.""" | |||
| @@ -145,7 +145,7 @@ class DatasourceProviderEntity(ToolProviderEntity): | |||
| class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): | |||
| datasources: list[DatasourceEntity] = Field(default_factory=list) | |||
| datasources: list[DatasourceEntity] = Field(default_factory=list) | |||
| class DatasourceInvokeMeta(BaseModel): | |||
| @@ -127,7 +127,7 @@ class GeneralStructureChunk(BaseModel): | |||
| General Structure Chunk. | |||
| """ | |||
| general_chunk: list[str] | |||
| general_chunks: list[str] | |||
| data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo] | |||
| @@ -80,9 +80,9 @@ def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Va | |||
| def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: | |||
| if not mapping.get("name"): | |||
| raise VariableError("missing name") | |||
| return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["name"]]) | |||
| if not mapping.get("variable"): | |||
| raise VariableError("missing variable") | |||
| return _build_variable_from_mapping(mapping=mapping, selector=[PIPELINE_VARIABLE_NODE_ID, mapping["variable"]]) | |||
| def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: | |||
| @@ -123,6 +123,43 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen | |||
| result = result.model_copy(update={"selector": selector}) | |||
| return cast(Variable, result) | |||
| def _build_rag_pipeline_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: | |||
| """ | |||
| This factory function is used to create the rag pipeline variable, | |||
| not support the File type. | |||
| """ | |||
| if (type := mapping.get("type")) is None: | |||
| raise VariableError("missing type") | |||
| if (value := mapping.get("value")) is None: | |||
| raise VariableError("missing value") | |||
| # FIXME: using Any here, fix it later | |||
| result: Any | |||
| match type: | |||
| case SegmentType.STRING: | |||
| result = StringVariable.model_validate(mapping) | |||
| case SegmentType.SECRET: | |||
| result = SecretVariable.model_validate(mapping) | |||
| case SegmentType.NUMBER if isinstance(value, int): | |||
| result = IntegerVariable.model_validate(mapping) | |||
| case SegmentType.NUMBER if isinstance(value, float): | |||
| result = FloatVariable.model_validate(mapping) | |||
| case SegmentType.NUMBER if not isinstance(value, float | int): | |||
| raise VariableError(f"invalid number value {value}") | |||
| case SegmentType.OBJECT if isinstance(value, dict): | |||
| result = ObjectVariable.model_validate(mapping) | |||
| case SegmentType.ARRAY_STRING if isinstance(value, list): | |||
| result = ArrayStringVariable.model_validate(mapping) | |||
| case SegmentType.ARRAY_NUMBER if isinstance(value, list): | |||
| result = ArrayNumberVariable.model_validate(mapping) | |||
| case SegmentType.ARRAY_OBJECT if isinstance(value, list): | |||
| result = ArrayObjectVariable.model_validate(mapping) | |||
| case _: | |||
| raise VariableError(f"not supported type {type}") | |||
| if result.size > dify_config.MAX_VARIABLE_SIZE: | |||
| raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") | |||
| if not result.selector: | |||
| result = result.model_copy(update={"selector": selector}) | |||
| return cast(Variable, result) | |||
| def build_segment(value: Any, /) -> Segment: | |||
| if value is None: | |||
| @@ -42,9 +42,19 @@ conversation_variable_fields = { | |||
| pipeline_variable_fields = { | |||
| "id": fields.String, | |||
| "name": fields.String, | |||
| "value_type": fields.String(attribute="value_type.value"), | |||
| "value": fields.Raw, | |||
| "label": fields.String, | |||
| "variable": fields.String, | |||
| "type": fields.String(attribute="type.value"), | |||
| "belong_to_node_id": fields.String, | |||
| "max_length": fields.Integer, | |||
| "required": fields.Boolean, | |||
| "default_value": fields.Raw, | |||
| "options": fields.List(fields.String), | |||
| "placeholder": fields.String, | |||
| "tooltips": fields.String, | |||
| "allowed_file_types": fields.List(fields.String), | |||
| "allow_file_extension": fields.List(fields.String), | |||
| "allow_file_upload_methods": fields.List(fields.String), | |||
| } | |||
| workflow_fields = { | |||
| @@ -62,6 +72,7 @@ workflow_fields = { | |||
| "tool_published": fields.Boolean, | |||
| "environment_variables": fields.List(EnvironmentVariableField()), | |||
| "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), | |||
| "rag_pipeline_variables": fields.List(fields.Nested(pipeline_variable_fields)), | |||
| } | |||
| workflow_partial_fields = { | |||
| @@ -352,21 +352,19 @@ class Workflow(Base): | |||
| ) | |||
| @property | |||
| def rag_pipeline_variables(self) -> dict[str, Sequence[Variable]]: | |||
| def rag_pipeline_variables(self) -> Sequence[Variable]: | |||
| # TODO: find some way to init `self._conversation_variables` when instance created. | |||
| if self._rag_pipeline_variables is None: | |||
| self._rag_pipeline_variables = "{}" | |||
| variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) | |||
| results = {} | |||
| for k, v in variables_dict.items(): | |||
| results[k] = [variable_factory.build_pipeline_variable_from_mapping(item) for item in v.values()] | |||
| results = [variable_factory.build_pipeline_variable_from_mapping(v) for v in variables_dict.values()] | |||
| return results | |||
| @rag_pipeline_variables.setter | |||
| def rag_pipeline_variables(self, values: dict[str, Sequence[Variable]]) -> None: | |||
| def rag_pipeline_variables(self, values: Sequence[Variable]) -> None: | |||
| self._rag_pipeline_variables = json.dumps( | |||
| {k: {item.name: item.model_dump() for item in v} for k, v in values.items()}, | |||
| {item.name: item.model_dump() for item in values}, | |||
| ensure_ascii=False, | |||
| ) | |||
| @@ -201,7 +201,7 @@ class RagPipelineService: | |||
| account: Account, | |||
| environment_variables: Sequence[Variable], | |||
| conversation_variables: Sequence[Variable], | |||
| rag_pipeline_variables: dict[str, Sequence[Variable]], | |||
| rag_pipeline_variables: Sequence[Variable], | |||
| ) -> Workflow: | |||
| """ | |||
| Sync draft workflow | |||
| @@ -552,7 +552,7 @@ class RagPipelineService: | |||
| return workflow | |||
| def get_second_step_parameters(self, pipeline: Pipeline, datasource_provider: str) -> dict: | |||
| def get_second_step_parameters(self, pipeline: Pipeline, node_id: str) -> dict: | |||
| """ | |||
| Get second step parameters of rag pipeline | |||
| """ | |||
| @@ -562,13 +562,15 @@ class RagPipelineService: | |||
| raise ValueError("Workflow not initialized") | |||
| # get second step node | |||
| pipeline_variables = workflow.pipeline_variables | |||
| if not pipeline_variables: | |||
| rag_pipeline_variables = workflow.rag_pipeline_variables | |||
| if not rag_pipeline_variables: | |||
| return {} | |||
| # get datasource provider | |||
| datasource_provider_variables = pipeline_variables.get(datasource_provider, []) | |||
| shared_variables = pipeline_variables.get("shared", []) | |||
| return datasource_provider_variables + shared_variables | |||
| datasource_provider_variables = [item for item in rag_pipeline_variables | |||
| if item.get("belong_to_node_id") == node_id | |||
| or item.get("belong_to_node_id") == "shared"] | |||
| return datasource_provider_variables | |||
| def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: | |||
| """ | |||