|
|
|
@@ -1,4 +1,5 @@ |
|
|
|
import json |
|
|
|
import logging |
|
|
|
import time |
|
|
|
import uuid |
|
|
|
from typing import Any |
|
|
|
@@ -9,11 +10,24 @@ from pymochow import MochowClient # type: ignore |
|
|
|
from pymochow.auth.bce_credentials import BceCredentials # type: ignore |
|
|
|
from pymochow.configuration import Configuration # type: ignore |
|
|
|
from pymochow.exception import ServerError # type: ignore |
|
|
|
from pymochow.model.database import Database |
|
|
|
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore |
|
|
|
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore |
|
|
|
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore |
|
|
|
from pymochow.model.schema import ( |
|
|
|
Field, |
|
|
|
FilteringIndex, |
|
|
|
HNSWParams, |
|
|
|
InvertedIndex, |
|
|
|
InvertedIndexAnalyzer, |
|
|
|
InvertedIndexFieldAttribute, |
|
|
|
InvertedIndexParams, |
|
|
|
InvertedIndexParseMode, |
|
|
|
Schema, |
|
|
|
VectorIndex, |
|
|
|
) # type: ignore |
|
|
|
from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams, Partition, Row # type: ignore |
|
|
|
|
|
|
|
from configs import dify_config |
|
|
|
from core.rag.datasource.vdb.field import Field as VDBField |
|
|
|
from core.rag.datasource.vdb.vector_base import BaseVector |
|
|
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory |
|
|
|
from core.rag.datasource.vdb.vector_type import VectorType |
|
|
|
@@ -22,6 +36,8 @@ from core.rag.models.document import Document |
|
|
|
from extensions.ext_redis import redis_client |
|
|
|
from models.dataset import Dataset |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class BaiduConfig(BaseModel): |
|
|
|
endpoint: str |
|
|
|
@@ -30,9 +46,11 @@ class BaiduConfig(BaseModel): |
|
|
|
api_key: str |
|
|
|
database: str |
|
|
|
index_type: str = "HNSW" |
|
|
|
metric_type: str = "L2" |
|
|
|
metric_type: str = "IP" |
|
|
|
shard: int = 1 |
|
|
|
replicas: int = 3 |
|
|
|
inverted_index_analyzer: str = "DEFAULT_ANALYZER" |
|
|
|
inverted_index_parser_mode: str = "COARSE_MODE" |
|
|
|
|
|
|
|
@model_validator(mode="before") |
|
|
|
@classmethod |
|
|
|
@@ -49,13 +67,9 @@ class BaiduConfig(BaseModel): |
|
|
|
|
|
|
|
|
|
|
|
class BaiduVector(BaseVector): |
|
|
|
field_id: str = "id" |
|
|
|
field_vector: str = "vector" |
|
|
|
field_text: str = "text" |
|
|
|
field_metadata: str = "metadata" |
|
|
|
field_app_id: str = "app_id" |
|
|
|
field_annotation_id: str = "annotation_id" |
|
|
|
index_vector: str = "vector_idx" |
|
|
|
vector_index: str = "vector_idx" |
|
|
|
filtering_index: str = "filtering_idx" |
|
|
|
inverted_index: str = "content_inverted_idx" |
|
|
|
|
|
|
|
def __init__(self, collection_name: str, config: BaiduConfig): |
|
|
|
super().__init__(collection_name) |
|
|
|
@@ -74,8 +88,6 @@ class BaiduVector(BaseVector): |
|
|
|
self.add_texts(texts, embeddings) |
|
|
|
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
texts = [doc.page_content for doc in documents] |
|
|
|
metadatas = [doc.metadata for doc in documents if doc.metadata is not None] |
|
|
|
total_count = len(documents) |
|
|
|
batch_size = 1000 |
|
|
|
|
|
|
|
@@ -84,29 +96,31 @@ class BaiduVector(BaseVector): |
|
|
|
for start in range(0, total_count, batch_size): |
|
|
|
end = min(start + batch_size, total_count) |
|
|
|
rows = [] |
|
|
|
assert len(metadatas) == total_count, "metadatas length should be equal to total_count" |
|
|
|
for i in range(start, end, 1): |
|
|
|
metadata = documents[i].metadata |
|
|
|
row = Row( |
|
|
|
id=metadatas[i].get("doc_id", str(uuid.uuid4())), |
|
|
|
id=metadata.get("doc_id", str(uuid.uuid4())), |
|
|
|
page_content=documents[i].page_content, |
|
|
|
metadata=metadata, |
|
|
|
vector=embeddings[i], |
|
|
|
text=texts[i], |
|
|
|
metadata=json.dumps(metadatas[i]), |
|
|
|
app_id=metadatas[i].get("app_id", ""), |
|
|
|
annotation_id=metadatas[i].get("annotation_id", ""), |
|
|
|
) |
|
|
|
rows.append(row) |
|
|
|
table.upsert(rows=rows) |
|
|
|
|
|
|
|
# rebuild vector index after upsert finished |
|
|
|
table.rebuild_index(self.index_vector) |
|
|
|
table.rebuild_index(self.vector_index) |
|
|
|
timeout = 3600 # 1 hour timeout |
|
|
|
start_time = time.time() |
|
|
|
while True: |
|
|
|
time.sleep(1) |
|
|
|
index = table.describe_index(self.index_vector) |
|
|
|
index = table.describe_index(self.vector_index) |
|
|
|
if index.state == IndexState.NORMAL: |
|
|
|
break |
|
|
|
if time.time() - start_time > timeout: |
|
|
|
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds") |
|
|
|
|
|
|
|
def text_exists(self, id: str) -> bool: |
|
|
|
res = self._db.table(self._collection_name).query(primary_key={self.field_id: id}) |
|
|
|
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id}) |
|
|
|
if res and res.code == 0: |
|
|
|
return True |
|
|
|
return False |
|
|
|
@@ -115,53 +129,73 @@ class BaiduVector(BaseVector): |
|
|
|
if not ids: |
|
|
|
return |
|
|
|
quoted_ids = [f"'{id}'" for id in ids] |
|
|
|
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") |
|
|
|
self._db.table(self._collection_name).delete(filter=f"{VDBField.PRIMARY_KEY} IN({', '.join(quoted_ids)})") |
|
|
|
|
|
|
|
def delete_by_metadata_field(self, key: str, value: str): |
|
|
|
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'") |
|
|
|
# Escape double quotes in value to prevent injection |
|
|
|
escaped_value = value.replace('"', '\\"') |
|
|
|
self._db.table(self._collection_name).delete(filter=f'metadata["{key}"] = "{escaped_value}"') |
|
|
|
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: |
|
|
|
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector] |
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
filter = "" |
|
|
|
if document_ids_filter: |
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) |
|
|
|
anns = AnnSearch( |
|
|
|
vector_field=self.field_vector, |
|
|
|
vector_floats=query_vector, |
|
|
|
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), |
|
|
|
filter=f"document_id IN ({document_ids})", |
|
|
|
) |
|
|
|
else: |
|
|
|
anns = AnnSearch( |
|
|
|
vector_field=self.field_vector, |
|
|
|
vector_floats=query_vector, |
|
|
|
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), |
|
|
|
) |
|
|
|
filter = f'metadata["document_id"] IN({document_ids})' |
|
|
|
anns = AnnSearch( |
|
|
|
vector_field=VDBField.VECTOR, |
|
|
|
vector_floats=query_vector, |
|
|
|
params=HNSWSearchParams(ef=kwargs.get("ef", 20), limit=kwargs.get("top_k", 4)), |
|
|
|
filter=filter, |
|
|
|
) |
|
|
|
res = self._db.table(self._collection_name).search( |
|
|
|
anns=anns, |
|
|
|
projections=[self.field_id, self.field_text, self.field_metadata], |
|
|
|
retrieve_vector=True, |
|
|
|
projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY], |
|
|
|
retrieve_vector=False, |
|
|
|
) |
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0) |
|
|
|
return self._get_search_res(res, score_threshold) |
|
|
|
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: |
|
|
|
# baidu vector database doesn't support bm25 search on current version |
|
|
|
return [] |
|
|
|
# document ids filter |
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
filter = "" |
|
|
|
if document_ids_filter: |
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) |
|
|
|
filter = f'metadata["document_id"] IN({document_ids})' |
|
|
|
|
|
|
|
request = BM25SearchRequest( |
|
|
|
index_name=self.inverted_index, search_text=query, limit=kwargs.get("top_k", 4), filter=filter |
|
|
|
) |
|
|
|
res = self._db.table(self._collection_name).bm25_search( |
|
|
|
request=request, projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY] |
|
|
|
) |
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0) |
|
|
|
return self._get_search_res(res, score_threshold) |
|
|
|
|
|
|
|
def _get_search_res(self, res, score_threshold) -> list[Document]: |
|
|
|
docs = [] |
|
|
|
for row in res.rows: |
|
|
|
row_data = row.get("row", {}) |
|
|
|
meta = row_data.get(self.field_metadata) |
|
|
|
if meta is not None: |
|
|
|
meta = json.loads(meta) |
|
|
|
score = row.get("score", 0.0) |
|
|
|
meta = row_data.get(VDBField.METADATA_KEY, {}) |
|
|
|
|
|
|
|
# Handle both JSON string and dict formats for backward compatibility |
|
|
|
if isinstance(meta, str): |
|
|
|
try: |
|
|
|
import json |
|
|
|
|
|
|
|
meta = json.loads(meta) |
|
|
|
except (json.JSONDecodeError, TypeError): |
|
|
|
meta = {} |
|
|
|
elif not isinstance(meta, dict): |
|
|
|
meta = {} |
|
|
|
|
|
|
|
if score >= score_threshold: |
|
|
|
meta["score"] = score |
|
|
|
doc = Document(page_content=row_data.get(self.field_text), metadata=meta) |
|
|
|
doc = Document(page_content=row_data.get(VDBField.CONTENT_KEY), metadata=meta) |
|
|
|
docs.append(doc) |
|
|
|
|
|
|
|
return docs |
|
|
|
|
|
|
|
def delete(self): |
|
|
|
@@ -178,7 +212,7 @@ class BaiduVector(BaseVector): |
|
|
|
client = MochowClient(config) |
|
|
|
return client |
|
|
|
|
|
|
|
def _init_database(self): |
|
|
|
def _init_database(self) -> Database: |
|
|
|
exists = False |
|
|
|
for db in self._client.list_databases(): |
|
|
|
if db.database_name == self._client_config.database: |
|
|
|
@@ -192,10 +226,10 @@ class BaiduVector(BaseVector): |
|
|
|
self._client.create_database(database_name=self._client_config.database) |
|
|
|
except ServerError as e: |
|
|
|
if e.code == ServerErrCode.DB_ALREADY_EXIST: |
|
|
|
pass |
|
|
|
return self._client.database(self._client_config.database) |
|
|
|
else: |
|
|
|
raise |
|
|
|
return |
|
|
|
return self._client.database(self._client_config.database) |
|
|
|
|
|
|
|
def _table_existed(self) -> bool: |
|
|
|
tables = self._db.list_table() |
|
|
|
@@ -232,7 +266,7 @@ class BaiduVector(BaseVector): |
|
|
|
fields = [] |
|
|
|
fields.append( |
|
|
|
Field( |
|
|
|
self.field_id, |
|
|
|
VDBField.PRIMARY_KEY, |
|
|
|
FieldType.STRING, |
|
|
|
primary_key=True, |
|
|
|
partition_key=True, |
|
|
|
@@ -240,24 +274,57 @@ class BaiduVector(BaseVector): |
|
|
|
not_null=True, |
|
|
|
) |
|
|
|
) |
|
|
|
fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True)) |
|
|
|
fields.append(Field(self.field_app_id, FieldType.STRING)) |
|
|
|
fields.append(Field(self.field_annotation_id, FieldType.STRING)) |
|
|
|
fields.append(Field(self.field_text, FieldType.TEXT, not_null=True)) |
|
|
|
fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension)) |
|
|
|
fields.append(Field(VDBField.CONTENT_KEY, FieldType.TEXT, not_null=False)) |
|
|
|
fields.append(Field(VDBField.METADATA_KEY, FieldType.JSON, not_null=False)) |
|
|
|
fields.append(Field(VDBField.VECTOR, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension)) |
|
|
|
|
|
|
|
# Construct vector index params |
|
|
|
indexes = [] |
|
|
|
indexes.append( |
|
|
|
VectorIndex( |
|
|
|
index_name="vector_idx", |
|
|
|
index_name=self.vector_index, |
|
|
|
index_type=index_type, |
|
|
|
field="vector", |
|
|
|
field=VDBField.VECTOR, |
|
|
|
metric_type=metric_type, |
|
|
|
params=HNSWParams(m=16, efconstruction=200), |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
# Filtering index |
|
|
|
indexes.append( |
|
|
|
FilteringIndex( |
|
|
|
index_name=self.filtering_index, |
|
|
|
fields=[VDBField.METADATA_KEY], |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
# Get analyzer and parse_mode from config |
|
|
|
analyzer = getattr( |
|
|
|
InvertedIndexAnalyzer, |
|
|
|
self._client_config.inverted_index_analyzer, |
|
|
|
InvertedIndexAnalyzer.DEFAULT_ANALYZER, |
|
|
|
) |
|
|
|
|
|
|
|
parse_mode = getattr( |
|
|
|
InvertedIndexParseMode, |
|
|
|
self._client_config.inverted_index_parser_mode, |
|
|
|
InvertedIndexParseMode.COARSE_MODE, |
|
|
|
) |
|
|
|
|
|
|
|
# Inverted index |
|
|
|
indexes.append( |
|
|
|
InvertedIndex( |
|
|
|
index_name=self.inverted_index, |
|
|
|
fields=[VDBField.CONTENT_KEY], |
|
|
|
params=InvertedIndexParams( |
|
|
|
analyzer=analyzer, |
|
|
|
parse_mode=parse_mode, |
|
|
|
case_sensitive=True, |
|
|
|
), |
|
|
|
field_attributes=[InvertedIndexFieldAttribute.ANALYZED], |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
# Create table |
|
|
|
self._db.create_table( |
|
|
|
table_name=self._collection_name, |
|
|
|
@@ -268,11 +335,15 @@ class BaiduVector(BaseVector): |
|
|
|
) |
|
|
|
|
|
|
|
# Wait for table created |
|
|
|
timeout = 300 # 5 minutes timeout |
|
|
|
start_time = time.time() |
|
|
|
while True: |
|
|
|
time.sleep(1) |
|
|
|
table = self._db.describe_table(self._collection_name) |
|
|
|
if table.state == TableState.NORMAL: |
|
|
|
break |
|
|
|
if time.time() - start_time > timeout: |
|
|
|
raise TimeoutError(f"Table creation timeout after {timeout} seconds") |
|
|
|
redis_client.set(table_exist_cache_key, 1, ex=3600) |
|
|
|
|
|
|
|
|
|
|
|
@@ -296,5 +367,7 @@ class BaiduVectorFactory(AbstractVectorFactory): |
|
|
|
database=dify_config.BAIDU_VECTOR_DB_DATABASE or "", |
|
|
|
shard=dify_config.BAIDU_VECTOR_DB_SHARD, |
|
|
|
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, |
|
|
|
inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER, |
|
|
|
inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE, |
|
|
|
), |
|
|
|
) |