|
|
|
@@ -1,3 +1,4 @@ |
|
|
|
import random |
|
|
|
import uuid |
|
|
|
from unittest.mock import MagicMock |
|
|
|
|
|
|
|
@@ -8,26 +9,18 @@ from extensions import ext_redis |
|
|
|
from models.dataset import Dataset |
|
|
|
|
|
|
|
|
|
|
|
def get_sample_text() -> str: |
|
|
|
def get_example_text() -> str: |
|
|
|
return 'test_text' |
|
|
|
|
|
|
|
|
|
|
|
def get_sample_embedding() -> list[float]: |
|
|
|
return [1.1, 2.2, 3.3] |
|
|
|
|
|
|
|
|
|
|
|
def get_sample_query_vector() -> list[float]: |
|
|
|
return get_sample_embedding() |
|
|
|
|
|
|
|
|
|
|
|
def get_sample_document(sample_dataset_id: str) -> Document: |
|
|
|
def get_example_document(doc_id: str) -> Document: |
|
|
|
doc = Document( |
|
|
|
page_content=get_sample_text(), |
|
|
|
page_content=get_example_text(), |
|
|
|
metadata={ |
|
|
|
"doc_id": sample_dataset_id, |
|
|
|
"doc_hash": sample_dataset_id, |
|
|
|
"document_id": sample_dataset_id, |
|
|
|
"dataset_id": sample_dataset_id, |
|
|
|
"doc_id": doc_id, |
|
|
|
"doc_hash": doc_id, |
|
|
|
"document_id": doc_id, |
|
|
|
"dataset_id": doc_id, |
|
|
|
} |
|
|
|
) |
|
|
|
return doc |
|
|
|
@@ -53,49 +46,48 @@ class AbstractTestVector: |
|
|
|
self.vector = None |
|
|
|
self.dataset_id = str(uuid.uuid4()) |
|
|
|
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) |
|
|
|
self.example_doc_id = str(uuid.uuid4()) |
|
|
|
self.example_embedding = [1.001 * i for i in range(128)] |
|
|
|
|
|
|
|
def create_vector(self) -> None: |
|
|
|
self.vector.create( |
|
|
|
texts=[get_sample_document(self.dataset_id)], |
|
|
|
embeddings=[get_sample_embedding()], |
|
|
|
texts=[get_example_document(doc_id=self.example_doc_id)], |
|
|
|
embeddings=[self.example_embedding], |
|
|
|
) |
|
|
|
|
|
|
|
def search_by_vector(self): |
|
|
|
hits_by_vector = self.vector.search_by_vector(query_vector=get_sample_query_vector()) |
|
|
|
assert len(hits_by_vector) >= 1 |
|
|
|
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding) |
|
|
|
assert len(hits_by_vector) == 1 |
|
|
|
assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id |
|
|
|
|
|
|
|
def search_by_full_text(self): |
|
|
|
hits_by_full_text = self.vector.search_by_full_text(query=get_sample_text()) |
|
|
|
assert len(hits_by_full_text) >= 1 |
|
|
|
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text()) |
|
|
|
assert len(hits_by_full_text) == 1 |
|
|
|
assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id |
|
|
|
|
|
|
|
def delete_vector(self): |
|
|
|
self.vector.delete() |
|
|
|
|
|
|
|
def delete_by_ids(self): |
|
|
|
self.vector.delete_by_ids([self.dataset_id]) |
|
|
|
|
|
|
|
def add_texts(self): |
|
|
|
self.vector.add_texts( |
|
|
|
documents=[ |
|
|
|
get_sample_document(str(uuid.uuid4())), |
|
|
|
get_sample_document(str(uuid.uuid4())), |
|
|
|
], |
|
|
|
embeddings=[ |
|
|
|
get_sample_embedding(), |
|
|
|
get_sample_embedding(), |
|
|
|
], |
|
|
|
) |
|
|
|
def delete_by_ids(self, ids: list[str]): |
|
|
|
self.vector.delete_by_ids(ids=ids) |
|
|
|
|
|
|
|
def add_texts(self) -> list[str]: |
|
|
|
batch_size = 100 |
|
|
|
documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)] |
|
|
|
embeddings = [self.example_embedding] * batch_size |
|
|
|
self.vector.add_texts(documents=documents, embeddings=embeddings) |
|
|
|
return [doc.metadata['doc_id'] for doc in documents] |
|
|
|
|
|
|
|
def text_exists(self): |
|
|
|
self.vector.text_exists(self.dataset_id) |
|
|
|
assert self.vector.text_exists(self.example_doc_id) |
|
|
|
|
|
|
|
def delete_document_by_id(self): |
|
|
|
def delete_by_document_id(self): |
|
|
|
with pytest.raises(NotImplementedError): |
|
|
|
self.vector.delete_by_document_id(self.dataset_id) |
|
|
|
self.vector.delete_by_document_id(document_id=self.example_doc_id) |
|
|
|
|
|
|
|
def get_ids_by_metadata_field(self): |
|
|
|
with pytest.raises(NotImplementedError): |
|
|
|
self.vector.get_ids_by_metadata_field('key', 'value') |
|
|
|
self.vector.get_ids_by_metadata_field(key='key', value='value') |
|
|
|
|
|
|
|
def run_all_tests(self): |
|
|
|
self.create_vector() |
|
|
|
@@ -103,7 +95,7 @@ class AbstractTestVector: |
|
|
|
self.search_by_full_text() |
|
|
|
self.text_exists() |
|
|
|
self.get_ids_by_metadata_field() |
|
|
|
self.add_texts() |
|
|
|
self.delete_document_by_id() |
|
|
|
self.delete_by_ids() |
|
|
|
self.delete_by_document_id() |
|
|
|
added_doc_ids = self.add_texts() |
|
|
|
self.delete_by_ids(added_doc_ids) |
|
|
|
self.delete_vector() |