您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

tcvectordb.py 7.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import os
  2. from typing import Optional, Union
  3. import pytest
  4. from _pytest.monkeypatch import MonkeyPatch
  5. from requests.adapters import HTTPAdapter
  6. from tcvectordb import RPCVectorDBClient # type: ignore
  7. from tcvectordb.model import enum
  8. from tcvectordb.model.collection import FilterIndexConfig
  9. from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore
  10. from tcvectordb.model.enum import ReadConsistency # type: ignore
  11. from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore
  12. from tcvectordb.rpc.model.collection import RPCCollection
  13. from tcvectordb.rpc.model.database import RPCDatabase
  14. from xinference_client.types import Embedding # type: ignore
  15. class MockTcvectordbClass:
  16. def mock_vector_db_client(
  17. self,
  18. url: str,
  19. username="",
  20. key="",
  21. read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
  22. timeout=10,
  23. adapter: HTTPAdapter = None,
  24. pool_size: int = 2,
  25. proxies: Optional[dict] = None,
  26. password: Optional[str] = None,
  27. **kwargs,
  28. ):
  29. self._conn = None
  30. self._read_consistency = read_consistency
  31. def create_database_if_not_exists(self, database_name: str, timeout: Optional[float] = None) -> RPCDatabase:
  32. return RPCDatabase(
  33. name="dify",
  34. read_consistency=self._read_consistency,
  35. )
  36. def exists_collection(self, database_name: str, collection_name: str) -> bool:
  37. return True
  38. def describe_collection(
  39. self, database_name: str, collection_name: str, timeout: Optional[float] = None
  40. ) -> RPCCollection:
  41. index = Index(
  42. FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
  43. VectorIndex(
  44. "vector",
  45. 128,
  46. enum.IndexType.HNSW,
  47. enum.MetricType.IP,
  48. HNSWParams(m=16, efconstruction=200),
  49. ),
  50. FilterIndex("text", enum.FieldType.String, enum.IndexType.FILTER),
  51. FilterIndex("metadata", enum.FieldType.String, enum.IndexType.FILTER),
  52. )
  53. return RPCCollection(
  54. RPCDatabase(
  55. name=database_name,
  56. read_consistency=self._read_consistency,
  57. ),
  58. collection_name,
  59. index=index,
  60. )
  61. def create_collection(
  62. self,
  63. database_name: str,
  64. collection_name: str,
  65. shard: int,
  66. replicas: int,
  67. description: Optional[str] = None,
  68. index: Index = None,
  69. embedding: Embedding = None,
  70. timeout: Optional[float] = None,
  71. ttl_config: Optional[dict] = None,
  72. filter_index_config: FilterIndexConfig = None,
  73. indexes: Optional[list[IndexField]] = None,
  74. ) -> RPCCollection:
  75. return RPCCollection(
  76. RPCDatabase(
  77. name="dify",
  78. read_consistency=self._read_consistency,
  79. ),
  80. collection_name,
  81. shard,
  82. replicas,
  83. description,
  84. index,
  85. embedding=embedding,
  86. read_consistency=self._read_consistency,
  87. timeout=timeout,
  88. ttl_config=ttl_config,
  89. filter_index_config=filter_index_config,
  90. indexes=indexes,
  91. )
  92. def collection_upsert(
  93. self,
  94. database_name: str,
  95. collection_name: str,
  96. documents: list[Union[Document, dict]],
  97. timeout: Optional[float] = None,
  98. build_index: bool = True,
  99. **kwargs,
  100. ):
  101. return {"code": 0, "msg": "operation success"}
  102. def collection_search(
  103. self,
  104. database_name: str,
  105. collection_name: str,
  106. vectors: list[list[float]],
  107. filter: Filter = None,
  108. params=None,
  109. retrieve_vector: bool = False,
  110. limit: int = 10,
  111. output_fields: Optional[list[str]] = None,
  112. timeout: Optional[float] = None,
  113. ) -> list[list[dict]]:
  114. return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
  115. def collection_hybrid_search(
  116. self,
  117. database_name: str,
  118. collection_name: str,
  119. ann: Optional[Union[list[AnnSearch], AnnSearch]] = None,
  120. match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None,
  121. filter: Union[Filter, str] = None,
  122. rerank: Optional[Rerank] = None,
  123. retrieve_vector: Optional[bool] = None,
  124. output_fields: Optional[list[str]] = None,
  125. limit: Optional[int] = None,
  126. timeout: Optional[float] = None,
  127. return_pd_object=False,
  128. **kwargs,
  129. ) -> list[list[dict]]:
  130. return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
  131. def collection_query(
  132. self,
  133. database_name: str,
  134. collection_name: str,
  135. document_ids: Optional[list] = None,
  136. retrieve_vector: bool = False,
  137. limit: Optional[int] = None,
  138. offset: Optional[int] = None,
  139. filter: Optional[Filter] = None,
  140. output_fields: Optional[list[str]] = None,
  141. timeout: Optional[float] = None,
  142. ) -> list[dict]:
  143. return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
  144. def collection_delete(
  145. self,
  146. database_name: str,
  147. collection_name: str,
  148. document_ids: Optional[list[str]] = None,
  149. filter: Filter = None,
  150. timeout: Optional[float] = None,
  151. ):
  152. return {"code": 0, "msg": "operation success"}
  153. def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None) -> dict:
  154. return {"code": 0, "msg": "operation success"}
  155. MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
  156. @pytest.fixture
  157. def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
  158. if MOCK:
  159. monkeypatch.setattr(RPCVectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
  160. monkeypatch.setattr(
  161. RPCVectorDBClient, "create_database_if_not_exists", MockTcvectordbClass.create_database_if_not_exists
  162. )
  163. monkeypatch.setattr(RPCVectorDBClient, "exists_collection", MockTcvectordbClass.exists_collection)
  164. monkeypatch.setattr(RPCVectorDBClient, "create_collection", MockTcvectordbClass.create_collection)
  165. monkeypatch.setattr(RPCVectorDBClient, "describe_collection", MockTcvectordbClass.describe_collection)
  166. monkeypatch.setattr(RPCVectorDBClient, "upsert", MockTcvectordbClass.collection_upsert)
  167. monkeypatch.setattr(RPCVectorDBClient, "search", MockTcvectordbClass.collection_search)
  168. monkeypatch.setattr(RPCVectorDBClient, "hybrid_search", MockTcvectordbClass.collection_hybrid_search)
  169. monkeypatch.setattr(RPCVectorDBClient, "query", MockTcvectordbClass.collection_query)
  170. monkeypatch.setattr(RPCVectorDBClient, "delete", MockTcvectordbClass.collection_delete)
  171. monkeypatch.setattr(RPCVectorDBClient, "drop_collection", MockTcvectordbClass.drop_collection)
  172. yield
  173. if MOCK:
  174. monkeypatch.undo()