選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

doc_store_conn.py 6.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. from abc import ABC, abstractmethod
  2. from typing import Optional, Union
  3. from dataclasses import dataclass
  4. import numpy as np
  5. import polars as pl
  6. from typing import List, Dict
  7. DEFAULT_MATCH_VECTOR_TOPN = 10
  8. DEFAULT_MATCH_SPARSE_TOPN = 10
  9. VEC = Union[list, np.ndarray]
  10. @dataclass
  11. class SparseVector:
  12. indices: list[int]
  13. values: Union[list[float], list[int], None] = None
  14. def __post_init__(self):
  15. assert (self.values is None) or (len(self.indices) == len(self.values))
  16. def to_dict_old(self):
  17. d = {"indices": self.indices}
  18. if self.values is not None:
  19. d["values"] = self.values
  20. return d
  21. def to_dict(self):
  22. if self.values is None:
  23. raise ValueError("SparseVector.values is None")
  24. result = {}
  25. for i, v in zip(self.indices, self.values):
  26. result[str(i)] = v
  27. return result
  28. @staticmethod
  29. def from_dict(d):
  30. return SparseVector(d["indices"], d.get("values"))
  31. def __str__(self):
  32. return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
  33. def __repr__(self):
  34. return str(self)
  35. class MatchTextExpr(ABC):
  36. def __init__(
  37. self,
  38. fields: str,
  39. matching_text: str,
  40. topn: int,
  41. extra_options: dict = dict(),
  42. ):
  43. self.fields = fields
  44. self.matching_text = matching_text
  45. self.topn = topn
  46. self.extra_options = extra_options
  47. class MatchDenseExpr(ABC):
  48. def __init__(
  49. self,
  50. vector_column_name: str,
  51. embedding_data: VEC,
  52. embedding_data_type: str,
  53. distance_type: str,
  54. topn: int = DEFAULT_MATCH_VECTOR_TOPN,
  55. extra_options: dict = dict(),
  56. ):
  57. self.vector_column_name = vector_column_name
  58. self.embedding_data = embedding_data
  59. self.embedding_data_type = embedding_data_type
  60. self.distance_type = distance_type
  61. self.topn = topn
  62. self.extra_options = extra_options
  63. class MatchSparseExpr(ABC):
  64. def __init__(
  65. self,
  66. vector_column_name: str,
  67. sparse_data: SparseVector | dict,
  68. distance_type: str,
  69. topn: int,
  70. opt_params: Optional[dict] = None,
  71. ):
  72. self.vector_column_name = vector_column_name
  73. self.sparse_data = sparse_data
  74. self.distance_type = distance_type
  75. self.topn = topn
  76. self.opt_params = opt_params
  77. class MatchTensorExpr(ABC):
  78. def __init__(
  79. self,
  80. column_name: str,
  81. query_data: VEC,
  82. query_data_type: str,
  83. topn: int,
  84. extra_option: Optional[dict] = None,
  85. ):
  86. self.column_name = column_name
  87. self.query_data = query_data
  88. self.query_data_type = query_data_type
  89. self.topn = topn
  90. self.extra_option = extra_option
  91. class FusionExpr(ABC):
  92. def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None):
  93. self.method = method
  94. self.topn = topn
  95. self.fusion_params = fusion_params
  96. MatchExpr = Union[
  97. MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr
  98. ]
  99. class OrderByExpr(ABC):
  100. def __init__(self):
  101. self.fields = list()
  102. def asc(self, field: str):
  103. self.fields.append((field, 0))
  104. return self
  105. def desc(self, field: str):
  106. self.fields.append((field, 1))
  107. return self
  108. def fields(self):
  109. return self.fields
  110. class DocStoreConnection(ABC):
  111. """
  112. Database operations
  113. """
  114. @abstractmethod
  115. def dbType(self) -> str:
  116. """
  117. Return the type of the database.
  118. """
  119. raise NotImplementedError("Not implemented")
  120. @abstractmethod
  121. def health(self) -> dict:
  122. """
  123. Return the health status of the database.
  124. """
  125. raise NotImplementedError("Not implemented")
  126. """
  127. Table operations
  128. """
  129. @abstractmethod
  130. def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
  131. """
  132. Create an index with given name
  133. """
  134. raise NotImplementedError("Not implemented")
  135. @abstractmethod
  136. def deleteIdx(self, indexName: str, knowledgebaseId: str):
  137. """
  138. Delete an index with given name
  139. """
  140. raise NotImplementedError("Not implemented")
  141. @abstractmethod
  142. def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
  143. """
  144. Check if an index with given name exists
  145. """
  146. raise NotImplementedError("Not implemented")
  147. """
  148. CRUD operations
  149. """
  150. @abstractmethod
  151. def search(
  152. self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]
  153. ) -> list[dict] | pl.DataFrame:
  154. """
  155. Search with given conjunctive equivalent filtering condition and return all fields of matched documents
  156. """
  157. raise NotImplementedError("Not implemented")
  158. @abstractmethod
  159. def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
  160. """
  161. Get single chunk with given id
  162. """
  163. raise NotImplementedError("Not implemented")
  164. @abstractmethod
  165. def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
  166. """
  167. Update or insert a bulk of rows
  168. """
  169. raise NotImplementedError("Not implemented")
  170. @abstractmethod
  171. def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
  172. """
  173. Update rows with given conjunctive equivalent filtering condition
  174. """
  175. raise NotImplementedError("Not implemented")
  176. @abstractmethod
  177. def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
  178. """
  179. Delete rows with given conjunctive equivalent filtering condition
  180. """
  181. raise NotImplementedError("Not implemented")
  182. """
  183. Helper functions for search result
  184. """
  185. @abstractmethod
  186. def getTotal(self, res):
  187. raise NotImplementedError("Not implemented")
  188. @abstractmethod
  189. def getChunkIds(self, res):
  190. raise NotImplementedError("Not implemented")
  191. @abstractmethod
  192. def getFields(self, res, fields: List[str]) -> Dict[str, dict]:
  193. raise NotImplementedError("Not implemented")
  194. @abstractmethod
  195. def getHighlight(self, res, keywords: List[str], fieldnm: str):
  196. raise NotImplementedError("Not implemented")
  197. @abstractmethod
  198. def getAggregation(self, res, fieldnm: str):
  199. raise NotImplementedError("Not implemented")
  200. """
  201. SQL
  202. """
  203. @abstractmethod
  204. def sql(sql: str, fetch_size: int, format: str):
  205. """
  206. Run the sql generated by text-to-sql
  207. """
  208. raise NotImplementedError("Not implemented")