Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

doc_store_conn.py 6.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass
  3. import numpy as np
  4. DEFAULT_MATCH_VECTOR_TOPN = 10
  5. DEFAULT_MATCH_SPARSE_TOPN = 10
  6. VEC = list | np.ndarray
  7. @dataclass
  8. class SparseVector:
  9. indices: list[int]
  10. values: list[float] | list[int] | None = None
  11. def __post_init__(self):
  12. assert (self.values is None) or (len(self.indices) == len(self.values))
  13. def to_dict_old(self):
  14. d = {"indices": self.indices}
  15. if self.values is not None:
  16. d["values"] = self.values
  17. return d
  18. def to_dict(self):
  19. if self.values is None:
  20. raise ValueError("SparseVector.values is None")
  21. result = {}
  22. for i, v in zip(self.indices, self.values):
  23. result[str(i)] = v
  24. return result
  25. @staticmethod
  26. def from_dict(d):
  27. return SparseVector(d["indices"], d.get("values"))
  28. def __str__(self):
  29. return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
  30. def __repr__(self):
  31. return str(self)
  32. class MatchTextExpr(ABC):
  33. def __init__(
  34. self,
  35. fields: list[str],
  36. matching_text: str,
  37. topn: int,
  38. extra_options: dict = dict(),
  39. ):
  40. self.fields = fields
  41. self.matching_text = matching_text
  42. self.topn = topn
  43. self.extra_options = extra_options
  44. class MatchDenseExpr(ABC):
  45. def __init__(
  46. self,
  47. vector_column_name: str,
  48. embedding_data: VEC,
  49. embedding_data_type: str,
  50. distance_type: str,
  51. topn: int = DEFAULT_MATCH_VECTOR_TOPN,
  52. extra_options: dict = dict(),
  53. ):
  54. self.vector_column_name = vector_column_name
  55. self.embedding_data = embedding_data
  56. self.embedding_data_type = embedding_data_type
  57. self.distance_type = distance_type
  58. self.topn = topn
  59. self.extra_options = extra_options
  60. class MatchSparseExpr(ABC):
  61. def __init__(
  62. self,
  63. vector_column_name: str,
  64. sparse_data: SparseVector | dict,
  65. distance_type: str,
  66. topn: int,
  67. opt_params: dict | None = None,
  68. ):
  69. self.vector_column_name = vector_column_name
  70. self.sparse_data = sparse_data
  71. self.distance_type = distance_type
  72. self.topn = topn
  73. self.opt_params = opt_params
  74. class MatchTensorExpr(ABC):
  75. def __init__(
  76. self,
  77. column_name: str,
  78. query_data: VEC,
  79. query_data_type: str,
  80. topn: int,
  81. extra_option: dict | None = None,
  82. ):
  83. self.column_name = column_name
  84. self.query_data = query_data
  85. self.query_data_type = query_data_type
  86. self.topn = topn
  87. self.extra_option = extra_option
  88. class FusionExpr(ABC):
  89. def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
  90. self.method = method
  91. self.topn = topn
  92. self.fusion_params = fusion_params
  93. MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
  94. class OrderByExpr(ABC):
  95. def __init__(self):
  96. self.fields = list()
  97. def asc(self, field: str):
  98. self.fields.append((field, 0))
  99. return self
  100. def desc(self, field: str):
  101. self.fields.append((field, 1))
  102. return self
  103. def fields(self):
  104. return self.fields
  105. class DocStoreConnection(ABC):
  106. """
  107. Database operations
  108. """
  109. @abstractmethod
  110. def dbType(self) -> str:
  111. """
  112. Return the type of the database.
  113. """
  114. raise NotImplementedError("Not implemented")
  115. @abstractmethod
  116. def health(self) -> dict:
  117. """
  118. Return the health status of the database.
  119. """
  120. raise NotImplementedError("Not implemented")
  121. """
  122. Table operations
  123. """
  124. @abstractmethod
  125. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  126. """
  127. Create an index with given name
  128. """
  129. raise NotImplementedError("Not implemented")
  130. @abstractmethod
  131. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  132. """
  133. Delete an index with given name
  134. """
  135. raise NotImplementedError("Not implemented")
  136. @abstractmethod
  137. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  138. """
  139. Check if an index with given name exists
  140. """
  141. raise NotImplementedError("Not implemented")
  142. """
  143. CRUD operations
  144. """
  145. @abstractmethod
  146. def search(
  147. self, selectFields: list[str],
  148. highlightFields: list[str],
  149. condition: dict,
  150. matchExprs: list[MatchExpr],
  151. orderBy: OrderByExpr,
  152. offset: int,
  153. limit: int,
  154. indexNames: str|list[str],
  155. knowledgebaseIds: list[str],
  156. aggFields: list[str] = [],
  157. rank_feature: dict | None = None
  158. ):
  159. """
  160. Search with given conjunctive equivalent filtering condition and return all fields of matched documents
  161. """
  162. raise NotImplementedError("Not implemented")
  163. @abstractmethod
  164. def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
  165. """
  166. Get single chunk with given id
  167. """
  168. raise NotImplementedError("Not implemented")
  169. @abstractmethod
  170. def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
  171. """
  172. Update or insert a bulk of rows
  173. """
  174. raise NotImplementedError("Not implemented")
  175. @abstractmethod
  176. def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
  177. """
  178. Update rows with given conjunctive equivalent filtering condition
  179. """
  180. raise NotImplementedError("Not implemented")
  181. @abstractmethod
  182. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  183. """
  184. Delete rows with given conjunctive equivalent filtering condition
  185. """
  186. raise NotImplementedError("Not implemented")
  187. """
  188. Helper functions for search result
  189. """
  190. @abstractmethod
  191. def getTotal(self, res):
  192. raise NotImplementedError("Not implemented")
  193. @abstractmethod
  194. def getChunkIds(self, res):
  195. raise NotImplementedError("Not implemented")
  196. @abstractmethod
  197. def getFields(self, res, fields: list[str]) -> dict[str, dict]:
  198. raise NotImplementedError("Not implemented")
  199. @abstractmethod
  200. def getHighlight(self, res, keywords: list[str], fieldnm: str):
  201. raise NotImplementedError("Not implemented")
  202. @abstractmethod
  203. def getAggregation(self, res, fieldnm: str):
  204. raise NotImplementedError("Not implemented")
  205. """
  206. SQL
  207. """
  208. @abstractmethod
  209. def sql(sql: str, fetch_size: int, format: str):
  210. """
  211. Run the sql generated by text-to-sql
  212. """
  213. raise NotImplementedError("Not implemented")