Co-authored-by: 璟义 <yangshangpo.ysp@alibaba-inc.com>tags/0.12.0
| ANALYTICDB_PASSWORD=testpassword | ANALYTICDB_PASSWORD=testpassword | ||||
| ANALYTICDB_NAMESPACE=dify | ANALYTICDB_NAMESPACE=dify | ||||
| ANALYTICDB_NAMESPACE_PASSWORD=difypassword | ANALYTICDB_NAMESPACE_PASSWORD=difypassword | ||||
| ANALYTICDB_HOST=gp-test.aliyuncs.com | |||||
| ANALYTICDB_PORT=5432 | |||||
| ANALYTICDB_MIN_CONNECTION=1 | |||||
| ANALYTICDB_MAX_CONNECTION=5 | |||||
| # OpenSearch configuration | # OpenSearch configuration | ||||
| OPENSEARCH_HOST=127.0.0.1 | OPENSEARCH_HOST=127.0.0.1 |
| from typing import Optional | from typing import Optional | ||||
| from pydantic import BaseModel, Field | |||||
| from pydantic import BaseModel, Field, PositiveInt | |||||
| class AnalyticdbConfig(BaseModel): | class AnalyticdbConfig(BaseModel): | ||||
| description="The password for accessing the specified namespace within the AnalyticDB instance" | description="The password for accessing the specified namespace within the AnalyticDB instance" | ||||
| " (if namespace feature is enabled).", | " (if namespace feature is enabled).", | ||||
| ) | ) | ||||
| ANALYTICDB_HOST: Optional[str] = Field( | |||||
| default=None, description="The host of the AnalyticDB instance you want to connect to." | |||||
| ) | |||||
| ANALYTICDB_PORT: PositiveInt = Field( | |||||
| default=5432, description="The port of the AnalyticDB instance you want to connect to." | |||||
| ) | |||||
| ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.") | |||||
| ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.") |
| import json | import json | ||||
| from typing import Any | from typing import Any | ||||
| from pydantic import BaseModel | |||||
| _import_err_msg = ( | |||||
| "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, " | |||||
| "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`" | |||||
| ) | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import ( | |||||
| AnalyticdbVectorOpenAPI, | |||||
| AnalyticdbVectorOpenAPIConfig, | |||||
| ) | |||||
| from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig | |||||
| from core.rag.datasource.vdb.vector_base import BaseVector | from core.rag.datasource.vdb.vector_base import BaseVector | ||||
| from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | ||||
| from core.rag.datasource.vdb.vector_type import VectorType | from core.rag.datasource.vdb.vector_type import VectorType | ||||
| from core.rag.embedding.embedding_base import Embeddings | from core.rag.embedding.embedding_base import Embeddings | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_redis import redis_client | |||||
| from models.dataset import Dataset | from models.dataset import Dataset | ||||
| class AnalyticdbConfig(BaseModel): | |||||
| access_key_id: str | |||||
| access_key_secret: str | |||||
| region_id: str | |||||
| instance_id: str | |||||
| account: str | |||||
| account_password: str | |||||
| namespace: str = ("dify",) | |||||
| namespace_password: str = (None,) | |||||
| metrics: str = ("cosine",) | |||||
| read_timeout: int = 60000 | |||||
| def to_analyticdb_client_params(self): | |||||
| return { | |||||
| "access_key_id": self.access_key_id, | |||||
| "access_key_secret": self.access_key_secret, | |||||
| "region_id": self.region_id, | |||||
| "read_timeout": self.read_timeout, | |||||
| } | |||||
| class AnalyticdbVector(BaseVector): | class AnalyticdbVector(BaseVector): | ||||
| def __init__(self, collection_name: str, config: AnalyticdbConfig): | |||||
| self._collection_name = collection_name.lower() | |||||
| try: | |||||
| from alibabacloud_gpdb20160503.client import Client | |||||
| from alibabacloud_tea_openapi import models as open_api_models | |||||
| except: | |||||
| raise ImportError(_import_err_msg) | |||||
| self.config = config | |||||
| self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params()) | |||||
| self._client = Client(self._client_config) | |||||
| self._initialize() | |||||
| def _initialize(self) -> None: | |||||
| cache_key = f"vector_indexing_{self.config.instance_id}" | |||||
| lock_name = f"{cache_key}_lock" | |||||
| with redis_client.lock(lock_name, timeout=20): | |||||
| collection_exist_cache_key = f"vector_indexing_{self.config.instance_id}" | |||||
| if redis_client.get(collection_exist_cache_key): | |||||
| return | |||||
| self._initialize_vector_database() | |||||
| self._create_namespace_if_not_exists() | |||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||||
| def _initialize_vector_database(self) -> None: | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| request = gpdb_20160503_models.InitVectorDatabaseRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| manager_account=self.config.account, | |||||
| manager_account_password=self.config.account_password, | |||||
| ) | |||||
| self._client.init_vector_database(request) | |||||
| def _create_namespace_if_not_exists(self) -> None: | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| from Tea.exceptions import TeaException | |||||
| try: | |||||
| request = gpdb_20160503_models.DescribeNamespaceRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| manager_account=self.config.account, | |||||
| manager_account_password=self.config.account_password, | |||||
| ) | |||||
| self._client.describe_namespace(request) | |||||
| except TeaException as e: | |||||
| if e.statusCode == 404: | |||||
| request = gpdb_20160503_models.CreateNamespaceRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| manager_account=self.config.account, | |||||
| manager_account_password=self.config.account_password, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| ) | |||||
| self._client.create_namespace(request) | |||||
| else: | |||||
| raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") | |||||
| def _create_collection_if_not_exists(self, embedding_dimension: int): | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| from Tea.exceptions import TeaException | |||||
| cache_key = f"vector_indexing_{self._collection_name}" | |||||
| lock_name = f"{cache_key}_lock" | |||||
| with redis_client.lock(lock_name, timeout=20): | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | |||||
| return | |||||
| try: | |||||
| request = gpdb_20160503_models.DescribeCollectionRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| ) | |||||
| self._client.describe_collection(request) | |||||
| except TeaException as e: | |||||
| if e.statusCode == 404: | |||||
| metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}' | |||||
| full_text_retrieval_fields = "page_content" | |||||
| request = gpdb_20160503_models.CreateCollectionRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| manager_account=self.config.account, | |||||
| manager_account_password=self.config.account_password, | |||||
| namespace=self.config.namespace, | |||||
| collection=self._collection_name, | |||||
| dimension=embedding_dimension, | |||||
| metrics=self.config.metrics, | |||||
| metadata=metadata, | |||||
| full_text_retrieval_fields=full_text_retrieval_fields, | |||||
| ) | |||||
| self._client.create_collection(request) | |||||
| else: | |||||
| raise ValueError(f"failed to create collection {self._collection_name}: {e}") | |||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||||
| def __init__( | |||||
| self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig | |||||
| ): | |||||
| super().__init__(collection_name) | |||||
| if api_config is not None: | |||||
| self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config) | |||||
| else: | |||||
| self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config) | |||||
| def get_type(self) -> str: | def get_type(self) -> str: | ||||
| return VectorType.ANALYTICDB | return VectorType.ANALYTICDB | ||||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | ||||
| dimension = len(embeddings[0]) | dimension = len(embeddings[0]) | ||||
| self._create_collection_if_not_exists(dimension) | |||||
| self.add_texts(texts, embeddings) | |||||
| self.analyticdb_vector._create_collection_if_not_exists(dimension) | |||||
| self.analyticdb_vector.add_texts(texts, embeddings) | |||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] | |||||
| for doc, embedding in zip(documents, embeddings, strict=True): | |||||
| metadata = { | |||||
| "ref_doc_id": doc.metadata["doc_id"], | |||||
| "page_content": doc.page_content, | |||||
| "metadata_": json.dumps(doc.metadata), | |||||
| } | |||||
| rows.append( | |||||
| gpdb_20160503_models.UpsertCollectionDataRequestRows( | |||||
| vector=embedding, | |||||
| metadata=metadata, | |||||
| ) | |||||
| ) | |||||
| request = gpdb_20160503_models.UpsertCollectionDataRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| rows=rows, | |||||
| ) | |||||
| self._client.upsert_collection_data(request) | |||||
| def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| self.analyticdb_vector.add_texts(texts, embeddings) | |||||
| def text_exists(self, id: str) -> bool: | def text_exists(self, id: str) -> bool: | ||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| metrics=self.config.metrics, | |||||
| include_values=True, | |||||
| vector=None, | |||||
| content=None, | |||||
| top_k=1, | |||||
| filter=f"ref_doc_id='{id}'", | |||||
| ) | |||||
| response = self._client.query_collection_data(request) | |||||
| return len(response.body.matches.match) > 0 | |||||
| return self.analyticdb_vector.text_exists(id) | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| ids_str = ",".join(f"'{id}'" for id in ids) | |||||
| ids_str = f"({ids_str})" | |||||
| request = gpdb_20160503_models.DeleteCollectionDataRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| collection_data=None, | |||||
| collection_data_filter=f"ref_doc_id IN {ids_str}", | |||||
| ) | |||||
| self._client.delete_collection_data(request) | |||||
| self.analyticdb_vector.delete_by_ids(ids) | |||||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | def delete_by_metadata_field(self, key: str, value: str) -> None: | ||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| request = gpdb_20160503_models.DeleteCollectionDataRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| collection_data=None, | |||||
| collection_data_filter=f"metadata_ ->> '{key}' = '{value}'", | |||||
| ) | |||||
| self._client.delete_collection_data(request) | |||||
| self.analyticdb_vector.delete_by_metadata_field(key, value) | |||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | ||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| score_threshold = kwargs.get("score_threshold") or 0.0 | |||||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| include_values=kwargs.pop("include_values", True), | |||||
| metrics=self.config.metrics, | |||||
| vector=query_vector, | |||||
| content=None, | |||||
| top_k=kwargs.get("top_k", 4), | |||||
| filter=None, | |||||
| ) | |||||
| response = self._client.query_collection_data(request) | |||||
| documents = [] | |||||
| for match in response.body.matches.match: | |||||
| if match.score > score_threshold: | |||||
| metadata = json.loads(match.metadata.get("metadata_")) | |||||
| metadata["score"] = match.score | |||||
| doc = Document( | |||||
| page_content=match.metadata.get("page_content"), | |||||
| metadata=metadata, | |||||
| ) | |||||
| documents.append(doc) | |||||
| documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) | |||||
| return documents | |||||
| return self.analyticdb_vector.search_by_vector(query_vector) | |||||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | ||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| include_values=kwargs.pop("include_values", True), | |||||
| metrics=self.config.metrics, | |||||
| vector=None, | |||||
| content=query, | |||||
| top_k=kwargs.get("top_k", 4), | |||||
| filter=None, | |||||
| ) | |||||
| response = self._client.query_collection_data(request) | |||||
| documents = [] | |||||
| for match in response.body.matches.match: | |||||
| if match.score > score_threshold: | |||||
| metadata = json.loads(match.metadata.get("metadata_")) | |||||
| metadata["score"] = match.score | |||||
| doc = Document( | |||||
| page_content=match.metadata.get("page_content"), | |||||
| vector=match.metadata.get("vector"), | |||||
| metadata=metadata, | |||||
| ) | |||||
| documents.append(doc) | |||||
| documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) | |||||
| return documents | |||||
| return self.analyticdb_vector.search_by_full_text(query, **kwargs) | |||||
| def delete(self) -> None: | def delete(self) -> None: | ||||
| try: | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| request = gpdb_20160503_models.DeleteCollectionRequest( | |||||
| collection=self._collection_name, | |||||
| dbinstance_id=self.config.instance_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| region_id=self.config.region_id, | |||||
| ) | |||||
| self._client.delete_collection(request) | |||||
| except Exception as e: | |||||
| raise e | |||||
| self.analyticdb_vector.delete() | |||||
| class AnalyticdbVectorFactory(AbstractVectorFactory): | class AnalyticdbVectorFactory(AbstractVectorFactory): | ||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings): | |||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector: | |||||
| if dataset.index_struct_dict: | if dataset.index_struct_dict: | ||||
| class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] | class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] | ||||
| collection_name = class_prefix.lower() | collection_name = class_prefix.lower() | ||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() | collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() | ||||
| dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)) | dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)) | ||||
| # handle optional params | |||||
| if dify_config.ANALYTICDB_KEY_ID is None: | |||||
| raise ValueError("ANALYTICDB_KEY_ID should not be None") | |||||
| if dify_config.ANALYTICDB_KEY_SECRET is None: | |||||
| raise ValueError("ANALYTICDB_KEY_SECRET should not be None") | |||||
| if dify_config.ANALYTICDB_REGION_ID is None: | |||||
| raise ValueError("ANALYTICDB_REGION_ID should not be None") | |||||
| if dify_config.ANALYTICDB_INSTANCE_ID is None: | |||||
| raise ValueError("ANALYTICDB_INSTANCE_ID should not be None") | |||||
| if dify_config.ANALYTICDB_ACCOUNT is None: | |||||
| raise ValueError("ANALYTICDB_ACCOUNT should not be None") | |||||
| if dify_config.ANALYTICDB_PASSWORD is None: | |||||
| raise ValueError("ANALYTICDB_PASSWORD should not be None") | |||||
| if dify_config.ANALYTICDB_NAMESPACE is None: | |||||
| raise ValueError("ANALYTICDB_NAMESPACE should not be None") | |||||
| if dify_config.ANALYTICDB_NAMESPACE_PASSWORD is None: | |||||
| raise ValueError("ANALYTICDB_NAMESPACE_PASSWORD should not be None") | |||||
| return AnalyticdbVector( | |||||
| collection_name, | |||||
| AnalyticdbConfig( | |||||
| if dify_config.ANALYTICDB_HOST is None: | |||||
| # implemented through OpenAPI | |||||
| apiConfig = AnalyticdbVectorOpenAPIConfig( | |||||
| access_key_id=dify_config.ANALYTICDB_KEY_ID, | access_key_id=dify_config.ANALYTICDB_KEY_ID, | ||||
| access_key_secret=dify_config.ANALYTICDB_KEY_SECRET, | access_key_secret=dify_config.ANALYTICDB_KEY_SECRET, | ||||
| region_id=dify_config.ANALYTICDB_REGION_ID, | region_id=dify_config.ANALYTICDB_REGION_ID, | ||||
| account_password=dify_config.ANALYTICDB_PASSWORD, | account_password=dify_config.ANALYTICDB_PASSWORD, | ||||
| namespace=dify_config.ANALYTICDB_NAMESPACE, | namespace=dify_config.ANALYTICDB_NAMESPACE, | ||||
| namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD, | namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD, | ||||
| ), | |||||
| ) | |||||
| sqlConfig = None | |||||
| else: | |||||
| # implemented through sql | |||||
| sqlConfig = AnalyticdbVectorBySqlConfig( | |||||
| host=dify_config.ANALYTICDB_HOST, | |||||
| port=dify_config.ANALYTICDB_PORT, | |||||
| account=dify_config.ANALYTICDB_ACCOUNT, | |||||
| account_password=dify_config.ANALYTICDB_PASSWORD, | |||||
| min_connection=dify_config.ANALYTICDB_MIN_CONNECTION, | |||||
| max_connection=dify_config.ANALYTICDB_MAX_CONNECTION, | |||||
| namespace=dify_config.ANALYTICDB_NAMESPACE, | |||||
| ) | |||||
| apiConfig = None | |||||
| return AnalyticdbVector( | |||||
| collection_name, | |||||
| apiConfig, | |||||
| sqlConfig, | |||||
| ) | ) |
| import json | |||||
| from typing import Any | |||||
| from pydantic import BaseModel, model_validator | |||||
| _import_err_msg = ( | |||||
| "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, " | |||||
| "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`" | |||||
| ) | |||||
| from core.rag.models.document import Document | |||||
| from extensions.ext_redis import redis_client | |||||
| class AnalyticdbVectorOpenAPIConfig(BaseModel): | |||||
| access_key_id: str | |||||
| access_key_secret: str | |||||
| region_id: str | |||||
| instance_id: str | |||||
| account: str | |||||
| account_password: str | |||||
| namespace: str = "dify" | |||||
| namespace_password: str = (None,) | |||||
| metrics: str = "cosine" | |||||
| read_timeout: int = 60000 | |||||
| @model_validator(mode="before") | |||||
| @classmethod | |||||
| def validate_config(cls, values: dict) -> dict: | |||||
| if not values["access_key_id"]: | |||||
| raise ValueError("config ANALYTICDB_KEY_ID is required") | |||||
| if not values["access_key_secret"]: | |||||
| raise ValueError("config ANALYTICDB_KEY_SECRET is required") | |||||
| if not values["region_id"]: | |||||
| raise ValueError("config ANALYTICDB_REGION_ID is required") | |||||
| if not values["instance_id"]: | |||||
| raise ValueError("config ANALYTICDB_INSTANCE_ID is required") | |||||
| if not values["account"]: | |||||
| raise ValueError("config ANALYTICDB_ACCOUNT is required") | |||||
| if not values["account_password"]: | |||||
| raise ValueError("config ANALYTICDB_PASSWORD is required") | |||||
| if not values["namespace_password"]: | |||||
| raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required") | |||||
| return values | |||||
| def to_analyticdb_client_params(self): | |||||
| return { | |||||
| "access_key_id": self.access_key_id, | |||||
| "access_key_secret": self.access_key_secret, | |||||
| "region_id": self.region_id, | |||||
| "read_timeout": self.read_timeout, | |||||
| } | |||||
| class AnalyticdbVectorOpenAPI: | |||||
| def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig): | |||||
| try: | |||||
| from alibabacloud_gpdb20160503.client import Client | |||||
| from alibabacloud_tea_openapi import models as open_api_models | |||||
| except: | |||||
| raise ImportError(_import_err_msg) | |||||
| self._collection_name = collection_name.lower() | |||||
| self.config = config | |||||
| self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params()) | |||||
| self._client = Client(self._client_config) | |||||
| self._initialize() | |||||
| def _initialize(self) -> None: | |||||
| cache_key = f"vector_initialize_{self.config.instance_id}" | |||||
| lock_name = f"{cache_key}_lock" | |||||
| with redis_client.lock(lock_name, timeout=20): | |||||
| database_exist_cache_key = f"vector_initialize_{self.config.instance_id}" | |||||
| if redis_client.get(database_exist_cache_key): | |||||
| return | |||||
| self._initialize_vector_database() | |||||
| self._create_namespace_if_not_exists() | |||||
| redis_client.set(database_exist_cache_key, 1, ex=3600) | |||||
| def _initialize_vector_database(self) -> None: | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| request = gpdb_20160503_models.InitVectorDatabaseRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| manager_account=self.config.account, | |||||
| manager_account_password=self.config.account_password, | |||||
| ) | |||||
| self._client.init_vector_database(request) | |||||
| def _create_namespace_if_not_exists(self) -> None: | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| from Tea.exceptions import TeaException | |||||
| try: | |||||
| request = gpdb_20160503_models.DescribeNamespaceRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| manager_account=self.config.account, | |||||
| manager_account_password=self.config.account_password, | |||||
| ) | |||||
| self._client.describe_namespace(request) | |||||
| except TeaException as e: | |||||
| if e.statusCode == 404: | |||||
| request = gpdb_20160503_models.CreateNamespaceRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| manager_account=self.config.account, | |||||
| manager_account_password=self.config.account_password, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| ) | |||||
| self._client.create_namespace(request) | |||||
| else: | |||||
| raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") | |||||
| def _create_collection_if_not_exists(self, embedding_dimension: int): | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| from Tea.exceptions import TeaException | |||||
| cache_key = f"vector_indexing_{self._collection_name}" | |||||
| lock_name = f"{cache_key}_lock" | |||||
| with redis_client.lock(lock_name, timeout=20): | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | |||||
| return | |||||
| try: | |||||
| request = gpdb_20160503_models.DescribeCollectionRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| ) | |||||
| self._client.describe_collection(request) | |||||
| except TeaException as e: | |||||
| if e.statusCode == 404: | |||||
| metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}' | |||||
| full_text_retrieval_fields = "page_content" | |||||
| request = gpdb_20160503_models.CreateCollectionRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| manager_account=self.config.account, | |||||
| manager_account_password=self.config.account_password, | |||||
| namespace=self.config.namespace, | |||||
| collection=self._collection_name, | |||||
| dimension=embedding_dimension, | |||||
| metrics=self.config.metrics, | |||||
| metadata=metadata, | |||||
| full_text_retrieval_fields=full_text_retrieval_fields, | |||||
| ) | |||||
| self._client.create_collection(request) | |||||
| else: | |||||
| raise ValueError(f"failed to create collection {self._collection_name}: {e}") | |||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] | |||||
| for doc, embedding in zip(documents, embeddings, strict=True): | |||||
| metadata = { | |||||
| "ref_doc_id": doc.metadata["doc_id"], | |||||
| "page_content": doc.page_content, | |||||
| "metadata_": json.dumps(doc.metadata), | |||||
| } | |||||
| rows.append( | |||||
| gpdb_20160503_models.UpsertCollectionDataRequestRows( | |||||
| vector=embedding, | |||||
| metadata=metadata, | |||||
| ) | |||||
| ) | |||||
| request = gpdb_20160503_models.UpsertCollectionDataRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| rows=rows, | |||||
| ) | |||||
| self._client.upsert_collection_data(request) | |||||
| def text_exists(self, id: str) -> bool: | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| metrics=self.config.metrics, | |||||
| include_values=True, | |||||
| vector=None, | |||||
| content=None, | |||||
| top_k=1, | |||||
| filter=f"ref_doc_id='{id}'", | |||||
| ) | |||||
| response = self._client.query_collection_data(request) | |||||
| return len(response.body.matches.match) > 0 | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| ids_str = ",".join(f"'{id}'" for id in ids) | |||||
| ids_str = f"({ids_str})" | |||||
| request = gpdb_20160503_models.DeleteCollectionDataRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| collection_data=None, | |||||
| collection_data_filter=f"ref_doc_id IN {ids_str}", | |||||
| ) | |||||
| self._client.delete_collection_data(request) | |||||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| request = gpdb_20160503_models.DeleteCollectionDataRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| collection_data=None, | |||||
| collection_data_filter=f"metadata_ ->> '{key}' = '{value}'", | |||||
| ) | |||||
| self._client.delete_collection_data(request) | |||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| score_threshold = kwargs.get("score_threshold") or 0.0 | |||||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| include_values=kwargs.pop("include_values", True), | |||||
| metrics=self.config.metrics, | |||||
| vector=query_vector, | |||||
| content=None, | |||||
| top_k=kwargs.get("top_k", 4), | |||||
| filter=None, | |||||
| ) | |||||
| response = self._client.query_collection_data(request) | |||||
| documents = [] | |||||
| for match in response.body.matches.match: | |||||
| if match.score > score_threshold: | |||||
| metadata = json.loads(match.metadata.get("metadata_")) | |||||
| metadata["score"] = match.score | |||||
| doc = Document( | |||||
| page_content=match.metadata.get("page_content"), | |||||
| vector=match.values.value, | |||||
| metadata=metadata, | |||||
| ) | |||||
| documents.append(doc) | |||||
| documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) | |||||
| return documents | |||||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||||
| dbinstance_id=self.config.instance_id, | |||||
| region_id=self.config.region_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| collection=self._collection_name, | |||||
| include_values=kwargs.pop("include_values", True), | |||||
| metrics=self.config.metrics, | |||||
| vector=None, | |||||
| content=query, | |||||
| top_k=kwargs.get("top_k", 4), | |||||
| filter=None, | |||||
| ) | |||||
| response = self._client.query_collection_data(request) | |||||
| documents = [] | |||||
| for match in response.body.matches.match: | |||||
| if match.score > score_threshold: | |||||
| metadata = json.loads(match.metadata.get("metadata_")) | |||||
| metadata["score"] = match.score | |||||
| doc = Document( | |||||
| page_content=match.metadata.get("page_content"), | |||||
| vector=match.values.value, | |||||
| metadata=metadata, | |||||
| ) | |||||
| documents.append(doc) | |||||
| documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) | |||||
| return documents | |||||
| def delete(self) -> None: | |||||
| try: | |||||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||||
| request = gpdb_20160503_models.DeleteCollectionRequest( | |||||
| collection=self._collection_name, | |||||
| dbinstance_id=self.config.instance_id, | |||||
| namespace=self.config.namespace, | |||||
| namespace_password=self.config.namespace_password, | |||||
| region_id=self.config.region_id, | |||||
| ) | |||||
| self._client.delete_collection(request) | |||||
| except Exception as e: | |||||
| raise e |
| import json | |||||
| import uuid | |||||
| from contextlib import contextmanager | |||||
| from typing import Any | |||||
| import psycopg2.extras | |||||
| import psycopg2.pool | |||||
| from pydantic import BaseModel, model_validator | |||||
| from core.rag.models.document import Document | |||||
| from extensions.ext_redis import redis_client | |||||
| class AnalyticdbVectorBySqlConfig(BaseModel): | |||||
| host: str | |||||
| port: int | |||||
| account: str | |||||
| account_password: str | |||||
| min_connection: int | |||||
| max_connection: int | |||||
| namespace: str = "dify" | |||||
| metrics: str = "cosine" | |||||
| @model_validator(mode="before") | |||||
| @classmethod | |||||
| def validate_config(cls, values: dict) -> dict: | |||||
| if not values["host"]: | |||||
| raise ValueError("config ANALYTICDB_HOST is required") | |||||
| if not values["port"]: | |||||
| raise ValueError("config ANALYTICDB_PORT is required") | |||||
| if not values["account"]: | |||||
| raise ValueError("config ANALYTICDB_ACCOUNT is required") | |||||
| if not values["account_password"]: | |||||
| raise ValueError("config ANALYTICDB_PASSWORD is required") | |||||
| if not values["min_connection"]: | |||||
| raise ValueError("config ANALYTICDB_MIN_CONNECTION is required") | |||||
| if not values["max_connection"]: | |||||
| raise ValueError("config ANALYTICDB_MAX_CONNECTION is required") | |||||
| if values["min_connection"] > values["max_connection"]: | |||||
| raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION") | |||||
| return values | |||||
| class AnalyticdbVectorBySql: | |||||
| def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig): | |||||
| self._collection_name = collection_name.lower() | |||||
| self.databaseName = "knowledgebase" | |||||
| self.config = config | |||||
| self.table_name = f"{self.config.namespace}.{self._collection_name}" | |||||
| self.pool = None | |||||
| self._initialize() | |||||
| if not self.pool: | |||||
| self.pool = self._create_connection_pool() | |||||
| def _initialize(self) -> None: | |||||
| cache_key = f"vector_initialize_{self.config.host}" | |||||
| lock_name = f"{cache_key}_lock" | |||||
| with redis_client.lock(lock_name, timeout=20): | |||||
| database_exist_cache_key = f"vector_initialize_{self.config.host}" | |||||
| if redis_client.get(database_exist_cache_key): | |||||
| return | |||||
| self._initialize_vector_database() | |||||
| redis_client.set(database_exist_cache_key, 1, ex=3600) | |||||
| def _create_connection_pool(self): | |||||
| return psycopg2.pool.SimpleConnectionPool( | |||||
| self.config.min_connection, | |||||
| self.config.max_connection, | |||||
| host=self.config.host, | |||||
| port=self.config.port, | |||||
| user=self.config.account, | |||||
| password=self.config.account_password, | |||||
| database=self.databaseName, | |||||
| ) | |||||
| @contextmanager | |||||
| def _get_cursor(self): | |||||
| conn = self.pool.getconn() | |||||
| cur = conn.cursor() | |||||
| try: | |||||
| yield cur | |||||
| finally: | |||||
| cur.close() | |||||
| conn.commit() | |||||
| self.pool.putconn(conn) | |||||
| def _initialize_vector_database(self) -> None: | |||||
| conn = psycopg2.connect( | |||||
| host=self.config.host, | |||||
| port=self.config.port, | |||||
| user=self.config.account, | |||||
| password=self.config.account_password, | |||||
| database="postgres", | |||||
| ) | |||||
| conn.autocommit = True | |||||
| cur = conn.cursor() | |||||
| try: | |||||
| cur.execute(f"CREATE DATABASE {self.databaseName}") | |||||
| except Exception as e: | |||||
| if "already exists" in str(e): | |||||
| return | |||||
| raise e | |||||
| finally: | |||||
| cur.close() | |||||
| conn.close() | |||||
| self.pool = self._create_connection_pool() | |||||
| with self._get_cursor() as cur: | |||||
| try: | |||||
| cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)") | |||||
| cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple") | |||||
| except Exception as e: | |||||
| if "already exists" not in str(e): | |||||
| raise e | |||||
| cur.execute( | |||||
| "CREATE OR REPLACE FUNCTION " | |||||
| "public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) " | |||||
| "RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ " | |||||
| "SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) " | |||||
| "FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) " | |||||
| "AS words_only;$function$" | |||||
| ) | |||||
| cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}") | |||||
| def _create_collection_if_not_exists(self, embedding_dimension: int): | |||||
| cache_key = f"vector_indexing_{self._collection_name}" | |||||
| lock_name = f"{cache_key}_lock" | |||||
| with redis_client.lock(lock_name, timeout=20): | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | |||||
| return | |||||
| with self._get_cursor() as cur: | |||||
| cur.execute( | |||||
| f"CREATE TABLE IF NOT EXISTS {self.table_name}(" | |||||
| f"id text PRIMARY KEY," | |||||
| f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, " | |||||
| f"to_tsvector TSVECTOR" | |||||
| f") WITH (fillfactor=70) DISTRIBUTED BY (id);" | |||||
| ) | |||||
| if embedding_dimension is not None: | |||||
| index_name = f"{self._collection_name}_embedding_idx" | |||||
| cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN") | |||||
| cur.execute( | |||||
| f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) " | |||||
| f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', " | |||||
| f"pq_enable=0, external_storage=0)" | |||||
| ) | |||||
| cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)") | |||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| values = [] | |||||
| id_prefix = str(uuid.uuid4()) + "_" | |||||
| sql = f""" | |||||
| INSERT INTO {self.table_name} | |||||
| (id, ref_doc_id, vector, page_content, metadata_, to_tsvector) | |||||
| VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s)); | |||||
| """ | |||||
| for i, doc in enumerate(documents): | |||||
| values.append( | |||||
| ( | |||||
| id_prefix + str(i), | |||||
| doc.metadata.get("doc_id", str(uuid.uuid4())), | |||||
| embeddings[i], | |||||
| doc.page_content, | |||||
| json.dumps(doc.metadata), | |||||
| doc.page_content, | |||||
| ) | |||||
| ) | |||||
| with self._get_cursor() as cur: | |||||
| psycopg2.extras.execute_batch(cur, sql, values) | |||||
| def text_exists(self, id: str) -> bool: | |||||
| with self._get_cursor() as cur: | |||||
| cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,)) | |||||
| return cur.fetchone() is not None | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| with self._get_cursor() as cur: | |||||
| try: | |||||
| cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),)) | |||||
| except Exception as e: | |||||
| if "does not exist" not in str(e): | |||||
| raise e | |||||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||||
| with self._get_cursor() as cur: | |||||
| try: | |||||
| cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value)) | |||||
| except Exception as e: | |||||
| if "does not exist" not in str(e): | |||||
| raise e | |||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||||
| top_k = kwargs.get("top_k", 4) | |||||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||||
| with self._get_cursor() as cur: | |||||
| query_vector_str = json.dumps(query_vector) | |||||
| query_vector_str = "{" + query_vector_str[1:-1] + "}" | |||||
| cur.execute( | |||||
| f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, " | |||||
| f"t.page_content as page_content, t.metadata_ AS metadata_ " | |||||
| f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score " | |||||
| f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t", | |||||
| (query_vector_str,), | |||||
| ) | |||||
| documents = [] | |||||
| for record in cur: | |||||
| id, vector, score, page_content, metadata = record | |||||
| if score > score_threshold: | |||||
| metadata["score"] = score | |||||
| doc = Document( | |||||
| page_content=page_content, | |||||
| vector=vector, | |||||
| metadata=metadata, | |||||
| ) | |||||
| documents.append(doc) | |||||
| return documents | |||||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||||
| top_k = kwargs.get("top_k", 4) | |||||
| with self._get_cursor() as cur: | |||||
| cur.execute( | |||||
| f"""SELECT id, vector, page_content, metadata_, | |||||
| ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score | |||||
| FROM {self.table_name} | |||||
| WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') | |||||
| ORDER BY score DESC | |||||
| LIMIT {top_k}""", | |||||
| (f"'{query}'", f"'{query}'"), | |||||
| ) | |||||
| documents = [] | |||||
| for record in cur: | |||||
| id, vector, page_content, metadata, score = record | |||||
| metadata["score"] = score | |||||
| doc = Document( | |||||
| page_content=page_content, | |||||
| vector=vector, | |||||
| metadata=metadata, | |||||
| ) | |||||
| documents.append(doc) | |||||
| return documents | |||||
| def delete(self) -> None: | |||||
| with self._get_cursor() as cur: | |||||
| cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") |
| from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbConfig, AnalyticdbVector | from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbConfig, AnalyticdbVector | ||||
| from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig | |||||
| from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig | |||||
| from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis | from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis | ||||
| class AnalyticdbVectorTest(AbstractVectorTest): | class AnalyticdbVectorTest(AbstractVectorTest): | ||||
| def __init__(self): | |||||
| def __init__(self, config_type: str): | |||||
| super().__init__() | super().__init__() | ||||
| # Analyticdb requires collection_name length less than 60. | # Analyticdb requires collection_name length less than 60. | ||||
| # it's ok for normal usage. | # it's ok for normal usage. | ||||
| self.collection_name = self.collection_name.replace("_test", "") | self.collection_name = self.collection_name.replace("_test", "") | ||||
| self.vector = AnalyticdbVector( | |||||
| collection_name=self.collection_name, | |||||
| config=AnalyticdbConfig( | |||||
| access_key_id="test_key_id", | |||||
| access_key_secret="test_key_secret", | |||||
| region_id="test_region", | |||||
| instance_id="test_id", | |||||
| account="test_account", | |||||
| account_password="test_passwd", | |||||
| namespace="difytest_namespace", | |||||
| collection="difytest_collection", | |||||
| namespace_password="test_passwd", | |||||
| ), | |||||
| ) | |||||
| if config_type == "sql": | |||||
| self.vector = AnalyticdbVector( | |||||
| collection_name=self.collection_name, | |||||
| sql_config=AnalyticdbVectorBySqlConfig( | |||||
| host="test_host", | |||||
| port=5432, | |||||
| account="test_account", | |||||
| account_password="test_passwd", | |||||
| namespace="difytest_namespace", | |||||
| ), | |||||
| api_config=None, | |||||
| ) | |||||
| else: | |||||
| self.vector = AnalyticdbVector( | |||||
| collection_name=self.collection_name, | |||||
| sql_config=None, | |||||
| api_config=AnalyticdbVectorOpenAPIConfig( | |||||
| access_key_id="test_key_id", | |||||
| access_key_secret="test_key_secret", | |||||
| region_id="test_region", | |||||
| instance_id="test_id", | |||||
| account="test_account", | |||||
| account_password="test_passwd", | |||||
| namespace="difytest_namespace", | |||||
| collection="difytest_collection", | |||||
| namespace_password="test_passwd", | |||||
| ), | |||||
| ) | |||||
| def run_all_tests(self): | def run_all_tests(self): | ||||
| self.vector.delete() | self.vector.delete() | ||||
| def test_chroma_vector(setup_mock_redis): | def test_chroma_vector(setup_mock_redis): | ||||
| AnalyticdbVectorTest().run_all_tests() | |||||
| AnalyticdbVectorTest("api").run_all_tests() | |||||
| AnalyticdbVectorTest("sql").run_all_tests() |
| ANALYTICDB_PASSWORD=testpassword | ANALYTICDB_PASSWORD=testpassword | ||||
| ANALYTICDB_NAMESPACE=dify | ANALYTICDB_NAMESPACE=dify | ||||
| ANALYTICDB_NAMESPACE_PASSWORD=difypassword | ANALYTICDB_NAMESPACE_PASSWORD=difypassword | ||||
| ANALYTICDB_HOST=gp-test.aliyuncs.com | |||||
| ANALYTICDB_PORT=5432 | |||||
| ANALYTICDB_MIN_CONNECTION=1 | |||||
| ANALYTICDB_MAX_CONNECTION=5 | |||||
| # TiDB vector configurations, only available when VECTOR_STORE is `tidb` | # TiDB vector configurations, only available when VECTOR_STORE is `tidb` | ||||
| TIDB_VECTOR_HOST=tidb | TIDB_VECTOR_HOST=tidb |
| ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-} | ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-} | ||||
| ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify} | ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify} | ||||
| ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-} | ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-} | ||||
| ANALYTICDB_HOST: ${ANALYTICDB_HOST:-} | |||||
| ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432} | |||||
| ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1} | |||||
| ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5} | |||||
| OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch} | OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch} | ||||
| OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200} | OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200} | ||||
| OPENSEARCH_USER: ${OPENSEARCH_USER:-admin} | OPENSEARCH_USER: ${OPENSEARCH_USER:-admin} |