| datasets_segments, | datasets_segments, | ||||
| external, | external, | ||||
| hit_testing, | hit_testing, | ||||
| metadata, | |||||
| website, | website, | ||||
| ) | ) | ||||
| raise InvalidMetadataError(f"Invalid metadata value: {metadata}") | raise InvalidMetadataError(f"Invalid metadata value: {metadata}") | ||||
| if metadata == "only": | if metadata == "only": | ||||
| response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} | |||||
| response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} | |||||
| elif metadata == "without": | elif metadata == "without": | ||||
| dataset_process_rules = DatasetService.get_process_rules(dataset_id) | dataset_process_rules = DatasetService.get_process_rules(dataset_id) | ||||
| document_process_rules = document.dataset_process_rule.to_dict() | document_process_rules = document.dataset_process_rule.to_dict() | ||||
| "disabled_by": document.disabled_by, | "disabled_by": document.disabled_by, | ||||
| "archived": document.archived, | "archived": document.archived, | ||||
| "doc_type": document.doc_type, | "doc_type": document.doc_type, | ||||
| "doc_metadata": document.doc_metadata, | |||||
| "doc_metadata": document.doc_metadata_details, | |||||
| "segment_count": document.segment_count, | "segment_count": document.segment_count, | ||||
| "average_segment_length": document.average_segment_length, | "average_segment_length": document.average_segment_length, | ||||
| "hit_count": document.hit_count, | "hit_count": document.hit_count, |
| from flask_login import current_user # type: ignore # type: ignore | |||||
| from flask_restful import Resource, marshal_with, reqparse # type: ignore | |||||
| from werkzeug.exceptions import NotFound | |||||
| from controllers.console import api | |||||
| from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required | |||||
| from fields.dataset_fields import dataset_metadata_fields | |||||
| from libs.login import login_required | |||||
| from services.dataset_service import DatasetService | |||||
| from services.entities.knowledge_entities.knowledge_entities import ( | |||||
| MetadataArgs, | |||||
| MetadataOperationData, | |||||
| ) | |||||
| from services.metadata_service import MetadataService | |||||
| def _validate_name(name): | |||||
| if not name or len(name) < 1 or len(name) > 40: | |||||
| raise ValueError("Name must be between 1 to 40 characters.") | |||||
| return name | |||||
| def _validate_description_length(description): | |||||
| if len(description) > 400: | |||||
| raise ValueError("Description cannot exceed 400 characters.") | |||||
| return description | |||||
| class DatasetMetadataCreateApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @enterprise_license_required | |||||
| @marshal_with(dataset_metadata_fields) | |||||
| def post(self, dataset_id): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("type", type=str, required=True, nullable=True, location="json") | |||||
| parser.add_argument("name", type=str, required=True, nullable=True, location="json") | |||||
| args = parser.parse_args() | |||||
| metadata_args = MetadataArgs(**args) | |||||
| dataset_id_str = str(dataset_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||||
| if dataset is None: | |||||
| raise NotFound("Dataset not found.") | |||||
| DatasetService.check_dataset_permission(dataset, current_user) | |||||
| metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) | |||||
| return metadata, 201 | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @enterprise_license_required | |||||
| def get(self, dataset_id): | |||||
| dataset_id_str = str(dataset_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||||
| if dataset is None: | |||||
| raise NotFound("Dataset not found.") | |||||
| return MetadataService.get_dataset_metadatas(dataset), 200 | |||||
| class DatasetMetadataApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @enterprise_license_required | |||||
| @marshal_with(dataset_metadata_fields) | |||||
| def patch(self, dataset_id, metadata_id): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("name", type=str, required=True, nullable=True, location="json") | |||||
| args = parser.parse_args() | |||||
| dataset_id_str = str(dataset_id) | |||||
| metadata_id_str = str(metadata_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||||
| if dataset is None: | |||||
| raise NotFound("Dataset not found.") | |||||
| DatasetService.check_dataset_permission(dataset, current_user) | |||||
| metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) | |||||
| return metadata, 200 | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @enterprise_license_required | |||||
| def delete(self, dataset_id, metadata_id): | |||||
| dataset_id_str = str(dataset_id) | |||||
| metadata_id_str = str(metadata_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||||
| if dataset is None: | |||||
| raise NotFound("Dataset not found.") | |||||
| DatasetService.check_dataset_permission(dataset, current_user) | |||||
| MetadataService.delete_metadata(dataset_id_str, metadata_id_str) | |||||
| return 200 | |||||
| class DatasetMetadataBuiltInFieldApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @enterprise_license_required | |||||
| def get(self): | |||||
| built_in_fields = MetadataService.get_built_in_fields() | |||||
| return {"fields": built_in_fields}, 200 | |||||
| class DatasetMetadataBuiltInFieldActionApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @enterprise_license_required | |||||
| def post(self, dataset_id, action): | |||||
| dataset_id_str = str(dataset_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||||
| if dataset is None: | |||||
| raise NotFound("Dataset not found.") | |||||
| DatasetService.check_dataset_permission(dataset, current_user) | |||||
| if action == "enable": | |||||
| MetadataService.enable_built_in_field(dataset) | |||||
| elif action == "disable": | |||||
| MetadataService.disable_built_in_field(dataset) | |||||
| return 200 | |||||
| class DocumentMetadataEditApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @enterprise_license_required | |||||
| def post(self, dataset_id): | |||||
| dataset_id_str = str(dataset_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||||
| if dataset is None: | |||||
| raise NotFound("Dataset not found.") | |||||
| DatasetService.check_dataset_permission(dataset, current_user) | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json") | |||||
| args = parser.parse_args() | |||||
| metadata_args = MetadataOperationData(**args) | |||||
| MetadataService.update_documents_metadata(dataset, metadata_args) | |||||
| return 200 | |||||
| api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata") | |||||
| api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>") | |||||
| api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in") | |||||
| api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>") | |||||
| api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata") |
| import uuid | import uuid | ||||
| from typing import Optional | from typing import Optional | ||||
| from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity | |||||
| from core.app.app_config.entities import ( | |||||
| DatasetEntity, | |||||
| DatasetRetrieveConfigEntity, | |||||
| MetadataFilteringCondition, | |||||
| ModelConfig, | |||||
| ) | |||||
| from core.entities.agent_entities import PlanningStrategy | from core.entities.agent_entities import PlanningStrategy | ||||
| from models.model import AppMode | from models.model import AppMode | ||||
| from services.dataset_service import DatasetService | from services.dataset_service import DatasetService | ||||
| retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( | retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( | ||||
| dataset_configs["retrieval_model"] | dataset_configs["retrieval_model"] | ||||
| ), | ), | ||||
| metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"), | |||||
| metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config")) | |||||
| if dataset_configs.get("metadata_model_config") | |||||
| else None, | |||||
| metadata_filtering_conditions=MetadataFilteringCondition( | |||||
| **dataset_configs.get("metadata_filtering_conditions", {}) | |||||
| ) | |||||
| if dataset_configs.get("metadata_filtering_conditions") | |||||
| else None, | |||||
| ), | ), | ||||
| ) | ) | ||||
| else: | else: | ||||
| weights=dataset_configs.get("weights"), | weights=dataset_configs.get("weights"), | ||||
| reranking_enabled=dataset_configs.get("reranking_enabled", True), | reranking_enabled=dataset_configs.get("reranking_enabled", True), | ||||
| rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), | rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), | ||||
| metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"), | |||||
| metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config")) | |||||
| if dataset_configs.get("metadata_model_config") | |||||
| else None, | |||||
| metadata_filtering_conditions=MetadataFilteringCondition( | |||||
| **dataset_configs.get("metadata_filtering_conditions", {}) | |||||
| ) | |||||
| if dataset_configs.get("metadata_filtering_conditions") | |||||
| else None, | |||||
| ), | ), | ||||
| ) | ) | ||||
| from collections.abc import Sequence | from collections.abc import Sequence | ||||
| from enum import Enum, StrEnum | from enum import Enum, StrEnum | ||||
| from typing import Any, Optional | |||||
| from typing import Any, Literal, Optional | |||||
| from pydantic import BaseModel, Field, field_validator | from pydantic import BaseModel, Field, field_validator | ||||
| from core.file import FileTransferMethod, FileType, FileUploadConfig | from core.file import FileTransferMethod, FileType, FileUploadConfig | ||||
| from core.model_runtime.entities.llm_entities import LLMMode | |||||
| from core.model_runtime.entities.message_entities import PromptMessageRole | from core.model_runtime.entities.message_entities import PromptMessageRole | ||||
| from models.model import AppMode | from models.model import AppMode | ||||
| config: dict[str, Any] = Field(default_factory=dict) | config: dict[str, Any] = Field(default_factory=dict) | ||||
| SupportedComparisonOperator = Literal[ | |||||
| # for string or array | |||||
| "contains", | |||||
| "not contains", | |||||
| "start with", | |||||
| "end with", | |||||
| "is", | |||||
| "is not", | |||||
| "empty", | |||||
| "not empty", | |||||
| # for number | |||||
| "=", | |||||
| "≠", | |||||
| ">", | |||||
| "<", | |||||
| "≥", | |||||
| "≤", | |||||
| # for time | |||||
| "before", | |||||
| "after", | |||||
| ] | |||||
| class ModelConfig(BaseModel): | |||||
| provider: str | |||||
| name: str | |||||
| mode: LLMMode | |||||
| completion_params: dict[str, Any] = {} | |||||
| class Condition(BaseModel): | |||||
| """ | |||||
| Conditon detail | |||||
| """ | |||||
| name: str | |||||
| comparison_operator: SupportedComparisonOperator | |||||
| value: str | Sequence[str] | None | int | float = None | |||||
| class MetadataFilteringCondition(BaseModel): | |||||
| """ | |||||
| Metadata Filtering Condition. | |||||
| """ | |||||
| logical_operator: Optional[Literal["and", "or"]] = "and" | |||||
| conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) | |||||
| class DatasetRetrieveConfigEntity(BaseModel): | class DatasetRetrieveConfigEntity(BaseModel): | ||||
| """ | """ | ||||
| Dataset Retrieve Config Entity. | Dataset Retrieve Config Entity. | ||||
| reranking_model: Optional[dict] = None | reranking_model: Optional[dict] = None | ||||
| weights: Optional[dict] = None | weights: Optional[dict] = None | ||||
| reranking_enabled: Optional[bool] = True | reranking_enabled: Optional[bool] = True | ||||
| metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" | |||||
| metadata_model_config: Optional[ModelConfig] = None | |||||
| metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None | |||||
| class DatasetEntity(BaseModel): | class DatasetEntity(BaseModel): |
| hit_callback=hit_callback, | hit_callback=hit_callback, | ||||
| memory=memory, | memory=memory, | ||||
| message_id=message.id, | message_id=message.id, | ||||
| inputs=inputs, | |||||
| ) | ) | ||||
| # reorganize all inputs and template to prompt messages | # reorganize all inputs and template to prompt messages |
| show_retrieve_source=app_config.additional_features.show_retrieve_source, | show_retrieve_source=app_config.additional_features.show_retrieve_source, | ||||
| hit_callback=hit_callback, | hit_callback=hit_callback, | ||||
| message_id=message.id, | message_id=message.id, | ||||
| inputs=inputs, | |||||
| ) | ) | ||||
| # reorganize all inputs and template to prompt messages | # reorganize all inputs and template to prompt messages |
| keyword_table = self._get_dataset_keyword_table() | keyword_table = self._get_dataset_keyword_table() | ||||
| k = kwargs.get("top_k", 4) | k = kwargs.get("top_k", 4) | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) | sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) | ||||
| documents = [] | documents = [] | ||||
| for chunk_index in sorted_chunk_indices: | for chunk_index in sorted_chunk_indices: | ||||
| segment = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index) | |||||
| .first() | |||||
| segment_query = db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index | |||||
| ) | ) | ||||
| if document_ids_filter: | |||||
| segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter)) | |||||
| segment = segment_query.first() | |||||
| if segment: | if segment: | ||||
| documents.append( | documents.append( |
| reranking_model: Optional[dict] = None, | reranking_model: Optional[dict] = None, | ||||
| reranking_mode: str = "reranking_model", | reranking_mode: str = "reranking_model", | ||||
| weights: Optional[dict] = None, | weights: Optional[dict] = None, | ||||
| document_ids_filter: Optional[list[str]] = None, | |||||
| ): | ): | ||||
| if not query: | if not query: | ||||
| return [] | return [] | ||||
| top_k=top_k, | top_k=top_k, | ||||
| all_documents=all_documents, | all_documents=all_documents, | ||||
| exceptions=exceptions, | exceptions=exceptions, | ||||
| document_ids_filter=document_ids_filter, | |||||
| ) | ) | ||||
| ) | ) | ||||
| if RetrievalMethod.is_support_semantic_search(retrieval_method): | if RetrievalMethod.is_support_semantic_search(retrieval_method): | ||||
| all_documents=all_documents, | all_documents=all_documents, | ||||
| retrieval_method=retrieval_method, | retrieval_method=retrieval_method, | ||||
| exceptions=exceptions, | exceptions=exceptions, | ||||
| document_ids_filter=document_ids_filter, | |||||
| ) | ) | ||||
| ) | ) | ||||
| if RetrievalMethod.is_support_fulltext_search(retrieval_method): | if RetrievalMethod.is_support_fulltext_search(retrieval_method): | ||||
| @classmethod | @classmethod | ||||
| def keyword_search( | def keyword_search( | ||||
| cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list | |||||
| cls, | |||||
| flask_app: Flask, | |||||
| dataset_id: str, | |||||
| query: str, | |||||
| top_k: int, | |||||
| all_documents: list, | |||||
| exceptions: list, | |||||
| document_ids_filter: Optional[list[str]] = None, | |||||
| ): | ): | ||||
| with flask_app.app_context(): | with flask_app.app_context(): | ||||
| try: | try: | ||||
| raise ValueError("dataset not found") | raise ValueError("dataset not found") | ||||
| keyword = Keyword(dataset=dataset) | keyword = Keyword(dataset=dataset) | ||||
| documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k) | |||||
| documents = keyword.search( | |||||
| cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter | |||||
| ) | |||||
| all_documents.extend(documents) | all_documents.extend(documents) | ||||
| except Exception as e: | except Exception as e: | ||||
| exceptions.append(str(e)) | exceptions.append(str(e)) | ||||
| all_documents: list, | all_documents: list, | ||||
| retrieval_method: str, | retrieval_method: str, | ||||
| exceptions: list, | exceptions: list, | ||||
| document_ids_filter: Optional[list[str]] = None, | |||||
| ): | ): | ||||
| with flask_app.app_context(): | with flask_app.app_context(): | ||||
| try: | try: | ||||
| top_k=top_k, | top_k=top_k, | ||||
| score_threshold=score_threshold, | score_threshold=score_threshold, | ||||
| filter={"group_id": [dataset.id]}, | filter={"group_id": [dataset.id]}, | ||||
| document_ids_filter=document_ids_filter, | |||||
| ) | ) | ||||
| if documents: | if documents: |
| self.analyticdb_vector.delete_by_metadata_field(key, value) | self.analyticdb_vector.delete_by_metadata_field(key, value) | ||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | ||||
| return self.analyticdb_vector.search_by_vector(query_vector) | |||||
| return self.analyticdb_vector.search_by_vector(query_vector, **kwargs) | |||||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | ||||
| return self.analyticdb_vector.search_by_full_text(query, **kwargs) | return self.analyticdb_vector.search_by_full_text(query, **kwargs) |
| top_k = kwargs.get("top_k", 4) | top_k = kwargs.get("top_k", 4) | ||||
| if not isinstance(top_k, int) or top_k <= 0: | if not isinstance(top_k, int) or top_k <= 0: | ||||
| raise ValueError("top_k must be a positive integer") | raise ValueError("top_k must be a positive integer") | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| where_clause = "WHERE 1=1" | |||||
| if document_ids_filter: | |||||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||||
| where_clause += f"AND metadata_->>'document_id' IN ({document_ids})" | |||||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | score_threshold = float(kwargs.get("score_threshold") or 0.0) | ||||
| with self._get_cursor() as cur: | with self._get_cursor() as cur: | ||||
| query_vector_str = json.dumps(query_vector) | query_vector_str = json.dumps(query_vector) | ||||
| f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, " | f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, " | ||||
| f"t.page_content as page_content, t.metadata_ AS metadata_ " | f"t.page_content as page_content, t.metadata_ AS metadata_ " | ||||
| f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score " | f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score " | ||||
| f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t", | |||||
| f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t", | |||||
| (query_vector_str,), | (query_vector_str,), | ||||
| ) | ) | ||||
| documents = [] | documents = [] | ||||
| top_k = kwargs.get("top_k", 4) | top_k = kwargs.get("top_k", 4) | ||||
| if not isinstance(top_k, int) or top_k <= 0: | if not isinstance(top_k, int) or top_k <= 0: | ||||
| raise ValueError("top_k must be a positive integer") | raise ValueError("top_k must be a positive integer") | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| where_clause = "" | |||||
| if document_ids_filter: | |||||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||||
| where_clause += f"AND metadata_->>'document_id' IN ({document_ids})" | |||||
| with self._get_cursor() as cur: | with self._get_cursor() as cur: | ||||
| cur.execute( | cur.execute( | ||||
| f"""SELECT id, vector, page_content, metadata_, | f"""SELECT id, vector, page_content, metadata_, | ||||
| ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score | ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score | ||||
| FROM {self.table_name} | FROM {self.table_name} | ||||
| WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') | |||||
| WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause} | |||||
| ORDER BY score DESC | ORDER BY score DESC | ||||
| LIMIT {top_k}""", | LIMIT {top_k}""", | ||||
| (f"'{query}'", f"'{query}'"), | (f"'{query}'", f"'{query}'"), |
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | ||||
| query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector] | query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector] | ||||
| anns = AnnSearch( | |||||
| vector_field=self.field_vector, | |||||
| vector_floats=query_vector, | |||||
| params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), | |||||
| ) | |||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||||
| anns = AnnSearch( | |||||
| vector_field=self.field_vector, | |||||
| vector_floats=query_vector, | |||||
| params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), | |||||
| filter=f"document_id IN ({document_ids})", | |||||
| ) | |||||
| else: | |||||
| anns = AnnSearch( | |||||
| vector_field=self.field_vector, | |||||
| vector_floats=query_vector, | |||||
| params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), | |||||
| ) | |||||
| res = self._db.table(self._collection_name).search( | res = self._db.table(self._collection_name).search( | ||||
| anns=anns, | anns=anns, | ||||
| projections=[self.field_id, self.field_text, self.field_metadata], | projections=[self.field_id, self.field_text, self.field_metadata], |
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | ||||
| collection = self._client.get_or_create_collection(self._collection_name) | collection = self._client.get_or_create_collection(self._collection_name) | ||||
| results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) | |||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| results: QueryResult = collection.query( | |||||
| query_embeddings=query_vector, | |||||
| n_results=kwargs.get("top_k", 4), | |||||
| where={"document_id": {"$in": document_ids_filter}}, # type: ignore | |||||
| ) | |||||
| else: | |||||
| results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore | |||||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | score_threshold = float(kwargs.get("score_threshold") or 0.0) | ||||
| # Check if results contain data | # Check if results contain data |
| top_k = kwargs.get("top_k", 4) | top_k = kwargs.get("top_k", 4) | ||||
| num_candidates = math.ceil(top_k * 1.5) | num_candidates = math.ceil(top_k * 1.5) | ||||
| knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates} | knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates} | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} | |||||
| results = self._client.search(index=self._collection_name, knn=knn, size=top_k) | results = self._client.search(index=self._collection_name, knn=knn, size=top_k) | ||||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | ||||
| query_str = {"match": {Field.CONTENT_KEY.value: query}} | query_str = {"match": {Field.CONTENT_KEY.value: query}} | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore | |||||
| results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) | results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) | ||||
| docs = [] | docs = [] | ||||
| for hit in results["hits"]["hits"]: | for hit in results["hits"]["hits"]: |
| raise ValueError("All elements in query_vector should be floats") | raise ValueError("All elements in query_vector should be floats") | ||||
| top_k = kwargs.get("top_k", 10) | top_k = kwargs.get("top_k", 10) | ||||
| query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs) | |||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| filters = [] | |||||
| if document_ids_filter: | |||||
| filters.append({"terms": {"metadata.document_id": document_ids_filter}}) | |||||
| query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs) | |||||
| try: | try: | ||||
| params = {} | params = {} | ||||
| if self._using_ugc: | if self._using_ugc: | ||||
| should = kwargs.get("should") | should = kwargs.get("should") | ||||
| minimum_should_match = kwargs.get("minimum_should_match", 0) | minimum_should_match = kwargs.get("minimum_should_match", 0) | ||||
| top_k = kwargs.get("top_k", 10) | top_k = kwargs.get("top_k", 10) | ||||
| filters = kwargs.get("filter") | |||||
| filters = kwargs.get("filter", []) | |||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| filters.append({"terms": {"metadata.document_id": document_ids_filter}}) | |||||
| routing = self._routing | routing = self._routing | ||||
| full_text_query = default_text_search_query( | full_text_query = default_text_search_query( | ||||
| query_text=query, | query_text=query, |
| """ | """ | ||||
| Search for documents by vector similarity. | Search for documents by vector similarity. | ||||
| """ | """ | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| filter = "" | |||||
| if document_ids_filter: | |||||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||||
| filter = f'metadata["document_id"] in ({document_ids})' | |||||
| results = self._client.search( | results = self._client.search( | ||||
| collection_name=self._collection_name, | collection_name=self._collection_name, | ||||
| data=[query_vector], | data=[query_vector], | ||||
| anns_field=Field.VECTOR.value, | anns_field=Field.VECTOR.value, | ||||
| limit=kwargs.get("top_k", 4), | limit=kwargs.get("top_k", 4), | ||||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | ||||
| filter=filter, | |||||
| ) | ) | ||||
| return self._process_search_results( | return self._process_search_results( | ||||
| if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value): | if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value): | ||||
| logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)") | logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)") | ||||
| return [] | return [] | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| filter = "" | |||||
| if document_ids_filter: | |||||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||||
| filter = f'metadata["document_id"] in ({document_ids})' | |||||
| results = self._client.search( | results = self._client.search( | ||||
| collection_name=self._collection_name, | collection_name=self._collection_name, | ||||
| anns_field=Field.SPARSE_VECTOR.value, | anns_field=Field.SPARSE_VECTOR.value, | ||||
| limit=kwargs.get("top_k", 4), | limit=kwargs.get("top_k", 4), | ||||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | ||||
| filter=filter, | |||||
| ) | ) | ||||
| return self._process_search_results( | return self._process_search_results( |
| if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 | if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 | ||||
| else "" | else "" | ||||
| ) | ) | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||||
| where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})" | |||||
| sql = f""" | sql = f""" | ||||
| SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} | SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} | ||||
| {where_str} ORDER BY dist {order.value} LIMIT {top_k} | {where_str} ORDER BY dist {order.value} LIMIT {top_k} |
| return [] | return [] | ||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| where_clause = None | |||||
| if document_ids_filter: | |||||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||||
| where_clause = f"metadata->>'$.document_id' in ({document_ids})" | |||||
| ef_search = kwargs.get("ef_search", self._hnsw_ef_search) | ef_search = kwargs.get("ef_search", self._hnsw_ef_search) | ||||
| if ef_search != self._hnsw_ef_search: | if ef_search != self._hnsw_ef_search: | ||||
| self._client.set_ob_hnsw_ef_search(ef_search) | self._client.set_ob_hnsw_ef_search(ef_search) | ||||
| distance_func=func.l2_distance, | distance_func=func.l2_distance, | ||||
| output_column_names=["text", "metadata"], | output_column_names=["text", "metadata"], | ||||
| with_dist=True, | with_dist=True, | ||||
| where_clause=where_clause, | |||||
| ) | ) | ||||
| docs = [] | docs = [] | ||||
| for text, metadata, distance in cur: | for text, metadata, distance in cur: |
| "size": kwargs.get("top_k", 4), | "size": kwargs.get("top_k", 4), | ||||
| "query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}}, | "query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}}, | ||||
| } | } | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| query["query"] = {"terms": {"metadata.document_id": document_ids_filter}} | |||||
| try: | try: | ||||
| response = self._client.search(index=self._collection_name.lower(), body=query) | response = self._client.search(index=self._collection_name.lower(), body=query) | ||||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | ||||
| full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}} | full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}} | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter} | |||||
| response = self._client.search(index=self._collection_name.lower(), body=full_text_query) | response = self._client.search(index=self._collection_name.lower(), body=full_text_query) | ||||
| :return: List of Documents that are nearest to the query vector. | :return: List of Documents that are nearest to the query vector. | ||||
| """ | """ | ||||
| top_k = kwargs.get("top_k", 4) | top_k = kwargs.get("top_k", 4) | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| where_clause = "" | |||||
| if document_ids_filter: | |||||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||||
| where_clause = f"WHERE metadata->>'document_id' in ({document_ids})" | |||||
| with self._get_cursor() as cur: | with self._get_cursor() as cur: | ||||
| cur.execute( | cur.execute( | ||||
| f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}" | f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}" | ||||
| f" ORDER BY distance fetch first {top_k} rows only", | |||||
| f" {where_clause} ORDER BY distance fetch first {top_k} rows only", | |||||
| [numpy.array(query_vector)], | [numpy.array(query_vector)], | ||||
| ) | ) | ||||
| docs = [] | docs = [] | ||||
| if token not in stop_words: | if token not in stop_words: | ||||
| entities.append(token) | entities.append(token) | ||||
| with self._get_cursor() as cur: | with self._get_cursor() as cur: | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| where_clause = "" | |||||
| if document_ids_filter: | |||||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||||
| where_clause = f" AND metadata->>'document_id' in ({document_ids}) " | |||||
| cur.execute( | cur.execute( | ||||
| f"select meta, text, embedding FROM {self.table_name}" | f"select meta, text, embedding FROM {self.table_name}" | ||||
| f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", | |||||
| f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} " | |||||
| f"order by score(1) desc fetch first {top_k} rows only", | |||||
| [" ACCUM ".join(entities)], | [" ACCUM ".join(entities)], | ||||
| ) | ) | ||||
| docs = [] | docs = [] |
| .limit(kwargs.get("top_k", 4)) | .limit(kwargs.get("top_k", 4)) | ||||
| .order_by("distance") | .order_by("distance") | ||||
| ) | ) | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| stmt = stmt.where(self._table.meta["document_id"].in_(document_ids_filter)) | |||||
| res = session.execute(stmt) | res = session.execute(stmt) | ||||
| results = [(row[0], row[1]) for row in res] | results = [(row[0], row[1]) for row in res] | ||||
| top_k = kwargs.get("top_k", 4) | top_k = kwargs.get("top_k", 4) | ||||
| if not isinstance(top_k, int) or top_k <= 0: | if not isinstance(top_k, int) or top_k <= 0: | ||||
| raise ValueError("top_k must be a positive integer") | raise ValueError("top_k must be a positive integer") | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| where_clause = "" | |||||
| if document_ids_filter: | |||||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||||
| where_clause = f" WHERE metadata->>'document_id' in ({document_ids}) " | |||||
| with self._get_cursor() as cur: | with self._get_cursor() as cur: | ||||
| cur.execute( | cur.execute( | ||||
| f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}" | f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}" | ||||
| f" {where_clause}" | |||||
| f" ORDER BY distance LIMIT {top_k}", | f" ORDER BY distance LIMIT {top_k}", | ||||
| (json.dumps(query_vector),), | (json.dumps(query_vector),), | ||||
| ) | ) | ||||
| if not isinstance(top_k, int) or top_k <= 0: | if not isinstance(top_k, int) or top_k <= 0: | ||||
| raise ValueError("top_k must be a positive integer") | raise ValueError("top_k must be a positive integer") | ||||
| with self._get_cursor() as cur: | with self._get_cursor() as cur: | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| where_clause = "" | |||||
| if document_ids_filter: | |||||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||||
| where_clause = f" AND metadata->>'document_id' in ({document_ids}) " | |||||
| if self.pg_bigm: | if self.pg_bigm: | ||||
| cur.execute("SET pg_bigm.similarity_limit TO 0.000001") | cur.execute("SET pg_bigm.similarity_limit TO 0.000001") | ||||
| cur.execute( | cur.execute( | ||||
| f"""SELECT meta, text, bigm_similarity(unistr(%s), coalesce(text, '')) AS score | f"""SELECT meta, text, bigm_similarity(unistr(%s), coalesce(text, '')) AS score | ||||
| FROM {self.table_name} | FROM {self.table_name} | ||||
| WHERE text =%% unistr(%s) | WHERE text =%% unistr(%s) | ||||
| {where_clause} | |||||
| ORDER BY score DESC | ORDER BY score DESC | ||||
| LIMIT {top_k}""", | LIMIT {top_k}""", | ||||
| # f"'{query}'" is required in order to account for whitespace in query | # f"'{query}'" is required in order to account for whitespace in query | ||||
| f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score | f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score | ||||
| FROM {self.table_name} | FROM {self.table_name} | ||||
| WHERE to_tsvector(text) @@ plainto_tsquery(%s) | WHERE to_tsvector(text) @@ plainto_tsquery(%s) | ||||
| {where_clause} | |||||
| ORDER BY score DESC | ORDER BY score DESC | ||||
| LIMIT {top_k}""", | LIMIT {top_k}""", | ||||
| # f"'{query}'" is required in order to account for whitespace in query | # f"'{query}'" is required in order to account for whitespace in query |
| from qdrant_client.http import models | from qdrant_client.http import models | ||||
| from qdrant_client.http.exceptions import UnexpectedResponse | from qdrant_client.http.exceptions import UnexpectedResponse | ||||
| for node_id in ids: | |||||
| try: | |||||
| filter = models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="metadata.doc_id", | |||||
| match=models.MatchValue(value=node_id), | |||||
| ), | |||||
| ], | |||||
| ) | |||||
| self._client.delete( | |||||
| collection_name=self._collection_name, | |||||
| points_selector=FilterSelector(filter=filter), | |||||
| ) | |||||
| except UnexpectedResponse as e: | |||||
| # Collection does not exist, so return | |||||
| if e.status_code == 404: | |||||
| return | |||||
| # Some other error occurred, so re-raise the exception | |||||
| else: | |||||
| raise e | |||||
| try: | |||||
| filter = models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="metadata.doc_id", | |||||
| match=models.MatchAny(any=ids), | |||||
| ), | |||||
| ], | |||||
| ) | |||||
| self._client.delete( | |||||
| collection_name=self._collection_name, | |||||
| points_selector=FilterSelector(filter=filter), | |||||
| ) | |||||
| except UnexpectedResponse as e: | |||||
| # Collection does not exist, so return | |||||
| if e.status_code == 404: | |||||
| return | |||||
| # Some other error occurred, so re-raise the exception | |||||
| else: | |||||
| raise e | |||||
| def text_exists(self, id: str) -> bool: | def text_exists(self, id: str) -> bool: | ||||
| all_collection_name = [] | all_collection_name = [] | ||||
| ), | ), | ||||
| ], | ], | ||||
| ) | ) | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| if filter.must: | |||||
| filter.must.append( | |||||
| models.FieldCondition( | |||||
| key="metadata.document_id", | |||||
| match=models.MatchAny(any=document_ids_filter), | |||||
| ) | |||||
| ) | |||||
| results = self._client.search( | results = self._client.search( | ||||
| collection_name=self._collection_name, | collection_name=self._collection_name, | ||||
| query_vector=query_vector, | query_vector=query_vector, | ||||
| ), | ), | ||||
| ] | ] | ||||
| ) | ) | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| if scroll_filter.must: | |||||
| scroll_filter.must.append( | |||||
| models.FieldCondition( | |||||
| key="metadata.document_id", | |||||
| match=models.MatchAny(any=document_ids_filter), | |||||
| ) | |||||
| ) | |||||
| response = self._client.scroll( | response = self._client.scroll( | ||||
| collection_name=self._collection_name, | collection_name=self._collection_name, | ||||
| scroll_filter=scroll_filter, | scroll_filter=scroll_filter, |
| return len(result) > 0 | return len(result) > 0 | ||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| filter = kwargs.get("filter", {}) | |||||
| if document_ids_filter: | |||||
| filter["document_id"] = document_ids_filter | |||||
| results = self.similarity_search_with_score_by_vector( | results = self.similarity_search_with_score_by_vector( | ||||
| k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter") | |||||
| k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter | |||||
| ) | ) | ||||
| # Organize results. | # Organize results. | ||||
| filter_condition = "" | filter_condition = "" | ||||
| if filter is not None: | if filter is not None: | ||||
| conditions = [ | conditions = [ | ||||
| f"metadata->>{key!r} in ({', '.join(map(repr, value))})" | |||||
| f"metadata->>'{key!r}' in ({', '.join(map(repr, value))})" | |||||
| if len(value) > 1 | if len(value) > 1 | ||||
| else f"metadata->>{key!r} = {value[0]!r}" | |||||
| else f"metadata->>'{key!r}' = {value[0]!r}" | |||||
| for key, value in filter.items() | for key, value in filter.items() | ||||
| ] | ] | ||||
| filter_condition = f"WHERE {' AND '.join(conditions)}" | filter_condition = f"WHERE {' AND '.join(conditions)}" |
| self._db.collection(self._collection_name).delete(document_ids=ids) | self._db.collection(self._collection_name).delete(document_ids=ids) | ||||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | def delete_by_metadata_field(self, key: str, value: str) -> None: | ||||
| self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value]))) | |||||
| self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value]))) | |||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| filter = None | |||||
| if document_ids_filter: | |||||
| filter = Filter(Filter.In("metadata.document_id", document_ids_filter)) | |||||
| res = self._db.collection(self._collection_name).search( | res = self._db.collection(self._collection_name).search( | ||||
| vectors=[query_vector], | vectors=[query_vector], | ||||
| filter=filter, | |||||
| params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)), | params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)), | ||||
| retrieve_vector=False, | retrieve_vector=False, | ||||
| limit=kwargs.get("top_k", 4), | limit=kwargs.get("top_k", 4), |
| ), | ), | ||||
| ], | ], | ||||
| ) | ) | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| should_conditions = [] | |||||
| for document_id_filter in document_ids_filter: | |||||
| should_conditions.append( | |||||
| models.FieldCondition( | |||||
| key="metadata.document_id", | |||||
| match=models.MatchValue(value=document_id_filter), | |||||
| ) | |||||
| ) | |||||
| if should_conditions: | |||||
| filter.should = should_conditions # type: ignore | |||||
| results = self._client.search( | results = self._client.search( | ||||
| collection_name=self._collection_name, | collection_name=self._collection_name, | ||||
| query_vector=query_vector, | query_vector=query_vector, | ||||
| ) | ) | ||||
| ] | ] | ||||
| ) | ) | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| should_conditions = [] | |||||
| for document_id_filter in document_ids_filter: | |||||
| should_conditions.append( | |||||
| models.FieldCondition( | |||||
| key="metadata.document_id", | |||||
| match=models.MatchValue(value=document_id_filter), | |||||
| ) | |||||
| ) | |||||
| if should_conditions: | |||||
| scroll_filter.should = should_conditions # type: ignore | |||||
| response = self._client.scroll( | response = self._client.scroll( | ||||
| collection_name=self._collection_name, | collection_name=self._collection_name, | ||||
| scroll_filter=scroll_filter, | scroll_filter=scroll_filter, |
| docs = [] | docs = [] | ||||
| tidb_dist_func = self._get_distance_func() | tidb_dist_func = self._get_distance_func() | ||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| where_clause = "" | |||||
| if document_ids_filter: | |||||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||||
| where_clause = f" WHERE meta->>'$.document_id' in ({document_ids}) " | |||||
| with Session(self._engine) as session: | with Session(self._engine) as session: | ||||
| select_statement = sql_text(f""" | select_statement = sql_text(f""" | ||||
| text, | text, | ||||
| {tidb_dist_func}(vector, :query_vector_str) AS distance | {tidb_dist_func}(vector, :query_vector_str) AS distance | ||||
| FROM {self._collection_name} | FROM {self._collection_name} | ||||
| {where_clause} | |||||
| ORDER BY distance ASC | ORDER BY distance ASC | ||||
| LIMIT :top_k | LIMIT :top_k | ||||
| ) t | ) t |
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | ||||
| top_k = kwargs.get("top_k", 4) | top_k = kwargs.get("top_k", 4) | ||||
| result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True) | |||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||||
| filter = f"document_id in ({document_ids})" | |||||
| else: | |||||
| filter = "" | |||||
| result = self.index.query( | |||||
| vector=query_vector, | |||||
| top_k=top_k, | |||||
| include_metadata=True, | |||||
| include_data=True, | |||||
| include_vectors=False, | |||||
| filter=filter, | |||||
| ) | |||||
| docs = [] | docs = [] | ||||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | score_threshold = float(kwargs.get("score_threshold") or 0.0) | ||||
| for record in result: | for record in result: |
| query_vector, limit=kwargs.get("top_k", 4) | query_vector, limit=kwargs.get("top_k", 4) | ||||
| ) | ) | ||||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | score_threshold = float(kwargs.get("score_threshold") or 0.0) | ||||
| return self._get_search_res(results, score_threshold) | |||||
| docs = self._get_search_res(results, score_threshold) | |||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| docs = [doc for doc in docs if doc.metadata.get("document_id") in document_ids_filter] | |||||
| return docs | |||||
| def _get_search_res(self, results, score_threshold) -> list[Document]: | def _get_search_res(self, results, score_threshold) -> list[Document]: | ||||
| if len(results) == 0: | if len(results) == 0: |
| query_obj = self._client.query.get(collection_name, properties) | query_obj = self._client.query.get(collection_name, properties) | ||||
| vector = {"vector": query_vector} | vector = {"vector": query_vector} | ||||
| if kwargs.get("where_filter"): | |||||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter} | |||||
| query_obj = query_obj.with_where(where_filter) | |||||
| result = ( | result = ( | ||||
| query_obj.with_near_vector(vector) | query_obj.with_near_vector(vector) | ||||
| .with_limit(kwargs.get("top_k", 4)) | .with_limit(kwargs.get("top_k", 4)) | ||||
| if kwargs.get("search_distance"): | if kwargs.get("search_distance"): | ||||
| content["certainty"] = kwargs.get("search_distance") | content["certainty"] = kwargs.get("search_distance") | ||||
| query_obj = self._client.query.get(collection_name, properties) | query_obj = self._client.query.get(collection_name, properties) | ||||
| if kwargs.get("where_filter"): | |||||
| query_obj = query_obj.with_where(kwargs.get("where_filter")) | |||||
| document_ids_filter = kwargs.get("document_ids_filter") | |||||
| if document_ids_filter: | |||||
| where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter} | |||||
| query_obj = query_obj.with_where(where_filter) | |||||
| query_obj = query_obj.with_additional(["vector"]) | query_obj = query_obj.with_additional(["vector"]) | ||||
| properties = ["text"] | properties = ["text"] | ||||
| result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do() | result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do() |
| from collections.abc import Sequence | |||||
| from typing import Literal, Optional | |||||
| from pydantic import BaseModel, Field | |||||
| SupportedComparisonOperator = Literal[ | |||||
| # for string or array | |||||
| "contains", | |||||
| "not contains", | |||||
| "start with", | |||||
| "end with", | |||||
| "is", | |||||
| "is not", | |||||
| "empty", | |||||
| "not empty", | |||||
| # for number | |||||
| "=", | |||||
| "≠", | |||||
| ">", | |||||
| "<", | |||||
| "≥", | |||||
| "≤", | |||||
| # for time | |||||
| "before", | |||||
| "after", | |||||
| ] | |||||
| class Condition(BaseModel): | |||||
| """ | |||||
| Conditon detail | |||||
| """ | |||||
| name: str | |||||
| comparison_operator: SupportedComparisonOperator | |||||
| value: str | Sequence[str] | None | int | float = None | |||||
| class MetadataCondition(BaseModel): | |||||
| """ | |||||
| Metadata Condition. | |||||
| """ | |||||
| logical_operator: Optional[Literal["and", "or"]] = "and" | |||||
| conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) |
| from enum import Enum | |||||
| class BuiltInField(str, Enum): | |||||
| document_name = "document_name" | |||||
| uploader = "uploader" | |||||
| upload_date = "upload_date" | |||||
| last_update_date = "last_update_date" | |||||
| source = "source" | |||||
| class MetadataDataSource(Enum): | |||||
| upload_file = "file_upload" | |||||
| website_crawl = "website" | |||||
| notion_import = "notion" |
| import json | |||||
| import math | import math | ||||
| import re | |||||
| import threading | import threading | ||||
| from collections import Counter | |||||
| from typing import Any, Optional, cast | |||||
| from collections import Counter, defaultdict | |||||
| from collections.abc import Generator, Mapping | |||||
| from typing import Any, Optional, Union, cast | |||||
| from flask import Flask, current_app | from flask import Flask, current_app | ||||
| from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity | |||||
| from sqlalchemy import Integer, and_, or_, text | |||||
| from sqlalchemy import cast as sqlalchemy_cast | |||||
| from core.app.app_config.entities import ( | |||||
| DatasetEntity, | |||||
| DatasetRetrieveConfigEntity, | |||||
| MetadataFilteringCondition, | |||||
| ModelConfig, | |||||
| ) | |||||
| from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity | from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity | ||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | ||||
| from core.entities.agent_entities import PlanningStrategy | from core.entities.agent_entities import PlanningStrategy | ||||
| from core.entities.model_entities import ModelStatus | |||||
| from core.memory.token_buffer_memory import TokenBufferMemory | from core.memory.token_buffer_memory import TokenBufferMemory | ||||
| from core.model_manager import ModelInstance, ModelManager | from core.model_manager import ModelInstance, ModelManager | ||||
| from core.model_runtime.entities.message_entities import PromptMessageTool | |||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | |||||
| from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool | |||||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | from core.model_runtime.entities.model_entities import ModelFeature, ModelType | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.ops.entities.trace_entity import TraceTaskName | from core.ops.entities.trace_entity import TraceTaskName | ||||
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | from core.ops.ops_trace_manager import TraceQueueManager, TraceTask | ||||
| from core.ops.utils import measure_time | from core.ops.utils import measure_time | ||||
| from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | |||||
| from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate | |||||
| from core.prompt.simple_prompt_transform import ModelMode | |||||
| from core.rag.data_post_processor.data_post_processor import DataPostProcessor | from core.rag.data_post_processor.data_post_processor import DataPostProcessor | ||||
| from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | ||||
| from core.rag.datasource.retrieval_service import RetrievalService | from core.rag.datasource.retrieval_service import RetrievalService | ||||
| from core.rag.entities.context_entities import DocumentContext | from core.rag.entities.context_entities import DocumentContext | ||||
| from core.rag.entities.metadata_entities import Condition, MetadataCondition | |||||
| from core.rag.index_processor.constant.index_type import IndexType | from core.rag.index_processor.constant.index_type import IndexType | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from core.rag.rerank.rerank_type import RerankMode | from core.rag.rerank.rerank_type import RerankMode | ||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter | from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter | ||||
| from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter | from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter | ||||
| from core.rag.retrieval.template_prompts import ( | |||||
| METADATA_FILTER_ASSISTANT_PROMPT_1, | |||||
| METADATA_FILTER_ASSISTANT_PROMPT_2, | |||||
| METADATA_FILTER_COMPLETION_PROMPT, | |||||
| METADATA_FILTER_SYSTEM_PROMPT, | |||||
| METADATA_FILTER_USER_PROMPT_1, | |||||
| METADATA_FILTER_USER_PROMPT_2, | |||||
| METADATA_FILTER_USER_PROMPT_3, | |||||
| ) | |||||
| from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import ChildChunk, Dataset, DatasetQuery, DocumentSegment | |||||
| from libs.json_in_md_parser import parse_and_check_json_markdown | |||||
| from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment | |||||
| from models.dataset import Document as DatasetDocument | from models.dataset import Document as DatasetDocument | ||||
| from services.external_knowledge_service import ExternalDatasetService | from services.external_knowledge_service import ExternalDatasetService | ||||
| hit_callback: DatasetIndexToolCallbackHandler, | hit_callback: DatasetIndexToolCallbackHandler, | ||||
| message_id: str, | message_id: str, | ||||
| memory: Optional[TokenBufferMemory] = None, | memory: Optional[TokenBufferMemory] = None, | ||||
| inputs: Optional[Mapping[str, Any]] = None, | |||||
| ) -> Optional[str]: | ) -> Optional[str]: | ||||
| """ | """ | ||||
| Retrieve dataset. | Retrieve dataset. | ||||
| continue | continue | ||||
| available_datasets.append(dataset) | available_datasets.append(dataset) | ||||
| if inputs: | |||||
| inputs = {key: str(value) for key, value in inputs.items()} | |||||
| else: | |||||
| inputs = {} | |||||
| available_datasets_ids = [dataset.id for dataset in available_datasets] | |||||
| metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( | |||||
| available_datasets_ids, | |||||
| query, | |||||
| tenant_id, | |||||
| user_id, | |||||
| retrieve_config.metadata_filtering_mode, # type: ignore | |||||
| retrieve_config.metadata_model_config, # type: ignore | |||||
| retrieve_config.metadata_filtering_conditions, | |||||
| inputs, | |||||
| ) | |||||
| all_documents = [] | all_documents = [] | ||||
| user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" | user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" | ||||
| if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: | if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: | ||||
| model_config, | model_config, | ||||
| planning_strategy, | planning_strategy, | ||||
| message_id, | message_id, | ||||
| metadata_filter_document_ids, | |||||
| metadata_condition, | |||||
| ) | ) | ||||
| elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: | elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: | ||||
| all_documents = self.multiple_retrieve( | all_documents = self.multiple_retrieve( | ||||
| retrieve_config.weights, | retrieve_config.weights, | ||||
| retrieve_config.reranking_enabled or True, | retrieve_config.reranking_enabled or True, | ||||
| message_id, | message_id, | ||||
| metadata_filter_document_ids, | |||||
| metadata_condition, | |||||
| ) | ) | ||||
| dify_documents = [item for item in all_documents if item.provider == "dify"] | dify_documents = [item for item in all_documents if item.provider == "dify"] | ||||
| model_config: ModelConfigWithCredentialsEntity, | model_config: ModelConfigWithCredentialsEntity, | ||||
| planning_strategy: PlanningStrategy, | planning_strategy: PlanningStrategy, | ||||
| message_id: Optional[str] = None, | message_id: Optional[str] = None, | ||||
| metadata_filter_document_ids: Optional[dict[str, list[str]]] = None, | |||||
| metadata_condition: Optional[MetadataCondition] = None, | |||||
| ): | ): | ||||
| tools = [] | tools = [] | ||||
| for dataset in available_datasets: | for dataset in available_datasets: | ||||
| dataset_id=dataset_id, | dataset_id=dataset_id, | ||||
| query=query, | query=query, | ||||
| external_retrieval_parameters=dataset.retrieval_model, | external_retrieval_parameters=dataset.retrieval_model, | ||||
| metadata_condition=metadata_condition, | |||||
| ) | ) | ||||
| for external_document in external_documents: | for external_document in external_documents: | ||||
| document = Document( | document = Document( | ||||
| document.metadata["dataset_name"] = dataset.name | document.metadata["dataset_name"] = dataset.name | ||||
| results.append(document) | results.append(document) | ||||
| else: | else: | ||||
| if metadata_condition and not metadata_filter_document_ids: | |||||
| return [] | |||||
| document_ids_filter = None | |||||
| if metadata_filter_document_ids: | |||||
| document_ids = metadata_filter_document_ids.get(dataset.id, []) | |||||
| if document_ids: | |||||
| document_ids_filter = document_ids | |||||
| else: | |||||
| return [] | |||||
| retrieval_model_config = dataset.retrieval_model or default_retrieval_model | retrieval_model_config = dataset.retrieval_model or default_retrieval_model | ||||
| # get top k | # get top k | ||||
| reranking_model=reranking_model, | reranking_model=reranking_model, | ||||
| reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"), | reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"), | ||||
| weights=retrieval_model_config.get("weights", None), | weights=retrieval_model_config.get("weights", None), | ||||
| document_ids_filter=document_ids_filter, | |||||
| ) | ) | ||||
| self._on_query(query, [dataset_id], app_id, user_from, user_id) | self._on_query(query, [dataset_id], app_id, user_from, user_id) | ||||
| weights: Optional[dict[str, Any]] = None, | weights: Optional[dict[str, Any]] = None, | ||||
| reranking_enable: bool = True, | reranking_enable: bool = True, | ||||
| message_id: Optional[str] = None, | message_id: Optional[str] = None, | ||||
| metadata_filter_document_ids: Optional[dict[str, list[str]]] = None, | |||||
| metadata_condition: Optional[MetadataCondition] = None, | |||||
| ): | ): | ||||
| if not available_datasets: | if not available_datasets: | ||||
| return [] | return [] | ||||
| for dataset in available_datasets: | for dataset in available_datasets: | ||||
| index_type = dataset.indexing_technique | index_type = dataset.indexing_technique | ||||
| document_ids_filter = None | |||||
| if dataset.provider != "external": | |||||
| if metadata_condition and not metadata_filter_document_ids: | |||||
| continue | |||||
| if metadata_filter_document_ids: | |||||
| document_ids = metadata_filter_document_ids.get(dataset.id, []) | |||||
| if document_ids: | |||||
| document_ids_filter = document_ids | |||||
| else: | |||||
| continue | |||||
| retrieval_thread = threading.Thread( | retrieval_thread = threading.Thread( | ||||
| target=self._retriever, | target=self._retriever, | ||||
| kwargs={ | kwargs={ | ||||
| "query": query, | "query": query, | ||||
| "top_k": top_k, | "top_k": top_k, | ||||
| "all_documents": all_documents, | "all_documents": all_documents, | ||||
| "document_ids_filter": document_ids_filter, | |||||
| "metadata_condition": metadata_condition, | |||||
| }, | }, | ||||
| ) | ) | ||||
| threads.append(retrieval_thread) | threads.append(retrieval_thread) | ||||
| db.session.add_all(dataset_queries) | db.session.add_all(dataset_queries) | ||||
| db.session.commit() | db.session.commit() | ||||
| def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): | |||||
| def _retriever( | |||||
| self, | |||||
| flask_app: Flask, | |||||
| dataset_id: str, | |||||
| query: str, | |||||
| top_k: int, | |||||
| all_documents: list, | |||||
| document_ids_filter: Optional[list[str]] = None, | |||||
| metadata_condition: Optional[MetadataCondition] = None, | |||||
| ): | |||||
| with flask_app.app_context(): | with flask_app.app_context(): | ||||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | ||||
| dataset_id=dataset_id, | dataset_id=dataset_id, | ||||
| query=query, | query=query, | ||||
| external_retrieval_parameters=dataset.retrieval_model, | external_retrieval_parameters=dataset.retrieval_model, | ||||
| metadata_condition=metadata_condition, | |||||
| ) | ) | ||||
| for external_document in external_documents: | for external_document in external_documents: | ||||
| document = Document( | document = Document( | ||||
| else None, | else None, | ||||
| reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | ||||
| weights=retrieval_model.get("weights", None), | weights=retrieval_model.get("weights", None), | ||||
| document_ids_filter=document_ids_filter, | |||||
| ) | ) | ||||
| all_documents.extend(documents) | all_documents.extend(documents) | ||||
| filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True | filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True | ||||
| ) | ) | ||||
| return filter_documents[:top_k] if top_k else filter_documents | return filter_documents[:top_k] if top_k else filter_documents | ||||
| def _get_metadata_filter_condition( | |||||
| self, | |||||
| dataset_ids: list, | |||||
| query: str, | |||||
| tenant_id: str, | |||||
| user_id: str, | |||||
| metadata_filtering_mode: str, | |||||
| metadata_model_config: ModelConfig, | |||||
| metadata_filtering_conditions: Optional[MetadataFilteringCondition], | |||||
| inputs: dict, | |||||
| ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: | |||||
| document_query = db.session.query(DatasetDocument).filter( | |||||
| DatasetDocument.dataset_id.in_(dataset_ids), | |||||
| DatasetDocument.indexing_status == "completed", | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ) | |||||
| filters = [] # type: ignore | |||||
| metadata_condition = None | |||||
| if metadata_filtering_mode == "disabled": | |||||
| return None, None | |||||
| elif metadata_filtering_mode == "automatic": | |||||
| automatic_metadata_filters = self._automatic_metadata_filter_func( | |||||
| dataset_ids, query, tenant_id, user_id, metadata_model_config | |||||
| ) | |||||
| if automatic_metadata_filters: | |||||
| conditions = [] | |||||
| for filter in automatic_metadata_filters: | |||||
| self._process_metadata_filter_func( | |||||
| filter.get("condition"), # type: ignore | |||||
| filter.get("metadata_name"), # type: ignore | |||||
| filter.get("value"), | |||||
| filters, # type: ignore | |||||
| ) | |||||
| conditions.append( | |||||
| Condition( | |||||
| name=filter.get("metadata_name"), # type: ignore | |||||
| comparison_operator=filter.get("condition"), # type: ignore | |||||
| value=filter.get("value"), | |||||
| ) | |||||
| ) | |||||
| metadata_condition = MetadataCondition( | |||||
| logical_operator=metadata_filtering_conditions.logical_operator, # type: ignore | |||||
| conditions=conditions, | |||||
| ) | |||||
| elif metadata_filtering_mode == "manual": | |||||
| if metadata_filtering_conditions: | |||||
| metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump()) | |||||
| for condition in metadata_filtering_conditions.conditions: # type: ignore | |||||
| metadata_name = condition.name | |||||
| expected_value = condition.value | |||||
| if expected_value or condition.comparison_operator in ("empty", "not empty"): | |||||
| if isinstance(expected_value, str): | |||||
| expected_value = self._replace_metadata_filter_value(expected_value, inputs) | |||||
| filters = self._process_metadata_filter_func( | |||||
| condition.comparison_operator, metadata_name, expected_value, filters | |||||
| ) | |||||
| else: | |||||
| raise ValueError("Invalid metadata filtering mode") | |||||
| if filters: | |||||
| if metadata_filtering_conditions.logical_operator == "or": # type: ignore | |||||
| document_query = document_query.filter(or_(*filters)) | |||||
| else: | |||||
| document_query = document_query.filter(and_(*filters)) | |||||
| documents = document_query.all() | |||||
| # group by dataset_id | |||||
| metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore | |||||
| for document in documents: | |||||
| metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore | |||||
| return metadata_filter_document_ids, metadata_condition | |||||
| def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str: | |||||
| def replacer(match): | |||||
| key = match.group(1) | |||||
| return str(inputs.get(key, f"{{{{{key}}}}}")) | |||||
| pattern = re.compile(r"\{\{(\w+)\}\}") | |||||
| return pattern.sub(replacer, text) | |||||
| def _automatic_metadata_filter_func( | |||||
| self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig | |||||
| ) -> Optional[list[dict[str, Any]]]: | |||||
| # get all metadata field | |||||
| metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() | |||||
| all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] | |||||
| # get metadata model config | |||||
| if metadata_model_config is None: | |||||
| raise ValueError("metadata_model_config is required") | |||||
| # get metadata model instance | |||||
| # fetch model config | |||||
| model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config) | |||||
| # fetch prompt messages | |||||
| prompt_messages, stop = self._get_prompt_template( | |||||
| model_config=model_config, | |||||
| mode=metadata_model_config.mode, | |||||
| metadata_fields=all_metadata_fields, | |||||
| query=query or "", | |||||
| ) | |||||
| result_text = "" | |||||
| try: | |||||
| # handle invoke result | |||||
| invoke_result = cast( | |||||
| Generator[LLMResult, None, None], | |||||
| model_instance.invoke_llm( | |||||
| prompt_messages=prompt_messages, | |||||
| model_parameters=model_config.parameters, | |||||
| stop=stop, | |||||
| stream=True, | |||||
| user=user_id, | |||||
| ), | |||||
| ) | |||||
| # handle invoke result | |||||
| result_text, usage = self._handle_invoke_result(invoke_result=invoke_result) | |||||
| result_text_json = parse_and_check_json_markdown(result_text, []) | |||||
| automatic_metadata_filters = [] | |||||
| if "metadata_map" in result_text_json: | |||||
| metadata_map = result_text_json["metadata_map"] | |||||
| for item in metadata_map: | |||||
| if item.get("metadata_field_name") in all_metadata_fields: | |||||
| automatic_metadata_filters.append( | |||||
| { | |||||
| "metadata_name": item.get("metadata_field_name"), | |||||
| "value": item.get("metadata_field_value"), | |||||
| "condition": item.get("comparison_operator"), | |||||
| } | |||||
| ) | |||||
| except Exception as e: | |||||
| return None | |||||
| return automatic_metadata_filters | |||||
| def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[Any], filters: list): | |||||
| match condition: | |||||
| case "contains": | |||||
| filters.append( | |||||
| (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%") | |||||
| ) | |||||
| case "not contains": | |||||
| filters.append( | |||||
| (text("documents.doc_metadata ->> :key NOT LIKE :value")).params( | |||||
| key=metadata_name, value=f"%{value}%" | |||||
| ) | |||||
| ) | |||||
| case "start with": | |||||
| filters.append( | |||||
| (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%") | |||||
| ) | |||||
| case "end with": | |||||
| filters.append( | |||||
| (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}") | |||||
| ) | |||||
| case "is" | "=": | |||||
| if isinstance(value, str): | |||||
| filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"') | |||||
| else: | |||||
| filters.append( | |||||
| sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value | |||||
| ) | |||||
| case "is not" | "≠": | |||||
| if isinstance(value, str): | |||||
| filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"') | |||||
| else: | |||||
| filters.append( | |||||
| sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value | |||||
| ) | |||||
| case "empty": | |||||
| filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None)) | |||||
| case "not empty": | |||||
| filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None)) | |||||
| case "before" | "<": | |||||
| filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value) | |||||
| case "after" | ">": | |||||
| filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value) | |||||
| case "≤" | ">=": | |||||
| filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value) | |||||
| case "≥" | ">=": | |||||
| filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value) | |||||
| case _: | |||||
| pass | |||||
| return filters | |||||
| def _fetch_model_config( | |||||
| self, tenant_id: str, model: ModelConfig | |||||
| ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: | |||||
| """ | |||||
| Fetch model config | |||||
| :param node_data: node data | |||||
| :return: | |||||
| """ | |||||
| if model is None: | |||||
| raise ValueError("single_retrieval_config is required") | |||||
| model_name = model.name | |||||
| provider_name = model.provider | |||||
| model_manager = ModelManager() | |||||
| model_instance = model_manager.get_model_instance( | |||||
| tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name | |||||
| ) | |||||
| provider_model_bundle = model_instance.provider_model_bundle | |||||
| model_type_instance = model_instance.model_type_instance | |||||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||||
| model_credentials = model_instance.credentials | |||||
| # check model | |||||
| provider_model = provider_model_bundle.configuration.get_provider_model( | |||||
| model=model_name, model_type=ModelType.LLM | |||||
| ) | |||||
| if provider_model is None: | |||||
| raise ValueError(f"Model {model_name} not exist.") | |||||
| if provider_model.status == ModelStatus.NO_CONFIGURE: | |||||
| raise ValueError(f"Model {model_name} credentials is not initialized.") | |||||
| elif provider_model.status == ModelStatus.NO_PERMISSION: | |||||
| raise ValueError(f"Dify Hosted OpenAI {model_name} currently not support.") | |||||
| elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: | |||||
| raise ValueError(f"Model provider {provider_name} quota exceeded.") | |||||
| # model config | |||||
| completion_params = model.completion_params | |||||
| stop = [] | |||||
| if "stop" in completion_params: | |||||
| stop = completion_params["stop"] | |||||
| del completion_params["stop"] | |||||
| # get model mode | |||||
| model_mode = model.mode | |||||
| if not model_mode: | |||||
| raise ValueError("LLM mode is required.") | |||||
| model_schema = model_type_instance.get_model_schema(model_name, model_credentials) | |||||
| if not model_schema: | |||||
| raise ValueError(f"Model {model_name} not exist.") | |||||
| return model_instance, ModelConfigWithCredentialsEntity( | |||||
| provider=provider_name, | |||||
| model=model_name, | |||||
| model_schema=model_schema, | |||||
| mode=model_mode, | |||||
| provider_model_bundle=provider_model_bundle, | |||||
| credentials=model_credentials, | |||||
| parameters=completion_params, | |||||
| stop=stop, | |||||
| ) | |||||
| def _get_prompt_template( | |||||
| self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str | |||||
| ): | |||||
| model_mode = ModelMode.value_of(mode) | |||||
| input_text = query | |||||
| prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] | |||||
| if model_mode == ModelMode.CHAT: | |||||
| prompt_template = [] | |||||
| system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT) | |||||
| prompt_template.append(system_prompt_messages) | |||||
| user_prompt_message_1 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1) | |||||
| prompt_template.append(user_prompt_message_1) | |||||
| assistant_prompt_message_1 = ChatModelMessage( | |||||
| role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1 | |||||
| ) | |||||
| prompt_template.append(assistant_prompt_message_1) | |||||
| user_prompt_message_2 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2) | |||||
| prompt_template.append(user_prompt_message_2) | |||||
| assistant_prompt_message_2 = ChatModelMessage( | |||||
| role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2 | |||||
| ) | |||||
| prompt_template.append(assistant_prompt_message_2) | |||||
| user_prompt_message_3 = ChatModelMessage( | |||||
| role=PromptMessageRole.USER, | |||||
| text=METADATA_FILTER_USER_PROMPT_3.format( | |||||
| input_text=input_text, | |||||
| metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), | |||||
| ), | |||||
| ) | |||||
| prompt_template.append(user_prompt_message_3) | |||||
| elif model_mode == ModelMode.COMPLETION: | |||||
| prompt_template = CompletionModelPromptTemplate( | |||||
| text=METADATA_FILTER_COMPLETION_PROMPT.format( | |||||
| input_text=input_text, | |||||
| metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), | |||||
| ) | |||||
| ) | |||||
| else: | |||||
| raise ValueError(f"Model mode {model_mode} not support.") | |||||
| prompt_transform = AdvancedPromptTransform() | |||||
| prompt_messages = prompt_transform.get_prompt( | |||||
| prompt_template=prompt_template, | |||||
| inputs={}, | |||||
| query=query or "", | |||||
| files=[], | |||||
| context=None, | |||||
| memory_config=None, | |||||
| memory=None, | |||||
| model_config=model_config, | |||||
| ) | |||||
| stop = model_config.stop | |||||
| return prompt_messages, stop | |||||
| def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: | |||||
| """ | |||||
| Handle invoke result | |||||
| :param invoke_result: invoke result | |||||
| :return: | |||||
| """ | |||||
| model = None | |||||
| prompt_messages: list[PromptMessage] = [] | |||||
| full_text = "" | |||||
| usage = None | |||||
| for result in invoke_result: | |||||
| text = result.delta.message.content | |||||
| full_text += text | |||||
| if not model: | |||||
| model = result.model | |||||
| if not prompt_messages: | |||||
| prompt_messages = result.prompt_messages | |||||
| if not usage and result.delta.usage: | |||||
| usage = result.delta.usage | |||||
| if not usage: | |||||
| usage = LLMUsage.empty_usage() | |||||
| return full_text, usage |
| METADATA_FILTER_SYSTEM_PROMPT = """ | |||||
| ### Job Description', | |||||
| You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value | |||||
| ### Task | |||||
| Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". | |||||
| ### Format | |||||
| The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. | |||||
| ### Constraint | |||||
| DO NOT include anything other than the JSON array in your response. | |||||
| """ # noqa: E501 | |||||
| METADATA_FILTER_USER_PROMPT_1 = """ | |||||
| { "input_text": "I want to know which company’s email address test@example.com is?", | |||||
| "metadata_fields": ["filename", "email", "phone", "address"] | |||||
| } | |||||
| """ | |||||
| METADATA_FILTER_ASSISTANT_PROMPT_1 = """ | |||||
| ```json | |||||
| {"metadata_map": [ | |||||
| {"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="} | |||||
| ] | |||||
| } | |||||
| ``` | |||||
| """ | |||||
| METADATA_FILTER_USER_PROMPT_2 = """ | |||||
| {"input_text": "What are the movies with a score of more than 9 in 2024?", | |||||
| "metadata_fields": ["name", "year", "rating", "country"]} | |||||
| """ | |||||
| METADATA_FILTER_ASSISTANT_PROMPT_2 = """ | |||||
| ```json | |||||
| {"metadata_map": [ | |||||
| {"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, | |||||
| {"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}, | |||||
| ]} | |||||
| ``` | |||||
| """ | |||||
| METADATA_FILTER_USER_PROMPT_3 = """ | |||||
| '{{"input_text": "{input_text}",', | |||||
| '"metadata_fields": {metadata_fields}}}' | |||||
| """ | |||||
| METADATA_FILTER_COMPLETION_PROMPT = """ | |||||
| ### Job Description | |||||
| You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value | |||||
| ### Task | |||||
| # Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". | |||||
| ### Format | |||||
| The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. | |||||
| ### Constraint | |||||
| DO NOT include anything other than the JSON array in your response. | |||||
| ### Example | |||||
| Here is the chat example between human and assistant, inside <example></example> XML tags. | |||||
| <example> | |||||
| User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}} | |||||
| Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}} | |||||
| User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}} | |||||
| Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}} | |||||
| </example> | |||||
| ### User Input | |||||
| {{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}} | |||||
| ### Assistant Output | |||||
| """ # noqa: E501 |
| from collections.abc import Sequence | |||||
| from typing import Any, Literal, Optional | from typing import Any, Literal, Optional | ||||
| from pydantic import BaseModel | |||||
| from pydantic import BaseModel, Field | |||||
| from core.workflow.nodes.base import BaseNodeData | from core.workflow.nodes.base import BaseNodeData | ||||
| from core.workflow.nodes.llm.entities import VisionConfig | |||||
| class RerankingModelConfig(BaseModel): | class RerankingModelConfig(BaseModel): | ||||
| model: ModelConfig | model: ModelConfig | ||||
| SupportedComparisonOperator = Literal[ | |||||
| # for string or array | |||||
| "contains", | |||||
| "not contains", | |||||
| "start with", | |||||
| "end with", | |||||
| "is", | |||||
| "is not", | |||||
| "empty", | |||||
| "not empty", | |||||
| # for number | |||||
| "=", | |||||
| "≠", | |||||
| ">", | |||||
| "<", | |||||
| "≥", | |||||
| "≤", | |||||
| # for time | |||||
| "before", | |||||
| "after", | |||||
| ] | |||||
| class Condition(BaseModel): | |||||
| """ | |||||
| Conditon detail | |||||
| """ | |||||
| name: str | |||||
| comparison_operator: SupportedComparisonOperator | |||||
| value: str | Sequence[str] | None | int | float = None | |||||
| class MetadataFilteringCondition(BaseModel): | |||||
| """ | |||||
| Metadata Filtering Condition. | |||||
| """ | |||||
| logical_operator: Optional[Literal["and", "or"]] = "and" | |||||
| conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) | |||||
| class KnowledgeRetrievalNodeData(BaseNodeData): | class KnowledgeRetrievalNodeData(BaseNodeData): | ||||
| """ | """ | ||||
| Knowledge retrieval Node Data. | Knowledge retrieval Node Data. | ||||
| retrieval_mode: Literal["single", "multiple"] | retrieval_mode: Literal["single", "multiple"] | ||||
| multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None | multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None | ||||
| single_retrieval_config: Optional[SingleRetrievalConfig] = None | single_retrieval_config: Optional[SingleRetrievalConfig] = None | ||||
| metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" | |||||
| metadata_model_config: Optional[ModelConfig] = None | |||||
| metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None | |||||
| vision: VisionConfig = Field(default_factory=VisionConfig) |
| class ModelQuotaExceededError(KnowledgeRetrievalNodeError): | class ModelQuotaExceededError(KnowledgeRetrievalNodeError): | ||||
| """Raised when the model provider quota is exceeded.""" | """Raised when the model provider quota is exceeded.""" | ||||
| class InvalidModelTypeError(KnowledgeRetrievalNodeError): | |||||
| """Raised when the model is not a Large Language Model.""" |
| import json | |||||
| import logging | import logging | ||||
| import time | import time | ||||
| from collections import defaultdict | |||||
| from collections.abc import Mapping, Sequence | from collections.abc import Mapping, Sequence | ||||
| from typing import Any, cast | |||||
| from typing import Any, Optional, cast | |||||
| from sqlalchemy import func | |||||
| from sqlalchemy import Integer, and_, func, or_, text | |||||
| from sqlalchemy import cast as sqlalchemy_cast | |||||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity | from core.app.app_config.entities import DatasetRetrieveConfigEntity | ||||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | ||||
| from core.entities.agent_entities import PlanningStrategy | from core.entities.agent_entities import PlanningStrategy | ||||
| from core.entities.model_entities import ModelStatus | from core.entities.model_entities import ModelStatus | ||||
| from core.model_manager import ModelInstance, ModelManager | from core.model_manager import ModelInstance, ModelManager | ||||
| from core.model_runtime.entities.message_entities import PromptMessageRole | |||||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | from core.model_runtime.entities.model_entities import ModelFeature, ModelType | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.prompt.simple_prompt_transform import ModelMode | |||||
| from core.rag.datasource.retrieval_service import RetrievalService | from core.rag.datasource.retrieval_service import RetrievalService | ||||
| from core.rag.entities.metadata_entities import Condition, MetadataCondition | |||||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | ||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| from core.variables import StringSegment | from core.variables import StringSegment | ||||
| from core.workflow.entities.node_entities import NodeRunResult | from core.workflow.entities.node_entities import NodeRunResult | ||||
| from core.workflow.nodes.base import BaseNode | |||||
| from core.workflow.nodes.enums import NodeType | from core.workflow.nodes.enums import NodeType | ||||
| from core.workflow.nodes.event.event import ModelInvokeCompletedEvent | |||||
| from core.workflow.nodes.knowledge_retrieval.template_prompts import ( | |||||
| METADATA_FILTER_ASSISTANT_PROMPT_1, | |||||
| METADATA_FILTER_ASSISTANT_PROMPT_2, | |||||
| METADATA_FILTER_COMPLETION_PROMPT, | |||||
| METADATA_FILTER_SYSTEM_PROMPT, | |||||
| METADATA_FILTER_USER_PROMPT_1, | |||||
| METADATA_FILTER_USER_PROMPT_3, | |||||
| ) | |||||
| from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate | |||||
| from core.workflow.nodes.llm.node import LLMNode | |||||
| from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2 | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import Dataset, Document, RateLimitLog | |||||
| from libs.json_in_md_parser import parse_and_check_json_markdown | |||||
| from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog | |||||
| from models.workflow import WorkflowNodeExecutionStatus | from models.workflow import WorkflowNodeExecutionStatus | ||||
| from services.feature_service import FeatureService | from services.feature_service import FeatureService | ||||
| from .entities import KnowledgeRetrievalNodeData | |||||
| from .entities import KnowledgeRetrievalNodeData, ModelConfig | |||||
| from .exc import ( | from .exc import ( | ||||
| InvalidModelTypeError, | |||||
| KnowledgeRetrievalNodeError, | KnowledgeRetrievalNodeError, | ||||
| ModelCredentialsNotInitializedError, | ModelCredentialsNotInitializedError, | ||||
| ModelNotExistError, | ModelNotExistError, | ||||
| } | } | ||||
| class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||||
| _node_data_cls = KnowledgeRetrievalNodeData | |||||
| class KnowledgeRetrievalNode(LLMNode): | |||||
| _node_data_cls = KnowledgeRetrievalNodeData # type: ignore | |||||
| _node_type = NodeType.KNOWLEDGE_RETRIEVAL | _node_type = NodeType.KNOWLEDGE_RETRIEVAL | ||||
| def _run(self) -> NodeRunResult: | |||||
| def _run(self) -> NodeRunResult: # type: ignore | |||||
| node_data = cast(KnowledgeRetrievalNodeData, self.node_data) | |||||
| # extract variables | # extract variables | ||||
| variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector) | |||||
| variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector) | |||||
| if not isinstance(variable, StringSegment): | if not isinstance(variable, StringSegment): | ||||
| return NodeRunResult( | return NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.FAILED, | status=WorkflowNodeExecutionStatus.FAILED, | ||||
| # retrieve knowledge | # retrieve knowledge | ||||
| try: | try: | ||||
| results = self._fetch_dataset_retriever(node_data=self.node_data, query=query) | |||||
| results = self._fetch_dataset_retriever(node_data=node_data, query=query) | |||||
| outputs = {"result": results} | outputs = {"result": results} | ||||
| return NodeRunResult( | return NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs | status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs | ||||
| if not dataset: | if not dataset: | ||||
| continue | continue | ||||
| available_datasets.append(dataset) | available_datasets.append(dataset) | ||||
| metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( | |||||
| [dataset.id for dataset in available_datasets], query, node_data | |||||
| ) | |||||
| all_documents = [] | all_documents = [] | ||||
| dataset_retrieval = DatasetRetrieval() | dataset_retrieval = DatasetRetrieval() | ||||
| if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: | if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: | ||||
| # fetch model config | # fetch model config | ||||
| model_instance, model_config = self._fetch_model_config(node_data) | |||||
| model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore | |||||
| # check model is support tool calling | # check model is support tool calling | ||||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | model_type_instance = model_config.provider_model_bundle.model_type_instance | ||||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | model_type_instance = cast(LargeLanguageModel, model_type_instance) | ||||
| model_config=model_config, | model_config=model_config, | ||||
| model_instance=model_instance, | model_instance=model_instance, | ||||
| planning_strategy=planning_strategy, | planning_strategy=planning_strategy, | ||||
| metadata_filter_document_ids=metadata_filter_document_ids, | |||||
| metadata_condition=metadata_condition, | |||||
| ) | ) | ||||
| elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: | elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: | ||||
| if node_data.multiple_retrieval_config is None: | if node_data.multiple_retrieval_config is None: | ||||
| reranking_model=reranking_model, | reranking_model=reranking_model, | ||||
| weights=weights, | weights=weights, | ||||
| reranking_enable=node_data.multiple_retrieval_config.reranking_enable, | reranking_enable=node_data.multiple_retrieval_config.reranking_enable, | ||||
| metadata_filter_document_ids=metadata_filter_document_ids, | |||||
| metadata_condition=metadata_condition, | |||||
| ) | ) | ||||
| dify_documents = [item for item in all_documents if item.provider == "dify"] | dify_documents = [item for item in all_documents if item.provider == "dify"] | ||||
| external_documents = [item for item in all_documents if item.provider == "external"] | external_documents = [item for item in all_documents if item.provider == "external"] | ||||
| item["metadata"]["position"] = position | item["metadata"]["position"] = position | ||||
| return retrieval_resource_list | return retrieval_resource_list | ||||
| def _get_metadata_filter_condition( | |||||
| self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData | |||||
| ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: | |||||
| document_query = db.session.query(Document).filter( | |||||
| Document.dataset_id.in_(dataset_ids), | |||||
| Document.indexing_status == "completed", | |||||
| Document.enabled == True, | |||||
| Document.archived == False, | |||||
| ) | |||||
| filters = [] # type: ignore | |||||
| metadata_condition = None | |||||
| if node_data.metadata_filtering_mode == "disabled": | |||||
| return None, None | |||||
| elif node_data.metadata_filtering_mode == "automatic": | |||||
| automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data) | |||||
| if automatic_metadata_filters: | |||||
| conditions = [] | |||||
| for filter in automatic_metadata_filters: | |||||
| self._process_metadata_filter_func( | |||||
| filter.get("condition", ""), | |||||
| filter.get("metadata_name", ""), | |||||
| filter.get("value"), | |||||
| filters, # type: ignore | |||||
| ) | |||||
| conditions.append( | |||||
| Condition( | |||||
| name=filter.get("metadata_name"), # type: ignore | |||||
| comparison_operator=filter.get("condition"), # type: ignore | |||||
| value=filter.get("value"), | |||||
| ) | |||||
| ) | |||||
| metadata_condition = MetadataCondition( | |||||
| logical_operator=node_data.metadata_filtering_conditions.logical_operator, # type: ignore | |||||
| conditions=conditions, | |||||
| ) | |||||
| elif node_data.metadata_filtering_mode == "manual": | |||||
| if node_data.metadata_filtering_conditions: | |||||
| metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump()) | |||||
| if node_data.metadata_filtering_conditions: | |||||
| for condition in node_data.metadata_filtering_conditions.conditions: # type: ignore | |||||
| metadata_name = condition.name | |||||
| expected_value = condition.value | |||||
| if expected_value or condition.comparison_operator in ("empty", "not empty"): | |||||
| if isinstance(expected_value, str): | |||||
| expected_value = self.graph_runtime_state.variable_pool.convert_template( | |||||
| expected_value | |||||
| ).text | |||||
| filters = self._process_metadata_filter_func( | |||||
| condition.comparison_operator, metadata_name, expected_value, filters | |||||
| ) | |||||
| else: | |||||
| raise ValueError("Invalid metadata filtering mode") | |||||
| if filters: | |||||
| if node_data.metadata_filtering_conditions.logical_operator == "and": # type: ignore | |||||
| document_query = document_query.filter(and_(*filters)) | |||||
| else: | |||||
| document_query = document_query.filter(or_(*filters)) | |||||
| documents = document_query.all() | |||||
| # group by dataset_id | |||||
| metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore | |||||
| for document in documents: | |||||
| metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore | |||||
| return metadata_filter_document_ids, metadata_condition | |||||
| def _automatic_metadata_filter_func( | |||||
| self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData | |||||
| ) -> list[dict[str, Any]]: | |||||
| # get all metadata field | |||||
| metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() | |||||
| all_metadata_fields = [metadata_field.field_name for metadata_field in metadata_fields] | |||||
| # get metadata model config | |||||
| metadata_model_config = node_data.metadata_model_config | |||||
| if metadata_model_config is None: | |||||
| raise ValueError("metadata_model_config is required") | |||||
| # get metadata model instance | |||||
| # fetch model config | |||||
| model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore | |||||
| # fetch prompt messages | |||||
| prompt_template = self._get_prompt_template( | |||||
| node_data=node_data, | |||||
| metadata_fields=all_metadata_fields, | |||||
| query=query or "", | |||||
| ) | |||||
| prompt_messages, stop = self._fetch_prompt_messages( | |||||
| prompt_template=prompt_template, | |||||
| sys_query=query, | |||||
| memory=None, | |||||
| model_config=model_config, | |||||
| sys_files=[], | |||||
| vision_enabled=node_data.vision.enabled, | |||||
| vision_detail=node_data.vision.configs.detail, | |||||
| variable_pool=self.graph_runtime_state.variable_pool, | |||||
| jinja2_variables=[], | |||||
| ) | |||||
| result_text = "" | |||||
| try: | |||||
| # handle invoke result | |||||
| generator = self._invoke_llm( | |||||
| node_data_model=node_data.metadata_model_config, # type: ignore | |||||
| model_instance=model_instance, | |||||
| prompt_messages=prompt_messages, | |||||
| stop=stop, | |||||
| ) | |||||
| for event in generator: | |||||
| if isinstance(event, ModelInvokeCompletedEvent): | |||||
| result_text = event.text | |||||
| break | |||||
| result_text_json = parse_and_check_json_markdown(result_text, []) | |||||
| automatic_metadata_filters = [] | |||||
| if "metadata_map" in result_text_json: | |||||
| metadata_map = result_text_json["metadata_map"] | |||||
| for item in metadata_map: | |||||
| if item.get("metadata_field_name") in all_metadata_fields: | |||||
| automatic_metadata_filters.append( | |||||
| { | |||||
| "metadata_name": item.get("metadata_field_name"), | |||||
| "value": item.get("metadata_field_value"), | |||||
| "condition": item.get("comparison_operator"), | |||||
| } | |||||
| ) | |||||
| except Exception as e: | |||||
| return [] | |||||
| return automatic_metadata_filters | |||||
| def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[str], filters: list): | |||||
| match condition: | |||||
| case "contains": | |||||
| filters.append( | |||||
| (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%") | |||||
| ) | |||||
| case "not contains": | |||||
| filters.append( | |||||
| (text("documents.doc_metadata ->> :key NOT LIKE :value")).params( | |||||
| key=metadata_name, value=f"%{value}%" | |||||
| ) | |||||
| ) | |||||
| case "start with": | |||||
| filters.append( | |||||
| (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%") | |||||
| ) | |||||
| case "end with": | |||||
| filters.append( | |||||
| (text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}") | |||||
| ) | |||||
| case "=" | "is": | |||||
| if isinstance(value, str): | |||||
| filters.append(Document.doc_metadata[metadata_name] == f'"{value}"') | |||||
| else: | |||||
| filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value) | |||||
| case "is not" | "≠": | |||||
| if isinstance(value, str): | |||||
| filters.append(Document.doc_metadata[metadata_name] != f'"{value}"') | |||||
| else: | |||||
| filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value) | |||||
| case "empty": | |||||
| filters.append(Document.doc_metadata[metadata_name].is_(None)) | |||||
| case "not empty": | |||||
| filters.append(Document.doc_metadata[metadata_name].isnot(None)) | |||||
| case "before" | "<": | |||||
| filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value) | |||||
| case "after" | ">": | |||||
| filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value) | |||||
| case "≤" | ">=": | |||||
| filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value) | |||||
| case "≥" | ">=": | |||||
| filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value) | |||||
| case _: | |||||
| pass | |||||
| return filters | |||||
| @classmethod | @classmethod | ||||
| def _extract_variable_selector_to_variable_mapping( | def _extract_variable_selector_to_variable_mapping( | ||||
| cls, | cls, | ||||
| *, | *, | ||||
| graph_config: Mapping[str, Any], | graph_config: Mapping[str, Any], | ||||
| node_id: str, | node_id: str, | ||||
| node_data: KnowledgeRetrievalNodeData, | |||||
| node_data: KnowledgeRetrievalNodeData, # type: ignore | |||||
| ) -> Mapping[str, Sequence[str]]: | ) -> Mapping[str, Sequence[str]]: | ||||
| """ | """ | ||||
| Extract variable selector to variable mapping | Extract variable selector to variable mapping | ||||
| variable_mapping[node_id + ".query"] = node_data.query_variable_selector | variable_mapping[node_id + ".query"] = node_data.query_variable_selector | ||||
| return variable_mapping | return variable_mapping | ||||
| def _fetch_model_config( | |||||
| self, node_data: KnowledgeRetrievalNodeData | |||||
| ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: | |||||
| def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore | |||||
| """ | """ | ||||
| Fetch model config | Fetch model config | ||||
| :param node_data: node data | |||||
| :param model: model | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| if node_data.single_retrieval_config is None: | |||||
| raise ValueError("single_retrieval_config is required") | |||||
| model_name = node_data.single_retrieval_config.model.name | |||||
| provider_name = node_data.single_retrieval_config.model.provider | |||||
| if model is None: | |||||
| raise ValueError("model is required") | |||||
| model_name = model.name | |||||
| provider_name = model.provider | |||||
| model_manager = ModelManager() | model_manager = ModelManager() | ||||
| model_instance = model_manager.get_model_instance( | model_instance = model_manager.get_model_instance( | ||||
| raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.") | raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.") | ||||
| # model config | # model config | ||||
| completion_params = node_data.single_retrieval_config.model.completion_params | |||||
| completion_params = model.completion_params | |||||
| stop = [] | stop = [] | ||||
| if "stop" in completion_params: | if "stop" in completion_params: | ||||
| stop = completion_params["stop"] | stop = completion_params["stop"] | ||||
| del completion_params["stop"] | del completion_params["stop"] | ||||
| # get model mode | # get model mode | ||||
| model_mode = node_data.single_retrieval_config.model.mode | |||||
| model_mode = model.mode | |||||
| if not model_mode: | if not model_mode: | ||||
| raise ModelNotExistError("LLM mode is required.") | raise ModelNotExistError("LLM mode is required.") | ||||
| parameters=completion_params, | parameters=completion_params, | ||||
| stop=stop, | stop=stop, | ||||
| ) | ) | ||||
| def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str): | |||||
| model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore | |||||
| input_text = query | |||||
| memory_str = "" | |||||
| prompt_messages: list[LLMNodeChatModelMessage] = [] | |||||
| if model_mode == ModelMode.CHAT: | |||||
| system_prompt_messages = LLMNodeChatModelMessage( | |||||
| role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT | |||||
| ) | |||||
| prompt_messages.append(system_prompt_messages) | |||||
| user_prompt_message_1 = LLMNodeChatModelMessage( | |||||
| role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1 | |||||
| ) | |||||
| prompt_messages.append(user_prompt_message_1) | |||||
| assistant_prompt_message_1 = LLMNodeChatModelMessage( | |||||
| role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1 | |||||
| ) | |||||
| prompt_messages.append(assistant_prompt_message_1) | |||||
| user_prompt_message_2 = LLMNodeChatModelMessage( | |||||
| role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 | |||||
| ) | |||||
| prompt_messages.append(user_prompt_message_2) | |||||
| assistant_prompt_message_2 = LLMNodeChatModelMessage( | |||||
| role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2 | |||||
| ) | |||||
| prompt_messages.append(assistant_prompt_message_2) | |||||
| user_prompt_message_3 = LLMNodeChatModelMessage( | |||||
| role=PromptMessageRole.USER, | |||||
| text=METADATA_FILTER_USER_PROMPT_3.format( | |||||
| input_text=input_text, | |||||
| metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), | |||||
| ), | |||||
| ) | |||||
| prompt_messages.append(user_prompt_message_3) | |||||
| return prompt_messages | |||||
| elif model_mode == ModelMode.COMPLETION: | |||||
| return LLMNodeCompletionModelPromptTemplate( | |||||
| text=METADATA_FILTER_COMPLETION_PROMPT.format( | |||||
| input_text=input_text, | |||||
| metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), | |||||
| ) | |||||
| ) | |||||
| else: | |||||
| raise InvalidModelTypeError(f"Model mode {model_mode} not support.") |
| METADATA_FILTER_SYSTEM_PROMPT = """ | |||||
| ### Job Description', | |||||
| You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value | |||||
| ### Task | |||||
| Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". | |||||
| ### Format | |||||
| The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. | |||||
| ### Constraint | |||||
| DO NOT include anything other than the JSON array in your response. | |||||
| """ # noqa: E501 | |||||
| METADATA_FILTER_USER_PROMPT_1 = """ | |||||
| { "input_text": "I want to know which company’s email address test@example.com is?", | |||||
| "metadata_fields": ["filename", "email", "phone", "address"] | |||||
| } | |||||
| """ | |||||
| METADATA_FILTER_ASSISTANT_PROMPT_1 = """ | |||||
| ```json | |||||
| {"metadata_map": [ | |||||
| {"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="} | |||||
| ] | |||||
| } | |||||
| ``` | |||||
| """ | |||||
| METADATA_FILTER_USER_PROMPT_2 = """ | |||||
| {"input_text": "What are the movies with a score of more than 9 in 2024?", | |||||
| "metadata_fields": ["name", "year", "rating", "country"]} | |||||
| """ | |||||
| METADATA_FILTER_ASSISTANT_PROMPT_2 = """ | |||||
| ```json | |||||
| {"metadata_map": [ | |||||
| {"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, | |||||
| {"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}, | |||||
| ]} | |||||
| ``` | |||||
| """ | |||||
| METADATA_FILTER_USER_PROMPT_3 = """ | |||||
| '{{"input_text": "{input_text}",', | |||||
| '"metadata_fields": {metadata_fields}}}' | |||||
| """ | |||||
| METADATA_FILTER_COMPLETION_PROMPT = """ | |||||
| ### Job Description | |||||
| You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value | |||||
| ### Task | |||||
| # Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator". | |||||
| ### Format | |||||
| The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields. | |||||
| ### Constraint | |||||
| DO NOT include anything other than the JSON array in your response. | |||||
| ### Example | |||||
| Here is the chat example between human and assistant, inside <example></example> XML tags. | |||||
| <example> | |||||
| User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}} | |||||
| Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}} | |||||
| User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}} | |||||
| Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}} | |||||
| </example> | |||||
| ### User Input | |||||
| {{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}} | |||||
| ### Assistant Output | |||||
| """ # noqa: E501 |
| "external_knowledge_api_endpoint": fields.String, | "external_knowledge_api_endpoint": fields.String, | ||||
| } | } | ||||
| doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String} | |||||
| dataset_detail_fields = { | dataset_detail_fields = { | ||||
| "id": fields.String, | "id": fields.String, | ||||
| "name": fields.String, | "name": fields.String, | ||||
| "doc_form": fields.String, | "doc_form": fields.String, | ||||
| "external_knowledge_info": fields.Nested(external_knowledge_info_fields), | "external_knowledge_info": fields.Nested(external_knowledge_info_fields), | ||||
| "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), | "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), | ||||
| "doc_metadata": fields.List(fields.Nested(doc_metadata_fields)), | |||||
| "built_in_field_enabled": fields.Boolean, | |||||
| } | } | ||||
| dataset_query_detail_fields = { | dataset_query_detail_fields = { | ||||
| "created_by": fields.String, | "created_by": fields.String, | ||||
| "created_at": TimestampField, | "created_at": TimestampField, | ||||
| } | } | ||||
| dataset_metadata_fields = { | |||||
| "id": fields.String, | |||||
| "type": fields.String, | |||||
| "name": fields.String, | |||||
| } |
| from fields.dataset_fields import dataset_fields | from fields.dataset_fields import dataset_fields | ||||
| from libs.helper import TimestampField | from libs.helper import TimestampField | ||||
| document_metadata_fields = { | |||||
| "id": fields.String, | |||||
| "name": fields.String, | |||||
| "type": fields.String, | |||||
| "value": fields.String, | |||||
| } | |||||
| document_fields = { | document_fields = { | ||||
| "id": fields.String, | "id": fields.String, | ||||
| "position": fields.Integer, | "position": fields.Integer, | ||||
| "word_count": fields.Integer, | "word_count": fields.Integer, | ||||
| "hit_count": fields.Integer, | "hit_count": fields.Integer, | ||||
| "doc_form": fields.String, | "doc_form": fields.String, | ||||
| "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"), | |||||
| } | } | ||||
| document_with_segments_fields = { | document_with_segments_fields = { | ||||
| "hit_count": fields.Integer, | "hit_count": fields.Integer, | ||||
| "completed_segments": fields.Integer, | "completed_segments": fields.Integer, | ||||
| "total_segments": fields.Integer, | "total_segments": fields.Integer, | ||||
| "doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"), | |||||
| } | } | ||||
| dataset_and_document_fields = { | dataset_and_document_fields = { |
| """add_metadata_function | |||||
| Revision ID: d20049ed0af6 | |||||
| Revises: 08ec4f75af5e | |||||
| Create Date: 2025-02-27 09:17:48.903213 | |||||
| """ | |||||
| from alembic import op | |||||
| import models as models | |||||
| import sqlalchemy as sa | |||||
| from sqlalchemy.dialects import postgresql | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = 'd20049ed0af6' | |||||
| down_revision = 'f051706725cc' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| op.create_table('dataset_metadata_bindings', | |||||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('dataset_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('metadata_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('document_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), | |||||
| sa.Column('created_by', models.types.StringUUID(), nullable=False), | |||||
| sa.PrimaryKeyConstraint('id', name='dataset_metadata_binding_pkey') | |||||
| ) | |||||
| with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op: | |||||
| batch_op.create_index('dataset_metadata_binding_dataset_idx', ['dataset_id'], unique=False) | |||||
| batch_op.create_index('dataset_metadata_binding_document_idx', ['document_id'], unique=False) | |||||
| batch_op.create_index('dataset_metadata_binding_metadata_idx', ['metadata_id'], unique=False) | |||||
| batch_op.create_index('dataset_metadata_binding_tenant_idx', ['tenant_id'], unique=False) | |||||
| op.create_table('dataset_metadatas', | |||||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('dataset_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('type', sa.String(length=255), nullable=False), | |||||
| sa.Column('name', sa.String(length=255), nullable=False), | |||||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||||
| sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||||
| sa.Column('created_by', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('updated_by', models.types.StringUUID(), nullable=True), | |||||
| sa.PrimaryKeyConstraint('id', name='dataset_metadata_pkey') | |||||
| ) | |||||
| with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op: | |||||
| batch_op.create_index('dataset_metadata_dataset_idx', ['dataset_id'], unique=False) | |||||
| batch_op.create_index('dataset_metadata_tenant_idx', ['tenant_id'], unique=False) | |||||
| with op.batch_alter_table('datasets', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('built_in_field_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False)) | |||||
| with op.batch_alter_table('documents', schema=None) as batch_op: | |||||
| batch_op.alter_column('doc_metadata', | |||||
| existing_type=postgresql.JSON(astext_type=sa.Text()), | |||||
| type_=postgresql.JSONB(astext_type=sa.Text()), | |||||
| existing_nullable=True) | |||||
| batch_op.create_index('document_metadata_idx', ['doc_metadata'], unique=False, postgresql_using='gin') | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('documents', schema=None) as batch_op: | |||||
| batch_op.drop_index('document_metadata_idx', postgresql_using='gin') | |||||
| batch_op.alter_column('doc_metadata', | |||||
| existing_type=postgresql.JSONB(astext_type=sa.Text()), | |||||
| type_=postgresql.JSON(astext_type=sa.Text()), | |||||
| existing_nullable=True) | |||||
| with op.batch_alter_table('datasets', schema=None) as batch_op: | |||||
| batch_op.drop_column('built_in_field_enabled') | |||||
| with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op: | |||||
| batch_op.drop_index('dataset_metadata_tenant_idx') | |||||
| batch_op.drop_index('dataset_metadata_dataset_idx') | |||||
| op.drop_table('dataset_metadatas') | |||||
| with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op: | |||||
| batch_op.drop_index('dataset_metadata_binding_tenant_idx') | |||||
| batch_op.drop_index('dataset_metadata_binding_metadata_idx') | |||||
| batch_op.drop_index('dataset_metadata_binding_document_idx') | |||||
| batch_op.drop_index('dataset_metadata_binding_dataset_idx') | |||||
| op.drop_table('dataset_metadata_bindings') | |||||
| # ### end Alembic commands ### |
| from sqlalchemy.orm import Mapped | from sqlalchemy.orm import Mapped | ||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource | |||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule | from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule | ||||
| embedding_model_provider = db.Column(db.String(255), nullable=True) | embedding_model_provider = db.Column(db.String(255), nullable=True) | ||||
| collection_binding_id = db.Column(StringUUID, nullable=True) | collection_binding_id = db.Column(StringUUID, nullable=True) | ||||
| retrieval_model = db.Column(JSONB, nullable=True) | retrieval_model = db.Column(JSONB, nullable=True) | ||||
| built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||||
| @property | @property | ||||
| def dataset_keyword_table(self): | def dataset_keyword_table(self): | ||||
| "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), | "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), | ||||
| } | } | ||||
| @property | |||||
| def doc_metadata(self): | |||||
| dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all() | |||||
| doc_metadata = [ | |||||
| { | |||||
| "id": dataset_metadata.id, | |||||
| "name": dataset_metadata.name, | |||||
| "type": dataset_metadata.type, | |||||
| } | |||||
| for dataset_metadata in dataset_metadatas | |||||
| ] | |||||
| if self.built_in_field_enabled: | |||||
| doc_metadata.append( | |||||
| { | |||||
| "id": "built-in", | |||||
| "name": BuiltInField.document_name.value, | |||||
| "type": "string", | |||||
| } | |||||
| ) | |||||
| doc_metadata.append( | |||||
| { | |||||
| "id": "built-in", | |||||
| "name": BuiltInField.uploader.value, | |||||
| "type": "string", | |||||
| } | |||||
| ) | |||||
| doc_metadata.append( | |||||
| { | |||||
| "id": "built-in", | |||||
| "name": BuiltInField.upload_date.value, | |||||
| "type": "time", | |||||
| } | |||||
| ) | |||||
| doc_metadata.append( | |||||
| { | |||||
| "id": "built-in", | |||||
| "name": BuiltInField.last_update_date.value, | |||||
| "type": "time", | |||||
| } | |||||
| ) | |||||
| doc_metadata.append( | |||||
| { | |||||
| "id": "built-in", | |||||
| "name": BuiltInField.source.value, | |||||
| "type": "string", | |||||
| } | |||||
| ) | |||||
| return doc_metadata | |||||
| @staticmethod | @staticmethod | ||||
| def gen_collection_name_by_id(dataset_id: str) -> str: | def gen_collection_name_by_id(dataset_id: str) -> str: | ||||
| normalized_dataset_id = dataset_id.replace("-", "_") | normalized_dataset_id = dataset_id.replace("-", "_") | ||||
| db.Index("document_dataset_id_idx", "dataset_id"), | db.Index("document_dataset_id_idx", "dataset_id"), | ||||
| db.Index("document_is_paused_idx", "is_paused"), | db.Index("document_is_paused_idx", "is_paused"), | ||||
| db.Index("document_tenant_idx", "tenant_id"), | db.Index("document_tenant_idx", "tenant_id"), | ||||
| db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"), | |||||
| ) | ) | ||||
| # initial fields | # initial fields | ||||
| archived_at = db.Column(db.DateTime, nullable=True) | archived_at = db.Column(db.DateTime, nullable=True) | ||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| doc_type = db.Column(db.String(40), nullable=True) | doc_type = db.Column(db.String(40), nullable=True) | ||||
| doc_metadata = db.Column(db.JSON, nullable=True) | |||||
| doc_metadata = db.Column(JSONB, nullable=True) | |||||
| doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) | doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) | ||||
| doc_language = db.Column(db.String(255), nullable=True) | doc_language = db.Column(db.String(255), nullable=True) | ||||
| .scalar() | .scalar() | ||||
| ) | ) | ||||
| @property | |||||
| def uploader(self): | |||||
| user = db.session.query(Account).filter(Account.id == self.created_by).first() | |||||
| return user.name if user else None | |||||
| @property | |||||
| def upload_date(self): | |||||
| return self.created_at | |||||
| @property | |||||
| def last_update_date(self): | |||||
| return self.updated_at | |||||
| @property | |||||
| def doc_metadata_details(self): | |||||
| if self.doc_metadata: | |||||
| document_metadatas = ( | |||||
| db.session.query(DatasetMetadata) | |||||
| .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id) | |||||
| .filter( | |||||
| DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id | |||||
| ) | |||||
| .all() | |||||
| ) | |||||
| metadata_list = [] | |||||
| for metadata in document_metadatas: | |||||
| metadata_dict = { | |||||
| "id": metadata.id, | |||||
| "name": metadata.name, | |||||
| "type": metadata.type, | |||||
| "value": self.doc_metadata.get(metadata.name), | |||||
| } | |||||
| metadata_list.append(metadata_dict) | |||||
| # deal built-in fields | |||||
| metadata_list.extend(self.get_built_in_fields()) | |||||
| return metadata_list | |||||
| return None | |||||
| @property | @property | ||||
| def process_rule_dict(self): | def process_rule_dict(self): | ||||
| if self.dataset_process_rule_id: | if self.dataset_process_rule_id: | ||||
| return self.dataset_process_rule.to_dict() | return self.dataset_process_rule.to_dict() | ||||
| return None | return None | ||||
| def get_built_in_fields(self): | |||||
| built_in_fields = [] | |||||
| built_in_fields.append( | |||||
| { | |||||
| "id": "built-in", | |||||
| "name": BuiltInField.document_name, | |||||
| "type": "string", | |||||
| "value": self.name, | |||||
| } | |||||
| ) | |||||
| built_in_fields.append( | |||||
| { | |||||
| "id": "built-in", | |||||
| "name": BuiltInField.uploader, | |||||
| "type": "string", | |||||
| "value": self.uploader, | |||||
| } | |||||
| ) | |||||
| built_in_fields.append( | |||||
| { | |||||
| "id": "built-in", | |||||
| "name": BuiltInField.upload_date, | |||||
| "type": "time", | |||||
| "value": self.created_at.timestamp(), | |||||
| } | |||||
| ) | |||||
| built_in_fields.append( | |||||
| { | |||||
| "id": "built-in", | |||||
| "name": BuiltInField.last_update_date, | |||||
| "type": "time", | |||||
| "value": self.updated_at.timestamp(), | |||||
| } | |||||
| ) | |||||
| built_in_fields.append( | |||||
| { | |||||
| "id": "built-in", | |||||
| "name": BuiltInField.source, | |||||
| "type": "string", | |||||
| "value": MetadataDataSource[self.data_source_type].value, | |||||
| } | |||||
| ) | |||||
| return built_in_fields | |||||
| def to_dict(self): | def to_dict(self): | ||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| subscription_plan = db.Column(db.String(255), nullable=False) | subscription_plan = db.Column(db.String(255), nullable=False) | ||||
| operation = db.Column(db.String(255), nullable=False) | operation = db.Column(db.String(255), nullable=False) | ||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | ||||
| class DatasetMetadata(db.Model): # type: ignore[name-defined] | |||||
| __tablename__ = "dataset_metadatas" | |||||
| __table_args__ = ( | |||||
| db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), | |||||
| db.Index("dataset_metadata_tenant_idx", "tenant_id"), | |||||
| db.Index("dataset_metadata_dataset_idx", "dataset_id"), | |||||
| ) | |||||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||||
| tenant_id = db.Column(StringUUID, nullable=False) | |||||
| dataset_id = db.Column(StringUUID, nullable=False) | |||||
| type = db.Column(db.String(255), nullable=False) | |||||
| name = db.Column(db.String(255), nullable=False) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||||
| created_by = db.Column(StringUUID, nullable=False) | |||||
| updated_by = db.Column(StringUUID, nullable=True) | |||||
| class DatasetMetadataBinding(db.Model): # type: ignore[name-defined] | |||||
| __tablename__ = "dataset_metadata_bindings" | |||||
| __table_args__ = ( | |||||
| db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), | |||||
| db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"), | |||||
| db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"), | |||||
| db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"), | |||||
| db.Index("dataset_metadata_binding_document_idx", "document_id"), | |||||
| ) | |||||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||||
| tenant_id = db.Column(StringUUID, nullable=False) | |||||
| dataset_id = db.Column(StringUUID, nullable=False) | |||||
| metadata_id = db.Column(StringUUID, nullable=False) | |||||
| document_id = db.Column(StringUUID, nullable=False) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| created_by = db.Column(StringUUID, nullable=False) |
| import copy | |||||
| import datetime | import datetime | ||||
| import json | import json | ||||
| import logging | import logging | ||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.plugin.entities.plugin import ModelProviderID | from core.plugin.entities.plugin import ModelProviderID | ||||
| from core.rag.index_processor.constant.built_in_field import BuiltInField | |||||
| from core.rag.index_processor.constant.index_type import IndexType | from core.rag.index_processor.constant.index_type import IndexType | ||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| from events.dataset_event import dataset_was_deleted | from events.dataset_event import dataset_was_deleted | ||||
| return document | return document | ||||
| @staticmethod | |||||
| def get_document_by_ids(document_ids: list[str]) -> list[Document]: | |||||
| documents = ( | |||||
| db.session.query(Document) | |||||
| .filter( | |||||
| Document.id.in_(document_ids), | |||||
| Document.enabled == True, | |||||
| Document.indexing_status == "completed", | |||||
| Document.archived == False, | |||||
| ) | |||||
| .all() | |||||
| ) | |||||
| return documents | |||||
| @staticmethod | @staticmethod | ||||
| def get_document_by_dataset_id(dataset_id: str) -> list[Document]: | def get_document_by_dataset_id(dataset_id: str) -> list[Document]: | ||||
| documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all() | |||||
| documents = ( | |||||
| db.session.query(Document) | |||||
| .filter( | |||||
| Document.dataset_id == dataset_id, | |||||
| Document.enabled == True, | |||||
| ) | |||||
| .all() | |||||
| ) | |||||
| return documents | |||||
| @staticmethod | |||||
| def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]: | |||||
| documents = ( | |||||
| db.session.query(Document) | |||||
| .filter( | |||||
| Document.dataset_id == dataset_id, | |||||
| Document.enabled == True, | |||||
| Document.indexing_status == "completed", | |||||
| Document.archived == False, | |||||
| ) | |||||
| .all() | |||||
| ) | |||||
| return documents | return documents | ||||
| if document.tenant_id != current_user.current_tenant_id: | if document.tenant_id != current_user.current_tenant_id: | ||||
| raise ValueError("No permission.") | raise ValueError("No permission.") | ||||
| document.name = name | |||||
| if dataset.built_in_field_enabled: | |||||
| if document.doc_metadata: | |||||
| doc_metadata = copy.deepcopy(document.doc_metadata) | |||||
| doc_metadata[BuiltInField.document_name.value] = name | |||||
| document.doc_metadata = doc_metadata | |||||
| document.name = name | |||||
| db.session.add(document) | db.session.add(document) | ||||
| db.session.commit() | db.session.commit() | ||||
| doc_form=document_form, | doc_form=document_form, | ||||
| doc_language=document_language, | doc_language=document_language, | ||||
| ) | ) | ||||
| doc_metadata = {} | |||||
| if dataset.built_in_field_enabled: | |||||
| doc_metadata = { | |||||
| BuiltInField.document_name: name, | |||||
| BuiltInField.uploader: account.name, | |||||
| BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), | |||||
| BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"), | |||||
| BuiltInField.source: data_source_type, | |||||
| } | |||||
| if metadata is not None: | if metadata is not None: | ||||
| document.doc_metadata = metadata.doc_metadata | |||||
| doc_metadata.update(metadata.doc_metadata) | |||||
| document.doc_type = metadata.doc_type | document.doc_type = metadata.doc_type | ||||
| if doc_metadata: | |||||
| document.doc_metadata = doc_metadata | |||||
| return document | return document | ||||
| @staticmethod | @staticmethod |
| class ChildChunkUpdateArgs(BaseModel): | class ChildChunkUpdateArgs(BaseModel): | ||||
| id: Optional[str] = None | id: Optional[str] = None | ||||
| content: str | content: str | ||||
| class MetadataArgs(BaseModel): | |||||
| type: Literal["string", "number", "time"] | |||||
| name: str | |||||
| class MetadataUpdateArgs(BaseModel): | |||||
| name: str | |||||
| value: Optional[str | int | float] = None | |||||
| class MetadataValueUpdateArgs(BaseModel): | |||||
| fields: list[MetadataUpdateArgs] | |||||
| class MetadataDetail(BaseModel): | |||||
| id: str | |||||
| name: str | |||||
| value: Optional[str | int | float] = None | |||||
| class DocumentMetadataOperation(BaseModel): | |||||
| document_id: str | |||||
| metadata_list: list[MetadataDetail] | |||||
| class MetadataOperationData(BaseModel): | |||||
| """ | |||||
| Metadata operation data | |||||
| """ | |||||
| operation_data: list[DocumentMetadataOperation] |
| from constants import HIDDEN_VALUE | from constants import HIDDEN_VALUE | ||||
| from core.helper import ssrf_proxy | from core.helper import ssrf_proxy | ||||
| from core.rag.entities.metadata_entities import MetadataCondition | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import ( | from models.dataset import ( | ||||
| Dataset, | Dataset, | ||||
| @staticmethod | @staticmethod | ||||
| def fetch_external_knowledge_retrieval( | def fetch_external_knowledge_retrieval( | ||||
| tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict | |||||
| tenant_id: str, | |||||
| dataset_id: str, | |||||
| query: str, | |||||
| external_retrieval_parameters: dict, | |||||
| metadata_condition: Optional[MetadataCondition] = None, | |||||
| ) -> list: | ) -> list: | ||||
| external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( | external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( | ||||
| dataset_id=dataset_id, tenant_id=tenant_id | dataset_id=dataset_id, tenant_id=tenant_id | ||||
| }, | }, | ||||
| "query": query, | "query": query, | ||||
| "knowledge_id": external_knowledge_binding.external_knowledge_id, | "knowledge_id": external_knowledge_binding.external_knowledge_id, | ||||
| "metadata_condition": metadata_condition.model_dump() if metadata_condition else None, | |||||
| } | } | ||||
| response = ExternalDatasetService.process_external_api( | response = ExternalDatasetService.process_external_api( |
| import copy | |||||
| import datetime | |||||
| import logging | |||||
| from typing import Optional | |||||
| from flask_login import current_user # type: ignore | |||||
| from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource | |||||
| from extensions.ext_database import db | |||||
| from extensions.ext_redis import redis_client | |||||
| from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding | |||||
| from services.dataset_service import DocumentService | |||||
| from services.entities.knowledge_entities.knowledge_entities import ( | |||||
| MetadataArgs, | |||||
| MetadataOperationData, | |||||
| ) | |||||
| class MetadataService: | |||||
| @staticmethod | |||||
| def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: | |||||
| # check if metadata name already exists | |||||
| if DatasetMetadata.query.filter_by( | |||||
| tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name | |||||
| ).first(): | |||||
| raise ValueError("Metadata name already exists.") | |||||
| for field in BuiltInField: | |||||
| if field.value == metadata_args.name: | |||||
| raise ValueError("Metadata name already exists in Built-in fields.") | |||||
| metadata = DatasetMetadata( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| dataset_id=dataset_id, | |||||
| type=metadata_args.type, | |||||
| name=metadata_args.name, | |||||
| created_by=current_user.id, | |||||
| ) | |||||
| db.session.add(metadata) | |||||
| db.session.commit() | |||||
| return metadata | |||||
| @staticmethod | |||||
| def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore | |||||
| lock_key = f"dataset_metadata_lock_{dataset_id}" | |||||
| # check if metadata name already exists | |||||
| if DatasetMetadata.query.filter_by( | |||||
| tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name | |||||
| ).first(): | |||||
| raise ValueError("Metadata name already exists.") | |||||
| for field in BuiltInField: | |||||
| if field.value == name: | |||||
| raise ValueError("Metadata name already exists in Built-in fields.") | |||||
| try: | |||||
| MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) | |||||
| metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() | |||||
| if metadata is None: | |||||
| raise ValueError("Metadata not found.") | |||||
| old_name = metadata.name | |||||
| metadata.name = name | |||||
| metadata.updated_by = current_user.id | |||||
| metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||||
| # update related documents | |||||
| dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() | |||||
| if dataset_metadata_bindings: | |||||
| document_ids = [binding.document_id for binding in dataset_metadata_bindings] | |||||
| documents = DocumentService.get_document_by_ids(document_ids) | |||||
| for document in documents: | |||||
| doc_metadata = copy.deepcopy(document.doc_metadata) | |||||
| value = doc_metadata.pop(old_name, None) | |||||
| doc_metadata[name] = value | |||||
| document.doc_metadata = doc_metadata | |||||
| db.session.add(document) | |||||
| db.session.commit() | |||||
| return metadata # type: ignore | |||||
| except Exception: | |||||
| logging.exception("Update metadata name failed") | |||||
| finally: | |||||
| redis_client.delete(lock_key) | |||||
| @staticmethod | |||||
| def delete_metadata(dataset_id: str, metadata_id: str): | |||||
| lock_key = f"dataset_metadata_lock_{dataset_id}" | |||||
| try: | |||||
| MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) | |||||
| metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() | |||||
| if metadata is None: | |||||
| raise ValueError("Metadata not found.") | |||||
| db.session.delete(metadata) | |||||
| # deal related documents | |||||
| dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() | |||||
| if dataset_metadata_bindings: | |||||
| document_ids = [binding.document_id for binding in dataset_metadata_bindings] | |||||
| documents = DocumentService.get_document_by_ids(document_ids) | |||||
| for document in documents: | |||||
| doc_metadata = copy.deepcopy(document.doc_metadata) | |||||
| doc_metadata.pop(metadata.name, None) | |||||
| document.doc_metadata = doc_metadata | |||||
| db.session.add(document) | |||||
| db.session.commit() | |||||
| return metadata | |||||
| except Exception: | |||||
| logging.exception("Delete metadata failed") | |||||
| finally: | |||||
| redis_client.delete(lock_key) | |||||
| @staticmethod | |||||
| def get_built_in_fields(): | |||||
| return [ | |||||
| {"name": BuiltInField.document_name.value, "type": "string"}, | |||||
| {"name": BuiltInField.uploader.value, "type": "string"}, | |||||
| {"name": BuiltInField.upload_date.value, "type": "time"}, | |||||
| {"name": BuiltInField.last_update_date.value, "type": "time"}, | |||||
| {"name": BuiltInField.source.value, "type": "string"}, | |||||
| ] | |||||
| @staticmethod | |||||
| def enable_built_in_field(dataset: Dataset): | |||||
| if dataset.built_in_field_enabled: | |||||
| return | |||||
| lock_key = f"dataset_metadata_lock_{dataset.id}" | |||||
| try: | |||||
| MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) | |||||
| dataset.built_in_field_enabled = True | |||||
| db.session.add(dataset) | |||||
| documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) | |||||
| if documents: | |||||
| for document in documents: | |||||
| if not document.doc_metadata: | |||||
| doc_metadata = {} | |||||
| else: | |||||
| doc_metadata = copy.deepcopy(document.doc_metadata) | |||||
| doc_metadata[BuiltInField.document_name.value] = document.name | |||||
| doc_metadata[BuiltInField.uploader.value] = document.uploader | |||||
| doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp() | |||||
| doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp() | |||||
| doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value | |||||
| document.doc_metadata = doc_metadata | |||||
| db.session.add(document) | |||||
| db.session.commit() | |||||
| except Exception: | |||||
| logging.exception("Enable built-in field failed") | |||||
| finally: | |||||
| redis_client.delete(lock_key) | |||||
| @staticmethod | |||||
| def disable_built_in_field(dataset: Dataset): | |||||
| if not dataset.built_in_field_enabled: | |||||
| return | |||||
| lock_key = f"dataset_metadata_lock_{dataset.id}" | |||||
| try: | |||||
| MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) | |||||
| dataset.built_in_field_enabled = False | |||||
| db.session.add(dataset) | |||||
| documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) | |||||
| document_ids = [] | |||||
| if documents: | |||||
| for document in documents: | |||||
| doc_metadata = copy.deepcopy(document.doc_metadata) | |||||
| doc_metadata.pop(BuiltInField.document_name.value, None) | |||||
| doc_metadata.pop(BuiltInField.uploader.value, None) | |||||
| doc_metadata.pop(BuiltInField.upload_date.value, None) | |||||
| doc_metadata.pop(BuiltInField.last_update_date.value, None) | |||||
| doc_metadata.pop(BuiltInField.source.value, None) | |||||
| document.doc_metadata = doc_metadata | |||||
| db.session.add(document) | |||||
| document_ids.append(document.id) | |||||
| db.session.commit() | |||||
| except Exception: | |||||
| logging.exception("Disable built-in field failed") | |||||
| finally: | |||||
| redis_client.delete(lock_key) | |||||
| @staticmethod | |||||
| def update_documents_metadata(dataset: Dataset, metadata_args: MetadataOperationData): | |||||
| for operation in metadata_args.operation_data: | |||||
| lock_key = f"document_metadata_lock_{operation.document_id}" | |||||
| try: | |||||
| MetadataService.knowledge_base_metadata_lock_check(None, operation.document_id) | |||||
| document = DocumentService.get_document(dataset.id, operation.document_id) | |||||
| if document is None: | |||||
| raise ValueError("Document not found.") | |||||
| doc_metadata = {} | |||||
| for metadata_value in operation.metadata_list: | |||||
| doc_metadata[metadata_value.name] = metadata_value.value | |||||
| if dataset.built_in_field_enabled: | |||||
| doc_metadata[BuiltInField.document_name.value] = document.name | |||||
| doc_metadata[BuiltInField.uploader.value] = document.uploader | |||||
| doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp() | |||||
| doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp() | |||||
| doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value | |||||
| document.doc_metadata = doc_metadata | |||||
| db.session.add(document) | |||||
| db.session.commit() | |||||
| # deal metadata binding | |||||
| DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete() | |||||
| for metadata_value in operation.metadata_list: | |||||
| dataset_metadata_binding = DatasetMetadataBinding( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| dataset_id=dataset.id, | |||||
| document_id=operation.document_id, | |||||
| metadata_id=metadata_value.id, | |||||
| created_by=current_user.id, | |||||
| ) | |||||
| db.session.add(dataset_metadata_binding) | |||||
| db.session.commit() | |||||
| except Exception: | |||||
| logging.exception("Update documents metadata failed") | |||||
| finally: | |||||
| redis_client.delete(lock_key) | |||||
| @staticmethod | |||||
| def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]): | |||||
| if dataset_id: | |||||
| lock_key = f"dataset_metadata_lock_{dataset_id}" | |||||
| if redis_client.get(lock_key): | |||||
| raise ValueError("Another knowledge base metadata operation is running, please wait a moment.") | |||||
| redis_client.set(lock_key, 1, ex=3600) | |||||
| if document_id: | |||||
| lock_key = f"document_metadata_lock_{document_id}" | |||||
| if redis_client.get(lock_key): | |||||
| raise ValueError("Another document metadata operation is running, please wait a moment.") | |||||
| redis_client.set(lock_key, 1, ex=3600) | |||||
| @staticmethod | |||||
| def get_dataset_metadatas(dataset: Dataset): | |||||
| return { | |||||
| "doc_metadata": [ | |||||
| { | |||||
| "id": item.get("id"), | |||||
| "name": item.get("name"), | |||||
| "type": item.get("type"), | |||||
| "count": DatasetMetadataBinding.query.filter_by( | |||||
| metadata_id=item.get("id"), dataset_id=dataset.id | |||||
| ).count(), | |||||
| } | |||||
| for item in dataset.doc_metadata or [] | |||||
| if item.get("id") != "built-in" | |||||
| ], | |||||
| "built_in_field_enabled": dataset.built_in_field_enabled, | |||||
| } |
| ) | ) | ||||
| if keyword: | if keyword: | ||||
| query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) | query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) | ||||
| query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) | |||||
| query = query.group_by(Tag.id, Tag.type, Tag.name) | |||||
| results: list = query.order_by(Tag.created_at.desc()).all() | results: list = query.order_by(Tag.created_at.desc()).all() | ||||
| return results | return results | ||||