Co-authored-by: 金鹏程 <jinpengcheng01@corp.netease.com> Co-authored-by: crazywoola <427733928@qq.com>tags/1.4.0
| @@ -69,6 +69,13 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| tool_instances, prompt_messages_tools = self._init_prompt_tools() | |||
| self._prompt_messages_tools = prompt_messages_tools | |||
| # fix metadata filter not work | |||
| if app_config.dataset is not None: | |||
| metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions | |||
| for key, dataset_retriever_tool in tool_instances.items(): | |||
| if hasattr(dataset_retriever_tool, "retrieval_tool"): | |||
| dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions | |||
| function_call_state = True | |||
| llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} | |||
| final_answer = "" | |||
| @@ -45,6 +45,13 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| # convert tools into ModelRuntime Tool format | |||
| tool_instances, prompt_messages_tools = self._init_prompt_tools() | |||
| # fix metadata filter not work | |||
| if app_config.dataset is not None: | |||
| metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions | |||
| for key, dataset_retriever_tool in tool_instances.items(): | |||
| if hasattr(dataset_retriever_tool, "retrieval_tool"): | |||
| dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions | |||
| assert app_config.agent | |||
| iteration_step = 1 | |||
| @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.entities.context_entities import DocumentContext | |||
| from core.rag.entities.metadata_entities import MetadataCondition | |||
| from core.rag.models.document import Document as RetrievalDocument | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | |||
| @@ -33,6 +34,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| args_schema: type[BaseModel] = DatasetRetrieverToolInput | |||
| description: str = "use this to retrieve a dataset. " | |||
| dataset_id: str | |||
| metadata_filtering_conditions: MetadataCondition | |||
| @classmethod | |||
| def from_dataset(cls, dataset: Dataset, **kwargs): | |||
| @@ -46,6 +48,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| tenant_id=dataset.tenant_id, | |||
| dataset_id=dataset.id, | |||
| description=description, | |||
| metadata_filtering_conditions=MetadataCondition(), | |||
| **kwargs, | |||
| ) | |||
| @@ -65,6 +68,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| external_retrieval_parameters=dataset.retrieval_model, | |||
| metadata_condition=self.metadata_filtering_conditions, | |||
| ) | |||
| for external_document in external_documents: | |||
| document = RetrievalDocument( | |||