|
|
|
@@ -118,10 +118,21 @@ class TableStoreVector(BaseVector): |
|
|
|
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: |
|
|
|
top_k = kwargs.get("top_k", 4) |
|
|
|
return self._search_by_vector(query_vector, top_k) |
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
filtered_list = None |
|
|
|
if document_ids_filter: |
|
|
|
filtered_list = ["document_id=" + item for item in document_ids_filter] |
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0) |
|
|
|
return self._search_by_vector(query_vector, filtered_list, top_k, score_threshold) |
|
|
|
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: |
|
|
|
return self._search_by_full_text(query) |
|
|
|
top_k = kwargs.get("top_k", 4) |
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
filtered_list = None |
|
|
|
if document_ids_filter: |
|
|
|
filtered_list = ["document_id=" + item for item in document_ids_filter] |
|
|
|
|
|
|
|
return self._search_by_full_text(query, filtered_list, top_k) |
|
|
|
|
|
|
|
def delete(self) -> None: |
|
|
|
self._delete_table_if_exist() |
|
|
|
@@ -230,32 +241,51 @@ class TableStoreVector(BaseVector): |
|
|
|
primary_key = [("id", id)] |
|
|
|
row = tablestore.Row(primary_key) |
|
|
|
self._tablestore_client.delete_row(self._table_name, row, None) |
|
|
|
logging.info("Tablestore delete row successfully. id:%s", id) |
|
|
|
|
|
|
|
def _search_by_metadata(self, key: str, value: str) -> list[str]: |
|
|
|
query = tablestore.SearchQuery( |
|
|
|
tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)), |
|
|
|
limit=100, |
|
|
|
limit=1000, |
|
|
|
get_total_count=False, |
|
|
|
) |
|
|
|
rows: list[str] = [] |
|
|
|
next_token = None |
|
|
|
while True: |
|
|
|
if next_token is not None: |
|
|
|
query.next_token = next_token |
|
|
|
|
|
|
|
search_response = self._tablestore_client.search( |
|
|
|
table_name=self._table_name, |
|
|
|
index_name=self._index_name, |
|
|
|
search_query=query, |
|
|
|
columns_to_get=tablestore.ColumnsToGet( |
|
|
|
column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED |
|
|
|
), |
|
|
|
) |
|
|
|
|
|
|
|
search_response = self._tablestore_client.search( |
|
|
|
table_name=self._table_name, |
|
|
|
index_name=self._index_name, |
|
|
|
search_query=query, |
|
|
|
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), |
|
|
|
) |
|
|
|
if search_response is not None: |
|
|
|
rows.extend([row[0][0][1] for row in search_response.rows]) |
|
|
|
|
|
|
|
return [row[0][0][1] for row in search_response.rows] |
|
|
|
if search_response is None or search_response.next_token == b"": |
|
|
|
break |
|
|
|
else: |
|
|
|
next_token = search_response.next_token |
|
|
|
|
|
|
|
def _search_by_vector(self, query_vector: list[float], top_k: int) -> list[Document]: |
|
|
|
ots_query = tablestore.KnnVectorQuery( |
|
|
|
return rows |
|
|
|
|
|
|
|
def _search_by_vector( |
|
|
|
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float |
|
|
|
) -> list[Document]: |
|
|
|
knn_vector_query = tablestore.KnnVectorQuery( |
|
|
|
field_name=Field.VECTOR.value, |
|
|
|
top_k=top_k, |
|
|
|
float32_query_vector=query_vector, |
|
|
|
) |
|
|
|
if document_ids_filter: |
|
|
|
knn_vector_query.filter = tablestore.TermsQuery(self._tags_field, document_ids_filter) |
|
|
|
|
|
|
|
sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]) |
|
|
|
search_query = tablestore.SearchQuery(ots_query, limit=top_k, get_total_count=False, sort=sort) |
|
|
|
search_query = tablestore.SearchQuery(knn_vector_query, limit=top_k, get_total_count=False, sort=sort) |
|
|
|
|
|
|
|
search_response = self._tablestore_client.search( |
|
|
|
table_name=self._table_name, |
|
|
|
@@ -263,30 +293,32 @@ class TableStoreVector(BaseVector): |
|
|
|
search_query=search_query, |
|
|
|
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), |
|
|
|
) |
|
|
|
logging.info( |
|
|
|
"Tablestore search successfully. request_id:%s", |
|
|
|
search_response.request_id, |
|
|
|
) |
|
|
|
return self._to_query_result(search_response) |
|
|
|
|
|
|
|
def _to_query_result(self, search_response: tablestore.SearchResponse) -> list[Document]: |
|
|
|
documents = [] |
|
|
|
for row in search_response.rows: |
|
|
|
documents.append( |
|
|
|
Document( |
|
|
|
page_content=row[1][2][1], |
|
|
|
vector=json.loads(row[1][3][1]), |
|
|
|
metadata=json.loads(row[1][0][1]), |
|
|
|
for search_hit in search_response.search_hits: |
|
|
|
if search_hit.score > score_threshold: |
|
|
|
metadata = json.loads(search_hit.row[1][0][1]) |
|
|
|
metadata["score"] = search_hit.score |
|
|
|
documents.append( |
|
|
|
Document( |
|
|
|
page_content=search_hit.row[1][2][1], |
|
|
|
vector=json.loads(search_hit.row[1][3][1]), |
|
|
|
metadata=metadata, |
|
|
|
) |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) |
|
|
|
return documents |
|
|
|
|
|
|
|
def _search_by_full_text(self, query: str) -> list[Document]: |
|
|
|
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]: |
|
|
|
bool_query = tablestore.BoolQuery() |
|
|
|
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) |
|
|
|
|
|
|
|
if document_ids_filter: |
|
|
|
bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter)) |
|
|
|
|
|
|
|
search_query = tablestore.SearchQuery( |
|
|
|
query=tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value), |
|
|
|
query=bool_query, |
|
|
|
sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]), |
|
|
|
limit=100, |
|
|
|
limit=top_k, |
|
|
|
) |
|
|
|
search_response = self._tablestore_client.search( |
|
|
|
table_name=self._table_name, |
|
|
|
@@ -295,7 +327,16 @@ class TableStoreVector(BaseVector): |
|
|
|
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), |
|
|
|
) |
|
|
|
|
|
|
|
return self._to_query_result(search_response) |
|
|
|
documents = [] |
|
|
|
for search_hit in search_response.search_hits: |
|
|
|
documents.append( |
|
|
|
Document( |
|
|
|
page_content=search_hit.row[1][2][1], |
|
|
|
vector=json.loads(search_hit.row[1][3][1]), |
|
|
|
metadata=json.loads(search_hit.row[1][0][1]), |
|
|
|
) |
|
|
|
) |
|
|
|
return documents |
|
|
|
|
|
|
|
|
|
|
|
class TableStoreVectorFactory(AbstractVectorFactory): |