Co-authored-by: jyong <jyong@dify.ai>tags/0.3.23
| @@ -4,6 +4,7 @@ import math | |||
| import random | |||
| import string | |||
| import time | |||
| import uuid | |||
| import click | |||
| from tqdm import tqdm | |||
| @@ -23,7 +24,7 @@ from libs.helper import email as email_validate | |||
| from extensions.ext_database import db | |||
| from libs.rsa import generate_key_pair | |||
| from models.account import InvitationCode, Tenant, TenantAccountJoin | |||
| from models.dataset import Dataset, DatasetQuery, Document | |||
| from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding | |||
| from models.model import Account, AppModelConfig, App | |||
| import secrets | |||
| import base64 | |||
| @@ -239,7 +240,13 @@ def clean_unused_dataset_indexes(): | |||
| kw_index = IndexBuilder.get_index(dataset, 'economy') | |||
| # delete from vector index | |||
| if vector_index: | |||
| vector_index.delete() | |||
| if dataset.collection_binding_id: | |||
| vector_index.delete_by_group_id(dataset.id) | |||
| else: | |||
| if dataset.collection_binding_id: | |||
| vector_index.delete_by_group_id(dataset.id) | |||
| else: | |||
| vector_index.delete() | |||
| kw_index.delete() | |||
| # update document | |||
| update_params = { | |||
| @@ -346,7 +353,8 @@ def create_qdrant_indexes(): | |||
| is_valid=True, | |||
| ) | |||
| model_provider = OpenAIProvider(provider=provider) | |||
| embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider) | |||
| embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", | |||
| model_provider=model_provider) | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig | |||
| @@ -364,7 +372,8 @@ def create_qdrant_indexes(): | |||
| index.create_qdrant_dataset(dataset) | |||
| index_struct = { | |||
| "type": 'qdrant', | |||
| "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} | |||
| "vector_store": { | |||
| "class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} | |||
| } | |||
| dataset.index_struct = json.dumps(index_struct) | |||
| db.session.commit() | |||
| @@ -373,7 +382,8 @@ def create_qdrant_indexes(): | |||
| click.echo('passed.') | |||
| except Exception as e: | |||
| click.echo( | |||
| click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) | |||
| click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), | |||
| fg='red')) | |||
| continue | |||
| click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green')) | |||
| @@ -414,7 +424,8 @@ def update_qdrant_indexes(): | |||
| is_valid=True, | |||
| ) | |||
| model_provider = OpenAIProvider(provider=provider) | |||
| embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider) | |||
| embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", | |||
| model_provider=model_provider) | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig | |||
| @@ -435,11 +446,104 @@ def update_qdrant_indexes(): | |||
| click.echo('passed.') | |||
| except Exception as e: | |||
| click.echo( | |||
| click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) | |||
| click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), | |||
| fg='red')) | |||
| continue | |||
| click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green')) | |||
| @click.command('normalization-collections', help='restore all collections in one') | |||
| def normalization_collections(): | |||
| click.echo(click.style('Start normalization collections.', fg='green')) | |||
| normalization_count = 0 | |||
| page = 1 | |||
| while True: | |||
| try: | |||
| datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ | |||
| .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) | |||
| except NotFound: | |||
| break | |||
| page += 1 | |||
| for dataset in datasets: | |||
| if not dataset.collection_binding_id: | |||
| try: | |||
| click.echo('restore dataset index: {}'.format(dataset.id)) | |||
| try: | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| ) | |||
| except Exception: | |||
| provider = Provider( | |||
| id='provider_id', | |||
| tenant_id=dataset.tenant_id, | |||
| provider_name='openai', | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps({'openai_api_key': 'TEST'}), | |||
| is_valid=True, | |||
| ) | |||
| model_provider = OpenAIProvider(provider=provider) | |||
| embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", | |||
| model_provider=model_provider) | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||
| filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name, | |||
| DatasetCollectionBinding.model_name == embedding_model.name). \ | |||
| order_by(DatasetCollectionBinding.created_at). \ | |||
| first() | |||
| if not dataset_collection_binding: | |||
| dataset_collection_binding = DatasetCollectionBinding( | |||
| provider_name=embedding_model.model_provider.provider_name, | |||
| model_name=embedding_model.name, | |||
| collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' | |||
| ) | |||
| db.session.add(dataset_collection_binding) | |||
| db.session.commit() | |||
| from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig | |||
| index = QdrantVectorIndex( | |||
| dataset=dataset, | |||
| config=QdrantConfig( | |||
| endpoint=current_app.config.get('QDRANT_URL'), | |||
| api_key=current_app.config.get('QDRANT_API_KEY'), | |||
| root_path=current_app.root_path | |||
| ), | |||
| embeddings=embeddings | |||
| ) | |||
| if index: | |||
| index.restore_dataset_in_one(dataset, dataset_collection_binding) | |||
| else: | |||
| click.echo('passed.') | |||
| original_index = QdrantVectorIndex( | |||
| dataset=dataset, | |||
| config=QdrantConfig( | |||
| endpoint=current_app.config.get('QDRANT_URL'), | |||
| api_key=current_app.config.get('QDRANT_API_KEY'), | |||
| root_path=current_app.root_path | |||
| ), | |||
| embeddings=embeddings | |||
| ) | |||
| if original_index: | |||
| original_index.delete_original_collection(dataset, dataset_collection_binding) | |||
| normalization_count += 1 | |||
| else: | |||
| click.echo('passed.') | |||
| except Exception as e: | |||
| click.echo( | |||
| click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), | |||
| fg='red')) | |||
| continue | |||
| click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green')) | |||
| @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.') | |||
| @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.") | |||
| def update_app_model_configs(batch_size): | |||
| @@ -473,7 +577,7 @@ def update_app_model_configs(batch_size): | |||
| .join(App, App.app_model_config_id == AppModelConfig.id) \ | |||
| .filter(App.mode == 'completion') \ | |||
| .count() | |||
| if total_records == 0: | |||
| click.secho("No data to migrate.", fg='green') | |||
| return | |||
| @@ -485,14 +589,14 @@ def update_app_model_configs(batch_size): | |||
| offset = i * batch_size | |||
| limit = min(batch_size, total_records - offset) | |||
| click.secho(f"Fetching batch {i+1}/{num_batches} from source database...", fg='green') | |||
| click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green') | |||
| data_batch = db.session.query(AppModelConfig) \ | |||
| .join(App, App.app_model_config_id == AppModelConfig.id) \ | |||
| .filter(App.mode == 'completion') \ | |||
| .order_by(App.created_at) \ | |||
| .offset(offset).limit(limit).all() | |||
| if not data_batch: | |||
| click.secho("No more data to migrate.", fg='green') | |||
| break | |||
| @@ -512,7 +616,7 @@ def update_app_model_configs(batch_size): | |||
| app_data = db.session.query(App) \ | |||
| .filter(App.id == data.app_id) \ | |||
| .one() | |||
| account_data = db.session.query(Account) \ | |||
| .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \ | |||
| .filter(TenantAccountJoin.role == 'owner') \ | |||
| @@ -534,13 +638,15 @@ def update_app_model_configs(batch_size): | |||
| db.session.commit() | |||
| except Exception as e: | |||
| click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red') | |||
| click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", | |||
| fg='red') | |||
| continue | |||
| click.secho(f"Successfully migrated batch {i+1}/{num_batches}.", fg='green') | |||
| click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green') | |||
| pbar.update(len(data_batch)) | |||
| def register_commands(app): | |||
| app.cli.add_command(reset_password) | |||
| app.cli.add_command(reset_email) | |||
| @@ -551,4 +657,5 @@ def register_commands(app): | |||
| app.cli.add_command(clean_unused_dataset_indexes) | |||
| app.cli.add_command(create_qdrant_indexes) | |||
| app.cli.add_command(update_qdrant_indexes) | |||
| app.cli.add_command(update_app_model_configs) | |||
| app.cli.add_command(update_app_model_configs) | |||
| app.cli.add_command(normalization_collections) | |||
| @@ -16,6 +16,10 @@ class BaseIndex(ABC): | |||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def add_texts(self, texts: list[Document], **kwargs): | |||
| raise NotImplementedError | |||
| @@ -28,6 +32,10 @@ class BaseIndex(ABC): | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def delete_by_group_id(self, group_id: str) -> None: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def delete_by_document_id(self, document_id: str): | |||
| raise NotImplementedError | |||
| @@ -46,6 +46,32 @@ class KeywordTableIndex(BaseIndex): | |||
| return self | |||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||
| keyword_table_handler = JiebaKeywordTableHandler() | |||
| keyword_table = {} | |||
| for text in texts: | |||
| keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) | |||
| self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) | |||
| keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) | |||
| dataset_keyword_table = DatasetKeywordTable( | |||
| dataset_id=self.dataset.id, | |||
| keyword_table=json.dumps({ | |||
| '__type__': 'keyword_table', | |||
| '__data__': { | |||
| "index_id": self.dataset.id, | |||
| "summary": None, | |||
| "table": {} | |||
| } | |||
| }, cls=SetEncoder) | |||
| ) | |||
| db.session.add(dataset_keyword_table) | |||
| db.session.commit() | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| return self | |||
| def add_texts(self, texts: list[Document], **kwargs): | |||
| keyword_table_handler = JiebaKeywordTableHandler() | |||
| @@ -120,6 +146,12 @@ class KeywordTableIndex(BaseIndex): | |||
| db.session.delete(dataset_keyword_table) | |||
| db.session.commit() | |||
| def delete_by_group_id(self, group_id: str) -> None: | |||
| dataset_keyword_table = self.dataset.dataset_keyword_table | |||
| if dataset_keyword_table: | |||
| db.session.delete(dataset_keyword_table) | |||
| db.session.commit() | |||
| def _save_dataset_keyword_table(self, keyword_table): | |||
| keyword_table_dict = { | |||
| '__type__': 'keyword_table', | |||
| @@ -10,7 +10,7 @@ from weaviate import UnexpectedStatusCodeException | |||
| from core.index.base import BaseIndex | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DocumentSegment | |||
| from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding | |||
| from models.dataset import Document as DatasetDocument | |||
| @@ -110,6 +110,12 @@ class BaseVectorIndex(BaseIndex): | |||
| for node_id in ids: | |||
| vector_store.del_text(node_id) | |||
| def delete_by_group_id(self, group_id: str) -> None: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| vector_store.delete() | |||
| def delete(self) -> None: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| @@ -243,3 +249,53 @@ class BaseVectorIndex(BaseIndex): | |||
| raise e | |||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||
| def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): | |||
| logging.info(f"restore dataset in_one,_dataset {dataset.id}") | |||
| dataset_documents = db.session.query(DatasetDocument).filter( | |||
| DatasetDocument.dataset_id == dataset.id, | |||
| DatasetDocument.indexing_status == 'completed', | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ).all() | |||
| documents = [] | |||
| for dataset_document in dataset_documents: | |||
| segments = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.document_id == dataset_document.id, | |||
| DocumentSegment.status == 'completed', | |||
| DocumentSegment.enabled == True | |||
| ).all() | |||
| for segment in segments: | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| ) | |||
| documents.append(document) | |||
| if documents: | |||
| try: | |||
| self.create_with_collection_name(documents, dataset_collection_binding.collection_name) | |||
| except Exception as e: | |||
| raise e | |||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||
| def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding): | |||
| logging.info(f"delete original collection: {dataset.id}") | |||
| self.delete() | |||
| dataset.collection_binding_id = dataset_collection_binding.id | |||
| db.session.add(dataset) | |||
| db.session.commit() | |||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||
| @@ -69,6 +69,19 @@ class MilvusVectorIndex(BaseVectorIndex): | |||
| return self | |||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||
| uuids = self._get_uuids(texts) | |||
| self._vector_store = WeaviateVectorStore.from_documents( | |||
| texts, | |||
| self._embeddings, | |||
| client=self._client, | |||
| index_name=collection_name, | |||
| uuids=uuids, | |||
| by_text=False | |||
| ) | |||
| return self | |||
| def _get_vector_store(self) -> VectorStore: | |||
| """Only for created index.""" | |||
| if self._vector_store: | |||
| @@ -28,6 +28,7 @@ from langchain.docstore.document import Document | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.vectorstores import VectorStore | |||
| from langchain.vectorstores.utils import maximal_marginal_relevance | |||
| from qdrant_client.http.models import PayloadSchemaType | |||
| if TYPE_CHECKING: | |||
| from qdrant_client import grpc # noqa | |||
| @@ -84,6 +85,7 @@ class Qdrant(VectorStore): | |||
| CONTENT_KEY = "page_content" | |||
| METADATA_KEY = "metadata" | |||
| GROUP_KEY = "group_id" | |||
| VECTOR_NAME = None | |||
| def __init__( | |||
| @@ -93,9 +95,12 @@ class Qdrant(VectorStore): | |||
| embeddings: Optional[Embeddings] = None, | |||
| content_payload_key: str = CONTENT_KEY, | |||
| metadata_payload_key: str = METADATA_KEY, | |||
| group_payload_key: str = GROUP_KEY, | |||
| group_id: str = None, | |||
| distance_strategy: str = "COSINE", | |||
| vector_name: Optional[str] = VECTOR_NAME, | |||
| embedding_function: Optional[Callable] = None, # deprecated | |||
| is_new_collection: bool = False | |||
| ): | |||
| """Initialize with necessary components.""" | |||
| try: | |||
| @@ -129,7 +134,10 @@ class Qdrant(VectorStore): | |||
| self.collection_name = collection_name | |||
| self.content_payload_key = content_payload_key or self.CONTENT_KEY | |||
| self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY | |||
| self.group_payload_key = group_payload_key or self.GROUP_KEY | |||
| self.vector_name = vector_name or self.VECTOR_NAME | |||
| self.group_id = group_id | |||
| self.is_new_collection= is_new_collection | |||
| if embedding_function is not None: | |||
| warnings.warn( | |||
| @@ -170,6 +178,8 @@ class Qdrant(VectorStore): | |||
| batch_size: | |||
| How many vectors upload per-request. | |||
| Default: 64 | |||
| group_id: | |||
| collection group | |||
| Returns: | |||
| List of ids from adding the texts into the vectorstore. | |||
| @@ -182,7 +192,11 @@ class Qdrant(VectorStore): | |||
| collection_name=self.collection_name, points=points, **kwargs | |||
| ) | |||
| added_ids.extend(batch_ids) | |||
| # if is new collection, create payload index on group_id | |||
| if self.is_new_collection: | |||
| self.client.create_payload_index(self.collection_name, self.group_payload_key, | |||
| field_schema=PayloadSchemaType.KEYWORD, | |||
| field_type=PayloadSchemaType.KEYWORD) | |||
| return added_ids | |||
| @sync_call_fallback | |||
| @@ -970,6 +984,8 @@ class Qdrant(VectorStore): | |||
| distance_func: str = "Cosine", | |||
| content_payload_key: str = CONTENT_KEY, | |||
| metadata_payload_key: str = METADATA_KEY, | |||
| group_payload_key: str = GROUP_KEY, | |||
| group_id: str = None, | |||
| vector_name: Optional[str] = VECTOR_NAME, | |||
| batch_size: int = 64, | |||
| shard_number: Optional[int] = None, | |||
| @@ -1034,6 +1050,11 @@ class Qdrant(VectorStore): | |||
| metadata_payload_key: | |||
| A payload key used to store the metadata of the document. | |||
| Default: "metadata" | |||
| group_payload_key: | |||
| A payload key used to store the content of the document. | |||
| Default: "group_id" | |||
| group_id: | |||
| collection group id | |||
| vector_name: | |||
| Name of the vector to be used internally in Qdrant. | |||
| Default: None | |||
| @@ -1107,6 +1128,8 @@ class Qdrant(VectorStore): | |||
| distance_func, | |||
| content_payload_key, | |||
| metadata_payload_key, | |||
| group_payload_key, | |||
| group_id, | |||
| vector_name, | |||
| shard_number, | |||
| replication_factor, | |||
| @@ -1321,6 +1344,8 @@ class Qdrant(VectorStore): | |||
| distance_func: str = "Cosine", | |||
| content_payload_key: str = CONTENT_KEY, | |||
| metadata_payload_key: str = METADATA_KEY, | |||
| group_payload_key: str = GROUP_KEY, | |||
| group_id: str = None, | |||
| vector_name: Optional[str] = VECTOR_NAME, | |||
| shard_number: Optional[int] = None, | |||
| replication_factor: Optional[int] = None, | |||
| @@ -1350,6 +1375,7 @@ class Qdrant(VectorStore): | |||
| vector_size = len(partial_embeddings[0]) | |||
| collection_name = collection_name or uuid.uuid4().hex | |||
| distance_func = distance_func.upper() | |||
| is_new_collection = False | |||
| client = qdrant_client.QdrantClient( | |||
| location=location, | |||
| url=url, | |||
| @@ -1454,6 +1480,7 @@ class Qdrant(VectorStore): | |||
| init_from=init_from, | |||
| timeout=timeout, # type: ignore[arg-type] | |||
| ) | |||
| is_new_collection = True | |||
| qdrant = cls( | |||
| client=client, | |||
| collection_name=collection_name, | |||
| @@ -1462,6 +1489,9 @@ class Qdrant(VectorStore): | |||
| metadata_payload_key=metadata_payload_key, | |||
| distance_strategy=distance_func, | |||
| vector_name=vector_name, | |||
| group_id=group_id, | |||
| group_payload_key=group_payload_key, | |||
| is_new_collection=is_new_collection | |||
| ) | |||
| return qdrant | |||
| @@ -1516,6 +1546,8 @@ class Qdrant(VectorStore): | |||
| metadatas: Optional[List[dict]], | |||
| content_payload_key: str, | |||
| metadata_payload_key: str, | |||
| group_id: str, | |||
| group_payload_key: str | |||
| ) -> List[dict]: | |||
| payloads = [] | |||
| for i, text in enumerate(texts): | |||
| @@ -1529,6 +1561,7 @@ class Qdrant(VectorStore): | |||
| { | |||
| content_payload_key: text, | |||
| metadata_payload_key: metadata, | |||
| group_payload_key: group_id | |||
| } | |||
| ) | |||
| @@ -1578,7 +1611,7 @@ class Qdrant(VectorStore): | |||
| else: | |||
| out.append( | |||
| rest.FieldCondition( | |||
| key=f"{self.metadata_payload_key}.{key}", | |||
| key=key, | |||
| match=rest.MatchValue(value=value), | |||
| ) | |||
| ) | |||
| @@ -1654,6 +1687,7 @@ class Qdrant(VectorStore): | |||
| metadatas: Optional[List[dict]] = None, | |||
| ids: Optional[Sequence[str]] = None, | |||
| batch_size: int = 64, | |||
| group_id: Optional[str] = None, | |||
| ) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]: | |||
| from qdrant_client.http import models as rest | |||
| @@ -1684,6 +1718,8 @@ class Qdrant(VectorStore): | |||
| batch_metadatas, | |||
| self.content_payload_key, | |||
| self.metadata_payload_key, | |||
| self.group_id, | |||
| self.group_payload_key | |||
| ), | |||
| ) | |||
| ] | |||
| @@ -6,18 +6,20 @@ from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import Document, BaseRetriever | |||
| from langchain.vectorstores import VectorStore | |||
| from pydantic import BaseModel | |||
| from qdrant_client.http.models import HnswConfigDiff | |||
| from core.index.base import BaseIndex | |||
| from core.index.vector_index.base import BaseVectorIndex | |||
| from core.vector_store.qdrant_vector_store import QdrantVectorStore | |||
| from models.dataset import Dataset | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DatasetCollectionBinding | |||
| class QdrantConfig(BaseModel): | |||
| endpoint: str | |||
| api_key: Optional[str] | |||
| root_path: Optional[str] | |||
| def to_qdrant_params(self): | |||
| if self.endpoint and self.endpoint.startswith('path:'): | |||
| path = self.endpoint.replace('path:', '') | |||
| @@ -43,16 +45,21 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| return 'qdrant' | |||
| def get_index_name(self, dataset: Dataset) -> str: | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| if not class_prefix.endswith('_Node'): | |||
| # original class_prefix | |||
| class_prefix += '_Node' | |||
| return class_prefix | |||
| if dataset.collection_binding_id: | |||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||
| filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ | |||
| one_or_none() | |||
| if dataset_collection_binding: | |||
| return dataset_collection_binding.collection_name | |||
| else: | |||
| raise ValueError('Dataset Collection Bindings is not exist!') | |||
| else: | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| return class_prefix | |||
| dataset_id = dataset.id | |||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||
| dataset_id = dataset.id | |||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||
| def to_index_struct(self) -> dict: | |||
| return { | |||
| @@ -68,6 +75,27 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| collection_name=self.get_index_name(self.dataset), | |||
| ids=uuids, | |||
| content_payload_key='page_content', | |||
| group_id=self.dataset.id, | |||
| group_payload_key='group_id', | |||
| hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||
| max_indexing_threads=0, on_disk=False), | |||
| **self._client_config.to_qdrant_params() | |||
| ) | |||
| return self | |||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||
| uuids = self._get_uuids(texts) | |||
| self._vector_store = QdrantVectorStore.from_documents( | |||
| texts, | |||
| self._embeddings, | |||
| collection_name=collection_name, | |||
| ids=uuids, | |||
| content_payload_key='page_content', | |||
| group_id=self.dataset.id, | |||
| group_payload_key='group_id', | |||
| hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||
| max_indexing_threads=0, on_disk=False), | |||
| **self._client_config.to_qdrant_params() | |||
| ) | |||
| @@ -78,8 +106,6 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| if self._vector_store: | |||
| return self._vector_store | |||
| attributes = ['doc_id', 'dataset_id', 'document_id'] | |||
| if self._is_origin(): | |||
| attributes = ['doc_id'] | |||
| client = qdrant_client.QdrantClient( | |||
| **self._client_config.to_qdrant_params() | |||
| ) | |||
| @@ -88,16 +114,15 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| client=client, | |||
| collection_name=self.get_index_name(self.dataset), | |||
| embeddings=self._embeddings, | |||
| content_payload_key='page_content' | |||
| content_payload_key='page_content', | |||
| group_id=self.dataset.id, | |||
| group_payload_key='group_id' | |||
| ) | |||
| def _get_vector_store_class(self) -> type: | |||
| return QdrantVectorStore | |||
| def delete_by_document_id(self, document_id: str): | |||
| if self._is_origin(): | |||
| self.recreate_dataset(self.dataset) | |||
| return | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| @@ -114,9 +139,6 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| )) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| if self._is_origin(): | |||
| self.recreate_dataset(self.dataset) | |||
| return | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| @@ -132,6 +154,22 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| ], | |||
| )) | |||
| def delete_by_group_id(self, group_id: str) -> None: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| from qdrant_client.http import models | |||
| vector_store.del_texts(models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="group_id", | |||
| match=models.MatchValue(value=group_id), | |||
| ), | |||
| ], | |||
| )) | |||
| def _is_origin(self): | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| @@ -91,6 +91,20 @@ class WeaviateVectorIndex(BaseVectorIndex): | |||
| return self | |||
| def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: | |||
| uuids = self._get_uuids(texts) | |||
| self._vector_store = WeaviateVectorStore.from_documents( | |||
| texts, | |||
| self._embeddings, | |||
| client=self._client, | |||
| index_name=self.get_index_name(self.dataset), | |||
| uuids=uuids, | |||
| by_text=False | |||
| ) | |||
| return self | |||
| def _get_vector_store(self) -> VectorStore: | |||
| """Only for created index.""" | |||
| if self._vector_store: | |||
| @@ -33,7 +33,6 @@ class DatasetRetrieverTool(BaseTool): | |||
| return_resource: str | |||
| retriever_from: str | |||
| @classmethod | |||
| def from_dataset(cls, dataset: Dataset, **kwargs): | |||
| description = dataset.description | |||
| @@ -94,7 +93,10 @@ class DatasetRetrieverTool(BaseTool): | |||
| query, | |||
| search_type='similarity_score_threshold', | |||
| search_kwargs={ | |||
| 'k': self.k | |||
| 'k': self.k, | |||
| 'filter': { | |||
| 'group_id': [dataset.id] | |||
| } | |||
| } | |||
| ) | |||
| else: | |||
| @@ -46,6 +46,11 @@ class QdrantVectorStore(Qdrant): | |||
| self.client.delete_collection(collection_name=self.collection_name) | |||
| def delete_group(self): | |||
| self._reload_if_needed() | |||
| self.client.delete_collection(collection_name=self.collection_name) | |||
| @classmethod | |||
| def _document_from_scored_point( | |||
| cls, | |||
| @@ -0,0 +1,47 @@ | |||
| """add_dataset_collection_binding | |||
| Revision ID: 6e2cfb077b04 | |||
| Revises: 77e83833755c | |||
| Create Date: 2023-09-13 22:16:48.027810 | |||
| """ | |||
| from alembic import op | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.dialects import postgresql | |||
| # revision identifiers, used by Alembic. | |||
| revision = '6e2cfb077b04' | |||
| down_revision = '77e83833755c' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| op.create_table('dataset_collection_bindings', | |||
| sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||
| sa.Column('provider_name', sa.String(length=40), nullable=False), | |||
| sa.Column('model_name', sa.String(length=40), nullable=False), | |||
| sa.Column('collection_name', sa.String(length=64), nullable=False), | |||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||
| sa.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey') | |||
| ) | |||
| with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: | |||
| batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False) | |||
| with op.batch_alter_table('datasets', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True)) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('datasets', schema=None) as batch_op: | |||
| batch_op.drop_column('collection_binding_id') | |||
| with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: | |||
| batch_op.drop_index('provider_model_name_idx') | |||
| op.drop_table('dataset_collection_bindings') | |||
| # ### end Alembic commands ### | |||
| @@ -38,6 +38,8 @@ class Dataset(db.Model): | |||
| server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| embedding_model = db.Column(db.String(255), nullable=True) | |||
| embedding_model_provider = db.Column(db.String(255), nullable=True) | |||
| collection_binding_id = db.Column(UUID, nullable=True) | |||
| @property | |||
| def dataset_keyword_table(self): | |||
| @@ -445,3 +447,19 @@ class Embedding(db.Model): | |||
| def get_embedding(self) -> list[float]: | |||
| return pickle.loads(self.embedding) | |||
| class DatasetCollectionBinding(db.Model): | |||
| __tablename__ = 'dataset_collection_bindings' | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'), | |||
| db.Index('provider_model_name_idx', 'provider_name', 'model_name') | |||
| ) | |||
| id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) | |||
| provider_name = db.Column(db.String(40), nullable=False) | |||
| model_name = db.Column(db.String(40), nullable=False) | |||
| collection_name = db.Column(db.String(64), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| @@ -20,7 +20,8 @@ from events.document_event import document_was_deleted | |||
| from extensions.ext_database import db | |||
| from libs import helper | |||
| from models.account import Account | |||
| from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment | |||
| from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment, \ | |||
| DatasetCollectionBinding | |||
| from models.model import UploadFile | |||
| from models.source import DataSourceBinding | |||
| from services.errors.account import NoPermissionError | |||
| @@ -147,6 +148,7 @@ class DatasetService: | |||
| action = 'remove' | |||
| filtered_data['embedding_model'] = None | |||
| filtered_data['embedding_model_provider'] = None | |||
| filtered_data['collection_binding_id'] = None | |||
| elif data['indexing_technique'] == 'high_quality': | |||
| action = 'add' | |||
| # get embedding model setting | |||
| @@ -156,6 +158,11 @@ class DatasetService: | |||
| ) | |||
| filtered_data['embedding_model'] = embedding_model.name | |||
| filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||
| embedding_model.model_provider.provider_name, | |||
| embedding_model.name | |||
| ) | |||
| filtered_data['collection_binding_id'] = dataset_collection_binding.id | |||
| except LLMBadRequestError: | |||
| raise ValueError( | |||
| f"No Embedding Model available. Please configure a valid provider " | |||
| @@ -464,7 +471,11 @@ class DocumentService: | |||
| ) | |||
| dataset.embedding_model = embedding_model.name | |||
| dataset.embedding_model_provider = embedding_model.model_provider.provider_name | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||
| embedding_model.model_provider.provider_name, | |||
| embedding_model.name | |||
| ) | |||
| dataset.collection_binding_id = dataset_collection_binding.id | |||
| documents = [] | |||
| batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) | |||
| @@ -720,10 +731,16 @@ class DocumentService: | |||
| if total_count > tenant_document_count: | |||
| raise ValueError(f"All your documents have overed limit {tenant_document_count}.") | |||
| embedding_model = None | |||
| dataset_collection_binding_id = None | |||
| if document_data['indexing_technique'] == 'high_quality': | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=tenant_id | |||
| ) | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||
| embedding_model.model_provider.provider_name, | |||
| embedding_model.name | |||
| ) | |||
| dataset_collection_binding_id = dataset_collection_binding.id | |||
| # save dataset | |||
| dataset = Dataset( | |||
| tenant_id=tenant_id, | |||
| @@ -732,7 +749,8 @@ class DocumentService: | |||
| indexing_technique=document_data["indexing_technique"], | |||
| created_by=account.id, | |||
| embedding_model=embedding_model.name if embedding_model else None, | |||
| embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None | |||
| embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None, | |||
| collection_binding_id=dataset_collection_binding_id | |||
| ) | |||
| db.session.add(dataset) | |||
| @@ -1069,3 +1087,23 @@ class SegmentService: | |||
| delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id) | |||
| db.session.delete(segment) | |||
| db.session.commit() | |||
| class DatasetCollectionBindingService: | |||
| @classmethod | |||
| def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding: | |||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||
| filter(DatasetCollectionBinding.provider_name == provider_name, | |||
| DatasetCollectionBinding.model_name == model_name). \ | |||
| order_by(DatasetCollectionBinding.created_at). \ | |||
| first() | |||
| if not dataset_collection_binding: | |||
| dataset_collection_binding = DatasetCollectionBinding( | |||
| provider_name=provider_name, | |||
| model_name=model_name, | |||
| collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' | |||
| ) | |||
| db.session.add(dataset_collection_binding) | |||
| db.session.flush() | |||
| return dataset_collection_binding | |||
| @@ -47,7 +47,10 @@ class HitTestingService: | |||
| query, | |||
| search_type='similarity_score_threshold', | |||
| search_kwargs={ | |||
| 'k': 10 | |||
| 'k': 10, | |||
| 'filter': { | |||
| 'group_id': [dataset.id] | |||
| } | |||
| } | |||
| ) | |||
| end = time.perf_counter() | |||