| from constants.languages import languages | from constants.languages import languages | ||||
| from core.rag.datasource.vdb.vector_factory import Vector | from core.rag.datasource.vdb.vector_factory import Vector | ||||
| from core.rag.datasource.vdb.vector_type import VectorType | |||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.helper import email as email_validate | from libs.helper import email as email_validate | ||||
| skipped_count = skipped_count + 1 | skipped_count = skipped_count + 1 | ||||
| continue | continue | ||||
| collection_name = '' | collection_name = '' | ||||
| if vector_type == "weaviate": | |||||
| if vector_type == VectorType.WEAVIATE: | |||||
| dataset_id = dataset.id | dataset_id = dataset.id | ||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | collection_name = Dataset.gen_collection_name_by_id(dataset_id) | ||||
| index_struct_dict = { | index_struct_dict = { | ||||
| "type": 'weaviate', | |||||
| "type": VectorType.WEAVIATE, | |||||
| "vector_store": {"class_prefix": collection_name} | "vector_store": {"class_prefix": collection_name} | ||||
| } | } | ||||
| dataset.index_struct = json.dumps(index_struct_dict) | dataset.index_struct = json.dumps(index_struct_dict) | ||||
| elif vector_type == "qdrant": | |||||
| elif vector_type == VectorType.QDRANT: | |||||
| if dataset.collection_binding_id: | if dataset.collection_binding_id: | ||||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | ||||
| filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ | filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ | ||||
| dataset_id = dataset.id | dataset_id = dataset.id | ||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | collection_name = Dataset.gen_collection_name_by_id(dataset_id) | ||||
| index_struct_dict = { | index_struct_dict = { | ||||
| "type": 'qdrant', | |||||
| "type": VectorType.QDRANT, | |||||
| "vector_store": {"class_prefix": collection_name} | "vector_store": {"class_prefix": collection_name} | ||||
| } | } | ||||
| dataset.index_struct = json.dumps(index_struct_dict) | dataset.index_struct = json.dumps(index_struct_dict) | ||||
| elif vector_type == "milvus": | |||||
| elif vector_type == VectorType.MILVUS: | |||||
| dataset_id = dataset.id | dataset_id = dataset.id | ||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | collection_name = Dataset.gen_collection_name_by_id(dataset_id) | ||||
| index_struct_dict = { | index_struct_dict = { | ||||
| "type": 'milvus', | |||||
| "type": VectorType.MILVUS, | |||||
| "vector_store": {"class_prefix": collection_name} | "vector_store": {"class_prefix": collection_name} | ||||
| } | } | ||||
| dataset.index_struct = json.dumps(index_struct_dict) | dataset.index_struct = json.dumps(index_struct_dict) | ||||
| elif vector_type == "relyt": | |||||
| elif vector_type == VectorType.RELYT: | |||||
| dataset_id = dataset.id | dataset_id = dataset.id | ||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | collection_name = Dataset.gen_collection_name_by_id(dataset_id) | ||||
| index_struct_dict = { | index_struct_dict = { | ||||
| "vector_store": {"class_prefix": collection_name} | "vector_store": {"class_prefix": collection_name} | ||||
| } | } | ||||
| dataset.index_struct = json.dumps(index_struct_dict) | dataset.index_struct = json.dumps(index_struct_dict) | ||||
| elif vector_type == "pgvector": | |||||
| elif vector_type == VectorType.PGVECTOR: | |||||
| dataset_id = dataset.id | dataset_id = dataset.id | ||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | collection_name = Dataset.gen_collection_name_by_id(dataset_id) | ||||
| index_struct_dict = { | index_struct_dict = { | ||||
| "type": 'pgvector', | |||||
| "type": VectorType.PGVECTOR, | |||||
| "vector_store": {"class_prefix": collection_name} | "vector_store": {"class_prefix": collection_name} | ||||
| } | } | ||||
| dataset.index_struct = json.dumps(index_struct_dict) | dataset.index_struct = json.dumps(index_struct_dict) | ||||
| else: | else: | ||||
| raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") | |||||
| raise ValueError(f"Vector store {vector_type} is not supported.") | |||||
| vector = Vector(dataset) | vector = Vector(dataset) | ||||
| click.echo(f"Start to migrate dataset {dataset.id}.") | click.echo(f"Start to migrate dataset {dataset.id}.") |
| from core.indexing_runner import IndexingRunner | from core.indexing_runner import IndexingRunner | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.provider_manager import ProviderManager | from core.provider_manager import ProviderManager | ||||
| from core.rag.datasource.vdb.vector_type import VectorType | |||||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | from core.rag.extractor.entity.extract_setting import ExtractSetting | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from fields.app_fields import related_app_list | from fields.app_fields import related_app_list | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self): | def get(self): | ||||
| vector_type = current_app.config['VECTOR_STORE'] | vector_type = current_app.config['VECTOR_STORE'] | ||||
| if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs", 'tidb_vector'}: | |||||
| return { | |||||
| 'retrieval_method': [ | |||||
| 'semantic_search' | |||||
| ] | |||||
| } | |||||
| elif vector_type in {"qdrant", "weaviate"}: | |||||
| return { | |||||
| 'retrieval_method': [ | |||||
| 'semantic_search', 'full_text_search', 'hybrid_search' | |||||
| ] | |||||
| } | |||||
| else: | |||||
| raise ValueError("Unsupported vector db type.") | |||||
| match vector_type: | |||||
| case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR: | |||||
| return { | |||||
| 'retrieval_method': [ | |||||
| 'semantic_search' | |||||
| ] | |||||
| } | |||||
| case VectorType.QDRANT | VectorType.WEAVIATE: | |||||
| return { | |||||
| 'retrieval_method': [ | |||||
| 'semantic_search', 'full_text_search', 'hybrid_search' | |||||
| ] | |||||
| } | |||||
| case _: | |||||
| raise ValueError(f"Unsupported vector db type {vector_type}.") | |||||
| class DatasetRetrievalSettingMockApi(Resource): | class DatasetRetrievalSettingMockApi(Resource): | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self, vector_type): | def get(self, vector_type): | ||||
| if vector_type in {'milvus', 'relyt', 'pgvector', 'tidb_vector'}: | |||||
| return { | |||||
| 'retrieval_method': [ | |||||
| 'semantic_search' | |||||
| ] | |||||
| } | |||||
| elif vector_type in {'qdrant', 'weaviate'}: | |||||
| return { | |||||
| 'retrieval_method': [ | |||||
| 'semantic_search', 'full_text_search', 'hybrid_search' | |||||
| ] | |||||
| } | |||||
| else: | |||||
| raise ValueError("Unsupported vector db type.") | |||||
| match vector_type: | |||||
| case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR: | |||||
| return { | |||||
| 'retrieval_method': [ | |||||
| 'semantic_search' | |||||
| ] | |||||
| } | |||||
| case VectorType.QDRANT | VectorType.WEAVIATE: | |||||
| return { | |||||
| 'retrieval_method': [ | |||||
| 'semantic_search', 'full_text_search', 'hybrid_search' | |||||
| ] | |||||
| } | |||||
| case _: | |||||
| raise ValueError(f"Unsupported vector db type {vector_type}.") | |||||
| class DatasetErrorDocs(Resource): | class DatasetErrorDocs(Resource): | ||||
| @setup_required | @setup_required |
| import json | |||||
| import logging | import logging | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| from flask import current_app | |||||
| from pydantic import BaseModel, root_validator | from pydantic import BaseModel, root_validator | ||||
| from pymilvus import MilvusClient, MilvusException, connections | from pymilvus import MilvusClient, MilvusException, connections | ||||
| from core.rag.datasource.entity.embedding import Embeddings | |||||
| from core.rag.datasource.vdb.field import Field | from core.rag.datasource.vdb.field import Field | ||||
| from core.rag.datasource.vdb.vector_base import BaseVector | from core.rag.datasource.vdb.vector_base import BaseVector | ||||
| from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | |||||
| from core.rag.datasource.vdb.vector_type import VectorType | |||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import Dataset | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| self._fields = [] | self._fields = [] | ||||
| def get_type(self) -> str: | def get_type(self) -> str: | ||||
| return 'milvus' | |||||
| return VectorType.MILVUS | |||||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | ||||
| index_params = { | index_params = { | ||||
| schema=schema, index_param=index_params, | schema=schema, index_param=index_params, | ||||
| consistency_level=self._consistency_level) | consistency_level=self._consistency_level) | ||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | redis_client.set(collection_exist_cache_key, 1, ex=3600) | ||||
| def _init_client(self, config) -> MilvusClient: | def _init_client(self, config) -> MilvusClient: | ||||
| if config.secure: | if config.secure: | ||||
| uri = "https://" + str(config.host) + ":" + str(config.port) | uri = "https://" + str(config.host) + ":" + str(config.port) | ||||
| else: | else: | ||||
| uri = "http://" + str(config.host) + ":" + str(config.port) | uri = "http://" + str(config.host) + ":" + str(config.port) | ||||
| client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database) | |||||
| client = MilvusClient(uri=uri, user=config.user, password=config.password, db_name=config.database) | |||||
| return client | return client | ||||
| class MilvusVectorFactory(AbstractVectorFactory): | |||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: | |||||
| if dataset.index_struct_dict: | |||||
| class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||||
| dataset.index_struct = json.dumps( | |||||
| self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) | |||||
| config = current_app.config | |||||
| return MilvusVector( | |||||
| collection_name=collection_name, | |||||
| config=MilvusConfig( | |||||
| host=config.get('MILVUS_HOST'), | |||||
| port=config.get('MILVUS_PORT'), | |||||
| user=config.get('MILVUS_USER'), | |||||
| password=config.get('MILVUS_PASSWORD'), | |||||
| secure=config.get('MILVUS_SECURE'), | |||||
| database=config.get('MILVUS_DATABASE'), | |||||
| ) | |||||
| ) |
| import json | |||||
| import logging | import logging | ||||
| from typing import Any | from typing import Any | ||||
| from uuid import UUID, uuid4 | from uuid import UUID, uuid4 | ||||
| from flask import current_app | |||||
| from numpy import ndarray | from numpy import ndarray | ||||
| from pgvecto_rs.sqlalchemy import Vector | from pgvecto_rs.sqlalchemy import Vector | ||||
| from pydantic import BaseModel, root_validator | from pydantic import BaseModel, root_validator | ||||
| from sqlalchemy.dialects import postgresql | from sqlalchemy.dialects import postgresql | ||||
| from sqlalchemy.orm import Mapped, Session, mapped_column | from sqlalchemy.orm import Mapped, Session, mapped_column | ||||
| from core.rag.datasource.entity.embedding import Embeddings | |||||
| from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM | from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM | ||||
| from core.rag.datasource.vdb.vector_base import BaseVector | from core.rag.datasource.vdb.vector_base import BaseVector | ||||
| from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | |||||
| from core.rag.datasource.vdb.vector_type import VectorType | |||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import Dataset | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| self._distance_op = "<=>" | self._distance_op = "<=>" | ||||
| def get_type(self) -> str: | def get_type(self) -> str: | ||||
| return 'pgvecto-rs' | |||||
| return VectorType.PGVECTO_RS | |||||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | ||||
| self.create_collection(len(embeddings[0])) | self.create_collection(len(embeddings[0])) | ||||
| # docs.append(doc) | # docs.append(doc) | ||||
| # return docs | # return docs | ||||
| return [] | return [] | ||||
| class PGVectoRSFactory(AbstractVectorFactory): | |||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS: | |||||
| if dataset.index_struct_dict: | |||||
| class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix.lower() | |||||
| else: | |||||
| dataset_id = dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() | |||||
| dataset.index_struct = json.dumps( | |||||
| self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) | |||||
| dim = len(embeddings.embed_query("pgvecto_rs")) | |||||
| config = current_app.config | |||||
| return PGVectoRS( | |||||
| collection_name=collection_name, | |||||
| config=PgvectoRSConfig( | |||||
| host=config.get('PGVECTO_RS_HOST'), | |||||
| port=config.get('PGVECTO_RS_PORT'), | |||||
| user=config.get('PGVECTO_RS_USER'), | |||||
| password=config.get('PGVECTO_RS_PASSWORD'), | |||||
| database=config.get('PGVECTO_RS_DATABASE'), | |||||
| ), | |||||
| dim=dim | |||||
| ) |
| import psycopg2.extras | import psycopg2.extras | ||||
| import psycopg2.pool | import psycopg2.pool | ||||
| from flask import current_app | |||||
| from pydantic import BaseModel, root_validator | from pydantic import BaseModel, root_validator | ||||
| from core.rag.datasource.entity.embedding import Embeddings | |||||
| from core.rag.datasource.vdb.vector_base import BaseVector | from core.rag.datasource.vdb.vector_base import BaseVector | ||||
| from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | |||||
| from core.rag.datasource.vdb.vector_type import VectorType | |||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import Dataset | |||||
| class PGVectorConfig(BaseModel): | class PGVectorConfig(BaseModel): | ||||
| self.table_name = f"embedding_{collection_name}" | self.table_name = f"embedding_{collection_name}" | ||||
| def get_type(self) -> str: | def get_type(self) -> str: | ||||
| return "pgvector" | |||||
| return VectorType.PGVECTOR | |||||
| def _create_connection_pool(self, config: PGVectorConfig): | def _create_connection_pool(self, config: PGVectorConfig): | ||||
| return psycopg2.pool.SimpleConnectionPool( | return psycopg2.pool.SimpleConnectionPool( | ||||
| cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension)) | cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension)) | ||||
| # TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing | # TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing | ||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | redis_client.set(collection_exist_cache_key, 1, ex=3600) | ||||
| class PGVectorFactory(AbstractVectorFactory): | |||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVector: | |||||
| if dataset.index_struct_dict: | |||||
| class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||||
| dataset.index_struct = json.dumps( | |||||
| self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) | |||||
| config = current_app.config | |||||
| return PGVector( | |||||
| collection_name=collection_name, | |||||
| config=PGVectorConfig( | |||||
| host=config.get("PGVECTOR_HOST"), | |||||
| port=config.get("PGVECTOR_PORT"), | |||||
| user=config.get("PGVECTOR_USER"), | |||||
| password=config.get("PGVECTOR_PASSWORD"), | |||||
| database=config.get("PGVECTOR_DATABASE"), | |||||
| ), | |||||
| ) |
| import json | |||||
| import os | import os | ||||
| import uuid | import uuid | ||||
| from collections.abc import Generator, Iterable, Sequence | from collections.abc import Generator, Iterable, Sequence | ||||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | from typing import TYPE_CHECKING, Any, Optional, Union, cast | ||||
| import qdrant_client | import qdrant_client | ||||
| from flask import current_app | |||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from qdrant_client.http import models as rest | from qdrant_client.http import models as rest | ||||
| from qdrant_client.http.models import ( | from qdrant_client.http.models import ( | ||||
| ) | ) | ||||
| from qdrant_client.local.qdrant_local import QdrantLocal | from qdrant_client.local.qdrant_local import QdrantLocal | ||||
| from core.rag.datasource.entity.embedding import Embeddings | |||||
| from core.rag.datasource.vdb.field import Field | from core.rag.datasource.vdb.field import Field | ||||
| from core.rag.datasource.vdb.vector_base import BaseVector | from core.rag.datasource.vdb.vector_base import BaseVector | ||||
| from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | |||||
| from core.rag.datasource.vdb.vector_type import VectorType | |||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_database import db | |||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import Dataset, DatasetCollectionBinding | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from qdrant_client import grpc # noqa | from qdrant_client import grpc # noqa | ||||
| self._group_id = group_id | self._group_id = group_id | ||||
| def get_type(self) -> str: | def get_type(self) -> str: | ||||
| return 'qdrant' | |||||
| return VectorType.QDRANT | |||||
| def to_index_struct(self) -> dict: | def to_index_struct(self) -> dict: | ||||
| return { | return { | ||||
| page_content=scored_point.payload.get(content_payload_key), | page_content=scored_point.payload.get(content_payload_key), | ||||
| metadata=scored_point.payload.get(metadata_payload_key) or {}, | metadata=scored_point.payload.get(metadata_payload_key) or {}, | ||||
| ) | ) | ||||
| class QdrantVectorFactory(AbstractVectorFactory): | |||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: | |||||
| if dataset.collection_binding_id: | |||||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||||
| filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ | |||||
| one_or_none() | |||||
| if dataset_collection_binding: | |||||
| collection_name = dataset_collection_binding.collection_name | |||||
| else: | |||||
| raise ValueError('Dataset Collection Bindings is not exist!') | |||||
| else: | |||||
| if dataset.index_struct_dict: | |||||
| class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||||
| if not dataset.index_struct_dict: | |||||
| dataset.index_struct = json.dumps( | |||||
| self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) | |||||
| config = current_app.config | |||||
| return QdrantVector( | |||||
| collection_name=collection_name, | |||||
| group_id=dataset.id, | |||||
| config=QdrantConfig( | |||||
| endpoint=config.get('QDRANT_URL'), | |||||
| api_key=config.get('QDRANT_API_KEY'), | |||||
| root_path=config.root_path, | |||||
| timeout=config.get('QDRANT_CLIENT_TIMEOUT'), | |||||
| grpc_port=config.get('QDRANT_GRPC_PORT'), | |||||
| prefer_grpc=config.get('QDRANT_GRPC_ENABLED') | |||||
| ) | |||||
| ) |
| import json | |||||
| import uuid | import uuid | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from flask import current_app | |||||
| from pydantic import BaseModel, root_validator | from pydantic import BaseModel, root_validator | ||||
| from sqlalchemy import Column, Sequence, String, Table, create_engine, insert | from sqlalchemy import Column, Sequence, String, Table, create_engine, insert | ||||
| from sqlalchemy import text as sql_text | from sqlalchemy import text as sql_text | ||||
| from sqlalchemy.dialects.postgresql import JSON, TEXT | from sqlalchemy.dialects.postgresql import JSON, TEXT | ||||
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| from core.rag.datasource.entity.embedding import Embeddings | |||||
| from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | |||||
| from core.rag.datasource.vdb.vector_type import VectorType | |||||
| from models.dataset import Dataset | |||||
| try: | try: | ||||
| from sqlalchemy.orm import declarative_base | from sqlalchemy.orm import declarative_base | ||||
| except ImportError: | except ImportError: | ||||
| self._group_id = group_id | self._group_id = group_id | ||||
| def get_type(self) -> str: | def get_type(self) -> str: | ||||
| return 'relyt' | |||||
| return VectorType.RELYT | |||||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | ||||
| index_params = {} | index_params = {} | ||||
| return docs | return docs | ||||
| def similarity_search_with_score_by_vector( | def similarity_search_with_score_by_vector( | ||||
| self, | |||||
| embedding: list[float], | |||||
| k: int = 4, | |||||
| filter: Optional[dict] = None, | |||||
| self, | |||||
| embedding: list[float], | |||||
| k: int = 4, | |||||
| filter: Optional[dict] = None, | |||||
| ) -> list[tuple[Document, float]]: | ) -> list[tuple[Document, float]]: | ||||
| # Add the filter if provided | # Add the filter if provided | ||||
| try: | try: | ||||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | ||||
| # milvus/zilliz/relyt doesn't support bm25 search | # milvus/zilliz/relyt doesn't support bm25 search | ||||
| return [] | return [] | ||||
| class RelytVectorFactory(AbstractVectorFactory): | |||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector: | |||||
| if dataset.index_struct_dict: | |||||
| class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||||
| dataset.index_struct = json.dumps( | |||||
| self.gen_index_struct_dict(VectorType.RELYT, collection_name)) | |||||
| config = current_app.config | |||||
| return RelytVector( | |||||
| collection_name=collection_name, | |||||
| config=RelytConfig( | |||||
| host=config.get('RELYT_HOST'), | |||||
| port=config.get('RELYT_PORT'), | |||||
| user=config.get('RELYT_USER'), | |||||
| password=config.get('RELYT_PASSWORD'), | |||||
| database=config.get('RELYT_DATABASE'), | |||||
| ), | |||||
| group_id=dataset.id | |||||
| ) |
| from typing import Any | from typing import Any | ||||
| import sqlalchemy | import sqlalchemy | ||||
| from flask import current_app | |||||
| from pydantic import BaseModel, root_validator | from pydantic import BaseModel, root_validator | ||||
| from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert | from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert | ||||
| from sqlalchemy import text as sql_text | from sqlalchemy import text as sql_text | ||||
| from sqlalchemy.orm import Session, declarative_base | from sqlalchemy.orm import Session, declarative_base | ||||
| from core.rag.datasource.entity.embedding import Embeddings | |||||
| from core.rag.datasource.vdb.vector_base import BaseVector | from core.rag.datasource.vdb.vector_base import BaseVector | ||||
| from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | |||||
| from core.rag.datasource.vdb.vector_type import VectorType | |||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import Dataset | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class TiDBVector(BaseVector): | class TiDBVector(BaseVector): | ||||
| def get_type(self) -> str: | |||||
| return VectorType.TIDB_VECTOR | |||||
| def _table(self, dim: int) -> Table: | def _table(self, dim: int) -> Table: | ||||
| from tidb_vector.sqlalchemy import VectorType | from tidb_vector.sqlalchemy import VectorType | ||||
| return Table( | return Table( | ||||
| with Session(self._engine) as session: | with Session(self._engine) as session: | ||||
| session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) | session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) | ||||
| session.commit() | session.commit() | ||||
| class TiDBVectorFactory(AbstractVectorFactory): | |||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector: | |||||
| if dataset.index_struct_dict: | |||||
| class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix.lower() | |||||
| else: | |||||
| dataset_id = dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() | |||||
| dataset.index_struct = json.dumps( | |||||
| self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) | |||||
| config = current_app.config | |||||
| return TiDBVector( | |||||
| collection_name=collection_name, | |||||
| config=TiDBVectorConfig( | |||||
| host=config.get('TIDB_VECTOR_HOST'), | |||||
| port=config.get('TIDB_VECTOR_PORT'), | |||||
| user=config.get('TIDB_VECTOR_USER'), | |||||
| password=config.get('TIDB_VECTOR_PASSWORD'), | |||||
| database=config.get('TIDB_VECTOR_DATABASE'), | |||||
| ), | |||||
| ) |
| def __init__(self, collection_name: str): | def __init__(self, collection_name: str): | ||||
| self._collection_name = collection_name | self._collection_name = collection_name | ||||
| @abstractmethod | |||||
| def get_type(self) -> str: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | @abstractmethod | ||||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | ||||
| raise NotImplementedError | raise NotImplementedError |
| import json | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import Any | from typing import Any | ||||
| from flask import current_app | from flask import current_app | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.rag.datasource.entity.embedding import Embeddings | from core.rag.datasource.entity.embedding import Embeddings | ||||
| from core.rag.datasource.vdb.vector_base import BaseVector | from core.rag.datasource.vdb.vector_base import BaseVector | ||||
| from core.rag.datasource.vdb.vector_type import VectorType | |||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_database import db | |||||
| from models.dataset import Dataset, DatasetCollectionBinding | |||||
| from models.dataset import Dataset | |||||
| class AbstractVectorFactory(ABC): | |||||
| @abstractmethod | |||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector: | |||||
| raise NotImplementedError | |||||
| @staticmethod | |||||
| def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict: | |||||
| index_struct_dict = { | |||||
| "type": vector_type, | |||||
| "vector_store": {"class_prefix": collection_name} | |||||
| } | |||||
| return index_struct_dict | |||||
| class Vector: | class Vector: | ||||
| if not vector_type: | if not vector_type: | ||||
| raise ValueError("Vector store must be specified.") | raise ValueError("Vector store must be specified.") | ||||
| if vector_type == "weaviate": | |||||
| from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector | |||||
| if self._dataset.index_struct_dict: | |||||
| class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = self._dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||||
| index_struct_dict = { | |||||
| "type": 'weaviate', | |||||
| "vector_store": {"class_prefix": collection_name} | |||||
| } | |||||
| self._dataset.index_struct = json.dumps(index_struct_dict) | |||||
| return WeaviateVector( | |||||
| collection_name=collection_name, | |||||
| config=WeaviateConfig( | |||||
| endpoint=config.get('WEAVIATE_ENDPOINT'), | |||||
| api_key=config.get('WEAVIATE_API_KEY'), | |||||
| batch_size=int(config.get('WEAVIATE_BATCH_SIZE')) | |||||
| ), | |||||
| attributes=self._attributes | |||||
| ) | |||||
| elif vector_type == "qdrant": | |||||
| from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector | |||||
| if self._dataset.collection_binding_id: | |||||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||||
| filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \ | |||||
| one_or_none() | |||||
| if dataset_collection_binding: | |||||
| collection_name = dataset_collection_binding.collection_name | |||||
| else: | |||||
| raise ValueError('Dataset Collection Bindings is not exist!') | |||||
| else: | |||||
| if self._dataset.index_struct_dict: | |||||
| class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = self._dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||||
| if not self._dataset.index_struct_dict: | |||||
| index_struct_dict = { | |||||
| "type": 'qdrant', | |||||
| "vector_store": {"class_prefix": collection_name} | |||||
| } | |||||
| self._dataset.index_struct = json.dumps(index_struct_dict) | |||||
| return QdrantVector( | |||||
| collection_name=collection_name, | |||||
| group_id=self._dataset.id, | |||||
| config=QdrantConfig( | |||||
| endpoint=config.get('QDRANT_URL'), | |||||
| api_key=config.get('QDRANT_API_KEY'), | |||||
| root_path=current_app.root_path, | |||||
| timeout=config.get('QDRANT_CLIENT_TIMEOUT'), | |||||
| grpc_port=config.get('QDRANT_GRPC_PORT'), | |||||
| prefer_grpc=config.get('QDRANT_GRPC_ENABLED') | |||||
| ) | |||||
| ) | |||||
| elif vector_type == "milvus": | |||||
| from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector | |||||
| if self._dataset.index_struct_dict: | |||||
| class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = self._dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||||
| index_struct_dict = { | |||||
| "type": 'milvus', | |||||
| "vector_store": {"class_prefix": collection_name} | |||||
| } | |||||
| self._dataset.index_struct = json.dumps(index_struct_dict) | |||||
| return MilvusVector( | |||||
| collection_name=collection_name, | |||||
| config=MilvusConfig( | |||||
| host=config.get('MILVUS_HOST'), | |||||
| port=config.get('MILVUS_PORT'), | |||||
| user=config.get('MILVUS_USER'), | |||||
| password=config.get('MILVUS_PASSWORD'), | |||||
| secure=config.get('MILVUS_SECURE'), | |||||
| database=config.get('MILVUS_DATABASE'), | |||||
| ) | |||||
| ) | |||||
| elif vector_type == "relyt": | |||||
| from core.rag.datasource.vdb.relyt.relyt_vector import RelytConfig, RelytVector | |||||
| if self._dataset.index_struct_dict: | |||||
| class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = self._dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||||
| index_struct_dict = { | |||||
| "type": 'relyt', | |||||
| "vector_store": {"class_prefix": collection_name} | |||||
| } | |||||
| self._dataset.index_struct = json.dumps(index_struct_dict) | |||||
| return RelytVector( | |||||
| collection_name=collection_name, | |||||
| config=RelytConfig( | |||||
| host=config.get('RELYT_HOST'), | |||||
| port=config.get('RELYT_PORT'), | |||||
| user=config.get('RELYT_USER'), | |||||
| password=config.get('RELYT_PASSWORD'), | |||||
| database=config.get('RELYT_DATABASE'), | |||||
| ), | |||||
| group_id=self._dataset.id | |||||
| ) | |||||
| elif vector_type == "pgvecto_rs": | |||||
| from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig | |||||
| if self._dataset.index_struct_dict: | |||||
| class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix.lower() | |||||
| else: | |||||
| dataset_id = self._dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() | |||||
| index_struct_dict = { | |||||
| "type": 'pgvecto_rs', | |||||
| "vector_store": {"class_prefix": collection_name} | |||||
| } | |||||
| self._dataset.index_struct = json.dumps(index_struct_dict) | |||||
| dim = len(self._embeddings.embed_query("pgvecto_rs")) | |||||
| return PGVectoRS( | |||||
| collection_name=collection_name, | |||||
| config=PgvectoRSConfig( | |||||
| host=config.get('PGVECTO_RS_HOST'), | |||||
| port=config.get('PGVECTO_RS_PORT'), | |||||
| user=config.get('PGVECTO_RS_USER'), | |||||
| password=config.get('PGVECTO_RS_PASSWORD'), | |||||
| database=config.get('PGVECTO_RS_DATABASE'), | |||||
| ), | |||||
| dim=dim | |||||
| ) | |||||
| elif vector_type == "pgvector": | |||||
| from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig | |||||
| if self._dataset.index_struct_dict: | |||||
| class_prefix: str = self._dataset.index_struct_dict["vector_store"]["class_prefix"] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = self._dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||||
| index_struct_dict = { | |||||
| "type": "pgvector", | |||||
| "vector_store": {"class_prefix": collection_name}} | |||||
| self._dataset.index_struct = json.dumps(index_struct_dict) | |||||
| return PGVector( | |||||
| collection_name=collection_name, | |||||
| config=PGVectorConfig( | |||||
| host=config.get("PGVECTOR_HOST"), | |||||
| port=config.get("PGVECTOR_PORT"), | |||||
| user=config.get("PGVECTOR_USER"), | |||||
| password=config.get("PGVECTOR_PASSWORD"), | |||||
| database=config.get("PGVECTOR_DATABASE"), | |||||
| ), | |||||
| ) | |||||
| elif vector_type == "tidb_vector": | |||||
| from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig | |||||
| if self._dataset.index_struct_dict: | |||||
| class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix.lower() | |||||
| else: | |||||
| dataset_id = self._dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() | |||||
| index_struct_dict = { | |||||
| "type": 'tidb_vector', | |||||
| "vector_store": {"class_prefix": collection_name} | |||||
| } | |||||
| self._dataset.index_struct = json.dumps(index_struct_dict) | |||||
| return TiDBVector( | |||||
| collection_name=collection_name, | |||||
| config=TiDBVectorConfig( | |||||
| host=config.get('TIDB_VECTOR_HOST'), | |||||
| port=config.get('TIDB_VECTOR_PORT'), | |||||
| user=config.get('TIDB_VECTOR_USER'), | |||||
| password=config.get('TIDB_VECTOR_PASSWORD'), | |||||
| database=config.get('TIDB_VECTOR_DATABASE'), | |||||
| ), | |||||
| ) | |||||
| else: | |||||
| raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") | |||||
| vector_factory_cls = self.get_vector_factory(vector_type) | |||||
| return vector_factory_cls().init_vector(self._dataset, self._attributes, self._embeddings) | |||||
| @staticmethod | |||||
| def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: | |||||
| match vector_type: | |||||
| case VectorType.MILVUS: | |||||
| from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory | |||||
| return MilvusVectorFactory | |||||
| case VectorType.PGVECTOR: | |||||
| from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory | |||||
| return PGVectorFactory | |||||
| case VectorType.PGVECTO_RS: | |||||
| from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory | |||||
| return PGVectoRSFactory | |||||
| case VectorType.QDRANT: | |||||
| from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory | |||||
| return QdrantVectorFactory | |||||
| case VectorType.RELYT: | |||||
| from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory | |||||
| return RelytVectorFactory | |||||
| case VectorType.TIDB_VECTOR: | |||||
| from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory | |||||
| return TiDBVectorFactory | |||||
| case VectorType.WEAVIATE: | |||||
| from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory | |||||
| return WeaviateVectorFactory | |||||
| case _: | |||||
| raise ValueError(f"Vector store {vector_type} is not supported.") | |||||
| def create(self, texts: list = None, **kwargs): | def create(self, texts: list = None, **kwargs): | ||||
| if texts: | if texts: |
| from enum import Enum | |||||
| class VectorType(str, Enum): | |||||
| MILVUS = 'milvus' | |||||
| PGVECTOR = 'pgvector' | |||||
| PGVECTO_RS = 'pgvecto-rs' | |||||
| QDRANT = 'qdrant' | |||||
| RELYT = 'relyt' | |||||
| TIDB_VECTOR = 'tidb_vector' | |||||
| WEAVIATE = 'weaviate' |
| import datetime | import datetime | ||||
| import json | |||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| import requests | import requests | ||||
| import weaviate | import weaviate | ||||
| from flask import current_app | |||||
| from pydantic import BaseModel, root_validator | from pydantic import BaseModel, root_validator | ||||
| from core.rag.datasource.entity.embedding import Embeddings | |||||
| from core.rag.datasource.vdb.field import Field | from core.rag.datasource.vdb.field import Field | ||||
| from core.rag.datasource.vdb.vector_base import BaseVector | from core.rag.datasource.vdb.vector_base import BaseVector | ||||
| from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | |||||
| from core.rag.datasource.vdb.vector_type import VectorType | |||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import Dataset | from models.dataset import Dataset | ||||
| return client | return client | ||||
| def get_type(self) -> str: | def get_type(self) -> str: | ||||
| return 'weaviate' | |||||
| return VectorType.WEAVIATE | |||||
| def get_collection_name(self, dataset: Dataset) -> str: | def get_collection_name(self, dataset: Dataset) -> str: | ||||
| if dataset.index_struct_dict: | if dataset.index_struct_dict: | ||||
| if isinstance(value, datetime.datetime): | if isinstance(value, datetime.datetime): | ||||
| return value.isoformat() | return value.isoformat() | ||||
| return value | return value | ||||
| class WeaviateVectorFactory(AbstractVectorFactory): | |||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector: | |||||
| if dataset.index_struct_dict: | |||||
| class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||||
| dataset.index_struct = json.dumps( | |||||
| self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) | |||||
| return WeaviateVector( | |||||
| collection_name=collection_name, | |||||
| config=WeaviateConfig( | |||||
| endpoint=current_app.config.get('WEAVIATE_ENDPOINT'), | |||||
| api_key=current_app.config.get('WEAVIATE_API_KEY'), | |||||
| batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE')) | |||||
| ), | |||||
| attributes=attributes | |||||
| ) |