You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

tcvectordb.py 5.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import os
  2. from typing import Optional, Union
  3. import pytest
  4. from _pytest.monkeypatch import MonkeyPatch
  5. from requests.adapters import HTTPAdapter
  6. from tcvectordb import RPCVectorDBClient # type: ignore
  7. from tcvectordb.model.collection import FilterIndexConfig
  8. from tcvectordb.model.document import Document, Filter # type: ignore
  9. from tcvectordb.model.enum import ReadConsistency # type: ignore
  10. from tcvectordb.model.index import Index, IndexField # type: ignore
  11. from tcvectordb.rpc.model.collection import RPCCollection
  12. from tcvectordb.rpc.model.database import RPCDatabase
  13. from xinference_client.types import Embedding # type: ignore
  14. class MockTcvectordbClass:
  15. def mock_vector_db_client(
  16. self,
  17. url: str,
  18. username="",
  19. key="",
  20. read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
  21. timeout=10,
  22. adapter: HTTPAdapter = None,
  23. pool_size: int = 2,
  24. proxies: Optional[dict] = None,
  25. password: Optional[str] = None,
  26. **kwargs,
  27. ):
  28. self._conn = None
  29. self._read_consistency = read_consistency
  30. def create_database_if_not_exists(self, database_name: str, timeout: Optional[float] = None) -> RPCDatabase:
  31. return RPCDatabase(
  32. name="dify",
  33. read_consistency=self._read_consistency,
  34. )
  35. def exists_collection(self, database_name: str, collection_name: str) -> bool:
  36. return True
  37. def create_collection(
  38. self,
  39. database_name: str,
  40. collection_name: str,
  41. shard: int,
  42. replicas: int,
  43. description: Optional[str] = None,
  44. index: Index = None,
  45. embedding: Embedding = None,
  46. timeout: Optional[float] = None,
  47. ttl_config: Optional[dict] = None,
  48. filter_index_config: FilterIndexConfig = None,
  49. indexes: Optional[list[IndexField]] = None,
  50. ) -> RPCCollection:
  51. return RPCCollection(
  52. RPCDatabase(
  53. name="dify",
  54. read_consistency=self._read_consistency,
  55. ),
  56. collection_name,
  57. shard,
  58. replicas,
  59. description,
  60. index,
  61. embedding=embedding,
  62. read_consistency=self._read_consistency,
  63. timeout=timeout,
  64. ttl_config=ttl_config,
  65. filter_index_config=filter_index_config,
  66. indexes=indexes,
  67. )
  68. def collection_upsert(
  69. self,
  70. database_name: str,
  71. collection_name: str,
  72. documents: list[Union[Document, dict]],
  73. timeout: Optional[float] = None,
  74. build_index: bool = True,
  75. **kwargs,
  76. ):
  77. return {"code": 0, "msg": "operation success"}
  78. def collection_search(
  79. self,
  80. database_name: str,
  81. collection_name: str,
  82. vectors: list[list[float]],
  83. filter: Filter = None,
  84. params=None,
  85. retrieve_vector: bool = False,
  86. limit: int = 10,
  87. output_fields: Optional[list[str]] = None,
  88. timeout: Optional[float] = None,
  89. ) -> list[list[dict]]:
  90. return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
  91. def collection_query(
  92. self,
  93. database_name: str,
  94. collection_name: str,
  95. document_ids: Optional[list] = None,
  96. retrieve_vector: bool = False,
  97. limit: Optional[int] = None,
  98. offset: Optional[int] = None,
  99. filter: Optional[Filter] = None,
  100. output_fields: Optional[list[str]] = None,
  101. timeout: Optional[float] = None,
  102. ) -> list[dict]:
  103. return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
  104. def collection_delete(
  105. self,
  106. database_name: str,
  107. collection_name: str,
  108. document_ids: Optional[list[str]] = None,
  109. filter: Filter = None,
  110. timeout: Optional[float] = None,
  111. ):
  112. return {"code": 0, "msg": "operation success"}
  113. def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None) -> dict:
  114. return {"code": 0, "msg": "operation success"}
  115. MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
  116. @pytest.fixture
  117. def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
  118. if MOCK:
  119. monkeypatch.setattr(RPCVectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
  120. monkeypatch.setattr(
  121. RPCVectorDBClient, "create_database_if_not_exists", MockTcvectordbClass.create_database_if_not_exists
  122. )
  123. monkeypatch.setattr(RPCVectorDBClient, "exists_collection", MockTcvectordbClass.exists_collection)
  124. monkeypatch.setattr(RPCVectorDBClient, "create_collection", MockTcvectordbClass.create_collection)
  125. monkeypatch.setattr(RPCVectorDBClient, "upsert", MockTcvectordbClass.collection_upsert)
  126. monkeypatch.setattr(RPCVectorDBClient, "search", MockTcvectordbClass.collection_search)
  127. monkeypatch.setattr(RPCVectorDBClient, "query", MockTcvectordbClass.collection_query)
  128. monkeypatch.setattr(RPCVectorDBClient, "delete", MockTcvectordbClass.collection_delete)
  129. monkeypatch.setattr(RPCVectorDBClient, "drop_collection", MockTcvectordbClass.drop_collection)
  130. yield
  131. if MOCK:
  132. monkeypatch.undo()