| @@ -571,6 +571,11 @@ class DataSetConfig(BaseSettings): | |||
| default=False, | |||
| ) | |||
| TIDB_SERVERLESS_NUMBER: PositiveInt = Field( | |||
| description="number of tidb serverless cluster", | |||
| default=500, | |||
| ) | |||
| class WorkspaceConfig(BaseSettings): | |||
| """ | |||
| @@ -27,6 +27,7 @@ from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig | |||
| from configs.middleware.vdb.qdrant_config import QdrantConfig | |||
| from configs.middleware.vdb.relyt_config import RelytConfig | |||
| from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig | |||
| from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig | |||
| from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig | |||
| from configs.middleware.vdb.upstash_config import UpstashConfig | |||
| from configs.middleware.vdb.vikingdb_config import VikingDBConfig | |||
| @@ -54,6 +55,11 @@ class VectorStoreConfig(BaseSettings): | |||
| default=None, | |||
| ) | |||
| VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field( | |||
| description="Enable whitelist for vector store.", | |||
| default=False, | |||
| ) | |||
| class KeywordStoreConfig(BaseSettings): | |||
| KEYWORD_STORE: str = Field( | |||
| @@ -248,5 +254,6 @@ class MiddlewareConfig( | |||
| InternalTestConfig, | |||
| VikingDBConfig, | |||
| UpstashConfig, | |||
| TidbOnQdrantConfig, | |||
| ): | |||
| pass | |||
| @@ -0,0 +1,65 @@ | |||
| from typing import Optional | |||
| from pydantic import Field, NonNegativeInt, PositiveInt | |||
| from pydantic_settings import BaseSettings | |||
| class TidbOnQdrantConfig(BaseSettings): | |||
| """ | |||
| Tidb on Qdrant configs | |||
| """ | |||
| TIDB_ON_QDRANT_URL: Optional[str] = Field( | |||
| description="Tidb on Qdrant url", | |||
| default=None, | |||
| ) | |||
| TIDB_ON_QDRANT_API_KEY: Optional[str] = Field( | |||
| description="Tidb on Qdrant api key", | |||
| default=None, | |||
| ) | |||
| TIDB_ON_QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field( | |||
| description="Tidb on Qdrant client timeout in seconds", | |||
| default=20, | |||
| ) | |||
| TIDB_ON_QDRANT_GRPC_ENABLED: bool = Field( | |||
| description="whether enable grpc support for Tidb on Qdrant connection", | |||
| default=False, | |||
| ) | |||
| TIDB_ON_QDRANT_GRPC_PORT: PositiveInt = Field( | |||
| description="Tidb on Qdrant grpc port", | |||
| default=6334, | |||
| ) | |||
| TIDB_PUBLIC_KEY: Optional[str] = Field( | |||
| description="Tidb account public key", | |||
| default=None, | |||
| ) | |||
| TIDB_PRIVATE_KEY: Optional[str] = Field( | |||
| description="Tidb account private key", | |||
| default=None, | |||
| ) | |||
| TIDB_API_URL: Optional[str] = Field( | |||
| description="Tidb API url", | |||
| default=None, | |||
| ) | |||
| TIDB_IAM_API_URL: Optional[str] = Field( | |||
| description="Tidb IAM API url", | |||
| default=None, | |||
| ) | |||
| TIDB_REGION: Optional[str] = Field( | |||
| description="Tidb serverless region", | |||
| default="regions/aws-us-east-1", | |||
| ) | |||
| TIDB_PROJECT_ID: Optional[str] = Field( | |||
| description="Tidb project id", | |||
| default=None, | |||
| ) | |||
| @@ -639,6 +639,7 @@ class DatasetRetrievalSettingApi(Resource): | |||
| | VectorType.ORACLE | |||
| | VectorType.ELASTICSEARCH | |||
| | VectorType.PGVECTOR | |||
| | VectorType.TIDB_ON_QDRANT | |||
| ): | |||
| return { | |||
| "retrieval_method": [ | |||
| @@ -0,0 +1,17 @@ | |||
| from typing import Optional | |||
| from pydantic import BaseModel | |||
| class ClusterEntity(BaseModel): | |||
| """ | |||
| Model Config Entity. | |||
| """ | |||
| name: str | |||
| cluster_id: str | |||
| displayName: str | |||
| region: str | |||
| spendingLimit: Optional[int] = 1000 | |||
| version: str | |||
| createdBy: str | |||
| @@ -0,0 +1,526 @@ | |||
| import json | |||
| import os | |||
| import uuid | |||
| from collections.abc import Generator, Iterable, Sequence | |||
| from itertools import islice | |||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | |||
| import qdrant_client | |||
| import requests | |||
| from flask import current_app | |||
| from pydantic import BaseModel | |||
| from qdrant_client.http import models as rest | |||
| from qdrant_client.http.models import ( | |||
| FilterSelector, | |||
| HnswConfigDiff, | |||
| PayloadSchemaType, | |||
| TextIndexParams, | |||
| TextIndexType, | |||
| TokenizerType, | |||
| ) | |||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||
| from requests.auth import HTTPDigestAuth | |||
| from configs import dify_config | |||
| from core.rag.datasource.vdb.field import Field | |||
| from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService | |||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||
| from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | |||
| from core.rag.datasource.vdb.vector_type import VectorType | |||
| from core.rag.embedding.embedding_base import Embeddings | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset, TidbAuthBinding | |||
| if TYPE_CHECKING: | |||
| from qdrant_client import grpc # noqa | |||
| from qdrant_client.conversions import common_types | |||
| from qdrant_client.http import models as rest | |||
| DictFilter = dict[str, Union[str, int, bool, dict, list]] | |||
| MetadataFilter = Union[DictFilter, common_types.Filter] | |||
| class TidbOnQdrantConfig(BaseModel): | |||
| endpoint: str | |||
| api_key: Optional[str] = None | |||
| timeout: float = 20 | |||
| root_path: Optional[str] = None | |||
| grpc_port: int = 6334 | |||
| prefer_grpc: bool = False | |||
| def to_qdrant_params(self): | |||
| if self.endpoint and self.endpoint.startswith("path:"): | |||
| path = self.endpoint.replace("path:", "") | |||
| if not os.path.isabs(path): | |||
| path = os.path.join(self.root_path, path) | |||
| return {"path": path} | |||
| else: | |||
| return { | |||
| "url": self.endpoint, | |||
| "api_key": self.api_key, | |||
| "timeout": self.timeout, | |||
| "verify": False, | |||
| "grpc_port": self.grpc_port, | |||
| "prefer_grpc": self.prefer_grpc, | |||
| } | |||
| class TidbConfig(BaseModel): | |||
| api_url: str | |||
| public_key: str | |||
| private_key: str | |||
| class TidbOnQdrantVector(BaseVector): | |||
| def __init__(self, collection_name: str, group_id: str, config: TidbOnQdrantConfig, distance_func: str = "Cosine"): | |||
| super().__init__(collection_name) | |||
| self._client_config = config | |||
| self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) | |||
| self._distance_func = distance_func.upper() | |||
| self._group_id = group_id | |||
| def get_type(self) -> str: | |||
| return VectorType.TIDB_ON_QDRANT | |||
| def to_index_struct(self) -> dict: | |||
| return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} | |||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||
| if texts: | |||
| # get embedding vector size | |||
| vector_size = len(embeddings[0]) | |||
| # get collection name | |||
| collection_name = self._collection_name | |||
| # create collection | |||
| self.create_collection(collection_name, vector_size) | |||
| self.add_texts(texts, embeddings, **kwargs) | |||
| def create_collection(self, collection_name: str, vector_size: int): | |||
| lock_name = "vector_indexing_lock_{}".format(collection_name) | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||
| if redis_client.get(collection_exist_cache_key): | |||
| return | |||
| collection_name = collection_name or uuid.uuid4().hex | |||
| all_collection_name = [] | |||
| collections_response = self._client.get_collections() | |||
| collection_list = collections_response.collections | |||
| for collection in collection_list: | |||
| all_collection_name.append(collection.name) | |||
| if collection_name not in all_collection_name: | |||
| from qdrant_client.http import models as rest | |||
| vectors_config = rest.VectorParams( | |||
| size=vector_size, | |||
| distance=rest.Distance[self._distance_func], | |||
| ) | |||
| hnsw_config = HnswConfigDiff( | |||
| m=0, | |||
| payload_m=16, | |||
| ef_construct=100, | |||
| full_scan_threshold=10000, | |||
| max_indexing_threads=0, | |||
| on_disk=False, | |||
| ) | |||
| self._client.recreate_collection( | |||
| collection_name=collection_name, | |||
| vectors_config=vectors_config, | |||
| hnsw_config=hnsw_config, | |||
| timeout=int(self._client_config.timeout), | |||
| ) | |||
| # create group_id payload index | |||
| self._client.create_payload_index( | |||
| collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD | |||
| ) | |||
| # create doc_id payload index | |||
| self._client.create_payload_index( | |||
| collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD | |||
| ) | |||
| # create full text index | |||
| text_index_params = TextIndexParams( | |||
| type=TextIndexType.TEXT, | |||
| tokenizer=TokenizerType.MULTILINGUAL, | |||
| min_token_len=2, | |||
| max_token_len=20, | |||
| lowercase=True, | |||
| ) | |||
| self._client.create_payload_index( | |||
| collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params | |||
| ) | |||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||
| uuids = self._get_uuids(documents) | |||
| texts = [d.page_content for d in documents] | |||
| metadatas = [d.metadata for d in documents] | |||
| added_ids = [] | |||
| for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): | |||
| self._client.upsert(collection_name=self._collection_name, points=points) | |||
| added_ids.extend(batch_ids) | |||
| return added_ids | |||
| def _generate_rest_batches( | |||
| self, | |||
| texts: Iterable[str], | |||
| embeddings: list[list[float]], | |||
| metadatas: Optional[list[dict]] = None, | |||
| ids: Optional[Sequence[str]] = None, | |||
| batch_size: int = 64, | |||
| group_id: Optional[str] = None, | |||
| ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: | |||
| from qdrant_client.http import models as rest | |||
| texts_iterator = iter(texts) | |||
| embeddings_iterator = iter(embeddings) | |||
| metadatas_iterator = iter(metadatas or []) | |||
| ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) | |||
| while batch_texts := list(islice(texts_iterator, batch_size)): | |||
| # Take the corresponding metadata and id for each text in a batch | |||
| batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None | |||
| batch_ids = list(islice(ids_iterator, batch_size)) | |||
| # Generate the embeddings for all the texts in a batch | |||
| batch_embeddings = list(islice(embeddings_iterator, batch_size)) | |||
| points = [ | |||
| rest.PointStruct( | |||
| id=point_id, | |||
| vector=vector, | |||
| payload=payload, | |||
| ) | |||
| for point_id, vector, payload in zip( | |||
| batch_ids, | |||
| batch_embeddings, | |||
| self._build_payloads( | |||
| batch_texts, | |||
| batch_metadatas, | |||
| Field.CONTENT_KEY.value, | |||
| Field.METADATA_KEY.value, | |||
| group_id, | |||
| Field.GROUP_KEY.value, | |||
| ), | |||
| ) | |||
| ] | |||
| yield batch_ids, points | |||
| @classmethod | |||
| def _build_payloads( | |||
| cls, | |||
| texts: Iterable[str], | |||
| metadatas: Optional[list[dict]], | |||
| content_payload_key: str, | |||
| metadata_payload_key: str, | |||
| group_id: str, | |||
| group_payload_key: str, | |||
| ) -> list[dict]: | |||
| payloads = [] | |||
| for i, text in enumerate(texts): | |||
| if text is None: | |||
| raise ValueError( | |||
| "At least one of the texts is None. Please remove it before " | |||
| "calling .from_texts or .add_texts on Qdrant instance." | |||
| ) | |||
| metadata = metadatas[i] if metadatas is not None else None | |||
| payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id}) | |||
| return payloads | |||
| def delete_by_metadata_field(self, key: str, value: str): | |||
| from qdrant_client.http import models | |||
| from qdrant_client.http.exceptions import UnexpectedResponse | |||
| try: | |||
| filter = models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key=f"metadata.{key}", | |||
| match=models.MatchValue(value=value), | |||
| ), | |||
| ], | |||
| ) | |||
| self._reload_if_needed() | |||
| self._client.delete( | |||
| collection_name=self._collection_name, | |||
| points_selector=FilterSelector(filter=filter), | |||
| ) | |||
| except UnexpectedResponse as e: | |||
| # Collection does not exist, so return | |||
| if e.status_code == 404: | |||
| return | |||
| # Some other error occurred, so re-raise the exception | |||
| else: | |||
| raise e | |||
| def delete(self): | |||
| from qdrant_client.http.exceptions import UnexpectedResponse | |||
| try: | |||
| self._client.delete_collection(collection_name=self._collection_name) | |||
| except UnexpectedResponse as e: | |||
| # Collection does not exist, so return | |||
| if e.status_code == 404: | |||
| return | |||
| # Some other error occurred, so re-raise the exception | |||
| else: | |||
| raise e | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| from qdrant_client.http import models | |||
| from qdrant_client.http.exceptions import UnexpectedResponse | |||
| for node_id in ids: | |||
| try: | |||
| filter = models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="metadata.doc_id", | |||
| match=models.MatchValue(value=node_id), | |||
| ), | |||
| ], | |||
| ) | |||
| self._client.delete( | |||
| collection_name=self._collection_name, | |||
| points_selector=FilterSelector(filter=filter), | |||
| ) | |||
| except UnexpectedResponse as e: | |||
| # Collection does not exist, so return | |||
| if e.status_code == 404: | |||
| return | |||
| # Some other error occurred, so re-raise the exception | |||
| else: | |||
| raise e | |||
| def text_exists(self, id: str) -> bool: | |||
| all_collection_name = [] | |||
| collections_response = self._client.get_collections() | |||
| collection_list = collections_response.collections | |||
| for collection in collection_list: | |||
| all_collection_name.append(collection.name) | |||
| if self._collection_name not in all_collection_name: | |||
| return False | |||
| response = self._client.retrieve(collection_name=self._collection_name, ids=[id]) | |||
| return len(response) > 0 | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| from qdrant_client.http import models | |||
| filter = models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="group_id", | |||
| match=models.MatchValue(value=self._group_id), | |||
| ), | |||
| ], | |||
| ) | |||
| results = self._client.search( | |||
| collection_name=self._collection_name, | |||
| query_vector=query_vector, | |||
| query_filter=filter, | |||
| limit=kwargs.get("top_k", 4), | |||
| with_payload=True, | |||
| with_vectors=True, | |||
| score_threshold=kwargs.get("score_threshold", 0.0), | |||
| ) | |||
| docs = [] | |||
| for result in results: | |||
| metadata = result.payload.get(Field.METADATA_KEY.value) or {} | |||
| # duplicate check score threshold | |||
| score_threshold = kwargs.get("score_threshold") or 0.0 | |||
| if result.score > score_threshold: | |||
| metadata["score"] = result.score | |||
| doc = Document( | |||
| page_content=result.payload.get(Field.CONTENT_KEY.value), | |||
| metadata=metadata, | |||
| ) | |||
| docs.append(doc) | |||
| # Sort the documents by score in descending order | |||
| docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) | |||
| return docs | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| """Return docs most similar by bm25. | |||
| Returns: | |||
| List of documents most similar to the query text and distance for each. | |||
| """ | |||
| from qdrant_client.http import models | |||
| scroll_filter = models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="page_content", | |||
| match=models.MatchText(text=query), | |||
| ) | |||
| ] | |||
| ) | |||
| response = self._client.scroll( | |||
| collection_name=self._collection_name, | |||
| scroll_filter=scroll_filter, | |||
| limit=kwargs.get("top_k", 2), | |||
| with_payload=True, | |||
| with_vectors=True, | |||
| ) | |||
| results = response[0] | |||
| documents = [] | |||
| for result in results: | |||
| if result: | |||
| document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) | |||
| document.metadata["vector"] = result.vector | |||
| documents.append(document) | |||
| return documents | |||
| def _reload_if_needed(self): | |||
| if isinstance(self._client, QdrantLocal): | |||
| self._client = cast(QdrantLocal, self._client) | |||
| self._client._load() | |||
| @classmethod | |||
| def _document_from_scored_point( | |||
| cls, | |||
| scored_point: Any, | |||
| content_payload_key: str, | |||
| metadata_payload_key: str, | |||
| ) -> Document: | |||
| return Document( | |||
| page_content=scored_point.payload.get(content_payload_key), | |||
| metadata=scored_point.payload.get(metadata_payload_key) or {}, | |||
| ) | |||
| class TidbOnQdrantVectorFactory(AbstractVectorFactory): | |||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: | |||
| tidb_auth_binding = ( | |||
| db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() | |||
| ) | |||
| if not tidb_auth_binding: | |||
| idle_tidb_auth_binding = ( | |||
| db.session.query(TidbAuthBinding) | |||
| .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") | |||
| .limit(1) | |||
| .one_or_none() | |||
| ) | |||
| if idle_tidb_auth_binding: | |||
| idle_tidb_auth_binding.active = True | |||
| idle_tidb_auth_binding.tenant_id = dataset.tenant_id | |||
| db.session.commit() | |||
| TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" | |||
| else: | |||
| with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): | |||
| tidb_auth_binding = ( | |||
| db.session.query(TidbAuthBinding) | |||
| .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||
| .one_or_none() | |||
| ) | |||
| if tidb_auth_binding: | |||
| TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | |||
| else: | |||
| new_cluster = TidbService.create_tidb_serverless_cluster( | |||
| dify_config.TIDB_PROJECT_ID, | |||
| dify_config.TIDB_API_URL, | |||
| dify_config.TIDB_IAM_API_URL, | |||
| dify_config.TIDB_PUBLIC_KEY, | |||
| dify_config.TIDB_PRIVATE_KEY, | |||
| dify_config.TIDB_REGION, | |||
| ) | |||
| new_tidb_auth_binding = TidbAuthBinding( | |||
| cluster_id=new_cluster["cluster_id"], | |||
| cluster_name=new_cluster["cluster_name"], | |||
| account=new_cluster["account"], | |||
| password=new_cluster["password"], | |||
| tenant_id=dataset.tenant_id, | |||
| active=True, | |||
| status="ACTIVE", | |||
| ) | |||
| db.session.add(new_tidb_auth_binding) | |||
| db.session.commit() | |||
| TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" | |||
| else: | |||
| TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | |||
| if dataset.index_struct_dict: | |||
| class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] | |||
| collection_name = class_prefix | |||
| else: | |||
| dataset_id = dataset.id | |||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||
| dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_ON_QDRANT, collection_name)) | |||
| config = current_app.config | |||
| return TidbOnQdrantVector( | |||
| collection_name=collection_name, | |||
| group_id=dataset.id, | |||
| config=TidbOnQdrantConfig( | |||
| endpoint=dify_config.TIDB_ON_QDRANT_URL, | |||
| api_key=TIDB_ON_QDRANT_API_KEY, | |||
| root_path=config.root_path, | |||
| timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, | |||
| grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT, | |||
| prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED, | |||
| ), | |||
| ) | |||
| def create_tidb_serverless_cluster(self, tidb_config: TidbConfig, display_name: str, region: str): | |||
| """ | |||
| Creates a new TiDB Serverless cluster. | |||
| :param tidb_config: The configuration for the TiDB Cloud API. | |||
| :param display_name: The user-friendly display name of the cluster (required). | |||
| :param region: The region where the cluster will be created (required). | |||
| :return: The response from the API. | |||
| """ | |||
| region_object = { | |||
| "name": region, | |||
| } | |||
| labels = { | |||
| "tidb.cloud/project": "1372813089454548012", | |||
| } | |||
| cluster_data = {"displayName": display_name, "region": region_object, "labels": labels} | |||
| response = requests.post( | |||
| f"{tidb_config.api_url}/clusters", | |||
| json=cluster_data, | |||
| auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), | |||
| ) | |||
| if response.status_code == 200: | |||
| return response.json() | |||
| else: | |||
| response.raise_for_status() | |||
| def change_tidb_serverless_root_password(self, tidb_config: TidbConfig, cluster_id: str, new_password: str): | |||
| """ | |||
| Changes the root password of a specific TiDB Serverless cluster. | |||
| :param tidb_config: The configuration for the TiDB Cloud API. | |||
| :param cluster_id: The ID of the cluster for which the password is to be changed (required). | |||
| :param new_password: The new password for the root user (required). | |||
| :return: The response from the API. | |||
| """ | |||
| body = {"password": new_password} | |||
| response = requests.put( | |||
| f"{tidb_config.api_url}/clusters/{cluster_id}/password", | |||
| json=body, | |||
| auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), | |||
| ) | |||
| if response.status_code == 200: | |||
| return response.json() | |||
| else: | |||
| response.raise_for_status() | |||
| @@ -0,0 +1,250 @@ | |||
| import time | |||
| import uuid | |||
| import requests | |||
| from requests.auth import HTTPDigestAuth | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import TidbAuthBinding | |||
| class TidbService: | |||
| @staticmethod | |||
| def create_tidb_serverless_cluster( | |||
| project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str | |||
| ): | |||
| """ | |||
| Creates a new TiDB Serverless cluster. | |||
| :param project_id: The project ID of the TiDB Cloud project (required). | |||
| :param api_url: The URL of the TiDB Cloud API (required). | |||
| :param iam_url: The URL of the TiDB Cloud IAM API (required). | |||
| :param public_key: The public key for the API (required). | |||
| :param private_key: The private key for the API (required). | |||
| :param display_name: The user-friendly display name of the cluster (required). | |||
| :param region: The region where the cluster will be created (required). | |||
| :return: The response from the API. | |||
| """ | |||
| region_object = { | |||
| "name": region, | |||
| } | |||
| labels = { | |||
| "tidb.cloud/project": project_id, | |||
| } | |||
| spending_limit = { | |||
| "monthly": 100, | |||
| } | |||
| password = str(uuid.uuid4()).replace("-", "")[:16] | |||
| display_name = str(uuid.uuid4()).replace("-", "")[:16] | |||
| cluster_data = { | |||
| "displayName": display_name, | |||
| "region": region_object, | |||
| "labels": labels, | |||
| "spendingLimit": spending_limit, | |||
| "rootPassword": password, | |||
| } | |||
| response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key)) | |||
| if response.status_code == 200: | |||
| response_data = response.json() | |||
| cluster_id = response_data["clusterId"] | |||
| retry_count = 0 | |||
| max_retries = 30 | |||
| while retry_count < max_retries: | |||
| cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) | |||
| if cluster_response["state"] == "ACTIVE": | |||
| user_prefix = cluster_response["userPrefix"] | |||
| return { | |||
| "cluster_id": cluster_id, | |||
| "cluster_name": display_name, | |||
| "account": f"{user_prefix}.root", | |||
| "password": password, | |||
| } | |||
| time.sleep(30) # wait 30 seconds before retrying | |||
| retry_count += 1 | |||
| else: | |||
| response.raise_for_status() | |||
| @staticmethod | |||
| def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): | |||
| """ | |||
| Deletes a specific TiDB Serverless cluster. | |||
| :param api_url: The URL of the TiDB Cloud API (required). | |||
| :param public_key: The public key for the API (required). | |||
| :param private_key: The private key for the API (required). | |||
| :param cluster_id: The ID of the cluster to be deleted (required). | |||
| :return: The response from the API. | |||
| """ | |||
| response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) | |||
| if response.status_code == 200: | |||
| return response.json() | |||
| else: | |||
| response.raise_for_status() | |||
| @staticmethod | |||
| def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): | |||
| """ | |||
| Deletes a specific TiDB Serverless cluster. | |||
| :param api_url: The URL of the TiDB Cloud API (required). | |||
| :param public_key: The public key for the API (required). | |||
| :param private_key: The private key for the API (required). | |||
| :param cluster_id: The ID of the cluster to be deleted (required). | |||
| :return: The response from the API. | |||
| """ | |||
| response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) | |||
| if response.status_code == 200: | |||
| return response.json() | |||
| else: | |||
| response.raise_for_status() | |||
| @staticmethod | |||
| def change_tidb_serverless_root_password( | |||
| api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str | |||
| ): | |||
| """ | |||
| Changes the root password of a specific TiDB Serverless cluster. | |||
| :param api_url: The URL of the TiDB Cloud API (required). | |||
| :param public_key: The public key for the API (required). | |||
| :param private_key: The private key for the API (required). | |||
| :param cluster_id: The ID of the cluster for which the password is to be changed (required).+ | |||
| :param account: The account for which the password is to be changed (required). | |||
| :param new_password: The new password for the root user (required). | |||
| :return: The response from the API. | |||
| """ | |||
| body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []} | |||
| response = requests.patch( | |||
| f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}", | |||
| json=body, | |||
| auth=HTTPDigestAuth(public_key, private_key), | |||
| ) | |||
| if response.status_code == 200: | |||
| return response.json() | |||
| else: | |||
| response.raise_for_status() | |||
| @staticmethod | |||
| def batch_update_tidb_serverless_cluster_status( | |||
| tidb_serverless_list: list[TidbAuthBinding], | |||
| project_id: str, | |||
| api_url: str, | |||
| iam_url: str, | |||
| public_key: str, | |||
| private_key: str, | |||
| ) -> list[dict]: | |||
| """ | |||
| Update the status of a new TiDB Serverless cluster. | |||
| :param project_id: The project ID of the TiDB Cloud project (required). | |||
| :param api_url: The URL of the TiDB Cloud API (required). | |||
| :param iam_url: The URL of the TiDB Cloud IAM API (required). | |||
| :param public_key: The public key for the API (required). | |||
| :param private_key: The private key for the API (required). | |||
| :param display_name: The user-friendly display name of the cluster (required). | |||
| :param region: The region where the cluster will be created (required). | |||
| :return: The response from the API. | |||
| """ | |||
| clusters = [] | |||
| tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} | |||
| cluster_ids = [item.cluster_id for item in tidb_serverless_list] | |||
| params = {"clusterIds": cluster_ids, "view": "FULL"} | |||
| response = requests.get( | |||
| f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key) | |||
| ) | |||
| if response.status_code == 200: | |||
| response_data = response.json() | |||
| cluster_infos = [] | |||
| for item in response_data["clusters"]: | |||
| state = item["state"] | |||
| userPrefix = item["userPrefix"] | |||
| if state == "ACTIVE" and len(userPrefix) > 0: | |||
| cluster_info = tidb_serverless_list_map[item["clusterId"]] | |||
| cluster_info.status = "ACTIVE" | |||
| cluster_info.account = f"{userPrefix}.root" | |||
| db.session.add(cluster_info) | |||
| db.session.commit() | |||
| else: | |||
| response.raise_for_status() | |||
| @staticmethod | |||
| def batch_create_tidb_serverless_cluster( | |||
| batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str | |||
| ) -> list[dict]: | |||
| """ | |||
| Creates a new TiDB Serverless cluster. | |||
| :param project_id: The project ID of the TiDB Cloud project (required). | |||
| :param api_url: The URL of the TiDB Cloud API (required). | |||
| :param iam_url: The URL of the TiDB Cloud IAM API (required). | |||
| :param public_key: The public key for the API (required). | |||
| :param private_key: The private key for the API (required). | |||
| :param display_name: The user-friendly display name of the cluster (required). | |||
| :param region: The region where the cluster will be created (required). | |||
| :return: The response from the API. | |||
| """ | |||
| clusters = [] | |||
| for _ in range(batch_size): | |||
| region_object = { | |||
| "name": region, | |||
| } | |||
| labels = { | |||
| "tidb.cloud/project": project_id, | |||
| } | |||
| spending_limit = { | |||
| "monthly": 10, | |||
| } | |||
| password = str(uuid.uuid4()).replace("-", "")[:16] | |||
| display_name = str(uuid.uuid4()).replace("-", "") | |||
| cluster_data = { | |||
| "cluster": { | |||
| "displayName": display_name, | |||
| "region": region_object, | |||
| "labels": labels, | |||
| "spendingLimit": spending_limit, | |||
| "rootPassword": password, | |||
| } | |||
| } | |||
| cache_key = f"tidb_serverless_cluster_password:{display_name}" | |||
| redis_client.setex(cache_key, 3600, password) | |||
| clusters.append(cluster_data) | |||
| request_body = {"requests": clusters} | |||
| response = requests.post( | |||
| f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key) | |||
| ) | |||
| if response.status_code == 200: | |||
| response_data = response.json() | |||
| cluster_infos = [] | |||
| for item in response_data["clusters"]: | |||
| cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" | |||
| password = redis_client.get(cache_key) | |||
| if not password: | |||
| continue | |||
| cluster_info = { | |||
| "cluster_id": item["clusterId"], | |||
| "cluster_name": item["displayName"], | |||
| "account": "root", | |||
| "password": password.decode("utf-8"), | |||
| } | |||
| cluster_infos.append(cluster_info) | |||
| return cluster_infos | |||
| else: | |||
| response.raise_for_status() | |||
| @@ -9,8 +9,9 @@ from core.rag.datasource.vdb.vector_type import VectorType | |||
| from core.rag.embedding.cached_embedding import CacheEmbedding | |||
| from core.rag.embedding.embedding_base import Embeddings | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset | |||
| from models.dataset import Dataset, Whitelist | |||
| class AbstractVectorFactory(ABC): | |||
| @@ -35,8 +36,18 @@ class Vector: | |||
| def _init_vector(self) -> BaseVector: | |||
| vector_type = dify_config.VECTOR_STORE | |||
| if self._dataset.index_struct_dict: | |||
| vector_type = self._dataset.index_struct_dict["type"] | |||
| else: | |||
| if dify_config.VECTOR_STORE_WHITELIST_ENABLE: | |||
| whitelist = ( | |||
| db.session.query(Whitelist) | |||
| .filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") | |||
| .one_or_none() | |||
| ) | |||
| if whitelist: | |||
| vector_type = VectorType.TIDB_ON_QDRANT | |||
| if not vector_type: | |||
| raise ValueError("Vector store must be specified.") | |||
| @@ -115,6 +126,10 @@ class Vector: | |||
| from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory | |||
| return UpstashVectorFactory | |||
| case VectorType.TIDB_ON_QDRANT: | |||
| from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory | |||
| return TidbOnQdrantVectorFactory | |||
| case _: | |||
| raise ValueError(f"Vector store {vector_type} is not supported.") | |||
| @@ -19,3 +19,4 @@ class VectorType(str, Enum): | |||
| BAIDU = "baidu" | |||
| VIKINGDB = "vikingdb" | |||
| UPSTASH = "upstash" | |||
| TIDB_ON_QDRANT = "tidb_on_qdrant" | |||
| @@ -1,6 +1,7 @@ | |||
| from datetime import timedelta | |||
| from celery import Celery, Task | |||
| from celery.schedules import crontab | |||
| from flask import Flask | |||
| from configs import dify_config | |||
| @@ -55,6 +56,8 @@ def init_app(app: Flask) -> Celery: | |||
| imports = [ | |||
| "schedule.clean_embedding_cache_task", | |||
| "schedule.clean_unused_datasets_task", | |||
| "schedule.create_tidb_serverless_task", | |||
| "schedule.update_tidb_serverless_status_task", | |||
| ] | |||
| day = dify_config.CELERY_BEAT_SCHEDULER_TIME | |||
| beat_schedule = { | |||
| @@ -66,6 +69,14 @@ def init_app(app: Flask) -> Celery: | |||
| "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", | |||
| "schedule": timedelta(days=day), | |||
| }, | |||
| "create_tidb_serverless_task": { | |||
| "task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task", | |||
| "schedule": crontab(minute="0", hour="*"), | |||
| }, | |||
| "update_tidb_serverless_status_task": { | |||
| "task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task", | |||
| "schedule": crontab(minute="30", hour="*"), | |||
| }, | |||
| } | |||
| celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) | |||
| @@ -0,0 +1,51 @@ | |||
| """add-tidb-auth-binding | |||
| Revision ID: 0251a1c768cc | |||
| Revises: 63a83fcf12ba | |||
| Create Date: 2024-08-15 09:56:59.012490 | |||
| """ | |||
| import sqlalchemy as sa | |||
| from alembic import op | |||
| import models as models | |||
| # revision identifiers, used by Alembic. | |||
| revision = '0251a1c768cc' | |||
| down_revision = 'bbadea11becb' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| op.create_table('tidb_auth_bindings', | |||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=True), | |||
| sa.Column('cluster_id', sa.String(length=255), nullable=False), | |||
| sa.Column('cluster_name', sa.String(length=255), nullable=False), | |||
| sa.Column('active', sa.Boolean(), server_default=sa.text('false'), nullable=False), | |||
| sa.Column('status', sa.String(length=255), server_default=sa.text("'CREATING'::character varying"), nullable=False), | |||
| sa.Column('account', sa.String(length=255), nullable=False), | |||
| sa.Column('password', sa.String(length=255), nullable=False), | |||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||
| sa.PrimaryKeyConstraint('id', name='tidb_auth_bindings_pkey') | |||
| ) | |||
| with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op: | |||
| batch_op.create_index('tidb_auth_bindings_active_idx', ['active'], unique=False) | |||
| batch_op.create_index('tidb_auth_bindings_status_idx', ['status'], unique=False) | |||
| batch_op.create_index('tidb_auth_bindings_created_at_idx', ['created_at'], unique=False) | |||
| batch_op.create_index('tidb_auth_bindings_tenant_idx', ['tenant_id'], unique=False) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op: | |||
| batch_op.drop_index('tidb_auth_bindings_tenant_idx') | |||
| batch_op.drop_index('tidb_auth_bindings_created_at_idx') | |||
| batch_op.drop_index('tidb_auth_bindings_active_idx') | |||
| batch_op.drop_index('tidb_auth_bindings_status_idx') | |||
| op.drop_table('tidb_auth_bindings') | |||
| # ### end Alembic commands ### | |||
| @@ -0,0 +1,42 @@ | |||
| """add_white_list | |||
| Revision ID: 43fa78bc3b7d | |||
| Revises: 0251a1c768cc | |||
| Create Date: 2024-10-22 09:59:23.713716 | |||
| """ | |||
| from alembic import op | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.dialects import postgresql | |||
| # revision identifiers, used by Alembic. | |||
| revision = '43fa78bc3b7d' | |||
| down_revision = '0251a1c768cc' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| op.create_table('whitelists', | |||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=True), | |||
| sa.Column('category', sa.String(length=255), nullable=False), | |||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||
| sa.PrimaryKeyConstraint('id', name='whitelists_pkey') | |||
| ) | |||
| with op.batch_alter_table('whitelists', schema=None) as batch_op: | |||
| batch_op.create_index('whitelists_tenant_idx', ['tenant_id'], unique=False) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('whitelists', schema=None) as batch_op: | |||
| batch_op.drop_index('whitelists_tenant_idx') | |||
| op.drop_table('whitelists') | |||
| # ### end Alembic commands ### | |||
| @@ -704,6 +704,38 @@ class DatasetCollectionBinding(db.Model): | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class TidbAuthBinding(db.Model): | |||
| __tablename__ = "tidb_auth_bindings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), | |||
| db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), | |||
| db.Index("tidb_auth_bindings_active_idx", "active"), | |||
| db.Index("tidb_auth_bindings_created_at_idx", "created_at"), | |||
| db.Index("tidb_auth_bindings_status_idx", "status"), | |||
| ) | |||
| id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=True) | |||
| cluster_id = db.Column(db.String(255), nullable=False) | |||
| cluster_name = db.Column(db.String(255), nullable=False) | |||
| active = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||
| status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) | |||
| account = db.Column(db.String(255), nullable=False) | |||
| password = db.Column(db.String(255), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class Whitelist(db.Model): | |||
| __tablename__ = "whitelists" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="whitelists_pkey"), | |||
| db.Index("whitelists_tenant_idx", "tenant_id"), | |||
| ) | |||
| id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=True) | |||
| category = db.Column(db.String(255), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class DatasetPermission(db.Model): | |||
| __tablename__ = "dataset_permissions" | |||
| __table_args__ = ( | |||
| @@ -0,0 +1,56 @@ | |||
| import time | |||
| import click | |||
| import app | |||
| from configs import dify_config | |||
| from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService | |||
| from extensions.ext_database import db | |||
| from models.dataset import TidbAuthBinding | |||
| @app.celery.task(queue="dataset") | |||
| def create_tidb_serverless_task(): | |||
| click.echo(click.style("Start create tidb serverless task.", fg="green")) | |||
| tidb_serverless_number = dify_config.TIDB_SERVERLESS_NUMBER | |||
| start_at = time.perf_counter() | |||
| while True: | |||
| try: | |||
| # check the number of idle tidb serverless | |||
| idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count() | |||
| if idle_tidb_serverless_number >= tidb_serverless_number: | |||
| break | |||
| # create tidb serverless | |||
| iterations_per_thread = 20 | |||
| create_clusters(iterations_per_thread) | |||
| except Exception as e: | |||
| click.echo(click.style(f"Error: {e}", fg="red")) | |||
| break | |||
| end_at = time.perf_counter() | |||
| click.echo(click.style("Create tidb serverless task success latency: {}".format(end_at - start_at), fg="green")) | |||
| def create_clusters(batch_size): | |||
| try: | |||
| new_clusters = TidbService.batch_create_tidb_serverless_cluster( | |||
| batch_size, | |||
| dify_config.TIDB_PROJECT_ID, | |||
| dify_config.TIDB_API_URL, | |||
| dify_config.TIDB_IAM_API_URL, | |||
| dify_config.TIDB_PUBLIC_KEY, | |||
| dify_config.TIDB_PRIVATE_KEY, | |||
| dify_config.TIDB_REGION, | |||
| ) | |||
| for new_cluster in new_clusters: | |||
| tidb_auth_binding = TidbAuthBinding( | |||
| cluster_id=new_cluster["cluster_id"], | |||
| cluster_name=new_cluster["cluster_name"], | |||
| account=new_cluster["account"], | |||
| password=new_cluster["password"], | |||
| ) | |||
| db.session.add(tidb_auth_binding) | |||
| db.session.commit() | |||
| except Exception as e: | |||
| click.echo(click.style(f"Error: {e}", fg="red")) | |||
| @@ -0,0 +1,51 @@ | |||
| import time | |||
| import click | |||
| import app | |||
| from configs import dify_config | |||
| from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService | |||
| from models.dataset import TidbAuthBinding | |||
| @app.celery.task(queue="dataset") | |||
| def update_tidb_serverless_status_task(): | |||
| click.echo(click.style("Update tidb serverless status task.", fg="green")) | |||
| start_at = time.perf_counter() | |||
| while True: | |||
| try: | |||
| # check the number of idle tidb serverless | |||
| tidb_serverless_list = TidbAuthBinding.query.filter( | |||
| TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" | |||
| ).all() | |||
| if len(tidb_serverless_list) == 0: | |||
| break | |||
| # update tidb serverless status | |||
| iterations_per_thread = 20 | |||
| update_clusters(tidb_serverless_list) | |||
| except Exception as e: | |||
| click.echo(click.style(f"Error: {e}", fg="red")) | |||
| break | |||
| end_at = time.perf_counter() | |||
| click.echo( | |||
| click.style("Update tidb serverless status task success latency: {}".format(end_at - start_at), fg="green") | |||
| ) | |||
| def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): | |||
| try: | |||
| # batch 20 | |||
| for i in range(0, len(tidb_serverless_list), 20): | |||
| items = tidb_serverless_list[i : i + 20] | |||
| TidbService.batch_update_tidb_serverless_cluster_status( | |||
| items, | |||
| dify_config.TIDB_PROJECT_ID, | |||
| dify_config.TIDB_API_URL, | |||
| dify_config.TIDB_IAM_API_URL, | |||
| dify_config.TIDB_PUBLIC_KEY, | |||
| dify_config.TIDB_PRIVATE_KEY, | |||
| ) | |||
| except Exception as e: | |||
| click.echo(click.style(f"Error: {e}", fg="red")) | |||
| @@ -0,0 +1,44 @@ | |||
| import json | |||
| import requests | |||
| from services.auth.api_key_auth_base import ApiKeyAuthBase | |||
| class JinaAuth(ApiKeyAuthBase): | |||
| def __init__(self, credentials: dict): | |||
| super().__init__(credentials) | |||
| auth_type = credentials.get("auth_type") | |||
| if auth_type != "bearer": | |||
| raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") | |||
| self.api_key = credentials.get("config").get("api_key", None) | |||
| if not self.api_key: | |||
| raise ValueError("No API key provided") | |||
| def validate_credentials(self): | |||
| headers = self._prepare_headers() | |||
| options = { | |||
| "url": "https://example.com", | |||
| } | |||
| response = self._post_request("https://r.jina.ai", options, headers) | |||
| if response.status_code == 200: | |||
| return True | |||
| else: | |||
| self._handle_error(response) | |||
| def _prepare_headers(self): | |||
| return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} | |||
| def _post_request(self, url, data, headers): | |||
| return requests.post(url, headers=headers, json=data) | |||
| def _handle_error(self, response): | |||
| if response.status_code in {402, 409, 500}: | |||
| error_message = response.json().get("error", "Unknown error occurred") | |||
| raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") | |||
| else: | |||
| if response.text: | |||
| error_message = json.loads(response.text).get("error", "Unknown error occurred") | |||
| raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") | |||
| raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") | |||
| @@ -0,0 +1,40 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React from 'react' | |||
| import cn from '@/utils/classnames' | |||
| import Checkbox from '@/app/components/base/checkbox' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| type Props = { | |||
| className?: string | |||
| isChecked: boolean | |||
| onChange: (isChecked: boolean) => void | |||
| label: string | |||
| labelClassName?: string | |||
| tooltip?: string | |||
| } | |||
| const CheckboxWithLabel: FC<Props> = ({ | |||
| className = '', | |||
| isChecked, | |||
| onChange, | |||
| label, | |||
| labelClassName, | |||
| tooltip, | |||
| }) => { | |||
| return ( | |||
| <label className={cn(className, 'flex items-center h-7 space-x-2')}> | |||
| <Checkbox checked={isChecked} onCheck={() => onChange(!isChecked)} /> | |||
| <div className={cn(labelClassName, 'text-sm font-normal text-gray-800')}>{label}</div> | |||
| {tooltip && ( | |||
| <Tooltip | |||
| popupContent={ | |||
| <div className='w-[200px]'>{tooltip}</div> | |||
| } | |||
| triggerClassName='ml-0.5 w-4 h-4' | |||
| /> | |||
| )} | |||
| </label> | |||
| ) | |||
| } | |||
| export default React.memo(CheckboxWithLabel) | |||
| @@ -0,0 +1,30 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React from 'react' | |||
| import cn from '@/utils/classnames' | |||
| import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' | |||
| type Props = { | |||
| className?: string | |||
| title: string | |||
| errorMsg?: string | |||
| } | |||
| const ErrorMessage: FC<Props> = ({ | |||
| className, | |||
| title, | |||
| errorMsg, | |||
| }) => { | |||
| return ( | |||
| <div className={cn(className, 'py-2 px-4 border-t border-gray-200 bg-[#FFFAEB]')}> | |||
| <div className='flex items-center h-5'> | |||
| <AlertTriangle className='mr-2 w-4 h-4 text-[#F79009]' /> | |||
| <div className='text-sm font-medium text-[#DC6803]'>{title}</div> | |||
| </div> | |||
| {errorMsg && ( | |||
| <div className='mt-1 pl-6 leading-[18px] text-xs font-normal text-gray-700'>{errorMsg}</div> | |||
| )} | |||
| </div> | |||
| ) | |||
| } | |||
| export default React.memo(ErrorMessage) | |||
| @@ -0,0 +1,54 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React from 'react' | |||
| import Input from './input' | |||
| import cn from '@/utils/classnames' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| type Props = { | |||
| className?: string | |||
| label: string | |||
| labelClassName?: string | |||
| value: string | number | |||
| onChange: (value: string | number) => void | |||
| isRequired?: boolean | |||
| placeholder?: string | |||
| isNumber?: boolean | |||
| tooltip?: string | |||
| } | |||
| const Field: FC<Props> = ({ | |||
| className, | |||
| label, | |||
| labelClassName, | |||
| value, | |||
| onChange, | |||
| isRequired = false, | |||
| placeholder = '', | |||
| isNumber = false, | |||
| tooltip, | |||
| }) => { | |||
| return ( | |||
| <div className={cn(className)}> | |||
| <div className='flex py-[7px]'> | |||
| <div className={cn(labelClassName, 'flex items-center h-[18px] text-[13px] font-medium text-gray-900')}>{label} </div> | |||
| {isRequired && <span className='ml-0.5 text-xs font-semibold text-[#D92D20]'>*</span>} | |||
| {tooltip && ( | |||
| <Tooltip | |||
| popupContent={ | |||
| <div className='w-[200px]'>{tooltip}</div> | |||
| } | |||
| triggerClassName='ml-0.5 w-4 h-4' | |||
| /> | |||
| )} | |||
| </div> | |||
| <Input | |||
| value={value} | |||
| onChange={onChange} | |||
| placeholder={placeholder} | |||
| isNumber={isNumber} | |||
| /> | |||
| </div> | |||
| ) | |||
| } | |||
| export default React.memo(Field) | |||
| @@ -0,0 +1,58 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React, { useCallback } from 'react' | |||
| type Props = { | |||
| value: string | number | |||
| onChange: (value: string | number) => void | |||
| placeholder?: string | |||
| isNumber?: boolean | |||
| } | |||
| const MIN_VALUE = 0 | |||
| const Input: FC<Props> = ({ | |||
| value, | |||
| onChange, | |||
| placeholder = '', | |||
| isNumber = false, | |||
| }) => { | |||
| const handleChange = useCallback((e: React.ChangeEvent<HTMLInputElement>) => { | |||
| const value = e.target.value | |||
| if (isNumber) { | |||
| let numberValue = parseInt(value, 10) // integer only | |||
| if (isNaN(numberValue)) { | |||
| onChange('') | |||
| return | |||
| } | |||
| if (numberValue < MIN_VALUE) | |||
| numberValue = MIN_VALUE | |||
| onChange(numberValue) | |||
| return | |||
| } | |||
| onChange(value) | |||
| }, [isNumber, onChange]) | |||
| const otherOption = (() => { | |||
| if (isNumber) { | |||
| return { | |||
| min: MIN_VALUE, | |||
| } | |||
| } | |||
| return { | |||
| } | |||
| })() | |||
| return ( | |||
| <input | |||
| type={isNumber ? 'number' : 'text'} | |||
| {...otherOption} | |||
| value={value} | |||
| onChange={handleChange} | |||
| className='flex h-9 w-full py-1 px-2 rounded-lg text-xs leading-normal bg-gray-100 caret-primary-600 hover:bg-gray-100 focus:ring-1 focus:ring-inset focus:ring-gray-200 focus-visible:outline-none focus:bg-gray-50 placeholder:text-gray-400' | |||
| placeholder={placeholder} | |||
| /> | |||
| ) | |||
| } | |||
| export default React.memo(Input) | |||
| @@ -0,0 +1,55 @@ | |||
| 'use client' | |||
| import { useBoolean } from 'ahooks' | |||
| import type { FC } from 'react' | |||
| import React, { useEffect } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import cn from '@/utils/classnames' | |||
| import { Settings04 } from '@/app/components/base/icons/src/vender/line/general' | |||
| import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows' | |||
| const I18N_PREFIX = 'datasetCreation.stepOne.website' | |||
| type Props = { | |||
| className?: string | |||
| children: React.ReactNode | |||
| controlFoldOptions?: number | |||
| } | |||
| const OptionsWrap: FC<Props> = ({ | |||
| className = '', | |||
| children, | |||
| controlFoldOptions, | |||
| }) => { | |||
| const { t } = useTranslation() | |||
| const [fold, { | |||
| toggle: foldToggle, | |||
| setTrue: foldHide, | |||
| }] = useBoolean(false) | |||
| useEffect(() => { | |||
| if (controlFoldOptions) | |||
| foldHide() | |||
| // eslint-disable-next-line react-hooks/exhaustive-deps | |||
| }, [controlFoldOptions]) | |||
| return ( | |||
| <div className={cn(className, !fold ? 'mb-0' : 'mb-3')}> | |||
| <div | |||
| className='flex justify-between items-center h-[26px] py-1 cursor-pointer select-none' | |||
| onClick={foldToggle} | |||
| > | |||
| <div className='flex items-center text-gray-700'> | |||
| <Settings04 className='mr-1 w-4 h-4' /> | |||
| <div className='text-[13px] font-semibold text-gray-800 uppercase'>{t(`${I18N_PREFIX}.options`)}</div> | |||
| </div> | |||
| <ChevronRight className={cn(!fold && 'rotate-90', 'w-4 h-4 text-gray-500')} /> | |||
| </div> | |||
| {!fold && ( | |||
| <div className='mb-4'> | |||
| {children} | |||
| </div> | |||
| )} | |||
| </div> | |||
| ) | |||
| } | |||
| export default React.memo(OptionsWrap) | |||
| @@ -0,0 +1,48 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React, { useCallback, useState } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import Input from './input' | |||
| import Button from '@/app/components/base/button' | |||
| const I18N_PREFIX = 'datasetCreation.stepOne.website' | |||
| type Props = { | |||
| isRunning: boolean | |||
| onRun: (url: string) => void | |||
| } | |||
| const UrlInput: FC<Props> = ({ | |||
| isRunning, | |||
| onRun, | |||
| }) => { | |||
| const { t } = useTranslation() | |||
| const [url, setUrl] = useState('') | |||
| const handleUrlChange = useCallback((url: string | number) => { | |||
| setUrl(url as string) | |||
| }, []) | |||
| const handleOnRun = useCallback(() => { | |||
| if (isRunning) | |||
| return | |||
| onRun(url) | |||
| }, [isRunning, onRun, url]) | |||
| return ( | |||
| <div className='flex items-center justify-between'> | |||
| <Input | |||
| value={url} | |||
| onChange={handleUrlChange} | |||
| placeholder='https://docs.dify.ai' | |||
| /> | |||
| <Button | |||
| variant='primary' | |||
| onClick={handleOnRun} | |||
| className='ml-2' | |||
| loading={isRunning} | |||
| > | |||
| {!isRunning ? t(`${I18N_PREFIX}.run`) : ''} | |||
| </Button> | |||
| </div> | |||
| ) | |||
| } | |||
| export default React.memo(UrlInput) | |||
| @@ -0,0 +1,40 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React, { useCallback } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import cn from '@/utils/classnames' | |||
| import type { CrawlResultItem as CrawlResultItemType } from '@/models/datasets' | |||
| import Checkbox from '@/app/components/base/checkbox' | |||
| type Props = { | |||
| payload: CrawlResultItemType | |||
| isChecked: boolean | |||
| isPreview: boolean | |||
| onCheckChange: (checked: boolean) => void | |||
| onPreview: () => void | |||
| } | |||
| const CrawledResultItem: FC<Props> = ({ | |||
| isPreview, | |||
| payload, | |||
| isChecked, | |||
| onCheckChange, | |||
| onPreview, | |||
| }) => { | |||
| const { t } = useTranslation() | |||
| const handleCheckChange = useCallback(() => { | |||
| onCheckChange(!isChecked) | |||
| }, [isChecked, onCheckChange]) | |||
| return ( | |||
| <div className={cn(isPreview ? 'border-[#D1E0FF] bg-primary-50 shadow-xs' : 'group hover:bg-gray-100', 'rounded-md px-2 py-[5px] cursor-pointer border border-transparent')}> | |||
| <div className='flex items-center h-5'> | |||
| <Checkbox className='group-hover:border-2 group-hover:border-primary-600 mr-2 shrink-0' checked={isChecked} onCheck={handleCheckChange} /> | |||
| <div className='grow w-0 truncate text-sm font-medium text-gray-700' title={payload.title}>{payload.title}</div> | |||
| <div onClick={onPreview} className='hidden group-hover:flex items-center h-6 px-2 text-xs rounded-md font-medium text-gray-500 uppercase hover:bg-gray-50'>{t('datasetCreation.stepOne.website.preview')}</div> | |||
| </div> | |||
| <div className='mt-0.5 truncate pl-6 leading-[18px] text-xs font-normal text-gray-500' title={payload.source_url}>{payload.source_url}</div> | |||
| </div> | |||
| ) | |||
| } | |||
| export default React.memo(CrawledResultItem) | |||
| @@ -0,0 +1,87 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React, { useCallback } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import CheckboxWithLabel from './base/checkbox-with-label' | |||
| import CrawledResultItem from './crawled-result-item' | |||
| import cn from '@/utils/classnames' | |||
| import type { CrawlResultItem } from '@/models/datasets' | |||
| const I18N_PREFIX = 'datasetCreation.stepOne.website' | |||
| type Props = { | |||
| className?: string | |||
| list: CrawlResultItem[] | |||
| checkedList: CrawlResultItem[] | |||
| onSelectedChange: (selected: CrawlResultItem[]) => void | |||
| onPreview: (payload: CrawlResultItem) => void | |||
| usedTime: number | |||
| } | |||
| const CrawledResult: FC<Props> = ({ | |||
| className = '', | |||
| list, | |||
| checkedList, | |||
| onSelectedChange, | |||
| onPreview, | |||
| usedTime, | |||
| }) => { | |||
| const { t } = useTranslation() | |||
| const isCheckAll = checkedList.length === list.length | |||
| const handleCheckedAll = useCallback(() => { | |||
| if (!isCheckAll) | |||
| onSelectedChange(list) | |||
| else | |||
| onSelectedChange([]) | |||
| }, [isCheckAll, list, onSelectedChange]) | |||
| const handleItemCheckChange = useCallback((item: CrawlResultItem) => { | |||
| return (checked: boolean) => { | |||
| if (checked) | |||
| onSelectedChange([...checkedList, item]) | |||
| else | |||
| onSelectedChange(checkedList.filter(checkedItem => checkedItem.source_url !== item.source_url)) | |||
| } | |||
| }, [checkedList, onSelectedChange]) | |||
| const [previewIndex, setPreviewIndex] = React.useState<number>(-1) | |||
| const handlePreview = useCallback((index: number) => { | |||
| return () => { | |||
| setPreviewIndex(index) | |||
| onPreview(list[index]) | |||
| } | |||
| }, [list, onPreview]) | |||
| return ( | |||
| <div className={cn(className, 'border-t border-gray-200')}> | |||
| <div className='flex items-center justify-between h-[34px] px-4 bg-gray-50 shadow-xs border-b-[0.5px] border-black/8 text-xs font-normal text-gray-700'> | |||
| <CheckboxWithLabel | |||
| isChecked={isCheckAll} | |||
| onChange={handleCheckedAll} label={isCheckAll ? t(`${I18N_PREFIX}.resetAll`) : t(`${I18N_PREFIX}.selectAll`)} | |||
| labelClassName='!font-medium' | |||
| /> | |||
| <div>{t(`${I18N_PREFIX}.scrapTimeInfo`, { | |||
| total: list.length, | |||
| time: usedTime.toFixed(1), | |||
| })}</div> | |||
| </div> | |||
| <div className='p-2'> | |||
| {list.map((item, index) => ( | |||
| <CrawledResultItem | |||
| key={item.source_url} | |||
| isPreview={index === previewIndex} | |||
| onPreview={handlePreview(index)} | |||
| payload={item} | |||
| isChecked={checkedList.some(checkedItem => checkedItem.source_url === item.source_url)} | |||
| onCheckChange={handleItemCheckChange(item)} | |||
| /> | |||
| ))} | |||
| </div> | |||
| </div> | |||
| ) | |||
| } | |||
| export default React.memo(CrawledResult) | |||
| @@ -0,0 +1,37 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import cn from '@/utils/classnames' | |||
| import { RowStruct } from '@/app/components/base/icons/src/public/other' | |||
| type Props = { | |||
| className?: string | |||
| crawledNum: number | |||
| totalNum: number | |||
| } | |||
| const Crawling: FC<Props> = ({ | |||
| className = '', | |||
| crawledNum, | |||
| totalNum, | |||
| }) => { | |||
| const { t } = useTranslation() | |||
| return ( | |||
| <div className={cn(className, 'border-t border-gray-200')}> | |||
| <div className='flex items-center h-[34px] px-4 bg-gray-50 shadow-xs border-b-[0.5px] border-black/8 text-xs font-normal text-gray-700'> | |||
| {t('datasetCreation.stepOne.website.totalPageScraped')} {crawledNum}/{totalNum} | |||
| </div> | |||
| <div className='p-2'> | |||
| {['', '', '', ''].map((item, index) => ( | |||
| <div className='py-[5px]' key={index}> | |||
| <RowStruct /> | |||
| </div> | |||
| ))} | |||
| </div> | |||
| </div> | |||
| ) | |||
| } | |||
| export default React.memo(Crawling) | |||
| @@ -0,0 +1,24 @@ | |||
| import type { CrawlResultItem } from '@/models/datasets' | |||
| const result: CrawlResultItem[] = [ | |||
| { | |||
| title: 'Start the frontend Docker container separately', | |||
| markdown: 'Markdown 1', | |||
| description: 'Description 1', | |||
| source_url: 'https://example.com/1', | |||
| }, | |||
| { | |||
| title: 'Advanced Tool Integration', | |||
| markdown: 'Markdown 2', | |||
| description: 'Description 2', | |||
| source_url: 'https://example.com/2', | |||
| }, | |||
| { | |||
| title: 'Local Source Code Start | English | Dify', | |||
| markdown: 'Markdown 3', | |||
| description: 'Description 3', | |||
| source_url: 'https://example.com/3', | |||
| }, | |||
| ] | |||
| export default result | |||