|
|
|
@@ -233,6 +233,12 @@ class AnalyticdbVectorOpenAPI: |
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: |
|
|
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models |
|
|
|
|
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
where_clause = "" |
|
|
|
if document_ids_filter: |
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) |
|
|
|
where_clause += f"metadata_->>'document_id' IN ({document_ids})" |
|
|
|
|
|
|
|
score_threshold = kwargs.get("score_threshold") or 0.0 |
|
|
|
request = gpdb_20160503_models.QueryCollectionDataRequest( |
|
|
|
dbinstance_id=self.config.instance_id, |
|
|
|
@@ -245,7 +251,7 @@ class AnalyticdbVectorOpenAPI: |
|
|
|
vector=query_vector, |
|
|
|
content=None, |
|
|
|
top_k=kwargs.get("top_k", 4), |
|
|
|
filter=None, |
|
|
|
filter=where_clause, |
|
|
|
) |
|
|
|
response = self._client.query_collection_data(request) |
|
|
|
documents = [] |
|
|
|
@@ -265,6 +271,11 @@ class AnalyticdbVectorOpenAPI: |
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: |
|
|
|
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models |
|
|
|
|
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
where_clause = "" |
|
|
|
if document_ids_filter: |
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) |
|
|
|
where_clause += f"metadata_->>'document_id' IN ({document_ids})" |
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0) |
|
|
|
request = gpdb_20160503_models.QueryCollectionDataRequest( |
|
|
|
dbinstance_id=self.config.instance_id, |
|
|
|
@@ -277,7 +288,7 @@ class AnalyticdbVectorOpenAPI: |
|
|
|
vector=None, |
|
|
|
content=query, |
|
|
|
top_k=kwargs.get("top_k", 4), |
|
|
|
filter=None, |
|
|
|
filter=where_clause, |
|
|
|
) |
|
|
|
response = self._client.query_collection_data(request) |
|
|
|
documents = [] |