Co-authored-by: 璟义 <yangshangpo.ysp@alibaba-inc.com>tags/0.12.0
| @@ -234,6 +234,10 @@ ANALYTICDB_ACCOUNT=testaccount | |||
| ANALYTICDB_PASSWORD=testpassword | |||
| ANALYTICDB_NAMESPACE=dify | |||
| ANALYTICDB_NAMESPACE_PASSWORD=difypassword | |||
| ANALYTICDB_HOST=gp-test.aliyuncs.com | |||
| ANALYTICDB_PORT=5432 | |||
| ANALYTICDB_MIN_CONNECTION=1 | |||
| ANALYTICDB_MAX_CONNECTION=5 | |||
| # OpenSearch configuration | |||
| OPENSEARCH_HOST=127.0.0.1 | |||
| @@ -1,6 +1,6 @@ | |||
| from typing import Optional | |||
| from pydantic import BaseModel, Field | |||
| from pydantic import BaseModel, Field, PositiveInt | |||
| class AnalyticdbConfig(BaseModel): | |||
| @@ -40,3 +40,11 @@ class AnalyticdbConfig(BaseModel): | |||
| description="The password for accessing the specified namespace within the AnalyticDB instance" | |||
| " (if namespace feature is enabled).", | |||
| ) | |||
| ANALYTICDB_HOST: Optional[str] = Field( | |||
| default=None, description="The host of the AnalyticDB instance you want to connect to." | |||
| ) | |||
| ANALYTICDB_PORT: PositiveInt = Field( | |||
| default=5432, description="The port of the AnalyticDB instance you want to connect to." | |||
| ) | |||
| ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.") | |||
| ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.") | |||
| @@ -1,310 +1,62 @@ | |||
| import json | |||
| from typing import Any | |||
| from pydantic import BaseModel | |||
| _import_err_msg = ( | |||
| "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, " | |||
| "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`" | |||
| ) | |||
| from configs import dify_config | |||
| from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import ( | |||
| AnalyticdbVectorOpenAPI, | |||
| AnalyticdbVectorOpenAPIConfig, | |||
| ) | |||
| from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig | |||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||
| from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | |||
| from core.rag.datasource.vdb.vector_type import VectorType | |||
| from core.rag.embedding.embedding_base import Embeddings | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset | |||
| class AnalyticdbConfig(BaseModel): | |||
| access_key_id: str | |||
| access_key_secret: str | |||
| region_id: str | |||
| instance_id: str | |||
| account: str | |||
| account_password: str | |||
| namespace: str = ("dify",) | |||
| namespace_password: str = (None,) | |||
| metrics: str = ("cosine",) | |||
| read_timeout: int = 60000 | |||
| def to_analyticdb_client_params(self): | |||
| return { | |||
| "access_key_id": self.access_key_id, | |||
| "access_key_secret": self.access_key_secret, | |||
| "region_id": self.region_id, | |||
| "read_timeout": self.read_timeout, | |||
| } | |||
| class AnalyticdbVector(BaseVector): | |||
| def __init__(self, collection_name: str, config: AnalyticdbConfig): | |||
| self._collection_name = collection_name.lower() | |||
| try: | |||
| from alibabacloud_gpdb20160503.client import Client | |||
| from alibabacloud_tea_openapi import models as open_api_models | |||
| except: | |||
| raise ImportError(_import_err_msg) | |||
| self.config = config | |||
| self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params()) | |||
| self._client = Client(self._client_config) | |||
| self._initialize() | |||
| def _initialize(self) -> None: | |||
| cache_key = f"vector_indexing_{self.config.instance_id}" | |||
| lock_name = f"{cache_key}_lock" | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}" | |||
| if redis_client.get(collection_exist_cache_key): | |||
| return | |||
| self._initialize_vector_database() | |||
| self._create_namespace_if_not_exists() | |||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||
| def _initialize_vector_database(self) -> None: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| request = gpdb_20160503_models.InitVectorDatabaseRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| manager_account=self.config.account, | |||
| manager_account_password=self.config.account_password, | |||
| ) | |||
| self._client.init_vector_database(request) | |||
| def _create_namespace_if_not_exists(self) -> None: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| from Tea.exceptions import TeaException | |||
| try: | |||
| request = gpdb_20160503_models.DescribeNamespaceRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| manager_account=self.config.account, | |||
| manager_account_password=self.config.account_password, | |||
| ) | |||
| self._client.describe_namespace(request) | |||
| except TeaException as e: | |||
| if e.statusCode == 404: | |||
| request = gpdb_20160503_models.CreateNamespaceRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| manager_account=self.config.account, | |||
| manager_account_password=self.config.account_password, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| ) | |||
| self._client.create_namespace(request) | |||
| else: | |||
| raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") | |||
| def _create_collection_if_not_exists(self, embedding_dimension: int): | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| from Tea.exceptions import TeaException | |||
| cache_key = f"vector_indexing_{self._collection_name}" | |||
| lock_name = f"{cache_key}_lock" | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||
| if redis_client.get(collection_exist_cache_key): | |||
| return | |||
| try: | |||
| request = gpdb_20160503_models.DescribeCollectionRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| ) | |||
| self._client.describe_collection(request) | |||
| except TeaException as e: | |||
| if e.statusCode == 404: | |||
| metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}' | |||
| full_text_retrieval_fields = "page_content" | |||
| request = gpdb_20160503_models.CreateCollectionRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| manager_account=self.config.account, | |||
| manager_account_password=self.config.account_password, | |||
| namespace=self.config.namespace, | |||
| collection=self._collection_name, | |||
| dimension=embedding_dimension, | |||
| metrics=self.config.metrics, | |||
| metadata=metadata, | |||
| full_text_retrieval_fields=full_text_retrieval_fields, | |||
| ) | |||
| self._client.create_collection(request) | |||
| else: | |||
| raise ValueError(f"failed to create collection {self._collection_name}: {e}") | |||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||
| def __init__( | |||
| self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig | |||
| ): | |||
| super().__init__(collection_name) | |||
| if api_config is not None: | |||
| self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config) | |||
| else: | |||
| self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config) | |||
| def get_type(self) -> str: | |||
| return VectorType.ANALYTICDB | |||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||
| dimension = len(embeddings[0]) | |||
| self._create_collection_if_not_exists(dimension) | |||
| self.add_texts(texts, embeddings) | |||
| self.analyticdb_vector._create_collection_if_not_exists(dimension) | |||
| self.analyticdb_vector.add_texts(texts, embeddings) | |||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] | |||
| for doc, embedding in zip(documents, embeddings, strict=True): | |||
| metadata = { | |||
| "ref_doc_id": doc.metadata["doc_id"], | |||
| "page_content": doc.page_content, | |||
| "metadata_": json.dumps(doc.metadata), | |||
| } | |||
| rows.append( | |||
| gpdb_20160503_models.UpsertCollectionDataRequestRows( | |||
| vector=embedding, | |||
| metadata=metadata, | |||
| ) | |||
| ) | |||
| request = gpdb_20160503_models.UpsertCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| rows=rows, | |||
| ) | |||
| self._client.upsert_collection_data(request) | |||
| def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||
| self.analyticdb_vector.add_texts(texts, embeddings) | |||
| def text_exists(self, id: str) -> bool: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| metrics=self.config.metrics, | |||
| include_values=True, | |||
| vector=None, | |||
| content=None, | |||
| top_k=1, | |||
| filter=f"ref_doc_id='{id}'", | |||
| ) | |||
| response = self._client.query_collection_data(request) | |||
| return len(response.body.matches.match) > 0 | |||
| return self.analyticdb_vector.text_exists(id) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| ids_str = ",".join(f"'{id}'" for id in ids) | |||
| ids_str = f"({ids_str})" | |||
| request = gpdb_20160503_models.DeleteCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| collection_data=None, | |||
| collection_data_filter=f"ref_doc_id IN {ids_str}", | |||
| ) | |||
| self._client.delete_collection_data(request) | |||
| self.analyticdb_vector.delete_by_ids(ids) | |||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| request = gpdb_20160503_models.DeleteCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| collection_data=None, | |||
| collection_data_filter=f"metadata_ ->> '{key}' = '{value}'", | |||
| ) | |||
| self._client.delete_collection_data(request) | |||
| self.analyticdb_vector.delete_by_metadata_field(key, value) | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| score_threshold = kwargs.get("score_threshold") or 0.0 | |||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| include_values=kwargs.pop("include_values", True), | |||
| metrics=self.config.metrics, | |||
| vector=query_vector, | |||
| content=None, | |||
| top_k=kwargs.get("top_k", 4), | |||
| filter=None, | |||
| ) | |||
| response = self._client.query_collection_data(request) | |||
| documents = [] | |||
| for match in response.body.matches.match: | |||
| if match.score > score_threshold: | |||
| metadata = json.loads(match.metadata.get("metadata_")) | |||
| metadata["score"] = match.score | |||
| doc = Document( | |||
| page_content=match.metadata.get("page_content"), | |||
| metadata=metadata, | |||
| ) | |||
| documents.append(doc) | |||
| documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) | |||
| return documents | |||
| return self.analyticdb_vector.search_by_vector(query_vector) | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| include_values=kwargs.pop("include_values", True), | |||
| metrics=self.config.metrics, | |||
| vector=None, | |||
| content=query, | |||
| top_k=kwargs.get("top_k", 4), | |||
| filter=None, | |||
| ) | |||
| response = self._client.query_collection_data(request) | |||
| documents = [] | |||
| for match in response.body.matches.match: | |||
| if match.score > score_threshold: | |||
| metadata = json.loads(match.metadata.get("metadata_")) | |||
| metadata["score"] = match.score | |||
| doc = Document( | |||
| page_content=match.metadata.get("page_content"), | |||
| vector=match.metadata.get("vector"), | |||
| metadata=metadata, | |||
| ) | |||
| documents.append(doc) | |||
| documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) | |||
| return documents | |||
| return self.analyticdb_vector.search_by_full_text(query, **kwargs) | |||
| def delete(self) -> None: | |||
| try: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| request = gpdb_20160503_models.DeleteCollectionRequest( | |||
| collection=self._collection_name, | |||
| dbinstance_id=self.config.instance_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| region_id=self.config.region_id, | |||
| ) | |||
| self._client.delete_collection(request) | |||
| except Exception as e: | |||
| raise e | |||
| self.analyticdb_vector.delete() | |||
| class AnalyticdbVectorFactory(AbstractVectorFactory): | |||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings): | |||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector: | |||
| if dataset.index_struct_dict: | |||
| class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] | |||
| collection_name = class_prefix.lower() | |||
| @@ -313,26 +65,9 @@ class AnalyticdbVectorFactory(AbstractVectorFactory): | |||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() | |||
| dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)) | |||
| # handle optional params | |||
| if dify_config.ANALYTICDB_KEY_ID is None: | |||
| raise ValueError("ANALYTICDB_KEY_ID should not be None") | |||
| if dify_config.ANALYTICDB_KEY_SECRET is None: | |||
| raise ValueError("ANALYTICDB_KEY_SECRET should not be None") | |||
| if dify_config.ANALYTICDB_REGION_ID is None: | |||
| raise ValueError("ANALYTICDB_REGION_ID should not be None") | |||
| if dify_config.ANALYTICDB_INSTANCE_ID is None: | |||
| raise ValueError("ANALYTICDB_INSTANCE_ID should not be None") | |||
| if dify_config.ANALYTICDB_ACCOUNT is None: | |||
| raise ValueError("ANALYTICDB_ACCOUNT should not be None") | |||
| if dify_config.ANALYTICDB_PASSWORD is None: | |||
| raise ValueError("ANALYTICDB_PASSWORD should not be None") | |||
| if dify_config.ANALYTICDB_NAMESPACE is None: | |||
| raise ValueError("ANALYTICDB_NAMESPACE should not be None") | |||
| if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None: | |||
| raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None") | |||
| return AnalyticdbVector( | |||
| collection_name, | |||
| AnalyticdbConfig( | |||
| if dify_config.ANALYTICDB_HOST is None: | |||
| # implemented through OpenAPI | |||
| apiConfig = AnalyticdbVectorOpenAPIConfig( | |||
| access_key_id=dify_config.ANALYTICDB_KEY_ID, | |||
| access_key_secret=dify_config.ANALYTICDB_KEY_SECRET, | |||
| region_id=dify_config.ANALYTICDB_REGION_ID, | |||
| @@ -341,5 +76,22 @@ class AnalyticdbVectorFactory(AbstractVectorFactory): | |||
| account_password=dify_config.ANALYTICDB_PASSWORD, | |||
| namespace=dify_config.ANALYTICDB_NAMESPACE, | |||
| namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD, | |||
| ), | |||
| ) | |||
| sqlConfig = None | |||
| else: | |||
| # implemented through sql | |||
| sqlConfig = AnalyticdbVectorBySqlConfig( | |||
| host=dify_config.ANALYTICDB_HOST, | |||
| port=dify_config.ANALYTICDB_PORT, | |||
| account=dify_config.ANALYTICDB_ACCOUNT, | |||
| account_password=dify_config.ANALYTICDB_PASSWORD, | |||
| min_connection=dify_config.ANALYTICDB_MIN_CONNECTION, | |||
| max_connection=dify_config.ANALYTICDB_MAX_CONNECTION, | |||
| namespace=dify_config.ANALYTICDB_NAMESPACE, | |||
| ) | |||
| apiConfig = None | |||
| return AnalyticdbVector( | |||
| collection_name, | |||
| apiConfig, | |||
| sqlConfig, | |||
| ) | |||
| @@ -0,0 +1,309 @@ | |||
| import json | |||
| from typing import Any | |||
| from pydantic import BaseModel, model_validator | |||
| _import_err_msg = ( | |||
| "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, " | |||
| "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`" | |||
| ) | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_redis import redis_client | |||
| class AnalyticdbVectorOpenAPIConfig(BaseModel): | |||
| access_key_id: str | |||
| access_key_secret: str | |||
| region_id: str | |||
| instance_id: str | |||
| account: str | |||
| account_password: str | |||
| namespace: str = "dify" | |||
| namespace_password: str = (None,) | |||
| metrics: str = "cosine" | |||
| read_timeout: int = 60000 | |||
| @model_validator(mode="before") | |||
| @classmethod | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values["access_key_id"]: | |||
| raise ValueError("config ANALYTICDB_KEY_ID is required") | |||
| if not values["access_key_secret"]: | |||
| raise ValueError("config ANALYTICDB_KEY_SECRET is required") | |||
| if not values["region_id"]: | |||
| raise ValueError("config ANALYTICDB_REGION_ID is required") | |||
| if not values["instance_id"]: | |||
| raise ValueError("config ANALYTICDB_INSTANCE_ID is required") | |||
| if not values["account"]: | |||
| raise ValueError("config ANALYTICDB_ACCOUNT is required") | |||
| if not values["account_password"]: | |||
| raise ValueError("config ANALYTICDB_PASSWORD is required") | |||
| if not values["namespace_password"]: | |||
| raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required") | |||
| return values | |||
| def to_analyticdb_client_params(self): | |||
| return { | |||
| "access_key_id": self.access_key_id, | |||
| "access_key_secret": self.access_key_secret, | |||
| "region_id": self.region_id, | |||
| "read_timeout": self.read_timeout, | |||
| } | |||
| class AnalyticdbVectorOpenAPI: | |||
| def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig): | |||
| try: | |||
| from alibabacloud_gpdb20160503.client import Client | |||
| from alibabacloud_tea_openapi import models as open_api_models | |||
| except: | |||
| raise ImportError(_import_err_msg) | |||
| self._collection_name = collection_name.lower() | |||
| self.config = config | |||
| self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params()) | |||
| self._client = Client(self._client_config) | |||
| self._initialize() | |||
| def _initialize(self) -> None: | |||
| cache_key = f"vector_initialize_{self.config.instance_id}" | |||
| lock_name = f"{cache_key}_lock" | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| database_exist_cache_key = f"vector_initialize_{self.config.instance_id}" | |||
| if redis_client.get(database_exist_cache_key): | |||
| return | |||
| self._initialize_vector_database() | |||
| self._create_namespace_if_not_exists() | |||
| redis_client.set(database_exist_cache_key, 1, ex=3600) | |||
| def _initialize_vector_database(self) -> None: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| request = gpdb_20160503_models.InitVectorDatabaseRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| manager_account=self.config.account, | |||
| manager_account_password=self.config.account_password, | |||
| ) | |||
| self._client.init_vector_database(request) | |||
| def _create_namespace_if_not_exists(self) -> None: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| from Tea.exceptions import TeaException | |||
| try: | |||
| request = gpdb_20160503_models.DescribeNamespaceRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| manager_account=self.config.account, | |||
| manager_account_password=self.config.account_password, | |||
| ) | |||
| self._client.describe_namespace(request) | |||
| except TeaException as e: | |||
| if e.statusCode == 404: | |||
| request = gpdb_20160503_models.CreateNamespaceRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| manager_account=self.config.account, | |||
| manager_account_password=self.config.account_password, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| ) | |||
| self._client.create_namespace(request) | |||
| else: | |||
| raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") | |||
| def _create_collection_if_not_exists(self, embedding_dimension: int): | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| from Tea.exceptions import TeaException | |||
| cache_key = f"vector_indexing_{self._collection_name}" | |||
| lock_name = f"{cache_key}_lock" | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||
| if redis_client.get(collection_exist_cache_key): | |||
| return | |||
| try: | |||
| request = gpdb_20160503_models.DescribeCollectionRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| ) | |||
| self._client.describe_collection(request) | |||
| except TeaException as e: | |||
| if e.statusCode == 404: | |||
| metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}' | |||
| full_text_retrieval_fields = "page_content" | |||
| request = gpdb_20160503_models.CreateCollectionRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| manager_account=self.config.account, | |||
| manager_account_password=self.config.account_password, | |||
| namespace=self.config.namespace, | |||
| collection=self._collection_name, | |||
| dimension=embedding_dimension, | |||
| metrics=self.config.metrics, | |||
| metadata=metadata, | |||
| full_text_retrieval_fields=full_text_retrieval_fields, | |||
| ) | |||
| self._client.create_collection(request) | |||
| else: | |||
| raise ValueError(f"failed to create collection {self._collection_name}: {e}") | |||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] | |||
| for doc, embedding in zip(documents, embeddings, strict=True): | |||
| metadata = { | |||
| "ref_doc_id": doc.metadata["doc_id"], | |||
| "page_content": doc.page_content, | |||
| "metadata_": json.dumps(doc.metadata), | |||
| } | |||
| rows.append( | |||
| gpdb_20160503_models.UpsertCollectionDataRequestRows( | |||
| vector=embedding, | |||
| metadata=metadata, | |||
| ) | |||
| ) | |||
| request = gpdb_20160503_models.UpsertCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| rows=rows, | |||
| ) | |||
| self._client.upsert_collection_data(request) | |||
| def text_exists(self, id: str) -> bool: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| metrics=self.config.metrics, | |||
| include_values=True, | |||
| vector=None, | |||
| content=None, | |||
| top_k=1, | |||
| filter=f"ref_doc_id='{id}'", | |||
| ) | |||
| response = self._client.query_collection_data(request) | |||
| return len(response.body.matches.match) > 0 | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| ids_str = ",".join(f"'{id}'" for id in ids) | |||
| ids_str = f"({ids_str})" | |||
| request = gpdb_20160503_models.DeleteCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| collection_data=None, | |||
| collection_data_filter=f"ref_doc_id IN {ids_str}", | |||
| ) | |||
| self._client.delete_collection_data(request) | |||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| request = gpdb_20160503_models.DeleteCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| collection_data=None, | |||
| collection_data_filter=f"metadata_ ->> '{key}' = '{value}'", | |||
| ) | |||
| self._client.delete_collection_data(request) | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| score_threshold = kwargs.get("score_threshold") or 0.0 | |||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| include_values=kwargs.pop("include_values", True), | |||
| metrics=self.config.metrics, | |||
| vector=query_vector, | |||
| content=None, | |||
| top_k=kwargs.get("top_k", 4), | |||
| filter=None, | |||
| ) | |||
| response = self._client.query_collection_data(request) | |||
| documents = [] | |||
| for match in response.body.matches.match: | |||
| if match.score > score_threshold: | |||
| metadata = json.loads(match.metadata.get("metadata_")) | |||
| metadata["score"] = match.score | |||
| doc = Document( | |||
| page_content=match.metadata.get("page_content"), | |||
| vector=match.values.value, | |||
| metadata=metadata, | |||
| ) | |||
| documents.append(doc) | |||
| documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) | |||
| return documents | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| include_values=kwargs.pop("include_values", True), | |||
| metrics=self.config.metrics, | |||
| vector=None, | |||
| content=query, | |||
| top_k=kwargs.get("top_k", 4), | |||
| filter=None, | |||
| ) | |||
| response = self._client.query_collection_data(request) | |||
| documents = [] | |||
| for match in response.body.matches.match: | |||
| if match.score > score_threshold: | |||
| metadata = json.loads(match.metadata.get("metadata_")) | |||
| metadata["score"] = match.score | |||
| doc = Document( | |||
| page_content=match.metadata.get("page_content"), | |||
| vector=match.values.value, | |||
| metadata=metadata, | |||
| ) | |||
| documents.append(doc) | |||
| documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) | |||
| return documents | |||
| def delete(self) -> None: | |||
| try: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| request = gpdb_20160503_models.DeleteCollectionRequest( | |||
| collection=self._collection_name, | |||
| dbinstance_id=self.config.instance_id, | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| region_id=self.config.region_id, | |||
| ) | |||
| self._client.delete_collection(request) | |||
| except Exception as e: | |||
| raise e | |||
| @@ -0,0 +1,245 @@ | |||
| import json | |||
| import uuid | |||
| from contextlib import contextmanager | |||
| from typing import Any | |||
| import psycopg2.extras | |||
| import psycopg2.pool | |||
| from pydantic import BaseModel, model_validator | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_redis import redis_client | |||
| class AnalyticdbVectorBySqlConfig(BaseModel): | |||
| host: str | |||
| port: int | |||
| account: str | |||
| account_password: str | |||
| min_connection: int | |||
| max_connection: int | |||
| namespace: str = "dify" | |||
| metrics: str = "cosine" | |||
| @model_validator(mode="before") | |||
| @classmethod | |||
| def validate_config(cls, values: dict) -> dict: | |||
| if not values["host"]: | |||
| raise ValueError("config ANALYTICDB_HOST is required") | |||
| if not values["port"]: | |||
| raise ValueError("config ANALYTICDB_PORT is required") | |||
| if not values["account"]: | |||
| raise ValueError("config ANALYTICDB_ACCOUNT is required") | |||
| if not values["account_password"]: | |||
| raise ValueError("config ANALYTICDB_PASSWORD is required") | |||
| if not values["min_connection"]: | |||
| raise ValueError("config ANALYTICDB_MIN_CONNECTION is required") | |||
| if not values["max_connection"]: | |||
| raise ValueError("config ANALYTICDB_MAX_CONNECTION is required") | |||
| if values["min_connection"] > values["max_connection"]: | |||
| raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION") | |||
| return values | |||
| class AnalyticdbVectorBySql: | |||
| def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig): | |||
| self._collection_name = collection_name.lower() | |||
| self.databaseName = "knowledgebase" | |||
| self.config = config | |||
| self.table_name = f"{self.config.namespace}.{self._collection_name}" | |||
| self.pool = None | |||
| self._initialize() | |||
| if not self.pool: | |||
| self.pool = self._create_connection_pool() | |||
| def _initialize(self) -> None: | |||
| cache_key = f"vector_initialize_{self.config.host}" | |||
| lock_name = f"{cache_key}_lock" | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| database_exist_cache_key = f"vector_initialize_{self.config.host}" | |||
| if redis_client.get(database_exist_cache_key): | |||
| return | |||
| self._initialize_vector_database() | |||
| redis_client.set(database_exist_cache_key, 1, ex=3600) | |||
| def _create_connection_pool(self): | |||
| return psycopg2.pool.SimpleConnectionPool( | |||
| self.config.min_connection, | |||
| self.config.max_connection, | |||
| host=self.config.host, | |||
| port=self.config.port, | |||
| user=self.config.account, | |||
| password=self.config.account_password, | |||
| database=self.databaseName, | |||
| ) | |||
| @contextmanager | |||
| def _get_cursor(self): | |||
| conn = self.pool.getconn() | |||
| cur = conn.cursor() | |||
| try: | |||
| yield cur | |||
| finally: | |||
| cur.close() | |||
| conn.commit() | |||
| self.pool.putconn(conn) | |||
| def _initialize_vector_database(self) -> None: | |||
| conn = psycopg2.connect( | |||
| host=self.config.host, | |||
| port=self.config.port, | |||
| user=self.config.account, | |||
| password=self.config.account_password, | |||
| database="postgres", | |||
| ) | |||
| conn.autocommit = True | |||
| cur = conn.cursor() | |||
| try: | |||
| cur.execute(f"CREATE DATABASE {self.databaseName}") | |||
| except Exception as e: | |||
| if "already exists" in str(e): | |||
| return | |||
| raise e | |||
| finally: | |||
| cur.close() | |||
| conn.close() | |||
| self.pool = self._create_connection_pool() | |||
| with self._get_cursor() as cur: | |||
| try: | |||
| cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)") | |||
| cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple") | |||
| except Exception as e: | |||
| if "already exists" not in str(e): | |||
| raise e | |||
| cur.execute( | |||
| "CREATE OR REPLACE FUNCTION " | |||
| "public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) " | |||
| "RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ " | |||
| "SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) " | |||
| "FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) " | |||
| "AS words_only;$function$" | |||
| ) | |||
| cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}") | |||
| def _create_collection_if_not_exists(self, embedding_dimension: int): | |||
| cache_key = f"vector_indexing_{self._collection_name}" | |||
| lock_name = f"{cache_key}_lock" | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||
| if redis_client.get(collection_exist_cache_key): | |||
| return | |||
| with self._get_cursor() as cur: | |||
| cur.execute( | |||
| f"CREATE TABLE IF NOT EXISTS {self.table_name}(" | |||
| f"id text PRIMARY KEY," | |||
| f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, " | |||
| f"to_tsvector TSVECTOR" | |||
| f") WITH (fillfactor=70) DISTRIBUTED BY (id);" | |||
| ) | |||
| if embedding_dimension is not None: | |||
| index_name = f"{self._collection_name}_embedding_idx" | |||
| cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN") | |||
| cur.execute( | |||
| f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) " | |||
| f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', " | |||
| f"pq_enable=0, external_storage=0)" | |||
| ) | |||
| cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)") | |||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||
| values = [] | |||
| id_prefix = str(uuid.uuid4()) + "_" | |||
| sql = f""" | |||
| INSERT INTO {self.table_name} | |||
| (id, ref_doc_id, vector, page_content, metadata_, to_tsvector) | |||
| VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s)); | |||
| """ | |||
| for i, doc in enumerate(documents): | |||
| values.append( | |||
| ( | |||
| id_prefix + str(i), | |||
| doc.metadata.get("doc_id", str(uuid.uuid4())), | |||
| embeddings[i], | |||
| doc.page_content, | |||
| json.dumps(doc.metadata), | |||
| doc.page_content, | |||
| ) | |||
| ) | |||
| with self._get_cursor() as cur: | |||
| psycopg2.extras.execute_batch(cur, sql, values) | |||
| def text_exists(self, id: str) -> bool: | |||
| with self._get_cursor() as cur: | |||
| cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,)) | |||
| return cur.fetchone() is not None | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| with self._get_cursor() as cur: | |||
| try: | |||
| cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),)) | |||
| except Exception as e: | |||
| if "does not exist" not in str(e): | |||
| raise e | |||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||
| with self._get_cursor() as cur: | |||
| try: | |||
| cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value)) | |||
| except Exception as e: | |||
| if "does not exist" not in str(e): | |||
| raise e | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| top_k = kwargs.get("top_k", 4) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| with self._get_cursor() as cur: | |||
| query_vector_str = json.dumps(query_vector) | |||
| query_vector_str = "{" + query_vector_str[1:-1] + "}" | |||
| cur.execute( | |||
| f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, " | |||
| f"t.page_content as page_content, t.metadata_ AS metadata_ " | |||
| f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score " | |||
| f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t", | |||
| (query_vector_str,), | |||
| ) | |||
| documents = [] | |||
| for record in cur: | |||
| id, vector, score, page_content, metadata = record | |||
| if score > score_threshold: | |||
| metadata["score"] = score | |||
| doc = Document( | |||
| page_content=page_content, | |||
| vector=vector, | |||
| metadata=metadata, | |||
| ) | |||
| documents.append(doc) | |||
| return documents | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| top_k = kwargs.get("top_k", 4) | |||
| with self._get_cursor() as cur: | |||
| cur.execute( | |||
| f"""SELECT id, vector, page_content, metadata_, | |||
| ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score | |||
| FROM {self.table_name} | |||
| WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') | |||
| ORDER BY score DESC | |||
| LIMIT {top_k}""", | |||
| (f"'{query}'", f"'{query}'"), | |||
| ) | |||
| documents = [] | |||
| for record in cur: | |||
| id, vector, page_content, metadata, score = record | |||
| metadata["score"] = score | |||
| doc = Document( | |||
| page_content=page_content, | |||
| vector=vector, | |||
| metadata=metadata, | |||
| ) | |||
| documents.append(doc) | |||
| return documents | |||
| def delete(self) -> None: | |||
| with self._get_cursor() as cur: | |||
| cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") | |||
| @@ -1,27 +1,43 @@ | |||
| from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbConfig, AnalyticdbVector | |||
| from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig | |||
| from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig | |||
| from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis | |||
| class AnalyticdbVectorTest(AbstractVectorTest): | |||
| def __init__(self): | |||
| def __init__(self, config_type: str): | |||
| super().__init__() | |||
| # Analyticdb requires collection_name length less than 60. | |||
| # it's ok for normal usage. | |||
| self.collection_name = self.collection_name.replace("_test", "") | |||
| self.vector = AnalyticdbVector( | |||
| collection_name=self.collection_name, | |||
| config=AnalyticdbConfig( | |||
| access_key_id="test_key_id", | |||
| access_key_secret="test_key_secret", | |||
| region_id="test_region", | |||
| instance_id="test_id", | |||
| account="test_account", | |||
| account_password="test_passwd", | |||
| namespace="difytest_namespace", | |||
| collection="difytest_collection", | |||
| namespace_password="test_passwd", | |||
| ), | |||
| ) | |||
| if config_type == "sql": | |||
| self.vector = AnalyticdbVector( | |||
| collection_name=self.collection_name, | |||
| sql_config=AnalyticdbVectorBySqlConfig( | |||
| host="test_host", | |||
| port=5432, | |||
| account="test_account", | |||
| account_password="test_passwd", | |||
| namespace="difytest_namespace", | |||
| ), | |||
| api_config=None, | |||
| ) | |||
| else: | |||
| self.vector = AnalyticdbVector( | |||
| collection_name=self.collection_name, | |||
| sql_config=None, | |||
| api_config=AnalyticdbVectorOpenAPIConfig( | |||
| access_key_id="test_key_id", | |||
| access_key_secret="test_key_secret", | |||
| region_id="test_region", | |||
| instance_id="test_id", | |||
| account="test_account", | |||
| account_password="test_passwd", | |||
| namespace="difytest_namespace", | |||
| collection="difytest_collection", | |||
| namespace_password="test_passwd", | |||
| ), | |||
| ) | |||
| def run_all_tests(self): | |||
| self.vector.delete() | |||
| @@ -29,4 +45,5 @@ class AnalyticdbVectorTest(AbstractVectorTest): | |||
| def test_chroma_vector(setup_mock_redis): | |||
| AnalyticdbVectorTest().run_all_tests() | |||
| AnalyticdbVectorTest("api").run_all_tests() | |||
| AnalyticdbVectorTest("sql").run_all_tests() | |||
| @@ -450,6 +450,10 @@ ANALYTICDB_ACCOUNT=testaccount | |||
| ANALYTICDB_PASSWORD=testpassword | |||
| ANALYTICDB_NAMESPACE=dify | |||
| ANALYTICDB_NAMESPACE_PASSWORD=difypassword | |||
| ANALYTICDB_HOST=gp-test.aliyuncs.com | |||
| ANALYTICDB_PORT=5432 | |||
| ANALYTICDB_MIN_CONNECTION=1 | |||
| ANALYTICDB_MAX_CONNECTION=5 | |||
| # TiDB vector configurations, only available when VECTOR_STORE is `tidb` | |||
| TIDB_VECTOR_HOST=tidb | |||
| @@ -185,6 +185,10 @@ x-shared-env: &shared-api-worker-env | |||
| ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-} | |||
| ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify} | |||
| ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-} | |||
| ANALYTICDB_HOST: ${ANALYTICDB_HOST:-} | |||
| ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432} | |||
| ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1} | |||
| ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5} | |||
| OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch} | |||
| OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200} | |||
| OPENSEARCH_USER: ${OPENSEARCH_USER:-admin} | |||