Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

doc_store_conn.py 7.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from abc import ABC, abstractmethod
  17. from dataclasses import dataclass
  18. import numpy as np
  19. DEFAULT_MATCH_VECTOR_TOPN = 10
  20. DEFAULT_MATCH_SPARSE_TOPN = 10
  21. VEC = list | np.ndarray
  22. @dataclass
  23. class SparseVector:
  24. indices: list[int]
  25. values: list[float] | list[int] | None = None
  26. def __post_init__(self):
  27. assert (self.values is None) or (len(self.indices) == len(self.values))
  28. def to_dict_old(self):
  29. d = {"indices": self.indices}
  30. if self.values is not None:
  31. d["values"] = self.values
  32. return d
  33. def to_dict(self):
  34. if self.values is None:
  35. raise ValueError("SparseVector.values is None")
  36. result = {}
  37. for i, v in zip(self.indices, self.values):
  38. result[str(i)] = v
  39. return result
  40. @staticmethod
  41. def from_dict(d):
  42. return SparseVector(d["indices"], d.get("values"))
  43. def __str__(self):
  44. return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
  45. def __repr__(self):
  46. return str(self)
  47. class MatchTextExpr(ABC):
  48. def __init__(
  49. self,
  50. fields: list[str],
  51. matching_text: str,
  52. topn: int,
  53. extra_options: dict = dict(),
  54. ):
  55. self.fields = fields
  56. self.matching_text = matching_text
  57. self.topn = topn
  58. self.extra_options = extra_options
  59. class MatchDenseExpr(ABC):
  60. def __init__(
  61. self,
  62. vector_column_name: str,
  63. embedding_data: VEC,
  64. embedding_data_type: str,
  65. distance_type: str,
  66. topn: int = DEFAULT_MATCH_VECTOR_TOPN,
  67. extra_options: dict = dict(),
  68. ):
  69. self.vector_column_name = vector_column_name
  70. self.embedding_data = embedding_data
  71. self.embedding_data_type = embedding_data_type
  72. self.distance_type = distance_type
  73. self.topn = topn
  74. self.extra_options = extra_options
  75. class MatchSparseExpr(ABC):
  76. def __init__(
  77. self,
  78. vector_column_name: str,
  79. sparse_data: SparseVector | dict,
  80. distance_type: str,
  81. topn: int,
  82. opt_params: dict | None = None,
  83. ):
  84. self.vector_column_name = vector_column_name
  85. self.sparse_data = sparse_data
  86. self.distance_type = distance_type
  87. self.topn = topn
  88. self.opt_params = opt_params
  89. class MatchTensorExpr(ABC):
  90. def __init__(
  91. self,
  92. column_name: str,
  93. query_data: VEC,
  94. query_data_type: str,
  95. topn: int,
  96. extra_option: dict | None = None,
  97. ):
  98. self.column_name = column_name
  99. self.query_data = query_data
  100. self.query_data_type = query_data_type
  101. self.topn = topn
  102. self.extra_option = extra_option
  103. class FusionExpr(ABC):
  104. def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
  105. self.method = method
  106. self.topn = topn
  107. self.fusion_params = fusion_params
  108. MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
  109. class OrderByExpr(ABC):
  110. def __init__(self):
  111. self.fields = list()
  112. def asc(self, field: str):
  113. self.fields.append((field, 0))
  114. return self
  115. def desc(self, field: str):
  116. self.fields.append((field, 1))
  117. return self
  118. def fields(self):
  119. return self.fields
  120. class DocStoreConnection(ABC):
  121. """
  122. Database operations
  123. """
  124. @abstractmethod
  125. def dbType(self) -> str:
  126. """
  127. Return the type of the database.
  128. """
  129. raise NotImplementedError("Not implemented")
  130. @abstractmethod
  131. def health(self) -> dict:
  132. """
  133. Return the health status of the database.
  134. """
  135. raise NotImplementedError("Not implemented")
  136. """
  137. Table operations
  138. """
  139. @abstractmethod
  140. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  141. """
  142. Create an index with given name
  143. """
  144. raise NotImplementedError("Not implemented")
  145. @abstractmethod
  146. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  147. """
  148. Delete an index with given name
  149. """
  150. raise NotImplementedError("Not implemented")
  151. @abstractmethod
  152. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  153. """
  154. Check if an index with given name exists
  155. """
  156. raise NotImplementedError("Not implemented")
  157. """
  158. CRUD operations
  159. """
  160. @abstractmethod
  161. def search(
  162. self, selectFields: list[str],
  163. highlightFields: list[str],
  164. condition: dict,
  165. matchExprs: list[MatchExpr],
  166. orderBy: OrderByExpr,
  167. offset: int,
  168. limit: int,
  169. indexNames: str|list[str],
  170. knowledgebaseIds: list[str],
  171. aggFields: list[str] = [],
  172. rank_feature: dict | None = None
  173. ):
  174. """
  175. Search with given conjunctive equivalent filtering condition and return all fields of matched documents
  176. """
  177. raise NotImplementedError("Not implemented")
  178. @abstractmethod
  179. def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
  180. """
  181. Get single chunk with given id
  182. """
  183. raise NotImplementedError("Not implemented")
  184. @abstractmethod
  185. def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
  186. """
  187. Update or insert a bulk of rows
  188. """
  189. raise NotImplementedError("Not implemented")
  190. @abstractmethod
  191. def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
  192. """
  193. Update rows with given conjunctive equivalent filtering condition
  194. """
  195. raise NotImplementedError("Not implemented")
  196. @abstractmethod
  197. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  198. """
  199. Delete rows with given conjunctive equivalent filtering condition
  200. """
  201. raise NotImplementedError("Not implemented")
  202. """
  203. Helper functions for search result
  204. """
  205. @abstractmethod
  206. def getTotal(self, res):
  207. raise NotImplementedError("Not implemented")
  208. @abstractmethod
  209. def getChunkIds(self, res):
  210. raise NotImplementedError("Not implemented")
  211. @abstractmethod
  212. def getFields(self, res, fields: list[str]) -> dict[str, dict]:
  213. raise NotImplementedError("Not implemented")
  214. @abstractmethod
  215. def getHighlight(self, res, keywords: list[str], fieldnm: str):
  216. raise NotImplementedError("Not implemented")
  217. @abstractmethod
  218. def getAggregation(self, res, fieldnm: str):
  219. raise NotImplementedError("Not implemented")
  220. """
  221. SQL
  222. """
  223. @abstractmethod
  224. def sql(sql: str, fetch_size: int, format: str):
  225. """
  226. Run the sql generated by text-to-sql
  227. """
  228. raise NotImplementedError("Not implemented")