You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

doc_store_conn.py 6.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass
  3. import numpy as np
  4. import polars as pl
  5. DEFAULT_MATCH_VECTOR_TOPN = 10
  6. DEFAULT_MATCH_SPARSE_TOPN = 10
  7. VEC = list | np.ndarray
  8. @dataclass
  9. class SparseVector:
  10. indices: list[int]
  11. values: list[float] | list[int] | None = None
  12. def __post_init__(self):
  13. assert (self.values is None) or (len(self.indices) == len(self.values))
  14. def to_dict_old(self):
  15. d = {"indices": self.indices}
  16. if self.values is not None:
  17. d["values"] = self.values
  18. return d
  19. def to_dict(self):
  20. if self.values is None:
  21. raise ValueError("SparseVector.values is None")
  22. result = {}
  23. for i, v in zip(self.indices, self.values):
  24. result[str(i)] = v
  25. return result
  26. @staticmethod
  27. def from_dict(d):
  28. return SparseVector(d["indices"], d.get("values"))
  29. def __str__(self):
  30. return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
  31. def __repr__(self):
  32. return str(self)
  33. class MatchTextExpr(ABC):
  34. def __init__(
  35. self,
  36. fields: list[str],
  37. matching_text: str,
  38. topn: int,
  39. extra_options: dict = dict(),
  40. ):
  41. self.fields = fields
  42. self.matching_text = matching_text
  43. self.topn = topn
  44. self.extra_options = extra_options
  45. class MatchDenseExpr(ABC):
  46. def __init__(
  47. self,
  48. vector_column_name: str,
  49. embedding_data: VEC,
  50. embedding_data_type: str,
  51. distance_type: str,
  52. topn: int = DEFAULT_MATCH_VECTOR_TOPN,
  53. extra_options: dict = dict(),
  54. ):
  55. self.vector_column_name = vector_column_name
  56. self.embedding_data = embedding_data
  57. self.embedding_data_type = embedding_data_type
  58. self.distance_type = distance_type
  59. self.topn = topn
  60. self.extra_options = extra_options
  61. class MatchSparseExpr(ABC):
  62. def __init__(
  63. self,
  64. vector_column_name: str,
  65. sparse_data: SparseVector | dict,
  66. distance_type: str,
  67. topn: int,
  68. opt_params: dict | None = None,
  69. ):
  70. self.vector_column_name = vector_column_name
  71. self.sparse_data = sparse_data
  72. self.distance_type = distance_type
  73. self.topn = topn
  74. self.opt_params = opt_params
  75. class MatchTensorExpr(ABC):
  76. def __init__(
  77. self,
  78. column_name: str,
  79. query_data: VEC,
  80. query_data_type: str,
  81. topn: int,
  82. extra_option: dict | None = None,
  83. ):
  84. self.column_name = column_name
  85. self.query_data = query_data
  86. self.query_data_type = query_data_type
  87. self.topn = topn
  88. self.extra_option = extra_option
  89. class FusionExpr(ABC):
  90. def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
  91. self.method = method
  92. self.topn = topn
  93. self.fusion_params = fusion_params
  94. MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
  95. class OrderByExpr(ABC):
  96. def __init__(self):
  97. self.fields = list()
  98. def asc(self, field: str):
  99. self.fields.append((field, 0))
  100. return self
  101. def desc(self, field: str):
  102. self.fields.append((field, 1))
  103. return self
  104. def fields(self):
  105. return self.fields
  106. class DocStoreConnection(ABC):
  107. """
  108. Database operations
  109. """
  110. @abstractmethod
  111. def dbType(self) -> str:
  112. """
  113. Return the type of the database.
  114. """
  115. raise NotImplementedError("Not implemented")
  116. @abstractmethod
  117. def health(self) -> dict:
  118. """
  119. Return the health status of the database.
  120. """
  121. raise NotImplementedError("Not implemented")
  122. """
  123. Table operations
  124. """
  125. @abstractmethod
  126. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  127. """
  128. Create an index with given name
  129. """
  130. raise NotImplementedError("Not implemented")
  131. @abstractmethod
  132. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  133. """
  134. Delete an index with given name
  135. """
  136. raise NotImplementedError("Not implemented")
  137. @abstractmethod
  138. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  139. """
  140. Check if an index with given name exists
  141. """
  142. raise NotImplementedError("Not implemented")
  143. """
  144. CRUD operations
  145. """
  146. @abstractmethod
  147. def search(
  148. self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]
  149. ) -> list[dict] | pl.DataFrame:
  150. """
  151. Search with given conjunctive equivalent filtering condition and return all fields of matched documents
  152. """
  153. raise NotImplementedError("Not implemented")
  154. @abstractmethod
  155. def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
  156. """
  157. Get single chunk with given id
  158. """
  159. raise NotImplementedError("Not implemented")
  160. @abstractmethod
  161. def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
  162. """
  163. Update or insert a bulk of rows
  164. """
  165. raise NotImplementedError("Not implemented")
  166. @abstractmethod
  167. def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
  168. """
  169. Update rows with given conjunctive equivalent filtering condition
  170. """
  171. raise NotImplementedError("Not implemented")
  172. @abstractmethod
  173. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  174. """
  175. Delete rows with given conjunctive equivalent filtering condition
  176. """
  177. raise NotImplementedError("Not implemented")
  178. """
  179. Helper functions for search result
  180. """
  181. @abstractmethod
  182. def getTotal(self, res):
  183. raise NotImplementedError("Not implemented")
  184. @abstractmethod
  185. def getChunkIds(self, res):
  186. raise NotImplementedError("Not implemented")
  187. @abstractmethod
  188. def getFields(self, res, fields: list[str]) -> dict[str, dict]:
  189. raise NotImplementedError("Not implemented")
  190. @abstractmethod
  191. def getHighlight(self, res, keywords: list[str], fieldnm: str):
  192. raise NotImplementedError("Not implemented")
  193. @abstractmethod
  194. def getAggregation(self, res, fieldnm: str):
  195. raise NotImplementedError("Not implemented")
  196. """
  197. SQL
  198. """
  199. @abstractmethod
  200. def sql(sql: str, fetch_size: int, format: str):
  201. """
  202. Run the sql generated by text-to-sql
  203. """
  204. raise NotImplementedError("Not implemented")