| @@ -1,4 +1,5 @@ | |||
| import datetime | |||
| import json | |||
| import math | |||
| import random | |||
| import string | |||
| @@ -6,10 +7,16 @@ import time | |||
| import click | |||
| from flask import current_app | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| from werkzeug.exceptions import NotFound | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.index.index import IndexBuilder | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.providers.hosted import hosted_model_providers | |||
| from core.model_providers.providers.openai_provider import OpenAIProvider | |||
| from libs.password import password_pattern, valid_password, hash_password | |||
| from libs.helper import email as email_validate | |||
| from extensions.ext_database import db | |||
| @@ -296,6 +303,66 @@ def sync_anthropic_hosted_providers(): | |||
| click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green')) | |||
| @click.command('create-qdrant-indexes', help='Create qdrant indexes.') | |||
| def create_qdrant_indexes(): | |||
| click.echo(click.style('Start create qdrant indexes.', fg='green')) | |||
| create_count = 0 | |||
| page = 1 | |||
| while True: | |||
| try: | |||
| datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ | |||
| .order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50) | |||
| except NotFound: | |||
| break | |||
| page += 1 | |||
| for dataset in datasets: | |||
| try: | |||
| click.echo('Create dataset qdrant index: {}'.format(dataset.id)) | |||
| try: | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| ) | |||
| except Exception: | |||
| provider = Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name='openai', | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps({'openai_api_key': 'TEST'}), | |||
| is_valid=True, | |||
| ) | |||
| model_provider = OpenAIProvider(provider=provider) | |||
| embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider) | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig | |||
| index = QdrantVectorIndex( | |||
| dataset=dataset, | |||
| config=QdrantConfig( | |||
| endpoint=current_app.config.get('QDRANT_URL'), | |||
| api_key=current_app.config.get('QDRANT_API_KEY'), | |||
| root_path=current_app.root_path | |||
| ), | |||
| embeddings=embeddings | |||
| ) | |||
| if index: | |||
| index.create_qdrant_dataset(dataset) | |||
| create_count += 1 | |||
| else: | |||
| click.echo('passed.') | |||
| except Exception as e: | |||
| click.echo( | |||
| click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) | |||
| continue | |||
| click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green')) | |||
| def register_commands(app): | |||
| app.cli.add_command(reset_password) | |||
| app.cli.add_command(reset_email) | |||
| @@ -304,3 +371,4 @@ def register_commands(app): | |||
| app.cli.add_command(recreate_all_dataset_indexes) | |||
| app.cli.add_command(sync_anthropic_hosted_providers) | |||
| app.cli.add_command(clean_unused_dataset_indexes) | |||
| app.cli.add_command(create_qdrant_indexes) | |||
| @@ -38,7 +38,7 @@ class ExcelLoader(BaseLoader): | |||
| else: | |||
| row_dict = dict(zip(keys, list(map(str, row)))) | |||
| row_dict = {k: v for k, v in row_dict.items() if v} | |||
| item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items()) | |||
| item = ''.join(f'{k}:{v};' for k, v in row_dict.items()) | |||
| document = Document(page_content=item, metadata={'source': self._file_path}) | |||
| data.append(document) | |||
| @@ -173,3 +173,49 @@ class BaseVectorIndex(BaseIndex): | |||
| self.dataset = dataset | |||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||
| def create_qdrant_dataset(self, dataset: Dataset): | |||
| logging.info(f"create_qdrant_dataset {dataset.id}") | |||
| try: | |||
| self.delete() | |||
| except UnexpectedStatusCodeException as e: | |||
| if e.status_code != 400: | |||
| # 400 means index not exists | |||
| raise e | |||
| dataset_documents = db.session.query(DatasetDocument).filter( | |||
| DatasetDocument.dataset_id == dataset.id, | |||
| DatasetDocument.indexing_status == 'completed', | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ).all() | |||
| documents = [] | |||
| for dataset_document in dataset_documents: | |||
| segments = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.document_id == dataset_document.id, | |||
| DocumentSegment.status == 'completed', | |||
| DocumentSegment.enabled == True | |||
| ).all() | |||
| for segment in segments: | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| ) | |||
| documents.append(document) | |||
| if documents: | |||
| try: | |||
| self.create(documents) | |||
| except Exception as e: | |||
| raise e | |||
| logging.info(f"Dataset {dataset.id} recreate successfully.") | |||
| @@ -0,0 +1,114 @@ | |||
| from typing import Optional, cast | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import Document, BaseRetriever | |||
| from langchain.vectorstores import VectorStore, milvus | |||
| from pydantic import BaseModel, root_validator | |||
| from core.index.base import BaseIndex | |||
| from core.index.vector_index.base import BaseVectorIndex | |||
| from core.vector_store.milvus_vector_store import MilvusVectorStore | |||
| from core.vector_store.weaviate_vector_store import WeaviateVectorStore | |||
| from models.dataset import Dataset | |||
| class MilvusConfig(BaseModel): | |||
| endpoint: str | |||
| user: str | |||
| password: str | |||
| batch_size: int = 100 | |||
| @root_validator() | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values['endpoint']: | |||
| raise ValueError("config MILVUS_ENDPOINT is required") | |||
| if not values['user']: | |||
| raise ValueError("config MILVUS_USER is required") | |||
| if not values['password']: | |||
| raise ValueError("config MILVUS_PASSWORD is required") | |||
| return values | |||
| class MilvusVectorIndex(BaseVectorIndex): | |||
| def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings): | |||
| super().__init__(dataset, embeddings) | |||
| self._client = self._init_client(config) | |||
| def get_type(self) -> str: | |||
| return 'milvus' | |||
| def get_index_name(self, dataset: Dataset) -> str: | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| if not class_prefix.endswith('_Node'): | |||
| # original class_prefix | |||
| class_prefix += '_Node' | |||
| return class_prefix | |||
| dataset_id = dataset.id | |||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||
| def to_index_struct(self) -> dict: | |||
| return { | |||
| "type": self.get_type(), | |||
| "vector_store": {"class_prefix": self.get_index_name(self.dataset)} | |||
| } | |||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||
| uuids = self._get_uuids(texts) | |||
| self._vector_store = WeaviateVectorStore.from_documents( | |||
| texts, | |||
| self._embeddings, | |||
| client=self._client, | |||
| index_name=self.get_index_name(self.dataset), | |||
| uuids=uuids, | |||
| by_text=False | |||
| ) | |||
| return self | |||
| def _get_vector_store(self) -> VectorStore: | |||
| """Only for created index.""" | |||
| if self._vector_store: | |||
| return self._vector_store | |||
| attributes = ['doc_id', 'dataset_id', 'document_id'] | |||
| if self._is_origin(): | |||
| attributes = ['doc_id'] | |||
| return WeaviateVectorStore( | |||
| client=self._client, | |||
| index_name=self.get_index_name(self.dataset), | |||
| text_key='text', | |||
| embedding=self._embeddings, | |||
| attributes=attributes, | |||
| by_text=False | |||
| ) | |||
| def _get_vector_store_class(self) -> type: | |||
| return MilvusVectorStore | |||
| def delete_by_document_id(self, document_id: str): | |||
| if self._is_origin(): | |||
| self.recreate_dataset(self.dataset) | |||
| return | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| vector_store.del_texts({ | |||
| "operator": "Equal", | |||
| "path": ["document_id"], | |||
| "valueText": document_id | |||
| }) | |||
| def _is_origin(self): | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| if not class_prefix.endswith('_Node'): | |||
| # original class_prefix | |||
| return True | |||
| return False | |||
| @@ -44,15 +44,20 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| def get_index_name(self, dataset: Dataset) -> str: | |||
| if self.dataset.index_struct_dict: | |||
| return self.dataset.index_struct_dict['vector_store']['collection_name'] | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| if not class_prefix.endswith('_Node'): | |||
| # original class_prefix | |||
| class_prefix += '_Node' | |||
| return class_prefix | |||
| dataset_id = dataset.id | |||
| return "Index_" + dataset_id.replace("-", "_") | |||
| return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' | |||
| def to_index_struct(self) -> dict: | |||
| return { | |||
| "type": self.get_type(), | |||
| "vector_store": {"collection_name": self.get_index_name(self.dataset)} | |||
| "vector_store": {"class_prefix": self.get_index_name(self.dataset)} | |||
| } | |||
| def create(self, texts: list[Document], **kwargs) -> BaseIndex: | |||
| @@ -62,7 +67,7 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| self._embeddings, | |||
| collection_name=self.get_index_name(self.dataset), | |||
| ids=uuids, | |||
| content_payload_key='text', | |||
| content_payload_key='page_content', | |||
| **self._client_config.to_qdrant_params() | |||
| ) | |||
| @@ -72,7 +77,9 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| """Only for created index.""" | |||
| if self._vector_store: | |||
| return self._vector_store | |||
| attributes = ['doc_id', 'dataset_id', 'document_id'] | |||
| if self._is_origin(): | |||
| attributes = ['doc_id'] | |||
| client = qdrant_client.QdrantClient( | |||
| **self._client_config.to_qdrant_params() | |||
| ) | |||
| @@ -81,7 +88,7 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| client=client, | |||
| collection_name=self.get_index_name(self.dataset), | |||
| embeddings=self._embeddings, | |||
| content_payload_key='text' | |||
| content_payload_key='page_content' | |||
| ) | |||
| def _get_vector_store_class(self) -> type: | |||
| @@ -108,8 +115,8 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| def _is_origin(self): | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name'] | |||
| if class_prefix.startswith('Vector_'): | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| if not class_prefix.endswith('_Node'): | |||
| # original class_prefix | |||
| return True | |||
| @@ -0,0 +1,38 @@ | |||
| from langchain.vectorstores import Milvus | |||
| class MilvusVectorStore(Milvus): | |||
| def del_texts(self, where_filter: dict): | |||
| if not where_filter: | |||
| raise ValueError('where_filter must not be empty') | |||
| self._client.batch.delete_objects( | |||
| class_name=self._index_name, | |||
| where=where_filter, | |||
| output='minimal' | |||
| ) | |||
| def del_text(self, uuid: str) -> None: | |||
| self._client.data_object.delete( | |||
| uuid, | |||
| class_name=self._index_name | |||
| ) | |||
| def text_exists(self, uuid: str) -> bool: | |||
| result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({ | |||
| "path": ["doc_id"], | |||
| "operator": "Equal", | |||
| "valueText": uuid, | |||
| }).with_limit(1).do() | |||
| if "errors" in result: | |||
| raise ValueError(f"Error during query: {result['errors']}") | |||
| entries = result["data"]["Get"][self._index_name] | |||
| if len(entries) == 0: | |||
| return False | |||
| return True | |||
| def delete(self): | |||
| self._client.schema.delete_class(self._index_name) | |||
| @@ -1,10 +1,11 @@ | |||
| from typing import cast, Any | |||
| from langchain.schema import Document | |||
| from langchain.vectorstores import Qdrant | |||
| from qdrant_client.http.models import Filter, PointIdsList, FilterSelector | |||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||
| from core.index.vector_index.qdrant import Qdrant | |||
| class QdrantVectorStore(Qdrant): | |||
| def del_texts(self, filter: Filter): | |||