瀏覽代碼

feat:support baidu vector db (#9185)

tags/0.9.2
Shili Cao 1 年之前
父節點
當前提交
2ec6ffe478
No account linked to committer's email address

+ 9
- 0
api/.env.example 查看文件

@@ -208,6 +208,15 @@ OPENSEARCH_USER=admin
OPENSEARCH_PASSWORD=admin
OPENSEARCH_SECURE=true

# Baidu configuration
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
BAIDU_VECTOR_DB_ACCOUNT=root
BAIDU_VECTOR_DB_API_KEY=dify
BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3

# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5

+ 8
- 0
api/commands.py 查看文件

@@ -347,6 +347,14 @@ def migrate_knowledge_vector_database():
index_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.BAIDU:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.BAIDU,
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {vector_type} is not supported.")


+ 45
- 0
api/configs/middleware/vdb/baidu_vector_config.py 查看文件

@@ -0,0 +1,45 @@
from typing import Optional

from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings


class BaiduVectorDBConfig(BaseSettings):
"""
Configuration settings for Baidu Vector Database
"""

BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field(
description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')",
default=None,
)

BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field(
description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)",
default=30000,
)

BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field(
description="Account for authenticating with the Baidu Vector Database",
default=None,
)

BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Baidu Vector Database service",
default=None,
)

BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field(
description="Name of the specific Baidu Vector Database to connect to",
default=None,
)

BAIDU_VECTOR_DB_SHARD: PositiveInt = Field(
description="Number of shards for the Baidu Vector Database (default is 1)",
default=1,
)

BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description="Number of replicas for the Baidu Vector Database (default is 3)",
default=3,
)

+ 2
- 0
api/controllers/console/datasets/datasets.py 查看文件

@@ -617,6 +617,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
| VectorType.BAIDU
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
@@ -653,6 +654,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
| VectorType.BAIDU
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (

+ 0
- 0
api/core/rag/datasource/vdb/baidu/__init__.py 查看文件


+ 272
- 0
api/core/rag/datasource/vdb/baidu/baidu_vector.py 查看文件

@@ -0,0 +1,272 @@
import json
import time
import uuid
from typing import Any

from pydantic import BaseModel, model_validator
from pymochow import MochowClient
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.configuration import Configuration
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row

from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset


class BaiduConfig(BaseModel):
endpoint: str
connection_timeout_in_mills: int = 30 * 1000
account: str
api_key: str
database: str
index_type: str = "HNSW"
metric_type: str = "L2"
shard: int = 1
replicas: int = 3

@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["endpoint"]:
raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required")
if not values["account"]:
raise ValueError("config BAIDU_VECTOR_DB_ACCOUNT is required")
if not values["api_key"]:
raise ValueError("config BAIDU_VECTOR_DB_API_KEY is required")
if not values["database"]:
raise ValueError("config BAIDU_VECTOR_DB_DATABASE is required")
return values


class BaiduVector(BaseVector):
field_id: str = "id"
field_vector: str = "vector"
field_text: str = "text"
field_metadata: str = "metadata"
field_app_id: str = "app_id"
field_annotation_id: str = "annotation_id"
index_vector: str = "vector_idx"

def __init__(self, collection_name: str, config: BaiduConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._db = self._init_database()

def get_type(self) -> str:
return VectorType.BAIDU

def to_index_struct(self) -> dict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_table(len(embeddings[0]))
self.add_texts(texts, embeddings)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
total_count = len(documents)
batch_size = 1000

# upsert texts and embeddings batch by batch
table = self._db.table(self._collection_name)
for start in range(0, total_count, batch_size):
end = min(start + batch_size, total_count)
rows = []
for i in range(start, end, 1):
row = Row(
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
vector=embeddings[i],
text=texts[i],
metadata=json.dumps(metadatas[i]),
app_id=metadatas[i].get("app_id", ""),
annotation_id=metadatas[i].get("annotation_id", ""),
)
rows.append(row)
table.upsert(rows=rows)

# rebuild vector index after upsert finished
table.rebuild_index(self.index_vector)
while True:
time.sleep(1)
index = table.describe_index(self.index_vector)
if index.state == IndexState.NORMAL:
break

def text_exists(self, id: str) -> bool:
res = self._db.table(self._collection_name).query(primary_key={self.field_id: id})
if res and res.code == 0:
return True
return False

def delete_by_ids(self, ids: list[str]) -> None:
quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")

def delete_by_metadata_field(self, key: str, value: str) -> None:
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
)
res = self._db.table(self._collection_name).search(
anns=anns,
projections=[self.field_id, self.field_text, self.field_metadata],
retrieve_vector=True,
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# baidu vector database doesn't support bm25 search on current version
return []

def _get_search_res(self, res, score_threshold):
docs = []
for row in res.rows:
row_data = row.get("row", {})
meta = row_data.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
score = row.get("score", 0.0)
if score > score_threshold:
meta["score"] = score
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
docs.append(doc)

return docs

def delete(self) -> None:
self._db.drop_table(table_name=self._collection_name)

def _init_client(self, config) -> MochowClient:
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
client = MochowClient(config)
return client

def _init_database(self):
exists = False
for db in self._client.list_databases():
if db.database_name == self._client_config.database:
exists = True
break
# Create database if not existed
if exists:
return self._client.database(self._client_config.database)
else:
return self._client.create_database(database_name=self._client_config.database)

def _table_existed(self) -> bool:
tables = self._db.list_table()
return any(table.table_name == self._collection_name for table in tables)

def _create_table(self, dimension: int) -> None:
# Try to grab distributed lock and create table
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
table_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(table_exist_cache_key):
return

if self._table_existed():
return

self.delete()

# check IndexType and MetricType
index_type = None
for k, v in IndexType.__members__.items():
if k == self._client_config.index_type:
index_type = v
if index_type is None:
raise ValueError("unsupported index_type")
metric_type = None
for k, v in MetricType.__members__.items():
if k == self._client_config.metric_type:
metric_type = v
if metric_type is None:
raise ValueError("unsupported metric_type")

# Construct field schema
fields = []
fields.append(
Field(
self.field_id,
FieldType.STRING,
primary_key=True,
partition_key=True,
auto_increment=False,
not_null=True,
)
)
fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True))
fields.append(Field(self.field_app_id, FieldType.STRING))
fields.append(Field(self.field_annotation_id, FieldType.STRING))
fields.append(Field(self.field_text, FieldType.TEXT, not_null=True))
fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))

# Construct vector index params
indexes = []
indexes.append(
VectorIndex(
index_name="vector_idx",
index_type=index_type,
field="vector",
metric_type=metric_type,
params=HNSWParams(m=16, efconstruction=200),
)
)

# Create table
self._db.create_table(
table_name=self._collection_name,
replication=self._client_config.replicas,
partition=Partition(partition_num=self._client_config.shard),
schema=Schema(fields=fields, indexes=indexes),
description="Table for Dify",
)

redis_client.set(table_exist_cache_key, 1, ex=3600)

# Wait for table created
while True:
time.sleep(1)
table = self._db.describe_table(self._collection_name)
if table.state == TableState.NORMAL:
break


class BaiduVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.BAIDU, collection_name))

return BaiduVector(
collection_name=collection_name,
config=BaiduConfig(
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT,
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT,
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY,
database=dify_config.BAIDU_VECTOR_DB_DATABASE,
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
),
)

+ 4
- 0
api/core/rag/datasource/vdb/vector_factory.py 查看文件

@@ -103,6 +103,10 @@ class Vector:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory

return AnalyticdbVectorFactory
case VectorType.BAIDU:
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory

return BaiduVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")


+ 1
- 0
api/core/rag/datasource/vdb/vector_type.py 查看文件

@@ -16,3 +16,4 @@ class VectorType(str, Enum):
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
BAIDU = "baidu"

+ 34
- 13
api/poetry.lock 查看文件

@@ -732,7 +732,7 @@ name = "bce-python-sdk"
version = "0.9.23"
description = "BCE SDK for python"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, <4"
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,<4,>=2.7"
files = [
{file = "bce_python_sdk-0.9.23-py3-none-any.whl", hash = "sha256:8debe21a040e00060f6044877d594765ed7b18bc765c6bf16b878bca864140a3"},
{file = "bce_python_sdk-0.9.23.tar.gz", hash = "sha256:19739fed5cd0725356fc5ffa2acbdd8fb23f2a81edb91db21a03174551d0cf41"},
@@ -847,7 +847,7 @@ name = "botocore"
version = "1.35.38"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">= 3.8"
python-versions = ">=3.8"
files = [
{file = "botocore-1.35.38-py3-none-any.whl", hash = "sha256:2eb17d32fa2d3bb5d475132a83564d28e3acc2161534f24b75a54418a1d51359"},
{file = "botocore-1.35.38.tar.gz", hash = "sha256:55d9305c44e5ba29476df456120fa4fb919f03f066afa82f2ae400485e7465f4"},
@@ -1068,7 +1068,7 @@ name = "build"
version = "1.2.2.post1"
description = "A simple, correct Python build frontend"
optional = false
python-versions = ">= 3.8"
python-versions = ">=3.8"
files = [
{file = "build-1.2.2.post1-py3-none-any.whl", hash = "sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5"},
{file = "build-1.2.2.post1.tar.gz", hash = "sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7"},
@@ -3385,7 +3385,7 @@ name = "gotrue"
version = "2.9.2"
description = "Python Client Library for Supabase Auth"
optional = false
python-versions = ">=3.8,<4.0"
python-versions = "<4.0,>=3.8"
files = [
{file = "gotrue-2.9.2-py3-none-any.whl", hash = "sha256:fcd5279e8f1cc630f3ac35af5485fe39f8030b23906776920d2c32a4e308cff4"},
{file = "gotrue-2.9.2.tar.gz", hash = "sha256:57b3245e916c5efbf19a21b1181011a903c1276bb1df2d847558f2f24f29abb2"},
@@ -4415,7 +4415,7 @@ name = "langfuse"
version = "2.51.5"
description = "A client library for accessing langfuse"
optional = false
python-versions = ">=3.8.1,<4.0"
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langfuse-2.51.5-py3-none-any.whl", hash = "sha256:b95401ca710ef94b521afa6541933b6f93d7cfd4a97523c8fc75bca4d6d219fb"},
{file = "langfuse-2.51.5.tar.gz", hash = "sha256:55bc37b5c5d3ae133c1a95db09117cfb3117add110ba02ebbf2ce45ac4395c5b"},
@@ -4440,7 +4440,7 @@ name = "langsmith"
version = "0.1.134"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = ">=3.8.1,<4.0"
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langsmith-0.1.134-py3-none-any.whl", hash = "sha256:ada98ad80ef38807725f32441a472da3dd28394010877751f48f458d3289da04"},
{file = "langsmith-0.1.134.tar.gz", hash = "sha256:23abee3b508875a0e63c602afafffc02442a19cfd88f9daae05b3e9054fd6b61"},
@@ -6429,7 +6429,7 @@ name = "postgrest"
version = "0.17.1"
description = "PostgREST client for Python. This library provides an ORM interface to PostgREST."
optional = false
python-versions = ">=3.8,<4.0"
python-versions = "<4.0,>=3.8"
files = [
{file = "postgrest-0.17.1-py3-none-any.whl", hash = "sha256:ec1d00dc8532fe5ffb342cfc7c4e610a1e0e2272eb14f78f9b2b61094f9be510"},
{file = "postgrest-0.17.1.tar.gz", hash = "sha256:e31d9977dbb80dc5f9fdd4d444014686606692dc4ddb9adc85639e56c6d54c92"},
@@ -7047,6 +7047,22 @@ bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "r
dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"]
model = ["milvus-model (>=0.1.0)"]

[[package]]
name = "pymochow"
version = "1.3.1"
description = "Python SDK for mochow"
optional = false
python-versions = ">=3.7"
files = [
{file = "pymochow-1.3.1-py3-none-any.whl", hash = "sha256:a7f3b34fd6ea5d1d8413650bb6678365aa148fc396ae945e4ccb4f2365a52327"},
{file = "pymochow-1.3.1.tar.gz", hash = "sha256:1693d10cd0bb7bce45327890a90adafb503155922ccc029acb257699a73a20ba"},
]

[package.dependencies]
future = "*"
orjson = "*"
requests = "*"

[[package]]
name = "pymysql"
version = "1.1.1"
@@ -7746,7 +7762,7 @@ name = "realtime"
version = "2.0.2"
description = ""
optional = false
python-versions = ">=3.9,<4.0"
python-versions = "<4.0,>=3.9"
files = [
{file = "realtime-2.0.2-py3-none-any.whl", hash = "sha256:2634c915bc38807f2013f21e8bcc4d2f79870dfd81460ddb9393883d0489928a"},
{file = "realtime-2.0.2.tar.gz", hash = "sha256:519da9325b3b8102139d51785013d592f6b2403d81fa21d838a0b0234723ed7d"},
@@ -8173,7 +8189,7 @@ name = "s3transfer"
version = "0.10.3"
description = "An Amazon S3 Transfer Manager"
optional = false
python-versions = ">= 3.8"
python-versions = ">=3.8"
files = [
{file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"},
{file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"},
@@ -8417,6 +8433,11 @@ files = [
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"},
{file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"},
{file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"},
{file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"},
{file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"},
{file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"},
{file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"},
{file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"},
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"},
{file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"},
{file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"},
@@ -8836,7 +8857,7 @@ name = "storage3"
version = "0.8.1"
description = "Supabase Storage client for Python."
optional = false
python-versions = ">=3.8,<4.0"
python-versions = "<4.0,>=3.8"
files = [
{file = "storage3-0.8.1-py3-none-any.whl", hash = "sha256:0b21205f43eaf0d1dd33bde6c6d0612f88524b7865f017d2ae9827e3f63d9cdc"},
{file = "storage3-0.8.1.tar.gz", hash = "sha256:ea60b68b2221b3868ccc1a7f1294d57d0d9c51642cdc639d8115fe5d0adc8892"},
@@ -8882,7 +8903,7 @@ name = "supabase"
version = "2.8.1"
description = "Supabase client for Python."
optional = false
python-versions = ">=3.9,<4.0"
python-versions = "<4.0,>=3.9"
files = [
{file = "supabase-2.8.1-py3-none-any.whl", hash = "sha256:dfa8bef89b54129093521d5bba2136ff765baf67cd76d8ad0aa4984d61a7815c"},
{file = "supabase-2.8.1.tar.gz", hash = "sha256:711c70e6acd9e2ff48ca0dc0b1bb70c01c25378cc5189ec9f5ed9655b30bc41d"},
@@ -8902,7 +8923,7 @@ name = "supafunc"
version = "0.6.1"
description = "Library for Supabase Functions"
optional = false
python-versions = ">=3.8,<4.0"
python-versions = "<4.0,>=3.8"
files = [
{file = "supafunc-0.6.1-py3-none-any.whl", hash = "sha256:01aeeeb4bf429977664454a32c86418345140faf6d2e6eb0636d52e4547c5fbb"},
{file = "supafunc-0.6.1.tar.gz", hash = "sha256:3c8761e3999336ccdb7550498a395fd08afc8469382f55ea56f7f640e5a909aa"},
@@ -10615,4 +10636,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "cc10ee218369eb5576d1e5ac8aeeb72e8927bbcb8bd1ac1594167c45aa9d9a21"
content-hash = "375ac3a91760513924647e67376cb6018505ec61d967651b254c68af9808d774"

+ 1
- 0
api/pyproject.toml 查看文件

@@ -242,6 +242,7 @@ oracledb = "~2.2.1"
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
pgvector = "0.2.5"
pymilvus = "~2.4.4"
pymochow = "1.3.1"
qdrant-client = "1.7.3"
tcvectordb = "1.3.2"
tidb-vector = "0.0.9"

+ 154
- 0
api/tests/integration_tests/vdb/__mock/baiduvectordb.py 查看文件

@@ -0,0 +1,154 @@
import os

import pytest
from _pytest.monkeypatch import MonkeyPatch
from pymochow import MochowClient
from pymochow.model.database import Database
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
from pymochow.model.schema import HNSWParams, VectorIndex
from pymochow.model.table import Table
from requests.adapters import HTTPAdapter


class MockBaiduVectorDBClass:
def mock_vector_db_client(
self,
config=None,
adapter: HTTPAdapter = None,
):
self._conn = None
self._config = None

def list_databases(self, config=None) -> list[Database]:
return [
Database(
conn=self._conn,
database_name="dify",
config=self._config,
)
]

def create_database(self, database_name: str, config=None) -> Database:
return Database(conn=self._conn, database_name=database_name, config=config)

def list_table(self, config=None) -> list[Table]:
return []

def drop_table(self, table_name: str, config=None):
return {"code": 0, "msg": "Success"}

def create_table(
self,
table_name: str,
replication: int,
partition: int,
schema,
enable_dynamic_field=False,
description: str = "",
config=None,
) -> Table:
return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config)

def describe_table(self, table_name: str, config=None) -> Table:
return Table(
self,
table_name,
3,
1,
None,
enable_dynamic_field=False,
description="table for dify",
config=config,
state=TableState.NORMAL,
)

def upsert(self, rows, config=None):
return {"code": 0, "msg": "operation success", "affectedCount": 1}

def rebuild_index(self, index_name: str, config=None):
return {"code": 0, "msg": "Success"}

def describe_index(self, index_name: str, config=None):
return VectorIndex(
index_name=index_name,
index_type=IndexType.HNSW,
field="vector",
metric_type=MetricType.L2,
params=HNSWParams(m=16, efconstruction=200),
auto_build=False,
state=IndexState.NORMAL,
)

def query(
self,
primary_key,
partition_key=None,
projections=None,
retrieve_vector=False,
read_consistency=ReadConsistency.EVENTUAL,
config=None,
):
return {
"row": {
"id": "doc_id_001",
"vector": [0.23432432, 0.8923744, 0.89238432],
"text": "text",
"metadata": {"doc_id": "doc_id_001"},
},
"code": 0,
"msg": "Success",
}

def delete(self, primary_key=None, partition_key=None, filter=None, config=None):
return {"code": 0, "msg": "Success"}

def search(
self,
anns,
partition_key=None,
projections=None,
retrieve_vector=False,
read_consistency=ReadConsistency.EVENTUAL,
config=None,
):
return {
"rows": [
{
"row": {
"id": "doc_id_001",
"vector": [0.23432432, 0.8923744, 0.89238432],
"text": "text",
"metadata": {"doc_id": "doc_id_001"},
},
"distance": 0.1,
"score": 0.5,
}
],
"code": 0,
"msg": "Success",
}


MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"


@pytest.fixture
def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client)
monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases)
monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database)
monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table)
monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table)
monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table)
monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table)
monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table)
monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index)
monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index)
monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete)
monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search)

yield

if MOCK:
monkeypatch.undo()

+ 0
- 0
api/tests/integration_tests/vdb/baidu/__init__.py 查看文件


+ 36
- 0
api/tests/integration_tests/vdb/baidu/test_baidu.py 查看文件

@@ -0,0 +1,36 @@
from unittest.mock import MagicMock

from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis

mock_client = MagicMock()
mock_client.list_databases.return_value = [{"name": "test"}]


class BaiduVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = BaiduVector(
"dify",
BaiduConfig(
endpoint="http://127.0.0.1:5287",
account="root",
api_key="dify",
database="dify",
shard=1,
replicas=3,
),
)

def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1

def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0


def test_baidu_vector(setup_mock_redis, setup_baiduvectordb_mock):
BaiduVectorTest().run_all_tests()

+ 9
- 0
docker/.env.example 查看文件

@@ -462,6 +462,15 @@ ELASTICSEARCH_PORT=9200
ELASTICSEARCH_USERNAME=elastic
ELASTICSEARCH_PASSWORD=elastic

# baidu vector configurations, only available when VECTOR_STORE is `baidu`
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
BAIDU_VECTOR_DB_ACCOUNT=root
BAIDU_VECTOR_DB_API_KEY=dify
BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3

# ------------------------------
# Knowledge Configuration
# ------------------------------

+ 7
- 0
docker/docker-compose.yaml 查看文件

@@ -165,6 +165,13 @@ x-shared-env: &shared-api-worker-env
TENCENT_VECTOR_DB_DATABASE: ${TENCENT_VECTOR_DB_DATABASE:-dify}
TENCENT_VECTOR_DB_SHARD: ${TENCENT_VECTOR_DB_SHARD:-1}
TENCENT_VECTOR_DB_REPLICAS: ${TENCENT_VECTOR_DB_REPLICAS:-2}
BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287}
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000}
BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root}
BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify}
BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify}
BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1}
BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3}
UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15}
UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5}
ETL_TYPE: ${ETL_TYPE:-dify}

Loading…
取消
儲存