|  |  | @@ -1,12 +1,14 @@ | 
		
	
		
			
			|  |  |  | import json | 
		
	
		
			
			|  |  |  | import logging | 
		
	
		
			
			|  |  |  | import math | 
		
	
		
			
			|  |  |  | from typing import Any, Optional | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | from pydantic import BaseModel | 
		
	
		
			
			|  |  |  | from tcvdb_text.encoder import BM25Encoder  # type: ignore | 
		
	
		
			
			|  |  |  | from tcvectordb import RPCVectorDBClient, VectorDBException  # type: ignore | 
		
	
		
			
			|  |  |  | from tcvectordb.model import document, enum  # type: ignore | 
		
	
		
			
			|  |  |  | from tcvectordb.model import index as vdb_index  # type: ignore | 
		
	
		
			
			|  |  |  | from tcvectordb.model.document import Filter  # type: ignore | 
		
	
		
			
			|  |  |  | from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, WeightedRerank  # type: ignore | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | from configs import dify_config | 
		
	
		
			
			|  |  |  | from core.rag.datasource.vdb.vector_base import BaseVector | 
		
	
	
		
			
			|  |  | @@ -17,6 +19,8 @@ from core.rag.models.document import Document | 
		
	
		
			
			|  |  |  | from extensions.ext_redis import redis_client | 
		
	
		
			
			|  |  |  | from models.dataset import Dataset | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | logger = logging.getLogger(__name__) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | class TencentConfig(BaseModel): | 
		
	
		
			
			|  |  |  | url: str | 
		
	
	
		
			
			|  |  | @@ -25,10 +29,11 @@ class TencentConfig(BaseModel): | 
		
	
		
			
			|  |  |  | username: Optional[str] | 
		
	
		
			
			|  |  |  | database: Optional[str] | 
		
	
		
			
			|  |  |  | index_type: str = "HNSW" | 
		
	
		
			
			|  |  |  | metric_type: str = "L2" | 
		
	
		
			
			|  |  |  | metric_type: str = "IP" | 
		
	
		
			
			|  |  |  | shard: int = 1 | 
		
	
		
			
			|  |  |  | replicas: int = 2 | 
		
	
		
			
			|  |  |  | max_upsert_batch_size: int = 128 | 
		
	
		
			
			|  |  |  | enable_hybrid_search: bool = False  # Flag to enable hybrid search | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def to_tencent_params(self): | 
		
	
		
			
			|  |  |  | return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} | 
		
	
	
		
			
			|  |  | @@ -44,6 +49,29 @@ class TencentVector(BaseVector): | 
		
	
		
			
			|  |  |  | super().__init__(collection_name) | 
		
	
		
			
			|  |  |  | self._client_config = config | 
		
	
		
			
			|  |  |  | self._client = RPCVectorDBClient(**self._client_config.to_tencent_params()) | 
		
	
		
			
			|  |  |  | self._enable_hybrid_search = False | 
		
	
		
			
			|  |  |  | self._dimension = 1024 | 
		
	
		
			
			|  |  |  | self._load_collection() | 
		
	
		
			
			|  |  |  | self._bm25 = BM25Encoder.default("zh") | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def _load_collection(self): | 
		
	
		
			
			|  |  |  | """ | 
		
	
		
			
			|  |  |  | Check if the collection supports hybrid search. | 
		
	
		
			
			|  |  |  | """ | 
		
	
		
			
			|  |  |  | if self._client_config.enable_hybrid_search: | 
		
	
		
			
			|  |  |  | self._enable_hybrid_search = True | 
		
	
		
			
			|  |  |  | if self._has_collection(): | 
		
	
		
			
			|  |  |  | coll = self._client.describe_collection( | 
		
	
		
			
			|  |  |  | database_name=self._client_config.database, collection_name=self.collection_name | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | has_hybrid_search = False | 
		
	
		
			
			|  |  |  | for idx in coll.indexes: | 
		
	
		
			
			|  |  |  | if idx.name == "sparse_vector": | 
		
	
		
			
			|  |  |  | has_hybrid_search = True | 
		
	
		
			
			|  |  |  | elif idx.name == "vector": | 
		
	
		
			
			|  |  |  | self._dimension = idx.dimension | 
		
	
		
			
			|  |  |  | if not has_hybrid_search: | 
		
	
		
			
			|  |  |  | self._enable_hybrid_search = False | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def _init_database(self): | 
		
	
		
			
			|  |  |  | return self._client.create_database_if_not_exists(database_name=self._client_config.database) | 
		
	
	
		
			
			|  |  | @@ -62,6 +90,7 @@ class TencentVector(BaseVector): | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def _create_collection(self, dimension: int) -> None: | 
		
	
		
			
			|  |  |  | self._dimension = dimension | 
		
	
		
			
			|  |  |  | lock_name = "vector_indexing_lock_{}".format(self._collection_name) | 
		
	
		
			
			|  |  |  | with redis_client.lock(lock_name, timeout=20): | 
		
	
		
			
			|  |  |  | collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | 
		
	
	
		
			
			|  |  | @@ -84,18 +113,25 @@ class TencentVector(BaseVector): | 
		
	
		
			
			|  |  |  | if metric_type is None: | 
		
	
		
			
			|  |  |  | raise ValueError("unsupported metric_type") | 
		
	
		
			
			|  |  |  | params = vdb_index.HNSWParams(m=16, efconstruction=200) | 
		
	
		
			
			|  |  |  | index = vdb_index.Index( | 
		
	
		
			
			|  |  |  | vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY), | 
		
	
		
			
			|  |  |  | vdb_index.VectorIndex( | 
		
	
		
			
			|  |  |  | self.field_vector, | 
		
	
		
			
			|  |  |  | dimension, | 
		
	
		
			
			|  |  |  | index_type, | 
		
	
		
			
			|  |  |  | metric_type, | 
		
	
		
			
			|  |  |  | params, | 
		
	
		
			
			|  |  |  | ), | 
		
	
		
			
			|  |  |  | vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER), | 
		
	
		
			
			|  |  |  | vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER), | 
		
	
		
			
			|  |  |  | index_id = vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY) | 
		
	
		
			
			|  |  |  | index_vector = vdb_index.VectorIndex( | 
		
	
		
			
			|  |  |  | self.field_vector, | 
		
	
		
			
			|  |  |  | dimension, | 
		
	
		
			
			|  |  |  | index_type, | 
		
	
		
			
			|  |  |  | metric_type, | 
		
	
		
			
			|  |  |  | params, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | index_text = vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER) | 
		
	
		
			
			|  |  |  | index_metadate = vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER) | 
		
	
		
			
			|  |  |  | index_sparse_vector = vdb_index.SparseIndex( | 
		
	
		
			
			|  |  |  | name="sparse_vector", | 
		
	
		
			
			|  |  |  | field_type=enum.FieldType.SparseVector, | 
		
	
		
			
			|  |  |  | index_type=enum.IndexType.SPARSE_INVERTED, | 
		
	
		
			
			|  |  |  | metric_type=enum.MetricType.IP, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | indexes = [index_id, index_vector, index_text, index_metadate] | 
		
	
		
			
			|  |  |  | if self._enable_hybrid_search: | 
		
	
		
			
			|  |  |  | indexes.append(index_sparse_vector) | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | self._client.create_collection( | 
		
	
		
			
			|  |  |  | database_name=self._client_config.database, | 
		
	
	
		
			
			|  |  | @@ -103,31 +139,25 @@ class TencentVector(BaseVector): | 
		
	
		
			
			|  |  |  | shard=self._client_config.shard, | 
		
	
		
			
			|  |  |  | replicas=self._client_config.replicas, | 
		
	
		
			
			|  |  |  | description="Collection for Dify", | 
		
	
		
			
			|  |  |  | index=index, | 
		
	
		
			
			|  |  |  | indexes=indexes, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | except VectorDBException as e: | 
		
	
		
			
			|  |  |  | if "fieldType:json" not in e.message: | 
		
	
		
			
			|  |  |  | raise e | 
		
	
		
			
			|  |  |  | # vdb version not support json, use string | 
		
	
		
			
			|  |  |  | index = vdb_index.Index( | 
		
	
		
			
			|  |  |  | vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY), | 
		
	
		
			
			|  |  |  | vdb_index.VectorIndex( | 
		
	
		
			
			|  |  |  | self.field_vector, | 
		
	
		
			
			|  |  |  | dimension, | 
		
	
		
			
			|  |  |  | index_type, | 
		
	
		
			
			|  |  |  | metric_type, | 
		
	
		
			
			|  |  |  | params, | 
		
	
		
			
			|  |  |  | ), | 
		
	
		
			
			|  |  |  | vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER), | 
		
	
		
			
			|  |  |  | vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER), | 
		
	
		
			
			|  |  |  | index_metadate = vdb_index.FilterIndex( | 
		
	
		
			
			|  |  |  | self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | indexes = [index_id, index_vector, index_text, index_metadate] | 
		
	
		
			
			|  |  |  | if self._enable_hybrid_search: | 
		
	
		
			
			|  |  |  | indexes.append(index_sparse_vector) | 
		
	
		
			
			|  |  |  | self._client.create_collection( | 
		
	
		
			
			|  |  |  | database_name=self._client_config.database, | 
		
	
		
			
			|  |  |  | collection_name=self._collection_name, | 
		
	
		
			
			|  |  |  | shard=self._client_config.shard, | 
		
	
		
			
			|  |  |  | replicas=self._client_config.replicas, | 
		
	
		
			
			|  |  |  | description="Collection for Dify", | 
		
	
		
			
			|  |  |  | index=index, | 
		
	
		
			
			|  |  |  | indexes=indexes, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | redis_client.set(collection_exist_cache_key, 1, ex=3600) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
	
		
			
			|  |  | @@ -155,6 +185,8 @@ class TencentVector(BaseVector): | 
		
	
		
			
			|  |  |  | text=texts[i], | 
		
	
		
			
			|  |  |  | metadata=metadata, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | if self._enable_hybrid_search: | 
		
	
		
			
			|  |  |  | doc.__dict__["sparse_vector"] = self._bm25.encode_texts(texts[i]) | 
		
	
		
			
			|  |  |  | docs.append(doc) | 
		
	
		
			
			|  |  |  | self._client.upsert( | 
		
	
		
			
			|  |  |  | database_name=self._client_config.database, | 
		
	
	
		
			
			|  |  | @@ -204,7 +236,32 @@ class TencentVector(BaseVector): | 
		
	
		
			
			|  |  |  | return self._get_search_res(res, score_threshold) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | 
		
	
		
			
			|  |  |  | return [] | 
		
	
		
			
			|  |  |  | if not self._enable_hybrid_search: | 
		
	
		
			
			|  |  |  | return [] | 
		
	
		
			
			|  |  |  | res = self._client.hybrid_search( | 
		
	
		
			
			|  |  |  | database_name=self._client_config.database, | 
		
	
		
			
			|  |  |  | collection_name=self.collection_name, | 
		
	
		
			
			|  |  |  | ann=[ | 
		
	
		
			
			|  |  |  | AnnSearch( | 
		
	
		
			
			|  |  |  | field_name="vector", | 
		
	
		
			
			|  |  |  | data=[0.0] * self._dimension, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | ], | 
		
	
		
			
			|  |  |  | match=[ | 
		
	
		
			
			|  |  |  | KeywordSearch( | 
		
	
		
			
			|  |  |  | field_name="sparse_vector", | 
		
	
		
			
			|  |  |  | data=self._bm25.encode_queries(query), | 
		
	
		
			
			|  |  |  | ), | 
		
	
		
			
			|  |  |  | ], | 
		
	
		
			
			|  |  |  | rerank=WeightedRerank( | 
		
	
		
			
			|  |  |  | field_list=["vector", "sparse_vector"], | 
		
	
		
			
			|  |  |  | weight=[0, 1], | 
		
	
		
			
			|  |  |  | ), | 
		
	
		
			
			|  |  |  | retrieve_vector=False, | 
		
	
		
			
			|  |  |  | limit=kwargs.get("top_k", 4), | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | score_threshold = float(kwargs.get("score_threshold") or 0.0) | 
		
	
		
			
			|  |  |  | return self._get_search_res(res, score_threshold) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]: | 
		
	
		
			
			|  |  |  | docs: list[Document] = [] | 
		
	
	
		
			
			|  |  | @@ -213,7 +270,7 @@ class TencentVector(BaseVector): | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | for result in res[0]: | 
		
	
		
			
			|  |  |  | meta = result.get(self.field_metadata) | 
		
	
		
			
			|  |  |  | score = 1 - result.get("score", 0.0) | 
		
	
		
			
			|  |  |  | score = result.get("score", 0.0) | 
		
	
		
			
			|  |  |  | if score > score_threshold: | 
		
	
		
			
			|  |  |  | meta["score"] = score | 
		
	
		
			
			|  |  |  | doc = Document(page_content=result.get(self.field_text), metadata=meta) | 
		
	
	
		
			
			|  |  | @@ -245,5 +302,6 @@ class TencentVectorFactory(AbstractVectorFactory): | 
		
	
		
			
			|  |  |  | database=dify_config.TENCENT_VECTOR_DB_DATABASE, | 
		
	
		
			
			|  |  |  | shard=dify_config.TENCENT_VECTOR_DB_SHARD, | 
		
	
		
			
			|  |  |  | replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS, | 
		
	
		
			
			|  |  |  | enable_hybrid_search=dify_config.TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH or False, | 
		
	
		
			
			|  |  |  | ), | 
		
	
		
			
			|  |  |  | ) |