Pārlūkot izejas kodu

Add Full-Text & Hybrid Search Support to Baidu Vector DB and Update SDK, Closes #25982 (#25983)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
tags/1.9.0
Shili Cao pirms 1 mēnesi
vecāks
revīzija
345ac8333c
Revīzijas autora e-pasta adrese nav piesaistīta nevienam kontam

+ 2
- 0
api/.env.example Parādīt failu

@@ -304,6 +304,8 @@ BAIDU_VECTOR_DB_API_KEY=dify
BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE

# Upstash configuration
UPSTASH_VECTOR_URL=your-server-url

+ 10
- 0
api/configs/middleware/vdb/baidu_vector_config.py Parādīt failu

@@ -41,3 +41,13 @@ class BaiduVectorDBConfig(BaseSettings):
description="Number of replicas for the Baidu Vector Database (default is 3)",
default=3,
)

BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: str = Field(
description="Analyzer type for inverted index in Baidu Vector Database (default is DEFAULT_ANALYZER)",
default="DEFAULT_ANALYZER",
)

BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: str = Field(
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
default="COARSE_MODE",
)

+ 2
- 2
api/controllers/console/datasets/datasets.py Parādīt failu

@@ -782,7 +782,6 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
@@ -809,6 +808,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.TENCENT
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
):
return {
"retrieval_method": [
@@ -838,7 +838,6 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
):
@@ -863,6 +862,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.HUAWEI_CLOUD
| VectorType.MATRIXONE
| VectorType.CLICKZETTA
| VectorType.BAIDU
):
return {
"retrieval_method": [

+ 128
- 55
api/core/rag/datasource/vdb/baidu/baidu_vector.py Parādīt failu

@@ -1,4 +1,5 @@
import json
import logging
import time
import uuid
from typing import Any
@@ -9,11 +10,24 @@ from pymochow import MochowClient # type: ignore
from pymochow.auth.bce_credentials import BceCredentials # type: ignore
from pymochow.configuration import Configuration # type: ignore
from pymochow.exception import ServerError # type: ignore
from pymochow.model.database import Database
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore
from pymochow.model.schema import (
Field,
FilteringIndex,
HNSWParams,
InvertedIndex,
InvertedIndexAnalyzer,
InvertedIndexFieldAttribute,
InvertedIndexParams,
InvertedIndexParseMode,
Schema,
VectorIndex,
) # type: ignore
from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams, Partition, Row # type: ignore

from configs import dify_config
from core.rag.datasource.vdb.field import Field as VDBField
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -22,6 +36,8 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

logger = logging.getLogger(__name__)


class BaiduConfig(BaseModel):
endpoint: str
@@ -30,9 +46,11 @@ class BaiduConfig(BaseModel):
api_key: str
database: str
index_type: str = "HNSW"
metric_type: str = "L2"
metric_type: str = "IP"
shard: int = 1
replicas: int = 3
inverted_index_analyzer: str = "DEFAULT_ANALYZER"
inverted_index_parser_mode: str = "COARSE_MODE"

@model_validator(mode="before")
@classmethod
@@ -49,13 +67,9 @@ class BaiduConfig(BaseModel):


class BaiduVector(BaseVector):
field_id: str = "id"
field_vector: str = "vector"
field_text: str = "text"
field_metadata: str = "metadata"
field_app_id: str = "app_id"
field_annotation_id: str = "annotation_id"
index_vector: str = "vector_idx"
vector_index: str = "vector_idx"
filtering_index: str = "filtering_idx"
inverted_index: str = "content_inverted_idx"

def __init__(self, collection_name: str, config: BaiduConfig):
super().__init__(collection_name)
@@ -74,8 +88,6 @@ class BaiduVector(BaseVector):
self.add_texts(texts, embeddings)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents if doc.metadata is not None]
total_count = len(documents)
batch_size = 1000

@@ -84,29 +96,31 @@ class BaiduVector(BaseVector):
for start in range(0, total_count, batch_size):
end = min(start + batch_size, total_count)
rows = []
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
for i in range(start, end, 1):
metadata = documents[i].metadata
row = Row(
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
id=metadata.get("doc_id", str(uuid.uuid4())),
page_content=documents[i].page_content,
metadata=metadata,
vector=embeddings[i],
text=texts[i],
metadata=json.dumps(metadatas[i]),
app_id=metadatas[i].get("app_id", ""),
annotation_id=metadatas[i].get("annotation_id", ""),
)
rows.append(row)
table.upsert(rows=rows)

# rebuild vector index after upsert finished
table.rebuild_index(self.index_vector)
table.rebuild_index(self.vector_index)
timeout = 3600 # 1 hour timeout
start_time = time.time()
while True:
time.sleep(1)
index = table.describe_index(self.index_vector)
index = table.describe_index(self.vector_index)
if index.state == IndexState.NORMAL:
break
if time.time() - start_time > timeout:
raise TimeoutError(f"Index rebuild timeout after {timeout} seconds")

def text_exists(self, id: str) -> bool:
res = self._db.table(self._collection_name).query(primary_key={self.field_id: id})
res = self._db.table(self._collection_name).query(primary_key={VDBField.PRIMARY_KEY: id})
if res and res.code == 0:
return True
return False
@@ -115,53 +129,73 @@ class BaiduVector(BaseVector):
if not ids:
return
quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
self._db.table(self._collection_name).delete(filter=f"{VDBField.PRIMARY_KEY} IN({', '.join(quoted_ids)})")

def delete_by_metadata_field(self, key: str, value: str):
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
# Escape double quotes in value to prevent injection
escaped_value = value.replace('"', '\\"')
self._db.table(self._collection_name).delete(filter=f'metadata["{key}"] = "{escaped_value}"')

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
filter=f"document_id IN ({document_ids})",
)
else:
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
)
filter = f'metadata["document_id"] IN({document_ids})'
anns = AnnSearch(
vector_field=VDBField.VECTOR,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 20), limit=kwargs.get("top_k", 4)),
filter=filter,
)
res = self._db.table(self._collection_name).search(
anns=anns,
projections=[self.field_id, self.field_text, self.field_metadata],
retrieve_vector=True,
projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY],
retrieve_vector=False,
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# baidu vector database doesn't support bm25 search on current version
return []
# document ids filter
document_ids_filter = kwargs.get("document_ids_filter")
filter = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
filter = f'metadata["document_id"] IN({document_ids})'

request = BM25SearchRequest(
index_name=self.inverted_index, search_text=query, limit=kwargs.get("top_k", 4), filter=filter
)
res = self._db.table(self._collection_name).bm25_search(
request=request, projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY]
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)

def _get_search_res(self, res, score_threshold) -> list[Document]:
docs = []
for row in res.rows:
row_data = row.get("row", {})
meta = row_data.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
score = row.get("score", 0.0)
meta = row_data.get(VDBField.METADATA_KEY, {})

# Handle both JSON string and dict formats for backward compatibility
if isinstance(meta, str):
try:
import json

meta = json.loads(meta)
except (json.JSONDecodeError, TypeError):
meta = {}
elif not isinstance(meta, dict):
meta = {}

if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
doc = Document(page_content=row_data.get(VDBField.CONTENT_KEY), metadata=meta)
docs.append(doc)

return docs

def delete(self):
@@ -178,7 +212,7 @@ class BaiduVector(BaseVector):
client = MochowClient(config)
return client

def _init_database(self):
def _init_database(self) -> Database:
exists = False
for db in self._client.list_databases():
if db.database_name == self._client_config.database:
@@ -192,10 +226,10 @@ class BaiduVector(BaseVector):
self._client.create_database(database_name=self._client_config.database)
except ServerError as e:
if e.code == ServerErrCode.DB_ALREADY_EXIST:
pass
return self._client.database(self._client_config.database)
else:
raise
return
return self._client.database(self._client_config.database)

def _table_existed(self) -> bool:
tables = self._db.list_table()
@@ -232,7 +266,7 @@ class BaiduVector(BaseVector):
fields = []
fields.append(
Field(
self.field_id,
VDBField.PRIMARY_KEY,
FieldType.STRING,
primary_key=True,
partition_key=True,
@@ -240,24 +274,57 @@ class BaiduVector(BaseVector):
not_null=True,
)
)
fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True))
fields.append(Field(self.field_app_id, FieldType.STRING))
fields.append(Field(self.field_annotation_id, FieldType.STRING))
fields.append(Field(self.field_text, FieldType.TEXT, not_null=True))
fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))
fields.append(Field(VDBField.CONTENT_KEY, FieldType.TEXT, not_null=False))
fields.append(Field(VDBField.METADATA_KEY, FieldType.JSON, not_null=False))
fields.append(Field(VDBField.VECTOR, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))

# Construct vector index params
indexes = []
indexes.append(
VectorIndex(
index_name="vector_idx",
index_name=self.vector_index,
index_type=index_type,
field="vector",
field=VDBField.VECTOR,
metric_type=metric_type,
params=HNSWParams(m=16, efconstruction=200),
)
)

# Filtering index
indexes.append(
FilteringIndex(
index_name=self.filtering_index,
fields=[VDBField.METADATA_KEY],
)
)

# Get analyzer and parse_mode from config
analyzer = getattr(
InvertedIndexAnalyzer,
self._client_config.inverted_index_analyzer,
InvertedIndexAnalyzer.DEFAULT_ANALYZER,
)

parse_mode = getattr(
InvertedIndexParseMode,
self._client_config.inverted_index_parser_mode,
InvertedIndexParseMode.COARSE_MODE,
)

# Inverted index
indexes.append(
InvertedIndex(
index_name=self.inverted_index,
fields=[VDBField.CONTENT_KEY],
params=InvertedIndexParams(
analyzer=analyzer,
parse_mode=parse_mode,
case_sensitive=True,
),
field_attributes=[InvertedIndexFieldAttribute.ANALYZED],
)
)

# Create table
self._db.create_table(
table_name=self._collection_name,
@@ -268,11 +335,15 @@ class BaiduVector(BaseVector):
)

# Wait for table created
timeout = 300 # 5 minutes timeout
start_time = time.time()
while True:
time.sleep(1)
table = self._db.describe_table(self._collection_name)
if table.state == TableState.NORMAL:
break
if time.time() - start_time > timeout:
raise TimeoutError(f"Table creation timeout after {timeout} seconds")
redis_client.set(table_exist_cache_key, 1, ex=3600)


@@ -296,5 +367,7 @@ class BaiduVectorFactory(AbstractVectorFactory):
database=dify_config.BAIDU_VECTOR_DB_DATABASE or "",
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
inverted_index_analyzer=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER,
inverted_index_parser_mode=dify_config.BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE,
),
)

+ 1
- 1
api/pyproject.toml Parādīt failu

@@ -211,7 +211,7 @@ vdb = [
"pgvecto-rs[sqlalchemy]~=0.2.1",
"pgvector==0.2.5",
"pymilvus~=2.5.0",
"pymochow==1.3.1",
"pymochow==2.2.9",
"pyobvector~=0.2.15",
"qdrant-client==1.9.0",
"tablestore==6.2.0",

+ 4
- 4
api/tests/integration_tests/vdb/__mock/baiduvectordb.py Parādīt failu

@@ -100,8 +100,8 @@ class MockBaiduVectorDBClass:
"row": {
"id": primary_key.get("id"),
"vector": [0.23432432, 0.8923744, 0.89238432],
"text": "text",
"metadata": '{"doc_id": "doc_id_001"}',
"page_content": "text",
"metadata": {"doc_id": "doc_id_001"},
},
"code": 0,
"msg": "Success",
@@ -127,8 +127,8 @@ class MockBaiduVectorDBClass:
"row": {
"id": "doc_id_001",
"vector": [0.23432432, 0.8923744, 0.89238432],
"text": "text",
"metadata": '{"doc_id": "doc_id_001"}',
"page_content": "text",
"metadata": {"doc_id": "doc_id_001"},
},
"distance": 0.1,
"score": 0.5,

+ 5
- 5
api/uv.lock Parādīt failu

@@ -1,5 +1,5 @@
version = 1
revision = 3
revision = 2
requires-python = ">=3.11, <3.13"
resolution-markers = [
"python_full_version >= '3.12.4' and sys_platform == 'linux'",
@@ -1670,7 +1670,7 @@ vdb = [
{ name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" },
{ name = "pgvector", specifier = "==0.2.5" },
{ name = "pymilvus", specifier = "~=2.5.0" },
{ name = "pymochow", specifier = "==1.3.1" },
{ name = "pymochow", specifier = "==2.2.9" },
{ name = "pyobvector", specifier = "~=0.2.15" },
{ name = "qdrant-client", specifier = "==1.9.0" },
{ name = "tablestore", specifier = "==6.2.0" },
@@ -4935,16 +4935,16 @@ wheels = [

[[package]]
name = "pymochow"
version = "1.3.1"
version = "2.2.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "future" },
{ name = "orjson" },
{ name = "requests" },
]
sdist = { url = "https://files.pythonhosted.org/packages/cc/da/3027eeeaf7a7db9b0ca761079de4e676a002e1cc2c4260dab0ce812972b8/pymochow-1.3.1.tar.gz", hash = "sha256:1693d10cd0bb7bce45327890a90adafb503155922ccc029acb257699a73a20ba", size = 30800, upload-time = "2024-09-11T12:06:37.88Z" }
sdist = { url = "https://files.pythonhosted.org/packages/b5/29/d9b112684ce490057b90bddede3fb6a69cf2787a3fd7736bdce203e77388/pymochow-2.2.9.tar.gz", hash = "sha256:5a28058edc8861deb67524410e786814571ed9fe0700c8c9fc0bc2ad5835b06c", size = 50079, upload-time = "2025-06-05T08:33:19.59Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/6b/74/4b6227717f6baa37e7288f53e0fd55764939abc4119342eed4924a98f477/pymochow-1.3.1-py3-none-any.whl", hash = "sha256:a7f3b34fd6ea5d1d8413650bb6678365aa148fc396ae945e4ccb4f2365a52327", size = 42697, upload-time = "2024-09-11T12:06:36.114Z" },
{ url = "https://files.pythonhosted.org/packages/bf/9b/be18f9709dfd8187ff233be5acb253a9f4f1b07f1db0e7b09d84197c28e2/pymochow-2.2.9-py3-none-any.whl", hash = "sha256:639192b97f143d4a22fc163872be12aee19523c46f12e22416e8f289f1354d15", size = 77899, upload-time = "2025-06-05T08:33:17.424Z" },
]

[[package]]

+ 2
- 0
docker/.env.example Parādīt failu

@@ -635,6 +635,8 @@ BAIDU_VECTOR_DB_API_KEY=dify
BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE

# VikingDB configurations, only available when VECTOR_STORE is `vikingdb`
VIKINGDB_ACCESS_KEY=your-ak

+ 2
- 0
docker/docker-compose.yaml Parādīt failu

@@ -286,6 +286,8 @@ x-shared-env: &shared-api-worker-env
BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify}
BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1}
BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3}
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: ${BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER:-DEFAULT_ANALYZER}
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: ${BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE:-COARSE_MODE}
VIKINGDB_ACCESS_KEY: ${VIKINGDB_ACCESS_KEY:-your-ak}
VIKINGDB_SECRET_KEY: ${VIKINGDB_SECRET_KEY:-your-sk}
VIKINGDB_REGION: ${VIKINGDB_REGION:-cn-shanghai}

Notiek ielāde…
Atcelt
Saglabāt