|
|
|
@@ -1,5 +1,6 @@ |
|
|
|
import json |
|
|
|
import logging |
|
|
|
import math |
|
|
|
from typing import Any, Optional |
|
|
|
|
|
|
|
import tablestore # type: ignore |
|
|
|
@@ -22,6 +23,7 @@ class TableStoreConfig(BaseModel): |
|
|
|
access_key_secret: Optional[str] = None |
|
|
|
instance_name: Optional[str] = None |
|
|
|
endpoint: Optional[str] = None |
|
|
|
normalize_full_text_bm25_score: Optional[bool] = False |
|
|
|
|
|
|
|
@model_validator(mode="before") |
|
|
|
@classmethod |
|
|
|
@@ -47,6 +49,7 @@ class TableStoreVector(BaseVector): |
|
|
|
config.access_key_secret, |
|
|
|
config.instance_name, |
|
|
|
) |
|
|
|
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score |
|
|
|
self._table_name = f"{collection_name}" |
|
|
|
self._index_name = f"{collection_name}_idx" |
|
|
|
self._tags_field = f"{Field.METADATA_KEY.value}_tags" |
|
|
|
@@ -131,8 +134,8 @@ class TableStoreVector(BaseVector): |
|
|
|
filtered_list = None |
|
|
|
if document_ids_filter: |
|
|
|
filtered_list = ["document_id=" + item for item in document_ids_filter] |
|
|
|
|
|
|
|
return self._search_by_full_text(query, filtered_list, top_k) |
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0) |
|
|
|
return self._search_by_full_text(query, filtered_list, top_k, score_threshold) |
|
|
|
|
|
|
|
def delete(self) -> None: |
|
|
|
self._delete_table_if_exist() |
|
|
|
@@ -318,7 +321,19 @@ class TableStoreVector(BaseVector): |
|
|
|
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) |
|
|
|
return documents |
|
|
|
|
|
|
|
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]: |
|
|
|
@staticmethod |
|
|
|
def _normalize_score_exp_decay(score: float, k: float = 0.15) -> float: |
|
|
|
""" |
|
|
|
Args: |
|
|
|
score: BM25 search score. |
|
|
|
k: decay factor, the larger the k, the steeper the low score end |
|
|
|
""" |
|
|
|
normalized_score = 1 - math.exp(-k * score) |
|
|
|
return max(0.0, min(1.0, normalized_score)) |
|
|
|
|
|
|
|
def _search_by_full_text( |
|
|
|
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float |
|
|
|
) -> list[Document]: |
|
|
|
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[]) |
|
|
|
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) |
|
|
|
|
|
|
|
@@ -339,15 +354,27 @@ class TableStoreVector(BaseVector): |
|
|
|
|
|
|
|
documents = [] |
|
|
|
for search_hit in search_response.search_hits: |
|
|
|
score = None |
|
|
|
if self._normalize_full_text_bm25_score: |
|
|
|
score = self._normalize_score_exp_decay(search_hit.score) |
|
|
|
|
|
|
|
# skip when score is below threshold and use normalize score |
|
|
|
if score and score <= score_threshold: |
|
|
|
continue |
|
|
|
|
|
|
|
ots_column_map = {} |
|
|
|
for col in search_hit.row[1]: |
|
|
|
ots_column_map[col[0]] = col[1] |
|
|
|
|
|
|
|
vector_str = ots_column_map.get(Field.VECTOR.value) |
|
|
|
metadata_str = ots_column_map.get(Field.METADATA_KEY.value) |
|
|
|
vector = json.loads(vector_str) if vector_str else None |
|
|
|
metadata = json.loads(metadata_str) if metadata_str else {} |
|
|
|
|
|
|
|
vector_str = ots_column_map.get(Field.VECTOR.value) |
|
|
|
vector = json.loads(vector_str) if vector_str else None |
|
|
|
|
|
|
|
if score: |
|
|
|
metadata["score"] = score |
|
|
|
|
|
|
|
documents.append( |
|
|
|
Document( |
|
|
|
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "", |
|
|
|
@@ -355,6 +382,8 @@ class TableStoreVector(BaseVector): |
|
|
|
metadata=metadata, |
|
|
|
) |
|
|
|
) |
|
|
|
if self._normalize_full_text_bm25_score: |
|
|
|
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) |
|
|
|
return documents |
|
|
|
|
|
|
|
|
|
|
|
@@ -375,5 +404,6 @@ class TableStoreVectorFactory(AbstractVectorFactory): |
|
|
|
instance_name=dify_config.TABLESTORE_INSTANCE_NAME, |
|
|
|
access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID, |
|
|
|
access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET, |
|
|
|
normalize_full_text_bm25_score=dify_config.TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE, |
|
|
|
), |
|
|
|
) |