|
|
|
@@ -31,6 +31,7 @@ class OceanBaseVectorConfig(BaseModel): |
|
|
|
user: str |
|
|
|
password: str |
|
|
|
database: str |
|
|
|
enable_hybrid_search: bool = False |
|
|
|
|
|
|
|
@model_validator(mode="before") |
|
|
|
@classmethod |
|
|
|
@@ -57,6 +58,7 @@ class OceanBaseVector(BaseVector): |
|
|
|
password=self._config.password, |
|
|
|
db_name=self._config.database, |
|
|
|
) |
|
|
|
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported |
|
|
|
|
|
|
|
def get_type(self) -> str: |
|
|
|
return VectorType.OCEANBASE |
|
|
|
@@ -98,6 +100,16 @@ class OceanBaseVector(BaseVector): |
|
|
|
columns=cols, |
|
|
|
vidxs=vidx_params, |
|
|
|
) |
|
|
|
try: |
|
|
|
if self._hybrid_search_enabled: |
|
|
|
self._client.perform_raw_text_sql(f"""ALTER TABLE {self._collection_name} |
|
|
|
ADD FULLTEXT INDEX fulltext_index_for_col_text (text) WITH PARSER ik""") |
|
|
|
except Exception as e: |
|
|
|
raise Exception( |
|
|
|
"Failed to add fulltext index to the target table, your OceanBase version must be 4.3.5.1 or above " |
|
|
|
+ "to support fulltext index and vector index in the same table", |
|
|
|
e, |
|
|
|
) |
|
|
|
vals = [] |
|
|
|
params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'") |
|
|
|
for row in params: |
|
|
|
@@ -116,6 +128,27 @@ class OceanBaseVector(BaseVector): |
|
|
|
) |
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600) |
|
|
|
|
|
|
|
def _check_hybrid_search_support(self) -> bool: |
|
|
|
""" |
|
|
|
Check if the current OceanBase version supports hybrid search. |
|
|
|
Returns True if the version is >= 4.3.5.1, otherwise False. |
|
|
|
""" |
|
|
|
if not self._config.enable_hybrid_search: |
|
|
|
return False |
|
|
|
|
|
|
|
try: |
|
|
|
from packaging import version |
|
|
|
|
|
|
|
# return OceanBase_CE 4.3.5.1 (r101000042025031818-bxxxx) (Built Mar 18 2025 18:13:36) |
|
|
|
result = self._client.perform_raw_text_sql("SELECT @@version_comment AS version") |
|
|
|
ob_full_version = result.fetchone()[0] |
|
|
|
ob_version = ob_full_version.split()[1] |
|
|
|
logger.debug("Current OceanBase version is %s", ob_version) |
|
|
|
return version.parse(ob_version).base_version >= version.parse("4.3.5.1").base_version |
|
|
|
except Exception as e: |
|
|
|
logger.warning(f"Failed to check OceanBase version: {str(e)}. Disabling hybrid search.") |
|
|
|
return False |
|
|
|
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
ids = self._get_uuids(documents) |
|
|
|
for id, doc, emb in zip(ids, documents, embeddings): |
|
|
|
@@ -130,7 +163,7 @@ class OceanBaseVector(BaseVector): |
|
|
|
) |
|
|
|
|
|
|
|
def text_exists(self, id: str) -> bool: |
|
|
|
cur = self._client.get(table_name=self._collection_name, id=id) |
|
|
|
cur = self._client.get(table_name=self._collection_name, ids=id) |
|
|
|
return bool(cur.rowcount != 0) |
|
|
|
|
|
|
|
def delete_by_ids(self, ids: list[str]) -> None: |
|
|
|
@@ -139,9 +172,12 @@ class OceanBaseVector(BaseVector): |
|
|
|
self._client.delete(table_name=self._collection_name, ids=ids) |
|
|
|
|
|
|
|
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: |
|
|
|
from sqlalchemy import text |
|
|
|
|
|
|
|
cur = self._client.get( |
|
|
|
table_name=self._collection_name, |
|
|
|
where_clause=f"metadata->>'$.{key}' = '{value}'", |
|
|
|
ids=None, |
|
|
|
where_clause=[text(f"metadata->>'$.{key}' = '{value}'")], |
|
|
|
output_column_name=["id"], |
|
|
|
) |
|
|
|
return [row[0] for row in cur] |
|
|
|
@@ -151,36 +187,84 @@ class OceanBaseVector(BaseVector): |
|
|
|
self.delete_by_ids(ids) |
|
|
|
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: |
|
|
|
return [] |
|
|
|
if not self._hybrid_search_enabled: |
|
|
|
return [] |
|
|
|
|
|
|
|
try: |
|
|
|
top_k = kwargs.get("top_k", 5) |
|
|
|
if not isinstance(top_k, int) or top_k <= 0: |
|
|
|
raise ValueError("top_k must be a positive integer") |
|
|
|
|
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
where_clause = "" |
|
|
|
if document_ids_filter: |
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) |
|
|
|
where_clause = f" AND metadata->>'$.document_id' IN ({document_ids})" |
|
|
|
|
|
|
|
full_sql = f"""SELECT metadata, text, MATCH (text) AGAINST (:query) AS score |
|
|
|
FROM {self._collection_name} |
|
|
|
WHERE MATCH (text) AGAINST (:query) > 0 |
|
|
|
{where_clause} |
|
|
|
ORDER BY score DESC |
|
|
|
LIMIT {top_k}""" |
|
|
|
|
|
|
|
with self._client.engine.connect() as conn: |
|
|
|
with conn.begin(): |
|
|
|
from sqlalchemy import text |
|
|
|
|
|
|
|
result = conn.execute(text(full_sql), {"query": query}) |
|
|
|
rows = result.fetchall() |
|
|
|
|
|
|
|
docs = [] |
|
|
|
for row in rows: |
|
|
|
metadata_str, _text, score = row |
|
|
|
try: |
|
|
|
metadata = json.loads(metadata_str) |
|
|
|
except json.JSONDecodeError: |
|
|
|
print(f"Invalid JSON metadata: {metadata_str}") |
|
|
|
metadata = {} |
|
|
|
metadata["score"] = score |
|
|
|
docs.append(Document(page_content=_text, metadata=metadata)) |
|
|
|
|
|
|
|
return docs |
|
|
|
except Exception as e: |
|
|
|
logger.warning(f"Failed to fulltext search: {str(e)}.") |
|
|
|
return [] |
|
|
|
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: |
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
where_clause = None |
|
|
|
_where_clause = None |
|
|
|
if document_ids_filter: |
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) |
|
|
|
where_clause = f"metadata->>'$.document_id' in ({document_ids})" |
|
|
|
from sqlalchemy import text |
|
|
|
|
|
|
|
_where_clause = [text(where_clause)] |
|
|
|
ef_search = kwargs.get("ef_search", self._hnsw_ef_search) |
|
|
|
if ef_search != self._hnsw_ef_search: |
|
|
|
self._client.set_ob_hnsw_ef_search(ef_search) |
|
|
|
self._hnsw_ef_search = ef_search |
|
|
|
topk = kwargs.get("top_k", 10) |
|
|
|
cur = self._client.ann_search( |
|
|
|
table_name=self._collection_name, |
|
|
|
vec_column_name="vector", |
|
|
|
vec_data=query_vector, |
|
|
|
topk=topk, |
|
|
|
distance_func=func.l2_distance, |
|
|
|
output_column_names=["text", "metadata"], |
|
|
|
with_dist=True, |
|
|
|
where_clause=where_clause, |
|
|
|
) |
|
|
|
try: |
|
|
|
cur = self._client.ann_search( |
|
|
|
table_name=self._collection_name, |
|
|
|
vec_column_name="vector", |
|
|
|
vec_data=query_vector, |
|
|
|
topk=topk, |
|
|
|
distance_func=func.l2_distance, |
|
|
|
output_column_names=["text", "metadata"], |
|
|
|
with_dist=True, |
|
|
|
where_clause=_where_clause, |
|
|
|
) |
|
|
|
except Exception as e: |
|
|
|
raise Exception("Failed to search by vector. ", e) |
|
|
|
docs = [] |
|
|
|
for text, metadata, distance in cur: |
|
|
|
for _text, metadata, distance in cur: |
|
|
|
metadata = json.loads(metadata) |
|
|
|
metadata["score"] = 1 - distance / math.sqrt(2) |
|
|
|
docs.append( |
|
|
|
Document( |
|
|
|
page_content=text, |
|
|
|
page_content=_text, |
|
|
|
metadata=metadata, |
|
|
|
) |
|
|
|
) |