|
|
|
@@ -144,6 +144,10 @@ class TidbOnQdrantVector(BaseVector): |
|
|
|
self._client.create_payload_index( |
|
|
|
collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD |
|
|
|
) |
|
|
|
# create document_id payload index |
|
|
|
self._client.create_payload_index( |
|
|
|
collection_name, Field.DOCUMENT_ID.value, field_schema=PayloadSchemaType.KEYWORD |
|
|
|
) |
|
|
|
# create full text index |
|
|
|
text_index_params = TextIndexParams( |
|
|
|
type=TextIndexType.TEXT, |
|
|
|
@@ -318,23 +322,17 @@ class TidbOnQdrantVector(BaseVector): |
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: |
|
|
|
from qdrant_client.http import models |
|
|
|
|
|
|
|
filter = models.Filter( |
|
|
|
must=[ |
|
|
|
models.FieldCondition( |
|
|
|
key="group_id", |
|
|
|
match=models.MatchValue(value=self._group_id), |
|
|
|
), |
|
|
|
], |
|
|
|
) |
|
|
|
filter = None |
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
if document_ids_filter: |
|
|
|
if filter.must: |
|
|
|
filter.must.append( |
|
|
|
filter = models.Filter( |
|
|
|
must=[ |
|
|
|
models.FieldCondition( |
|
|
|
key="metadata.document_id", |
|
|
|
match=models.MatchAny(any=document_ids_filter), |
|
|
|
) |
|
|
|
) |
|
|
|
], |
|
|
|
) |
|
|
|
results = self._client.search( |
|
|
|
collection_name=self._collection_name, |
|
|
|
query_vector=query_vector, |
|
|
|
@@ -369,23 +367,17 @@ class TidbOnQdrantVector(BaseVector): |
|
|
|
""" |
|
|
|
from qdrant_client.http import models |
|
|
|
|
|
|
|
scroll_filter = models.Filter( |
|
|
|
must=[ |
|
|
|
models.FieldCondition( |
|
|
|
key="page_content", |
|
|
|
match=models.MatchText(text=query), |
|
|
|
) |
|
|
|
] |
|
|
|
) |
|
|
|
scroll_filter = None |
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
if document_ids_filter: |
|
|
|
if scroll_filter.must: |
|
|
|
scroll_filter.must.append( |
|
|
|
scroll_filter = models.Filter( |
|
|
|
must=[ |
|
|
|
models.FieldCondition( |
|
|
|
key="metadata.document_id", |
|
|
|
match=models.MatchAny(any=document_ids_filter), |
|
|
|
) |
|
|
|
) |
|
|
|
] |
|
|
|
) |
|
|
|
response = self._client.scroll( |
|
|
|
collection_name=self._collection_name, |
|
|
|
scroll_filter=scroll_filter, |