| import math | import math | ||||
| import random | import random | ||||
| import string | import string | ||||
| import threading | |||||
| import time | import time | ||||
| import uuid | import uuid | ||||
| import click | import click | ||||
| from tqdm import tqdm | from tqdm import tqdm | ||||
| from flask import current_app | |||||
| from flask import current_app, Flask | |||||
| from langchain.embeddings import OpenAIEmbeddings | from langchain.embeddings import OpenAIEmbeddings | ||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| @click.command('normalization-collections', help='restore all collections in one') | @click.command('normalization-collections', help='restore all collections in one') | ||||
| def normalization_collections(): | def normalization_collections(): | ||||
| click.echo(click.style('Start normalization collections.', fg='green')) | click.echo(click.style('Start normalization collections.', fg='green')) | ||||
| normalization_count = 0 | |||||
| normalization_count = [] | |||||
| page = 1 | page = 1 | ||||
| while True: | while True: | ||||
| try: | try: | ||||
| datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ | datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ | ||||
| .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) | |||||
| .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100) | |||||
| except NotFound: | except NotFound: | ||||
| break | break | ||||
| datasets_result = datasets.items | |||||
| page += 1 | page += 1 | ||||
| for dataset in datasets: | |||||
| if not dataset.collection_binding_id: | |||||
| try: | |||||
| click.echo('restore dataset index: {}'.format(dataset.id)) | |||||
| try: | |||||
| embedding_model = ModelFactory.get_embedding_model( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_provider_name=dataset.embedding_model_provider, | |||||
| model_name=dataset.embedding_model | |||||
| ) | |||||
| except Exception: | |||||
| provider = Provider( | |||||
| id='provider_id', | |||||
| tenant_id=dataset.tenant_id, | |||||
| provider_name='openai', | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=json.dumps({'openai_api_key': 'TEST'}), | |||||
| is_valid=True, | |||||
| ) | |||||
| model_provider = OpenAIProvider(provider=provider) | |||||
| embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", | |||||
| model_provider=model_provider) | |||||
| embeddings = CacheEmbedding(embedding_model) | |||||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||||
| filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name, | |||||
| DatasetCollectionBinding.model_name == embedding_model.name). \ | |||||
| order_by(DatasetCollectionBinding.created_at). \ | |||||
| first() | |||||
| if not dataset_collection_binding: | |||||
| dataset_collection_binding = DatasetCollectionBinding( | |||||
| provider_name=embedding_model.model_provider.provider_name, | |||||
| model_name=embedding_model.name, | |||||
| collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' | |||||
| ) | |||||
| db.session.add(dataset_collection_binding) | |||||
| db.session.commit() | |||||
| for i in range(0, len(datasets_result), 5): | |||||
| threads = [] | |||||
| sub_datasets = datasets_result[i:i + 5] | |||||
| for dataset in sub_datasets: | |||||
| document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={ | |||||
| 'flask_app': current_app._get_current_object(), | |||||
| 'dataset': dataset, | |||||
| 'normalization_count': normalization_count | |||||
| }) | |||||
| threads.append(document_format_thread) | |||||
| document_format_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green')) | |||||
| def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list): | |||||
| with flask_app.app_context(): | |||||
| try: | |||||
| click.echo('restore dataset index: {}'.format(dataset.id)) | |||||
| try: | |||||
| embedding_model = ModelFactory.get_embedding_model( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_provider_name=dataset.embedding_model_provider, | |||||
| model_name=dataset.embedding_model | |||||
| ) | |||||
| except Exception: | |||||
| provider = Provider( | |||||
| id='provider_id', | |||||
| tenant_id=dataset.tenant_id, | |||||
| provider_name='openai', | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=json.dumps({'openai_api_key': 'TEST'}), | |||||
| is_valid=True, | |||||
| ) | |||||
| model_provider = OpenAIProvider(provider=provider) | |||||
| embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", | |||||
| model_provider=model_provider) | |||||
| embeddings = CacheEmbedding(embedding_model) | |||||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||||
| filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name, | |||||
| DatasetCollectionBinding.model_name == embedding_model.name). \ | |||||
| order_by(DatasetCollectionBinding.created_at). \ | |||||
| first() | |||||
| if not dataset_collection_binding: | |||||
| dataset_collection_binding = DatasetCollectionBinding( | |||||
| provider_name=embedding_model.model_provider.provider_name, | |||||
| model_name=embedding_model.name, | |||||
| collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' | |||||
| ) | |||||
| db.session.add(dataset_collection_binding) | |||||
| db.session.commit() | |||||
| from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig | |||||
| index = QdrantVectorIndex( | |||||
| dataset=dataset, | |||||
| config=QdrantConfig( | |||||
| endpoint=current_app.config.get('QDRANT_URL'), | |||||
| api_key=current_app.config.get('QDRANT_API_KEY'), | |||||
| root_path=current_app.root_path | |||||
| ), | |||||
| embeddings=embeddings | |||||
| ) | |||||
| if index: | |||||
| index.restore_dataset_in_one(dataset, dataset_collection_binding) | |||||
| else: | |||||
| click.echo('passed.') | |||||
| original_index = QdrantVectorIndex( | |||||
| dataset=dataset, | |||||
| config=QdrantConfig( | |||||
| endpoint=current_app.config.get('QDRANT_URL'), | |||||
| api_key=current_app.config.get('QDRANT_API_KEY'), | |||||
| root_path=current_app.root_path | |||||
| ), | |||||
| embeddings=embeddings | |||||
| ) | |||||
| if original_index: | |||||
| original_index.delete_original_collection(dataset, dataset_collection_binding) | |||||
| normalization_count += 1 | |||||
| else: | |||||
| click.echo('passed.') | |||||
| except Exception as e: | |||||
| click.echo( | |||||
| click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), | |||||
| fg='red')) | |||||
| continue | |||||
| click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green')) | |||||
| from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig | |||||
| index = QdrantVectorIndex( | |||||
| dataset=dataset, | |||||
| config=QdrantConfig( | |||||
| endpoint=current_app.config.get('QDRANT_URL'), | |||||
| api_key=current_app.config.get('QDRANT_API_KEY'), | |||||
| root_path=current_app.root_path | |||||
| ), | |||||
| embeddings=embeddings | |||||
| ) | |||||
| if index: | |||||
| # index.delete_by_group_id(dataset.id) | |||||
| index.restore_dataset_in_one(dataset, dataset_collection_binding) | |||||
| else: | |||||
| click.echo('passed.') | |||||
| normalization_count.append(1) | |||||
| except Exception as e: | |||||
| click.echo( | |||||
| click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), | |||||
| fg='red')) | |||||
| @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.') | @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.') |
| def delete_by_group_id(self, group_id: str) -> None: | def delete_by_group_id(self, group_id: str) -> None: | ||||
| vector_store = self._get_vector_store() | vector_store = self._get_vector_store() | ||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | vector_store = cast(self._get_vector_store_class(), vector_store) | ||||
| vector_store.delete() | |||||
| if self.dataset.collection_binding_id: | |||||
| vector_store.delete_by_group_id(group_id) | |||||
| else: | |||||
| vector_store.delete() | |||||
| def delete(self) -> None: | def delete(self) -> None: | ||||
| vector_store = self._get_vector_store() | vector_store = self._get_vector_store() | ||||
| if documents: | if documents: | ||||
| try: | try: | ||||
| self.create_with_collection_name(documents, dataset_collection_binding.collection_name) | |||||
| self.add_texts(documents) | |||||
| except Exception as e: | except Exception as e: | ||||
| raise e | raise e | ||||
| path=path, | path=path, | ||||
| **kwargs, | **kwargs, | ||||
| ) | ) | ||||
| try: | |||||
| # Skip any validation in case of forced collection recreate. | |||||
| if force_recreate: | |||||
| raise ValueError | |||||
| # Get the vector configuration of the existing collection and vector, if it | |||||
| # was specified. If the old configuration does not match the current one, | |||||
| # an exception is being thrown. | |||||
| collection_info = client.get_collection(collection_name=collection_name) | |||||
| current_vector_config = collection_info.config.params.vectors | |||||
| if isinstance(current_vector_config, dict) and vector_name is not None: | |||||
| if vector_name not in current_vector_config: | |||||
| raise QdrantException( | |||||
| f"Existing Qdrant collection {collection_name} does not " | |||||
| f"contain vector named {vector_name}. Did you mean one of the " | |||||
| f"existing vectors: {', '.join(current_vector_config.keys())}? " | |||||
| f"If you want to recreate the collection, set `force_recreate` " | |||||
| f"parameter to `True`." | |||||
| ) | |||||
| current_vector_config = current_vector_config.get( | |||||
| vector_name | |||||
| ) # type: ignore[assignment] | |||||
| elif isinstance(current_vector_config, dict) and vector_name is None: | |||||
| raise QdrantException( | |||||
| f"Existing Qdrant collection {collection_name} uses named vectors. " | |||||
| f"If you want to reuse it, please set `vector_name` to any of the " | |||||
| f"existing named vectors: " | |||||
| f"{', '.join(current_vector_config.keys())}." # noqa | |||||
| f"If you want to recreate the collection, set `force_recreate` " | |||||
| f"parameter to `True`." | |||||
| ) | |||||
| elif ( | |||||
| not isinstance(current_vector_config, dict) and vector_name is not None | |||||
| ): | |||||
| raise QdrantException( | |||||
| f"Existing Qdrant collection {collection_name} doesn't use named " | |||||
| f"vectors. If you want to reuse it, please set `vector_name` to " | |||||
| f"`None`. If you want to recreate the collection, set " | |||||
| f"`force_recreate` parameter to `True`." | |||||
| ) | |||||
| # Check if the vector configuration has the same dimensionality. | |||||
| if current_vector_config.size != vector_size: # type: ignore[union-attr] | |||||
| raise QdrantException( | |||||
| f"Existing Qdrant collection is configured for vectors with " | |||||
| f"{current_vector_config.size} " # type: ignore[union-attr] | |||||
| f"dimensions. Selected embeddings are {vector_size}-dimensional. " | |||||
| f"If you want to recreate the collection, set `force_recreate` " | |||||
| f"parameter to `True`." | |||||
| ) | |||||
| current_distance_func = ( | |||||
| current_vector_config.distance.name.upper() # type: ignore[union-attr] | |||||
| ) | |||||
| if current_distance_func != distance_func: | |||||
| raise QdrantException( | |||||
| f"Existing Qdrant collection is configured for " | |||||
| f"{current_vector_config.distance} " # type: ignore[union-attr] | |||||
| f"similarity. Please set `distance_func` parameter to " | |||||
| f"`{distance_func}` if you want to reuse it. If you want to " | |||||
| f"recreate the collection, set `force_recreate` parameter to " | |||||
| f"`True`." | |||||
| ) | |||||
| except (UnexpectedResponse, RpcError, ValueError): | |||||
| all_collection_name = [] | |||||
| collections_response = client.get_collections() | |||||
| collection_list = collections_response.collections | |||||
| for collection in collection_list: | |||||
| all_collection_name.append(collection.name) | |||||
| if collection_name not in all_collection_name: | |||||
| vectors_config = rest.VectorParams( | vectors_config = rest.VectorParams( | ||||
| size=vector_size, | size=vector_size, | ||||
| distance=rest.Distance[distance_func], | distance=rest.Distance[distance_func], | ||||
| timeout=timeout, # type: ignore[arg-type] | timeout=timeout, # type: ignore[arg-type] | ||||
| ) | ) | ||||
| is_new_collection = True | is_new_collection = True | ||||
| if force_recreate: | |||||
| raise ValueError | |||||
| # Get the vector configuration of the existing collection and vector, if it | |||||
| # was specified. If the old configuration does not match the current one, | |||||
| # an exception is being thrown. | |||||
| collection_info = client.get_collection(collection_name=collection_name) | |||||
| current_vector_config = collection_info.config.params.vectors | |||||
| if isinstance(current_vector_config, dict) and vector_name is not None: | |||||
| if vector_name not in current_vector_config: | |||||
| raise QdrantException( | |||||
| f"Existing Qdrant collection {collection_name} does not " | |||||
| f"contain vector named {vector_name}. Did you mean one of the " | |||||
| f"existing vectors: {', '.join(current_vector_config.keys())}? " | |||||
| f"If you want to recreate the collection, set `force_recreate` " | |||||
| f"parameter to `True`." | |||||
| ) | |||||
| current_vector_config = current_vector_config.get( | |||||
| vector_name | |||||
| ) # type: ignore[assignment] | |||||
| elif isinstance(current_vector_config, dict) and vector_name is None: | |||||
| raise QdrantException( | |||||
| f"Existing Qdrant collection {collection_name} uses named vectors. " | |||||
| f"If you want to reuse it, please set `vector_name` to any of the " | |||||
| f"existing named vectors: " | |||||
| f"{', '.join(current_vector_config.keys())}." # noqa | |||||
| f"If you want to recreate the collection, set `force_recreate` " | |||||
| f"parameter to `True`." | |||||
| ) | |||||
| elif ( | |||||
| not isinstance(current_vector_config, dict) and vector_name is not None | |||||
| ): | |||||
| raise QdrantException( | |||||
| f"Existing Qdrant collection {collection_name} doesn't use named " | |||||
| f"vectors. If you want to reuse it, please set `vector_name` to " | |||||
| f"`None`. If you want to recreate the collection, set " | |||||
| f"`force_recreate` parameter to `True`." | |||||
| ) | |||||
| # Check if the vector configuration has the same dimensionality. | |||||
| if current_vector_config.size != vector_size: # type: ignore[union-attr] | |||||
| raise QdrantException( | |||||
| f"Existing Qdrant collection is configured for vectors with " | |||||
| f"{current_vector_config.size} " # type: ignore[union-attr] | |||||
| f"dimensions. Selected embeddings are {vector_size}-dimensional. " | |||||
| f"If you want to recreate the collection, set `force_recreate` " | |||||
| f"parameter to `True`." | |||||
| ) | |||||
| current_distance_func = ( | |||||
| current_vector_config.distance.name.upper() # type: ignore[union-attr] | |||||
| ) | |||||
| if current_distance_func != distance_func: | |||||
| raise QdrantException( | |||||
| f"Existing Qdrant collection is configured for " | |||||
| f"{current_vector_config.distance} " # type: ignore[union-attr] | |||||
| f"similarity. Please set `distance_func` parameter to " | |||||
| f"`{distance_func}` if you want to reuse it. If you want to " | |||||
| f"recreate the collection, set `force_recreate` parameter to " | |||||
| f"`True`." | |||||
| ) | |||||
| qdrant = cls( | qdrant = cls( | ||||
| client=client, | client=client, | ||||
| collection_name=collection_name, | collection_name=collection_name, |
| ], | ], | ||||
| )) | )) | ||||
| def delete(self) -> None: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| from qdrant_client.http import models | |||||
| vector_store.del_texts(models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="group_id", | |||||
| match=models.MatchValue(value=self.dataset.id), | |||||
| ), | |||||
| ], | |||||
| )) | |||||
| def _is_origin(self): | def _is_origin(self): | ||||
| if self.dataset.index_struct_dict: | if self.dataset.index_struct_dict: |
| @dataset_was_deleted.connect | @dataset_was_deleted.connect | ||||
| def handle(sender, **kwargs): | def handle(sender, **kwargs): | ||||
| dataset = sender | dataset = sender | ||||
| clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, dataset.index_struct) | |||||
| clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, | |||||
| dataset.index_struct, dataset.collection_binding_id) |
| @shared_task(queue='dataset') | @shared_task(queue='dataset') | ||||
| def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, index_struct: str): | |||||
| def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, | |||||
| index_struct: str, collection_binding_id: str): | |||||
| """ | """ | ||||
| Clean dataset when dataset deleted. | Clean dataset when dataset deleted. | ||||
| :param dataset_id: dataset id | :param dataset_id: dataset id | ||||
| :param tenant_id: tenant id | :param tenant_id: tenant id | ||||
| :param indexing_technique: indexing technique | :param indexing_technique: indexing technique | ||||
| :param index_struct: index struct dict | :param index_struct: index struct dict | ||||
| :param collection_binding_id: collection binding id | |||||
| Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) | Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) | ||||
| """ | """ | ||||
| id=dataset_id, | id=dataset_id, | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| indexing_technique=indexing_technique, | indexing_technique=indexing_technique, | ||||
| index_struct=index_struct | |||||
| index_struct=index_struct, | |||||
| collection_binding_id=collection_binding_id | |||||
| ) | ) | ||||
| documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() | documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() | ||||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() | segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() | ||||
| if dataset.indexing_technique == 'high_quality': | if dataset.indexing_technique == 'high_quality': | ||||
| vector_index = IndexBuilder.get_default_high_quality_index(dataset) | vector_index = IndexBuilder.get_default_high_quality_index(dataset) | ||||
| try: | try: | ||||
| vector_index.delete() | |||||
| vector_index.delete_by_group_id(dataset.id) | |||||
| except Exception: | except Exception: | ||||
| logging.exception("Delete doc index failed when dataset deleted.") | logging.exception("Delete doc index failed when dataset deleted.") | ||||
| raise Exception('Dataset not found') | raise Exception('Dataset not found') | ||||
| if action == "remove": | if action == "remove": | ||||
| index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False) | |||||
| index.delete() | |||||
| index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True) | |||||
| index.delete_by_group_id(dataset.id) | |||||
| elif action == "add": | elif action == "add": | ||||
| dataset_documents = db.session.query(DatasetDocument).filter( | dataset_documents = db.session.query(DatasetDocument).filter( | ||||
| DatasetDocument.dataset_id == dataset_id, | DatasetDocument.dataset_id == dataset_id, |