| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.apikey import api_key_fields, api_key_list | from controllers.console.apikey import api_key_fields, api_key_list | ||||
| from controllers.console.app.error import ProviderNotInitializeError | from controllers.console.app.error import ProviderNotInitializeError | ||||
| from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError | |||||
| from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError | |||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | ||||
| "in the Settings -> Model Provider.") | "in the Settings -> Model Provider.") | ||||
| except ProviderTokenNotInitError as ex: | except ProviderTokenNotInitError as ex: | ||||
| raise ProviderNotInitializeError(ex.description) | raise ProviderNotInitializeError(ex.description) | ||||
| except Exception as e: | |||||
| raise IndexingEstimateError(str(e)) | |||||
| return response, 200 | return response, 200 | ||||
| ArchivedDocumentImmutableError, | ArchivedDocumentImmutableError, | ||||
| DocumentAlreadyFinishedError, | DocumentAlreadyFinishedError, | ||||
| DocumentIndexingError, | DocumentIndexingError, | ||||
| IndexingEstimateError, | |||||
| InvalidActionError, | InvalidActionError, | ||||
| InvalidMetadataError, | InvalidMetadataError, | ||||
| ) | ) | ||||
| "in the Settings -> Model Provider.") | "in the Settings -> Model Provider.") | ||||
| except ProviderTokenNotInitError as ex: | except ProviderTokenNotInitError as ex: | ||||
| raise ProviderNotInitializeError(ex.description) | raise ProviderNotInitializeError(ex.description) | ||||
| except Exception as e: | |||||
| raise IndexingEstimateError(str(e)) | |||||
| return response | return response | ||||
| "in the Settings -> Model Provider.") | "in the Settings -> Model Provider.") | ||||
| except ProviderTokenNotInitError as ex: | except ProviderTokenNotInitError as ex: | ||||
| raise ProviderNotInitializeError(ex.description) | raise ProviderNotInitializeError(ex.description) | ||||
| except Exception as e: | |||||
| raise IndexingEstimateError(str(e)) | |||||
| return response | return response | ||||
| error_code = 'dataset_in_use' | error_code = 'dataset_in_use' | ||||
| description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." | description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." | ||||
| code = 409 | code = 409 | ||||
| class IndexingEstimateError(BaseHTTPException): | |||||
| error_code = 'indexing_estimate_error' | |||||
| description = "Knowledge indexing estimate failed: {message}" | |||||
| code = 500 |
| self._save_dataset_keyword_table(keyword_table) | self._save_dataset_keyword_table(keyword_table) | ||||
| def delete_by_document_id(self, document_id: str): | |||||
| lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) | |||||
| with redis_client.lock(lock_name, timeout=600): | |||||
| # get segment ids by document_id | |||||
| segments = db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.dataset_id == self.dataset.id, | |||||
| DocumentSegment.document_id == document_id | |||||
| ).all() | |||||
| ids = [segment.index_node_id for segment in segments] | |||||
| keyword_table = self._get_dataset_keyword_table() | |||||
| keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) | |||||
| self._save_dataset_keyword_table(keyword_table) | |||||
| def search( | def search( | ||||
| self, query: str, | self, query: str, | ||||
| **kwargs: Any | **kwargs: Any | ||||
| ).first() | ).first() | ||||
| if segment: | if segment: | ||||
| documents.append(Document( | documents.append(Document( | ||||
| page_content=segment.content, | page_content=segment.content, | ||||
| metadata={ | metadata={ |
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @abstractmethod | |||||
| def delete_by_document_id(self, document_id: str) -> None: | |||||
| raise NotImplementedError | |||||
| def delete(self) -> None: | def delete(self) -> None: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| self._keyword_processor.delete_by_ids(ids) | self._keyword_processor.delete_by_ids(ids) | ||||
| def delete_by_document_id(self, document_id: str) -> None: | |||||
| self._keyword_processor.delete_by_document_id(document_id) | |||||
| def delete(self) -> None: | def delete(self) -> None: | ||||
| self._keyword_processor.delete() | self._keyword_processor.delete() | ||||
| raise e | raise e | ||||
| return pks | return pks | ||||
| def delete_by_document_id(self, document_id: str): | |||||
| ids = self.get_ids_by_metadata_field('document_id', document_id) | |||||
| if ids: | |||||
| self._client.delete(collection_name=self._collection_name, pks=ids) | |||||
| def get_ids_by_metadata_field(self, key: str, value: str): | def get_ids_by_metadata_field(self, key: str, value: str): | ||||
| result = self._client.query(collection_name=self._collection_name, | result = self._client.query(collection_name=self._collection_name, | ||||
| filter=f'metadata["{key}"] == "{value}"', | filter=f'metadata["{key}"] == "{value}"', |
| helpers.bulk(self._client, actions) | helpers.bulk(self._client, actions) | ||||
| def delete_by_document_id(self, document_id: str): | |||||
| ids = self.get_ids_by_metadata_field('document_id', document_id) | |||||
| if ids: | |||||
| self.delete_by_ids(ids) | |||||
| def get_ids_by_metadata_field(self, key: str, value: str): | def get_ids_by_metadata_field(self, key: str, value: str): | ||||
| query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} | query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} | ||||
| response = self._client.search(index=self._collection_name.lower(), body=query) | response = self._client.search(index=self._collection_name.lower(), body=query) |
| # idss.append(record[0]) | # idss.append(record[0]) | ||||
| # return idss | # return idss | ||||
| #def delete_by_document_id(self, document_id: str): | |||||
| # ids = self.get_ids_by_metadata_field('doc_id', document_id) | |||||
| # if len(ids)>0: | |||||
| # with self._get_cursor() as cur: | |||||
| # cur.execute(f"delete FROM {self.table_name} d WHERE d.meta.doc_id in '%s'" % ("','".join(ids),)) | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| with self._get_cursor() as cur: | with self._get_cursor() as cur: | ||||
| cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) | cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) |
| return pks | return pks | ||||
| def delete_by_document_id(self, document_id: str): | |||||
| ids = self.get_ids_by_metadata_field('document_id', document_id) | |||||
| if ids: | |||||
| with Session(self._client) as session: | |||||
| select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") | |||||
| session.execute(select_statement, {'ids': ids}) | |||||
| session.commit() | |||||
| def get_ids_by_metadata_field(self, key: str, value: str): | def get_ids_by_metadata_field(self, key: str, value: str): | ||||
| result = None | result = None | ||||
| with Session(self._client) as session: | with Session(self._client) as session: |
| return ids | return ids | ||||
| def delete_by_document_id(self, document_id: str): | |||||
| ids = self.get_ids_by_metadata_field('document_id', document_id) | |||||
| if ids: | |||||
| self.delete_by_uuids(ids) | |||||
| def get_ids_by_metadata_field(self, key: str, value: str): | def get_ids_by_metadata_field(self, key: str, value: str): | ||||
| result = None | result = None | ||||
| with Session(self.client) as session: | with Session(self.client) as session: |
| print("Delete operation failed:", str(e)) | print("Delete operation failed:", str(e)) | ||||
| return False | return False | ||||
| def delete_by_document_id(self, document_id: str): | |||||
| ids = self.get_ids_by_metadata_field('document_id', document_id) | |||||
| if ids: | |||||
| self._delete_by_ids(ids) | |||||
| def get_ids_by_metadata_field(self, key: str, value: str): | def get_ids_by_metadata_field(self, key: str, value: str): | ||||
| with Session(self._engine) as session: | with Session(self._engine) as session: | ||||
| select_statement = sql_text( | select_statement = sql_text( |
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def delete_by_document_id(self, document_id: str): | |||||
| raise NotImplementedError | |||||
| def get_ids_by_metadata_field(self, key: str, value: str): | def get_ids_by_metadata_field(self, key: str, value: str): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| """Abstract interface for document loader implementations.""" | """Abstract interface for document loader implementations.""" | ||||
| import os | |||||
| from typing import Optional | from typing import Optional | ||||
| import pandas as pd | import pandas as pd | ||||
| def extract(self) -> list[Document]: | def extract(self) -> list[Document]: | ||||
| """ Load from Excel file in xls or xlsx format using Pandas.""" | """ Load from Excel file in xls or xlsx format using Pandas.""" | ||||
| documents = [] | documents = [] | ||||
| # Determine the file extension | |||||
| file_extension = os.path.splitext(self._file_path)[-1].lower() | |||||
| # Read each worksheet of an Excel file using Pandas | # Read each worksheet of an Excel file using Pandas | ||||
| excel_file = pd.ExcelFile(self._file_path) | |||||
| if file_extension == '.xlsx': | |||||
| excel_file = pd.ExcelFile(self._file_path, engine='openpyxl') | |||||
| elif file_extension == '.xls': | |||||
| excel_file = pd.ExcelFile(self._file_path, engine='xlrd') | |||||
| else: | |||||
| raise ValueError(f"Unsupported file extension: {file_extension}") | |||||
| for sheet_name in excel_file.sheet_names: | for sheet_name in excel_file.sheet_names: | ||||
| df: pd.DataFrame = excel_file.parse(sheet_name=sheet_name) | df: pd.DataFrame = excel_file.parse(sheet_name=sheet_name) | ||||
| hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) | hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) | ||||
| assert len(hits_by_full_text) == 0 | assert len(hits_by_full_text) == 0 | ||||
| def delete_by_document_id(self): | |||||
| self.vector.delete_by_document_id(document_id=self.example_doc_id) | |||||
| def get_ids_by_metadata_field(self): | def get_ids_by_metadata_field(self): | ||||
| ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) | ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) | ||||
| assert len(ids) == 1 | assert len(ids) == 1 |
| assert hits_by_vector[0].metadata['document_id'] == self.example_doc_id, \ | assert hits_by_vector[0].metadata['document_id'] == self.example_doc_id, \ | ||||
| f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" | f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" | ||||
| def test_delete_by_document_id(self): | |||||
| self.vector._client.delete_by_query.return_value = {'deleted': 1} | |||||
| doc = Document(page_content="Test content to delete", metadata={"document_id": self.example_doc_id}) | doc = Document(page_content="Test content to delete", metadata={"document_id": self.example_doc_id}) | ||||
| embedding = [0.1] * 128 | embedding = [0.1] * 128 | ||||
| mock_bulk.return_value = ([], []) | mock_bulk.return_value = ([], []) | ||||
| self.vector.add_texts([doc], [embedding]) | self.vector.add_texts([doc], [embedding]) | ||||
| self.vector.delete_by_document_id(document_id=self.example_doc_id) | |||||
| self.vector._client.search.return_value = {'hits': {'total': {'value': 0}, 'hits': []}} | self.vector._client.search.return_value = {'hits': {'total': {'value': 0}, 'hits': []}} | ||||
| ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) | ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) | ||||
| expected_doc_id = "example_doc_id" | expected_doc_id = "example_doc_id" | ||||
| self.tester.test_search_by_full_text(search_response, expected_length, expected_doc_id) | self.tester.test_search_by_full_text(search_response, expected_length, expected_doc_id) | ||||
| def test_delete_by_document_id(self): | |||||
| self.tester.setup_method() | |||||
| self.tester.test_delete_by_document_id() | |||||
| def test_get_ids_by_metadata_field(self): | def test_get_ids_by_metadata_field(self): | ||||
| self.tester.setup_method() | self.tester.setup_method() | ||||
| self.tester.test_get_ids_by_metadata_field() | self.tester.test_get_ids_by_metadata_field() |
| hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) | hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) | ||||
| assert len(hits_by_full_text) == 0 | assert len(hits_by_full_text) == 0 | ||||
| def delete_by_document_id(self): | |||||
| self.vector.delete_by_document_id(document_id=self.example_doc_id) | |||||
| def get_ids_by_metadata_field(self): | def get_ids_by_metadata_field(self): | ||||
| ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) | ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) | ||||
| assert len(ids) == 1 | assert len(ids) == 1 |
| def text_exists(self): | def text_exists(self): | ||||
| assert self.vector.text_exists(self.example_doc_id) | assert self.vector.text_exists(self.example_doc_id) | ||||
| def delete_by_document_id(self): | |||||
| with pytest.raises(NotImplementedError): | |||||
| self.vector.delete_by_document_id(document_id=self.example_doc_id) | |||||
| def get_ids_by_metadata_field(self): | def get_ids_by_metadata_field(self): | ||||
| with pytest.raises(NotImplementedError): | with pytest.raises(NotImplementedError): | ||||
| self.vector.get_ids_by_metadata_field(key='key', value='value') | self.vector.get_ids_by_metadata_field(key='key', value='value') | ||||
| self.search_by_full_text() | self.search_by_full_text() | ||||
| self.text_exists() | self.text_exists() | ||||
| self.get_ids_by_metadata_field() | self.get_ids_by_metadata_field() | ||||
| self.delete_by_document_id() | |||||
| added_doc_ids = self.add_texts() | added_doc_ids = self.add_texts() | ||||
| self.delete_by_ids(added_doc_ids) | self.delete_by_ids(added_doc_ids) | ||||
| self.delete_vector() | self.delete_vector() |
| ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) | ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) | ||||
| assert len(ids) == 0 | assert len(ids) == 0 | ||||
| def delete_by_document_id(self): | |||||
| self.vector.delete_by_document_id(document_id=self.example_doc_id) | |||||
| def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_session): | def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_session): | ||||
| TiDBVectorTest(vector=tidb_vector).run_all_tests() | TiDBVectorTest(vector=tidb_vector).run_all_tests() |