| @@ -81,6 +81,7 @@ from .datasets import ( | |||
| datasets_segments, | |||
| external, | |||
| hit_testing, | |||
| metadata, | |||
| website, | |||
| ) | |||
| @@ -621,7 +621,7 @@ class DocumentDetailApi(DocumentResource): | |||
| raise InvalidMetadataError(f"Invalid metadata value: {metadata}") | |||
| if metadata == "only": | |||
| response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} | |||
| response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} | |||
| elif metadata == "without": | |||
| dataset_process_rules = DatasetService.get_process_rules(dataset_id) | |||
| document_process_rules = document.dataset_process_rule.to_dict() | |||
| @@ -682,7 +682,7 @@ class DocumentDetailApi(DocumentResource): | |||
| "disabled_by": document.disabled_by, | |||
| "archived": document.archived, | |||
| "doc_type": document.doc_type, | |||
| "doc_metadata": document.doc_metadata, | |||
| "doc_metadata": document.doc_metadata_details, | |||
| "segment_count": document.segment_count, | |||
| "average_segment_length": document.average_segment_length, | |||
| "hit_count": document.hit_count, | |||
| @@ -0,0 +1,155 @@ | |||
| from flask_login import current_user # type: ignore # type: ignore | |||
| from flask_restful import Resource, marshal_with, reqparse # type: ignore | |||
| from werkzeug.exceptions import NotFound | |||
| from controllers.console import api | |||
| from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required | |||
| from fields.dataset_fields import dataset_metadata_fields | |||
| from libs.login import login_required | |||
| from services.dataset_service import DatasetService | |||
| from services.entities.knowledge_entities.knowledge_entities import ( | |||
| MetadataArgs, | |||
| MetadataOperationData, | |||
| ) | |||
| from services.metadata_service import MetadataService | |||
| def _validate_name(name): | |||
| if not name or len(name) < 1 or len(name) > 40: | |||
| raise ValueError("Name must be between 1 to 40 characters.") | |||
| return name | |||
| def _validate_description_length(description): | |||
| if len(description) > 400: | |||
| raise ValueError("Description cannot exceed 400 characters.") | |||
| return description | |||
| class DatasetMetadataCreateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @enterprise_license_required | |||
| @marshal_with(dataset_metadata_fields) | |||
| def post(self, dataset_id): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("type", type=str, required=True, nullable=True, location="json") | |||
| parser.add_argument("name", type=str, required=True, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| metadata_args = MetadataArgs(**args) | |||
| dataset_id_str = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) | |||
| return metadata, 201 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @enterprise_license_required | |||
| def get(self, dataset_id): | |||
| dataset_id_str = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| return MetadataService.get_dataset_metadatas(dataset), 200 | |||
| class DatasetMetadataApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @enterprise_license_required | |||
| @marshal_with(dataset_metadata_fields) | |||
| def patch(self, dataset_id, metadata_id): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=str, required=True, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| dataset_id_str = str(dataset_id) | |||
| metadata_id_str = str(metadata_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) | |||
| return metadata, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @enterprise_license_required | |||
| def delete(self, dataset_id, metadata_id): | |||
| dataset_id_str = str(dataset_id) | |||
| metadata_id_str = str(metadata_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| MetadataService.delete_metadata(dataset_id_str, metadata_id_str) | |||
| return 200 | |||
| class DatasetMetadataBuiltInFieldApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @enterprise_license_required | |||
| def get(self): | |||
| built_in_fields = MetadataService.get_built_in_fields() | |||
| return {"fields": built_in_fields}, 200 | |||
| class DatasetMetadataBuiltInFieldActionApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @enterprise_license_required | |||
| def post(self, dataset_id, action): | |||
| dataset_id_str = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| if action == "enable": | |||
| MetadataService.enable_built_in_field(dataset) | |||
| elif action == "disable": | |||
| MetadataService.disable_built_in_field(dataset) | |||
| return 200 | |||
| class DocumentMetadataEditApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @enterprise_license_required | |||
| def post(self, dataset_id): | |||
| dataset_id_str = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| metadata_args = MetadataOperationData(**args) | |||
| MetadataService.update_documents_metadata(dataset, metadata_args) | |||
| return 200 | |||
| api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata") | |||
| api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>") | |||
| api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in") | |||
| api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>") | |||
| api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata") | |||
| @@ -1,7 +1,12 @@ | |||
| import uuid | |||
| from typing import Optional | |||
| from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity | |||
| from core.app.app_config.entities import ( | |||
| DatasetEntity, | |||
| DatasetRetrieveConfigEntity, | |||
| MetadataFilteringCondition, | |||
| ModelConfig, | |||
| ) | |||
| from core.entities.agent_entities import PlanningStrategy | |||
| from models.model import AppMode | |||
| from services.dataset_service import DatasetService | |||
| @@ -78,6 +83,15 @@ class DatasetConfigManager: | |||
| retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( | |||
| dataset_configs["retrieval_model"] | |||
| ), | |||
| metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"), | |||
| metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config")) | |||
| if dataset_configs.get("metadata_model_config") | |||
| else None, | |||
| metadata_filtering_conditions=MetadataFilteringCondition( | |||
| **dataset_configs.get("metadata_filtering_conditions", {}) | |||
| ) | |||
| if dataset_configs.get("metadata_filtering_conditions") | |||
| else None, | |||
| ), | |||
| ) | |||
| else: | |||
| @@ -96,6 +110,15 @@ class DatasetConfigManager: | |||
| weights=dataset_configs.get("weights"), | |||
| reranking_enabled=dataset_configs.get("reranking_enabled", True), | |||
| rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), | |||
| metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"), | |||
| metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config")) | |||
| if dataset_configs.get("metadata_model_config") | |||
| else None, | |||
| metadata_filtering_conditions=MetadataFilteringCondition( | |||
| **dataset_configs.get("metadata_filtering_conditions", {}) | |||
| ) | |||
| if dataset_configs.get("metadata_filtering_conditions") | |||
| else None, | |||
| ), | |||
| ) | |||
| @@ -1,10 +1,11 @@ | |||
| from collections.abc import Sequence | |||
| from enum import Enum, StrEnum | |||
| from typing import Any, Optional | |||
| from typing import Any, Literal, Optional | |||
| from pydantic import BaseModel, Field, field_validator | |||
| from core.file import FileTransferMethod, FileType, FileUploadConfig | |||
| from core.model_runtime.entities.llm_entities import LLMMode | |||
| from core.model_runtime.entities.message_entities import PromptMessageRole | |||
| from models.model import AppMode | |||
| @@ -135,6 +136,55 @@ class ExternalDataVariableEntity(BaseModel): | |||
| config: dict[str, Any] = Field(default_factory=dict) | |||
| SupportedComparisonOperator = Literal[ | |||
| # for string or array | |||
| "contains", | |||
| "not contains", | |||
| "start with", | |||
| "end with", | |||
| "is", | |||
| "is not", | |||
| "empty", | |||
| "not empty", | |||
| # for number | |||
| "=", | |||
| "≠", | |||
| ">", | |||
| "<", | |||
| "≥", | |||
| "≤", | |||
| # for time | |||
| "before", | |||
| "after", | |||
| ] | |||
| class ModelConfig(BaseModel): | |||
| provider: str | |||
| name: str | |||
| mode: LLMMode | |||
| completion_params: dict[str, Any] = {} | |||
| class Condition(BaseModel): | |||
| """ | |||
| Conditon detail | |||
| """ | |||
| name: str | |||
| comparison_operator: SupportedComparisonOperator | |||
| value: str | Sequence[str] | None | int | float = None | |||
| class MetadataFilteringCondition(BaseModel): | |||
| """ | |||
| Metadata Filtering Condition. | |||
| """ | |||
| logical_operator: Optional[Literal["and", "or"]] = "and" | |||
| conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) | |||
| class DatasetRetrieveConfigEntity(BaseModel): | |||
| """ | |||
| Dataset Retrieve Config Entity. | |||
| @@ -171,6 +221,9 @@ class DatasetRetrieveConfigEntity(BaseModel): | |||
| reranking_model: Optional[dict] = None | |||
| weights: Optional[dict] = None | |||
| reranking_enabled: Optional[bool] = True | |||
| metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" | |||
| metadata_model_config: Optional[ModelConfig] = None | |||
| metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None | |||
| class DatasetEntity(BaseModel): | |||
| @@ -180,6 +180,7 @@ class ChatAppRunner(AppRunner): | |||
| hit_callback=hit_callback, | |||
| memory=memory, | |||
| message_id=message.id, | |||
| inputs=inputs, | |||
| ) | |||
| # reorganize all inputs and template to prompt messages | |||
| @@ -139,6 +139,7 @@ class CompletionAppRunner(AppRunner): | |||
| show_retrieve_source=app_config.additional_features.show_retrieve_source, | |||
| hit_callback=hit_callback, | |||
| message_id=message.id, | |||
| inputs=inputs, | |||
| ) | |||
| # reorganize all inputs and template to prompt messages | |||
| @@ -88,16 +88,17 @@ class Jieba(BaseKeyword): | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| k = kwargs.get("top_k", 4) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) | |||
| documents = [] | |||
| for chunk_index in sorted_chunk_indices: | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index) | |||
| .first() | |||
| segment_query = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index | |||
| ) | |||
| if document_ids_filter: | |||
| segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter)) | |||
| segment = segment_query.first() | |||
| if segment: | |||
| documents.append( | |||
| @@ -41,6 +41,7 @@ class RetrievalService: | |||
| reranking_model: Optional[dict] = None, | |||
| reranking_mode: str = "reranking_model", | |||
| weights: Optional[dict] = None, | |||
| document_ids_filter: Optional[list[str]] = None, | |||
| ): | |||
| if not query: | |||
| return [] | |||
| @@ -64,6 +65,7 @@ class RetrievalService: | |||
| top_k=top_k, | |||
| all_documents=all_documents, | |||
| exceptions=exceptions, | |||
| document_ids_filter=document_ids_filter, | |||
| ) | |||
| ) | |||
| if RetrievalMethod.is_support_semantic_search(retrieval_method): | |||
| @@ -79,6 +81,7 @@ class RetrievalService: | |||
| all_documents=all_documents, | |||
| retrieval_method=retrieval_method, | |||
| exceptions=exceptions, | |||
| document_ids_filter=document_ids_filter, | |||
| ) | |||
| ) | |||
| if RetrievalMethod.is_support_fulltext_search(retrieval_method): | |||
| @@ -130,7 +133,14 @@ class RetrievalService: | |||
| @classmethod | |||
| def keyword_search( | |||
| cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list | |||
| cls, | |||
| flask_app: Flask, | |||
| dataset_id: str, | |||
| query: str, | |||
| top_k: int, | |||
| all_documents: list, | |||
| exceptions: list, | |||
| document_ids_filter: Optional[list[str]] = None, | |||
| ): | |||
| with flask_app.app_context(): | |||
| try: | |||
| @@ -139,7 +149,10 @@ class RetrievalService: | |||
| raise ValueError("dataset not found") | |||
| keyword = Keyword(dataset=dataset) | |||
| documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k) | |||
| documents = keyword.search( | |||
| cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter | |||
| ) | |||
| all_documents.extend(documents) | |||
| except Exception as e: | |||
| exceptions.append(str(e)) | |||
| @@ -156,6 +169,7 @@ class RetrievalService: | |||
| all_documents: list, | |||
| retrieval_method: str, | |||
| exceptions: list, | |||
| document_ids_filter: Optional[list[str]] = None, | |||
| ): | |||
| with flask_app.app_context(): | |||
| try: | |||
| @@ -170,6 +184,7 @@ class RetrievalService: | |||
| top_k=top_k, | |||
| score_threshold=score_threshold, | |||
| filter={"group_id": [dataset.id]}, | |||
| document_ids_filter=document_ids_filter, | |||
| ) | |||
| if documents: | |||
| @@ -53,7 +53,7 @@ class AnalyticdbVector(BaseVector): | |||
| self.analyticdb_vector.delete_by_metadata_field(key, value) | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| return self.analyticdb_vector.search_by_vector(query_vector) | |||
| return self.analyticdb_vector.search_by_vector(query_vector, **kwargs) | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| return self.analyticdb_vector.search_by_full_text(query, **kwargs) | |||
| @@ -196,6 +196,11 @@ class AnalyticdbVectorBySql: | |||
| top_k = kwargs.get("top_k", 4) | |||
| if not isinstance(top_k, int) or top_k <= 0: | |||
| raise ValueError("top_k must be a positive integer") | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| where_clause = "WHERE 1=1" | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| where_clause += f"AND metadata_->>'document_id' IN ({document_ids})" | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| with self._get_cursor() as cur: | |||
| query_vector_str = json.dumps(query_vector) | |||
| @@ -204,7 +209,7 @@ class AnalyticdbVectorBySql: | |||
| f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, " | |||
| f"t.page_content as page_content, t.metadata_ AS metadata_ " | |||
| f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score " | |||
| f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t", | |||
| f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t", | |||
| (query_vector_str,), | |||
| ) | |||
| documents = [] | |||
| @@ -224,12 +229,17 @@ class AnalyticdbVectorBySql: | |||
| top_k = kwargs.get("top_k", 4) | |||
| if not isinstance(top_k, int) or top_k <= 0: | |||
| raise ValueError("top_k must be a positive integer") | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| where_clause = "" | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| where_clause += f"AND metadata_->>'document_id' IN ({document_ids})" | |||
| with self._get_cursor() as cur: | |||
| cur.execute( | |||
| f"""SELECT id, vector, page_content, metadata_, | |||
| ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score | |||
| FROM {self.table_name} | |||
| WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') | |||
| WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause} | |||
| ORDER BY score DESC | |||
| LIMIT {top_k}""", | |||
| (f"'{query}'", f"'{query}'"), | |||
| @@ -123,11 +123,21 @@ class BaiduVector(BaseVector): | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector] | |||
| anns = AnnSearch( | |||
| vector_field=self.field_vector, | |||
| vector_floats=query_vector, | |||
| params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), | |||
| ) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| anns = AnnSearch( | |||
| vector_field=self.field_vector, | |||
| vector_floats=query_vector, | |||
| params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), | |||
| filter=f"document_id IN ({document_ids})", | |||
| ) | |||
| else: | |||
| anns = AnnSearch( | |||
| vector_field=self.field_vector, | |||
| vector_floats=query_vector, | |||
| params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), | |||
| ) | |||
| res = self._db.table(self._collection_name).search( | |||
| anns=anns, | |||
| projections=[self.field_id, self.field_text, self.field_metadata], | |||
| @@ -95,7 +95,15 @@ class ChromaVector(BaseVector): | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| collection = self._client.get_or_create_collection(self._collection_name) | |||
| results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| results: QueryResult = collection.query( | |||
| query_embeddings=query_vector, | |||
| n_results=kwargs.get("top_k", 4), | |||
| where={"document_id": {"$in": document_ids_filter}}, # type: ignore | |||
| ) | |||
| else: | |||
| results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| # Check if results contain data | |||
| @@ -117,6 +117,9 @@ class ElasticSearchVector(BaseVector): | |||
| top_k = kwargs.get("top_k", 4) | |||
| num_candidates = math.ceil(top_k * 1.5) | |||
| knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates} | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} | |||
| results = self._client.search(index=self._collection_name, knn=knn, size=top_k) | |||
| @@ -145,6 +148,9 @@ class ElasticSearchVector(BaseVector): | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| query_str = {"match": {Field.CONTENT_KEY.value: query}} | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore | |||
| results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) | |||
| docs = [] | |||
| for hit in results["hits"]["hits"]: | |||
| @@ -168,7 +168,12 @@ class LindormVectorStore(BaseVector): | |||
| raise ValueError("All elements in query_vector should be floats") | |||
| top_k = kwargs.get("top_k", 10) | |||
| query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| filters = [] | |||
| if document_ids_filter: | |||
| filters.append({"terms": {"metadata.document_id": document_ids_filter}}) | |||
| query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs) | |||
| try: | |||
| params = {} | |||
| if self._using_ugc: | |||
| @@ -206,7 +211,10 @@ class LindormVectorStore(BaseVector): | |||
| should = kwargs.get("should") | |||
| minimum_should_match = kwargs.get("minimum_should_match", 0) | |||
| top_k = kwargs.get("top_k", 10) | |||
| filters = kwargs.get("filter") | |||
| filters = kwargs.get("filter", []) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| filters.append({"terms": {"metadata.document_id": document_ids_filter}}) | |||
| routing = self._routing | |||
| full_text_query = default_text_search_query( | |||
| query_text=query, | |||
| @@ -228,12 +228,18 @@ class MilvusVector(BaseVector): | |||
| """ | |||
| Search for documents by vector similarity. | |||
| """ | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| filter = "" | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| filter = f'metadata["document_id"] in ({document_ids})' | |||
| results = self._client.search( | |||
| collection_name=self._collection_name, | |||
| data=[query_vector], | |||
| anns_field=Field.VECTOR.value, | |||
| limit=kwargs.get("top_k", 4), | |||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | |||
| filter=filter, | |||
| ) | |||
| return self._process_search_results( | |||
| @@ -249,6 +255,11 @@ class MilvusVector(BaseVector): | |||
| if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value): | |||
| logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)") | |||
| return [] | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| filter = "" | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| filter = f'metadata["document_id"] in ({document_ids})' | |||
| results = self._client.search( | |||
| collection_name=self._collection_name, | |||
| @@ -256,6 +267,7 @@ class MilvusVector(BaseVector): | |||
| anns_field=Field.SPARSE_VECTOR.value, | |||
| limit=kwargs.get("top_k", 4), | |||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | |||
| filter=filter, | |||
| ) | |||
| return self._process_search_results( | |||
| @@ -133,6 +133,10 @@ class MyScaleVector(BaseVector): | |||
| if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 | |||
| else "" | |||
| ) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})" | |||
| sql = f""" | |||
| SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} | |||
| {where_str} ORDER BY dist {order.value} LIMIT {top_k} | |||
| @@ -154,6 +154,11 @@ class OceanBaseVector(BaseVector): | |||
| return [] | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| where_clause = None | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| where_clause = f"metadata->>'$.document_id' in ({document_ids})" | |||
| ef_search = kwargs.get("ef_search", self._hnsw_ef_search) | |||
| if ef_search != self._hnsw_ef_search: | |||
| self._client.set_ob_hnsw_ef_search(ef_search) | |||
| @@ -167,6 +172,7 @@ class OceanBaseVector(BaseVector): | |||
| distance_func=func.l2_distance, | |||
| output_column_names=["text", "metadata"], | |||
| with_dist=True, | |||
| where_clause=where_clause, | |||
| ) | |||
| docs = [] | |||
| for text, metadata, distance in cur: | |||
| @@ -154,6 +154,9 @@ class OpenSearchVector(BaseVector): | |||
| "size": kwargs.get("top_k", 4), | |||
| "query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}}, | |||
| } | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| query["query"] = {"terms": {"metadata.document_id": document_ids_filter}} | |||
| try: | |||
| response = self._client.search(index=self._collection_name.lower(), body=query) | |||
| @@ -179,6 +182,9 @@ class OpenSearchVector(BaseVector): | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}} | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter} | |||
| response = self._client.search(index=self._collection_name.lower(), body=full_text_query) | |||
| @@ -201,10 +201,15 @@ class OracleVector(BaseVector): | |||
| :return: List of Documents that are nearest to the query vector. | |||
| """ | |||
| top_k = kwargs.get("top_k", 4) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| where_clause = "" | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| where_clause = f"WHERE metadata->>'document_id' in ({document_ids})" | |||
| with self._get_cursor() as cur: | |||
| cur.execute( | |||
| f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}" | |||
| f" ORDER BY distance fetch first {top_k} rows only", | |||
| f" {where_clause} ORDER BY distance fetch first {top_k} rows only", | |||
| [numpy.array(query_vector)], | |||
| ) | |||
| docs = [] | |||
| @@ -257,9 +262,15 @@ class OracleVector(BaseVector): | |||
| if token not in stop_words: | |||
| entities.append(token) | |||
| with self._get_cursor() as cur: | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| where_clause = "" | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| where_clause = f" AND metadata->>'document_id' in ({document_ids}) " | |||
| cur.execute( | |||
| f"select meta, text, embedding FROM {self.table_name}" | |||
| f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", | |||
| f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} " | |||
| f"order by score(1) desc fetch first {top_k} rows only", | |||
| [" ACCUM ".join(entities)], | |||
| ) | |||
| docs = [] | |||
| @@ -189,6 +189,9 @@ class PGVectoRS(BaseVector): | |||
| .limit(kwargs.get("top_k", 4)) | |||
| .order_by("distance") | |||
| ) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| stmt = stmt.where(self._table.meta["document_id"].in_(document_ids_filter)) | |||
| res = session.execute(stmt) | |||
| results = [(row[0], row[1]) for row in res] | |||
| @@ -173,10 +173,16 @@ class PGVector(BaseVector): | |||
| top_k = kwargs.get("top_k", 4) | |||
| if not isinstance(top_k, int) or top_k <= 0: | |||
| raise ValueError("top_k must be a positive integer") | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| where_clause = "" | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| where_clause = f" WHERE metadata->>'document_id' in ({document_ids}) " | |||
| with self._get_cursor() as cur: | |||
| cur.execute( | |||
| f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}" | |||
| f" {where_clause}" | |||
| f" ORDER BY distance LIMIT {top_k}", | |||
| (json.dumps(query_vector),), | |||
| ) | |||
| @@ -195,12 +201,18 @@ class PGVector(BaseVector): | |||
| if not isinstance(top_k, int) or top_k <= 0: | |||
| raise ValueError("top_k must be a positive integer") | |||
| with self._get_cursor() as cur: | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| where_clause = "" | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| where_clause = f" AND metadata->>'document_id' in ({document_ids}) " | |||
| if self.pg_bigm: | |||
| cur.execute("SET pg_bigm.similarity_limit TO 0.000001") | |||
| cur.execute( | |||
| f"""SELECT meta, text, bigm_similarity(unistr(%s), coalesce(text, '')) AS score | |||
| FROM {self.table_name} | |||
| WHERE text =%% unistr(%s) | |||
| {where_clause} | |||
| ORDER BY score DESC | |||
| LIMIT {top_k}""", | |||
| # f"'{query}'" is required in order to account for whitespace in query | |||
| @@ -211,6 +223,7 @@ class PGVector(BaseVector): | |||
| f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score | |||
| FROM {self.table_name} | |||
| WHERE to_tsvector(text) @@ plainto_tsquery(%s) | |||
| {where_clause} | |||
| ORDER BY score DESC | |||
| LIMIT {top_k}""", | |||
| # f"'{query}'" is required in order to account for whitespace in query | |||
| @@ -286,27 +286,26 @@ class QdrantVector(BaseVector): | |||
| from qdrant_client.http import models | |||
| from qdrant_client.http.exceptions import UnexpectedResponse | |||
| for node_id in ids: | |||
| try: | |||
| filter = models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="metadata.doc_id", | |||
| match=models.MatchValue(value=node_id), | |||
| ), | |||
| ], | |||
| ) | |||
| self._client.delete( | |||
| collection_name=self._collection_name, | |||
| points_selector=FilterSelector(filter=filter), | |||
| ) | |||
| except UnexpectedResponse as e: | |||
| # Collection does not exist, so return | |||
| if e.status_code == 404: | |||
| return | |||
| # Some other error occurred, so re-raise the exception | |||
| else: | |||
| raise e | |||
| try: | |||
| filter = models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="metadata.doc_id", | |||
| match=models.MatchAny(any=ids), | |||
| ), | |||
| ], | |||
| ) | |||
| self._client.delete( | |||
| collection_name=self._collection_name, | |||
| points_selector=FilterSelector(filter=filter), | |||
| ) | |||
| except UnexpectedResponse as e: | |||
| # Collection does not exist, so return | |||
| if e.status_code == 404: | |||
| return | |||
| # Some other error occurred, so re-raise the exception | |||
| else: | |||
| raise e | |||
| def text_exists(self, id: str) -> bool: | |||
| all_collection_name = [] | |||
| @@ -331,6 +330,15 @@ class QdrantVector(BaseVector): | |||
| ), | |||
| ], | |||
| ) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| if filter.must: | |||
| filter.must.append( | |||
| models.FieldCondition( | |||
| key="metadata.document_id", | |||
| match=models.MatchAny(any=document_ids_filter), | |||
| ) | |||
| ) | |||
| results = self._client.search( | |||
| collection_name=self._collection_name, | |||
| query_vector=query_vector, | |||
| @@ -377,6 +385,15 @@ class QdrantVector(BaseVector): | |||
| ), | |||
| ] | |||
| ) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| if scroll_filter.must: | |||
| scroll_filter.must.append( | |||
| models.FieldCondition( | |||
| key="metadata.document_id", | |||
| match=models.MatchAny(any=document_ids_filter), | |||
| ) | |||
| ) | |||
| response = self._client.scroll( | |||
| collection_name=self._collection_name, | |||
| scroll_filter=scroll_filter, | |||
| @@ -223,8 +223,12 @@ class RelytVector(BaseVector): | |||
| return len(result) > 0 | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| filter = kwargs.get("filter", {}) | |||
| if document_ids_filter: | |||
| filter["document_id"] = document_ids_filter | |||
| results = self.similarity_search_with_score_by_vector( | |||
| k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter") | |||
| k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter | |||
| ) | |||
| # Organize results. | |||
| @@ -246,9 +250,9 @@ class RelytVector(BaseVector): | |||
| filter_condition = "" | |||
| if filter is not None: | |||
| conditions = [ | |||
| f"metadata->>{key!r} in ({', '.join(map(repr, value))})" | |||
| f"metadata->>'{key!r}' in ({', '.join(map(repr, value))})" | |||
| if len(value) > 1 | |||
| else f"metadata->>{key!r} = {value[0]!r}" | |||
| else f"metadata->>'{key!r}' = {value[0]!r}" | |||
| for key, value in filter.items() | |||
| ] | |||
| filter_condition = f"WHERE {' AND '.join(conditions)}" | |||
| @@ -145,11 +145,16 @@ class TencentVector(BaseVector): | |||
| self._db.collection(self._collection_name).delete(document_ids=ids) | |||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||
| self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value]))) | |||
| self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value]))) | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| filter = None | |||
| if document_ids_filter: | |||
| filter = Filter(Filter.In("metadata.document_id", document_ids_filter)) | |||
| res = self._db.collection(self._collection_name).search( | |||
| vectors=[query_vector], | |||
| filter=filter, | |||
| params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)), | |||
| retrieve_vector=False, | |||
| limit=kwargs.get("top_k", 4), | |||
| @@ -326,6 +326,18 @@ class TidbOnQdrantVector(BaseVector): | |||
| ), | |||
| ], | |||
| ) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| should_conditions = [] | |||
| for document_id_filter in document_ids_filter: | |||
| should_conditions.append( | |||
| models.FieldCondition( | |||
| key="metadata.document_id", | |||
| match=models.MatchValue(value=document_id_filter), | |||
| ) | |||
| ) | |||
| if should_conditions: | |||
| filter.should = should_conditions # type: ignore | |||
| results = self._client.search( | |||
| collection_name=self._collection_name, | |||
| query_vector=query_vector, | |||
| @@ -368,6 +380,18 @@ class TidbOnQdrantVector(BaseVector): | |||
| ) | |||
| ] | |||
| ) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| should_conditions = [] | |||
| for document_id_filter in document_ids_filter: | |||
| should_conditions.append( | |||
| models.FieldCondition( | |||
| key="metadata.document_id", | |||
| match=models.MatchValue(value=document_id_filter), | |||
| ) | |||
| ) | |||
| if should_conditions: | |||
| scroll_filter.should = should_conditions # type: ignore | |||
| response = self._client.scroll( | |||
| collection_name=self._collection_name, | |||
| scroll_filter=scroll_filter, | |||
| @@ -196,6 +196,11 @@ class TiDBVector(BaseVector): | |||
| docs = [] | |||
| tidb_dist_func = self._get_distance_func() | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| where_clause = "" | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| where_clause = f" WHERE meta->>'$.document_id' in ({document_ids}) " | |||
| with Session(self._engine) as session: | |||
| select_statement = sql_text(f""" | |||
| @@ -206,6 +211,7 @@ class TiDBVector(BaseVector): | |||
| text, | |||
| {tidb_dist_func}(vector, :query_vector_str) AS distance | |||
| FROM {self._collection_name} | |||
| {where_clause} | |||
| ORDER BY distance ASC | |||
| LIMIT :top_k | |||
| ) t | |||
| @@ -88,7 +88,20 @@ class UpstashVector(BaseVector): | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| top_k = kwargs.get("top_k", 4) | |||
| result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| filter = f"document_id in ({document_ids})" | |||
| else: | |||
| filter = "" | |||
| result = self.index.query( | |||
| vector=query_vector, | |||
| top_k=top_k, | |||
| include_metadata=True, | |||
| include_data=True, | |||
| include_vectors=False, | |||
| filter=filter, | |||
| ) | |||
| docs = [] | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| for record in result: | |||
| @@ -177,7 +177,11 @@ class VikingDBVector(BaseVector): | |||
| query_vector, limit=kwargs.get("top_k", 4) | |||
| ) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| return self._get_search_res(results, score_threshold) | |||
| docs = self._get_search_res(results, score_threshold) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| docs = [doc for doc in docs if doc.metadata.get("document_id") in document_ids_filter] | |||
| return docs | |||
| def _get_search_res(self, results, score_threshold) -> list[Document]: | |||
| if len(results) == 0: | |||
| @@ -187,8 +187,10 @@ class WeaviateVector(BaseVector): | |||
| query_obj = self._client.query.get(collection_name, properties) | |||
| vector = {"vector": query_vector} | |||
| if kwargs.get("where_filter"): | |||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter} | |||
| query_obj = query_obj.with_where(where_filter) | |||
| result = ( | |||
| query_obj.with_near_vector(vector) | |||
| .with_limit(kwargs.get("top_k", 4)) | |||
| @@ -233,8 +235,10 @@ class WeaviateVector(BaseVector): | |||
| if kwargs.get("search_distance"): | |||
| content["certainty"] = kwargs.get("search_distance") | |||
| query_obj = self._client.query.get(collection_name, properties) | |||
| if kwargs.get("where_filter"): | |||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| if document_ids_filter: | |||
| where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter} | |||
| query_obj = query_obj.with_where(where_filter) | |||
| query_obj = query_obj.with_additional(["vector"]) | |||
| properties = ["text"] | |||
| result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do() | |||
| @@ -0,0 +1,45 @@ | |||
| from collections.abc import Sequence | |||
| from typing import Literal, Optional | |||
| from pydantic import BaseModel, Field | |||
| SupportedComparisonOperator = Literal[ | |||
| # for string or array | |||
| "contains", | |||
| "not contains", | |||
| "start with", | |||
| "end with", | |||
| "is", | |||
| "is not", | |||
| "empty", | |||
| "not empty", | |||
| # for number | |||
| "=", | |||
| "≠", | |||
| ">", | |||
| "<", | |||
| "≥", | |||
| "≤", | |||
| # for time | |||
| "before", | |||
| "after", | |||
| ] | |||
| class Condition(BaseModel): | |||
| """ | |||
| Conditon detail | |||
| """ | |||
| name: str | |||
| comparison_operator: SupportedComparisonOperator | |||
| value: str | Sequence[str] | None | int | float = None | |||
| class MetadataCondition(BaseModel): | |||
| """ | |||
| Metadata Condition. | |||
| """ | |||
| logical_operator: Optional[Literal["and", "or"]] = "and" | |||
| conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) | |||
| @@ -0,0 +1,15 @@ | |||
| from enum import Enum | |||
| class BuiltInField(str, Enum): | |||
| document_name = "document_name" | |||
| uploader = "uploader" | |||
| upload_date = "upload_date" | |||
| last_update_date = "last_update_date" | |||
| source = "source" | |||
| class MetadataDataSource(Enum): | |||
| upload_file = "file_upload" | |||
| website_crawl = "website" | |||
| notion_import = "notion" | |||
| @@ -1,35 +1,61 @@ | |||
| import json | |||
| import math | |||
| import re | |||
| import threading | |||
| from collections import Counter | |||
| from typing import Any, Optional, cast | |||
| from collections import Counter, defaultdict | |||
| from collections.abc import Generator, Mapping | |||
| from typing import Any, Optional, Union, cast | |||
| from flask import Flask, current_app | |||
| from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity | |||
| from sqlalchemy import Integer, and_, or_, text | |||
| from sqlalchemy import cast as sqlalchemy_cast | |||
| from core.app.app_config.entities import ( | |||
| DatasetEntity, | |||
| DatasetRetrieveConfigEntity, | |||
| MetadataFilteringCondition, | |||
| ModelConfig, | |||
| ) | |||
| from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.entities.agent_entities import PlanningStrategy | |||
| from core.entities.model_entities import ModelStatus | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance, ModelManager | |||
| from core.model_runtime.entities.message_entities import PromptMessageTool | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | |||
| from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool | |||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.ops.entities.trace_entity import TraceTaskName | |||
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | |||
| from core.ops.utils import measure_time | |||
| from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | |||
| from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate | |||
| from core.prompt.simple_prompt_transform import ModelMode | |||
| from core.rag.data_post_processor.data_post_processor import DataPostProcessor | |||
| from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.entities.context_entities import DocumentContext | |||
| from core.rag.entities.metadata_entities import Condition, MetadataCondition | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.models.document import Document | |||
| from core.rag.rerank.rerank_type import RerankMode | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter | |||
| from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter | |||
| from core.rag.retrieval.template_prompts import ( | |||
| METADATA_FILTER_ASSISTANT_PROMPT_1, | |||
| METADATA_FILTER_ASSISTANT_PROMPT_2, | |||
| METADATA_FILTER_COMPLETION_PROMPT, | |||
| METADATA_FILTER_SYSTEM_PROMPT, | |||
| METADATA_FILTER_USER_PROMPT_1, | |||
| METADATA_FILTER_USER_PROMPT_2, | |||
| METADATA_FILTER_USER_PROMPT_3, | |||
| ) | |||
| from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | |||
| from extensions.ext_database import db | |||
| from models.dataset import ChildChunk, Dataset, DatasetQuery, DocumentSegment | |||
| from libs.json_in_md_parser import parse_and_check_json_markdown | |||
| from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| from services.external_knowledge_service import ExternalDatasetService | |||
| @@ -59,6 +85,7 @@ class DatasetRetrieval: | |||
| hit_callback: DatasetIndexToolCallbackHandler, | |||
| message_id: str, | |||
| memory: Optional[TokenBufferMemory] = None, | |||
| inputs: Optional[Mapping[str, Any]] = None, | |||
| ) -> Optional[str]: | |||
| """ | |||
| Retrieve dataset. | |||
| @@ -116,6 +143,22 @@ class DatasetRetrieval: | |||
| continue | |||
| available_datasets.append(dataset) | |||
| if inputs: | |||
| inputs = {key: str(value) for key, value in inputs.items()} | |||
| else: | |||
| inputs = {} | |||
| available_datasets_ids = [dataset.id for dataset in available_datasets] | |||
| metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( | |||
| available_datasets_ids, | |||
| query, | |||
| tenant_id, | |||
| user_id, | |||
| retrieve_config.metadata_filtering_mode, # type: ignore | |||
| retrieve_config.metadata_model_config, # type: ignore | |||
| retrieve_config.metadata_filtering_conditions, | |||
| inputs, | |||
| ) | |||
| all_documents = [] | |||
| user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" | |||
| if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: | |||
| @@ -130,6 +173,8 @@ class DatasetRetrieval: | |||
| model_config, | |||
| planning_strategy, | |||
| message_id, | |||
| metadata_filter_document_ids, | |||
| metadata_condition, | |||
| ) | |||
| elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: | |||
| all_documents = self.multiple_retrieve( | |||
| @@ -146,6 +191,8 @@ class DatasetRetrieval: | |||
| retrieve_config.weights, | |||
| retrieve_config.reranking_enabled or True, | |||
| message_id, | |||
| metadata_filter_document_ids, | |||
| metadata_condition, | |||
| ) | |||
| dify_documents = [item for item in all_documents if item.provider == "dify"] | |||
| @@ -239,6 +286,8 @@ class DatasetRetrieval: | |||
| model_config: ModelConfigWithCredentialsEntity, | |||
| planning_strategy: PlanningStrategy, | |||
| message_id: Optional[str] = None, | |||
| metadata_filter_document_ids: Optional[dict[str, list[str]]] = None, | |||
| metadata_condition: Optional[MetadataCondition] = None, | |||
| ): | |||
| tools = [] | |||
| for dataset in available_datasets: | |||
| @@ -279,6 +328,7 @@ class DatasetRetrieval: | |||
| dataset_id=dataset_id, | |||
| query=query, | |||
| external_retrieval_parameters=dataset.retrieval_model, | |||
| metadata_condition=metadata_condition, | |||
| ) | |||
| for external_document in external_documents: | |||
| document = Document( | |||
| @@ -293,6 +343,15 @@ class DatasetRetrieval: | |||
| document.metadata["dataset_name"] = dataset.name | |||
| results.append(document) | |||
| else: | |||
| if metadata_condition and not metadata_filter_document_ids: | |||
| return [] | |||
| document_ids_filter = None | |||
| if metadata_filter_document_ids: | |||
| document_ids = metadata_filter_document_ids.get(dataset.id, []) | |||
| if document_ids: | |||
| document_ids_filter = document_ids | |||
| else: | |||
| return [] | |||
| retrieval_model_config = dataset.retrieval_model or default_retrieval_model | |||
| # get top k | |||
| @@ -324,6 +383,7 @@ class DatasetRetrieval: | |||
| reranking_model=reranking_model, | |||
| reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"), | |||
| weights=retrieval_model_config.get("weights", None), | |||
| document_ids_filter=document_ids_filter, | |||
| ) | |||
| self._on_query(query, [dataset_id], app_id, user_from, user_id) | |||
| @@ -348,6 +408,8 @@ class DatasetRetrieval: | |||
| weights: Optional[dict[str, Any]] = None, | |||
| reranking_enable: bool = True, | |||
| message_id: Optional[str] = None, | |||
| metadata_filter_document_ids: Optional[dict[str, list[str]]] = None, | |||
| metadata_condition: Optional[MetadataCondition] = None, | |||
| ): | |||
| if not available_datasets: | |||
| return [] | |||
| @@ -387,6 +449,16 @@ class DatasetRetrieval: | |||
| for dataset in available_datasets: | |||
| index_type = dataset.indexing_technique | |||
| document_ids_filter = None | |||
| if dataset.provider != "external": | |||
| if metadata_condition and not metadata_filter_document_ids: | |||
| continue | |||
| if metadata_filter_document_ids: | |||
| document_ids = metadata_filter_document_ids.get(dataset.id, []) | |||
| if document_ids: | |||
| document_ids_filter = document_ids | |||
| else: | |||
| continue | |||
| retrieval_thread = threading.Thread( | |||
| target=self._retriever, | |||
| kwargs={ | |||
| @@ -395,6 +467,8 @@ class DatasetRetrieval: | |||
| "query": query, | |||
| "top_k": top_k, | |||
| "all_documents": all_documents, | |||
| "document_ids_filter": document_ids_filter, | |||
| "metadata_condition": metadata_condition, | |||
| }, | |||
| ) | |||
| threads.append(retrieval_thread) | |||
| @@ -493,7 +567,16 @@ class DatasetRetrieval: | |||
| db.session.add_all(dataset_queries) | |||
| db.session.commit() | |||
| def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): | |||
| def _retriever( | |||
| self, | |||
| flask_app: Flask, | |||
| dataset_id: str, | |||
| query: str, | |||
| top_k: int, | |||
| all_documents: list, | |||
| document_ids_filter: Optional[list[str]] = None, | |||
| metadata_condition: Optional[MetadataCondition] = None, | |||
| ): | |||
| with flask_app.app_context(): | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| @@ -506,6 +589,7 @@ class DatasetRetrieval: | |||
| dataset_id=dataset_id, | |||
| query=query, | |||
| external_retrieval_parameters=dataset.retrieval_model, | |||
| metadata_condition=metadata_condition, | |||
| ) | |||
| for external_document in external_documents: | |||
| document = Document( | |||
| @@ -546,6 +630,7 @@ class DatasetRetrieval: | |||
| else None, | |||
| reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | |||
| weights=retrieval_model.get("weights", None), | |||
| document_ids_filter=document_ids_filter, | |||
| ) | |||
| all_documents.extend(documents) | |||
| @@ -733,3 +818,340 @@ class DatasetRetrieval: | |||
| filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True | |||
| ) | |||
| return filter_documents[:top_k] if top_k else filter_documents | |||
| def _get_metadata_filter_condition( | |||
| self, | |||
| dataset_ids: list, | |||
| query: str, | |||
| tenant_id: str, | |||
| user_id: str, | |||
| metadata_filtering_mode: str, | |||
| metadata_model_config: ModelConfig, | |||
| metadata_filtering_conditions: Optional[MetadataFilteringCondition], | |||
| inputs: dict, | |||
| ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: | |||
| document_query = db.session.query(DatasetDocument).filter( | |||
| DatasetDocument.dataset_id.in_(dataset_ids), | |||
| DatasetDocument.indexing_status == "completed", | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ) | |||
| filters = [] # type: ignore | |||
| metadata_condition = None | |||
| if metadata_filtering_mode == "disabled": | |||
| return None, None | |||
| elif metadata_filtering_mode == "automatic": | |||
| automatic_metadata_filters = self._automatic_metadata_filter_func( | |||
| dataset_ids, query, tenant_id, user_id, metadata_model_config | |||
| ) | |||
| if automatic_metadata_filters: | |||
| conditions = [] | |||
| for filter in automatic_metadata_filters: | |||
| self._process_metadata_filter_func( | |||
| filter.get("condition"), # type: ignore | |||
| filter.get("metadata_name"), # type: ignore | |||
| filter.get("value"), | |||
| filters, # type: ignore | |||
| ) | |||
| conditions.append( | |||
| Condition( | |||
| name=filter.get("metadata_name"), # type: ignore | |||
| comparison_operator=filter.get("condition"), # type: ignore | |||
| value=filter.get("value"), | |||
| ) | |||
| ) | |||
| metadata_condition = MetadataCondition( | |||
| logical_operator=metadata_filtering_conditions.logical_operator, # type: ignore | |||
| conditions=conditions, | |||
| ) | |||
| elif metadata_filtering_mode == "manual": | |||
| if metadata_filtering_conditions: | |||
| metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump()) | |||
| for condition in metadata_filtering_conditions.conditions: # type: ignore | |||
| metadata_name = condition.name | |||
| expected_value = condition.value | |||
| if expected_value or condition.comparison_operator in ("empty", "not empty"): | |||
| if isinstance(expected_value, str): | |||
| expected_value = self._replace_metadata_filter_value(expected_value, inputs) | |||
| filters = self._process_metadata_filter_func( | |||
| condition.comparison_operator, metadata_name, expected_value, filters | |||
| ) | |||
| else: | |||
| raise ValueError("Invalid metadata filtering mode") | |||
| if filters: | |||
| if metadata_filtering_conditions.logical_operator == "or": # type: ignore | |||
| document_query = document_query.filter(or_(*filters)) | |||
| else: | |||
| document_query = document_query.filter(and_(*filters)) | |||
| documents = document_query.all() | |||
| # group by dataset_id | |||
| metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore | |||
| for document in documents: | |||
| metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore | |||
| return metadata_filter_document_ids, metadata_condition | |||
| def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str: | |||
| def replacer(match): | |||
| key = match.group(1) | |||
| return str(inputs.get(key, f"{{{{{key}}}}}")) | |||
| pattern = re.compile(r"\{\{(\w+)\}\}") | |||
| return pattern.sub(replacer, text) | |||
| def _automatic_metadata_filter_func( | |||
| self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig | |||
| ) -> Optional[list[dict[str, Any]]]: | |||
| # get all metadata field | |||
| metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() | |||
| all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] | |||
| # get metadata model config | |||
| if metadata_model_config is None: | |||
| raise ValueError("metadata_model_config is required") | |||
| # get metadata model instance | |||
| # fetch model config | |||
| model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config) | |||
| # fetch prompt messages | |||
| prompt_messages, stop = self._get_prompt_template( | |||
| model_config=model_config, | |||
| mode=metadata_model_config.mode, | |||
| metadata_fields=all_metadata_fields, | |||
| query=query or "", | |||
| ) | |||
| result_text = "" | |||
| try: | |||
| # handle invoke result | |||
| invoke_result = cast( | |||
| Generator[LLMResult, None, None], | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=model_config.parameters, | |||
| stop=stop, | |||
| stream=True, | |||
| user=user_id, | |||
| ), | |||
| ) | |||
| # handle invoke result | |||
| result_text, usage = self._handle_invoke_result(invoke_result=invoke_result) | |||
| result_text_json = parse_and_check_json_markdown(result_text, []) | |||
| automatic_metadata_filters = [] | |||
| if "metadata_map" in result_text_json: | |||
| metadata_map = result_text_json["metadata_map"] | |||
| for item in metadata_map: | |||
| if item.get("metadata_field_name") in all_metadata_fields: | |||
| automatic_metadata_filters.append( | |||
| { | |||
| "metadata_name": item.get("metadata_field_name"), | |||
| "value": item.get("metadata_field_value"), | |||
| "condition": item.get("comparison_operator"), | |||
| } | |||
| ) | |||
| except Exception as e: | |||
| return None | |||
| return automatic_metadata_filters | |||
| def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[Any], filters: list): | |||
| match condition: | |||
| case "contains": | |||
| filters.append( | |||
| (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%") | |||
| ) | |||
| case "not contains": | |||
| filters.append( | |||
| (text("documents.doc_metadata ->> :key NOT LIKE :value")).params( | |||
| key=metadata_name, value=f"%{value}%" | |||
| ) | |||
| ) | |||
| case "start with": | |||
| filters.append( | |||
| (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%") | |||
| ) | |||
| case "end with": | |||
| filters.append( | |||
| (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}") | |||
| ) | |||
| case "is" | "=": | |||
| if isinstance(value, str): | |||
| filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"') | |||
| else: | |||
| filters.append( | |||
| sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value | |||
| ) | |||
| case "is not" | "≠": | |||
| if isinstance(value, str): | |||
| filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"') | |||
| else: | |||
| filters.append( | |||
| sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value | |||
| ) | |||
| case "empty": | |||
| filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None)) | |||
| case "not empty": | |||
| filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None)) | |||
| case "before" | "<": | |||
| filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value) | |||
| case "after" | ">": | |||
| filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value) | |||
| case "≤" | ">=": | |||
| filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value) | |||
| case "≥" | ">=": | |||
| filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value) | |||
| case _: | |||
| pass | |||
| return filters | |||
| def _fetch_model_config( | |||
| self, tenant_id: str, model: ModelConfig | |||
| ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: | |||
| """ | |||
| Fetch model config | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| if model is None: | |||
| raise ValueError("single_retrieval_config is required") | |||
| model_name = model.name | |||
| provider_name = model.provider | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_model_instance( | |||
| tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name | |||
| ) | |||
| provider_model_bundle = model_instance.provider_model_bundle | |||
| model_type_instance = model_instance.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| model_credentials = model_instance.credentials | |||
| # check model | |||
| provider_model = provider_model_bundle.configuration.get_provider_model( | |||
| model=model_name, model_type=ModelType.LLM | |||
| ) | |||
| if provider_model is None: | |||
| raise ValueError(f"Model {model_name} not exist.") | |||
| if provider_model.status == ModelStatus.NO_CONFIGURE: | |||
| raise ValueError(f"Model {model_name} credentials is not initialized.") | |||
| elif provider_model.status == ModelStatus.NO_PERMISSION: | |||
| raise ValueError(f"Dify Hosted OpenAI {model_name} currently not support.") | |||
| elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: | |||
| raise ValueError(f"Model provider {provider_name} quota exceeded.") | |||
| # model config | |||
| completion_params = model.completion_params | |||
| stop = [] | |||
| if "stop" in completion_params: | |||
| stop = completion_params["stop"] | |||
| del completion_params["stop"] | |||
| # get model mode | |||
| model_mode = model.mode | |||
| if not model_mode: | |||
| raise ValueError("LLM mode is required.") | |||
| model_schema = model_type_instance.get_model_schema(model_name, model_credentials) | |||
| if not model_schema: | |||
| raise ValueError(f"Model {model_name} not exist.") | |||
| return model_instance, ModelConfigWithCredentialsEntity( | |||
| provider=provider_name, | |||
| model=model_name, | |||
| model_schema=model_schema, | |||
| mode=model_mode, | |||
| provider_model_bundle=provider_model_bundle, | |||
| credentials=model_credentials, | |||
| parameters=completion_params, | |||
| stop=stop, | |||
| ) | |||
| def _get_prompt_template( | |||
| self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str | |||
| ): | |||
| model_mode = ModelMode.value_of(mode) | |||
| input_text = query | |||
| prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] | |||
| if model_mode == ModelMode.CHAT: | |||
| prompt_template = [] | |||
| system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT) | |||
| prompt_template.append(system_prompt_messages) | |||
| user_prompt_message_1 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1) | |||
| prompt_template.append(user_prompt_message_1) | |||
| assistant_prompt_message_1 = ChatModelMessage( | |||
| role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1 | |||
| ) | |||
| prompt_template.append(assistant_prompt_message_1) | |||
| user_prompt_message_2 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2) | |||
| prompt_template.append(user_prompt_message_2) | |||
| assistant_prompt_message_2 = ChatModelMessage( | |||
| role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2 | |||
| ) | |||
| prompt_template.append(assistant_prompt_message_2) | |||
| user_prompt_message_3 = ChatModelMessage( | |||
| role=PromptMessageRole.USER, | |||
| text=METADATA_FILTER_USER_PROMPT_3.format( | |||
| input_text=input_text, | |||
| metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), | |||
| ), | |||
| ) | |||
| prompt_template.append(user_prompt_message_3) | |||
| elif model_mode == ModelMode.COMPLETION: | |||
| prompt_template = CompletionModelPromptTemplate( | |||
| text=METADATA_FILTER_COMPLETION_PROMPT.format( | |||
| input_text=input_text, | |||
| metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), | |||
| ) | |||
| ) | |||
| else: | |||
| raise ValueError(f"Model mode {model_mode} not support.") | |||
| prompt_transform = AdvancedPromptTransform() | |||
| prompt_messages = prompt_transform.get_prompt( | |||
| prompt_template=prompt_template, | |||
| inputs={}, | |||
| query=query or "", | |||
| files=[], | |||
| context=None, | |||
| memory_config=None, | |||
| memory=None, | |||
| model_config=model_config, | |||
| ) | |||
| stop = model_config.stop | |||
| return prompt_messages, stop | |||
| def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: | |||
| """ | |||
| Handle invoke result | |||
| :param invoke_result: invoke result | |||
| :return: | |||
| """ | |||
| model = None | |||
| prompt_messages: list[PromptMessage] = [] | |||
| full_text = "" | |||
| usage = None | |||
| for result in invoke_result: | |||
| text = result.delta.message.content | |||
| full_text += text | |||
| if not model: | |||
| model = result.model | |||
| if not prompt_messages: | |||
| prompt_messages = result.prompt_messages | |||
| if not usage and result.delta.usage: | |||
| usage = result.delta.usage | |||
| if not usage: | |||
| usage = LLMUsage.empty_usage() | |||
| return full_text, usage | |||
| @@ -0,0 +1,66 @@ | |||
| METADATA_FILTER_SYSTEM_PROMPT = """ | |||
| ### Job Description', | |||
| You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value | |||
| ### Task | |||
| Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". | |||
| ### Format | |||
| The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. | |||
| ### Constraint | |||
| DO NOT include anything other than the JSON array in your response. | |||
| """ # noqa: E501 | |||
| METADATA_FILTER_USER_PROMPT_1 = """ | |||
| { "input_text": "I want to know which company’s email address test@example.com is?", | |||
| "metadata_fields": ["filename", "email", "phone", "address"] | |||
| } | |||
| """ | |||
| METADATA_FILTER_ASSISTANT_PROMPT_1 = """ | |||
| ```json | |||
| {"metadata_map": [ | |||
| {"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="} | |||
| ] | |||
| } | |||
| ``` | |||
| """ | |||
| METADATA_FILTER_USER_PROMPT_2 = """ | |||
| {"input_text": "What are the movies with a score of more than 9 in 2024?", | |||
| "metadata_fields": ["name", "year", "rating", "country"]} | |||
| """ | |||
| METADATA_FILTER_ASSISTANT_PROMPT_2 = """ | |||
| ```json | |||
| {"metadata_map": [ | |||
| {"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, | |||
| {"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}, | |||
| ]} | |||
| ``` | |||
| """ | |||
| METADATA_FILTER_USER_PROMPT_3 = """ | |||
| '{{"input_text": "{input_text}",', | |||
| '"metadata_fields": {metadata_fields}}}' | |||
| """ | |||
| METADATA_FILTER_COMPLETION_PROMPT = """ | |||
| ### Job Description | |||
| You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value | |||
| ### Task | |||
| # Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". | |||
| ### Format | |||
| The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. | |||
| ### Constraint | |||
| DO NOT include anything other than the JSON array in your response. | |||
| ### Example | |||
| Here is the chat example between human and assistant, inside <example></example> XML tags. | |||
| <example> | |||
| User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}} | |||
| Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}} | |||
| User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}} | |||
| Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}} | |||
| </example> | |||
| ### User Input | |||
| {{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}} | |||
| ### Assistant Output | |||
| """ # noqa: E501 | |||
| @@ -1,8 +1,10 @@ | |||
| from collections.abc import Sequence | |||
| from typing import Any, Literal, Optional | |||
| from pydantic import BaseModel | |||
| from pydantic import BaseModel, Field | |||
| from core.workflow.nodes.base import BaseNodeData | |||
| from core.workflow.nodes.llm.entities import VisionConfig | |||
| class RerankingModelConfig(BaseModel): | |||
| @@ -73,6 +75,48 @@ class SingleRetrievalConfig(BaseModel): | |||
| model: ModelConfig | |||
| SupportedComparisonOperator = Literal[ | |||
| # for string or array | |||
| "contains", | |||
| "not contains", | |||
| "start with", | |||
| "end with", | |||
| "is", | |||
| "is not", | |||
| "empty", | |||
| "not empty", | |||
| # for number | |||
| "=", | |||
| "≠", | |||
| ">", | |||
| "<", | |||
| "≥", | |||
| "≤", | |||
| # for time | |||
| "before", | |||
| "after", | |||
| ] | |||
| class Condition(BaseModel): | |||
| """ | |||
| Conditon detail | |||
| """ | |||
| name: str | |||
| comparison_operator: SupportedComparisonOperator | |||
| value: str | Sequence[str] | None | int | float = None | |||
| class MetadataFilteringCondition(BaseModel): | |||
| """ | |||
| Metadata Filtering Condition. | |||
| """ | |||
| logical_operator: Optional[Literal["and", "or"]] = "and" | |||
| conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) | |||
| class KnowledgeRetrievalNodeData(BaseNodeData): | |||
| """ | |||
| Knowledge retrieval Node Data. | |||
| @@ -84,3 +128,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData): | |||
| retrieval_mode: Literal["single", "multiple"] | |||
| multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None | |||
| single_retrieval_config: Optional[SingleRetrievalConfig] = None | |||
| metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" | |||
| metadata_model_config: Optional[ModelConfig] = None | |||
| metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None | |||
| vision: VisionConfig = Field(default_factory=VisionConfig) | |||
| @@ -16,3 +16,7 @@ class ModelNotSupportedError(KnowledgeRetrievalNodeError): | |||
| class ModelQuotaExceededError(KnowledgeRetrievalNodeError): | |||
| """Raised when the model provider quota is exceeded.""" | |||
| class InvalidModelTypeError(KnowledgeRetrievalNodeError): | |||
| """Raised when the model is not a Large Language Model.""" | |||
| @@ -1,32 +1,51 @@ | |||
| import json | |||
| import logging | |||
| import time | |||
| from collections import defaultdict | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, cast | |||
| from typing import Any, Optional, cast | |||
| from sqlalchemy import func | |||
| from sqlalchemy import Integer, and_, func, or_, text | |||
| from sqlalchemy import cast as sqlalchemy_cast | |||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.entities.agent_entities import PlanningStrategy | |||
| from core.entities.model_entities import ModelStatus | |||
| from core.model_manager import ModelInstance, ModelManager | |||
| from core.model_runtime.entities.message_entities import PromptMessageRole | |||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.prompt.simple_prompt_transform import ModelMode | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.entities.metadata_entities import Condition, MetadataCondition | |||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.variables import StringSegment | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.event.event import ModelInvokeCompletedEvent | |||
| from core.workflow.nodes.knowledge_retrieval.template_prompts import ( | |||
| METADATA_FILTER_ASSISTANT_PROMPT_1, | |||
| METADATA_FILTER_ASSISTANT_PROMPT_2, | |||
| METADATA_FILTER_COMPLETION_PROMPT, | |||
| METADATA_FILTER_SYSTEM_PROMPT, | |||
| METADATA_FILTER_USER_PROMPT_1, | |||
| METADATA_FILTER_USER_PROMPT_3, | |||
| ) | |||
| from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate | |||
| from core.workflow.nodes.llm.node import LLMNode | |||
| from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2 | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset, Document, RateLimitLog | |||
| from libs.json_in_md_parser import parse_and_check_json_markdown | |||
| from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from services.feature_service import FeatureService | |||
| from .entities import KnowledgeRetrievalNodeData | |||
| from .entities import KnowledgeRetrievalNodeData, ModelConfig | |||
| from .exc import ( | |||
| InvalidModelTypeError, | |||
| KnowledgeRetrievalNodeError, | |||
| ModelCredentialsNotInitializedError, | |||
| ModelNotExistError, | |||
| @@ -45,13 +64,14 @@ default_retrieval_model = { | |||
| } | |||
| class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||
| _node_data_cls = KnowledgeRetrievalNodeData | |||
| class KnowledgeRetrievalNode(LLMNode): | |||
| _node_data_cls = KnowledgeRetrievalNodeData # type: ignore | |||
| _node_type = NodeType.KNOWLEDGE_RETRIEVAL | |||
| def _run(self) -> NodeRunResult: | |||
| def _run(self) -> NodeRunResult: # type: ignore | |||
| node_data = cast(KnowledgeRetrievalNodeData, self.node_data) | |||
| # extract variables | |||
| variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector) | |||
| variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector) | |||
| if not isinstance(variable, StringSegment): | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| @@ -91,7 +111,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||
| # retrieve knowledge | |||
| try: | |||
| results = self._fetch_dataset_retriever(node_data=self.node_data, query=query) | |||
| results = self._fetch_dataset_retriever(node_data=node_data, query=query) | |||
| outputs = {"result": results} | |||
| return NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs | |||
| @@ -145,11 +165,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||
| if not dataset: | |||
| continue | |||
| available_datasets.append(dataset) | |||
| metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( | |||
| [dataset.id for dataset in available_datasets], query, node_data | |||
| ) | |||
| all_documents = [] | |||
| dataset_retrieval = DatasetRetrieval() | |||
| if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: | |||
| # fetch model config | |||
| model_instance, model_config = self._fetch_model_config(node_data) | |||
| model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore | |||
| # check model is support tool calling | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| @@ -174,6 +197,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||
| model_config=model_config, | |||
| model_instance=model_instance, | |||
| planning_strategy=planning_strategy, | |||
| metadata_filter_document_ids=metadata_filter_document_ids, | |||
| metadata_condition=metadata_condition, | |||
| ) | |||
| elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: | |||
| if node_data.multiple_retrieval_config is None: | |||
| @@ -220,6 +245,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||
| reranking_model=reranking_model, | |||
| weights=weights, | |||
| reranking_enable=node_data.multiple_retrieval_config.reranking_enable, | |||
| metadata_filter_document_ids=metadata_filter_document_ids, | |||
| metadata_condition=metadata_condition, | |||
| ) | |||
| dify_documents = [item for item in all_documents if item.provider == "dify"] | |||
| external_documents = [item for item in all_documents if item.provider == "external"] | |||
| @@ -287,13 +314,187 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||
| item["metadata"]["position"] = position | |||
| return retrieval_resource_list | |||
| def _get_metadata_filter_condition( | |||
| self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData | |||
| ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: | |||
| document_query = db.session.query(Document).filter( | |||
| Document.dataset_id.in_(dataset_ids), | |||
| Document.indexing_status == "completed", | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ) | |||
| filters = [] # type: ignore | |||
| metadata_condition = None | |||
| if node_data.metadata_filtering_mode == "disabled": | |||
| return None, None | |||
| elif node_data.metadata_filtering_mode == "automatic": | |||
| automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data) | |||
| if automatic_metadata_filters: | |||
| conditions = [] | |||
| for filter in automatic_metadata_filters: | |||
| self._process_metadata_filter_func( | |||
| filter.get("condition", ""), | |||
| filter.get("metadata_name", ""), | |||
| filter.get("value"), | |||
| filters, # type: ignore | |||
| ) | |||
| conditions.append( | |||
| Condition( | |||
| name=filter.get("metadata_name"), # type: ignore | |||
| comparison_operator=filter.get("condition"), # type: ignore | |||
| value=filter.get("value"), | |||
| ) | |||
| ) | |||
| metadata_condition = MetadataCondition( | |||
| logical_operator=node_data.metadata_filtering_conditions.logical_operator, # type: ignore | |||
| conditions=conditions, | |||
| ) | |||
| elif node_data.metadata_filtering_mode == "manual": | |||
| if node_data.metadata_filtering_conditions: | |||
| metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump()) | |||
| if node_data.metadata_filtering_conditions: | |||
| for condition in node_data.metadata_filtering_conditions.conditions: # type: ignore | |||
| metadata_name = condition.name | |||
| expected_value = condition.value | |||
| if expected_value or condition.comparison_operator in ("empty", "not empty"): | |||
| if isinstance(expected_value, str): | |||
| expected_value = self.graph_runtime_state.variable_pool.convert_template( | |||
| expected_value | |||
| ).text | |||
| filters = self._process_metadata_filter_func( | |||
| condition.comparison_operator, metadata_name, expected_value, filters | |||
| ) | |||
| else: | |||
| raise ValueError("Invalid metadata filtering mode") | |||
| if filters: | |||
| if node_data.metadata_filtering_conditions.logical_operator == "and": # type: ignore | |||
| document_query = document_query.filter(and_(*filters)) | |||
| else: | |||
| document_query = document_query.filter(or_(*filters)) | |||
| documents = document_query.all() | |||
| # group by dataset_id | |||
| metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore | |||
| for document in documents: | |||
| metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore | |||
| return metadata_filter_document_ids, metadata_condition | |||
| def _automatic_metadata_filter_func( | |||
| self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData | |||
| ) -> list[dict[str, Any]]: | |||
| # get all metadata field | |||
| metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() | |||
| all_metadata_fields = [metadata_field.field_name for metadata_field in metadata_fields] | |||
| # get metadata model config | |||
| metadata_model_config = node_data.metadata_model_config | |||
| if metadata_model_config is None: | |||
| raise ValueError("metadata_model_config is required") | |||
| # get metadata model instance | |||
| # fetch model config | |||
| model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore | |||
| # fetch prompt messages | |||
| prompt_template = self._get_prompt_template( | |||
| node_data=node_data, | |||
| metadata_fields=all_metadata_fields, | |||
| query=query or "", | |||
| ) | |||
| prompt_messages, stop = self._fetch_prompt_messages( | |||
| prompt_template=prompt_template, | |||
| sys_query=query, | |||
| memory=None, | |||
| model_config=model_config, | |||
| sys_files=[], | |||
| vision_enabled=node_data.vision.enabled, | |||
| vision_detail=node_data.vision.configs.detail, | |||
| variable_pool=self.graph_runtime_state.variable_pool, | |||
| jinja2_variables=[], | |||
| ) | |||
| result_text = "" | |||
| try: | |||
| # handle invoke result | |||
| generator = self._invoke_llm( | |||
| node_data_model=node_data.metadata_model_config, # type: ignore | |||
| model_instance=model_instance, | |||
| prompt_messages=prompt_messages, | |||
| stop=stop, | |||
| ) | |||
| for event in generator: | |||
| if isinstance(event, ModelInvokeCompletedEvent): | |||
| result_text = event.text | |||
| break | |||
| result_text_json = parse_and_check_json_markdown(result_text, []) | |||
| automatic_metadata_filters = [] | |||
| if "metadata_map" in result_text_json: | |||
| metadata_map = result_text_json["metadata_map"] | |||
| for item in metadata_map: | |||
| if item.get("metadata_field_name") in all_metadata_fields: | |||
| automatic_metadata_filters.append( | |||
| { | |||
| "metadata_name": item.get("metadata_field_name"), | |||
| "value": item.get("metadata_field_value"), | |||
| "condition": item.get("comparison_operator"), | |||
| } | |||
| ) | |||
| except Exception as e: | |||
| return [] | |||
| return automatic_metadata_filters | |||
| def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[str], filters: list): | |||
| match condition: | |||
| case "contains": | |||
| filters.append( | |||
| (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%") | |||
| ) | |||
| case "not contains": | |||
| filters.append( | |||
| (text("documents.doc_metadata ->> :key NOT LIKE :value")).params( | |||
| key=metadata_name, value=f"%{value}%" | |||
| ) | |||
| ) | |||
| case "start with": | |||
| filters.append( | |||
| (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%") | |||
| ) | |||
| case "end with": | |||
| filters.append( | |||
| (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}") | |||
| ) | |||
| case "=" | "is": | |||
| if isinstance(value, str): | |||
| filters.append(Document.doc_metadata[metadata_name] == f'"{value}"') | |||
| else: | |||
| filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value) | |||
| case "is not" | "≠": | |||
| if isinstance(value, str): | |||
| filters.append(Document.doc_metadata[metadata_name] != f'"{value}"') | |||
| else: | |||
| filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value) | |||
| case "empty": | |||
| filters.append(Document.doc_metadata[metadata_name].is_(None)) | |||
| case "not empty": | |||
| filters.append(Document.doc_metadata[metadata_name].isnot(None)) | |||
| case "before" | "<": | |||
| filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value) | |||
| case "after" | ">": | |||
| filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value) | |||
| case "≤" | ">=": | |||
| filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value) | |||
| case "≥" | ">=": | |||
| filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value) | |||
| case _: | |||
| pass | |||
| return filters | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| *, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: KnowledgeRetrievalNodeData, | |||
| node_data: KnowledgeRetrievalNodeData, # type: ignore | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| @@ -306,18 +507,16 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||
| variable_mapping[node_id + ".query"] = node_data.query_variable_selector | |||
| return variable_mapping | |||
| def _fetch_model_config( | |||
| self, node_data: KnowledgeRetrievalNodeData | |||
| ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: | |||
| def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore | |||
| """ | |||
| Fetch model config | |||
| :param node_data: node data | |||
| :param model: model | |||
| :return: | |||
| """ | |||
| if node_data.single_retrieval_config is None: | |||
| raise ValueError("single_retrieval_config is required") | |||
| model_name = node_data.single_retrieval_config.model.name | |||
| provider_name = node_data.single_retrieval_config.model.provider | |||
| if model is None: | |||
| raise ValueError("model is required") | |||
| model_name = model.name | |||
| provider_name = model.provider | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_model_instance( | |||
| @@ -346,14 +545,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||
| raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.") | |||
| # model config | |||
| completion_params = node_data.single_retrieval_config.model.completion_params | |||
| completion_params = model.completion_params | |||
| stop = [] | |||
| if "stop" in completion_params: | |||
| stop = completion_params["stop"] | |||
| del completion_params["stop"] | |||
| # get model mode | |||
| model_mode = node_data.single_retrieval_config.model.mode | |||
| model_mode = model.mode | |||
| if not model_mode: | |||
| raise ModelNotExistError("LLM mode is required.") | |||
| @@ -372,3 +571,50 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||
| parameters=completion_params, | |||
| stop=stop, | |||
| ) | |||
| def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str): | |||
| model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore | |||
| input_text = query | |||
| memory_str = "" | |||
| prompt_messages: list[LLMNodeChatModelMessage] = [] | |||
| if model_mode == ModelMode.CHAT: | |||
| system_prompt_messages = LLMNodeChatModelMessage( | |||
| role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT | |||
| ) | |||
| prompt_messages.append(system_prompt_messages) | |||
| user_prompt_message_1 = LLMNodeChatModelMessage( | |||
| role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1 | |||
| ) | |||
| prompt_messages.append(user_prompt_message_1) | |||
| assistant_prompt_message_1 = LLMNodeChatModelMessage( | |||
| role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1 | |||
| ) | |||
| prompt_messages.append(assistant_prompt_message_1) | |||
| user_prompt_message_2 = LLMNodeChatModelMessage( | |||
| role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 | |||
| ) | |||
| prompt_messages.append(user_prompt_message_2) | |||
| assistant_prompt_message_2 = LLMNodeChatModelMessage( | |||
| role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2 | |||
| ) | |||
| prompt_messages.append(assistant_prompt_message_2) | |||
| user_prompt_message_3 = LLMNodeChatModelMessage( | |||
| role=PromptMessageRole.USER, | |||
| text=METADATA_FILTER_USER_PROMPT_3.format( | |||
| input_text=input_text, | |||
| metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), | |||
| ), | |||
| ) | |||
| prompt_messages.append(user_prompt_message_3) | |||
| return prompt_messages | |||
| elif model_mode == ModelMode.COMPLETION: | |||
| return LLMNodeCompletionModelPromptTemplate( | |||
| text=METADATA_FILTER_COMPLETION_PROMPT.format( | |||
| input_text=input_text, | |||
| metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), | |||
| ) | |||
| ) | |||
| else: | |||
| raise InvalidModelTypeError(f"Model mode {model_mode} not support.") | |||
| @@ -0,0 +1,66 @@ | |||
| METADATA_FILTER_SYSTEM_PROMPT = """ | |||
| ### Job Description', | |||
| You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value | |||
| ### Task | |||
| Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". | |||
| ### Format | |||
| The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. | |||
| ### Constraint | |||
| DO NOT include anything other than the JSON array in your response. | |||
| """ # noqa: E501 | |||
| METADATA_FILTER_USER_PROMPT_1 = """ | |||
| { "input_text": "I want to know which company’s email address test@example.com is?", | |||
| "metadata_fields": ["filename", "email", "phone", "address"] | |||
| } | |||
| """ | |||
| METADATA_FILTER_ASSISTANT_PROMPT_1 = """ | |||
| ```json | |||
| {"metadata_map": [ | |||
| {"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="} | |||
| ] | |||
| } | |||
| ``` | |||
| """ | |||
| METADATA_FILTER_USER_PROMPT_2 = """ | |||
| {"input_text": "What are the movies with a score of more than 9 in 2024?", | |||
| "metadata_fields": ["name", "year", "rating", "country"]} | |||
| """ | |||
| METADATA_FILTER_ASSISTANT_PROMPT_2 = """ | |||
| ```json | |||
| {"metadata_map": [ | |||
| {"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, | |||
| {"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}, | |||
| ]} | |||
| ``` | |||
| """ | |||
| METADATA_FILTER_USER_PROMPT_3 = """ | |||
| '{{"input_text": "{input_text}",', | |||
| '"metadata_fields": {metadata_fields}}}' | |||
| """ | |||
| METADATA_FILTER_COMPLETION_PROMPT = """ | |||
| ### Job Description | |||
| You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value | |||
| ### Task | |||
| # Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". | |||
| ### Format | |||
| The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. | |||
| ### Constraint | |||
| DO NOT include anything other than the JSON array in your response. | |||
| ### Example | |||
| Here is the chat example between human and assistant, inside <example></example> XML tags. | |||
| <example> | |||
| User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}} | |||
| Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}} | |||
| User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}} | |||
| Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}} | |||
| </example> | |||
| ### User Input | |||
| {{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}} | |||
| ### Assistant Output | |||
| """ # noqa: E501 | |||
| @@ -53,6 +53,8 @@ external_knowledge_info_fields = { | |||
| "external_knowledge_api_endpoint": fields.String, | |||
| } | |||
| doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String} | |||
| dataset_detail_fields = { | |||
| "id": fields.String, | |||
| "name": fields.String, | |||
| @@ -76,6 +78,8 @@ dataset_detail_fields = { | |||
| "doc_form": fields.String, | |||
| "external_knowledge_info": fields.Nested(external_knowledge_info_fields), | |||
| "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), | |||
| "doc_metadata": fields.List(fields.Nested(doc_metadata_fields)), | |||
| "built_in_field_enabled": fields.Boolean, | |||
| } | |||
| dataset_query_detail_fields = { | |||
| @@ -87,3 +91,9 @@ dataset_query_detail_fields = { | |||
| "created_by": fields.String, | |||
| "created_at": TimestampField, | |||
| } | |||
| dataset_metadata_fields = { | |||
| "id": fields.String, | |||
| "type": fields.String, | |||
| "name": fields.String, | |||
| } | |||
| @@ -3,6 +3,13 @@ from flask_restful import fields # type: ignore | |||
| from fields.dataset_fields import dataset_fields | |||
| from libs.helper import TimestampField | |||
| document_metadata_fields = { | |||
| "id": fields.String, | |||
| "name": fields.String, | |||
| "type": fields.String, | |||
| "value": fields.String, | |||
| } | |||
| document_fields = { | |||
| "id": fields.String, | |||
| "position": fields.Integer, | |||
| @@ -25,6 +32,7 @@ document_fields = { | |||
| "word_count": fields.Integer, | |||
| "hit_count": fields.Integer, | |||
| "doc_form": fields.String, | |||
| "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"), | |||
| } | |||
| document_with_segments_fields = { | |||
| @@ -51,6 +59,7 @@ document_with_segments_fields = { | |||
| "hit_count": fields.Integer, | |||
| "completed_segments": fields.Integer, | |||
| "total_segments": fields.Integer, | |||
| "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"), | |||
| } | |||
| dataset_and_document_fields = { | |||
| @@ -0,0 +1,90 @@ | |||
| """add_metadata_function | |||
| Revision ID: d20049ed0af6 | |||
| Revises: 08ec4f75af5e | |||
| Create Date: 2025-02-27 09:17:48.903213 | |||
| """ | |||
| from alembic import op | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.dialects import postgresql | |||
| # revision identifiers, used by Alembic. | |||
| revision = 'd20049ed0af6' | |||
| down_revision = 'f051706725cc' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| op.create_table('dataset_metadata_bindings', | |||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('dataset_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('metadata_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('document_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), | |||
| sa.Column('created_by', models.types.StringUUID(), nullable=False), | |||
| sa.PrimaryKeyConstraint('id', name='dataset_metadata_binding_pkey') | |||
| ) | |||
| with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op: | |||
| batch_op.create_index('dataset_metadata_binding_dataset_idx', ['dataset_id'], unique=False) | |||
| batch_op.create_index('dataset_metadata_binding_document_idx', ['document_id'], unique=False) | |||
| batch_op.create_index('dataset_metadata_binding_metadata_idx', ['metadata_id'], unique=False) | |||
| batch_op.create_index('dataset_metadata_binding_tenant_idx', ['tenant_id'], unique=False) | |||
| op.create_table('dataset_metadatas', | |||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('dataset_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('type', sa.String(length=255), nullable=False), | |||
| sa.Column('name', sa.String(length=255), nullable=False), | |||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||
| sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||
| sa.Column('created_by', models.types.StringUUID(), nullable=False), | |||
| sa.Column('updated_by', models.types.StringUUID(), nullable=True), | |||
| sa.PrimaryKeyConstraint('id', name='dataset_metadata_pkey') | |||
| ) | |||
| with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op: | |||
| batch_op.create_index('dataset_metadata_dataset_idx', ['dataset_id'], unique=False) | |||
| batch_op.create_index('dataset_metadata_tenant_idx', ['tenant_id'], unique=False) | |||
| with op.batch_alter_table('datasets', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('built_in_field_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False)) | |||
| with op.batch_alter_table('documents', schema=None) as batch_op: | |||
| batch_op.alter_column('doc_metadata', | |||
| existing_type=postgresql.JSON(astext_type=sa.Text()), | |||
| type_=postgresql.JSONB(astext_type=sa.Text()), | |||
| existing_nullable=True) | |||
| batch_op.create_index('document_metadata_idx', ['doc_metadata'], unique=False, postgresql_using='gin') | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('documents', schema=None) as batch_op: | |||
| batch_op.drop_index('document_metadata_idx', postgresql_using='gin') | |||
| batch_op.alter_column('doc_metadata', | |||
| existing_type=postgresql.JSONB(astext_type=sa.Text()), | |||
| type_=postgresql.JSON(astext_type=sa.Text()), | |||
| existing_nullable=True) | |||
| with op.batch_alter_table('datasets', schema=None) as batch_op: | |||
| batch_op.drop_column('built_in_field_enabled') | |||
| with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op: | |||
| batch_op.drop_index('dataset_metadata_tenant_idx') | |||
| batch_op.drop_index('dataset_metadata_dataset_idx') | |||
| op.drop_table('dataset_metadatas') | |||
| with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op: | |||
| batch_op.drop_index('dataset_metadata_binding_tenant_idx') | |||
| batch_op.drop_index('dataset_metadata_binding_metadata_idx') | |||
| batch_op.drop_index('dataset_metadata_binding_document_idx') | |||
| batch_op.drop_index('dataset_metadata_binding_dataset_idx') | |||
| op.drop_table('dataset_metadata_bindings') | |||
| # ### end Alembic commands ### | |||
| @@ -16,6 +16,7 @@ from sqlalchemy.dialects.postgresql import JSONB | |||
| from sqlalchemy.orm import Mapped | |||
| from configs import dify_config | |||
| from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from extensions.ext_storage import storage | |||
| from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule | |||
| @@ -60,6 +61,7 @@ class Dataset(db.Model): # type: ignore[name-defined] | |||
| embedding_model_provider = db.Column(db.String(255), nullable=True) | |||
| collection_binding_id = db.Column(StringUUID, nullable=True) | |||
| retrieval_model = db.Column(JSONB, nullable=True) | |||
| built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||
| @property | |||
| def dataset_keyword_table(self): | |||
| @@ -197,6 +199,56 @@ class Dataset(db.Model): # type: ignore[name-defined] | |||
| "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), | |||
| } | |||
| @property | |||
| def doc_metadata(self): | |||
| dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all() | |||
| doc_metadata = [ | |||
| { | |||
| "id": dataset_metadata.id, | |||
| "name": dataset_metadata.name, | |||
| "type": dataset_metadata.type, | |||
| } | |||
| for dataset_metadata in dataset_metadatas | |||
| ] | |||
| if self.built_in_field_enabled: | |||
| doc_metadata.append( | |||
| { | |||
| "id": "built-in", | |||
| "name": BuiltInField.document_name.value, | |||
| "type": "string", | |||
| } | |||
| ) | |||
| doc_metadata.append( | |||
| { | |||
| "id": "built-in", | |||
| "name": BuiltInField.uploader.value, | |||
| "type": "string", | |||
| } | |||
| ) | |||
| doc_metadata.append( | |||
| { | |||
| "id": "built-in", | |||
| "name": BuiltInField.upload_date.value, | |||
| "type": "time", | |||
| } | |||
| ) | |||
| doc_metadata.append( | |||
| { | |||
| "id": "built-in", | |||
| "name": BuiltInField.last_update_date.value, | |||
| "type": "time", | |||
| } | |||
| ) | |||
| doc_metadata.append( | |||
| { | |||
| "id": "built-in", | |||
| "name": BuiltInField.source.value, | |||
| "type": "string", | |||
| } | |||
| ) | |||
| return doc_metadata | |||
| @staticmethod | |||
| def gen_collection_name_by_id(dataset_id: str) -> str: | |||
| normalized_dataset_id = dataset_id.replace("-", "_") | |||
| @@ -250,6 +302,7 @@ class Document(db.Model): # type: ignore[name-defined] | |||
| db.Index("document_dataset_id_idx", "dataset_id"), | |||
| db.Index("document_is_paused_idx", "is_paused"), | |||
| db.Index("document_tenant_idx", "tenant_id"), | |||
| db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"), | |||
| ) | |||
| # initial fields | |||
| @@ -306,7 +359,7 @@ class Document(db.Model): # type: ignore[name-defined] | |||
| archived_at = db.Column(db.DateTime, nullable=True) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| doc_type = db.Column(db.String(40), nullable=True) | |||
| doc_metadata = db.Column(db.JSON, nullable=True) | |||
| doc_metadata = db.Column(JSONB, nullable=True) | |||
| doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) | |||
| doc_language = db.Column(db.String(255), nullable=True) | |||
| @@ -396,12 +449,95 @@ class Document(db.Model): # type: ignore[name-defined] | |||
| .scalar() | |||
| ) | |||
| @property | |||
| def uploader(self): | |||
| user = db.session.query(Account).filter(Account.id == self.created_by).first() | |||
| return user.name if user else None | |||
| @property | |||
| def upload_date(self): | |||
| return self.created_at | |||
| @property | |||
| def last_update_date(self): | |||
| return self.updated_at | |||
| @property | |||
| def doc_metadata_details(self): | |||
| if self.doc_metadata: | |||
| document_metadatas = ( | |||
| db.session.query(DatasetMetadata) | |||
| .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id) | |||
| .filter( | |||
| DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id | |||
| ) | |||
| .all() | |||
| ) | |||
| metadata_list = [] | |||
| for metadata in document_metadatas: | |||
| metadata_dict = { | |||
| "id": metadata.id, | |||
| "name": metadata.name, | |||
| "type": metadata.type, | |||
| "value": self.doc_metadata.get(metadata.name), | |||
| } | |||
| metadata_list.append(metadata_dict) | |||
| # deal built-in fields | |||
| metadata_list.extend(self.get_built_in_fields()) | |||
| return metadata_list | |||
| return None | |||
| @property | |||
| def process_rule_dict(self): | |||
| if self.dataset_process_rule_id: | |||
| return self.dataset_process_rule.to_dict() | |||
| return None | |||
| def get_built_in_fields(self): | |||
| built_in_fields = [] | |||
| built_in_fields.append( | |||
| { | |||
| "id": "built-in", | |||
| "name": BuiltInField.document_name, | |||
| "type": "string", | |||
| "value": self.name, | |||
| } | |||
| ) | |||
| built_in_fields.append( | |||
| { | |||
| "id": "built-in", | |||
| "name": BuiltInField.uploader, | |||
| "type": "string", | |||
| "value": self.uploader, | |||
| } | |||
| ) | |||
| built_in_fields.append( | |||
| { | |||
| "id": "built-in", | |||
| "name": BuiltInField.upload_date, | |||
| "type": "time", | |||
| "value": self.created_at.timestamp(), | |||
| } | |||
| ) | |||
| built_in_fields.append( | |||
| { | |||
| "id": "built-in", | |||
| "name": BuiltInField.last_update_date, | |||
| "type": "time", | |||
| "value": self.updated_at.timestamp(), | |||
| } | |||
| ) | |||
| built_in_fields.append( | |||
| { | |||
| "id": "built-in", | |||
| "name": BuiltInField.source, | |||
| "type": "string", | |||
| "value": MetadataDataSource[self.data_source_type].value, | |||
| } | |||
| ) | |||
| return built_in_fields | |||
| def to_dict(self): | |||
| return { | |||
| "id": self.id, | |||
| @@ -945,3 +1081,41 @@ class RateLimitLog(db.Model): # type: ignore[name-defined] | |||
| subscription_plan = db.Column(db.String(255), nullable=False) | |||
| operation = db.Column(db.String(255), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class DatasetMetadata(db.Model): # type: ignore[name-defined] | |||
| __tablename__ = "dataset_metadatas" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), | |||
| db.Index("dataset_metadata_tenant_idx", "tenant_id"), | |||
| db.Index("dataset_metadata_dataset_idx", "dataset_id"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| dataset_id = db.Column(StringUUID, nullable=False) | |||
| type = db.Column(db.String(255), nullable=False) | |||
| name = db.Column(db.String(255), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| updated_by = db.Column(StringUUID, nullable=True) | |||
| class DatasetMetadataBinding(db.Model): # type: ignore[name-defined] | |||
| __tablename__ = "dataset_metadata_bindings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), | |||
| db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"), | |||
| db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"), | |||
| db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"), | |||
| db.Index("dataset_metadata_binding_document_idx", "document_id"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| dataset_id = db.Column(StringUUID, nullable=False) | |||
| metadata_id = db.Column(StringUUID, nullable=False) | |||
| document_id = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| @@ -1,3 +1,4 @@ | |||
| import copy | |||
| import datetime | |||
| import json | |||
| import logging | |||
| @@ -17,6 +18,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.rag.index_processor.constant.built_in_field import BuiltInField | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from events.dataset_event import dataset_was_deleted | |||
| @@ -643,9 +645,45 @@ class DocumentService: | |||
| return document | |||
| @staticmethod | |||
| def get_document_by_ids(document_ids: list[str]) -> list[Document]: | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .filter( | |||
| Document.id.in_(document_ids), | |||
| Document.enabled == True, | |||
| Document.indexing_status == "completed", | |||
| Document.archived == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| return documents | |||
| @staticmethod | |||
| def get_document_by_dataset_id(dataset_id: str) -> list[Document]: | |||
| documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all() | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .filter( | |||
| Document.dataset_id == dataset_id, | |||
| Document.enabled == True, | |||
| ) | |||
| .all() | |||
| ) | |||
| return documents | |||
| @staticmethod | |||
| def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]: | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .filter( | |||
| Document.dataset_id == dataset_id, | |||
| Document.enabled == True, | |||
| Document.indexing_status == "completed", | |||
| Document.archived == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| return documents | |||
| @@ -728,8 +766,13 @@ class DocumentService: | |||
| if document.tenant_id != current_user.current_tenant_id: | |||
| raise ValueError("No permission.") | |||
| document.name = name | |||
| if dataset.built_in_field_enabled: | |||
| if document.doc_metadata: | |||
| doc_metadata = copy.deepcopy(document.doc_metadata) | |||
| doc_metadata[BuiltInField.document_name.value] = name | |||
| document.doc_metadata = doc_metadata | |||
| document.name = name | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| @@ -1128,9 +1171,20 @@ class DocumentService: | |||
| doc_form=document_form, | |||
| doc_language=document_language, | |||
| ) | |||
| doc_metadata = {} | |||
| if dataset.built_in_field_enabled: | |||
| doc_metadata = { | |||
| BuiltInField.document_name: name, | |||
| BuiltInField.uploader: account.name, | |||
| BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), | |||
| BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), | |||
| BuiltInField.source: data_source_type, | |||
| } | |||
| if metadata is not None: | |||
| document.doc_metadata = metadata.doc_metadata | |||
| doc_metadata.update(metadata.doc_metadata) | |||
| document.doc_type = metadata.doc_type | |||
| if doc_metadata: | |||
| document.doc_metadata = doc_metadata | |||
| return document | |||
| @staticmethod | |||
| @@ -125,3 +125,36 @@ class SegmentUpdateArgs(BaseModel): | |||
| class ChildChunkUpdateArgs(BaseModel): | |||
| id: Optional[str] = None | |||
| content: str | |||
| class MetadataArgs(BaseModel): | |||
| type: Literal["string", "number", "time"] | |||
| name: str | |||
| class MetadataUpdateArgs(BaseModel): | |||
| name: str | |||
| value: Optional[str | int | float] = None | |||
| class MetadataValueUpdateArgs(BaseModel): | |||
| fields: list[MetadataUpdateArgs] | |||
| class MetadataDetail(BaseModel): | |||
| id: str | |||
| name: str | |||
| value: Optional[str | int | float] = None | |||
| class DocumentMetadataOperation(BaseModel): | |||
| document_id: str | |||
| metadata_list: list[MetadataDetail] | |||
| class MetadataOperationData(BaseModel): | |||
| """ | |||
| Metadata operation data | |||
| """ | |||
| operation_data: list[DocumentMetadataOperation] | |||
| @@ -8,6 +8,7 @@ import validators | |||
| from constants import HIDDEN_VALUE | |||
| from core.helper import ssrf_proxy | |||
| from core.rag.entities.metadata_entities import MetadataCondition | |||
| from extensions.ext_database import db | |||
| from models.dataset import ( | |||
| Dataset, | |||
| @@ -245,7 +246,11 @@ class ExternalDatasetService: | |||
| @staticmethod | |||
| def fetch_external_knowledge_retrieval( | |||
| tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict | |||
| tenant_id: str, | |||
| dataset_id: str, | |||
| query: str, | |||
| external_retrieval_parameters: dict, | |||
| metadata_condition: Optional[MetadataCondition] = None, | |||
| ) -> list: | |||
| external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( | |||
| dataset_id=dataset_id, tenant_id=tenant_id | |||
| @@ -272,6 +277,7 @@ class ExternalDatasetService: | |||
| }, | |||
| "query": query, | |||
| "knowledge_id": external_knowledge_binding.external_knowledge_id, | |||
| "metadata_condition": metadata_condition.model_dump() if metadata_condition else None, | |||
| } | |||
| response = ExternalDatasetService.process_external_api( | |||
| @@ -0,0 +1,241 @@ | |||
| import copy | |||
| import datetime | |||
| import logging | |||
| from typing import Optional | |||
| from flask_login import current_user # type: ignore | |||
| from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding | |||
| from services.dataset_service import DocumentService | |||
| from services.entities.knowledge_entities.knowledge_entities import ( | |||
| MetadataArgs, | |||
| MetadataOperationData, | |||
| ) | |||
| class MetadataService: | |||
| @staticmethod | |||
| def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: | |||
| # check if metadata name already exists | |||
| if DatasetMetadata.query.filter_by( | |||
| tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name | |||
| ).first(): | |||
| raise ValueError("Metadata name already exists.") | |||
| for field in BuiltInField: | |||
| if field.value == metadata_args.name: | |||
| raise ValueError("Metadata name already exists in Built-in fields.") | |||
| metadata = DatasetMetadata( | |||
| tenant_id=current_user.current_tenant_id, | |||
| dataset_id=dataset_id, | |||
| type=metadata_args.type, | |||
| name=metadata_args.name, | |||
| created_by=current_user.id, | |||
| ) | |||
| db.session.add(metadata) | |||
| db.session.commit() | |||
| return metadata | |||
| @staticmethod | |||
| def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore | |||
| lock_key = f"dataset_metadata_lock_{dataset_id}" | |||
| # check if metadata name already exists | |||
| if DatasetMetadata.query.filter_by( | |||
| tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name | |||
| ).first(): | |||
| raise ValueError("Metadata name already exists.") | |||
| for field in BuiltInField: | |||
| if field.value == name: | |||
| raise ValueError("Metadata name already exists in Built-in fields.") | |||
| try: | |||
| MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) | |||
| metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() | |||
| if metadata is None: | |||
| raise ValueError("Metadata not found.") | |||
| old_name = metadata.name | |||
| metadata.name = name | |||
| metadata.updated_by = current_user.id | |||
| metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| # update related documents | |||
| dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() | |||
| if dataset_metadata_bindings: | |||
| document_ids = [binding.document_id for binding in dataset_metadata_bindings] | |||
| documents = DocumentService.get_document_by_ids(document_ids) | |||
| for document in documents: | |||
| doc_metadata = copy.deepcopy(document.doc_metadata) | |||
| value = doc_metadata.pop(old_name, None) | |||
| doc_metadata[name] = value | |||
| document.doc_metadata = doc_metadata | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| return metadata # type: ignore | |||
| except Exception: | |||
| logging.exception("Update metadata name failed") | |||
| finally: | |||
| redis_client.delete(lock_key) | |||
| @staticmethod | |||
| def delete_metadata(dataset_id: str, metadata_id: str): | |||
| lock_key = f"dataset_metadata_lock_{dataset_id}" | |||
| try: | |||
| MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) | |||
| metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() | |||
| if metadata is None: | |||
| raise ValueError("Metadata not found.") | |||
| db.session.delete(metadata) | |||
| # deal related documents | |||
| dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() | |||
| if dataset_metadata_bindings: | |||
| document_ids = [binding.document_id for binding in dataset_metadata_bindings] | |||
| documents = DocumentService.get_document_by_ids(document_ids) | |||
| for document in documents: | |||
| doc_metadata = copy.deepcopy(document.doc_metadata) | |||
| doc_metadata.pop(metadata.name, None) | |||
| document.doc_metadata = doc_metadata | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| return metadata | |||
| except Exception: | |||
| logging.exception("Delete metadata failed") | |||
| finally: | |||
| redis_client.delete(lock_key) | |||
| @staticmethod | |||
| def get_built_in_fields(): | |||
| return [ | |||
| {"name": BuiltInField.document_name.value, "type": "string"}, | |||
| {"name": BuiltInField.uploader.value, "type": "string"}, | |||
| {"name": BuiltInField.upload_date.value, "type": "time"}, | |||
| {"name": BuiltInField.last_update_date.value, "type": "time"}, | |||
| {"name": BuiltInField.source.value, "type": "string"}, | |||
| ] | |||
| @staticmethod | |||
| def enable_built_in_field(dataset: Dataset): | |||
| if dataset.built_in_field_enabled: | |||
| return | |||
| lock_key = f"dataset_metadata_lock_{dataset.id}" | |||
| try: | |||
| MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) | |||
| dataset.built_in_field_enabled = True | |||
| db.session.add(dataset) | |||
| documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) | |||
| if documents: | |||
| for document in documents: | |||
| if not document.doc_metadata: | |||
| doc_metadata = {} | |||
| else: | |||
| doc_metadata = copy.deepcopy(document.doc_metadata) | |||
| doc_metadata[BuiltInField.document_name.value] = document.name | |||
| doc_metadata[BuiltInField.uploader.value] = document.uploader | |||
| doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp() | |||
| doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp() | |||
| doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value | |||
| document.doc_metadata = doc_metadata | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| except Exception: | |||
| logging.exception("Enable built-in field failed") | |||
| finally: | |||
| redis_client.delete(lock_key) | |||
| @staticmethod | |||
| def disable_built_in_field(dataset: Dataset): | |||
| if not dataset.built_in_field_enabled: | |||
| return | |||
| lock_key = f"dataset_metadata_lock_{dataset.id}" | |||
| try: | |||
| MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) | |||
| dataset.built_in_field_enabled = False | |||
| db.session.add(dataset) | |||
| documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) | |||
| document_ids = [] | |||
| if documents: | |||
| for document in documents: | |||
| doc_metadata = copy.deepcopy(document.doc_metadata) | |||
| doc_metadata.pop(BuiltInField.document_name.value, None) | |||
| doc_metadata.pop(BuiltInField.uploader.value, None) | |||
| doc_metadata.pop(BuiltInField.upload_date.value, None) | |||
| doc_metadata.pop(BuiltInField.last_update_date.value, None) | |||
| doc_metadata.pop(BuiltInField.source.value, None) | |||
| document.doc_metadata = doc_metadata | |||
| db.session.add(document) | |||
| document_ids.append(document.id) | |||
| db.session.commit() | |||
| except Exception: | |||
| logging.exception("Disable built-in field failed") | |||
| finally: | |||
| redis_client.delete(lock_key) | |||
| @staticmethod | |||
| def update_documents_metadata(dataset: Dataset, metadata_args: MetadataOperationData): | |||
| for operation in metadata_args.operation_data: | |||
| lock_key = f"document_metadata_lock_{operation.document_id}" | |||
| try: | |||
| MetadataService.knowledge_base_metadata_lock_check(None, operation.document_id) | |||
| document = DocumentService.get_document(dataset.id, operation.document_id) | |||
| if document is None: | |||
| raise ValueError("Document not found.") | |||
| doc_metadata = {} | |||
| for metadata_value in operation.metadata_list: | |||
| doc_metadata[metadata_value.name] = metadata_value.value | |||
| if dataset.built_in_field_enabled: | |||
| doc_metadata[BuiltInField.document_name.value] = document.name | |||
| doc_metadata[BuiltInField.uploader.value] = document.uploader | |||
| doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp() | |||
| doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp() | |||
| doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value | |||
| document.doc_metadata = doc_metadata | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| # deal metadata binding | |||
| DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete() | |||
| for metadata_value in operation.metadata_list: | |||
| dataset_metadata_binding = DatasetMetadataBinding( | |||
| tenant_id=current_user.current_tenant_id, | |||
| dataset_id=dataset.id, | |||
| document_id=operation.document_id, | |||
| metadata_id=metadata_value.id, | |||
| created_by=current_user.id, | |||
| ) | |||
| db.session.add(dataset_metadata_binding) | |||
| db.session.commit() | |||
| except Exception: | |||
| logging.exception("Update documents metadata failed") | |||
| finally: | |||
| redis_client.delete(lock_key) | |||
| @staticmethod | |||
| def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]): | |||
| if dataset_id: | |||
| lock_key = f"dataset_metadata_lock_{dataset_id}" | |||
| if redis_client.get(lock_key): | |||
| raise ValueError("Another knowledge base metadata operation is running, please wait a moment.") | |||
| redis_client.set(lock_key, 1, ex=3600) | |||
| if document_id: | |||
| lock_key = f"document_metadata_lock_{document_id}" | |||
| if redis_client.get(lock_key): | |||
| raise ValueError("Another document metadata operation is running, please wait a moment.") | |||
| redis_client.set(lock_key, 1, ex=3600) | |||
| @staticmethod | |||
| def get_dataset_metadatas(dataset: Dataset): | |||
| return { | |||
| "doc_metadata": [ | |||
| { | |||
| "id": item.get("id"), | |||
| "name": item.get("name"), | |||
| "type": item.get("type"), | |||
| "count": DatasetMetadataBinding.query.filter_by( | |||
| metadata_id=item.get("id"), dataset_id=dataset.id | |||
| ).count(), | |||
| } | |||
| for item in dataset.doc_metadata or [] | |||
| if item.get("id") != "built-in" | |||
| ], | |||
| "built_in_field_enabled": dataset.built_in_field_enabled, | |||
| } | |||
| @@ -20,7 +20,7 @@ class TagService: | |||
| ) | |||
| if keyword: | |||
| query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) | |||
| query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) | |||
| query = query.group_by(Tag.id, Tag.type, Tag.name) | |||
| results: list = query.order_by(Tag.created_at.desc()).all() | |||
| return results | |||