Parcourir la source

feat: support tencent vector db (#3568)

tags/0.6.11
quicksand il y a 1 an
Parent
révision
4080f7b8ad
Aucun compte lié à l'adresse e-mail de l'auteur

+ 9
- 0
api/.env.example Voir le fichier

@@ -99,6 +99,15 @@ RELYT_USER=postgres
RELYT_PASSWORD=postgres
RELYT_DATABASE=postgres

# Tencent configuration
TENCENT_VECTOR_DB_URL=http://127.0.0.1
TENCENT_VECTOR_DB_API_KEY=dify
TENCENT_VECTOR_DB_TIMEOUT=30
TENCENT_VECTOR_DB_USERNAME=dify
TENCENT_VECTOR_DB_DATABASE=dify
TENCENT_VECTOR_DB_SHARD=1
TENCENT_VECTOR_DB_REPLICAS=2

# PGVECTO_RS configuration
PGVECTO_RS_HOST=localhost
PGVECTO_RS_PORT=5431

+ 8
- 0
api/commands.py Voir le fichier

@@ -309,6 +309,14 @@ def migrate_knowledge_vector_database():
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.TENCENT:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.TENCENT,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.PGVECTOR:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)

+ 10
- 0
api/config.py Voir le fichier

@@ -288,6 +288,16 @@ class Config:
self.RELYT_PASSWORD = get_env('RELYT_PASSWORD')
self.RELYT_DATABASE = get_env('RELYT_DATABASE')


# tencent settings
self.TENCENT_VECTOR_DB_URL = get_env('TENCENT_VECTOR_DB_URL')
self.TENCENT_VECTOR_DB_API_KEY = get_env('TENCENT_VECTOR_DB_API_KEY')
self.TENCENT_VECTOR_DB_TIMEOUT = get_env('TENCENT_VECTOR_DB_TIMEOUT')
self.TENCENT_VECTOR_DB_USERNAME = get_env('TENCENT_VECTOR_DB_USERNAME')
self.TENCENT_VECTOR_DB_DATABASE = get_env('TENCENT_VECTOR_DB_DATABASE')
self.TENCENT_VECTOR_DB_SHARD = get_env('TENCENT_VECTOR_DB_SHARD')
self.TENCENT_VECTOR_DB_REPLICAS = get_env('TENCENT_VECTOR_DB_REPLICAS')

# pgvecto rs settings
self.PGVECTO_RS_HOST = get_env('PGVECTO_RS_HOST')
self.PGVECTO_RS_PORT = get_env('PGVECTO_RS_PORT')

+ 2
- 3
api/controllers/console/datasets/datasets.py Voir le fichier

@@ -480,9 +480,8 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required
def get(self):
vector_type = current_app.config['VECTOR_STORE']

match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
return {
'retrieval_method': [
'semantic_search'
@@ -504,7 +503,7 @@ class DatasetRetrievalSettingMockApi(Resource):
@account_initialization_required
def get(self, vector_type):
match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCEN:
return {
'retrieval_method': [
'semantic_search'

+ 0
- 0
api/core/rag/datasource/vdb/tencent/__init__.py Voir le fichier


+ 227
- 0
api/core/rag/datasource/vdb/tencent/tencent_vector.py Voir le fichier

@@ -0,0 +1,227 @@
import json
from typing import Any, Optional

from flask import current_app
from pydantic import BaseModel
from tcvectordb import VectorDBClient
from tcvectordb.model import document, enum
from tcvectordb.model import index as vdb_index
from tcvectordb.model.document import Filter

from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset


class TencentConfig(BaseModel):
url: str
api_key: Optional[str]
timeout: float = 30
username: Optional[str]
database: Optional[str]
index_type: str = "HNSW"
metric_type: str = "L2"
shard: int = 1,
replicas: int = 2,

def to_tencent_params(self):
return {
'url': self.url,
'username': self.username,
'key': self.api_key,
'timeout': self.timeout
}


class TencentVector(BaseVector):
field_id: str = "id"
field_vector: str = "vector"
field_text: str = "text"
field_metadata: str = "metadata"

def __init__(self, collection_name: str, config: TencentConfig):
super().__init__(collection_name)
self._client_config = config
self._client = VectorDBClient(**self._client_config.to_tencent_params())
self._db = self._init_database()

def _init_database(self):
exists = False
for db in self._client.list_databases():
if db.database_name == self._client_config.database:
exists = True
break
if exists:
return self._client.database(self._client_config.database)
else:
return self._client.create_database(database_name=self._client_config.database)

def get_type(self) -> str:
return 'tencent'

def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}

def _has_collection(self) -> bool:
collections = self._db.list_collections()
for collection in collections:
if collection.collection_name == self._collection_name:
return True
return False

def _create_collection(self, dimension: int) -> None:
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return

if self._has_collection():
return

self.delete()
index_type = None
for k, v in enum.IndexType.__members__.items():
if k == self._client_config.index_type:
index_type = v
if index_type is None:
raise ValueError("unsupported index_type")
metric_type = None
for k, v in enum.MetricType.__members__.items():
if k == self._client_config.metric_type:
metric_type = v
if metric_type is None:
raise ValueError("unsupported metric_type")
params = vdb_index.HNSWParams(m=16, efconstruction=200)
index = vdb_index.Index(
vdb_index.FilterIndex(
self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY
),
vdb_index.VectorIndex(
self.field_vector,
dimension,
index_type,
metric_type,
params,
),
vdb_index.FilterIndex(
self.field_text, enum.FieldType.String, enum.IndexType.FILTER
),
vdb_index.FilterIndex(
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
),
)

self._db.create_collection(
name=self._collection_name,
shard=self._client_config.shard,
replicas=self._client_config.replicas,
description="Collection for Dify",
index=index,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_collection(len(embeddings[0]))
self.add_texts(texts, embeddings)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
total_count = len(embeddings)
docs = []
for id in range(0, total_count):
if metadatas is None:
continue
metadata = json.dumps(metadatas[id])
doc = document.Document(
id=metadatas[id]["doc_id"],
vector=embeddings[id],
text=texts[id],
metadata=metadata,
)
docs.append(doc)
self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout)

def text_exists(self, id: str) -> bool:
docs = self._db.collection(self._collection_name).query(document_ids=[id])
if docs and len(docs) > 0:
return True
return False

def delete_by_ids(self, ids: list[str]) -> None:
self._db.collection(self._collection_name).delete(document_ids=ids)

def delete_by_metadata_field(self, key: str, value: str) -> None:
self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value])))

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:

res = self._db.collection(self._collection_name).search(vectors=[query_vector],
params=document.HNSWSearchParams(
ef=kwargs.get("ef", 10)),
retrieve_vector=False,
limit=kwargs.get('top_k', 4),
timeout=self._client_config.timeout,
)
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
return self._get_search_res(res, score_threshold)

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []

def _get_search_res(self, res, score_threshold):
docs = []
if res is None or len(res) == 0:
return docs

for result in res[0]:
meta = result.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
score = 1 - result.get("score", 0.0)
if score > score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)
docs.append(doc)

return docs

def delete(self) -> None:
self._db.drop_collection(name=self._collection_name)




class TencentVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector:

if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))

config = current_app.config
return TencentVector(
collection_name=collection_name,
config=TencentConfig(
url=config.get('TENCENT_VECTOR_DB_URL'),
api_key=config.get('TENCENT_VECTOR_DB_API_KEY'),
timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'),
username=config.get('TENCENT_VECTOR_DB_USERNAME'),
database=config.get('TENCENT_VECTOR_DB_DATABASE'),
shard=config.get('TENCENT_VECTOR_DB_SHARD'),
replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'),
)
)

+ 3
- 1
api/core/rag/datasource/vdb/vector_factory.py Voir le fichier

@@ -39,7 +39,6 @@ class Vector:
def _init_vector(self) -> BaseVector:
config = current_app.config
vector_type = config.get('VECTOR_STORE')

if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']

@@ -76,6 +75,9 @@ class Vector:
case VectorType.WEAVIATE:
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
return WeaviateVectorFactory
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")


+ 1
- 0
api/core/rag/datasource/vdb/vector_type.py Voir le fichier

@@ -10,3 +10,4 @@ class VectorType(str, Enum):
RELYT = 'relyt'
TIDB_VECTOR = 'tidb_vector'
WEAVIATE = 'weaviate'
TENCENT = 'tencent'

+ 44
- 1
api/poetry.lock Voir le fichier

@@ -1439,6 +1439,23 @@ mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.8.0)", "types-Pill
test = ["Pillow", "contourpy[test-no-images]", "matplotlib"]
test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"]

[[package]]
name = "cos-python-sdk-v5"
version = "1.9.29"
description = "cos-python-sdk-v5"
optional = false
python-versions = "*"
files = [
{file = "cos-python-sdk-v5-1.9.29.tar.gz", hash = "sha256:1bb07022368d178e7a50a3cc42e0d6cbf4b0bef2af12a3bb8436904339cdec8e"},
]

[package.dependencies]
crcmod = "*"
pycryptodome = "*"
requests = ">=2.8"
six = "*"
xmltodict = "*"

[[package]]
name = "coverage"
version = "7.2.7"
@@ -7411,6 +7428,21 @@ files = [
[package.extras]
widechars = ["wcwidth"]

[[package]]
name = "tcvectordb"
version = "1.3.2"
description = "Tencent VectorDB Python SDK"
optional = false
python-versions = ">=3"
files = [
{file = "tcvectordb-1.3.2-py3-none-any.whl", hash = "sha256:c4b6922d5df4cf14fcd3e61220d9374d1d53ec7270c254216ae35f8a752908f3"},
{file = "tcvectordb-1.3.2.tar.gz", hash = "sha256:2772f5871a69744ffc7c970b321312d626078533a721de3c744059a81aab419e"},
]

[package.dependencies]
cos-python-sdk-v5 = ">=1.9.26"
requests = "*"

[[package]]
name = "tenacity"
version = "8.3.0"
@@ -8641,6 +8673,17 @@ files = [
{file = "XlsxWriter-3.2.0.tar.gz", hash = "sha256:9977d0c661a72866a61f9f7a809e25ebbb0fb7036baa3b9fe74afcfca6b3cb8c"},
]

[[package]]
name = "xmltodict"
version = "0.13.0"
description = "Makes working with XML feel like you are working with JSON"
optional = false
python-versions = ">=3.4"
files = [
{file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"},
{file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"},
]

[[package]]
name = "yarl"
version = "1.9.4"
@@ -8878,4 +8921,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "32a9ac027beabdb863fb33886bbf6f0000cbddf4d6089cbdb5c5dbfba23b29b4"
content-hash = "e967aa4b61dc7c40f2f50eb325038da1dc0ff633d8f778e7a7560bdabce744dc"

+ 1
- 0
api/pyproject.toml Voir le fichier

@@ -179,6 +179,7 @@ google-cloud-aiplatform = "1.49.0"
vanna = {version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"]}
kaleido = "0.2.1"
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
tcvectordb = "1.3.2"
chromadb = "~0.5.0"

[tool.poetry.group.dev]

+ 1
- 0
api/requirements.txt Voir le fichier

@@ -78,6 +78,7 @@ lxml==5.1.0
pydantic~=2.7.4
pydantic_extra_types~=2.8.1
pgvecto-rs==0.1.4
tcvectordb==1.3.2
firecrawl-py==0.0.5
oss2==2.18.5
pgvector==0.2.5

+ 0
- 0
api/tests/integration_tests/vdb/__mock/__init__.py Voir le fichier


+ 132
- 0
api/tests/integration_tests/vdb/__mock/tcvectordb.py Voir le fichier

@@ -0,0 +1,132 @@
import os
from typing import Optional

import pytest
from _pytest.monkeypatch import MonkeyPatch
from requests.adapters import HTTPAdapter
from tcvectordb import VectorDBClient
from tcvectordb.model.database import Collection, Database
from tcvectordb.model.document import Document, Filter
from tcvectordb.model.enum import ReadConsistency
from tcvectordb.model.index import Index
from xinference_client.types import Embedding


class MockTcvectordbClass:

def VectorDBClient(self, url=None, username='', key='',
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
timeout=5,
adapter: HTTPAdapter = None):
self._conn = None
self._read_consistency = read_consistency

def list_databases(self) -> list[Database]:
return [
Database(
conn=self._conn,
read_consistency=self._read_consistency,
name='dify',
)]

def list_collections(self, timeout: Optional[float] = None) -> list[Collection]:
return []

def drop_collection(self, name: str, timeout: Optional[float] = None):
return {
"code": 0,
"msg": "operation success"
}

def create_collection(
self,
name: str,
shard: int,
replicas: int,
description: str,
index: Index,
embedding: Embedding = None,
timeout: float = None,
) -> Collection:
return Collection(self, name, shard, replicas, description, index, embedding=embedding,
read_consistency=self._read_consistency, timeout=timeout)

def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
collection = Collection(
self,
name,
shard=1,
replicas=2,
description=name,
timeout=timeout
)
return collection

def collection_upsert(
self,
documents: list[Document],
timeout: Optional[float] = None,
build_index: bool = True,
**kwargs
):
return {
"code": 0,
"msg": "operation success"
}

def collection_search(
self,
vectors: list[list[float]],
filter: Filter = None,
params=None,
retrieve_vector: bool = False,
limit: int = 10,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
) -> list[list[dict]]:
return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]]

def collection_query(
self,
document_ids: Optional[list] = None,
retrieve_vector: bool = False,
limit: Optional[int] = None,
offset: Optional[int] = None,
filter: Optional[Filter] = None,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
) -> list[dict]:
return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]

def collection_delete(
self,
document_ids: list[str] = None,
filter: Filter = None,
timeout: float = None,
):
return {
"code": 0,
"msg": "operation success"
}


MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'

@pytest.fixture
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient)
monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases)
monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection)
monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections)
monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection)
monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection)
monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert)
monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search)
monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query)
monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete)

yield

if MOCK:
monkeypatch.undo()

+ 0
- 0
api/tests/integration_tests/vdb/tcvectordb/__init__.py Voir le fichier


+ 35
- 0
api/tests/integration_tests/vdb/tcvectordb/test_tencent.py Voir le fichier

@@ -0,0 +1,35 @@
from unittest.mock import MagicMock

from core.rag.datasource.vdb.tencent.tencent_vector import TencentConfig, TencentVector
from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis

mock_client = MagicMock()
mock_client.list_databases.return_value = [{"name": "test"}]

class TencentVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = TencentVector("dify", TencentConfig(
url="http://127.0.0.1",
api_key="dify",
timeout=30,
username="dify",
database="dify",
shard=1,
replicas=2,
))

def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1

def search_by_full_text(self):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0

def test_tencent_vector(setup_mock_redis,setup_tcvectordb_mock):
TencentVectorTest().run_all_tests()




+ 8
- 0
docker/docker-compose.yaml Voir le fichier

@@ -298,6 +298,14 @@ services:
RELYT_USER: postgres
RELYT_PASSWORD: difyai123456
RELYT_DATABASE: postgres
# tencent configurations
TENCENT_VECTOR_DB_URL: http://127.0.0.1
TENCENT_VECTOR_DB_API_KEY: dify
TENCENT_VECTOR_DB_TIMEOUT: 30
TENCENT_VECTOR_DB_USERNAME: dify
TENCENT_VECTOR_DB_DATABASE: dify
TENCENT_VECTOR_DB_SHARD: 1
TENCENT_VECTOR_DB_REPLICAS: 2
# pgvector configurations
PGVECTOR_HOST: pgvector
PGVECTOR_PORT: 5432

Chargement…
Annuler
Enregistrer