Browse Source

Support knowledge metadata filter (#15982)

tags/1.1.0
Jyong 7 months ago
parent
commit
abeaea4f79
No account linked to committer's email address
48 changed files with 2501 additions and 573 deletions
  1. 1
    0
      api/controllers/console/__init__.py
  2. 2
    2
      api/controllers/console/datasets/datasets_document.py
  3. 155
    0
      api/controllers/console/datasets/metadata.py
  4. 24
    1
      api/core/app/app_config/easy_ui_based_app/dataset/manager.py
  5. 54
    1
      api/core/app/app_config/entities.py
  6. 1
    0
      api/core/app/apps/chat/app_runner.py
  7. 1
    0
      api/core/app/apps/completion/app_runner.py
  8. 6
    5
      api/core/rag/datasource/keyword/jieba/jieba.py
  9. 17
    2
      api/core/rag/datasource/retrieval_service.py
  10. 1
    1
      api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
  11. 12
    2
      api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
  12. 15
    5
      api/core/rag/datasource/vdb/baidu/baidu_vector.py
  13. 9
    1
      api/core/rag/datasource/vdb/chroma/chroma_vector.py
  14. 6
    0
      api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
  15. 10
    2
      api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
  16. 12
    0
      api/core/rag/datasource/vdb/milvus/milvus_vector.py
  17. 4
    0
      api/core/rag/datasource/vdb/myscale/myscale_vector.py
  18. 6
    0
      api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
  19. 6
    0
      api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
  20. 13
    2
      api/core/rag/datasource/vdb/oracle/oraclevector.py
  21. 3
    0
      api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
  22. 13
    0
      api/core/rag/datasource/vdb/pgvector/pgvector.py
  23. 38
    21
      api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
  24. 7
    3
      api/core/rag/datasource/vdb/relyt/relyt_vector.py
  25. 6
    1
      api/core/rag/datasource/vdb/tencent/tencent_vector.py
  26. 24
    0
      api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
  27. 6
    0
      api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
  28. 14
    1
      api/core/rag/datasource/vdb/upstash/upstash_vector.py
  29. 5
    1
      api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
  30. 8
    4
      api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
  31. 45
    0
      api/core/rag/entities/metadata_entities.py
  32. 15
    0
      api/core/rag/index_processor/constant/built_in_field.py
  33. 429
    7
      api/core/rag/retrieval/dataset_retrieval.py
  34. 66
    0
      api/core/rag/retrieval/template_prompts.py
  35. 49
    1
      api/core/workflow/nodes/knowledge_retrieval/entities.py
  36. 4
    0
      api/core/workflow/nodes/knowledge_retrieval/exc.py
  37. 268
    22
      api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
  38. 66
    0
      api/core/workflow/nodes/knowledge_retrieval/template_prompts.py
  39. 10
    0
      api/fields/dataset_fields.py
  40. 9
    0
      api/fields/document_fields.py
  41. 90
    0
      api/migrations/versions/2025_02_27_0917-d20049ed0af6_add_metadata_function.py
  42. 175
    1
      api/models/dataset.py
  43. 457
    482
      api/poetry.lock
  44. 57
    3
      api/services/dataset_service.py
  45. 33
    0
      api/services/entities/knowledge_entities/knowledge_entities.py
  46. 7
    1
      api/services/external_knowledge_service.py
  47. 241
    0
      api/services/metadata_service.py
  48. 1
    1
      api/services/tag_service.py

+ 1
- 0
api/controllers/console/__init__.py View File

datasets_segments, datasets_segments,
external, external,
hit_testing, hit_testing,
metadata,
website, website,
) )



+ 2
- 2
api/controllers/console/datasets/datasets_document.py View File

raise InvalidMetadataError(f"Invalid metadata value: {metadata}") raise InvalidMetadataError(f"Invalid metadata value: {metadata}")


if metadata == "only": if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without": elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() document_process_rules = document.dataset_process_rule.to_dict()
"disabled_by": document.disabled_by, "disabled_by": document.disabled_by,
"archived": document.archived, "archived": document.archived,
"doc_type": document.doc_type, "doc_type": document.doc_type,
"doc_metadata": document.doc_metadata,
"doc_metadata": document.doc_metadata_details,
"segment_count": document.segment_count, "segment_count": document.segment_count,
"average_segment_length": document.average_segment_length, "average_segment_length": document.average_segment_length,
"hit_count": document.hit_count, "hit_count": document.hit_count,

+ 155
- 0
api/controllers/console/datasets/metadata.py View File

from flask_login import current_user # type: ignore # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import NotFound

from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
from libs.login import login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
MetadataOperationData,
)
from services.metadata_service import MetadataService


def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name


def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description


class DatasetMetadataCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
def post(self, dataset_id):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
args = parser.parse_args()
metadata_args = MetadataArgs(**args)

dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)

metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return metadata, 201

@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return MetadataService.get_dataset_metadatas(dataset), 200


class DatasetMetadataApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
def patch(self, dataset_id, metadata_id):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
args = parser.parse_args()

dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)

metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
return metadata, 200

@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def delete(self, dataset_id, metadata_id):
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)

MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
return 200


class DatasetMetadataBuiltInFieldApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
built_in_fields = MetadataService.get_built_in_fields()
return {"fields": built_in_fields}, 200


class DatasetMetadataBuiltInFieldActionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id, action):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)

if action == "enable":
MetadataService.enable_built_in_field(dataset)
elif action == "disable":
MetadataService.disable_built_in_field(dataset)
return 200


class DocumentMetadataEditApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)

parser = reqparse.RequestParser()
parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json")
args = parser.parse_args()
metadata_args = MetadataOperationData(**args)

MetadataService.update_documents_metadata(dataset, metadata_args)

return 200


api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata")

+ 24
- 1
api/core/app/app_config/easy_ui_based_app/dataset/manager.py View File

import uuid import uuid
from typing import Optional from typing import Optional


from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from core.app.app_config.entities import (
DatasetEntity,
DatasetRetrieveConfigEntity,
MetadataFilteringCondition,
ModelConfig,
)
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from models.model import AppMode from models.model import AppMode
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"] dataset_configs["retrieval_model"]
), ),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
if dataset_configs.get("metadata_model_config")
else None,
metadata_filtering_conditions=MetadataFilteringCondition(
**dataset_configs.get("metadata_filtering_conditions", {})
)
if dataset_configs.get("metadata_filtering_conditions")
else None,
), ),
) )
else: else:
weights=dataset_configs.get("weights"), weights=dataset_configs.get("weights"),
reranking_enabled=dataset_configs.get("reranking_enabled", True), reranking_enabled=dataset_configs.get("reranking_enabled", True),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
metadata_filtering_mode=dataset_configs.get("metadata_filtering_mode", "disabled"),
metadata_model_config=ModelConfig(**dataset_configs.get("metadata_model_config"))
if dataset_configs.get("metadata_model_config")
else None,
metadata_filtering_conditions=MetadataFilteringCondition(
**dataset_configs.get("metadata_filtering_conditions", {})
)
if dataset_configs.get("metadata_filtering_conditions")
else None,
), ),
) )



+ 54
- 1
api/core/app/app_config/entities.py View File

from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Optional
from typing import Any, Literal, Optional


from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator


from core.file import FileTransferMethod, FileType, FileUploadConfig from core.file import FileTransferMethod, FileType, FileUploadConfig
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.message_entities import PromptMessageRole
from models.model import AppMode from models.model import AppMode


config: dict[str, Any] = Field(default_factory=dict) config: dict[str, Any] = Field(default_factory=dict)




SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
# for number
"=",
"≠",
">",
"<",
"≥",
"≤",
# for time
"before",
"after",
]


class ModelConfig(BaseModel):
provider: str
name: str
mode: LLMMode
completion_params: dict[str, Any] = {}


class Condition(BaseModel):
"""
Conditon detail
"""

name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None


class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""

logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)


class DatasetRetrieveConfigEntity(BaseModel): class DatasetRetrieveConfigEntity(BaseModel):
""" """
Dataset Retrieve Config Entity. Dataset Retrieve Config Entity.
reranking_model: Optional[dict] = None reranking_model: Optional[dict] = None
weights: Optional[dict] = None weights: Optional[dict] = None
reranking_enabled: Optional[bool] = True reranking_enabled: Optional[bool] = True
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None




class DatasetEntity(BaseModel): class DatasetEntity(BaseModel):

+ 1
- 0
api/core/app/apps/chat/app_runner.py View File

hit_callback=hit_callback, hit_callback=hit_callback,
memory=memory, memory=memory,
message_id=message.id, message_id=message.id,
inputs=inputs,
) )


# reorganize all inputs and template to prompt messages # reorganize all inputs and template to prompt messages

+ 1
- 0
api/core/app/apps/completion/app_runner.py View File

show_retrieve_source=app_config.additional_features.show_retrieve_source, show_retrieve_source=app_config.additional_features.show_retrieve_source,
hit_callback=hit_callback, hit_callback=hit_callback,
message_id=message.id, message_id=message.id,
inputs=inputs,
) )


# reorganize all inputs and template to prompt messages # reorganize all inputs and template to prompt messages

+ 6
- 5
api/core/rag/datasource/keyword/jieba/jieba.py View File

keyword_table = self._get_dataset_keyword_table() keyword_table = self._get_dataset_keyword_table()


k = kwargs.get("top_k", 4) k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)


documents = [] documents = []
for chunk_index in sorted_chunk_indices: for chunk_index in sorted_chunk_indices:
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
.first()
segment_query = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
) )
if document_ids_filter:
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first()


if segment: if segment:
documents.append( documents.append(

+ 17
- 2
api/core/rag/datasource/retrieval_service.py View File

reranking_model: Optional[dict] = None, reranking_model: Optional[dict] = None,
reranking_mode: str = "reranking_model", reranking_mode: str = "reranking_model",
weights: Optional[dict] = None, weights: Optional[dict] = None,
document_ids_filter: Optional[list[str]] = None,
): ):
if not query: if not query:
return [] return []
top_k=top_k, top_k=top_k,
all_documents=all_documents, all_documents=all_documents,
exceptions=exceptions, exceptions=exceptions,
document_ids_filter=document_ids_filter,
) )
) )
if RetrievalMethod.is_support_semantic_search(retrieval_method): if RetrievalMethod.is_support_semantic_search(retrieval_method):
all_documents=all_documents, all_documents=all_documents,
retrieval_method=retrieval_method, retrieval_method=retrieval_method,
exceptions=exceptions, exceptions=exceptions,
document_ids_filter=document_ids_filter,
) )
) )
if RetrievalMethod.is_support_fulltext_search(retrieval_method): if RetrievalMethod.is_support_fulltext_search(retrieval_method):


@classmethod @classmethod
def keyword_search( def keyword_search(
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
exceptions: list,
document_ids_filter: Optional[list[str]] = None,
): ):
with flask_app.app_context(): with flask_app.app_context():
try: try:
raise ValueError("dataset not found") raise ValueError("dataset not found")


keyword = Keyword(dataset=dataset) keyword = Keyword(dataset=dataset)
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)

documents = keyword.search(
cls.escape_query_for_search(query), top_k=top_k, document_ids_filter=document_ids_filter
)
all_documents.extend(documents) all_documents.extend(documents)
except Exception as e: except Exception as e:
exceptions.append(str(e)) exceptions.append(str(e))
all_documents: list, all_documents: list,
retrieval_method: str, retrieval_method: str,
exceptions: list, exceptions: list,
document_ids_filter: Optional[list[str]] = None,
): ):
with flask_app.app_context(): with flask_app.app_context():
try: try:
top_k=top_k, top_k=top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
filter={"group_id": [dataset.id]}, filter={"group_id": [dataset.id]},
document_ids_filter=document_ids_filter,
) )


if documents: if documents:

+ 1
- 1
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py View File

self.analyticdb_vector.delete_by_metadata_field(key, value) self.analyticdb_vector.delete_by_metadata_field(key, value)


def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
return self.analyticdb_vector.search_by_vector(query_vector)
return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)


def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self.analyticdb_vector.search_by_full_text(query, **kwargs) return self.analyticdb_vector.search_by_full_text(query, **kwargs)

+ 12
- 2
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py View File

top_k = kwargs.get("top_k", 4) top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0: if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer") raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = "WHERE 1=1"
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
with self._get_cursor() as cur: with self._get_cursor() as cur:
query_vector_str = json.dumps(query_vector) query_vector_str = json.dumps(query_vector)
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, " f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
f"t.page_content as page_content, t.metadata_ AS metadata_ " f"t.page_content as page_content, t.metadata_ AS metadata_ "
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score " f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t",
(query_vector_str,), (query_vector_str,),
) )
documents = [] documents = []
top_k = kwargs.get("top_k", 4) top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0: if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer") raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
with self._get_cursor() as cur: with self._get_cursor() as cur:
cur.execute( cur.execute(
f"""SELECT id, vector, page_content, metadata_, f"""SELECT id, vector, page_content, metadata_,
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name} FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
ORDER BY score DESC ORDER BY score DESC
LIMIT {top_k}""", LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"), (f"'{query}'", f"'{query}'"),

+ 15
- 5
api/core/rag/datasource/vdb/baidu/baidu_vector.py View File



def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector] query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
filter=f"document_id IN ({document_ids})",
)
else:
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
)
res = self._db.table(self._collection_name).search( res = self._db.table(self._collection_name).search(
anns=anns, anns=anns,
projections=[self.field_id, self.field_text, self.field_metadata], projections=[self.field_id, self.field_text, self.field_metadata],

+ 9
- 1
api/core/rag/datasource/vdb/chroma/chroma_vector.py View File



def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name) collection = self._client.get_or_create_collection(self._collection_name)
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
results: QueryResult = collection.query(
query_embeddings=query_vector,
n_results=kwargs.get("top_k", 4),
where={"document_id": {"$in": document_ids_filter}}, # type: ignore
)
else:
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)


# Check if results contain data # Check if results contain data

+ 6
- 0
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py View File

top_k = kwargs.get("top_k", 4) top_k = kwargs.get("top_k", 4)
num_candidates = math.ceil(top_k * 1.5) num_candidates = math.ceil(top_k * 1.5)
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates} knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}


results = self._client.search(index=self._collection_name, knn=knn, size=top_k) results = self._client.search(index=self._collection_name, knn=knn, size=top_k)




def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {"match": {Field.CONTENT_KEY.value: query}} query_str = {"match": {Field.CONTENT_KEY.value: query}}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
docs = [] docs = []
for hit in results["hits"]["hits"]: for hit in results["hits"]["hits"]:

+ 10
- 2
api/core/rag/datasource/vdb/lindorm/lindorm_vector.py View File

raise ValueError("All elements in query_vector should be floats") raise ValueError("All elements in query_vector should be floats")


top_k = kwargs.get("top_k", 10) top_k = kwargs.get("top_k", 10)
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
document_ids_filter = kwargs.get("document_ids_filter")
filters = []
if document_ids_filter:
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)

try: try:
params = {} params = {}
if self._using_ugc: if self._using_ugc:
should = kwargs.get("should") should = kwargs.get("should")
minimum_should_match = kwargs.get("minimum_should_match", 0) minimum_should_match = kwargs.get("minimum_should_match", 0)
top_k = kwargs.get("top_k", 10) top_k = kwargs.get("top_k", 10)
filters = kwargs.get("filter")
filters = kwargs.get("filter", [])
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
routing = self._routing routing = self._routing
full_text_query = default_text_search_query( full_text_query = default_text_search_query(
query_text=query, query_text=query,

+ 12
- 0
api/core/rag/datasource/vdb/milvus/milvus_vector.py View File

""" """
Search for documents by vector similarity. Search for documents by vector similarity.
""" """
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f'metadata["document_id"] in ({document_ids})'
results = self._client.search( results = self._client.search(
collection_name=self._collection_name, collection_name=self._collection_name,
data=[query_vector], data=[query_vector],
anns_field=Field.VECTOR.value, anns_field=Field.VECTOR.value,
limit=kwargs.get("top_k", 4), limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
filter=filter,
) )


return self._process_search_results( return self._process_search_results(
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value): if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)") logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
return [] return []
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f'metadata["document_id"] in ({document_ids})'


results = self._client.search( results = self._client.search(
collection_name=self._collection_name, collection_name=self._collection_name,
anns_field=Field.SPARSE_VECTOR.value, anns_field=Field.SPARSE_VECTOR.value,
limit=kwargs.get("top_k", 4), limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
filter=filter,
) )


return self._process_search_results( return self._process_search_results(

+ 4
- 0
api/core/rag/datasource/vdb/myscale/myscale_vector.py View File

if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
else "" else ""
) )
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
sql = f""" sql = f"""
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
{where_str} ORDER BY dist {order.value} LIMIT {top_k} {where_str} ORDER BY dist {order.value} LIMIT {top_k}

+ 6
- 0
api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py View File

return [] return []


def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = None
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
ef_search = kwargs.get("ef_search", self._hnsw_ef_search) ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
if ef_search != self._hnsw_ef_search: if ef_search != self._hnsw_ef_search:
self._client.set_ob_hnsw_ef_search(ef_search) self._client.set_ob_hnsw_ef_search(ef_search)
distance_func=func.l2_distance, distance_func=func.l2_distance,
output_column_names=["text", "metadata"], output_column_names=["text", "metadata"],
with_dist=True, with_dist=True,
where_clause=where_clause,
) )
docs = [] docs = []
for text, metadata, distance in cur: for text, metadata, distance in cur:

+ 6
- 0
api/core/rag/datasource/vdb/opensearch/opensearch_vector.py View File

"size": kwargs.get("top_k", 4), "size": kwargs.get("top_k", 4),
"query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}}, "query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}},
} }
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}


try: try:
response = self._client.search(index=self._collection_name.lower(), body=query) response = self._client.search(index=self._collection_name.lower(), body=query)


def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}} full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}


response = self._client.search(index=self._collection_name.lower(), body=full_text_query) response = self._client.search(index=self._collection_name.lower(), body=full_text_query)



+ 13
- 2
api/core/rag/datasource/vdb/oracle/oraclevector.py View File

:return: List of Documents that are nearest to the query vector. :return: List of Documents that are nearest to the query vector.
""" """
top_k = kwargs.get("top_k", 4) top_k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
with self._get_cursor() as cur: with self._get_cursor() as cur:
cur.execute( cur.execute(
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}" f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
f" ORDER BY distance fetch first {top_k} rows only",
f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
[numpy.array(query_vector)], [numpy.array(query_vector)],
) )
docs = [] docs = []
if token not in stop_words: if token not in stop_words:
entities.append(token) entities.append(token)
with self._get_cursor() as cur: with self._get_cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
cur.execute( cur.execute(
f"select meta, text, embedding FROM {self.table_name}" f"select meta, text, embedding FROM {self.table_name}"
f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only",
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
f"order by score(1) desc fetch first {top_k} rows only",
[" ACCUM ".join(entities)], [" ACCUM ".join(entities)],
) )
docs = [] docs = []

+ 3
- 0
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py View File

.limit(kwargs.get("top_k", 4)) .limit(kwargs.get("top_k", 4))
.order_by("distance") .order_by("distance")
) )
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
stmt = stmt.where(self._table.meta["document_id"].in_(document_ids_filter))
res = session.execute(stmt) res = session.execute(stmt)
results = [(row[0], row[1]) for row in res] results = [(row[0], row[1]) for row in res]



+ 13
- 0
api/core/rag/datasource/vdb/pgvector/pgvector.py View File

top_k = kwargs.get("top_k", 4) top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0: if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer") raise ValueError("top_k must be a positive integer")
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" WHERE metadata->>'document_id' in ({document_ids}) "


with self._get_cursor() as cur: with self._get_cursor() as cur:
cur.execute( cur.execute(
f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}" f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
f" {where_clause}"
f" ORDER BY distance LIMIT {top_k}", f" ORDER BY distance LIMIT {top_k}",
(json.dumps(query_vector),), (json.dumps(query_vector),),
) )
if not isinstance(top_k, int) or top_k <= 0: if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer") raise ValueError("top_k must be a positive integer")
with self._get_cursor() as cur: with self._get_cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
if self.pg_bigm: if self.pg_bigm:
cur.execute("SET pg_bigm.similarity_limit TO 0.000001") cur.execute("SET pg_bigm.similarity_limit TO 0.000001")
cur.execute( cur.execute(
f"""SELECT meta, text, bigm_similarity(unistr(%s), coalesce(text, '')) AS score f"""SELECT meta, text, bigm_similarity(unistr(%s), coalesce(text, '')) AS score
FROM {self.table_name} FROM {self.table_name}
WHERE text =%% unistr(%s) WHERE text =%% unistr(%s)
{where_clause}
ORDER BY score DESC ORDER BY score DESC
LIMIT {top_k}""", LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query # f"'{query}'" is required in order to account for whitespace in query
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
FROM {self.table_name} FROM {self.table_name}
WHERE to_tsvector(text) @@ plainto_tsquery(%s) WHERE to_tsvector(text) @@ plainto_tsquery(%s)
{where_clause}
ORDER BY score DESC ORDER BY score DESC
LIMIT {top_k}""", LIMIT {top_k}""",
# f"'{query}'" is required in order to account for whitespace in query # f"'{query}'" is required in order to account for whitespace in query

+ 38
- 21
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py View File

from qdrant_client.http import models from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.exceptions import UnexpectedResponse


for node_id in ids:
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchAny(any=ids),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(filter=filter),
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e


def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
all_collection_name = [] all_collection_name = []
), ),
], ],
) )
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
if filter.must:
filter.must.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchAny(any=document_ids_filter),
)
)
results = self._client.search( results = self._client.search(
collection_name=self._collection_name, collection_name=self._collection_name,
query_vector=query_vector, query_vector=query_vector,
), ),
] ]
) )
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
if scroll_filter.must:
scroll_filter.must.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchAny(any=document_ids_filter),
)
)
response = self._client.scroll( response = self._client.scroll(
collection_name=self._collection_name, collection_name=self._collection_name,
scroll_filter=scroll_filter, scroll_filter=scroll_filter,

+ 7
- 3
api/core/rag/datasource/vdb/relyt/relyt_vector.py View File

return len(result) > 0 return len(result) > 0


def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
filter = kwargs.get("filter", {})
if document_ids_filter:
filter["document_id"] = document_ids_filter
results = self.similarity_search_with_score_by_vector( results = self.similarity_search_with_score_by_vector(
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter")
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter
) )


# Organize results. # Organize results.
filter_condition = "" filter_condition = ""
if filter is not None: if filter is not None:
conditions = [ conditions = [
f"metadata->>{key!r} in ({', '.join(map(repr, value))})"
f"metadata->>'{key!r}' in ({', '.join(map(repr, value))})"
if len(value) > 1 if len(value) > 1
else f"metadata->>{key!r} = {value[0]!r}"
else f"metadata->>'{key!r}' = {value[0]!r}"
for key, value in filter.items() for key, value in filter.items()
] ]
filter_condition = f"WHERE {' AND '.join(conditions)}" filter_condition = f"WHERE {' AND '.join(conditions)}"

+ 6
- 1
api/core/rag/datasource/vdb/tencent/tencent_vector.py View File

self._db.collection(self._collection_name).delete(document_ids=ids) self._db.collection(self._collection_name).delete(document_ids=ids)


def delete_by_metadata_field(self, key: str, value: str) -> None: def delete_by_metadata_field(self, key: str, value: str) -> None:
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(f"metadata.{key}", [value])))


def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
res = self._db.collection(self._collection_name).search( res = self._db.collection(self._collection_name).search(
vectors=[query_vector], vectors=[query_vector],
filter=filter,
params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)), params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)),
retrieve_vector=False, retrieve_vector=False,
limit=kwargs.get("top_k", 4), limit=kwargs.get("top_k", 4),

+ 24
- 0
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py View File

), ),
], ],
) )
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
should_conditions = []
for document_id_filter in document_ids_filter:
should_conditions.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchValue(value=document_id_filter),
)
)
if should_conditions:
filter.should = should_conditions # type: ignore
results = self._client.search( results = self._client.search(
collection_name=self._collection_name, collection_name=self._collection_name,
query_vector=query_vector, query_vector=query_vector,
) )
] ]
) )
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
should_conditions = []
for document_id_filter in document_ids_filter:
should_conditions.append(
models.FieldCondition(
key="metadata.document_id",
match=models.MatchValue(value=document_id_filter),
)
)
if should_conditions:
scroll_filter.should = should_conditions # type: ignore
response = self._client.scroll( response = self._client.scroll(
collection_name=self._collection_name, collection_name=self._collection_name,
scroll_filter=scroll_filter, scroll_filter=scroll_filter,

+ 6
- 0
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py View File



docs = [] docs = []
tidb_dist_func = self._get_distance_func() tidb_dist_func = self._get_distance_func()
document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" WHERE meta->>'$.document_id' in ({document_ids}) "


with Session(self._engine) as session: with Session(self._engine) as session:
select_statement = sql_text(f""" select_statement = sql_text(f"""
text, text,
{tidb_dist_func}(vector, :query_vector_str) AS distance {tidb_dist_func}(vector, :query_vector_str) AS distance
FROM {self._collection_name} FROM {self._collection_name}
{where_clause}
ORDER BY distance ASC ORDER BY distance ASC
LIMIT :top_k LIMIT :top_k
) t ) t

+ 14
- 1
api/core/rag/datasource/vdb/upstash/upstash_vector.py View File



def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4) top_k = kwargs.get("top_k", 4)
result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f"document_id in ({document_ids})"
else:
filter = ""
result = self.index.query(
vector=query_vector,
top_k=top_k,
include_metadata=True,
include_data=True,
include_vectors=False,
filter=filter,
)
docs = [] docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
for record in result: for record in result:

+ 5
- 1
api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py View File

query_vector, limit=kwargs.get("top_k", 4) query_vector, limit=kwargs.get("top_k", 4)
) )
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(results, score_threshold)
docs = self._get_search_res(results, score_threshold)
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
docs = [doc for doc in docs if doc.metadata.get("document_id") in document_ids_filter]
return docs


def _get_search_res(self, results, score_threshold) -> list[Document]: def _get_search_res(self, results, score_threshold) -> list[Document]:
if len(results) == 0: if len(results) == 0:

+ 8
- 4
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py View File

query_obj = self._client.query.get(collection_name, properties) query_obj = self._client.query.get(collection_name, properties)


vector = {"vector": query_vector} vector = {"vector": query_vector}
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
query_obj = query_obj.with_where(where_filter)
result = ( result = (
query_obj.with_near_vector(vector) query_obj.with_near_vector(vector)
.with_limit(kwargs.get("top_k", 4)) .with_limit(kwargs.get("top_k", 4))
if kwargs.get("search_distance"): if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance") content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(collection_name, properties) query_obj = self._client.query.get(collection_name, properties)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
query_obj = query_obj.with_where(where_filter)
query_obj = query_obj.with_additional(["vector"]) query_obj = query_obj.with_additional(["vector"])
properties = ["text"] properties = ["text"]
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do() result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()

+ 45
- 0
api/core/rag/entities/metadata_entities.py View File

from collections.abc import Sequence
from typing import Literal, Optional

from pydantic import BaseModel, Field

SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
# for number
"=",
"≠",
">",
"<",
"≥",
"≤",
# for time
"before",
"after",
]


class Condition(BaseModel):
"""
Conditon detail
"""

name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None


class MetadataCondition(BaseModel):
"""
Metadata Condition.
"""

logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)

+ 15
- 0
api/core/rag/index_processor/constant/built_in_field.py View File

from enum import Enum


class BuiltInField(str, Enum):
document_name = "document_name"
uploader = "uploader"
upload_date = "upload_date"
last_update_date = "last_update_date"
source = "source"


class MetadataDataSource(Enum):
upload_file = "file_upload"
website_crawl = "website"
notion_import = "notion"

+ 429
- 7
api/core/rag/retrieval/dataset_retrieval.py View File

import json
import math import math
import re
import threading import threading
from collections import Counter
from typing import Any, Optional, cast
from collections import Counter, defaultdict
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union, cast


from flask import Flask, current_app from flask import Flask, current_app

from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
from sqlalchemy import Integer, and_, or_, text
from sqlalchemy import cast as sqlalchemy_cast

from core.app.app_config.entities import (
DatasetEntity,
DatasetRetrieveConfigEntity,
MetadataFilteringCondition,
ModelConfig,
)
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time from core.ops.utils import measure_time
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
from core.rag.retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetQuery, DocumentSegment
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService


hit_callback: DatasetIndexToolCallbackHandler, hit_callback: DatasetIndexToolCallbackHandler,
message_id: str, message_id: str,
memory: Optional[TokenBufferMemory] = None, memory: Optional[TokenBufferMemory] = None,
inputs: Optional[Mapping[str, Any]] = None,
) -> Optional[str]: ) -> Optional[str]:
""" """
Retrieve dataset. Retrieve dataset.
continue continue


available_datasets.append(dataset) available_datasets.append(dataset)
if inputs:
inputs = {key: str(value) for key, value in inputs.items()}
else:
inputs = {}
available_datasets_ids = [dataset.id for dataset in available_datasets]
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
available_datasets_ids,
query,
tenant_id,
user_id,
retrieve_config.metadata_filtering_mode, # type: ignore
retrieve_config.metadata_model_config, # type: ignore
retrieve_config.metadata_filtering_conditions,
inputs,
)

all_documents = [] all_documents = []
user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
model_config, model_config,
planning_strategy, planning_strategy,
message_id, message_id,
metadata_filter_document_ids,
metadata_condition,
) )
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
all_documents = self.multiple_retrieve( all_documents = self.multiple_retrieve(
retrieve_config.weights, retrieve_config.weights,
retrieve_config.reranking_enabled or True, retrieve_config.reranking_enabled or True,
message_id, message_id,
metadata_filter_document_ids,
metadata_condition,
) )


dify_documents = [item for item in all_documents if item.provider == "dify"] dify_documents = [item for item in all_documents if item.provider == "dify"]
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy, planning_strategy: PlanningStrategy,
message_id: Optional[str] = None, message_id: Optional[str] = None,
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
metadata_condition: Optional[MetadataCondition] = None,
): ):
tools = [] tools = []
for dataset in available_datasets: for dataset in available_datasets:
dataset_id=dataset_id, dataset_id=dataset_id,
query=query, query=query,
external_retrieval_parameters=dataset.retrieval_model, external_retrieval_parameters=dataset.retrieval_model,
metadata_condition=metadata_condition,
) )
for external_document in external_documents: for external_document in external_documents:
document = Document( document = Document(
document.metadata["dataset_name"] = dataset.name document.metadata["dataset_name"] = dataset.name
results.append(document) results.append(document)
else: else:
if metadata_condition and not metadata_filter_document_ids:
return []
document_ids_filter = None
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
return []
retrieval_model_config = dataset.retrieval_model or default_retrieval_model retrieval_model_config = dataset.retrieval_model or default_retrieval_model


# get top k # get top k
reranking_model=reranking_model, reranking_model=reranking_model,
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"), reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
weights=retrieval_model_config.get("weights", None), weights=retrieval_model_config.get("weights", None),
document_ids_filter=document_ids_filter,
) )
self._on_query(query, [dataset_id], app_id, user_from, user_id) self._on_query(query, [dataset_id], app_id, user_from, user_id)


weights: Optional[dict[str, Any]] = None, weights: Optional[dict[str, Any]] = None,
reranking_enable: bool = True, reranking_enable: bool = True,
message_id: Optional[str] = None, message_id: Optional[str] = None,
metadata_filter_document_ids: Optional[dict[str, list[str]]] = None,
metadata_condition: Optional[MetadataCondition] = None,
): ):
if not available_datasets: if not available_datasets:
return [] return []


for dataset in available_datasets: for dataset in available_datasets:
index_type = dataset.indexing_technique index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread( retrieval_thread = threading.Thread(
target=self._retriever, target=self._retriever,
kwargs={ kwargs={
"query": query, "query": query,
"top_k": top_k, "top_k": top_k,
"all_documents": all_documents, "all_documents": all_documents,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
}, },
) )
threads.append(retrieval_thread) threads.append(retrieval_thread)
db.session.add_all(dataset_queries) db.session.add_all(dataset_queries)
db.session.commit() db.session.commit()


def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
def _retriever(
self,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
all_documents: list,
document_ids_filter: Optional[list[str]] = None,
metadata_condition: Optional[MetadataCondition] = None,
):
with flask_app.app_context(): with flask_app.app_context():
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()


dataset_id=dataset_id, dataset_id=dataset_id,
query=query, query=query,
external_retrieval_parameters=dataset.retrieval_model, external_retrieval_parameters=dataset.retrieval_model,
metadata_condition=metadata_condition,
) )
for external_document in external_documents: for external_document in external_documents:
document = Document( document = Document(
else None, else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None), weights=retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
) )


all_documents.extend(documents) all_documents.extend(documents)
filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
) )
return filter_documents[:top_k] if top_k else filter_documents return filter_documents[:top_k] if top_k else filter_documents

def _get_metadata_filter_condition(
self,
dataset_ids: list,
query: str,
tenant_id: str,
user_id: str,
metadata_filtering_mode: str,
metadata_model_config: ModelConfig,
metadata_filtering_conditions: Optional[MetadataFilteringCondition],
inputs: dict,
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
document_query = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
filters = [] # type: ignore
metadata_condition = None
if metadata_filtering_mode == "disabled":
return None, None
elif metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(
dataset_ids, query, tenant_id, user_id, metadata_model_config
)
if automatic_metadata_filters:
conditions = []
for filter in automatic_metadata_filters:
self._process_metadata_filter_func(
filter.get("condition"), # type: ignore
filter.get("metadata_name"), # type: ignore
filter.get("value"),
filters, # type: ignore
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=metadata_filtering_conditions.logical_operator, # type: ignore
conditions=conditions,
)
elif metadata_filtering_mode == "manual":
if metadata_filtering_conditions:
metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
for condition in metadata_filtering_conditions.conditions: # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value or condition.comparison_operator in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
filters = self._process_metadata_filter_func(
condition.comparison_operator, metadata_name, expected_value, filters
)
else:
raise ValueError("Invalid metadata filtering mode")
if filters:
if metadata_filtering_conditions.logical_operator == "or": # type: ignore
document_query = document_query.filter(or_(*filters))
else:
document_query = document_query.filter(and_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition

def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str:
def replacer(match):
key = match.group(1)
return str(inputs.get(key, f"{{{{{key}}}}}"))

pattern = re.compile(r"\{\{(\w+)\}\}")
return pattern.sub(replacer, text)

def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> Optional[list[dict[str, Any]]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config
if metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config)

# fetch prompt messages
prompt_messages, stop = self._get_prompt_template(
model_config=model_config,
mode=metadata_model_config.mode,
metadata_fields=all_metadata_fields,
query=query or "",
)

result_text = ""
try:
# handle invoke result
invoke_result = cast(
Generator[LLMResult, None, None],
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_config.parameters,
stop=stop,
stream=True,
user=user_id,
),
)

# handle invoke result
result_text, usage = self._handle_invoke_result(invoke_result=invoke_result)

result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []
if "metadata_map" in result_text_json:
metadata_map = result_text_json["metadata_map"]
for item in metadata_map:
if item.get("metadata_field_name") in all_metadata_fields:
automatic_metadata_filters.append(
{
"metadata_name": item.get("metadata_field_name"),
"value": item.get("metadata_field_value"),
"condition": item.get("comparison_operator"),
}
)
except Exception as e:
return None
return automatic_metadata_filters

def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[Any], filters: list):
match condition:
case "contains":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%")
)
case "not contains":
filters.append(
(text("documents.doc_metadata ->> :key NOT LIKE :value")).params(
key=metadata_name, value=f"%{value}%"
)
)
case "start with":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%")
)

case "end with":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}")
)
case "is" | "=":
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
else:
filters.append(
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value
)
case "is not" | "≠":
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
else:
filters.append(
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value
)
case "empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value)
case "after" | ">":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value)
case "≤" | ">=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value)
case "≥" | ">=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value)
case _:
pass
return filters

def _fetch_model_config(
self, tenant_id: str, model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data: node data
:return:
"""
if model is None:
raise ValueError("single_retrieval_config is required")
model_name = model.name
provider_name = model.provider

model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)

provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)

model_credentials = model_instance.credentials

# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM
)

if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")

if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ValueError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ValueError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise ValueError(f"Model provider {provider_name} quota exceeded.")

# model config
completion_params = model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]

# get model mode
model_mode = model.mode
if not model_mode:
raise ValueError("LLM mode is required.")

model_schema = model_type_instance.get_model_schema(model_name, model_credentials)

if not model_schema:
raise ValueError(f"Model {model_name} not exist.")

return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)

def _get_prompt_template(
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
):
model_mode = ModelMode.value_of(mode)
input_text = query

prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.CHAT:
prompt_template = []
system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT)
prompt_template.append(system_prompt_messages)
user_prompt_message_1 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1)
prompt_template.append(user_prompt_message_1)
assistant_prompt_message_1 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
)
prompt_template.append(assistant_prompt_message_1)
user_prompt_message_2 = ChatModelMessage(role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2)
prompt_template.append(user_prompt_message_2)
assistant_prompt_message_2 = ChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
)
prompt_template.append(assistant_prompt_message_2)
user_prompt_message_3 = ChatModelMessage(
role=PromptMessageRole.USER,
text=METADATA_FILTER_USER_PROMPT_3.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
),
)
prompt_template.append(user_prompt_message_3)
elif model_mode == ModelMode.COMPLETION:
prompt_template = CompletionModelPromptTemplate(
text=METADATA_FILTER_COMPLETION_PROMPT.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
)
)

else:
raise ValueError(f"Model mode {model_mode} not support.")

prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query=query or "",
files=[],
context=None,
memory_config=None,
memory=None,
model_config=model_config,
)
stop = model_config.stop

return prompt_messages, stop

def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
model = None
prompt_messages: list[PromptMessage] = []
full_text = ""
usage = None
for result in invoke_result:
text = result.delta.message.content
full_text += text

if not model:
model = result.model

if not prompt_messages:
prompt_messages = result.prompt_messages

if not usage and result.delta.usage:
usage = result.delta.usage

if not usage:
usage = LLMUsage.empty_usage()

return full_text, usage

+ 66
- 0
api/core/rag/retrieval/template_prompts.py View File

METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
""" # noqa: E501

METADATA_FILTER_USER_PROMPT_1 = """
{ "input_text": "I want to know which company’s email address test@example.com is?",
"metadata_fields": ["filename", "email", "phone", "address"]
}
"""

METADATA_FILTER_ASSISTANT_PROMPT_1 = """
```json
{"metadata_map": [
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
]
}
```
"""

METADATA_FILTER_USER_PROMPT_2 = """
{"input_text": "What are the movies with a score of more than 9 in 2024?",
"metadata_fields": ["name", "year", "rating", "country"]}
"""

METADATA_FILTER_ASSISTANT_PROMPT_2 = """
```json
{"metadata_map": [
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
]}
```
"""

METADATA_FILTER_USER_PROMPT_3 = """
'{{"input_text": "{input_text}",',
'"metadata_fields": {metadata_fields}}}'
"""

METADATA_FILTER_COMPLETION_PROMPT = """
### Job Description
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
</example>
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output
""" # noqa: E501

+ 49
- 1
api/core/workflow/nodes/knowledge_retrieval/entities.py View File

from collections.abc import Sequence
from typing import Any, Literal, Optional from typing import Any, Literal, Optional


from pydantic import BaseModel
from pydantic import BaseModel, Field


from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import VisionConfig




class RerankingModelConfig(BaseModel): class RerankingModelConfig(BaseModel):
model: ModelConfig model: ModelConfig




SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
# for number
"=",
"≠",
">",
"<",
"≥",
"≤",
# for time
"before",
"after",
]


class Condition(BaseModel):
"""
Conditon detail
"""

name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None


class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""

logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)


class KnowledgeRetrievalNodeData(BaseNodeData): class KnowledgeRetrievalNodeData(BaseNodeData):
""" """
Knowledge retrieval Node Data. Knowledge retrieval Node Data.
retrieval_mode: Literal["single", "multiple"] retrieval_mode: Literal["single", "multiple"]
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None single_retrieval_config: Optional[SingleRetrievalConfig] = None
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
vision: VisionConfig = Field(default_factory=VisionConfig)

+ 4
- 0
api/core/workflow/nodes/knowledge_retrieval/exc.py View File



class ModelQuotaExceededError(KnowledgeRetrievalNodeError): class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
"""Raised when the model provider quota is exceeded.""" """Raised when the model provider quota is exceeded."""


class InvalidModelTypeError(KnowledgeRetrievalNodeError):
"""Raised when the model is not a Large Language Model."""

+ 268
- 22
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py View File

import json
import logging import logging
import time import time
from collections import defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Any, Optional, cast


from sqlalchemy import func
from sqlalchemy import Integer, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast


from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment from core.variables import StringSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document, RateLimitLog
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from services.feature_service import FeatureService from services.feature_service import FeatureService


from .entities import KnowledgeRetrievalNodeData
from .entities import KnowledgeRetrievalNodeData, ModelConfig
from .exc import ( from .exc import (
InvalidModelTypeError,
KnowledgeRetrievalNodeError, KnowledgeRetrievalNodeError,
ModelCredentialsNotInitializedError, ModelCredentialsNotInitializedError,
ModelNotExistError, ModelNotExistError,
} }




class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
_node_data_cls = KnowledgeRetrievalNodeData
class KnowledgeRetrievalNode(LLMNode):
_node_data_cls = KnowledgeRetrievalNodeData # type: ignore
_node_type = NodeType.KNOWLEDGE_RETRIEVAL _node_type = NodeType.KNOWLEDGE_RETRIEVAL


def _run(self) -> NodeRunResult:
def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables # extract variables
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
if not isinstance(variable, StringSegment): if not isinstance(variable, StringSegment):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,


# retrieve knowledge # retrieve knowledge
try: try:
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
outputs = {"result": results} outputs = {"result": results}
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
if not dataset: if not dataset:
continue continue
available_datasets.append(dataset) available_datasets.append(dataset)
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
all_documents = [] all_documents = []
dataset_retrieval = DatasetRetrieval() dataset_retrieval = DatasetRetrieval()
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
# fetch model config # fetch model config
model_instance, model_config = self._fetch_model_config(node_data)
model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore
# check model is support tool calling # check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance) model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_config=model_config, model_config=model_config,
model_instance=model_instance, model_instance=model_instance,
planning_strategy=planning_strategy, planning_strategy=planning_strategy,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
) )
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
if node_data.multiple_retrieval_config is None: if node_data.multiple_retrieval_config is None:
reranking_model=reranking_model, reranking_model=reranking_model,
weights=weights, weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable, reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
) )
dify_documents = [item for item in all_documents if item.provider == "dify"] dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"] external_documents = [item for item in all_documents if item.provider == "external"]
item["metadata"]["position"] = position item["metadata"]["position"] = position
return retrieval_resource_list return retrieval_resource_list


def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
document_query = db.session.query(Document).filter(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
filters = [] # type: ignore
metadata_condition = None
if node_data.metadata_filtering_mode == "disabled":
return None, None
elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters = self._automatic_metadata_filter_func(dataset_ids, query, node_data)
if automatic_metadata_filters:
conditions = []
for filter in automatic_metadata_filters:
self._process_metadata_filter_func(
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters, # type: ignore
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator, # type: ignore
conditions=conditions,
)
elif node_data.metadata_filtering_mode == "manual":
if node_data.metadata_filtering_conditions:
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
if node_data.metadata_filtering_conditions:
for condition in node_data.metadata_filtering_conditions.conditions: # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value or condition.comparison_operator in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).text

filters = self._process_metadata_filter_func(
condition.comparison_operator, metadata_name, expected_value, filters
)
else:
raise ValueError("Invalid metadata filtering mode")
if filters:
if node_data.metadata_filtering_conditions.logical_operator == "and": # type: ignore
document_query = document_query.filter(and_(*filters))
else:
document_query = document_query.filter(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition

def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.field_name for metadata_field in metadata_fields]
# get metadata model config
metadata_model_config = node_data.metadata_model_config
if metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore
# fetch prompt messages
prompt_template = self._get_prompt_template(
node_data=node_data,
metadata_fields=all_metadata_fields,
query=query or "",
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
sys_query=query,
memory=None,
model_config=model_config,
sys_files=[],
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
)

result_text = ""
try:
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.metadata_model_config, # type: ignore
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)

for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
break

result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []
if "metadata_map" in result_text_json:
metadata_map = result_text_json["metadata_map"]
for item in metadata_map:
if item.get("metadata_field_name") in all_metadata_fields:
automatic_metadata_filters.append(
{
"metadata_name": item.get("metadata_field_name"),
"value": item.get("metadata_field_value"),
"condition": item.get("comparison_operator"),
}
)
except Exception as e:
return []
return automatic_metadata_filters

def _process_metadata_filter_func(self, condition: str, metadata_name: str, value: Optional[str], filters: list):
match condition:
case "contains":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}%")
)
case "not contains":
filters.append(
(text("documents.doc_metadata ->> :key NOT LIKE :value")).params(
key=metadata_name, value=f"%{value}%"
)
)
case "start with":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"{value}%")
)
case "end with":
filters.append(
(text("documents.doc_metadata ->> :key LIKE :value")).params(key=metadata_name, value=f"%{value}")
)
case "=" | "is":
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
else:
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value)
case "is not" | "≠":
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
else:
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value)
case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(Document.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value)
case "after" | ">":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value)
case "≤" | ">=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value)
case "≥" | ">=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value)
case _:
pass
return filters

@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, cls,
*, *,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
node_id: str, node_id: str,
node_data: KnowledgeRetrievalNodeData,
node_data: KnowledgeRetrievalNodeData, # type: ignore
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" """
Extract variable selector to variable mapping Extract variable selector to variable mapping
variable_mapping[node_id + ".query"] = node_data.query_variable_selector variable_mapping[node_id + ".query"] = node_data.query_variable_selector
return variable_mapping return variable_mapping


def _fetch_model_config(
self, node_data: KnowledgeRetrievalNodeData
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore
""" """
Fetch model config Fetch model config
:param node_data: node data
:param model: model
:return: :return:
""" """
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
model_name = node_data.single_retrieval_config.model.name
provider_name = node_data.single_retrieval_config.model.provider
if model is None:
raise ValueError("model is required")
model_name = model.name
provider_name = model.provider


model_manager = ModelManager() model_manager = ModelManager()
model_instance = model_manager.get_model_instance( model_instance = model_manager.get_model_instance(
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.") raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")


# model config # model config
completion_params = node_data.single_retrieval_config.model.completion_params
completion_params = model.completion_params
stop = [] stop = []
if "stop" in completion_params: if "stop" in completion_params:
stop = completion_params["stop"] stop = completion_params["stop"]
del completion_params["stop"] del completion_params["stop"]


# get model mode # get model mode
model_mode = node_data.single_retrieval_config.model.mode
model_mode = model.mode
if not model_mode: if not model_mode:
raise ModelNotExistError("LLM mode is required.") raise ModelNotExistError("LLM mode is required.")


parameters=completion_params, parameters=completion_params,
stop=stop, stop=stop,
) )

def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore
input_text = query
memory_str = ""

prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER,
text=METADATA_FILTER_USER_PROMPT_3.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
),
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=METADATA_FILTER_COMPLETION_PROMPT.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
)
)

else:
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

+ 66
- 0
api/core/workflow/nodes/knowledge_retrieval/template_prompts.py View File

METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
""" # noqa: E501

METADATA_FILTER_USER_PROMPT_1 = """
{ "input_text": "I want to know which company’s email address test@example.com is?",
"metadata_fields": ["filename", "email", "phone", "address"]
}
"""

METADATA_FILTER_ASSISTANT_PROMPT_1 = """
```json
{"metadata_map": [
{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}
]
}
```
"""

METADATA_FILTER_USER_PROMPT_2 = """
{"input_text": "What are the movies with a score of more than 9 in 2024?",
"metadata_fields": ["name", "year", "rating", "country"]}
"""

METADATA_FILTER_ASSISTANT_PROMPT_2 = """
```json
{"metadata_map": [
{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="},
{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"},
]}
```
"""

METADATA_FILTER_USER_PROMPT_3 = """
'{{"input_text": "{input_text}",',
'"metadata_fields": {metadata_fields}}}'
"""

METADATA_FILTER_COMPLETION_PROMPT = """
### Job Description
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside <example></example> XML tags.
<example>
User:{{"input_text": ["I want to know which company’s email address test@example.com is?"], "metadata_fields": ["filename", "email", "phone", "address"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
</example>
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output
""" # noqa: E501

+ 10
- 0
api/fields/dataset_fields.py View File

"external_knowledge_api_endpoint": fields.String, "external_knowledge_api_endpoint": fields.String,
} }


doc_metadata_fields = {"id": fields.String, "name": fields.String, "type": fields.String}

dataset_detail_fields = { dataset_detail_fields = {
"id": fields.String, "id": fields.String,
"name": fields.String, "name": fields.String,
"doc_form": fields.String, "doc_form": fields.String,
"external_knowledge_info": fields.Nested(external_knowledge_info_fields), "external_knowledge_info": fields.Nested(external_knowledge_info_fields),
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
"doc_metadata": fields.List(fields.Nested(doc_metadata_fields)),
"built_in_field_enabled": fields.Boolean,
} }


dataset_query_detail_fields = { dataset_query_detail_fields = {
"created_by": fields.String, "created_by": fields.String,
"created_at": TimestampField, "created_at": TimestampField,
} }

dataset_metadata_fields = {
"id": fields.String,
"type": fields.String,
"name": fields.String,
}

+ 9
- 0
api/fields/document_fields.py View File

from fields.dataset_fields import dataset_fields from fields.dataset_fields import dataset_fields
from libs.helper import TimestampField from libs.helper import TimestampField


document_metadata_fields = {
"id": fields.String,
"name": fields.String,
"type": fields.String,
"value": fields.String,
}

document_fields = { document_fields = {
"id": fields.String, "id": fields.String,
"position": fields.Integer, "position": fields.Integer,
"word_count": fields.Integer, "word_count": fields.Integer,
"hit_count": fields.Integer, "hit_count": fields.Integer,
"doc_form": fields.String, "doc_form": fields.String,
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
} }


document_with_segments_fields = { document_with_segments_fields = {
"hit_count": fields.Integer, "hit_count": fields.Integer,
"completed_segments": fields.Integer, "completed_segments": fields.Integer,
"total_segments": fields.Integer, "total_segments": fields.Integer,
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
} }


dataset_and_document_fields = { dataset_and_document_fields = {

+ 90
- 0
api/migrations/versions/2025_02_27_0917-d20049ed0af6_add_metadata_function.py View File

"""add_metadata_function

Revision ID: d20049ed0af6
Revises: 08ec4f75af5e
Create Date: 2025-02-27 09:17:48.903213

"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = 'd20049ed0af6'
down_revision = 'f051706725cc'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('dataset_metadata_bindings',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('metadata_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.PrimaryKeyConstraint('id', name='dataset_metadata_binding_pkey')
)
with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op:
batch_op.create_index('dataset_metadata_binding_dataset_idx', ['dataset_id'], unique=False)
batch_op.create_index('dataset_metadata_binding_document_idx', ['document_id'], unique=False)
batch_op.create_index('dataset_metadata_binding_metadata_idx', ['metadata_id'], unique=False)
batch_op.create_index('dataset_metadata_binding_tenant_idx', ['tenant_id'], unique=False)

op.create_table('dataset_metadatas',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('type', sa.String(length=255), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint('id', name='dataset_metadata_pkey')
)
with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op:
batch_op.create_index('dataset_metadata_dataset_idx', ['dataset_id'], unique=False)
batch_op.create_index('dataset_metadata_tenant_idx', ['tenant_id'], unique=False)

with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('built_in_field_enabled', sa.Boolean(), server_default=sa.text('false'), nullable=False))

with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.alter_column('doc_metadata',
existing_type=postgresql.JSON(astext_type=sa.Text()),
type_=postgresql.JSONB(astext_type=sa.Text()),
existing_nullable=True)
batch_op.create_index('document_metadata_idx', ['doc_metadata'], unique=False, postgresql_using='gin')
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.drop_index('document_metadata_idx', postgresql_using='gin')
batch_op.alter_column('doc_metadata',
existing_type=postgresql.JSONB(astext_type=sa.Text()),
type_=postgresql.JSON(astext_type=sa.Text()),
existing_nullable=True)

with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('built_in_field_enabled')

with op.batch_alter_table('dataset_metadatas', schema=None) as batch_op:
batch_op.drop_index('dataset_metadata_tenant_idx')
batch_op.drop_index('dataset_metadata_dataset_idx')

op.drop_table('dataset_metadatas')
with op.batch_alter_table('dataset_metadata_bindings', schema=None) as batch_op:
batch_op.drop_index('dataset_metadata_binding_tenant_idx')
batch_op.drop_index('dataset_metadata_binding_metadata_idx')
batch_op.drop_index('dataset_metadata_binding_document_idx')
batch_op.drop_index('dataset_metadata_binding_dataset_idx')

op.drop_table('dataset_metadata_bindings')
# ### end Alembic commands ###

+ 175
- 1
api/models/dataset.py View File

from sqlalchemy.orm import Mapped from sqlalchemy.orm import Mapped


from configs import dify_config from configs import dify_config
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_storage import storage from extensions.ext_storage import storage
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
embedding_model_provider = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(StringUUID, nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True) retrieval_model = db.Column(JSONB, nullable=True)
built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))


@property @property
def dataset_keyword_table(self): def dataset_keyword_table(self):
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
} }


@property
def doc_metadata(self):
dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all()

doc_metadata = [
{
"id": dataset_metadata.id,
"name": dataset_metadata.name,
"type": dataset_metadata.type,
}
for dataset_metadata in dataset_metadatas
]
if self.built_in_field_enabled:
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.document_name.value,
"type": "string",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.uploader.value,
"type": "string",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.upload_date.value,
"type": "time",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.last_update_date.value,
"type": "time",
}
)
doc_metadata.append(
{
"id": "built-in",
"name": BuiltInField.source.value,
"type": "string",
}
)
return doc_metadata

@staticmethod @staticmethod
def gen_collection_name_by_id(dataset_id: str) -> str: def gen_collection_name_by_id(dataset_id: str) -> str:
normalized_dataset_id = dataset_id.replace("-", "_") normalized_dataset_id = dataset_id.replace("-", "_")
db.Index("document_dataset_id_idx", "dataset_id"), db.Index("document_dataset_id_idx", "dataset_id"),
db.Index("document_is_paused_idx", "is_paused"), db.Index("document_is_paused_idx", "is_paused"),
db.Index("document_tenant_idx", "tenant_id"), db.Index("document_tenant_idx", "tenant_id"),
db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
) )


# initial fields # initial fields
archived_at = db.Column(db.DateTime, nullable=True) archived_at = db.Column(db.DateTime, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
doc_type = db.Column(db.String(40), nullable=True) doc_type = db.Column(db.String(40), nullable=True)
doc_metadata = db.Column(db.JSON, nullable=True)
doc_metadata = db.Column(JSONB, nullable=True)
doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
doc_language = db.Column(db.String(255), nullable=True) doc_language = db.Column(db.String(255), nullable=True)


.scalar() .scalar()
) )


@property
def uploader(self):
user = db.session.query(Account).filter(Account.id == self.created_by).first()
return user.name if user else None

@property
def upload_date(self):
return self.created_at

@property
def last_update_date(self):
return self.updated_at

@property
def doc_metadata_details(self):
if self.doc_metadata:
document_metadatas = (
db.session.query(DatasetMetadata)
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
.filter(
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
)
.all()
)
metadata_list = []
for metadata in document_metadatas:
metadata_dict = {
"id": metadata.id,
"name": metadata.name,
"type": metadata.type,
"value": self.doc_metadata.get(metadata.name),
}
metadata_list.append(metadata_dict)
# deal built-in fields
metadata_list.extend(self.get_built_in_fields())

return metadata_list
return None

@property @property
def process_rule_dict(self): def process_rule_dict(self):
if self.dataset_process_rule_id: if self.dataset_process_rule_id:
return self.dataset_process_rule.to_dict() return self.dataset_process_rule.to_dict()
return None return None


def get_built_in_fields(self):
built_in_fields = []
built_in_fields.append(
{
"id": "built-in",
"name": BuiltInField.document_name,
"type": "string",
"value": self.name,
}
)
built_in_fields.append(
{
"id": "built-in",
"name": BuiltInField.uploader,
"type": "string",
"value": self.uploader,
}
)
built_in_fields.append(
{
"id": "built-in",
"name": BuiltInField.upload_date,
"type": "time",
"value": self.created_at.timestamp(),
}
)
built_in_fields.append(
{
"id": "built-in",
"name": BuiltInField.last_update_date,
"type": "time",
"value": self.updated_at.timestamp(),
}
)
built_in_fields.append(
{
"id": "built-in",
"name": BuiltInField.source,
"type": "string",
"value": MetadataDataSource[self.data_source_type].value,
}
)
return built_in_fields

def to_dict(self): def to_dict(self):
return { return {
"id": self.id, "id": self.id,
subscription_plan = db.Column(db.String(255), nullable=False) subscription_plan = db.Column(db.String(255), nullable=False)
operation = db.Column(db.String(255), nullable=False) operation = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))


class DatasetMetadata(db.Model): # type: ignore[name-defined]
__tablename__ = "dataset_metadatas"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
db.Index("dataset_metadata_tenant_idx", "tenant_id"),
db.Index("dataset_metadata_dataset_idx", "dataset_id"),
)

id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
type = db.Column(db.String(255), nullable=False)
name = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_by = db.Column(StringUUID, nullable=False)
updated_by = db.Column(StringUUID, nullable=True)


class DatasetMetadataBinding(db.Model): # type: ignore[name-defined]
__tablename__ = "dataset_metadata_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
db.Index("dataset_metadata_binding_document_idx", "document_id"),
)

id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
metadata_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_by = db.Column(StringUUID, nullable=False)

+ 457
- 482
api/poetry.lock
File diff suppressed because it is too large
View File


+ 57
- 3
api/services/dataset_service.py View File

import copy
import datetime import datetime
import json import json
import logging import logging
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin import ModelProviderID
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted


return document return document


@staticmethod
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
documents = (
db.session.query(Document)
.filter(
Document.id.in_(document_ids),
Document.enabled == True,
Document.indexing_status == "completed",
Document.archived == False,
)
.all()
)
return documents

@staticmethod @staticmethod
def get_document_by_dataset_id(dataset_id: str) -> list[Document]: def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all()
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset_id,
Document.enabled == True,
)
.all()
)

return documents

@staticmethod
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset_id,
Document.enabled == True,
Document.indexing_status == "completed",
Document.archived == False,
)
.all()
)


return documents return documents


if document.tenant_id != current_user.current_tenant_id: if document.tenant_id != current_user.current_tenant_id:
raise ValueError("No permission.") raise ValueError("No permission.")


document.name = name
if dataset.built_in_field_enabled:
if document.doc_metadata:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata[BuiltInField.document_name.value] = name
document.doc_metadata = doc_metadata


document.name = name
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()


doc_form=document_form, doc_form=document_form,
doc_language=document_language, doc_language=document_language,
) )
doc_metadata = {}
if dataset.built_in_field_enabled:
doc_metadata = {
BuiltInField.document_name: name,
BuiltInField.uploader: account.name,
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.source: data_source_type,
}
if metadata is not None: if metadata is not None:
document.doc_metadata = metadata.doc_metadata
doc_metadata.update(metadata.doc_metadata)
document.doc_type = metadata.doc_type document.doc_type = metadata.doc_type
if doc_metadata:
document.doc_metadata = doc_metadata
return document return document


@staticmethod @staticmethod

+ 33
- 0
api/services/entities/knowledge_entities/knowledge_entities.py View File

class ChildChunkUpdateArgs(BaseModel): class ChildChunkUpdateArgs(BaseModel):
id: Optional[str] = None id: Optional[str] = None
content: str content: str


class MetadataArgs(BaseModel):
type: Literal["string", "number", "time"]
name: str


class MetadataUpdateArgs(BaseModel):
name: str
value: Optional[str | int | float] = None


class MetadataValueUpdateArgs(BaseModel):
fields: list[MetadataUpdateArgs]


class MetadataDetail(BaseModel):
id: str
name: str
value: Optional[str | int | float] = None


class DocumentMetadataOperation(BaseModel):
document_id: str
metadata_list: list[MetadataDetail]


class MetadataOperationData(BaseModel):
"""
Metadata operation data
"""

operation_data: list[DocumentMetadataOperation]

+ 7
- 1
api/services/external_knowledge_service.py View File



from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.rag.entities.metadata_entities import MetadataCondition
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import ( from models.dataset import (
Dataset, Dataset,


@staticmethod @staticmethod
def fetch_external_knowledge_retrieval( def fetch_external_knowledge_retrieval(
tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict
tenant_id: str,
dataset_id: str,
query: str,
external_retrieval_parameters: dict,
metadata_condition: Optional[MetadataCondition] = None,
) -> list: ) -> list:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id dataset_id=dataset_id, tenant_id=tenant_id
}, },
"query": query, "query": query,
"knowledge_id": external_knowledge_binding.external_knowledge_id, "knowledge_id": external_knowledge_binding.external_knowledge_id,
"metadata_condition": metadata_condition.model_dump() if metadata_condition else None,
} }


response = ExternalDatasetService.process_external_api( response = ExternalDatasetService.process_external_api(

+ 241
- 0
api/services/metadata_service.py View File

import copy
import datetime
import logging
from typing import Optional

from flask_login import current_user # type: ignore

from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
MetadataOperationData,
)


class MetadataService:
@staticmethod
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
# check if metadata name already exists
if DatasetMetadata.query.filter_by(
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name
).first():
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
if field.value == metadata_args.name:
raise ValueError("Metadata name already exists in Built-in fields.")
metadata = DatasetMetadata(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset_id,
type=metadata_args.type,
name=metadata_args.name,
created_by=current_user.id,
)
db.session.add(metadata)
db.session.commit()
return metadata

@staticmethod
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists
if DatasetMetadata.query.filter_by(
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name
).first():
raise ValueError("Metadata name already exists.")
for field in BuiltInField:
if field.value == name:
raise ValueError("Metadata name already exists in Built-in fields.")
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
old_name = metadata.name
metadata.name = name
metadata.updated_by = current_user.id
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)

# update related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
for document in documents:
doc_metadata = copy.deepcopy(document.doc_metadata)
value = doc_metadata.pop(old_name, None)
doc_metadata[name] = value
document.doc_metadata = doc_metadata
db.session.add(document)
db.session.commit()
return metadata # type: ignore
except Exception:
logging.exception("Update metadata name failed")
finally:
redis_client.delete(lock_key)

@staticmethod
def delete_metadata(dataset_id: str, metadata_id: str):
lock_key = f"dataset_metadata_lock_{dataset_id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first()
if metadata is None:
raise ValueError("Metadata not found.")
db.session.delete(metadata)

# deal related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all()
if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids)
for document in documents:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata.pop(metadata.name, None)
document.doc_metadata = doc_metadata
db.session.add(document)
db.session.commit()
return metadata
except Exception:
logging.exception("Delete metadata failed")
finally:
redis_client.delete(lock_key)

@staticmethod
def get_built_in_fields():
return [
{"name": BuiltInField.document_name.value, "type": "string"},
{"name": BuiltInField.uploader.value, "type": "string"},
{"name": BuiltInField.upload_date.value, "type": "time"},
{"name": BuiltInField.last_update_date.value, "type": "time"},
{"name": BuiltInField.source.value, "type": "string"},
]

@staticmethod
def enable_built_in_field(dataset: Dataset):
if dataset.built_in_field_enabled:
return
lock_key = f"dataset_metadata_lock_{dataset.id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
dataset.built_in_field_enabled = True
db.session.add(dataset)
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
if documents:
for document in documents:
if not document.doc_metadata:
doc_metadata = {}
else:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata[BuiltInField.document_name.value] = document.name
doc_metadata[BuiltInField.uploader.value] = document.uploader
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
document.doc_metadata = doc_metadata
db.session.add(document)
db.session.commit()
except Exception:
logging.exception("Enable built-in field failed")
finally:
redis_client.delete(lock_key)

@staticmethod
def disable_built_in_field(dataset: Dataset):
if not dataset.built_in_field_enabled:
return
lock_key = f"dataset_metadata_lock_{dataset.id}"
try:
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
dataset.built_in_field_enabled = False
db.session.add(dataset)
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
document_ids = []
if documents:
for document in documents:
doc_metadata = copy.deepcopy(document.doc_metadata)
doc_metadata.pop(BuiltInField.document_name.value, None)
doc_metadata.pop(BuiltInField.uploader.value, None)
doc_metadata.pop(BuiltInField.upload_date.value, None)
doc_metadata.pop(BuiltInField.last_update_date.value, None)
doc_metadata.pop(BuiltInField.source.value, None)
document.doc_metadata = doc_metadata
db.session.add(document)
document_ids.append(document.id)
db.session.commit()
except Exception:
logging.exception("Disable built-in field failed")
finally:
redis_client.delete(lock_key)

@staticmethod
def update_documents_metadata(dataset: Dataset, metadata_args: MetadataOperationData):
for operation in metadata_args.operation_data:
lock_key = f"document_metadata_lock_{operation.document_id}"
try:
MetadataService.knowledge_base_metadata_lock_check(None, operation.document_id)
document = DocumentService.get_document(dataset.id, operation.document_id)
if document is None:
raise ValueError("Document not found.")
doc_metadata = {}
for metadata_value in operation.metadata_list:
doc_metadata[metadata_value.name] = metadata_value.value
if dataset.built_in_field_enabled:
doc_metadata[BuiltInField.document_name.value] = document.name
doc_metadata[BuiltInField.uploader.value] = document.uploader
doc_metadata[BuiltInField.upload_date.value] = document.upload_date.timestamp()
doc_metadata[BuiltInField.last_update_date.value] = document.last_update_date.timestamp()
doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value
document.doc_metadata = doc_metadata
db.session.add(document)
db.session.commit()
# deal metadata binding
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete()
for metadata_value in operation.metadata_list:
dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=current_user.current_tenant_id,
dataset_id=dataset.id,
document_id=operation.document_id,
metadata_id=metadata_value.id,
created_by=current_user.id,
)
db.session.add(dataset_metadata_binding)
db.session.commit()
except Exception:
logging.exception("Update documents metadata failed")
finally:
redis_client.delete(lock_key)

@staticmethod
def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]):
if dataset_id:
lock_key = f"dataset_metadata_lock_{dataset_id}"
if redis_client.get(lock_key):
raise ValueError("Another knowledge base metadata operation is running, please wait a moment.")
redis_client.set(lock_key, 1, ex=3600)
if document_id:
lock_key = f"document_metadata_lock_{document_id}"
if redis_client.get(lock_key):
raise ValueError("Another document metadata operation is running, please wait a moment.")
redis_client.set(lock_key, 1, ex=3600)

@staticmethod
def get_dataset_metadatas(dataset: Dataset):
return {
"doc_metadata": [
{
"id": item.get("id"),
"name": item.get("name"),
"type": item.get("type"),
"count": DatasetMetadataBinding.query.filter_by(
metadata_id=item.get("id"), dataset_id=dataset.id
).count(),
}
for item in dataset.doc_metadata or []
if item.get("id") != "built-in"
],
"built_in_field_enabled": dataset.built_in_field_enabled,
}

+ 1
- 1
api/services/tag_service.py View File

) )
if keyword: if keyword:
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
query = query.group_by(Tag.id, Tag.type, Tag.name)
results: list = query.order_by(Tag.created_at.desc()).all() results: list = query.order_by(Tag.created_at.desc()).all()
return results return results



Loading…
Cancel
Save