Browse Source

fix: update analyticdb vector to do filter by metadata (#22698)

Co-authored-by: xiaozeyu <xiaozeyu.xzy@alibaba-inc.com>
tags/1.7.0
8bitpd 3 months ago
parent
commit
9251a66a10
No account linked to committer's email address

+ 13
- 2
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py View File

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from alibabacloud_gpdb20160503 import models as gpdb_20160503_models


document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"metadata_->>'document_id' IN ({document_ids})"

score_threshold = kwargs.get("score_threshold") or 0.0 score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest( request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id, dbinstance_id=self.config.instance_id,
vector=query_vector, vector=query_vector,
content=None, content=None,
top_k=kwargs.get("top_k", 4), top_k=kwargs.get("top_k", 4),
filter=None,
filter=where_clause,
) )
response = self._client.query_collection_data(request) response = self._client.query_collection_data(request)
documents = [] documents = []
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from alibabacloud_gpdb20160503 import models as gpdb_20160503_models


document_ids_filter = kwargs.get("document_ids_filter")
where_clause = ""
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause += f"metadata_->>'document_id' IN ({document_ids})"
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest( request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id, dbinstance_id=self.config.instance_id,
vector=None, vector=None,
content=query, content=query,
top_k=kwargs.get("top_k", 4), top_k=kwargs.get("top_k", 4),
filter=None,
filter=where_clause,
) )
response = self._client.query_collection_data(request) response = self._client.query_collection_data(request)
documents = [] documents = []

Loading…
Cancel
Save