Browse Source

fix: tablestore vdb support metadata filter (#22774)

Co-authored-by: xiaozhiqing.xzq <xiaozhiqing.xzq@alibaba-inc.com>
tags/1.7.0
wanttobeamaster 3 months ago
parent
commit
a2048fd0f4
No account linked to committer's email address

+ 74
- 33
api/core/rag/datasource/vdb/tablestore/tablestore_vector.py View File



def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4) top_k = kwargs.get("top_k", 4)
return self._search_by_vector(query_vector, top_k)
document_ids_filter = kwargs.get("document_ids_filter")
filtered_list = None
if document_ids_filter:
filtered_list = ["document_id=" + item for item in document_ids_filter]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._search_by_vector(query_vector, filtered_list, top_k, score_threshold)


def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._search_by_full_text(query)
top_k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
filtered_list = None
if document_ids_filter:
filtered_list = ["document_id=" + item for item in document_ids_filter]

return self._search_by_full_text(query, filtered_list, top_k)


def delete(self) -> None: def delete(self) -> None:
self._delete_table_if_exist() self._delete_table_if_exist()
primary_key = [("id", id)] primary_key = [("id", id)]
row = tablestore.Row(primary_key) row = tablestore.Row(primary_key)
self._tablestore_client.delete_row(self._table_name, row, None) self._tablestore_client.delete_row(self._table_name, row, None)
logging.info("Tablestore delete row successfully. id:%s", id)


def _search_by_metadata(self, key: str, value: str) -> list[str]: def _search_by_metadata(self, key: str, value: str) -> list[str]:
query = tablestore.SearchQuery( query = tablestore.SearchQuery(
tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)), tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)),
limit=100,
limit=1000,
get_total_count=False, get_total_count=False,
) )
rows: list[str] = []
next_token = None
while True:
if next_token is not None:
query.next_token = next_token

search_response = self._tablestore_client.search(
table_name=self._table_name,
index_name=self._index_name,
search_query=query,
columns_to_get=tablestore.ColumnsToGet(
column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED
),
)


search_response = self._tablestore_client.search(
table_name=self._table_name,
index_name=self._index_name,
search_query=query,
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
)
if search_response is not None:
rows.extend([row[0][0][1] for row in search_response.rows])


return [row[0][0][1] for row in search_response.rows]
if search_response is None or search_response.next_token == b"":
break
else:
next_token = search_response.next_token


def _search_by_vector(self, query_vector: list[float], top_k: int) -> list[Document]:
ots_query = tablestore.KnnVectorQuery(
return rows

def _search_by_vector(
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
) -> list[Document]:
knn_vector_query = tablestore.KnnVectorQuery(
field_name=Field.VECTOR.value, field_name=Field.VECTOR.value,
top_k=top_k, top_k=top_k,
float32_query_vector=query_vector, float32_query_vector=query_vector,
) )
if document_ids_filter:
knn_vector_query.filter = tablestore.TermsQuery(self._tags_field, document_ids_filter)

sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]) sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)])
search_query = tablestore.SearchQuery(ots_query, limit=top_k, get_total_count=False, sort=sort)
search_query = tablestore.SearchQuery(knn_vector_query, limit=top_k, get_total_count=False, sort=sort)


search_response = self._tablestore_client.search( search_response = self._tablestore_client.search(
table_name=self._table_name, table_name=self._table_name,
search_query=search_query, search_query=search_query,
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
) )
logging.info(
"Tablestore search successfully. request_id:%s",
search_response.request_id,
)
return self._to_query_result(search_response)

def _to_query_result(self, search_response: tablestore.SearchResponse) -> list[Document]:
documents = [] documents = []
for row in search_response.rows:
documents.append(
Document(
page_content=row[1][2][1],
vector=json.loads(row[1][3][1]),
metadata=json.loads(row[1][0][1]),
for search_hit in search_response.search_hits:
if search_hit.score > score_threshold:
metadata = json.loads(search_hit.row[1][0][1])
metadata["score"] = search_hit.score
documents.append(
Document(
page_content=search_hit.row[1][2][1],
vector=json.loads(search_hit.row[1][3][1]),
metadata=metadata,
)
) )
)

documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents return documents


def _search_by_full_text(self, query: str) -> list[Document]:
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]:
bool_query = tablestore.BoolQuery()
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))

if document_ids_filter:
bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter))

search_query = tablestore.SearchQuery( search_query = tablestore.SearchQuery(
query=tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value),
query=bool_query,
sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]), sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]),
limit=100,
limit=top_k,
) )
search_response = self._tablestore_client.search( search_response = self._tablestore_client.search(
table_name=self._table_name, table_name=self._table_name,
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
) )


return self._to_query_result(search_response)
documents = []
for search_hit in search_response.search_hits:
documents.append(
Document(
page_content=search_hit.row[1][2][1],
vector=json.loads(search_hit.row[1][3][1]),
metadata=json.loads(search_hit.row[1][0][1]),
)
)
return documents




class TableStoreVectorFactory(AbstractVectorFactory): class TableStoreVectorFactory(AbstractVectorFactory):

+ 48
- 0
api/tests/integration_tests/vdb/tablestore/test_tablestore.py View File

import os import os
import uuid

import tablestore


from core.rag.datasource.vdb.tablestore.tablestore_vector import ( from core.rag.datasource.vdb.tablestore.tablestore_vector import (
TableStoreConfig, TableStoreConfig,
) )
from tests.integration_tests.vdb.test_vector_store import ( from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest, AbstractVectorTest,
get_example_document,
get_example_text,
setup_mock_redis, setup_mock_redis,
) )


assert len(ids) == 1 assert len(ids) == 1
assert ids[0] == self.example_doc_id assert ids[0] == self.example_doc_id


def create_vector(self):
self.vector.create(
texts=[get_example_document(doc_id=self.example_doc_id)],
embeddings=[self.example_embedding],
)
while True:
search_response = self.vector._tablestore_client.search(
table_name=self.vector._table_name,
index_name=self.vector._index_name,
search_query=tablestore.SearchQuery(query=tablestore.MatchAllQuery(), get_total_count=True, limit=0),
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
)
if search_response.total_count == 1:
break

def search_by_vector(self):
super().search_by_vector()
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[self.example_doc_id])
assert len(docs) == 1
assert docs[0].metadata["doc_id"] == self.example_doc_id
assert docs[0].metadata["score"] > 0

docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[str(uuid.uuid4())])
assert len(docs) == 0

def search_by_full_text(self):
super().search_by_full_text()
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
assert len(docs) == 1
assert docs[0].metadata["doc_id"] == self.example_doc_id
assert not hasattr(docs[0], "score")

docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
assert len(docs) == 0

def run_all_tests(self):
try:
self.vector.delete()
except Exception:
pass

return super().run_all_tests()



def test_tablestore_vector(setup_mock_redis): def test_tablestore_vector(setup_mock_redis):
TableStoreVectorTest().run_all_tests() TableStoreVectorTest().run_all_tests()

Loading…
Cancel
Save