| import random | import random | ||||
| import string | import string | ||||
| import time | import time | ||||
| import uuid | |||||
| import click | import click | ||||
| from tqdm import tqdm | from tqdm import tqdm | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.rsa import generate_key_pair | from libs.rsa import generate_key_pair | ||||
| from models.account import InvitationCode, Tenant, TenantAccountJoin | from models.account import InvitationCode, Tenant, TenantAccountJoin | ||||
| from models.dataset import Dataset, DatasetQuery, Document | |||||
| from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding | |||||
| from models.model import Account, AppModelConfig, App | from models.model import Account, AppModelConfig, App | ||||
| import secrets | import secrets | ||||
| import base64 | import base64 | ||||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | kw_index = IndexBuilder.get_index(dataset, 'economy') | ||||
| # delete from vector index | # delete from vector index | ||||
| if vector_index: | if vector_index: | ||||
| vector_index.delete() | |||||
| if dataset.collection_binding_id: | |||||
| vector_index.delete_by_group_id(dataset.id) | |||||
| else: | |||||
| if dataset.collection_binding_id: | |||||
| vector_index.delete_by_group_id(dataset.id) | |||||
| else: | |||||
| vector_index.delete() | |||||
| kw_index.delete() | kw_index.delete() | ||||
| # update document | # update document | ||||
| update_params = { | update_params = { | ||||
| is_valid=True, | is_valid=True, | ||||
| ) | ) | ||||
| model_provider = OpenAIProvider(provider=provider) | model_provider = OpenAIProvider(provider=provider) | ||||
| embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider) | |||||
| embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", | |||||
| model_provider=model_provider) | |||||
| embeddings = CacheEmbedding(embedding_model) | embeddings = CacheEmbedding(embedding_model) | ||||
| from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig | from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig | ||||
| index.create_qdrant_dataset(dataset) | index.create_qdrant_dataset(dataset) | ||||
| index_struct = { | index_struct = { | ||||
| "type": 'qdrant', | "type": 'qdrant', | ||||
| "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} | |||||
| "vector_store": { | |||||
| "class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} | |||||
| } | } | ||||
| dataset.index_struct = json.dumps(index_struct) | dataset.index_struct = json.dumps(index_struct) | ||||
| db.session.commit() | db.session.commit() | ||||
| click.echo('passed.') | click.echo('passed.') | ||||
| except Exception as e: | except Exception as e: | ||||
| click.echo( | click.echo( | ||||
| click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) | |||||
| click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), | |||||
| fg='red')) | |||||
| continue | continue | ||||
| click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green')) | click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green')) | ||||
| is_valid=True, | is_valid=True, | ||||
| ) | ) | ||||
| model_provider = OpenAIProvider(provider=provider) | model_provider = OpenAIProvider(provider=provider) | ||||
| embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider) | |||||
| embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", | |||||
| model_provider=model_provider) | |||||
| embeddings = CacheEmbedding(embedding_model) | embeddings = CacheEmbedding(embedding_model) | ||||
| from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig | from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig | ||||
| click.echo('passed.') | click.echo('passed.') | ||||
| except Exception as e: | except Exception as e: | ||||
| click.echo( | click.echo( | ||||
| click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) | |||||
| click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), | |||||
| fg='red')) | |||||
| continue | continue | ||||
| click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green')) | click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green')) | ||||
| @click.command('normalization-collections', help='restore all collections in one') | |||||
| def normalization_collections(): | |||||
| click.echo(click.style('Start normalization collections.', fg='green')) | |||||
| normalization_count = 0 | |||||
| page = 1 | |||||
| while True: | |||||
| try: | |||||
| datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ | |||||
| .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) | |||||
| except NotFound: | |||||
| break | |||||
| page += 1 | |||||
| for dataset in datasets: | |||||
| if not dataset.collection_binding_id: | |||||
| try: | |||||
| click.echo('restore dataset index: {}'.format(dataset.id)) | |||||
| try: | |||||
| embedding_model = ModelFactory.get_embedding_model( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_provider_name=dataset.embedding_model_provider, | |||||
| model_name=dataset.embedding_model | |||||
| ) | |||||
| except Exception: | |||||
| provider = Provider( | |||||
| id='provider_id', | |||||
| tenant_id=dataset.tenant_id, | |||||
| provider_name='openai', | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=json.dumps({'openai_api_key': 'TEST'}), | |||||
| is_valid=True, | |||||
| ) | |||||
| model_provider = OpenAIProvider(provider=provider) | |||||
| embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", | |||||
| model_provider=model_provider) | |||||
| embeddings = CacheEmbedding(embedding_model) | |||||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||||
| filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name, | |||||
| DatasetCollectionBinding.model_name == embedding_model.name). \ | |||||
| order_by(DatasetCollectionBinding.created_at). \ | |||||
| first() | |||||
| if not dataset_collection_binding: | |||||
| dataset_collection_binding = DatasetCollectionBinding( | |||||
| provider_name=embedding_model.model_provider.provider_name, | |||||
| model_name=embedding_model.name, | |||||
| collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' | |||||
| ) | |||||
| db.session.add(dataset_collection_binding) | |||||
| db.session.commit() | |||||
| from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig | |||||
| index = QdrantVectorIndex( | |||||
| dataset=dataset, | |||||
| config=QdrantConfig( | |||||
| endpoint=current_app.config.get('QDRANT_URL'), | |||||
| api_key=current_app.config.get('QDRANT_API_KEY'), | |||||
| root_path=current_app.root_path | |||||
| ), | |||||
| embeddings=embeddings | |||||
| ) | |||||
| if index: | |||||
| index.restore_dataset_in_one(dataset, dataset_collection_binding) | |||||
| else: | |||||
| click.echo('passed.') | |||||
| original_index = QdrantVectorIndex( | |||||
| dataset=dataset, | |||||
| config=QdrantConfig( | |||||
| endpoint=current_app.config.get('QDRANT_URL'), | |||||
| api_key=current_app.config.get('QDRANT_API_KEY'), | |||||
| root_path=current_app.root_path | |||||
| ), | |||||
| embeddings=embeddings | |||||
| ) | |||||
| if original_index: | |||||
| original_index.delete_original_collection(dataset, dataset_collection_binding) | |||||
| normalization_count += 1 | |||||
| else: | |||||
| click.echo('passed.') | |||||
| except Exception as e: | |||||
| click.echo( | |||||
| click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), | |||||
| fg='red')) | |||||
| continue | |||||
| click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green')) | |||||
| @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.') | @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.') | ||||
| @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.") | @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.") | ||||
| def update_app_model_configs(batch_size): | def update_app_model_configs(batch_size): | ||||
| .join(App, App.app_model_config_id == AppModelConfig.id) \ | .join(App, App.app_model_config_id == AppModelConfig.id) \ | ||||
| .filter(App.mode == 'completion') \ | .filter(App.mode == 'completion') \ | ||||
| .count() | .count() | ||||
| if total_records == 0: | if total_records == 0: | ||||
| click.secho("No data to migrate.", fg='green') | click.secho("No data to migrate.", fg='green') | ||||
| return | return | ||||
| offset = i * batch_size | offset = i * batch_size | ||||
| limit = min(batch_size, total_records - offset) | limit = min(batch_size, total_records - offset) | ||||
| click.secho(f"Fetching batch {i+1}/{num_batches} from source database...", fg='green') | |||||
| click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green') | |||||
| data_batch = db.session.query(AppModelConfig) \ | data_batch = db.session.query(AppModelConfig) \ | ||||
| .join(App, App.app_model_config_id == AppModelConfig.id) \ | .join(App, App.app_model_config_id == AppModelConfig.id) \ | ||||
| .filter(App.mode == 'completion') \ | .filter(App.mode == 'completion') \ | ||||
| .order_by(App.created_at) \ | .order_by(App.created_at) \ | ||||
| .offset(offset).limit(limit).all() | .offset(offset).limit(limit).all() | ||||
| if not data_batch: | if not data_batch: | ||||
| click.secho("No more data to migrate.", fg='green') | click.secho("No more data to migrate.", fg='green') | ||||
| break | break | ||||
| app_data = db.session.query(App) \ | app_data = db.session.query(App) \ | ||||
| .filter(App.id == data.app_id) \ | .filter(App.id == data.app_id) \ | ||||
| .one() | .one() | ||||
| account_data = db.session.query(Account) \ | account_data = db.session.query(Account) \ | ||||
| .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \ | .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \ | ||||
| .filter(TenantAccountJoin.role == 'owner') \ | .filter(TenantAccountJoin.role == 'owner') \ | ||||
| db.session.commit() | db.session.commit() | ||||
| except Exception as e: | except Exception as e: | ||||
| click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red') | |||||
| click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", | |||||
| fg='red') | |||||
| continue | continue | ||||
| click.secho(f"Successfully migrated batch {i+1}/{num_batches}.", fg='green') | |||||
| click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green') | |||||
| pbar.update(len(data_batch)) | pbar.update(len(data_batch)) | ||||
| def register_commands(app): | def register_commands(app): | ||||
| app.cli.add_command(reset_password) | app.cli.add_command(reset_password) | ||||
| app.cli.add_command(reset_email) | app.cli.add_command(reset_email) | ||||
| app.cli.add_command(clean_unused_dataset_indexes) | app.cli.add_command(clean_unused_dataset_indexes) | ||||
| app.cli.add_command(create_qdrant_indexes) | app.cli.add_command(create_qdrant_indexes) | ||||
| app.cli.add_command(update_qdrant_indexes) | app.cli.add_command(update_qdrant_indexes) | ||||
| app.cli.add_command(update_app_model_configs) | |||||
| app.cli.add_command(update_app_model_configs) | |||||
| app.cli.add_command(normalization_collections) |
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | def create(self, texts: list[Document], **kwargs) -> BaseIndex: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @abstractmethod | |||||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | @abstractmethod | ||||
| def add_texts(self, texts: list[Document], **kwargs): | def add_texts(self, texts: list[Document], **kwargs): | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @abstractmethod | |||||
| def delete_by_group_id(self, group_id: str) -> None: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | @abstractmethod | ||||
| def delete_by_document_id(self, document_id: str): | def delete_by_document_id(self, document_id: str): | ||||
| raise NotImplementedError | raise NotImplementedError |
| return self | return self | ||||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||||
| keyword_table_handler = JiebaKeywordTableHandler() | |||||
| keyword_table = {} | |||||
| for text in texts: | |||||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | |||||
| self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) | |||||
| keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | |||||
| dataset_keyword_table = DatasetKeywordTable( | |||||
| dataset_id=self.dataset.id, | |||||
| keyword_table=json.dumps({ | |||||
| '__type__': 'keyword_table', | |||||
| '__data__': { | |||||
| "index_id": self.dataset.id, | |||||
| "summary": None, | |||||
| "table": {} | |||||
| } | |||||
| }, cls=SetEncoder) | |||||
| ) | |||||
| db.session.add(dataset_keyword_table) | |||||
| db.session.commit() | |||||
| self._save_dataset_keyword_table(keyword_table) | |||||
| return self | |||||
| def add_texts(self, texts: list[Document], **kwargs): | def add_texts(self, texts: list[Document], **kwargs): | ||||
| keyword_table_handler = JiebaKeywordTableHandler() | keyword_table_handler = JiebaKeywordTableHandler() | ||||
| db.session.delete(dataset_keyword_table) | db.session.delete(dataset_keyword_table) | ||||
| db.session.commit() | db.session.commit() | ||||
| def delete_by_group_id(self, group_id: str) -> None: | |||||
| dataset_keyword_table = self.dataset.dataset_keyword_table | |||||
| if dataset_keyword_table: | |||||
| db.session.delete(dataset_keyword_table) | |||||
| db.session.commit() | |||||
| def _save_dataset_keyword_table(self, keyword_table): | def _save_dataset_keyword_table(self, keyword_table): | ||||
| keyword_table_dict = { | keyword_table_dict = { | ||||
| '__type__': 'keyword_table', | '__type__': 'keyword_table', |
| from core.index.base import BaseIndex | from core.index.base import BaseIndex | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, DocumentSegment | |||||
| from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding | |||||
| from models.dataset import Document as DatasetDocument | from models.dataset import Document as DatasetDocument | ||||
| for node_id in ids: | for node_id in ids: | ||||
| vector_store.del_text(node_id) | vector_store.del_text(node_id) | ||||
| def delete_by_group_id(self, group_id: str) -> None: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| vector_store.delete() | |||||
| def delete(self) -> None: | def delete(self) -> None: | ||||
| vector_store = self._get_vector_store() | vector_store = self._get_vector_store() | ||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | vector_store = cast(self._get_vector_store_class(), vector_store) | ||||
| raise e | raise e | ||||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | logging.info(f"Dataset {dataset.id} recreate successfully.") | ||||
| def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): | |||||
| logging.info(f"restore dataset in_one,_dataset {dataset.id}") | |||||
| dataset_documents = db.session.query(DatasetDocument).filter( | |||||
| DatasetDocument.dataset_id == dataset.id, | |||||
| DatasetDocument.indexing_status == 'completed', | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ).all() | |||||
| documents = [] | |||||
| for dataset_document in dataset_documents: | |||||
| segments = db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.document_id == dataset_document.id, | |||||
| DocumentSegment.status == 'completed', | |||||
| DocumentSegment.enabled == True | |||||
| ).all() | |||||
| for segment in segments: | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| } | |||||
| ) | |||||
| documents.append(document) | |||||
| if documents: | |||||
| try: | |||||
| self.create_with_collection_name(documents, dataset_collection_binding.collection_name) | |||||
| except Exception as e: | |||||
| raise e | |||||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||||
| def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): | |||||
| logging.info(f"delete original collection: {dataset.id}") | |||||
| self.delete() | |||||
| dataset.collection_binding_id = dataset_collection_binding.id | |||||
| db.session.add(dataset) | |||||
| db.session.commit() | |||||
| logging.info(f"Dataset {dataset.id} recreate successfully.") |
| return self | return self | ||||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||||
| uuids = self._get_uuids(texts) | |||||
| self._vector_store = WeaviateVectorStore.from_documents( | |||||
| texts, | |||||
| self._embeddings, | |||||
| client=self._client, | |||||
| index_name=collection_name, | |||||
| uuids=uuids, | |||||
| by_text=False | |||||
| ) | |||||
| return self | |||||
| def _get_vector_store(self) -> VectorStore: | def _get_vector_store(self) -> VectorStore: | ||||
| """Only for created index.""" | """Only for created index.""" | ||||
| if self._vector_store: | if self._vector_store: |
| from langchain.embeddings.base import Embeddings | from langchain.embeddings.base import Embeddings | ||||
| from langchain.vectorstores import VectorStore | from langchain.vectorstores import VectorStore | ||||
| from langchain.vectorstores.utils import maximal_marginal_relevance | from langchain.vectorstores.utils import maximal_marginal_relevance | ||||
| from qdrant_client.http.models import PayloadSchemaType | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from qdrant_client import grpc # noqa | from qdrant_client import grpc # noqa | ||||
| CONTENT_KEY = "page_content" | CONTENT_KEY = "page_content" | ||||
| METADATA_KEY = "metadata" | METADATA_KEY = "metadata" | ||||
| GROUP_KEY = "group_id" | |||||
| VECTOR_NAME = None | VECTOR_NAME = None | ||||
| def __init__( | def __init__( | ||||
| embeddings: Optional[Embeddings] = None, | embeddings: Optional[Embeddings] = None, | ||||
| content_payload_key: str = CONTENT_KEY, | content_payload_key: str = CONTENT_KEY, | ||||
| metadata_payload_key: str = METADATA_KEY, | metadata_payload_key: str = METADATA_KEY, | ||||
| group_payload_key: str = GROUP_KEY, | |||||
| group_id: str = None, | |||||
| distance_strategy: str = "COSINE", | distance_strategy: str = "COSINE", | ||||
| vector_name: Optional[str] = VECTOR_NAME, | vector_name: Optional[str] = VECTOR_NAME, | ||||
| embedding_function: Optional[Callable] = None, # deprecated | embedding_function: Optional[Callable] = None, # deprecated | ||||
| is_new_collection: bool = False | |||||
| ): | ): | ||||
| """Initialize with necessary components.""" | """Initialize with necessary components.""" | ||||
| try: | try: | ||||
| self.collection_name = collection_name | self.collection_name = collection_name | ||||
| self.content_payload_key = content_payload_key or self.CONTENT_KEY | self.content_payload_key = content_payload_key or self.CONTENT_KEY | ||||
| self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY | self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY | ||||
| self.group_payload_key = group_payload_key or self.GROUP_KEY | |||||
| self.vector_name = vector_name or self.VECTOR_NAME | self.vector_name = vector_name or self.VECTOR_NAME | ||||
| self.group_id = group_id | |||||
| self.is_new_collection= is_new_collection | |||||
| if embedding_function is not None: | if embedding_function is not None: | ||||
| warnings.warn( | warnings.warn( | ||||
| batch_size: | batch_size: | ||||
| How many vectors upload per-request. | How many vectors upload per-request. | ||||
| Default: 64 | Default: 64 | ||||
| group_id: | |||||
| collection group | |||||
| Returns: | Returns: | ||||
| List of ids from adding the texts into the vectorstore. | List of ids from adding the texts into the vectorstore. | ||||
| collection_name=self.collection_name, points=points, **kwargs | collection_name=self.collection_name, points=points, **kwargs | ||||
| ) | ) | ||||
| added_ids.extend(batch_ids) | added_ids.extend(batch_ids) | ||||
| # if is new collection, create payload index on group_id | |||||
| if self.is_new_collection: | |||||
| self.client.create_payload_index(self.collection_name, self.group_payload_key, | |||||
| field_schema=PayloadSchemaType.KEYWORD, | |||||
| field_type=PayloadSchemaType.KEYWORD) | |||||
| return added_ids | return added_ids | ||||
| @sync_call_fallback | @sync_call_fallback | ||||
| distance_func: str = "Cosine", | distance_func: str = "Cosine", | ||||
| content_payload_key: str = CONTENT_KEY, | content_payload_key: str = CONTENT_KEY, | ||||
| metadata_payload_key: str = METADATA_KEY, | metadata_payload_key: str = METADATA_KEY, | ||||
| group_payload_key: str = GROUP_KEY, | |||||
| group_id: str = None, | |||||
| vector_name: Optional[str] = VECTOR_NAME, | vector_name: Optional[str] = VECTOR_NAME, | ||||
| batch_size: int = 64, | batch_size: int = 64, | ||||
| shard_number: Optional[int] = None, | shard_number: Optional[int] = None, | ||||
| metadata_payload_key: | metadata_payload_key: | ||||
| A payload key used to store the metadata of the document. | A payload key used to store the metadata of the document. | ||||
| Default: "metadata" | Default: "metadata" | ||||
| group_payload_key: | |||||
| A payload key used to store the content of the document. | |||||
| Default: "group_id" | |||||
| group_id: | |||||
| collection group id | |||||
| vector_name: | vector_name: | ||||
| Name of the vector to be used internally in Qdrant. | Name of the vector to be used internally in Qdrant. | ||||
| Default: None | Default: None | ||||
| distance_func, | distance_func, | ||||
| content_payload_key, | content_payload_key, | ||||
| metadata_payload_key, | metadata_payload_key, | ||||
| group_payload_key, | |||||
| group_id, | |||||
| vector_name, | vector_name, | ||||
| shard_number, | shard_number, | ||||
| replication_factor, | replication_factor, | ||||
| distance_func: str = "Cosine", | distance_func: str = "Cosine", | ||||
| content_payload_key: str = CONTENT_KEY, | content_payload_key: str = CONTENT_KEY, | ||||
| metadata_payload_key: str = METADATA_KEY, | metadata_payload_key: str = METADATA_KEY, | ||||
| group_payload_key: str = GROUP_KEY, | |||||
| group_id: str = None, | |||||
| vector_name: Optional[str] = VECTOR_NAME, | vector_name: Optional[str] = VECTOR_NAME, | ||||
| shard_number: Optional[int] = None, | shard_number: Optional[int] = None, | ||||
| replication_factor: Optional[int] = None, | replication_factor: Optional[int] = None, | ||||
| vector_size = len(partial_embeddings[0]) | vector_size = len(partial_embeddings[0]) | ||||
| collection_name = collection_name or uuid.uuid4().hex | collection_name = collection_name or uuid.uuid4().hex | ||||
| distance_func = distance_func.upper() | distance_func = distance_func.upper() | ||||
| is_new_collection = False | |||||
| client = qdrant_client.QdrantClient( | client = qdrant_client.QdrantClient( | ||||
| location=location, | location=location, | ||||
| url=url, | url=url, | ||||
| init_from=init_from, | init_from=init_from, | ||||
| timeout=timeout, # type: ignore[arg-type] | timeout=timeout, # type: ignore[arg-type] | ||||
| ) | ) | ||||
| is_new_collection = True | |||||
| qdrant = cls( | qdrant = cls( | ||||
| client=client, | client=client, | ||||
| collection_name=collection_name, | collection_name=collection_name, | ||||
| metadata_payload_key=metadata_payload_key, | metadata_payload_key=metadata_payload_key, | ||||
| distance_strategy=distance_func, | distance_strategy=distance_func, | ||||
| vector_name=vector_name, | vector_name=vector_name, | ||||
| group_id=group_id, | |||||
| group_payload_key=group_payload_key, | |||||
| is_new_collection=is_new_collection | |||||
| ) | ) | ||||
| return qdrant | return qdrant | ||||
| metadatas: Optional[List[dict]], | metadatas: Optional[List[dict]], | ||||
| content_payload_key: str, | content_payload_key: str, | ||||
| metadata_payload_key: str, | metadata_payload_key: str, | ||||
| group_id: str, | |||||
| group_payload_key: str | |||||
| ) -> List[dict]: | ) -> List[dict]: | ||||
| payloads = [] | payloads = [] | ||||
| for i, text in enumerate(texts): | for i, text in enumerate(texts): | ||||
| { | { | ||||
| content_payload_key: text, | content_payload_key: text, | ||||
| metadata_payload_key: metadata, | metadata_payload_key: metadata, | ||||
| group_payload_key: group_id | |||||
| } | } | ||||
| ) | ) | ||||
| else: | else: | ||||
| out.append( | out.append( | ||||
| rest.FieldCondition( | rest.FieldCondition( | ||||
| key=f"{self.metadata_payload_key}.{key}", | |||||
| key=key, | |||||
| match=rest.MatchValue(value=value), | match=rest.MatchValue(value=value), | ||||
| ) | ) | ||||
| ) | ) | ||||
| metadatas: Optional[List[dict]] = None, | metadatas: Optional[List[dict]] = None, | ||||
| ids: Optional[Sequence[str]] = None, | ids: Optional[Sequence[str]] = None, | ||||
| batch_size: int = 64, | batch_size: int = 64, | ||||
| group_id: Optional[str] = None, | |||||
| ) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]: | ) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]: | ||||
| from qdrant_client.http import models as rest | from qdrant_client.http import models as rest | ||||
| batch_metadatas, | batch_metadatas, | ||||
| self.content_payload_key, | self.content_payload_key, | ||||
| self.metadata_payload_key, | self.metadata_payload_key, | ||||
| self.group_id, | |||||
| self.group_payload_key | |||||
| ), | ), | ||||
| ) | ) | ||||
| ] | ] |
| from langchain.schema import Document, BaseRetriever | from langchain.schema import Document, BaseRetriever | ||||
| from langchain.vectorstores import VectorStore | from langchain.vectorstores import VectorStore | ||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from qdrant_client.http.models import HnswConfigDiff | |||||
| from core.index.base import BaseIndex | from core.index.base import BaseIndex | ||||
| from core.index.vector_index.base import BaseVectorIndex | from core.index.vector_index.base import BaseVectorIndex | ||||
| from core.vector_store.qdrant_vector_store import QdrantVectorStore | from core.vector_store.qdrant_vector_store import QdrantVectorStore | ||||
| from models.dataset import Dataset | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, DatasetCollectionBinding | |||||
| class QdrantConfig(BaseModel): | class QdrantConfig(BaseModel): | ||||
| endpoint: str | endpoint: str | ||||
| api_key: Optional[str] | api_key: Optional[str] | ||||
| root_path: Optional[str] | root_path: Optional[str] | ||||
| def to_qdrant_params(self): | def to_qdrant_params(self): | ||||
| if self.endpoint and self.endpoint.startswith('path:'): | if self.endpoint and self.endpoint.startswith('path:'): | ||||
| path = self.endpoint.replace('path:', '') | path = self.endpoint.replace('path:', '') | ||||
| return 'qdrant' | return 'qdrant' | ||||
| def get_index_name(self, dataset: Dataset) -> str: | def get_index_name(self, dataset: Dataset) -> str: | ||||
| if self.dataset.index_struct_dict: | |||||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| if not class_prefix.endswith('_Node'): | |||||
| # original class_prefix | |||||
| class_prefix += '_Node' | |||||
| return class_prefix | |||||
| if dataset.collection_binding_id: | |||||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||||
| filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ | |||||
| one_or_none() | |||||
| if dataset_collection_binding: | |||||
| return dataset_collection_binding.collection_name | |||||
| else: | |||||
| raise ValueError('Dataset Collection Bindings is not exist!') | |||||
| else: | |||||
| if self.dataset.index_struct_dict: | |||||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| return class_prefix | |||||
| dataset_id = dataset.id | |||||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||||
| dataset_id = dataset.id | |||||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||||
| def to_index_struct(self) -> dict: | def to_index_struct(self) -> dict: | ||||
| return { | return { | ||||
| collection_name=self.get_index_name(self.dataset), | collection_name=self.get_index_name(self.dataset), | ||||
| ids=uuids, | ids=uuids, | ||||
| content_payload_key='page_content', | content_payload_key='page_content', | ||||
| group_id=self.dataset.id, | |||||
| group_payload_key='group_id', | |||||
| hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||||
| max_indexing_threads=0, on_disk=False), | |||||
| **self._client_config.to_qdrant_params() | |||||
| ) | |||||
| return self | |||||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||||
| uuids = self._get_uuids(texts) | |||||
| self._vector_store = QdrantVectorStore.from_documents( | |||||
| texts, | |||||
| self._embeddings, | |||||
| collection_name=collection_name, | |||||
| ids=uuids, | |||||
| content_payload_key='page_content', | |||||
| group_id=self.dataset.id, | |||||
| group_payload_key='group_id', | |||||
| hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||||
| max_indexing_threads=0, on_disk=False), | |||||
| **self._client_config.to_qdrant_params() | **self._client_config.to_qdrant_params() | ||||
| ) | ) | ||||
| if self._vector_store: | if self._vector_store: | ||||
| return self._vector_store | return self._vector_store | ||||
| attributes = ['doc_id', 'dataset_id', 'document_id'] | attributes = ['doc_id', 'dataset_id', 'document_id'] | ||||
| if self._is_origin(): | |||||
| attributes = ['doc_id'] | |||||
| client = qdrant_client.QdrantClient( | client = qdrant_client.QdrantClient( | ||||
| **self._client_config.to_qdrant_params() | **self._client_config.to_qdrant_params() | ||||
| ) | ) | ||||
| client=client, | client=client, | ||||
| collection_name=self.get_index_name(self.dataset), | collection_name=self.get_index_name(self.dataset), | ||||
| embeddings=self._embeddings, | embeddings=self._embeddings, | ||||
| content_payload_key='page_content' | |||||
| content_payload_key='page_content', | |||||
| group_id=self.dataset.id, | |||||
| group_payload_key='group_id' | |||||
| ) | ) | ||||
| def _get_vector_store_class(self) -> type: | def _get_vector_store_class(self) -> type: | ||||
| return QdrantVectorStore | return QdrantVectorStore | ||||
| def delete_by_document_id(self, document_id: str): | def delete_by_document_id(self, document_id: str): | ||||
| if self._is_origin(): | |||||
| self.recreate_dataset(self.dataset) | |||||
| return | |||||
| vector_store = self._get_vector_store() | vector_store = self._get_vector_store() | ||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | vector_store = cast(self._get_vector_store_class(), vector_store) | ||||
| )) | )) | ||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| if self._is_origin(): | |||||
| self.recreate_dataset(self.dataset) | |||||
| return | |||||
| vector_store = self._get_vector_store() | vector_store = self._get_vector_store() | ||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | vector_store = cast(self._get_vector_store_class(), vector_store) | ||||
| ], | ], | ||||
| )) | )) | ||||
| def delete_by_group_id(self, group_id: str) -> None: | |||||
| vector_store = self._get_vector_store() | |||||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||||
| from qdrant_client.http import models | |||||
| vector_store.del_texts(models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="group_id", | |||||
| match=models.MatchValue(value=group_id), | |||||
| ), | |||||
| ], | |||||
| )) | |||||
| def _is_origin(self): | def _is_origin(self): | ||||
| if self.dataset.index_struct_dict: | if self.dataset.index_struct_dict: | ||||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] |
| return self | return self | ||||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||||
| uuids = self._get_uuids(texts) | |||||
| self._vector_store = WeaviateVectorStore.from_documents( | |||||
| texts, | |||||
| self._embeddings, | |||||
| client=self._client, | |||||
| index_name=self.get_index_name(self.dataset), | |||||
| uuids=uuids, | |||||
| by_text=False | |||||
| ) | |||||
| return self | |||||
| def _get_vector_store(self) -> VectorStore: | def _get_vector_store(self) -> VectorStore: | ||||
| """Only for created index.""" | """Only for created index.""" | ||||
| if self._vector_store: | if self._vector_store: |
| return_resource: str | return_resource: str | ||||
| retriever_from: str | retriever_from: str | ||||
| @classmethod | @classmethod | ||||
| def from_dataset(cls, dataset: Dataset, **kwargs): | def from_dataset(cls, dataset: Dataset, **kwargs): | ||||
| description = dataset.description | description = dataset.description | ||||
| query, | query, | ||||
| search_type='similarity_score_threshold', | search_type='similarity_score_threshold', | ||||
| search_kwargs={ | search_kwargs={ | ||||
| 'k': self.k | |||||
| 'k': self.k, | |||||
| 'filter': { | |||||
| 'group_id': [dataset.id] | |||||
| } | |||||
| } | } | ||||
| ) | ) | ||||
| else: | else: |
| self.client.delete_collection(collection_name=self.collection_name) | self.client.delete_collection(collection_name=self.collection_name) | ||||
| def delete_group(self): | |||||
| self._reload_if_needed() | |||||
| self.client.delete_collection(collection_name=self.collection_name) | |||||
| @classmethod | @classmethod | ||||
| def _document_from_scored_point( | def _document_from_scored_point( | ||||
| cls, | cls, |
| """add_dataset_collection_binding | |||||
| Revision ID: 6e2cfb077b04 | |||||
| Revises: 77e83833755c | |||||
| Create Date: 2023-09-13 22:16:48.027810 | |||||
| """ | |||||
| from alembic import op | |||||
| import sqlalchemy as sa | |||||
| from sqlalchemy.dialects import postgresql | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = '6e2cfb077b04' | |||||
| down_revision = '77e83833755c' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| op.create_table('dataset_collection_bindings', | |||||
| sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||||
| sa.Column('provider_name', sa.String(length=40), nullable=False), | |||||
| sa.Column('model_name', sa.String(length=40), nullable=False), | |||||
| sa.Column('collection_name', sa.String(length=64), nullable=False), | |||||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||||
| sa.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey') | |||||
| ) | |||||
| with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: | |||||
| batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False) | |||||
| with op.batch_alter_table('datasets', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True)) | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('datasets', schema=None) as batch_op: | |||||
| batch_op.drop_column('collection_binding_id') | |||||
| with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: | |||||
| batch_op.drop_index('provider_model_name_idx') | |||||
| op.drop_table('dataset_collection_bindings') | |||||
| # ### end Alembic commands ### |
| server_default=db.text('CURRENT_TIMESTAMP(0)')) | server_default=db.text('CURRENT_TIMESTAMP(0)')) | ||||
| embedding_model = db.Column(db.String(255), nullable=True) | embedding_model = db.Column(db.String(255), nullable=True) | ||||
| embedding_model_provider = db.Column(db.String(255), nullable=True) | embedding_model_provider = db.Column(db.String(255), nullable=True) | ||||
| collection_binding_id = db.Column(UUID, nullable=True) | |||||
| @property | @property | ||||
| def dataset_keyword_table(self): | def dataset_keyword_table(self): | ||||
| def get_embedding(self) -> list[float]: | def get_embedding(self) -> list[float]: | ||||
| return pickle.loads(self.embedding) | return pickle.loads(self.embedding) | ||||
| class DatasetCollectionBinding(db.Model): | |||||
| __tablename__ = 'dataset_collection_bindings' | |||||
| __table_args__ = ( | |||||
| db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'), | |||||
| db.Index('provider_model_name_idx', 'provider_name', 'model_name') | |||||
| ) | |||||
| id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) | |||||
| provider_name = db.Column(db.String(40), nullable=False) | |||||
| model_name = db.Column(db.String(40), nullable=False) | |||||
| collection_name = db.Column(db.String(64), nullable=False) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs import helper | from libs import helper | ||||
| from models.account import Account | from models.account import Account | ||||
| from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment | |||||
| from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment, \ | |||||
| DatasetCollectionBinding | |||||
| from models.model import UploadFile | from models.model import UploadFile | ||||
| from models.source import DataSourceBinding | from models.source import DataSourceBinding | ||||
| from services.errors.account import NoPermissionError | from services.errors.account import NoPermissionError | ||||
| action = 'remove' | action = 'remove' | ||||
| filtered_data['embedding_model'] = None | filtered_data['embedding_model'] = None | ||||
| filtered_data['embedding_model_provider'] = None | filtered_data['embedding_model_provider'] = None | ||||
| filtered_data['collection_binding_id'] = None | |||||
| elif data['indexing_technique'] == 'high_quality': | elif data['indexing_technique'] == 'high_quality': | ||||
| action = 'add' | action = 'add' | ||||
| # get embedding model setting | # get embedding model setting | ||||
| ) | ) | ||||
| filtered_data['embedding_model'] = embedding_model.name | filtered_data['embedding_model'] = embedding_model.name | ||||
| filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name | filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name | ||||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||||
| embedding_model.model_provider.provider_name, | |||||
| embedding_model.name | |||||
| ) | |||||
| filtered_data['collection_binding_id'] = dataset_collection_binding.id | |||||
| except LLMBadRequestError: | except LLMBadRequestError: | ||||
| raise ValueError( | raise ValueError( | ||||
| f"No Embedding Model available. Please configure a valid provider " | f"No Embedding Model available. Please configure a valid provider " | ||||
| ) | ) | ||||
| dataset.embedding_model = embedding_model.name | dataset.embedding_model = embedding_model.name | ||||
| dataset.embedding_model_provider = embedding_model.model_provider.provider_name | dataset.embedding_model_provider = embedding_model.model_provider.provider_name | ||||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||||
| embedding_model.model_provider.provider_name, | |||||
| embedding_model.name | |||||
| ) | |||||
| dataset.collection_binding_id = dataset_collection_binding.id | |||||
| documents = [] | documents = [] | ||||
| batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) | batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) | ||||
| if total_count > tenant_document_count: | if total_count > tenant_document_count: | ||||
| raise ValueError(f"All your documents have overed limit {tenant_document_count}.") | raise ValueError(f"All your documents have overed limit {tenant_document_count}.") | ||||
| embedding_model = None | embedding_model = None | ||||
| dataset_collection_binding_id = None | |||||
| if document_data['indexing_technique'] == 'high_quality': | if document_data['indexing_technique'] == 'high_quality': | ||||
| embedding_model = ModelFactory.get_embedding_model( | embedding_model = ModelFactory.get_embedding_model( | ||||
| tenant_id=tenant_id | tenant_id=tenant_id | ||||
| ) | ) | ||||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||||
| embedding_model.model_provider.provider_name, | |||||
| embedding_model.name | |||||
| ) | |||||
| dataset_collection_binding_id = dataset_collection_binding.id | |||||
| # save dataset | # save dataset | ||||
| dataset = Dataset( | dataset = Dataset( | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| indexing_technique=document_data["indexing_technique"], | indexing_technique=document_data["indexing_technique"], | ||||
| created_by=account.id, | created_by=account.id, | ||||
| embedding_model=embedding_model.name if embedding_model else None, | embedding_model=embedding_model.name if embedding_model else None, | ||||
| embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None | |||||
| embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None, | |||||
| collection_binding_id=dataset_collection_binding_id | |||||
| ) | ) | ||||
| db.session.add(dataset) | db.session.add(dataset) | ||||
| delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id) | delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id) | ||||
| db.session.delete(segment) | db.session.delete(segment) | ||||
| db.session.commit() | db.session.commit() | ||||
| class DatasetCollectionBindingService: | |||||
| @classmethod | |||||
| def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding: | |||||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||||
| filter(DatasetCollectionBinding.provider_name == provider_name, | |||||
| DatasetCollectionBinding.model_name == model_name). \ | |||||
| order_by(DatasetCollectionBinding.created_at). \ | |||||
| first() | |||||
| if not dataset_collection_binding: | |||||
| dataset_collection_binding = DatasetCollectionBinding( | |||||
| provider_name=provider_name, | |||||
| model_name=model_name, | |||||
| collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' | |||||
| ) | |||||
| db.session.add(dataset_collection_binding) | |||||
| db.session.flush() | |||||
| return dataset_collection_binding |
| query, | query, | ||||
| search_type='similarity_score_threshold', | search_type='similarity_score_threshold', | ||||
| search_kwargs={ | search_kwargs={ | ||||
| 'k': 10 | |||||
| 'k': 10, | |||||
| 'filter': { | |||||
| 'group_id': [dataset.id] | |||||
| } | |||||
| } | } | ||||
| ) | ) | ||||
| end = time.perf_counter() | end = time.perf_counter() |