| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("query", type=str, location="json") | parser.add_argument("query", type=str, location="json") | ||||
| parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") | parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") | ||||
| parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| HitTestingService.hit_testing_args_check(args) | HitTestingService.hit_testing_args_check(args) | ||||
| query=args["query"], | query=args["query"], | ||||
| account=current_user, | account=current_user, | ||||
| external_retrieval_model=args["external_retrieval_model"], | external_retrieval_model=args["external_retrieval_model"], | ||||
| metadata_filtering_conditions=args["metadata_filtering_conditions"], | |||||
| ) | ) | ||||
| return response | return response |
| return_resource=app_config.additional_features.show_retrieve_source, | return_resource=app_config.additional_features.show_retrieve_source, | ||||
| invoke_from=application_generate_entity.invoke_from, | invoke_from=application_generate_entity.invoke_from, | ||||
| hit_callback=hit_callback, | hit_callback=hit_callback, | ||||
| user_id=user_id, | |||||
| inputs=cast(dict, application_generate_entity.inputs), | |||||
| ) | ) | ||||
| # get how many agent thoughts have been created | # get how many agent thoughts have been created | ||||
| self.agent_thought_count = ( | self.agent_thought_count = ( |
| tool_instances, prompt_messages_tools = self._init_prompt_tools() | tool_instances, prompt_messages_tools = self._init_prompt_tools() | ||||
| self._prompt_messages_tools = prompt_messages_tools | self._prompt_messages_tools = prompt_messages_tools | ||||
| # fix metadata filter not work | |||||
| if app_config.dataset is not None: | |||||
| metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions | |||||
| for key, dataset_retriever_tool in tool_instances.items(): | |||||
| if hasattr(dataset_retriever_tool, "retrieval_tool"): | |||||
| dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions | |||||
| function_call_state = True | function_call_state = True | ||||
| llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} | llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} | ||||
| final_answer = "" | final_answer = "" |
| # convert tools into ModelRuntime Tool format | # convert tools into ModelRuntime Tool format | ||||
| tool_instances, prompt_messages_tools = self._init_prompt_tools() | tool_instances, prompt_messages_tools = self._init_prompt_tools() | ||||
| # fix metadata filter not work | |||||
| if app_config.dataset is not None: | |||||
| metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions | |||||
| for key, dataset_retriever_tool in tool_instances.items(): | |||||
| if hasattr(dataset_retriever_tool, "retrieval_tool"): | |||||
| dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions | |||||
| assert app_config.agent | assert app_config.agent | ||||
| iteration_step = 1 | iteration_step = 1 |
| from core.rag.datasource.keyword.keyword_factory import Keyword | from core.rag.datasource.keyword.keyword_factory import Keyword | ||||
| from core.rag.datasource.vdb.vector_factory import Vector | from core.rag.datasource.vdb.vector_factory import Vector | ||||
| from core.rag.embedding.retrieval import RetrievalSegments | from core.rag.embedding.retrieval import RetrievalSegments | ||||
| from core.rag.entities.metadata_entities import MetadataCondition | |||||
| from core.rag.index_processor.constant.index_type import IndexType | from core.rag.index_processor.constant.index_type import IndexType | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from core.rag.rerank.rerank_type import RerankMode | from core.rag.rerank.rerank_type import RerankMode | ||||
| return all_documents | return all_documents | ||||
| @classmethod | @classmethod | ||||
| def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None): | |||||
| def external_retrieve( | |||||
| cls, | |||||
| dataset_id: str, | |||||
| query: str, | |||||
| external_retrieval_model: Optional[dict] = None, | |||||
| metadata_filtering_conditions: Optional[dict] = None, | |||||
| ): | |||||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | ||||
| if not dataset: | if not dataset: | ||||
| return [] | return [] | ||||
| metadata_condition = ( | |||||
| MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None | |||||
| ) | |||||
| all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | ||||
| dataset.tenant_id, dataset_id, query, external_retrieval_model or {} | |||||
| dataset.tenant_id, | |||||
| dataset_id, | |||||
| query, | |||||
| external_retrieval_model or {}, | |||||
| metadata_condition=metadata_condition, | |||||
| ) | ) | ||||
| return all_documents | return all_documents | ||||
| else: | else: | ||||
| inputs = {} | inputs = {} | ||||
| available_datasets_ids = [dataset.id for dataset in available_datasets] | available_datasets_ids = [dataset.id for dataset in available_datasets] | ||||
| metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( | |||||
| metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition( | |||||
| available_datasets_ids, | available_datasets_ids, | ||||
| query, | query, | ||||
| tenant_id, | tenant_id, | ||||
| return_resource: bool, | return_resource: bool, | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| hit_callback: DatasetIndexToolCallbackHandler, | hit_callback: DatasetIndexToolCallbackHandler, | ||||
| user_id: str, | |||||
| inputs: dict, | |||||
| ) -> Optional[list[DatasetRetrieverBaseTool]]: | ) -> Optional[list[DatasetRetrieverBaseTool]]: | ||||
| """ | """ | ||||
| A dataset tool is a tool that can be used to retrieve information from a dataset | A dataset tool is a tool that can be used to retrieve information from a dataset | ||||
| hit_callbacks=[hit_callback], | hit_callbacks=[hit_callback], | ||||
| return_resource=return_resource, | return_resource=return_resource, | ||||
| retriever_from=invoke_from.to_source(), | retriever_from=invoke_from.to_source(), | ||||
| retrieve_config=retrieve_config, | |||||
| user_id=user_id, | |||||
| inputs=inputs, | |||||
| ) | ) | ||||
| tools.append(tool) | tools.append(tool) | ||||
| ) | ) | ||||
| return filter_documents[:top_k] if top_k else filter_documents | return filter_documents[:top_k] if top_k else filter_documents | ||||
| def _get_metadata_filter_condition( | |||||
| def get_metadata_filter_condition( | |||||
| self, | self, | ||||
| dataset_ids: list, | dataset_ids: list, | ||||
| query: str, | query: str, | ||||
| ) | ) | ||||
| elif metadata_filtering_mode == "manual": | elif metadata_filtering_mode == "manual": | ||||
| if metadata_filtering_conditions: | if metadata_filtering_conditions: | ||||
| metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump()) | |||||
| conditions = [] | |||||
| for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore | for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore | ||||
| metadata_name = condition.name | metadata_name = condition.name | ||||
| expected_value = condition.value | expected_value = condition.value | ||||
| if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): | |||||
| if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): | |||||
| if isinstance(expected_value, str): | if isinstance(expected_value, str): | ||||
| expected_value = self._replace_metadata_filter_value(expected_value, inputs) | expected_value = self._replace_metadata_filter_value(expected_value, inputs) | ||||
| filters = self._process_metadata_filter_func( | |||||
| sequence, | |||||
| condition.comparison_operator, | |||||
| metadata_name, | |||||
| expected_value, | |||||
| filters, | |||||
| conditions.append( | |||||
| Condition( | |||||
| name=metadata_name, | |||||
| comparison_operator=condition.comparison_operator, | |||||
| value=expected_value, | |||||
| ) | ) | ||||
| ) | |||||
| filters = self._process_metadata_filter_func( | |||||
| sequence, | |||||
| condition.comparison_operator, | |||||
| metadata_name, | |||||
| expected_value, | |||||
| filters, | |||||
| ) | |||||
| metadata_condition = MetadataCondition( | |||||
| logical_operator=metadata_filtering_conditions.logical_operator, | |||||
| conditions=conditions, | |||||
| ) | |||||
| else: | else: | ||||
| raise ValueError("Invalid metadata filtering mode") | raise ValueError("Invalid metadata filtering mode") | ||||
| if filters: | if filters: |
| from typing import Any | |||||
| from typing import Any, Optional, cast | |||||
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig | |||||
| from core.rag.datasource.retrieval_service import RetrievalService | from core.rag.datasource.retrieval_service import RetrievalService | ||||
| from core.rag.entities.context_entities import DocumentContext | from core.rag.entities.context_entities import DocumentContext | ||||
| from core.rag.entities.metadata_entities import MetadataCondition | |||||
| from core.rag.models.document import Document as RetrievalDocument | from core.rag.models.document import Document as RetrievalDocument | ||||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| args_schema: type[BaseModel] = DatasetRetrieverToolInput | args_schema: type[BaseModel] = DatasetRetrieverToolInput | ||||
| description: str = "use this to retrieve a dataset. " | description: str = "use this to retrieve a dataset. " | ||||
| dataset_id: str | dataset_id: str | ||||
| metadata_filtering_conditions: MetadataCondition | |||||
| user_id: Optional[str] = None | |||||
| retrieve_config: DatasetRetrieveConfigEntity | |||||
| inputs: dict | |||||
| @classmethod | @classmethod | ||||
| def from_dataset(cls, dataset: Dataset, **kwargs): | def from_dataset(cls, dataset: Dataset, **kwargs): | ||||
| tenant_id=dataset.tenant_id, | tenant_id=dataset.tenant_id, | ||||
| dataset_id=dataset.id, | dataset_id=dataset.id, | ||||
| description=description, | description=description, | ||||
| metadata_filtering_conditions=MetadataCondition(), | |||||
| **kwargs, | **kwargs, | ||||
| ) | ) | ||||
| return "" | return "" | ||||
| for hit_callback in self.hit_callbacks: | for hit_callback in self.hit_callbacks: | ||||
| hit_callback.on_query(query, dataset.id) | hit_callback.on_query(query, dataset.id) | ||||
| dataset_retrieval = DatasetRetrieval() | |||||
| metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( | |||||
| [dataset.id], | |||||
| query, | |||||
| self.tenant_id, | |||||
| self.user_id or "unknown", | |||||
| cast(str, self.retrieve_config.metadata_filtering_mode), | |||||
| cast(ModelConfig, self.retrieve_config.metadata_model_config), | |||||
| self.retrieve_config.metadata_filtering_conditions, | |||||
| self.inputs, | |||||
| ) | |||||
| if metadata_filter_document_ids: | |||||
| document_ids_filter = metadata_filter_document_ids.get(dataset.id, []) | |||||
| else: | |||||
| document_ids_filter = None | |||||
| if dataset.provider == "external": | if dataset.provider == "external": | ||||
| results = [] | results = [] | ||||
| external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | ||||
| dataset_id=dataset.id, | dataset_id=dataset.id, | ||||
| query=query, | query=query, | ||||
| external_retrieval_parameters=dataset.retrieval_model, | external_retrieval_parameters=dataset.retrieval_model, | ||||
| metadata_condition=self.metadata_filtering_conditions, | |||||
| metadata_condition=metadata_condition, | |||||
| ) | ) | ||||
| for external_document in external_documents: | for external_document in external_documents: | ||||
| document = RetrievalDocument( | document = RetrievalDocument( | ||||
| return str("\n".join([item.page_content for item in results])) | return str("\n".join([item.page_content for item in results])) | ||||
| else: | else: | ||||
| if metadata_condition and not document_ids_filter: | |||||
| return "" | |||||
| # get retrieval model , if the model is not setting , using default | # get retrieval model , if the model is not setting , using default | ||||
| retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model | retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model | ||||
| if dataset.indexing_technique == "economy": | if dataset.indexing_technique == "economy": | ||||
| # use keyword table query | # use keyword table query | ||||
| documents = RetrievalService.retrieve( | documents = RetrievalService.retrieve( | ||||
| retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k | |||||
| retrieval_method="keyword_search", | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=self.top_k, | |||||
| document_ids_filter=document_ids_filter, | |||||
| ) | ) | ||||
| return str("\n".join([document.page_content for document in documents])) | return str("\n".join([document.page_content for document in documents])) | ||||
| else: | else: | ||||
| else None, | else None, | ||||
| reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | ||||
| weights=retrieval_model.get("weights"), | weights=retrieval_model.get("weights"), | ||||
| document_ids_filter=document_ids_filter, | |||||
| ) | ) | ||||
| else: | else: | ||||
| documents = [] | documents = [] |
| return_resource: bool, | return_resource: bool, | ||||
| invoke_from: InvokeFrom, | invoke_from: InvokeFrom, | ||||
| hit_callback: DatasetIndexToolCallbackHandler, | hit_callback: DatasetIndexToolCallbackHandler, | ||||
| user_id: str, | |||||
| inputs: dict, | |||||
| ) -> list["DatasetRetrieverTool"]: | ) -> list["DatasetRetrieverTool"]: | ||||
| """ | """ | ||||
| get dataset tool | get dataset tool | ||||
| return_resource=return_resource, | return_resource=return_resource, | ||||
| invoke_from=invoke_from, | invoke_from=invoke_from, | ||||
| hit_callback=hit_callback, | hit_callback=hit_callback, | ||||
| user_id=user_id, | |||||
| inputs=inputs, | |||||
| ) | ) | ||||
| if retrieval_tools is None or len(retrieval_tools) == 0: | if retrieval_tools is None or len(retrieval_tools) == 0: | ||||
| return [] | return [] |
| ) | ) | ||||
| elif node_data.metadata_filtering_mode == "manual": | elif node_data.metadata_filtering_mode == "manual": | ||||
| if node_data.metadata_filtering_conditions: | if node_data.metadata_filtering_conditions: | ||||
| metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump()) | |||||
| conditions = [] | |||||
| if node_data.metadata_filtering_conditions: | if node_data.metadata_filtering_conditions: | ||||
| for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore | for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore | ||||
| metadata_name = condition.name | metadata_name = condition.name | ||||
| expected_value = condition.value | expected_value = condition.value | ||||
| if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): | |||||
| if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): | |||||
| if isinstance(expected_value, str): | if isinstance(expected_value, str): | ||||
| expected_value = self.graph_runtime_state.variable_pool.convert_template( | expected_value = self.graph_runtime_state.variable_pool.convert_template( | ||||
| expected_value | expected_value | ||||
| expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore | expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore | ||||
| else: | else: | ||||
| raise ValueError("Invalid expected metadata value type") | raise ValueError("Invalid expected metadata value type") | ||||
| filters = self._process_metadata_filter_func( | |||||
| sequence, | |||||
| condition.comparison_operator, | |||||
| metadata_name, | |||||
| expected_value, | |||||
| filters, | |||||
| conditions.append( | |||||
| Condition( | |||||
| name=metadata_name, | |||||
| comparison_operator=condition.comparison_operator, | |||||
| value=expected_value, | |||||
| ) | ) | ||||
| ) | |||||
| filters = self._process_metadata_filter_func( | |||||
| sequence, | |||||
| condition.comparison_operator, | |||||
| metadata_name, | |||||
| expected_value, | |||||
| filters, | |||||
| ) | |||||
| metadata_condition = MetadataCondition( | |||||
| logical_operator=node_data.metadata_filtering_conditions.logical_operator, | |||||
| conditions=conditions, | |||||
| ) | |||||
| else: | else: | ||||
| raise ValueError("Invalid metadata filtering mode") | raise ValueError("Invalid metadata filtering mode") | ||||
| if filters: | if filters: |
| query: str, | query: str, | ||||
| account: Account, | account: Account, | ||||
| external_retrieval_model: dict, | external_retrieval_model: dict, | ||||
| metadata_filtering_conditions: dict, | |||||
| ) -> dict: | ) -> dict: | ||||
| if dataset.provider != "external": | if dataset.provider != "external": | ||||
| return { | return { | ||||
| dataset_id=dataset.id, | dataset_id=dataset.id, | ||||
| query=cls.escape_query_for_search(query), | query=cls.escape_query_for_search(query), | ||||
| external_retrieval_model=external_retrieval_model, | external_retrieval_model=external_retrieval_model, | ||||
| metadata_filtering_conditions=metadata_filtering_conditions, | |||||
| ) | ) | ||||
| end = time.perf_counter() | end = time.perf_counter() |