| RELYT_PASSWORD=postgres | RELYT_PASSWORD=postgres | ||||
| RELYT_DATABASE=postgres | RELYT_DATABASE=postgres | ||||
| # Tencent configuration | |||||
| TENCENT_VECTOR_DB_URL=http://127.0.0.1 | |||||
| TENCENT_VECTOR_DB_API_KEY=dify | |||||
| TENCENT_VECTOR_DB_TIMEOUT=30 | |||||
| TENCENT_VECTOR_DB_USERNAME=dify | |||||
| TENCENT_VECTOR_DB_DATABASE=dify | |||||
| TENCENT_VECTOR_DB_SHARD=1 | |||||
| TENCENT_VECTOR_DB_REPLICAS=2 | |||||
| # PGVECTO_RS configuration | # PGVECTO_RS configuration | ||||
| PGVECTO_RS_HOST=localhost | PGVECTO_RS_HOST=localhost | ||||
| PGVECTO_RS_PORT=5431 | PGVECTO_RS_PORT=5431 |
| "vector_store": {"class_prefix": collection_name} | "vector_store": {"class_prefix": collection_name} | ||||
| } | } | ||||
| dataset.index_struct = json.dumps(index_struct_dict) | dataset.index_struct = json.dumps(index_struct_dict) | ||||
| elif vector_type == VectorType.TENCENT: | |||||
| dataset_id = dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||||
| index_struct_dict = { | |||||
| "type": VectorType.TENCENT, | |||||
| "vector_store": {"class_prefix": collection_name} | |||||
| } | |||||
| dataset.index_struct = json.dumps(index_struct_dict) | |||||
| elif vector_type == VectorType.PGVECTOR: | elif vector_type == VectorType.PGVECTOR: | ||||
| dataset_id = dataset.id | dataset_id = dataset.id | ||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | collection_name = Dataset.gen_collection_name_by_id(dataset_id) |
| self.RELYT_PASSWORD = get_env('RELYT_PASSWORD') | self.RELYT_PASSWORD = get_env('RELYT_PASSWORD') | ||||
| self.RELYT_DATABASE = get_env('RELYT_DATABASE') | self.RELYT_DATABASE = get_env('RELYT_DATABASE') | ||||
| # tencent settings | |||||
| self.TENCENT_VECTOR_DB_URL = get_env('TENCENT_VECTOR_DB_URL') | |||||
| self.TENCENT_VECTOR_DB_API_KEY = get_env('TENCENT_VECTOR_DB_API_KEY') | |||||
| self.TENCENT_VECTOR_DB_TIMEOUT = get_env('TENCENT_VECTOR_DB_TIMEOUT') | |||||
| self.TENCENT_VECTOR_DB_USERNAME = get_env('TENCENT_VECTOR_DB_USERNAME') | |||||
| self.TENCENT_VECTOR_DB_DATABASE = get_env('TENCENT_VECTOR_DB_DATABASE') | |||||
| self.TENCENT_VECTOR_DB_SHARD = get_env('TENCENT_VECTOR_DB_SHARD') | |||||
| self.TENCENT_VECTOR_DB_REPLICAS = get_env('TENCENT_VECTOR_DB_REPLICAS') | |||||
| # pgvecto rs settings | # pgvecto rs settings | ||||
| self.PGVECTO_RS_HOST = get_env('PGVECTO_RS_HOST') | self.PGVECTO_RS_HOST = get_env('PGVECTO_RS_HOST') | ||||
| self.PGVECTO_RS_PORT = get_env('PGVECTO_RS_PORT') | self.PGVECTO_RS_PORT = get_env('PGVECTO_RS_PORT') |
| @account_initialization_required | @account_initialization_required | ||||
| def get(self): | def get(self): | ||||
| vector_type = current_app.config['VECTOR_STORE'] | vector_type = current_app.config['VECTOR_STORE'] | ||||
| match vector_type: | match vector_type: | ||||
| case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA: | |||||
| case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT: | |||||
| return { | return { | ||||
| 'retrieval_method': [ | 'retrieval_method': [ | ||||
| 'semantic_search' | 'semantic_search' | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self, vector_type): | def get(self, vector_type): | ||||
| match vector_type: | match vector_type: | ||||
| case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA: | |||||
| case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCEN: | |||||
| return { | return { | ||||
| 'retrieval_method': [ | 'retrieval_method': [ | ||||
| 'semantic_search' | 'semantic_search' |
| import json | |||||
| from typing import Any, Optional | |||||
| from flask import current_app | |||||
| from pydantic import BaseModel | |||||
| from tcvectordb import VectorDBClient | |||||
| from tcvectordb.model import document, enum | |||||
| from tcvectordb.model import index as vdb_index | |||||
| from tcvectordb.model.document import Filter | |||||
| from core.rag.datasource.entity.embedding import Embeddings | |||||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||||
| from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | |||||
| from core.rag.datasource.vdb.vector_type import VectorType | |||||
| from core.rag.models.document import Document | |||||
| from extensions.ext_redis import redis_client | |||||
| from models.dataset import Dataset | |||||
| class TencentConfig(BaseModel): | |||||
| url: str | |||||
| api_key: Optional[str] | |||||
| timeout: float = 30 | |||||
| username: Optional[str] | |||||
| database: Optional[str] | |||||
| index_type: str = "HNSW" | |||||
| metric_type: str = "L2" | |||||
| shard: int = 1, | |||||
| replicas: int = 2, | |||||
| def to_tencent_params(self): | |||||
| return { | |||||
| 'url': self.url, | |||||
| 'username': self.username, | |||||
| 'key': self.api_key, | |||||
| 'timeout': self.timeout | |||||
| } | |||||
| class TencentVector(BaseVector): | |||||
| field_id: str = "id" | |||||
| field_vector: str = "vector" | |||||
| field_text: str = "text" | |||||
| field_metadata: str = "metadata" | |||||
| def __init__(self, collection_name: str, config: TencentConfig): | |||||
| super().__init__(collection_name) | |||||
| self._client_config = config | |||||
| self._client = VectorDBClient(**self._client_config.to_tencent_params()) | |||||
| self._db = self._init_database() | |||||
| def _init_database(self): | |||||
| exists = False | |||||
| for db in self._client.list_databases(): | |||||
| if db.database_name == self._client_config.database: | |||||
| exists = True | |||||
| break | |||||
| if exists: | |||||
| return self._client.database(self._client_config.database) | |||||
| else: | |||||
| return self._client.create_database(database_name=self._client_config.database) | |||||
| def get_type(self) -> str: | |||||
| return 'tencent' | |||||
| def to_index_struct(self) -> dict: | |||||
| return { | |||||
| "type": self.get_type(), | |||||
| "vector_store": {"class_prefix": self._collection_name} | |||||
| } | |||||
| def _has_collection(self) -> bool: | |||||
| collections = self._db.list_collections() | |||||
| for collection in collections: | |||||
| if collection.collection_name == self._collection_name: | |||||
| return True | |||||
| return False | |||||
| def _create_collection(self, dimension: int) -> None: | |||||
| lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) | |||||
| with redis_client.lock(lock_name, timeout=20): | |||||
| collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) | |||||
| if redis_client.get(collection_exist_cache_key): | |||||
| return | |||||
| if self._has_collection(): | |||||
| return | |||||
| self.delete() | |||||
| index_type = None | |||||
| for k, v in enum.IndexType.__members__.items(): | |||||
| if k == self._client_config.index_type: | |||||
| index_type = v | |||||
| if index_type is None: | |||||
| raise ValueError("unsupported index_type") | |||||
| metric_type = None | |||||
| for k, v in enum.MetricType.__members__.items(): | |||||
| if k == self._client_config.metric_type: | |||||
| metric_type = v | |||||
| if metric_type is None: | |||||
| raise ValueError("unsupported metric_type") | |||||
| params = vdb_index.HNSWParams(m=16, efconstruction=200) | |||||
| index = vdb_index.Index( | |||||
| vdb_index.FilterIndex( | |||||
| self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY | |||||
| ), | |||||
| vdb_index.VectorIndex( | |||||
| self.field_vector, | |||||
| dimension, | |||||
| index_type, | |||||
| metric_type, | |||||
| params, | |||||
| ), | |||||
| vdb_index.FilterIndex( | |||||
| self.field_text, enum.FieldType.String, enum.IndexType.FILTER | |||||
| ), | |||||
| vdb_index.FilterIndex( | |||||
| self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER | |||||
| ), | |||||
| ) | |||||
| self._db.create_collection( | |||||
| name=self._collection_name, | |||||
| shard=self._client_config.shard, | |||||
| replicas=self._client_config.replicas, | |||||
| description="Collection for Dify", | |||||
| index=index, | |||||
| ) | |||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| self._create_collection(len(embeddings[0])) | |||||
| self.add_texts(texts, embeddings) | |||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| texts = [doc.page_content for doc in documents] | |||||
| metadatas = [doc.metadata for doc in documents] | |||||
| total_count = len(embeddings) | |||||
| docs = [] | |||||
| for id in range(0, total_count): | |||||
| if metadatas is None: | |||||
| continue | |||||
| metadata = json.dumps(metadatas[id]) | |||||
| doc = document.Document( | |||||
| id=metadatas[id]["doc_id"], | |||||
| vector=embeddings[id], | |||||
| text=texts[id], | |||||
| metadata=metadata, | |||||
| ) | |||||
| docs.append(doc) | |||||
| self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout) | |||||
| def text_exists(self, id: str) -> bool: | |||||
| docs = self._db.collection(self._collection_name).query(document_ids=[id]) | |||||
| if docs and len(docs) > 0: | |||||
| return True | |||||
| return False | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| self._db.collection(self._collection_name).delete(document_ids=ids) | |||||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||||
| self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value]))) | |||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||||
| res = self._db.collection(self._collection_name).search(vectors=[query_vector], | |||||
| params=document.HNSWSearchParams( | |||||
| ef=kwargs.get("ef", 10)), | |||||
| retrieve_vector=False, | |||||
| limit=kwargs.get('top_k', 4), | |||||
| timeout=self._client_config.timeout, | |||||
| ) | |||||
| score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 | |||||
| return self._get_search_res(res, score_threshold) | |||||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||||
| return [] | |||||
| def _get_search_res(self, res, score_threshold): | |||||
| docs = [] | |||||
| if res is None or len(res) == 0: | |||||
| return docs | |||||
| for result in res[0]: | |||||
| meta = result.get(self.field_metadata) | |||||
| if meta is not None: | |||||
| meta = json.loads(meta) | |||||
| score = 1 - result.get("score", 0.0) | |||||
| if score > score_threshold: | |||||
| meta["score"] = score | |||||
| doc = Document(page_content=result.get(self.field_text), metadata=meta) | |||||
| docs.append(doc) | |||||
| return docs | |||||
| def delete(self) -> None: | |||||
| self._db.drop_collection(name=self._collection_name) | |||||
| class TencentVectorFactory(AbstractVectorFactory): | |||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector: | |||||
| if dataset.index_struct_dict: | |||||
| class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix.lower() | |||||
| else: | |||||
| dataset_id = dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() | |||||
| dataset.index_struct = json.dumps( | |||||
| self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) | |||||
| config = current_app.config | |||||
| return TencentVector( | |||||
| collection_name=collection_name, | |||||
| config=TencentConfig( | |||||
| url=config.get('TENCENT_VECTOR_DB_URL'), | |||||
| api_key=config.get('TENCENT_VECTOR_DB_API_KEY'), | |||||
| timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'), | |||||
| username=config.get('TENCENT_VECTOR_DB_USERNAME'), | |||||
| database=config.get('TENCENT_VECTOR_DB_DATABASE'), | |||||
| shard=config.get('TENCENT_VECTOR_DB_SHARD'), | |||||
| replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'), | |||||
| ) | |||||
| ) |
| def _init_vector(self) -> BaseVector: | def _init_vector(self) -> BaseVector: | ||||
| config = current_app.config | config = current_app.config | ||||
| vector_type = config.get('VECTOR_STORE') | vector_type = config.get('VECTOR_STORE') | ||||
| if self._dataset.index_struct_dict: | if self._dataset.index_struct_dict: | ||||
| vector_type = self._dataset.index_struct_dict['type'] | vector_type = self._dataset.index_struct_dict['type'] | ||||
| case VectorType.WEAVIATE: | case VectorType.WEAVIATE: | ||||
| from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory | from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory | ||||
| return WeaviateVectorFactory | return WeaviateVectorFactory | ||||
| case VectorType.TENCENT: | |||||
| from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory | |||||
| return TencentVectorFactory | |||||
| case _: | case _: | ||||
| raise ValueError(f"Vector store {vector_type} is not supported.") | raise ValueError(f"Vector store {vector_type} is not supported.") | ||||
| RELYT = 'relyt' | RELYT = 'relyt' | ||||
| TIDB_VECTOR = 'tidb_vector' | TIDB_VECTOR = 'tidb_vector' | ||||
| WEAVIATE = 'weaviate' | WEAVIATE = 'weaviate' | ||||
| TENCENT = 'tencent' |
| test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] | test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] | ||||
| test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] | test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] | ||||
| [[package]] | |||||
| name = "cos-python-sdk-v5" | |||||
| version = "1.9.29" | |||||
| description = "cos-python-sdk-v5" | |||||
| optional = false | |||||
| python-versions = "*" | |||||
| files = [ | |||||
| {file = "cos-python-sdk-v5-1.9.29.tar.gz", hash = "sha256:1bb07022368d178e7a50a3cc42e0d6cbf4b0bef2af12a3bb8436904339cdec8e"}, | |||||
| ] | |||||
| [package.dependencies] | |||||
| crcmod = "*" | |||||
| pycryptodome = "*" | |||||
| requests = ">=2.8" | |||||
| six = "*" | |||||
| xmltodict = "*" | |||||
| [[package]] | [[package]] | ||||
| name = "coverage" | name = "coverage" | ||||
| version = "7.2.7" | version = "7.2.7" | ||||
| [package.extras] | [package.extras] | ||||
| widechars = ["wcwidth"] | widechars = ["wcwidth"] | ||||
| [[package]] | |||||
| name = "tcvectordb" | |||||
| version = "1.3.2" | |||||
| description = "Tencent VectorDB Python SDK" | |||||
| optional = false | |||||
| python-versions = ">=3" | |||||
| files = [ | |||||
| {file = "tcvectordb-1.3.2-py3-none-any.whl", hash = "sha256:c4b6922d5df4cf14fcd3e61220d9374d1d53ec7270c254216ae35f8a752908f3"}, | |||||
| {file = "tcvectordb-1.3.2.tar.gz", hash = "sha256:2772f5871a69744ffc7c970b321312d626078533a721de3c744059a81aab419e"}, | |||||
| ] | |||||
| [package.dependencies] | |||||
| cos-python-sdk-v5 = ">=1.9.26" | |||||
| requests = "*" | |||||
| [[package]] | [[package]] | ||||
| name = "tenacity" | name = "tenacity" | ||||
| version = "8.3.0" | version = "8.3.0" | ||||
| {file = "XlsxWriter-3.2.0.tar.gz", hash = "sha256:9977d0c661a72866a61f9f7a809e25ebbb0fb7036baa3b9fe74afcfca6b3cb8c"}, | {file = "XlsxWriter-3.2.0.tar.gz", hash = "sha256:9977d0c661a72866a61f9f7a809e25ebbb0fb7036baa3b9fe74afcfca6b3cb8c"}, | ||||
| ] | ] | ||||
| [[package]] | |||||
| name = "xmltodict" | |||||
| version = "0.13.0" | |||||
| description = "Makes working with XML feel like you are working with JSON" | |||||
| optional = false | |||||
| python-versions = ">=3.4" | |||||
| files = [ | |||||
| {file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"}, | |||||
| {file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"}, | |||||
| ] | |||||
| [[package]] | [[package]] | ||||
| name = "yarl" | name = "yarl" | ||||
| version = "1.9.4" | version = "1.9.4" | ||||
| [metadata] | [metadata] | ||||
| lock-version = "2.0" | lock-version = "2.0" | ||||
| python-versions = "^3.10" | python-versions = "^3.10" | ||||
| content-hash = "32a9ac027beabdb863fb33886bbf6f0000cbddf4d6089cbdb5c5dbfba23b29b4" | |||||
| content-hash = "e967aa4b61dc7c40f2f50eb325038da1dc0ff633d8f778e7a7560bdabce744dc" |
| vanna = {version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"]} | vanna = {version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"]} | ||||
| kaleido = "0.2.1" | kaleido = "0.2.1" | ||||
| tencentcloud-sdk-python-hunyuan = "~3.0.1158" | tencentcloud-sdk-python-hunyuan = "~3.0.1158" | ||||
| tcvectordb = "1.3.2" | |||||
| chromadb = "~0.5.0" | chromadb = "~0.5.0" | ||||
| [tool.poetry.group.dev] | [tool.poetry.group.dev] |
| pydantic~=2.7.4 | pydantic~=2.7.4 | ||||
| pydantic_extra_types~=2.8.1 | pydantic_extra_types~=2.8.1 | ||||
| pgvecto-rs==0.1.4 | pgvecto-rs==0.1.4 | ||||
| tcvectordb==1.3.2 | |||||
| firecrawl-py==0.0.5 | firecrawl-py==0.0.5 | ||||
| oss2==2.18.5 | oss2==2.18.5 | ||||
| pgvector==0.2.5 | pgvector==0.2.5 |
| import os | |||||
| from typing import Optional | |||||
| import pytest | |||||
| from _pytest.monkeypatch import MonkeyPatch | |||||
| from requests.adapters import HTTPAdapter | |||||
| from tcvectordb import VectorDBClient | |||||
| from tcvectordb.model.database import Collection, Database | |||||
| from tcvectordb.model.document import Document, Filter | |||||
| from tcvectordb.model.enum import ReadConsistency | |||||
| from tcvectordb.model.index import Index | |||||
| from xinference_client.types import Embedding | |||||
| class MockTcvectordbClass: | |||||
| def VectorDBClient(self, url=None, username='', key='', | |||||
| read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, | |||||
| timeout=5, | |||||
| adapter: HTTPAdapter = None): | |||||
| self._conn = None | |||||
| self._read_consistency = read_consistency | |||||
| def list_databases(self) -> list[Database]: | |||||
| return [ | |||||
| Database( | |||||
| conn=self._conn, | |||||
| read_consistency=self._read_consistency, | |||||
| name='dify', | |||||
| )] | |||||
| def list_collections(self, timeout: Optional[float] = None) -> list[Collection]: | |||||
| return [] | |||||
| def drop_collection(self, name: str, timeout: Optional[float] = None): | |||||
| return { | |||||
| "code": 0, | |||||
| "msg": "operation success" | |||||
| } | |||||
| def create_collection( | |||||
| self, | |||||
| name: str, | |||||
| shard: int, | |||||
| replicas: int, | |||||
| description: str, | |||||
| index: Index, | |||||
| embedding: Embedding = None, | |||||
| timeout: float = None, | |||||
| ) -> Collection: | |||||
| return Collection(self, name, shard, replicas, description, index, embedding=embedding, | |||||
| read_consistency=self._read_consistency, timeout=timeout) | |||||
| def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection: | |||||
| collection = Collection( | |||||
| self, | |||||
| name, | |||||
| shard=1, | |||||
| replicas=2, | |||||
| description=name, | |||||
| timeout=timeout | |||||
| ) | |||||
| return collection | |||||
| def collection_upsert( | |||||
| self, | |||||
| documents: list[Document], | |||||
| timeout: Optional[float] = None, | |||||
| build_index: bool = True, | |||||
| **kwargs | |||||
| ): | |||||
| return { | |||||
| "code": 0, | |||||
| "msg": "operation success" | |||||
| } | |||||
| def collection_search( | |||||
| self, | |||||
| vectors: list[list[float]], | |||||
| filter: Filter = None, | |||||
| params=None, | |||||
| retrieve_vector: bool = False, | |||||
| limit: int = 10, | |||||
| output_fields: Optional[list[str]] = None, | |||||
| timeout: Optional[float] = None, | |||||
| ) -> list[list[dict]]: | |||||
| return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]] | |||||
| def collection_query( | |||||
| self, | |||||
| document_ids: Optional[list] = None, | |||||
| retrieve_vector: bool = False, | |||||
| limit: Optional[int] = None, | |||||
| offset: Optional[int] = None, | |||||
| filter: Optional[Filter] = None, | |||||
| output_fields: Optional[list[str]] = None, | |||||
| timeout: Optional[float] = None, | |||||
| ) -> list[dict]: | |||||
| return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}] | |||||
| def collection_delete( | |||||
| self, | |||||
| document_ids: list[str] = None, | |||||
| filter: Filter = None, | |||||
| timeout: float = None, | |||||
| ): | |||||
| return { | |||||
| "code": 0, | |||||
| "msg": "operation success" | |||||
| } | |||||
| MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' | |||||
| @pytest.fixture | |||||
| def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): | |||||
| if MOCK: | |||||
| monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient) | |||||
| monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases) | |||||
| monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection) | |||||
| monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections) | |||||
| monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection) | |||||
| monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection) | |||||
| monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert) | |||||
| monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search) | |||||
| monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query) | |||||
| monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete) | |||||
| yield | |||||
| if MOCK: | |||||
| monkeypatch.undo() |
| from unittest.mock import MagicMock | |||||
| from core.rag.datasource.vdb.tencent.tencent_vector import TencentConfig, TencentVector | |||||
| from tests.integration_tests.vdb.__mock.tcvectordb import setup_tcvectordb_mock | |||||
| from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis | |||||
| mock_client = MagicMock() | |||||
| mock_client.list_databases.return_value = [{"name": "test"}] | |||||
| class TencentVectorTest(AbstractVectorTest): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.vector = TencentVector("dify", TencentConfig( | |||||
| url="http://127.0.0.1", | |||||
| api_key="dify", | |||||
| timeout=30, | |||||
| username="dify", | |||||
| database="dify", | |||||
| shard=1, | |||||
| replicas=2, | |||||
| )) | |||||
| def search_by_vector(self): | |||||
| hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) | |||||
| assert len(hits_by_vector) == 1 | |||||
| def search_by_full_text(self): | |||||
| hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) | |||||
| assert len(hits_by_full_text) == 0 | |||||
| def test_tencent_vector(setup_mock_redis,setup_tcvectordb_mock): | |||||
| TencentVectorTest().run_all_tests() | |||||
| RELYT_USER: postgres | RELYT_USER: postgres | ||||
| RELYT_PASSWORD: difyai123456 | RELYT_PASSWORD: difyai123456 | ||||
| RELYT_DATABASE: postgres | RELYT_DATABASE: postgres | ||||
| # tencent configurations | |||||
| TENCENT_VECTOR_DB_URL: http://127.0.0.1 | |||||
| TENCENT_VECTOR_DB_API_KEY: dify | |||||
| TENCENT_VECTOR_DB_TIMEOUT: 30 | |||||
| TENCENT_VECTOR_DB_USERNAME: dify | |||||
| TENCENT_VECTOR_DB_DATABASE: dify | |||||
| TENCENT_VECTOR_DB_SHARD: 1 | |||||
| TENCENT_VECTOR_DB_REPLICAS: 2 | |||||
| # pgvector configurations | # pgvector configurations | ||||
| PGVECTOR_HOST: pgvector | PGVECTOR_HOST: pgvector | ||||
| PGVECTOR_PORT: 5432 | PGVECTOR_PORT: 5432 |