- from abc import ABC, abstractmethod
- from dataclasses import dataclass
- import numpy as np
- import polars as pl
-
- DEFAULT_MATCH_VECTOR_TOPN = 10
- DEFAULT_MATCH_SPARSE_TOPN = 10
- VEC = list | np.ndarray
-
-
- @dataclass
- class SparseVector:
- indices: list[int]
- values: list[float] | list[int] | None = None
-
- def __post_init__(self):
- assert (self.values is None) or (len(self.indices) == len(self.values))
-
- def to_dict_old(self):
- d = {"indices": self.indices}
- if self.values is not None:
- d["values"] = self.values
- return d
-
- def to_dict(self):
- if self.values is None:
- raise ValueError("SparseVector.values is None")
- result = {}
- for i, v in zip(self.indices, self.values):
- result[str(i)] = v
- return result
-
- @staticmethod
- def from_dict(d):
- return SparseVector(d["indices"], d.get("values"))
-
- def __str__(self):
- return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
-
- def __repr__(self):
- return str(self)
-
-
- class MatchTextExpr(ABC):
- def __init__(
- self,
- fields: list[str],
- matching_text: str,
- topn: int,
- extra_options: dict = dict(),
- ):
- self.fields = fields
- self.matching_text = matching_text
- self.topn = topn
- self.extra_options = extra_options
-
-
- class MatchDenseExpr(ABC):
- def __init__(
- self,
- vector_column_name: str,
- embedding_data: VEC,
- embedding_data_type: str,
- distance_type: str,
- topn: int = DEFAULT_MATCH_VECTOR_TOPN,
- extra_options: dict = dict(),
- ):
- self.vector_column_name = vector_column_name
- self.embedding_data = embedding_data
- self.embedding_data_type = embedding_data_type
- self.distance_type = distance_type
- self.topn = topn
- self.extra_options = extra_options
-
-
- class MatchSparseExpr(ABC):
- def __init__(
- self,
- vector_column_name: str,
- sparse_data: SparseVector | dict,
- distance_type: str,
- topn: int,
- opt_params: dict | None = None,
- ):
- self.vector_column_name = vector_column_name
- self.sparse_data = sparse_data
- self.distance_type = distance_type
- self.topn = topn
- self.opt_params = opt_params
-
-
- class MatchTensorExpr(ABC):
- def __init__(
- self,
- column_name: str,
- query_data: VEC,
- query_data_type: str,
- topn: int,
- extra_option: dict | None = None,
- ):
- self.column_name = column_name
- self.query_data = query_data
- self.query_data_type = query_data_type
- self.topn = topn
- self.extra_option = extra_option
-
-
- class FusionExpr(ABC):
- def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
- self.method = method
- self.topn = topn
- self.fusion_params = fusion_params
-
-
- MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
-
- class OrderByExpr(ABC):
- def __init__(self):
- self.fields = list()
- def asc(self, field: str):
- self.fields.append((field, 0))
- return self
- def desc(self, field: str):
- self.fields.append((field, 1))
- return self
- def fields(self):
- return self.fields
-
- class DocStoreConnection(ABC):
- """
- Database operations
- """
-
- @abstractmethod
- def dbType(self) -> str:
- """
- Return the type of the database.
- """
- raise NotImplementedError("Not implemented")
-
- @abstractmethod
- def health(self) -> dict:
- """
- Return the health status of the database.
- """
- raise NotImplementedError("Not implemented")
-
- """
- Table operations
- """
-
- @abstractmethod
- def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
- """
- Create an index with given name
- """
- raise NotImplementedError("Not implemented")
-
- @abstractmethod
- def deleteIdx(self, indexName: str, knowledgebaseId: str):
- """
- Delete an index with given name
- """
- raise NotImplementedError("Not implemented")
-
- @abstractmethod
- def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
- """
- Check if an index with given name exists
- """
- raise NotImplementedError("Not implemented")
-
- """
- CRUD operations
- """
-
- @abstractmethod
- def search(
- self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]
- ) -> list[dict] | pl.DataFrame:
- """
- Search with given conjunctive equivalent filtering condition and return all fields of matched documents
- """
- raise NotImplementedError("Not implemented")
-
- @abstractmethod
- def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
- """
- Get single chunk with given id
- """
- raise NotImplementedError("Not implemented")
-
- @abstractmethod
- def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
- """
- Update or insert a bulk of rows
- """
- raise NotImplementedError("Not implemented")
-
- @abstractmethod
- def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
- """
- Update rows with given conjunctive equivalent filtering condition
- """
- raise NotImplementedError("Not implemented")
-
- @abstractmethod
- def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
- """
- Delete rows with given conjunctive equivalent filtering condition
- """
- raise NotImplementedError("Not implemented")
-
- """
- Helper functions for search result
- """
-
- @abstractmethod
- def getTotal(self, res):
- raise NotImplementedError("Not implemented")
-
- @abstractmethod
- def getChunkIds(self, res):
- raise NotImplementedError("Not implemented")
-
- @abstractmethod
- def getFields(self, res, fields: list[str]) -> dict[str, dict]:
- raise NotImplementedError("Not implemented")
-
- @abstractmethod
- def getHighlight(self, res, keywords: list[str], fieldnm: str):
- raise NotImplementedError("Not implemented")
-
- @abstractmethod
- def getAggregation(self, res, fieldnm: str):
- raise NotImplementedError("Not implemented")
-
- """
- SQL
- """
- @abstractmethod
- def sql(sql: str, fetch_size: int, format: str):
- """
- Run the sql generated by text-to-sql
- """
- raise NotImplementedError("Not implemented")
|