|
|
|
|
|
|
|
|
import json |
|
|
import json |
|
|
|
|
|
import logging |
|
|
import math |
|
|
import math |
|
|
from typing import Any, Optional |
|
|
from typing import Any, Optional |
|
|
|
|
|
|
|
|
from pydantic import BaseModel |
|
|
from pydantic import BaseModel |
|
|
|
|
|
from tcvdb_text.encoder import BM25Encoder # type: ignore |
|
|
from tcvectordb import RPCVectorDBClient, VectorDBException # type: ignore |
|
|
from tcvectordb import RPCVectorDBClient, VectorDBException # type: ignore |
|
|
from tcvectordb.model import document, enum # type: ignore |
|
|
from tcvectordb.model import document, enum # type: ignore |
|
|
from tcvectordb.model import index as vdb_index # type: ignore |
|
|
from tcvectordb.model import index as vdb_index # type: ignore |
|
|
from tcvectordb.model.document import Filter # type: ignore |
|
|
|
|
|
|
|
|
from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, WeightedRerank # type: ignore |
|
|
|
|
|
|
|
|
from configs import dify_config |
|
|
from configs import dify_config |
|
|
from core.rag.datasource.vdb.vector_base import BaseVector |
|
|
from core.rag.datasource.vdb.vector_base import BaseVector |
|
|
|
|
|
|
|
|
from extensions.ext_redis import redis_client |
|
|
from extensions.ext_redis import redis_client |
|
|
from models.dataset import Dataset |
|
|
from models.dataset import Dataset |
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TencentConfig(BaseModel): |
|
|
class TencentConfig(BaseModel): |
|
|
url: str |
|
|
url: str |
|
|
|
|
|
|
|
|
username: Optional[str] |
|
|
username: Optional[str] |
|
|
database: Optional[str] |
|
|
database: Optional[str] |
|
|
index_type: str = "HNSW" |
|
|
index_type: str = "HNSW" |
|
|
metric_type: str = "L2" |
|
|
|
|
|
|
|
|
metric_type: str = "IP" |
|
|
shard: int = 1 |
|
|
shard: int = 1 |
|
|
replicas: int = 2 |
|
|
replicas: int = 2 |
|
|
max_upsert_batch_size: int = 128 |
|
|
max_upsert_batch_size: int = 128 |
|
|
|
|
|
enable_hybrid_search: bool = False # Flag to enable hybrid search |
|
|
|
|
|
|
|
|
def to_tencent_params(self): |
|
|
def to_tencent_params(self): |
|
|
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} |
|
|
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} |
|
|
|
|
|
|
|
|
super().__init__(collection_name) |
|
|
super().__init__(collection_name) |
|
|
self._client_config = config |
|
|
self._client_config = config |
|
|
self._client = RPCVectorDBClient(**self._client_config.to_tencent_params()) |
|
|
self._client = RPCVectorDBClient(**self._client_config.to_tencent_params()) |
|
|
|
|
|
self._enable_hybrid_search = False |
|
|
|
|
|
self._dimension = 1024 |
|
|
|
|
|
self._load_collection() |
|
|
|
|
|
self._bm25 = BM25Encoder.default("zh") |
|
|
|
|
|
|
|
|
|
|
|
def _load_collection(self): |
|
|
|
|
|
""" |
|
|
|
|
|
Check if the collection supports hybrid search. |
|
|
|
|
|
""" |
|
|
|
|
|
if self._client_config.enable_hybrid_search: |
|
|
|
|
|
self._enable_hybrid_search = True |
|
|
|
|
|
if self._has_collection(): |
|
|
|
|
|
coll = self._client.describe_collection( |
|
|
|
|
|
database_name=self._client_config.database, collection_name=self.collection_name |
|
|
|
|
|
) |
|
|
|
|
|
has_hybrid_search = False |
|
|
|
|
|
for idx in coll.indexes: |
|
|
|
|
|
if idx.name == "sparse_vector": |
|
|
|
|
|
has_hybrid_search = True |
|
|
|
|
|
elif idx.name == "vector": |
|
|
|
|
|
self._dimension = idx.dimension |
|
|
|
|
|
if not has_hybrid_search: |
|
|
|
|
|
self._enable_hybrid_search = False |
|
|
|
|
|
|
|
|
def _init_database(self): |
|
|
def _init_database(self): |
|
|
return self._client.create_database_if_not_exists(database_name=self._client_config.database) |
|
|
return self._client.create_database_if_not_exists(database_name=self._client_config.database) |
|
|
|
|
|
|
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def _create_collection(self, dimension: int) -> None: |
|
|
def _create_collection(self, dimension: int) -> None: |
|
|
|
|
|
self._dimension = dimension |
|
|
lock_name = "vector_indexing_lock_{}".format(self._collection_name) |
|
|
lock_name = "vector_indexing_lock_{}".format(self._collection_name) |
|
|
with redis_client.lock(lock_name, timeout=20): |
|
|
with redis_client.lock(lock_name, timeout=20): |
|
|
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) |
|
|
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) |
|
|
|
|
|
|
|
|
if metric_type is None: |
|
|
if metric_type is None: |
|
|
raise ValueError("unsupported metric_type") |
|
|
raise ValueError("unsupported metric_type") |
|
|
params = vdb_index.HNSWParams(m=16, efconstruction=200) |
|
|
params = vdb_index.HNSWParams(m=16, efconstruction=200) |
|
|
index = vdb_index.Index( |
|
|
|
|
|
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY), |
|
|
|
|
|
vdb_index.VectorIndex( |
|
|
|
|
|
self.field_vector, |
|
|
|
|
|
dimension, |
|
|
|
|
|
index_type, |
|
|
|
|
|
metric_type, |
|
|
|
|
|
params, |
|
|
|
|
|
), |
|
|
|
|
|
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER), |
|
|
|
|
|
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER), |
|
|
|
|
|
|
|
|
index_id = vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY) |
|
|
|
|
|
index_vector = vdb_index.VectorIndex( |
|
|
|
|
|
self.field_vector, |
|
|
|
|
|
dimension, |
|
|
|
|
|
index_type, |
|
|
|
|
|
metric_type, |
|
|
|
|
|
params, |
|
|
|
|
|
) |
|
|
|
|
|
index_text = vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER) |
|
|
|
|
|
index_metadate = vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER) |
|
|
|
|
|
index_sparse_vector = vdb_index.SparseIndex( |
|
|
|
|
|
name="sparse_vector", |
|
|
|
|
|
field_type=enum.FieldType.SparseVector, |
|
|
|
|
|
index_type=enum.IndexType.SPARSE_INVERTED, |
|
|
|
|
|
metric_type=enum.MetricType.IP, |
|
|
) |
|
|
) |
|
|
|
|
|
indexes = [index_id, index_vector, index_text, index_metadate] |
|
|
|
|
|
if self._enable_hybrid_search: |
|
|
|
|
|
indexes.append(index_sparse_vector) |
|
|
try: |
|
|
try: |
|
|
self._client.create_collection( |
|
|
self._client.create_collection( |
|
|
database_name=self._client_config.database, |
|
|
database_name=self._client_config.database, |
|
|
|
|
|
|
|
|
shard=self._client_config.shard, |
|
|
shard=self._client_config.shard, |
|
|
replicas=self._client_config.replicas, |
|
|
replicas=self._client_config.replicas, |
|
|
description="Collection for Dify", |
|
|
description="Collection for Dify", |
|
|
index=index, |
|
|
|
|
|
|
|
|
indexes=indexes, |
|
|
) |
|
|
) |
|
|
except VectorDBException as e: |
|
|
except VectorDBException as e: |
|
|
if "fieldType:json" not in e.message: |
|
|
if "fieldType:json" not in e.message: |
|
|
raise e |
|
|
raise e |
|
|
# vdb version not support json, use string |
|
|
# vdb version not support json, use string |
|
|
index = vdb_index.Index( |
|
|
|
|
|
vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY), |
|
|
|
|
|
vdb_index.VectorIndex( |
|
|
|
|
|
self.field_vector, |
|
|
|
|
|
dimension, |
|
|
|
|
|
index_type, |
|
|
|
|
|
metric_type, |
|
|
|
|
|
params, |
|
|
|
|
|
), |
|
|
|
|
|
vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER), |
|
|
|
|
|
vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER), |
|
|
|
|
|
|
|
|
index_metadate = vdb_index.FilterIndex( |
|
|
|
|
|
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER |
|
|
) |
|
|
) |
|
|
|
|
|
indexes = [index_id, index_vector, index_text, index_metadate] |
|
|
|
|
|
if self._enable_hybrid_search: |
|
|
|
|
|
indexes.append(index_sparse_vector) |
|
|
self._client.create_collection( |
|
|
self._client.create_collection( |
|
|
database_name=self._client_config.database, |
|
|
database_name=self._client_config.database, |
|
|
collection_name=self._collection_name, |
|
|
collection_name=self._collection_name, |
|
|
shard=self._client_config.shard, |
|
|
shard=self._client_config.shard, |
|
|
replicas=self._client_config.replicas, |
|
|
replicas=self._client_config.replicas, |
|
|
description="Collection for Dify", |
|
|
description="Collection for Dify", |
|
|
index=index, |
|
|
|
|
|
|
|
|
indexes=indexes, |
|
|
) |
|
|
) |
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600) |
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text=texts[i], |
|
|
text=texts[i], |
|
|
metadata=metadata, |
|
|
metadata=metadata, |
|
|
) |
|
|
) |
|
|
|
|
|
if self._enable_hybrid_search: |
|
|
|
|
|
doc.__dict__["sparse_vector"] = self._bm25.encode_texts(texts[i]) |
|
|
docs.append(doc) |
|
|
docs.append(doc) |
|
|
self._client.upsert( |
|
|
self._client.upsert( |
|
|
database_name=self._client_config.database, |
|
|
database_name=self._client_config.database, |
|
|
|
|
|
|
|
|
return self._get_search_res(res, score_threshold) |
|
|
return self._get_search_res(res, score_threshold) |
|
|
|
|
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: |
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: |
|
|
return [] |
|
|
|
|
|
|
|
|
if not self._enable_hybrid_search: |
|
|
|
|
|
return [] |
|
|
|
|
|
res = self._client.hybrid_search( |
|
|
|
|
|
database_name=self._client_config.database, |
|
|
|
|
|
collection_name=self.collection_name, |
|
|
|
|
|
ann=[ |
|
|
|
|
|
AnnSearch( |
|
|
|
|
|
field_name="vector", |
|
|
|
|
|
data=[0.0] * self._dimension, |
|
|
|
|
|
) |
|
|
|
|
|
], |
|
|
|
|
|
match=[ |
|
|
|
|
|
KeywordSearch( |
|
|
|
|
|
field_name="sparse_vector", |
|
|
|
|
|
data=self._bm25.encode_queries(query), |
|
|
|
|
|
), |
|
|
|
|
|
], |
|
|
|
|
|
rerank=WeightedRerank( |
|
|
|
|
|
field_list=["vector", "sparse_vector"], |
|
|
|
|
|
weight=[0, 1], |
|
|
|
|
|
), |
|
|
|
|
|
retrieve_vector=False, |
|
|
|
|
|
limit=kwargs.get("top_k", 4), |
|
|
|
|
|
) |
|
|
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0) |
|
|
|
|
|
return self._get_search_res(res, score_threshold) |
|
|
|
|
|
|
|
|
def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]: |
|
|
def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]: |
|
|
docs: list[Document] = [] |
|
|
docs: list[Document] = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for result in res[0]: |
|
|
for result in res[0]: |
|
|
meta = result.get(self.field_metadata) |
|
|
meta = result.get(self.field_metadata) |
|
|
score = 1 - result.get("score", 0.0) |
|
|
|
|
|
|
|
|
score = result.get("score", 0.0) |
|
|
if score > score_threshold: |
|
|
if score > score_threshold: |
|
|
meta["score"] = score |
|
|
meta["score"] = score |
|
|
doc = Document(page_content=result.get(self.field_text), metadata=meta) |
|
|
doc = Document(page_content=result.get(self.field_text), metadata=meta) |
|
|
|
|
|
|
|
|
database=dify_config.TENCENT_VECTOR_DB_DATABASE, |
|
|
database=dify_config.TENCENT_VECTOR_DB_DATABASE, |
|
|
shard=dify_config.TENCENT_VECTOR_DB_SHARD, |
|
|
shard=dify_config.TENCENT_VECTOR_DB_SHARD, |
|
|
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS, |
|
|
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS, |
|
|
|
|
|
enable_hybrid_search=dify_config.TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH or False, |
|
|
), |
|
|
), |
|
|
) |
|
|
) |