| default=False, | default=False, | ||||
| ) | ) | ||||
| TIDB_SERVERLESS_NUMBER: PositiveInt = Field( | |||||
| description="number of tidb serverless cluster", | |||||
| default=500, | |||||
| ) | |||||
| class WorkspaceConfig(BaseSettings): | class WorkspaceConfig(BaseSettings): | ||||
| """ | """ |
| from configs.middleware.vdb.qdrant_config import QdrantConfig | from configs.middleware.vdb.qdrant_config import QdrantConfig | ||||
| from configs.middleware.vdb.relyt_config import RelytConfig | from configs.middleware.vdb.relyt_config import RelytConfig | ||||
| from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig | from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig | ||||
| from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig | |||||
| from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig | from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig | ||||
| from configs.middleware.vdb.upstash_config import UpstashConfig | from configs.middleware.vdb.upstash_config import UpstashConfig | ||||
| from configs.middleware.vdb.vikingdb_config import VikingDBConfig | from configs.middleware.vdb.vikingdb_config import VikingDBConfig | ||||
| default=None, | default=None, | ||||
| ) | ) | ||||
| VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field( | |||||
| description="Enable whitelist for vector store.", | |||||
| default=False, | |||||
| ) | |||||
| class KeywordStoreConfig(BaseSettings): | class KeywordStoreConfig(BaseSettings): | ||||
| KEYWORD_STORE: str = Field( | KEYWORD_STORE: str = Field( | ||||
| InternalTestConfig, | InternalTestConfig, | ||||
| VikingDBConfig, | VikingDBConfig, | ||||
| UpstashConfig, | UpstashConfig, | ||||
| TidbOnQdrantConfig, | |||||
| ): | ): | ||||
| pass | pass |
| from typing import Optional | |||||
| from pydantic import Field, NonNegativeInt, PositiveInt | |||||
| from pydantic_settings import BaseSettings | |||||
| class TidbOnQdrantConfig(BaseSettings): | |||||
| """ | |||||
| Tidb on Qdrant configs | |||||
| """ | |||||
| TIDB_ON_QDRANT_URL: Optional[str] = Field( | |||||
| description="Tidb on Qdrant url", | |||||
| default=None, | |||||
| ) | |||||
| TIDB_ON_QDRANT_API_KEY: Optional[str] = Field( | |||||
| description="Tidb on Qdrant api key", | |||||
| default=None, | |||||
| ) | |||||
| TIDB_ON_QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field( | |||||
| description="Tidb on Qdrant client timeout in seconds", | |||||
| default=20, | |||||
| ) | |||||
| TIDB_ON_QDRANT_GRPC_ENABLED: bool = Field( | |||||
| description="whether enable grpc support for Tidb on Qdrant connection", | |||||
| default=False, | |||||
| ) | |||||
| TIDB_ON_QDRANT_GRPC_PORT: PositiveInt = Field( | |||||
| description="Tidb on Qdrant grpc port", | |||||
| default=6334, | |||||
| ) | |||||
| TIDB_PUBLIC_KEY: Optional[str] = Field( | |||||
| description="Tidb account public key", | |||||
| default=None, | |||||
| ) | |||||
| TIDB_PRIVATE_KEY: Optional[str] = Field( | |||||
| description="Tidb account private key", | |||||
| default=None, | |||||
| ) | |||||
| TIDB_API_URL: Optional[str] = Field( | |||||
| description="Tidb API url", | |||||
| default=None, | |||||
| ) | |||||
| TIDB_IAM_API_URL: Optional[str] = Field( | |||||
| description="Tidb IAM API url", | |||||
| default=None, | |||||
| ) | |||||
| TIDB_REGION: Optional[str] = Field( | |||||
| description="Tidb serverless region", | |||||
| default="regions/aws-us-east-1", | |||||
| ) | |||||
| TIDB_PROJECT_ID: Optional[str] = Field( | |||||
| description="Tidb project id", | |||||
| default=None, | |||||
| ) |
| | VectorType.ORACLE | | VectorType.ORACLE | ||||
| | VectorType.ELASTICSEARCH | | VectorType.ELASTICSEARCH | ||||
| | VectorType.PGVECTOR | | VectorType.PGVECTOR | ||||
| | VectorType.TIDB_ON_QDRANT | |||||
| ): | ): | ||||
| return { | return { | ||||
| "retrieval_method": [ | "retrieval_method": [ |
| from typing import Optional | |||||
| from pydantic import BaseModel | |||||
| class ClusterEntity(BaseModel): | |||||
| """ | |||||
| Model Config Entity. | |||||
| """ | |||||
| name: str | |||||
| cluster_id: str | |||||
| displayName: str | |||||
| region: str | |||||
| spendingLimit: Optional[int] = 1000 | |||||
| version: str | |||||
| createdBy: str |
| import json | |||||
| import os | |||||
| import uuid | |||||
| from collections.abc import Generator, Iterable, Sequence | |||||
| from itertools import islice | |||||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | |||||
| import qdrant_client | |||||
| import requests | |||||
| from flask import current_app | |||||
| from pydantic import BaseModel | |||||
| from qdrant_client.http import models as rest | |||||
| from qdrant_client.http.models import ( | |||||
| FilterSelector, | |||||
| HnswConfigDiff, | |||||
| PayloadSchemaType, | |||||
| TextIndexParams, | |||||
| TextIndexType, | |||||
| TokenizerType, | |||||
| ) | |||||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||||
| from requests.auth import HTTPDigestAuth | |||||
| from configs import dify_config | |||||
| from core.rag.datasource.vdb.field import Field | |||||
| from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService | |||||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||||
| from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | |||||
| from core.rag.datasource.vdb.vector_type import VectorType | |||||
| from core.rag.embedding.embedding_base import Embeddings | |||||
| from core.rag.models.document import Document | |||||
| from extensions.ext_database import db | |||||
| from extensions.ext_redis import redis_client | |||||
| from models.dataset import Dataset, TidbAuthBinding | |||||
| if TYPE_CHECKING: | |||||
| from qdrant_client import grpc # noqa | |||||
| from qdrant_client.conversions import common_types | |||||
| from qdrant_client.http import models as rest | |||||
| DictFilter = dict[str, Union[str, int, bool, dict, list]] | |||||
| MetadataFilter = Union[DictFilter, common_types.Filter] | |||||
| class TidbOnQdrantConfig(BaseModel): | |||||
| endpoint: str | |||||
| api_key: Optional[str] = None | |||||
| timeout: float = 20 | |||||
| root_path: Optional[str] = None | |||||
| grpc_port: int = 6334 | |||||
| prefer_grpc: bool = False | |||||
| def to_qdrant_params(self): | |||||
| if self.endpoint and self.endpoint.startswith("path:"): | |||||
| path = self.endpoint.replace("path:", "") | |||||
| if not os.path.isabs(path): | |||||
| path = os.path.join(self.root_path, path) | |||||
| return {"path": path} | |||||
| else: | |||||
| return { | |||||
| "url": self.endpoint, | |||||
| "api_key": self.api_key, | |||||
| "timeout": self.timeout, | |||||
| "verify": False, | |||||
| "grpc_port": self.grpc_port, | |||||
| "prefer_grpc": self.prefer_grpc, | |||||
| } | |||||
| class TidbConfig(BaseModel): | |||||
| api_url: str | |||||
| public_key: str | |||||
| private_key: str | |||||
| class TidbOnQdrantVector(BaseVector): | |||||
| def __init__(self, collection_name: str, group_id: str, config: TidbOnQdrantConfig, distance_func: str = "Cosine"): | |||||
| super().__init__(collection_name) | |||||
| self._client_config = config | |||||
| self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) | |||||
| self._distance_func = distance_func.upper() | |||||
| self._group_id = group_id | |||||
| def get_type(self) -> str: | |||||
| return VectorType.TIDB_ON_QDRANT | |||||
| def to_index_struct(self) -> dict: | |||||
| return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} | |||||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| if texts: | |||||
| # get embedding vector size | |||||
| vector_size = len(embeddings[0]) | |||||
| # get collection name | |||||
| collection_name = self._collection_name | |||||
| # create collection | |||||
| self.create_collection(collection_name, vector_size) | |||||
| self.add_texts(texts, embeddings, **kwargs) | |||||
| def create_collection(self, collection_name: str, vector_size: int): | |||||
| lock_name = "vector_indexing_lock_{}".format(collection_name) | |||||
| with redis_client.lock(lock_name, timeout=20): | |||||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||||
| if redis_client.get(collection_exist_cache_key): | |||||
| return | |||||
| collection_name = collection_name or uuid.uuid4().hex | |||||
| all_collection_name = [] | |||||
| collections_response = self._client.get_collections() | |||||
| collection_list = collections_response.collections | |||||
| for collection in collection_list: | |||||
| all_collection_name.append(collection.name) | |||||
| if collection_name not in all_collection_name: | |||||
| from qdrant_client.http import models as rest | |||||
| vectors_config = rest.VectorParams( | |||||
| size=vector_size, | |||||
| distance=rest.Distance[self._distance_func], | |||||
| ) | |||||
| hnsw_config = HnswConfigDiff( | |||||
| m=0, | |||||
| payload_m=16, | |||||
| ef_construct=100, | |||||
| full_scan_threshold=10000, | |||||
| max_indexing_threads=0, | |||||
| on_disk=False, | |||||
| ) | |||||
| self._client.recreate_collection( | |||||
| collection_name=collection_name, | |||||
| vectors_config=vectors_config, | |||||
| hnsw_config=hnsw_config, | |||||
| timeout=int(self._client_config.timeout), | |||||
| ) | |||||
| # create group_id payload index | |||||
| self._client.create_payload_index( | |||||
| collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD | |||||
| ) | |||||
| # create doc_id payload index | |||||
| self._client.create_payload_index( | |||||
| collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD | |||||
| ) | |||||
| # create full text index | |||||
| text_index_params = TextIndexParams( | |||||
| type=TextIndexType.TEXT, | |||||
| tokenizer=TokenizerType.MULTILINGUAL, | |||||
| min_token_len=2, | |||||
| max_token_len=20, | |||||
| lowercase=True, | |||||
| ) | |||||
| self._client.create_payload_index( | |||||
| collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params | |||||
| ) | |||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||||
| uuids = self._get_uuids(documents) | |||||
| texts = [d.page_content for d in documents] | |||||
| metadatas = [d.metadata for d in documents] | |||||
| added_ids = [] | |||||
| for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): | |||||
| self._client.upsert(collection_name=self._collection_name, points=points) | |||||
| added_ids.extend(batch_ids) | |||||
| return added_ids | |||||
| def _generate_rest_batches( | |||||
| self, | |||||
| texts: Iterable[str], | |||||
| embeddings: list[list[float]], | |||||
| metadatas: Optional[list[dict]] = None, | |||||
| ids: Optional[Sequence[str]] = None, | |||||
| batch_size: int = 64, | |||||
| group_id: Optional[str] = None, | |||||
| ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: | |||||
| from qdrant_client.http import models as rest | |||||
| texts_iterator = iter(texts) | |||||
| embeddings_iterator = iter(embeddings) | |||||
| metadatas_iterator = iter(metadatas or []) | |||||
| ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) | |||||
| while batch_texts := list(islice(texts_iterator, batch_size)): | |||||
| # Take the corresponding metadata and id for each text in a batch | |||||
| batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None | |||||
| batch_ids = list(islice(ids_iterator, batch_size)) | |||||
| # Generate the embeddings for all the texts in a batch | |||||
| batch_embeddings = list(islice(embeddings_iterator, batch_size)) | |||||
| points = [ | |||||
| rest.PointStruct( | |||||
| id=point_id, | |||||
| vector=vector, | |||||
| payload=payload, | |||||
| ) | |||||
| for point_id, vector, payload in zip( | |||||
| batch_ids, | |||||
| batch_embeddings, | |||||
| self._build_payloads( | |||||
| batch_texts, | |||||
| batch_metadatas, | |||||
| Field.CONTENT_KEY.value, | |||||
| Field.METADATA_KEY.value, | |||||
| group_id, | |||||
| Field.GROUP_KEY.value, | |||||
| ), | |||||
| ) | |||||
| ] | |||||
| yield batch_ids, points | |||||
| @classmethod | |||||
| def _build_payloads( | |||||
| cls, | |||||
| texts: Iterable[str], | |||||
| metadatas: Optional[list[dict]], | |||||
| content_payload_key: str, | |||||
| metadata_payload_key: str, | |||||
| group_id: str, | |||||
| group_payload_key: str, | |||||
| ) -> list[dict]: | |||||
| payloads = [] | |||||
| for i, text in enumerate(texts): | |||||
| if text is None: | |||||
| raise ValueError( | |||||
| "At least one of the texts is None. Please remove it before " | |||||
| "calling .from_texts or .add_texts on Qdrant instance." | |||||
| ) | |||||
| metadata = metadatas[i] if metadatas is not None else None | |||||
| payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id}) | |||||
| return payloads | |||||
| def delete_by_metadata_field(self, key: str, value: str): | |||||
| from qdrant_client.http import models | |||||
| from qdrant_client.http.exceptions import UnexpectedResponse | |||||
| try: | |||||
| filter = models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key=f"metadata.{key}", | |||||
| match=models.MatchValue(value=value), | |||||
| ), | |||||
| ], | |||||
| ) | |||||
| self._reload_if_needed() | |||||
| self._client.delete( | |||||
| collection_name=self._collection_name, | |||||
| points_selector=FilterSelector(filter=filter), | |||||
| ) | |||||
| except UnexpectedResponse as e: | |||||
| # Collection does not exist, so return | |||||
| if e.status_code == 404: | |||||
| return | |||||
| # Some other error occurred, so re-raise the exception | |||||
| else: | |||||
| raise e | |||||
| def delete(self): | |||||
| from qdrant_client.http.exceptions import UnexpectedResponse | |||||
| try: | |||||
| self._client.delete_collection(collection_name=self._collection_name) | |||||
| except UnexpectedResponse as e: | |||||
| # Collection does not exist, so return | |||||
| if e.status_code == 404: | |||||
| return | |||||
| # Some other error occurred, so re-raise the exception | |||||
| else: | |||||
| raise e | |||||
| def delete_by_ids(self, ids: list[str]) -> None: | |||||
| from qdrant_client.http import models | |||||
| from qdrant_client.http.exceptions import UnexpectedResponse | |||||
| for node_id in ids: | |||||
| try: | |||||
| filter = models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="metadata.doc_id", | |||||
| match=models.MatchValue(value=node_id), | |||||
| ), | |||||
| ], | |||||
| ) | |||||
| self._client.delete( | |||||
| collection_name=self._collection_name, | |||||
| points_selector=FilterSelector(filter=filter), | |||||
| ) | |||||
| except UnexpectedResponse as e: | |||||
| # Collection does not exist, so return | |||||
| if e.status_code == 404: | |||||
| return | |||||
| # Some other error occurred, so re-raise the exception | |||||
| else: | |||||
| raise e | |||||
| def text_exists(self, id: str) -> bool: | |||||
| all_collection_name = [] | |||||
| collections_response = self._client.get_collections() | |||||
| collection_list = collections_response.collections | |||||
| for collection in collection_list: | |||||
| all_collection_name.append(collection.name) | |||||
| if self._collection_name not in all_collection_name: | |||||
| return False | |||||
| response = self._client.retrieve(collection_name=self._collection_name, ids=[id]) | |||||
| return len(response) > 0 | |||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||||
| from qdrant_client.http import models | |||||
| filter = models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="group_id", | |||||
| match=models.MatchValue(value=self._group_id), | |||||
| ), | |||||
| ], | |||||
| ) | |||||
| results = self._client.search( | |||||
| collection_name=self._collection_name, | |||||
| query_vector=query_vector, | |||||
| query_filter=filter, | |||||
| limit=kwargs.get("top_k", 4), | |||||
| with_payload=True, | |||||
| with_vectors=True, | |||||
| score_threshold=kwargs.get("score_threshold", 0.0), | |||||
| ) | |||||
| docs = [] | |||||
| for result in results: | |||||
| metadata = result.payload.get(Field.METADATA_KEY.value) or {} | |||||
| # duplicate check score threshold | |||||
| score_threshold = kwargs.get("score_threshold") or 0.0 | |||||
| if result.score > score_threshold: | |||||
| metadata["score"] = result.score | |||||
| doc = Document( | |||||
| page_content=result.payload.get(Field.CONTENT_KEY.value), | |||||
| metadata=metadata, | |||||
| ) | |||||
| docs.append(doc) | |||||
| # Sort the documents by score in descending order | |||||
| docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) | |||||
| return docs | |||||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||||
| """Return docs most similar by bm25. | |||||
| Returns: | |||||
| List of documents most similar to the query text and distance for each. | |||||
| """ | |||||
| from qdrant_client.http import models | |||||
| scroll_filter = models.Filter( | |||||
| must=[ | |||||
| models.FieldCondition( | |||||
| key="page_content", | |||||
| match=models.MatchText(text=query), | |||||
| ) | |||||
| ] | |||||
| ) | |||||
| response = self._client.scroll( | |||||
| collection_name=self._collection_name, | |||||
| scroll_filter=scroll_filter, | |||||
| limit=kwargs.get("top_k", 2), | |||||
| with_payload=True, | |||||
| with_vectors=True, | |||||
| ) | |||||
| results = response[0] | |||||
| documents = [] | |||||
| for result in results: | |||||
| if result: | |||||
| document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) | |||||
| document.metadata["vector"] = result.vector | |||||
| documents.append(document) | |||||
| return documents | |||||
| def _reload_if_needed(self): | |||||
| if isinstance(self._client, QdrantLocal): | |||||
| self._client = cast(QdrantLocal, self._client) | |||||
| self._client._load() | |||||
| @classmethod | |||||
| def _document_from_scored_point( | |||||
| cls, | |||||
| scored_point: Any, | |||||
| content_payload_key: str, | |||||
| metadata_payload_key: str, | |||||
| ) -> Document: | |||||
| return Document( | |||||
| page_content=scored_point.payload.get(content_payload_key), | |||||
| metadata=scored_point.payload.get(metadata_payload_key) or {}, | |||||
| ) | |||||
| class TidbOnQdrantVectorFactory(AbstractVectorFactory): | |||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: | |||||
| tidb_auth_binding = ( | |||||
| db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() | |||||
| ) | |||||
| if not tidb_auth_binding: | |||||
| idle_tidb_auth_binding = ( | |||||
| db.session.query(TidbAuthBinding) | |||||
| .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") | |||||
| .limit(1) | |||||
| .one_or_none() | |||||
| ) | |||||
| if idle_tidb_auth_binding: | |||||
| idle_tidb_auth_binding.active = True | |||||
| idle_tidb_auth_binding.tenant_id = dataset.tenant_id | |||||
| db.session.commit() | |||||
| TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" | |||||
| else: | |||||
| with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): | |||||
| tidb_auth_binding = ( | |||||
| db.session.query(TidbAuthBinding) | |||||
| .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||||
| .one_or_none() | |||||
| ) | |||||
| if tidb_auth_binding: | |||||
| TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | |||||
| else: | |||||
| new_cluster = TidbService.create_tidb_serverless_cluster( | |||||
| dify_config.TIDB_PROJECT_ID, | |||||
| dify_config.TIDB_API_URL, | |||||
| dify_config.TIDB_IAM_API_URL, | |||||
| dify_config.TIDB_PUBLIC_KEY, | |||||
| dify_config.TIDB_PRIVATE_KEY, | |||||
| dify_config.TIDB_REGION, | |||||
| ) | |||||
| new_tidb_auth_binding = TidbAuthBinding( | |||||
| cluster_id=new_cluster["cluster_id"], | |||||
| cluster_name=new_cluster["cluster_name"], | |||||
| account=new_cluster["account"], | |||||
| password=new_cluster["password"], | |||||
| tenant_id=dataset.tenant_id, | |||||
| active=True, | |||||
| status="ACTIVE", | |||||
| ) | |||||
| db.session.add(new_tidb_auth_binding) | |||||
| db.session.commit() | |||||
| TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" | |||||
| else: | |||||
| TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | |||||
| if dataset.index_struct_dict: | |||||
| class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||||
| dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_ON_QDRANT, collection_name)) | |||||
| config = current_app.config | |||||
| return TidbOnQdrantVector( | |||||
| collection_name=collection_name, | |||||
| group_id=dataset.id, | |||||
| config=TidbOnQdrantConfig( | |||||
| endpoint=dify_config.TIDB_ON_QDRANT_URL, | |||||
| api_key=TIDB_ON_QDRANT_API_KEY, | |||||
| root_path=config.root_path, | |||||
| timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, | |||||
| grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT, | |||||
| prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED, | |||||
| ), | |||||
| ) | |||||
| def create_tidb_serverless_cluster(self, tidb_config: TidbConfig, display_name: str, region: str): | |||||
| """ | |||||
| Creates a new TiDB Serverless cluster. | |||||
| :param tidb_config: The configuration for the TiDB Cloud API. | |||||
| :param display_name: The user-friendly display name of the cluster (required). | |||||
| :param region: The region where the cluster will be created (required). | |||||
| :return: The response from the API. | |||||
| """ | |||||
| region_object = { | |||||
| "name": region, | |||||
| } | |||||
| labels = { | |||||
| "tidb.cloud/project": "1372813089454548012", | |||||
| } | |||||
| cluster_data = {"displayName": display_name, "region": region_object, "labels": labels} | |||||
| response = requests.post( | |||||
| f"{tidb_config.api_url}/clusters", | |||||
| json=cluster_data, | |||||
| auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), | |||||
| ) | |||||
| if response.status_code == 200: | |||||
| return response.json() | |||||
| else: | |||||
| response.raise_for_status() | |||||
| def change_tidb_serverless_root_password(self, tidb_config: TidbConfig, cluster_id: str, new_password: str): | |||||
| """ | |||||
| Changes the root password of a specific TiDB Serverless cluster. | |||||
| :param tidb_config: The configuration for the TiDB Cloud API. | |||||
| :param cluster_id: The ID of the cluster for which the password is to be changed (required). | |||||
| :param new_password: The new password for the root user (required). | |||||
| :return: The response from the API. | |||||
| """ | |||||
| body = {"password": new_password} | |||||
| response = requests.put( | |||||
| f"{tidb_config.api_url}/clusters/{cluster_id}/password", | |||||
| json=body, | |||||
| auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), | |||||
| ) | |||||
| if response.status_code == 200: | |||||
| return response.json() | |||||
| else: | |||||
| response.raise_for_status() |
| import time | |||||
| import uuid | |||||
| import requests | |||||
| from requests.auth import HTTPDigestAuth | |||||
| from extensions.ext_database import db | |||||
| from extensions.ext_redis import redis_client | |||||
| from models.dataset import TidbAuthBinding | |||||
| class TidbService: | |||||
| @staticmethod | |||||
| def create_tidb_serverless_cluster( | |||||
| project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str | |||||
| ): | |||||
| """ | |||||
| Creates a new TiDB Serverless cluster. | |||||
| :param project_id: The project ID of the TiDB Cloud project (required). | |||||
| :param api_url: The URL of the TiDB Cloud API (required). | |||||
| :param iam_url: The URL of the TiDB Cloud IAM API (required). | |||||
| :param public_key: The public key for the API (required). | |||||
| :param private_key: The private key for the API (required). | |||||
| :param display_name: The user-friendly display name of the cluster (required). | |||||
| :param region: The region where the cluster will be created (required). | |||||
| :return: The response from the API. | |||||
| """ | |||||
| region_object = { | |||||
| "name": region, | |||||
| } | |||||
| labels = { | |||||
| "tidb.cloud/project": project_id, | |||||
| } | |||||
| spending_limit = { | |||||
| "monthly": 100, | |||||
| } | |||||
| password = str(uuid.uuid4()).replace("-", "")[:16] | |||||
| display_name = str(uuid.uuid4()).replace("-", "")[:16] | |||||
| cluster_data = { | |||||
| "displayName": display_name, | |||||
| "region": region_object, | |||||
| "labels": labels, | |||||
| "spendingLimit": spending_limit, | |||||
| "rootPassword": password, | |||||
| } | |||||
| response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key)) | |||||
| if response.status_code == 200: | |||||
| response_data = response.json() | |||||
| cluster_id = response_data["clusterId"] | |||||
| retry_count = 0 | |||||
| max_retries = 30 | |||||
| while retry_count < max_retries: | |||||
| cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) | |||||
| if cluster_response["state"] == "ACTIVE": | |||||
| user_prefix = cluster_response["userPrefix"] | |||||
| return { | |||||
| "cluster_id": cluster_id, | |||||
| "cluster_name": display_name, | |||||
| "account": f"{user_prefix}.root", | |||||
| "password": password, | |||||
| } | |||||
| time.sleep(30) # wait 30 seconds before retrying | |||||
| retry_count += 1 | |||||
| else: | |||||
| response.raise_for_status() | |||||
| @staticmethod | |||||
| def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): | |||||
| """ | |||||
| Deletes a specific TiDB Serverless cluster. | |||||
| :param api_url: The URL of the TiDB Cloud API (required). | |||||
| :param public_key: The public key for the API (required). | |||||
| :param private_key: The private key for the API (required). | |||||
| :param cluster_id: The ID of the cluster to be deleted (required). | |||||
| :return: The response from the API. | |||||
| """ | |||||
| response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) | |||||
| if response.status_code == 200: | |||||
| return response.json() | |||||
| else: | |||||
| response.raise_for_status() | |||||
| @staticmethod | |||||
| def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): | |||||
| """ | |||||
| Deletes a specific TiDB Serverless cluster. | |||||
| :param api_url: The URL of the TiDB Cloud API (required). | |||||
| :param public_key: The public key for the API (required). | |||||
| :param private_key: The private key for the API (required). | |||||
| :param cluster_id: The ID of the cluster to be deleted (required). | |||||
| :return: The response from the API. | |||||
| """ | |||||
| response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) | |||||
| if response.status_code == 200: | |||||
| return response.json() | |||||
| else: | |||||
| response.raise_for_status() | |||||
| @staticmethod | |||||
| def change_tidb_serverless_root_password( | |||||
| api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str | |||||
| ): | |||||
| """ | |||||
| Changes the root password of a specific TiDB Serverless cluster. | |||||
| :param api_url: The URL of the TiDB Cloud API (required). | |||||
| :param public_key: The public key for the API (required). | |||||
| :param private_key: The private key for the API (required). | |||||
| :param cluster_id: The ID of the cluster for which the password is to be changed (required).+ | |||||
| :param account: The account for which the password is to be changed (required). | |||||
| :param new_password: The new password for the root user (required). | |||||
| :return: The response from the API. | |||||
| """ | |||||
| body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []} | |||||
| response = requests.patch( | |||||
| f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}", | |||||
| json=body, | |||||
| auth=HTTPDigestAuth(public_key, private_key), | |||||
| ) | |||||
| if response.status_code == 200: | |||||
| return response.json() | |||||
| else: | |||||
| response.raise_for_status() | |||||
| @staticmethod | |||||
| def batch_update_tidb_serverless_cluster_status( | |||||
| tidb_serverless_list: list[TidbAuthBinding], | |||||
| project_id: str, | |||||
| api_url: str, | |||||
| iam_url: str, | |||||
| public_key: str, | |||||
| private_key: str, | |||||
| ) -> list[dict]: | |||||
| """ | |||||
| Update the status of a new TiDB Serverless cluster. | |||||
| :param project_id: The project ID of the TiDB Cloud project (required). | |||||
| :param api_url: The URL of the TiDB Cloud API (required). | |||||
| :param iam_url: The URL of the TiDB Cloud IAM API (required). | |||||
| :param public_key: The public key for the API (required). | |||||
| :param private_key: The private key for the API (required). | |||||
| :param display_name: The user-friendly display name of the cluster (required). | |||||
| :param region: The region where the cluster will be created (required). | |||||
| :return: The response from the API. | |||||
| """ | |||||
| clusters = [] | |||||
| tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} | |||||
| cluster_ids = [item.cluster_id for item in tidb_serverless_list] | |||||
| params = {"clusterIds": cluster_ids, "view": "FULL"} | |||||
| response = requests.get( | |||||
| f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key) | |||||
| ) | |||||
| if response.status_code == 200: | |||||
| response_data = response.json() | |||||
| cluster_infos = [] | |||||
| for item in response_data["clusters"]: | |||||
| state = item["state"] | |||||
| userPrefix = item["userPrefix"] | |||||
| if state == "ACTIVE" and len(userPrefix) > 0: | |||||
| cluster_info = tidb_serverless_list_map[item["clusterId"]] | |||||
| cluster_info.status = "ACTIVE" | |||||
| cluster_info.account = f"{userPrefix}.root" | |||||
| db.session.add(cluster_info) | |||||
| db.session.commit() | |||||
| else: | |||||
| response.raise_for_status() | |||||
| @staticmethod | |||||
| def batch_create_tidb_serverless_cluster( | |||||
| batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str | |||||
| ) -> list[dict]: | |||||
| """ | |||||
| Creates a new TiDB Serverless cluster. | |||||
| :param project_id: The project ID of the TiDB Cloud project (required). | |||||
| :param api_url: The URL of the TiDB Cloud API (required). | |||||
| :param iam_url: The URL of the TiDB Cloud IAM API (required). | |||||
| :param public_key: The public key for the API (required). | |||||
| :param private_key: The private key for the API (required). | |||||
| :param display_name: The user-friendly display name of the cluster (required). | |||||
| :param region: The region where the cluster will be created (required). | |||||
| :return: The response from the API. | |||||
| """ | |||||
| clusters = [] | |||||
| for _ in range(batch_size): | |||||
| region_object = { | |||||
| "name": region, | |||||
| } | |||||
| labels = { | |||||
| "tidb.cloud/project": project_id, | |||||
| } | |||||
| spending_limit = { | |||||
| "monthly": 10, | |||||
| } | |||||
| password = str(uuid.uuid4()).replace("-", "")[:16] | |||||
| display_name = str(uuid.uuid4()).replace("-", "") | |||||
| cluster_data = { | |||||
| "cluster": { | |||||
| "displayName": display_name, | |||||
| "region": region_object, | |||||
| "labels": labels, | |||||
| "spendingLimit": spending_limit, | |||||
| "rootPassword": password, | |||||
| } | |||||
| } | |||||
| cache_key = f"tidb_serverless_cluster_password:{display_name}" | |||||
| redis_client.setex(cache_key, 3600, password) | |||||
| clusters.append(cluster_data) | |||||
| request_body = {"requests": clusters} | |||||
| response = requests.post( | |||||
| f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key) | |||||
| ) | |||||
| if response.status_code == 200: | |||||
| response_data = response.json() | |||||
| cluster_infos = [] | |||||
| for item in response_data["clusters"]: | |||||
| cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" | |||||
| password = redis_client.get(cache_key) | |||||
| if not password: | |||||
| continue | |||||
| cluster_info = { | |||||
| "cluster_id": item["clusterId"], | |||||
| "cluster_name": item["displayName"], | |||||
| "account": "root", | |||||
| "password": password.decode("utf-8"), | |||||
| } | |||||
| cluster_infos.append(cluster_info) | |||||
| return cluster_infos | |||||
| else: | |||||
| response.raise_for_status() |
| from core.rag.embedding.cached_embedding import CacheEmbedding | from core.rag.embedding.cached_embedding import CacheEmbedding | ||||
| from core.rag.embedding.embedding_base import Embeddings | from core.rag.embedding.embedding_base import Embeddings | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_database import db | |||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import Dataset | |||||
| from models.dataset import Dataset, Whitelist | |||||
| class AbstractVectorFactory(ABC): | class AbstractVectorFactory(ABC): | ||||
| def _init_vector(self) -> BaseVector: | def _init_vector(self) -> BaseVector: | ||||
| vector_type = dify_config.VECTOR_STORE | vector_type = dify_config.VECTOR_STORE | ||||
| if self._dataset.index_struct_dict: | if self._dataset.index_struct_dict: | ||||
| vector_type = self._dataset.index_struct_dict["type"] | vector_type = self._dataset.index_struct_dict["type"] | ||||
| else: | |||||
| if dify_config.VECTOR_STORE_WHITELIST_ENABLE: | |||||
| whitelist = ( | |||||
| db.session.query(Whitelist) | |||||
| .filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") | |||||
| .one_or_none() | |||||
| ) | |||||
| if whitelist: | |||||
| vector_type = VectorType.TIDB_ON_QDRANT | |||||
| if not vector_type: | if not vector_type: | ||||
| raise ValueError("Vector store must be specified.") | raise ValueError("Vector store must be specified.") | ||||
| from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory | from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory | ||||
| return UpstashVectorFactory | return UpstashVectorFactory | ||||
| case VectorType.TIDB_ON_QDRANT: | |||||
| from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory | |||||
| return TidbOnQdrantVectorFactory | |||||
| case _: | case _: | ||||
| raise ValueError(f"Vector store {vector_type} is not supported.") | raise ValueError(f"Vector store {vector_type} is not supported.") | ||||
| BAIDU = "baidu" | BAIDU = "baidu" | ||||
| VIKINGDB = "vikingdb" | VIKINGDB = "vikingdb" | ||||
| UPSTASH = "upstash" | UPSTASH = "upstash" | ||||
| TIDB_ON_QDRANT = "tidb_on_qdrant" |
| from datetime import timedelta | from datetime import timedelta | ||||
| from celery import Celery, Task | from celery import Celery, Task | ||||
| from celery.schedules import crontab | |||||
| from flask import Flask | from flask import Flask | ||||
| from configs import dify_config | from configs import dify_config | ||||
| imports = [ | imports = [ | ||||
| "schedule.clean_embedding_cache_task", | "schedule.clean_embedding_cache_task", | ||||
| "schedule.clean_unused_datasets_task", | "schedule.clean_unused_datasets_task", | ||||
| "schedule.create_tidb_serverless_task", | |||||
| "schedule.update_tidb_serverless_status_task", | |||||
| ] | ] | ||||
| day = dify_config.CELERY_BEAT_SCHEDULER_TIME | day = dify_config.CELERY_BEAT_SCHEDULER_TIME | ||||
| beat_schedule = { | beat_schedule = { | ||||
| "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", | "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", | ||||
| "schedule": timedelta(days=day), | "schedule": timedelta(days=day), | ||||
| }, | }, | ||||
| "create_tidb_serverless_task": { | |||||
| "task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task", | |||||
| "schedule": crontab(minute="0", hour="*"), | |||||
| }, | |||||
| "update_tidb_serverless_status_task": { | |||||
| "task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task", | |||||
| "schedule": crontab(minute="30", hour="*"), | |||||
| }, | |||||
| } | } | ||||
| celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) | celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) | ||||
| """add-tidb-auth-binding | |||||
| Revision ID: 0251a1c768cc | |||||
| Revises: 63a83fcf12ba | |||||
| Create Date: 2024-08-15 09:56:59.012490 | |||||
| """ | |||||
| import sqlalchemy as sa | |||||
| from alembic import op | |||||
| import models as models | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = '0251a1c768cc' | |||||
| down_revision = 'bbadea11becb' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| op.create_table('tidb_auth_bindings', | |||||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=True), | |||||
| sa.Column('cluster_id', sa.String(length=255), nullable=False), | |||||
| sa.Column('cluster_name', sa.String(length=255), nullable=False), | |||||
| sa.Column('active', sa.Boolean(), server_default=sa.text('false'), nullable=False), | |||||
| sa.Column('status', sa.String(length=255), server_default=sa.text("'CREATING'::character varying"), nullable=False), | |||||
| sa.Column('account', sa.String(length=255), nullable=False), | |||||
| sa.Column('password', sa.String(length=255), nullable=False), | |||||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||||
| sa.PrimaryKeyConstraint('id', name='tidb_auth_bindings_pkey') | |||||
| ) | |||||
| with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op: | |||||
| batch_op.create_index('tidb_auth_bindings_active_idx', ['active'], unique=False) | |||||
| batch_op.create_index('tidb_auth_bindings_status_idx', ['status'], unique=False) | |||||
| batch_op.create_index('tidb_auth_bindings_created_at_idx', ['created_at'], unique=False) | |||||
| batch_op.create_index('tidb_auth_bindings_tenant_idx', ['tenant_id'], unique=False) | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op: | |||||
| batch_op.drop_index('tidb_auth_bindings_tenant_idx') | |||||
| batch_op.drop_index('tidb_auth_bindings_created_at_idx') | |||||
| batch_op.drop_index('tidb_auth_bindings_active_idx') | |||||
| batch_op.drop_index('tidb_auth_bindings_status_idx') | |||||
| op.drop_table('tidb_auth_bindings') | |||||
| # ### end Alembic commands ### |
| """add_white_list | |||||
| Revision ID: 43fa78bc3b7d | |||||
| Revises: 0251a1c768cc | |||||
| Create Date: 2024-10-22 09:59:23.713716 | |||||
| """ | |||||
| from alembic import op | |||||
| import models as models | |||||
| import sqlalchemy as sa | |||||
| from sqlalchemy.dialects import postgresql | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = '43fa78bc3b7d' | |||||
| down_revision = '0251a1c768cc' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| op.create_table('whitelists', | |||||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=True), | |||||
| sa.Column('category', sa.String(length=255), nullable=False), | |||||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||||
| sa.PrimaryKeyConstraint('id', name='whitelists_pkey') | |||||
| ) | |||||
| with op.batch_alter_table('whitelists', schema=None) as batch_op: | |||||
| batch_op.create_index('whitelists_tenant_idx', ['tenant_id'], unique=False) | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('whitelists', schema=None) as batch_op: | |||||
| batch_op.drop_index('whitelists_tenant_idx') | |||||
| op.drop_table('whitelists') | |||||
| # ### end Alembic commands ### |
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | ||||
| class TidbAuthBinding(db.Model): | |||||
| __tablename__ = "tidb_auth_bindings" | |||||
| __table_args__ = ( | |||||
| db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), | |||||
| db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), | |||||
| db.Index("tidb_auth_bindings_active_idx", "active"), | |||||
| db.Index("tidb_auth_bindings_created_at_idx", "created_at"), | |||||
| db.Index("tidb_auth_bindings_status_idx", "status"), | |||||
| ) | |||||
| id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |||||
| tenant_id = db.Column(StringUUID, nullable=True) | |||||
| cluster_id = db.Column(db.String(255), nullable=False) | |||||
| cluster_name = db.Column(db.String(255), nullable=False) | |||||
| active = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||||
| status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) | |||||
| account = db.Column(db.String(255), nullable=False) | |||||
| password = db.Column(db.String(255), nullable=False) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||||
| class Whitelist(db.Model): | |||||
| __tablename__ = "whitelists" | |||||
| __table_args__ = ( | |||||
| db.PrimaryKeyConstraint("id", name="whitelists_pkey"), | |||||
| db.Index("whitelists_tenant_idx", "tenant_id"), | |||||
| ) | |||||
| id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |||||
| tenant_id = db.Column(StringUUID, nullable=True) | |||||
| category = db.Column(db.String(255), nullable=False) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||||
| class DatasetPermission(db.Model): | class DatasetPermission(db.Model): | ||||
| __tablename__ = "dataset_permissions" | __tablename__ = "dataset_permissions" | ||||
| __table_args__ = ( | __table_args__ = ( |
| import time | |||||
| import click | |||||
| import app | |||||
| from configs import dify_config | |||||
| from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import TidbAuthBinding | |||||
| @app.celery.task(queue="dataset") | |||||
| def create_tidb_serverless_task(): | |||||
| click.echo(click.style("Start create tidb serverless task.", fg="green")) | |||||
| tidb_serverless_number = dify_config.TIDB_SERVERLESS_NUMBER | |||||
| start_at = time.perf_counter() | |||||
| while True: | |||||
| try: | |||||
| # check the number of idle tidb serverless | |||||
| idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count() | |||||
| if idle_tidb_serverless_number >= tidb_serverless_number: | |||||
| break | |||||
| # create tidb serverless | |||||
| iterations_per_thread = 20 | |||||
| create_clusters(iterations_per_thread) | |||||
| except Exception as e: | |||||
| click.echo(click.style(f"Error: {e}", fg="red")) | |||||
| break | |||||
| end_at = time.perf_counter() | |||||
| click.echo(click.style("Create tidb serverless task success latency: {}".format(end_at - start_at), fg="green")) | |||||
| def create_clusters(batch_size): | |||||
| try: | |||||
| new_clusters = TidbService.batch_create_tidb_serverless_cluster( | |||||
| batch_size, | |||||
| dify_config.TIDB_PROJECT_ID, | |||||
| dify_config.TIDB_API_URL, | |||||
| dify_config.TIDB_IAM_API_URL, | |||||
| dify_config.TIDB_PUBLIC_KEY, | |||||
| dify_config.TIDB_PRIVATE_KEY, | |||||
| dify_config.TIDB_REGION, | |||||
| ) | |||||
| for new_cluster in new_clusters: | |||||
| tidb_auth_binding = TidbAuthBinding( | |||||
| cluster_id=new_cluster["cluster_id"], | |||||
| cluster_name=new_cluster["cluster_name"], | |||||
| account=new_cluster["account"], | |||||
| password=new_cluster["password"], | |||||
| ) | |||||
| db.session.add(tidb_auth_binding) | |||||
| db.session.commit() | |||||
| except Exception as e: | |||||
| click.echo(click.style(f"Error: {e}", fg="red")) |
| import time | |||||
| import click | |||||
| import app | |||||
| from configs import dify_config | |||||
| from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService | |||||
| from models.dataset import TidbAuthBinding | |||||
| @app.celery.task(queue="dataset") | |||||
| def update_tidb_serverless_status_task(): | |||||
| click.echo(click.style("Update tidb serverless status task.", fg="green")) | |||||
| start_at = time.perf_counter() | |||||
| while True: | |||||
| try: | |||||
| # check the number of idle tidb serverless | |||||
| tidb_serverless_list = TidbAuthBinding.query.filter( | |||||
| TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" | |||||
| ).all() | |||||
| if len(tidb_serverless_list) == 0: | |||||
| break | |||||
| # update tidb serverless status | |||||
| iterations_per_thread = 20 | |||||
| update_clusters(tidb_serverless_list) | |||||
| except Exception as e: | |||||
| click.echo(click.style(f"Error: {e}", fg="red")) | |||||
| break | |||||
| end_at = time.perf_counter() | |||||
| click.echo( | |||||
| click.style("Update tidb serverless status task success latency: {}".format(end_at - start_at), fg="green") | |||||
| ) | |||||
| def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): | |||||
| try: | |||||
| # batch 20 | |||||
| for i in range(0, len(tidb_serverless_list), 20): | |||||
| items = tidb_serverless_list[i : i + 20] | |||||
| TidbService.batch_update_tidb_serverless_cluster_status( | |||||
| items, | |||||
| dify_config.TIDB_PROJECT_ID, | |||||
| dify_config.TIDB_API_URL, | |||||
| dify_config.TIDB_IAM_API_URL, | |||||
| dify_config.TIDB_PUBLIC_KEY, | |||||
| dify_config.TIDB_PRIVATE_KEY, | |||||
| ) | |||||
| except Exception as e: | |||||
| click.echo(click.style(f"Error: {e}", fg="red")) |
| import json | |||||
| import requests | |||||
| from services.auth.api_key_auth_base import ApiKeyAuthBase | |||||
| class JinaAuth(ApiKeyAuthBase): | |||||
| def __init__(self, credentials: dict): | |||||
| super().__init__(credentials) | |||||
| auth_type = credentials.get("auth_type") | |||||
| if auth_type != "bearer": | |||||
| raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") | |||||
| self.api_key = credentials.get("config").get("api_key", None) | |||||
| if not self.api_key: | |||||
| raise ValueError("No API key provided") | |||||
| def validate_credentials(self): | |||||
| headers = self._prepare_headers() | |||||
| options = { | |||||
| "url": "https://example.com", | |||||
| } | |||||
| response = self._post_request("https://r.jina.ai", options, headers) | |||||
| if response.status_code == 200: | |||||
| return True | |||||
| else: | |||||
| self._handle_error(response) | |||||
| def _prepare_headers(self): | |||||
| return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} | |||||
| def _post_request(self, url, data, headers): | |||||
| return requests.post(url, headers=headers, json=data) | |||||
| def _handle_error(self, response): | |||||
| if response.status_code in {402, 409, 500}: | |||||
| error_message = response.json().get("error", "Unknown error occurred") | |||||
| raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") | |||||
| else: | |||||
| if response.text: | |||||
| error_message = json.loads(response.text).get("error", "Unknown error occurred") | |||||
| raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") | |||||
| raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") |
| 'use client' | |||||
| import type { FC } from 'react' | |||||
| import React from 'react' | |||||
| import cn from '@/utils/classnames' | |||||
| import Checkbox from '@/app/components/base/checkbox' | |||||
| import Tooltip from '@/app/components/base/tooltip' | |||||
| type Props = { | |||||
| className?: string | |||||
| isChecked: boolean | |||||
| onChange: (isChecked: boolean) => void | |||||
| label: string | |||||
| labelClassName?: string | |||||
| tooltip?: string | |||||
| } | |||||
| const CheckboxWithLabel: FC<Props> = ({ | |||||
| className = '', | |||||
| isChecked, | |||||
| onChange, | |||||
| label, | |||||
| labelClassName, | |||||
| tooltip, | |||||
| }) => { | |||||
| return ( | |||||
| <label className={cn(className, 'flex items-center h-7 space-x-2')}> | |||||
| <Checkbox checked={isChecked} onCheck={() => onChange(!isChecked)} /> | |||||
| <div className={cn(labelClassName, 'text-sm font-normal text-gray-800')}>{label}</div> | |||||
| {tooltip && ( | |||||
| <Tooltip | |||||
| popupContent={ | |||||
| <div className='w-[200px]'>{tooltip}</div> | |||||
| } | |||||
| triggerClassName='ml-0.5 w-4 h-4' | |||||
| /> | |||||
| )} | |||||
| </label> | |||||
| ) | |||||
| } | |||||
| export default React.memo(CheckboxWithLabel) |
| 'use client' | |||||
| import type { FC } from 'react' | |||||
| import React from 'react' | |||||
| import cn from '@/utils/classnames' | |||||
| import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' | |||||
| type Props = { | |||||
| className?: string | |||||
| title: string | |||||
| errorMsg?: string | |||||
| } | |||||
| const ErrorMessage: FC<Props> = ({ | |||||
| className, | |||||
| title, | |||||
| errorMsg, | |||||
| }) => { | |||||
| return ( | |||||
| <div className={cn(className, 'py-2 px-4 border-t border-gray-200 bg-[#FFFAEB]')}> | |||||
| <div className='flex items-center h-5'> | |||||
| <AlertTriangle className='mr-2 w-4 h-4 text-[#F79009]' /> | |||||
| <div className='text-sm font-medium text-[#DC6803]'>{title}</div> | |||||
| </div> | |||||
| {errorMsg && ( | |||||
| <div className='mt-1 pl-6 leading-[18px] text-xs font-normal text-gray-700'>{errorMsg}</div> | |||||
| )} | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| export default React.memo(ErrorMessage) |
| 'use client' | |||||
| import type { FC } from 'react' | |||||
| import React from 'react' | |||||
| import Input from './input' | |||||
| import cn from '@/utils/classnames' | |||||
| import Tooltip from '@/app/components/base/tooltip' | |||||
| type Props = { | |||||
| className?: string | |||||
| label: string | |||||
| labelClassName?: string | |||||
| value: string | number | |||||
| onChange: (value: string | number) => void | |||||
| isRequired?: boolean | |||||
| placeholder?: string | |||||
| isNumber?: boolean | |||||
| tooltip?: string | |||||
| } | |||||
| const Field: FC<Props> = ({ | |||||
| className, | |||||
| label, | |||||
| labelClassName, | |||||
| value, | |||||
| onChange, | |||||
| isRequired = false, | |||||
| placeholder = '', | |||||
| isNumber = false, | |||||
| tooltip, | |||||
| }) => { | |||||
| return ( | |||||
| <div className={cn(className)}> | |||||
| <div className='flex py-[7px]'> | |||||
| <div className={cn(labelClassName, 'flex items-center h-[18px] text-[13px] font-medium text-gray-900')}>{label} </div> | |||||
| {isRequired && <span className='ml-0.5 text-xs font-semibold text-[#D92D20]'>*</span>} | |||||
| {tooltip && ( | |||||
| <Tooltip | |||||
| popupContent={ | |||||
| <div className='w-[200px]'>{tooltip}</div> | |||||
| } | |||||
| triggerClassName='ml-0.5 w-4 h-4' | |||||
| /> | |||||
| )} | |||||
| </div> | |||||
| <Input | |||||
| value={value} | |||||
| onChange={onChange} | |||||
| placeholder={placeholder} | |||||
| isNumber={isNumber} | |||||
| /> | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| export default React.memo(Field) |
| 'use client' | |||||
| import type { FC } from 'react' | |||||
| import React, { useCallback } from 'react' | |||||
| type Props = { | |||||
| value: string | number | |||||
| onChange: (value: string | number) => void | |||||
| placeholder?: string | |||||
| isNumber?: boolean | |||||
| } | |||||
| const MIN_VALUE = 0 | |||||
| const Input: FC<Props> = ({ | |||||
| value, | |||||
| onChange, | |||||
| placeholder = '', | |||||
| isNumber = false, | |||||
| }) => { | |||||
| const handleChange = useCallback((e: React.ChangeEvent<HTMLInputElement>) => { | |||||
| const value = e.target.value | |||||
| if (isNumber) { | |||||
| let numberValue = parseInt(value, 10) // integer only | |||||
| if (isNaN(numberValue)) { | |||||
| onChange('') | |||||
| return | |||||
| } | |||||
| if (numberValue < MIN_VALUE) | |||||
| numberValue = MIN_VALUE | |||||
| onChange(numberValue) | |||||
| return | |||||
| } | |||||
| onChange(value) | |||||
| }, [isNumber, onChange]) | |||||
| const otherOption = (() => { | |||||
| if (isNumber) { | |||||
| return { | |||||
| min: MIN_VALUE, | |||||
| } | |||||
| } | |||||
| return { | |||||
| } | |||||
| })() | |||||
| return ( | |||||
| <input | |||||
| type={isNumber ? 'number' : 'text'} | |||||
| {...otherOption} | |||||
| value={value} | |||||
| onChange={handleChange} | |||||
| className='flex h-9 w-full py-1 px-2 rounded-lg text-xs leading-normal bg-gray-100 caret-primary-600 hover:bg-gray-100 focus:ring-1 focus:ring-inset focus:ring-gray-200 focus-visible:outline-none focus:bg-gray-50 placeholder:text-gray-400' | |||||
| placeholder={placeholder} | |||||
| /> | |||||
| ) | |||||
| } | |||||
| export default React.memo(Input) |
| 'use client' | |||||
| import { useBoolean } from 'ahooks' | |||||
| import type { FC } from 'react' | |||||
| import React, { useEffect } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import cn from '@/utils/classnames' | |||||
| import { Settings04 } from '@/app/components/base/icons/src/vender/line/general' | |||||
| import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows' | |||||
| const I18N_PREFIX = 'datasetCreation.stepOne.website' | |||||
| type Props = { | |||||
| className?: string | |||||
| children: React.ReactNode | |||||
| controlFoldOptions?: number | |||||
| } | |||||
| const OptionsWrap: FC<Props> = ({ | |||||
| className = '', | |||||
| children, | |||||
| controlFoldOptions, | |||||
| }) => { | |||||
| const { t } = useTranslation() | |||||
| const [fold, { | |||||
| toggle: foldToggle, | |||||
| setTrue: foldHide, | |||||
| }] = useBoolean(false) | |||||
| useEffect(() => { | |||||
| if (controlFoldOptions) | |||||
| foldHide() | |||||
| // eslint-disable-next-line react-hooks/exhaustive-deps | |||||
| }, [controlFoldOptions]) | |||||
| return ( | |||||
| <div className={cn(className, !fold ? 'mb-0' : 'mb-3')}> | |||||
| <div | |||||
| className='flex justify-between items-center h-[26px] py-1 cursor-pointer select-none' | |||||
| onClick={foldToggle} | |||||
| > | |||||
| <div className='flex items-center text-gray-700'> | |||||
| <Settings04 className='mr-1 w-4 h-4' /> | |||||
| <div className='text-[13px] font-semibold text-gray-800 uppercase'>{t(`${I18N_PREFIX}.options`)}</div> | |||||
| </div> | |||||
| <ChevronRight className={cn(!fold && 'rotate-90', 'w-4 h-4 text-gray-500')} /> | |||||
| </div> | |||||
| {!fold && ( | |||||
| <div className='mb-4'> | |||||
| {children} | |||||
| </div> | |||||
| )} | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| export default React.memo(OptionsWrap) |
| 'use client' | |||||
| import type { FC } from 'react' | |||||
| import React, { useCallback, useState } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import Input from './input' | |||||
| import Button from '@/app/components/base/button' | |||||
| const I18N_PREFIX = 'datasetCreation.stepOne.website' | |||||
| type Props = { | |||||
| isRunning: boolean | |||||
| onRun: (url: string) => void | |||||
| } | |||||
| const UrlInput: FC<Props> = ({ | |||||
| isRunning, | |||||
| onRun, | |||||
| }) => { | |||||
| const { t } = useTranslation() | |||||
| const [url, setUrl] = useState('') | |||||
| const handleUrlChange = useCallback((url: string | number) => { | |||||
| setUrl(url as string) | |||||
| }, []) | |||||
| const handleOnRun = useCallback(() => { | |||||
| if (isRunning) | |||||
| return | |||||
| onRun(url) | |||||
| }, [isRunning, onRun, url]) | |||||
| return ( | |||||
| <div className='flex items-center justify-between'> | |||||
| <Input | |||||
| value={url} | |||||
| onChange={handleUrlChange} | |||||
| placeholder='https://docs.dify.ai' | |||||
| /> | |||||
| <Button | |||||
| variant='primary' | |||||
| onClick={handleOnRun} | |||||
| className='ml-2' | |||||
| loading={isRunning} | |||||
| > | |||||
| {!isRunning ? t(`${I18N_PREFIX}.run`) : ''} | |||||
| </Button> | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| export default React.memo(UrlInput) |
| 'use client' | |||||
| import type { FC } from 'react' | |||||
| import React, { useCallback } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import cn from '@/utils/classnames' | |||||
| import type { CrawlResultItem as CrawlResultItemType } from '@/models/datasets' | |||||
| import Checkbox from '@/app/components/base/checkbox' | |||||
| type Props = { | |||||
| payload: CrawlResultItemType | |||||
| isChecked: boolean | |||||
| isPreview: boolean | |||||
| onCheckChange: (checked: boolean) => void | |||||
| onPreview: () => void | |||||
| } | |||||
| const CrawledResultItem: FC<Props> = ({ | |||||
| isPreview, | |||||
| payload, | |||||
| isChecked, | |||||
| onCheckChange, | |||||
| onPreview, | |||||
| }) => { | |||||
| const { t } = useTranslation() | |||||
| const handleCheckChange = useCallback(() => { | |||||
| onCheckChange(!isChecked) | |||||
| }, [isChecked, onCheckChange]) | |||||
| return ( | |||||
| <div className={cn(isPreview ? 'border-[#D1E0FF] bg-primary-50 shadow-xs' : 'group hover:bg-gray-100', 'rounded-md px-2 py-[5px] cursor-pointer border border-transparent')}> | |||||
| <div className='flex items-center h-5'> | |||||
| <Checkbox className='group-hover:border-2 group-hover:border-primary-600 mr-2 shrink-0' checked={isChecked} onCheck={handleCheckChange} /> | |||||
| <div className='grow w-0 truncate text-sm font-medium text-gray-700' title={payload.title}>{payload.title}</div> | |||||
| <div onClick={onPreview} className='hidden group-hover:flex items-center h-6 px-2 text-xs rounded-md font-medium text-gray-500 uppercase hover:bg-gray-50'>{t('datasetCreation.stepOne.website.preview')}</div> | |||||
| </div> | |||||
| <div className='mt-0.5 truncate pl-6 leading-[18px] text-xs font-normal text-gray-500' title={payload.source_url}>{payload.source_url}</div> | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| export default React.memo(CrawledResultItem) |
| 'use client' | |||||
| import type { FC } from 'react' | |||||
| import React, { useCallback } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import CheckboxWithLabel from './base/checkbox-with-label' | |||||
| import CrawledResultItem from './crawled-result-item' | |||||
| import cn from '@/utils/classnames' | |||||
| import type { CrawlResultItem } from '@/models/datasets' | |||||
| const I18N_PREFIX = 'datasetCreation.stepOne.website' | |||||
| type Props = { | |||||
| className?: string | |||||
| list: CrawlResultItem[] | |||||
| checkedList: CrawlResultItem[] | |||||
| onSelectedChange: (selected: CrawlResultItem[]) => void | |||||
| onPreview: (payload: CrawlResultItem) => void | |||||
| usedTime: number | |||||
| } | |||||
| const CrawledResult: FC<Props> = ({ | |||||
| className = '', | |||||
| list, | |||||
| checkedList, | |||||
| onSelectedChange, | |||||
| onPreview, | |||||
| usedTime, | |||||
| }) => { | |||||
| const { t } = useTranslation() | |||||
| const isCheckAll = checkedList.length === list.length | |||||
| const handleCheckedAll = useCallback(() => { | |||||
| if (!isCheckAll) | |||||
| onSelectedChange(list) | |||||
| else | |||||
| onSelectedChange([]) | |||||
| }, [isCheckAll, list, onSelectedChange]) | |||||
| const handleItemCheckChange = useCallback((item: CrawlResultItem) => { | |||||
| return (checked: boolean) => { | |||||
| if (checked) | |||||
| onSelectedChange([...checkedList, item]) | |||||
| else | |||||
| onSelectedChange(checkedList.filter(checkedItem => checkedItem.source_url !== item.source_url)) | |||||
| } | |||||
| }, [checkedList, onSelectedChange]) | |||||
| const [previewIndex, setPreviewIndex] = React.useState<number>(-1) | |||||
| const handlePreview = useCallback((index: number) => { | |||||
| return () => { | |||||
| setPreviewIndex(index) | |||||
| onPreview(list[index]) | |||||
| } | |||||
| }, [list, onPreview]) | |||||
| return ( | |||||
| <div className={cn(className, 'border-t border-gray-200')}> | |||||
| <div className='flex items-center justify-between h-[34px] px-4 bg-gray-50 shadow-xs border-b-[0.5px] border-black/8 text-xs font-normal text-gray-700'> | |||||
| <CheckboxWithLabel | |||||
| isChecked={isCheckAll} | |||||
| onChange={handleCheckedAll} label={isCheckAll ? t(`${I18N_PREFIX}.resetAll`) : t(`${I18N_PREFIX}.selectAll`)} | |||||
| labelClassName='!font-medium' | |||||
| /> | |||||
| <div>{t(`${I18N_PREFIX}.scrapTimeInfo`, { | |||||
| total: list.length, | |||||
| time: usedTime.toFixed(1), | |||||
| })}</div> | |||||
| </div> | |||||
| <div className='p-2'> | |||||
| {list.map((item, index) => ( | |||||
| <CrawledResultItem | |||||
| key={item.source_url} | |||||
| isPreview={index === previewIndex} | |||||
| onPreview={handlePreview(index)} | |||||
| payload={item} | |||||
| isChecked={checkedList.some(checkedItem => checkedItem.source_url === item.source_url)} | |||||
| onCheckChange={handleItemCheckChange(item)} | |||||
| /> | |||||
| ))} | |||||
| </div> | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| export default React.memo(CrawledResult) |
| 'use client' | |||||
| import type { FC } from 'react' | |||||
| import React from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import cn from '@/utils/classnames' | |||||
| import { RowStruct } from '@/app/components/base/icons/src/public/other' | |||||
| type Props = { | |||||
| className?: string | |||||
| crawledNum: number | |||||
| totalNum: number | |||||
| } | |||||
| const Crawling: FC<Props> = ({ | |||||
| className = '', | |||||
| crawledNum, | |||||
| totalNum, | |||||
| }) => { | |||||
| const { t } = useTranslation() | |||||
| return ( | |||||
| <div className={cn(className, 'border-t border-gray-200')}> | |||||
| <div className='flex items-center h-[34px] px-4 bg-gray-50 shadow-xs border-b-[0.5px] border-black/8 text-xs font-normal text-gray-700'> | |||||
| {t('datasetCreation.stepOne.website.totalPageScraped')} {crawledNum}/{totalNum} | |||||
| </div> | |||||
| <div className='p-2'> | |||||
| {['', '', '', ''].map((item, index) => ( | |||||
| <div className='py-[5px]' key={index}> | |||||
| <RowStruct /> | |||||
| </div> | |||||
| ))} | |||||
| </div> | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| export default React.memo(Crawling) |
| import type { CrawlResultItem } from '@/models/datasets' | |||||
| const result: CrawlResultItem[] = [ | |||||
| { | |||||
| title: 'Start the frontend Docker container separately', | |||||
| markdown: 'Markdown 1', | |||||
| description: 'Description 1', | |||||
| source_url: 'https://example.com/1', | |||||
| }, | |||||
| { | |||||
| title: 'Advanced Tool Integration', | |||||
| markdown: 'Markdown 2', | |||||
| description: 'Description 2', | |||||
| source_url: 'https://example.com/2', | |||||
| }, | |||||
| { | |||||
| title: 'Local Source Code Start | English | Dify', | |||||
| markdown: 'Markdown 3', | |||||
| description: 'Description 3', | |||||
| source_url: 'https://example.com/3', | |||||
| }, | |||||
| ] | |||||
| export default result |