| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- import os
- from typing import Optional, Union
-
- import pytest
- from _pytest.monkeypatch import MonkeyPatch
- from requests.adapters import HTTPAdapter
- from tcvectordb import RPCVectorDBClient # type: ignore
- from tcvectordb.model import enum
- from tcvectordb.model.collection import FilterIndexConfig
- from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore
- from tcvectordb.model.enum import ReadConsistency # type: ignore
- from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore
- from tcvectordb.rpc.model.collection import RPCCollection
- from tcvectordb.rpc.model.database import RPCDatabase
- from xinference_client.types import Embedding # type: ignore
-
-
- class MockTcvectordbClass:
- def mock_vector_db_client(
- self,
- url: str,
- username="",
- key="",
- read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
- timeout=10,
- adapter: HTTPAdapter = None,
- pool_size: int = 2,
- proxies: Optional[dict] = None,
- password: Optional[str] = None,
- **kwargs,
- ):
- self._conn = None
- self._read_consistency = read_consistency
-
- def create_database_if_not_exists(self, database_name: str, timeout: Optional[float] = None) -> RPCDatabase:
- return RPCDatabase(
- name="dify",
- read_consistency=self._read_consistency,
- )
-
- def exists_collection(self, database_name: str, collection_name: str) -> bool:
- return True
-
- def describe_collection(
- self, database_name: str, collection_name: str, timeout: Optional[float] = None
- ) -> RPCCollection:
- index = Index(
- FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
- VectorIndex(
- "vector",
- 128,
- enum.IndexType.HNSW,
- enum.MetricType.IP,
- HNSWParams(m=16, efconstruction=200),
- ),
- FilterIndex("text", enum.FieldType.String, enum.IndexType.FILTER),
- FilterIndex("metadata", enum.FieldType.String, enum.IndexType.FILTER),
- )
- return RPCCollection(
- RPCDatabase(
- name=database_name,
- read_consistency=self._read_consistency,
- ),
- collection_name,
- index=index,
- )
-
- def create_collection(
- self,
- database_name: str,
- collection_name: str,
- shard: int,
- replicas: int,
- description: Optional[str] = None,
- index: Index = None,
- embedding: Embedding = None,
- timeout: Optional[float] = None,
- ttl_config: Optional[dict] = None,
- filter_index_config: FilterIndexConfig = None,
- indexes: Optional[list[IndexField]] = None,
- ) -> RPCCollection:
- return RPCCollection(
- RPCDatabase(
- name="dify",
- read_consistency=self._read_consistency,
- ),
- collection_name,
- shard,
- replicas,
- description,
- index,
- embedding=embedding,
- read_consistency=self._read_consistency,
- timeout=timeout,
- ttl_config=ttl_config,
- filter_index_config=filter_index_config,
- indexes=indexes,
- )
-
- def collection_upsert(
- self,
- database_name: str,
- collection_name: str,
- documents: list[Union[Document, dict]],
- timeout: Optional[float] = None,
- build_index: bool = True,
- **kwargs,
- ):
- return {"code": 0, "msg": "operation success"}
-
- def collection_search(
- self,
- database_name: str,
- collection_name: str,
- vectors: list[list[float]],
- filter: Filter = None,
- params=None,
- retrieve_vector: bool = False,
- limit: int = 10,
- output_fields: Optional[list[str]] = None,
- timeout: Optional[float] = None,
- ) -> list[list[dict]]:
- return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
-
- def collection_hybrid_search(
- self,
- database_name: str,
- collection_name: str,
- ann: Optional[Union[list[AnnSearch], AnnSearch]] = None,
- match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None,
- filter: Union[Filter, str] = None,
- rerank: Optional[Rerank] = None,
- retrieve_vector: Optional[bool] = None,
- output_fields: Optional[list[str]] = None,
- limit: Optional[int] = None,
- timeout: Optional[float] = None,
- return_pd_object=False,
- **kwargs,
- ) -> list[list[dict]]:
- return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
-
- def collection_query(
- self,
- database_name: str,
- collection_name: str,
- document_ids: Optional[list] = None,
- retrieve_vector: bool = False,
- limit: Optional[int] = None,
- offset: Optional[int] = None,
- filter: Optional[Filter] = None,
- output_fields: Optional[list[str]] = None,
- timeout: Optional[float] = None,
- ) -> list[dict]:
- return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
-
- def collection_delete(
- self,
- database_name: str,
- collection_name: str,
- document_ids: Optional[list[str]] = None,
- filter: Filter = None,
- timeout: Optional[float] = None,
- ):
- return {"code": 0, "msg": "operation success"}
-
- def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None) -> dict:
- return {"code": 0, "msg": "operation success"}
-
-
- MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
-
-
- @pytest.fixture
- def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
- if MOCK:
- monkeypatch.setattr(RPCVectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
- monkeypatch.setattr(
- RPCVectorDBClient, "create_database_if_not_exists", MockTcvectordbClass.create_database_if_not_exists
- )
- monkeypatch.setattr(RPCVectorDBClient, "exists_collection", MockTcvectordbClass.exists_collection)
- monkeypatch.setattr(RPCVectorDBClient, "create_collection", MockTcvectordbClass.create_collection)
- monkeypatch.setattr(RPCVectorDBClient, "describe_collection", MockTcvectordbClass.describe_collection)
- monkeypatch.setattr(RPCVectorDBClient, "upsert", MockTcvectordbClass.collection_upsert)
- monkeypatch.setattr(RPCVectorDBClient, "search", MockTcvectordbClass.collection_search)
- monkeypatch.setattr(RPCVectorDBClient, "hybrid_search", MockTcvectordbClass.collection_hybrid_search)
- monkeypatch.setattr(RPCVectorDBClient, "query", MockTcvectordbClass.collection_query)
- monkeypatch.setattr(RPCVectorDBClient, "delete", MockTcvectordbClass.collection_delete)
- monkeypatch.setattr(RPCVectorDBClient, "drop_collection", MockTcvectordbClass.drop_collection)
-
- yield
-
- if MOCK:
- monkeypatch.undo()
|