| @@ -209,6 +209,7 @@ class ExternalKnowledgeHitTestingApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("query", type=str, location="json") | |||
| parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") | |||
| parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json") | |||
| args = parser.parse_args() | |||
| HitTestingService.hit_testing_args_check(args) | |||
| @@ -219,6 +220,7 @@ class ExternalKnowledgeHitTestingApi(Resource): | |||
| query=args["query"], | |||
| account=current_user, | |||
| external_retrieval_model=args["external_retrieval_model"], | |||
| metadata_filtering_conditions=args["metadata_filtering_conditions"], | |||
| ) | |||
| return response | |||
| @@ -91,6 +91,8 @@ class BaseAgentRunner(AppRunner): | |||
| return_resource=app_config.additional_features.show_retrieve_source, | |||
| invoke_from=application_generate_entity.invoke_from, | |||
| hit_callback=hit_callback, | |||
| user_id=user_id, | |||
| inputs=cast(dict, application_generate_entity.inputs), | |||
| ) | |||
| # get how many agent thoughts have been created | |||
| self.agent_thought_count = ( | |||
| @@ -69,13 +69,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| tool_instances, prompt_messages_tools = self._init_prompt_tools() | |||
| self._prompt_messages_tools = prompt_messages_tools | |||
| # fix metadata filter not work | |||
| if app_config.dataset is not None: | |||
| metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions | |||
| for key, dataset_retriever_tool in tool_instances.items(): | |||
| if hasattr(dataset_retriever_tool, "retrieval_tool"): | |||
| dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions | |||
| function_call_state = True | |||
| llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} | |||
| final_answer = "" | |||
| @@ -45,13 +45,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| # convert tools into ModelRuntime Tool format | |||
| tool_instances, prompt_messages_tools = self._init_prompt_tools() | |||
| # fix metadata filter not work | |||
| if app_config.dataset is not None: | |||
| metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions | |||
| for key, dataset_retriever_tool in tool_instances.items(): | |||
| if hasattr(dataset_retriever_tool, "retrieval_tool"): | |||
| dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions | |||
| assert app_config.agent | |||
| iteration_step = 1 | |||
| @@ -10,6 +10,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor | |||
| from core.rag.datasource.keyword.keyword_factory import Keyword | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from core.rag.embedding.retrieval import RetrievalSegments | |||
| from core.rag.entities.metadata_entities import MetadataCondition | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.models.document import Document | |||
| from core.rag.rerank.rerank_type import RerankMode | |||
| @@ -119,12 +120,25 @@ class RetrievalService: | |||
| return all_documents | |||
| @classmethod | |||
| def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None): | |||
| def external_retrieve( | |||
| cls, | |||
| dataset_id: str, | |||
| query: str, | |||
| external_retrieval_model: Optional[dict] = None, | |||
| metadata_filtering_conditions: Optional[dict] = None, | |||
| ): | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| return [] | |||
| metadata_condition = ( | |||
| MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None | |||
| ) | |||
| all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | |||
| dataset.tenant_id, dataset_id, query, external_retrieval_model or {} | |||
| dataset.tenant_id, | |||
| dataset_id, | |||
| query, | |||
| external_retrieval_model or {}, | |||
| metadata_condition=metadata_condition, | |||
| ) | |||
| return all_documents | |||
| @@ -149,7 +149,7 @@ class DatasetRetrieval: | |||
| else: | |||
| inputs = {} | |||
| available_datasets_ids = [dataset.id for dataset in available_datasets] | |||
| metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( | |||
| metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition( | |||
| available_datasets_ids, | |||
| query, | |||
| tenant_id, | |||
| @@ -649,6 +649,8 @@ class DatasetRetrieval: | |||
| return_resource: bool, | |||
| invoke_from: InvokeFrom, | |||
| hit_callback: DatasetIndexToolCallbackHandler, | |||
| user_id: str, | |||
| inputs: dict, | |||
| ) -> Optional[list[DatasetRetrieverBaseTool]]: | |||
| """ | |||
| A dataset tool is a tool that can be used to retrieve information from a dataset | |||
| @@ -706,6 +708,9 @@ class DatasetRetrieval: | |||
| hit_callbacks=[hit_callback], | |||
| return_resource=return_resource, | |||
| retriever_from=invoke_from.to_source(), | |||
| retrieve_config=retrieve_config, | |||
| user_id=user_id, | |||
| inputs=inputs, | |||
| ) | |||
| tools.append(tool) | |||
| @@ -826,7 +831,7 @@ class DatasetRetrieval: | |||
| ) | |||
| return filter_documents[:top_k] if top_k else filter_documents | |||
| def _get_metadata_filter_condition( | |||
| def get_metadata_filter_condition( | |||
| self, | |||
| dataset_ids: list, | |||
| query: str, | |||
| @@ -876,20 +881,31 @@ class DatasetRetrieval: | |||
| ) | |||
| elif metadata_filtering_mode == "manual": | |||
| if metadata_filtering_conditions: | |||
| metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump()) | |||
| conditions = [] | |||
| for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore | |||
| metadata_name = condition.name | |||
| expected_value = condition.value | |||
| if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): | |||
| if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): | |||
| if isinstance(expected_value, str): | |||
| expected_value = self._replace_metadata_filter_value(expected_value, inputs) | |||
| filters = self._process_metadata_filter_func( | |||
| sequence, | |||
| condition.comparison_operator, | |||
| metadata_name, | |||
| expected_value, | |||
| filters, | |||
| conditions.append( | |||
| Condition( | |||
| name=metadata_name, | |||
| comparison_operator=condition.comparison_operator, | |||
| value=expected_value, | |||
| ) | |||
| ) | |||
| filters = self._process_metadata_filter_func( | |||
| sequence, | |||
| condition.comparison_operator, | |||
| metadata_name, | |||
| expected_value, | |||
| filters, | |||
| ) | |||
| metadata_condition = MetadataCondition( | |||
| logical_operator=metadata_filtering_conditions.logical_operator, | |||
| conditions=conditions, | |||
| ) | |||
| else: | |||
| raise ValueError("Invalid metadata filtering mode") | |||
| if filters: | |||
| @@ -1,11 +1,12 @@ | |||
| from typing import Any | |||
| from typing import Any, Optional, cast | |||
| from pydantic import BaseModel, Field | |||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.entities.context_entities import DocumentContext | |||
| from core.rag.entities.metadata_entities import MetadataCondition | |||
| from core.rag.models.document import Document as RetrievalDocument | |||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | |||
| from extensions.ext_database import db | |||
| @@ -34,7 +35,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| args_schema: type[BaseModel] = DatasetRetrieverToolInput | |||
| description: str = "use this to retrieve a dataset. " | |||
| dataset_id: str | |||
| metadata_filtering_conditions: MetadataCondition | |||
| user_id: Optional[str] = None | |||
| retrieve_config: DatasetRetrieveConfigEntity | |||
| inputs: dict | |||
| @classmethod | |||
| def from_dataset(cls, dataset: Dataset, **kwargs): | |||
| @@ -48,7 +51,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| tenant_id=dataset.tenant_id, | |||
| dataset_id=dataset.id, | |||
| description=description, | |||
| metadata_filtering_conditions=MetadataCondition(), | |||
| **kwargs, | |||
| ) | |||
| @@ -61,6 +63,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| return "" | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.on_query(query, dataset.id) | |||
| dataset_retrieval = DatasetRetrieval() | |||
| metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( | |||
| [dataset.id], | |||
| query, | |||
| self.tenant_id, | |||
| self.user_id or "unknown", | |||
| cast(str, self.retrieve_config.metadata_filtering_mode), | |||
| cast(ModelConfig, self.retrieve_config.metadata_model_config), | |||
| self.retrieve_config.metadata_filtering_conditions, | |||
| self.inputs, | |||
| ) | |||
| if metadata_filter_document_ids: | |||
| document_ids_filter = metadata_filter_document_ids.get(dataset.id, []) | |||
| else: | |||
| document_ids_filter = None | |||
| if dataset.provider == "external": | |||
| results = [] | |||
| external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( | |||
| @@ -68,7 +85,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| external_retrieval_parameters=dataset.retrieval_model, | |||
| metadata_condition=self.metadata_filtering_conditions, | |||
| metadata_condition=metadata_condition, | |||
| ) | |||
| for external_document in external_documents: | |||
| document = RetrievalDocument( | |||
| @@ -104,12 +121,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| return str("\n".join([item.page_content for item in results])) | |||
| else: | |||
| if metadata_condition and not document_ids_filter: | |||
| return "" | |||
| # get retrieval model , if the model is not setting , using default | |||
| retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| documents = RetrievalService.retrieve( | |||
| retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k | |||
| retrieval_method="keyword_search", | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=self.top_k, | |||
| document_ids_filter=document_ids_filter, | |||
| ) | |||
| return str("\n".join([document.page_content for document in documents])) | |||
| else: | |||
| @@ -128,6 +151,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| else None, | |||
| reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | |||
| weights=retrieval_model.get("weights"), | |||
| document_ids_filter=document_ids_filter, | |||
| ) | |||
| else: | |||
| documents = [] | |||
| @@ -34,6 +34,8 @@ class DatasetRetrieverTool(Tool): | |||
| return_resource: bool, | |||
| invoke_from: InvokeFrom, | |||
| hit_callback: DatasetIndexToolCallbackHandler, | |||
| user_id: str, | |||
| inputs: dict, | |||
| ) -> list["DatasetRetrieverTool"]: | |||
| """ | |||
| get dataset tool | |||
| @@ -57,6 +59,8 @@ class DatasetRetrieverTool(Tool): | |||
| return_resource=return_resource, | |||
| invoke_from=invoke_from, | |||
| hit_callback=hit_callback, | |||
| user_id=user_id, | |||
| inputs=inputs, | |||
| ) | |||
| if retrieval_tools is None or len(retrieval_tools) == 0: | |||
| return [] | |||
| @@ -356,12 +356,12 @@ class KnowledgeRetrievalNode(LLMNode): | |||
| ) | |||
| elif node_data.metadata_filtering_mode == "manual": | |||
| if node_data.metadata_filtering_conditions: | |||
| metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump()) | |||
| conditions = [] | |||
| if node_data.metadata_filtering_conditions: | |||
| for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore | |||
| metadata_name = condition.name | |||
| expected_value = condition.value | |||
| if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): | |||
| if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): | |||
| if isinstance(expected_value, str): | |||
| expected_value = self.graph_runtime_state.variable_pool.convert_template( | |||
| expected_value | |||
| @@ -372,13 +372,24 @@ class KnowledgeRetrievalNode(LLMNode): | |||
| expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore | |||
| else: | |||
| raise ValueError("Invalid expected metadata value type") | |||
| filters = self._process_metadata_filter_func( | |||
| sequence, | |||
| condition.comparison_operator, | |||
| metadata_name, | |||
| expected_value, | |||
| filters, | |||
| conditions.append( | |||
| Condition( | |||
| name=metadata_name, | |||
| comparison_operator=condition.comparison_operator, | |||
| value=expected_value, | |||
| ) | |||
| ) | |||
| filters = self._process_metadata_filter_func( | |||
| sequence, | |||
| condition.comparison_operator, | |||
| metadata_name, | |||
| expected_value, | |||
| filters, | |||
| ) | |||
| metadata_condition = MetadataCondition( | |||
| logical_operator=node_data.metadata_filtering_conditions.logical_operator, | |||
| conditions=conditions, | |||
| ) | |||
| else: | |||
| raise ValueError("Invalid metadata filtering mode") | |||
| if filters: | |||
| @@ -69,6 +69,7 @@ class HitTestingService: | |||
| query: str, | |||
| account: Account, | |||
| external_retrieval_model: dict, | |||
| metadata_filtering_conditions: dict, | |||
| ) -> dict: | |||
| if dataset.provider != "external": | |||
| return { | |||
| @@ -82,6 +83,7 @@ class HitTestingService: | |||
| dataset_id=dataset.id, | |||
| query=cls.escape_query_for_search(query), | |||
| external_retrieval_model=external_retrieval_model, | |||
| metadata_filtering_conditions=metadata_filtering_conditions, | |||
| ) | |||
| end = time.perf_counter() | |||