| # Access token expiration time in minutes | # Access token expiration time in minutes | ||||
| ACCESS_TOKEN_EXPIRE_MINUTES=60 | ACCESS_TOKEN_EXPIRE_MINUTES=60 | ||||
| # Refresh token expiration time in days | |||||
| REFRESH_TOKEN_EXPIRE_DAYS=30 | |||||
| # celery configuration | # celery configuration | ||||
| CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 | CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 | ||||
| app = create_migrations_app() | app = create_migrations_app() | ||||
| else: | else: | ||||
| if os.environ.get("FLASK_DEBUG", "False") != "True": | |||||
| # It seems that JetBrains Python debugger does not work well with gevent, | |||||
| # so we need to disable gevent in debug mode. | |||||
| # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. | |||||
| if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: | |||||
| from gevent import monkey # type: ignore | from gevent import monkey # type: ignore | ||||
| # gevent | # gevent | 
| default=60, | default=60, | ||||
| ) | ) | ||||
| REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field( | |||||
| description="Expiration time for refresh tokens in days", | |||||
| default=30, | |||||
| ) | |||||
| LOGIN_LOCKOUT_DURATION: PositiveInt = Field( | LOGIN_LOCKOUT_DURATION: PositiveInt = Field( | ||||
| description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.", | description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.", | ||||
| default=86400, | default=86400, | ||||
| default=4000, | default=4000, | ||||
| ) | ) | ||||
| CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field( | |||||
| description="Maximum number of child chunks to preview", | |||||
| default=50, | |||||
| ) | |||||
| class MultiModalTransferConfig(BaseSettings): | class MultiModalTransferConfig(BaseSettings): | ||||
| MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field( | MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field( | 
| description="Name of the Milvus database to connect to (default is 'default')", | description="Name of the Milvus database to connect to (default is 'default')", | ||||
| default="default", | default="default", | ||||
| ) | ) | ||||
| MILVUS_ENABLE_HYBRID_SEARCH: bool = Field( | |||||
| description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with " | |||||
| "older versions", | |||||
| default=True, | |||||
| ) | 
| CURRENT_VERSION: str = Field( | CURRENT_VERSION: str = Field( | ||||
| description="Dify version", | description="Dify version", | ||||
| default="0.14.2", | |||||
| default="0.15.0", | |||||
| ) | ) | ||||
| COMMIT_SHA: str = Field( | COMMIT_SHA: str = Field( | 
| ) | ) | ||||
| parser.add_argument("name", type=str, location="args", required=False) | parser.add_argument("name", type=str, location="args", required=False) | ||||
| parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) | parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) | ||||
| parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False) | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| # get app list | # get app list | ||||
| app_service = AppService() | app_service = AppService() | ||||
| app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) | |||||
| app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args) | |||||
| if not app_pagination: | if not app_pagination: | ||||
| return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} | return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} | ||||
| from core.app.apps.base_app_queue_manager import AppQueueManager | from core.app.apps.base_app_queue_manager import AppQueueManager | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.errors.error import ( | from core.errors.error import ( | ||||
| AppInvokeQuotaExceededError, | |||||
| ModelCurrentlyNotSupportError, | ModelCurrentlyNotSupportError, | ||||
| ProviderTokenNotInitError, | ProviderTokenNotInitError, | ||||
| QuotaExceededError, | QuotaExceededError, | ||||
| raise ProviderModelCurrentlyNotSupportError() | raise ProviderModelCurrentlyNotSupportError() | ||||
| except InvokeError as e: | except InvokeError as e: | ||||
| raise CompletionRequestError(e.description) | raise CompletionRequestError(e.description) | ||||
| except (ValueError, AppInvokeQuotaExceededError) as e: | |||||
| except ValueError as e: | |||||
| raise e | raise e | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception("internal server error.") | logging.exception("internal server error.") | ||||
| raise InvokeRateLimitHttpError(ex.description) | raise InvokeRateLimitHttpError(ex.description) | ||||
| except InvokeError as e: | except InvokeError as e: | ||||
| raise CompletionRequestError(e.description) | raise CompletionRequestError(e.description) | ||||
| except (ValueError, AppInvokeQuotaExceededError) as e: | |||||
| except ValueError as e: | |||||
| raise e | raise e | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception("internal server error.") | logging.exception("internal server error.") | 
| messages m | messages m | ||||
| ON c.id = m.conversation_id | ON c.id = m.conversation_id | ||||
| WHERE | WHERE | ||||
| c.override_model_configs IS NULL | |||||
| AND c.app_id = :app_id""" | |||||
| c.app_id = :app_id""" | |||||
| arg_dict = {"tz": account.timezone, "app_id": app_model.id} | arg_dict = {"tz": account.timezone, "app_id": app_model.id} | ||||
| timezone = pytz.timezone(account.timezone) | timezone = pytz.timezone(account.timezone) | 
| | VectorType.MYSCALE | | VectorType.MYSCALE | ||||
| | VectorType.ORACLE | | VectorType.ORACLE | ||||
| | VectorType.ELASTICSEARCH | | VectorType.ELASTICSEARCH | ||||
| | VectorType.ELASTICSEARCH_JA | |||||
| | VectorType.PGVECTOR | | VectorType.PGVECTOR | ||||
| | VectorType.TIDB_ON_QDRANT | | VectorType.TIDB_ON_QDRANT | ||||
| | VectorType.LINDORM | | VectorType.LINDORM | ||||
| | VectorType.MYSCALE | | VectorType.MYSCALE | ||||
| | VectorType.ORACLE | | VectorType.ORACLE | ||||
| | VectorType.ELASTICSEARCH | | VectorType.ELASTICSEARCH | ||||
| | VectorType.ELASTICSEARCH_JA | |||||
| | VectorType.COUCHBASE | | VectorType.COUCHBASE | ||||
| | VectorType.PGVECTOR | | VectorType.PGVECTOR | ||||
| | VectorType.LINDORM | | VectorType.LINDORM | 
| parser.add_argument("original_document_id", type=str, required=False, location="json") | parser.add_argument("original_document_id", type=str, required=False, location="json") | ||||
| parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | ||||
| parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") | parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") | ||||
| parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") | |||||
| parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") | |||||
| parser.add_argument( | parser.add_argument( | ||||
| "doc_language", type=str, default="English", required=False, nullable=False, location="json" | "doc_language", type=str, default="English", required=False, nullable=False, location="json" | ||||
| ) | ) | 
| from controllers.console.explore.wraps import InstalledAppResource | from controllers.console.explore.wraps import InstalledAppResource | ||||
| from core.app.apps.base_app_queue_manager import AppQueueManager | from core.app.apps.base_app_queue_manager import AppQueueManager | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||||
| from core.errors.error import ( | |||||
| ModelCurrentlyNotSupportError, | |||||
| ProviderTokenNotInitError, | |||||
| QuotaExceededError, | |||||
| ) | |||||
| from core.model_runtime.errors.invoke import InvokeError | from core.model_runtime.errors.invoke import InvokeError | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs import helper | from libs import helper | 
| from controllers.console.explore.wraps import InstalledAppResource | from controllers.console.explore.wraps import InstalledAppResource | ||||
| from core.app.apps.base_app_queue_manager import AppQueueManager | from core.app.apps.base_app_queue_manager import AppQueueManager | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||||
| from core.errors.error import ( | |||||
| ModelCurrentlyNotSupportError, | |||||
| ProviderTokenNotInitError, | |||||
| QuotaExceededError, | |||||
| ) | |||||
| from core.model_runtime.errors.invoke import InvokeError | from core.model_runtime.errors.invoke import InvokeError | ||||
| from libs import helper | from libs import helper | ||||
| from libs.login import current_user | from libs.login import current_user | 
| from core.app.apps.base_app_queue_manager import AppQueueManager | from core.app.apps.base_app_queue_manager import AppQueueManager | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.errors.error import ( | from core.errors.error import ( | ||||
| AppInvokeQuotaExceededError, | |||||
| ModelCurrentlyNotSupportError, | ModelCurrentlyNotSupportError, | ||||
| ProviderTokenNotInitError, | ProviderTokenNotInitError, | ||||
| QuotaExceededError, | QuotaExceededError, | ||||
| raise ProviderModelCurrentlyNotSupportError() | raise ProviderModelCurrentlyNotSupportError() | ||||
| except InvokeError as e: | except InvokeError as e: | ||||
| raise CompletionRequestError(e.description) | raise CompletionRequestError(e.description) | ||||
| except (ValueError, AppInvokeQuotaExceededError) as e: | |||||
| except ValueError as e: | |||||
| raise e | raise e | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception("internal server error.") | logging.exception("internal server error.") | ||||
| raise ProviderModelCurrentlyNotSupportError() | raise ProviderModelCurrentlyNotSupportError() | ||||
| except InvokeError as e: | except InvokeError as e: | ||||
| raise CompletionRequestError(e.description) | raise CompletionRequestError(e.description) | ||||
| except (ValueError, AppInvokeQuotaExceededError) as e: | |||||
| except ValueError as e: | |||||
| raise e | raise e | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception("internal server error.") | logging.exception("internal server error.") | 
| from core.app.apps.base_app_queue_manager import AppQueueManager | from core.app.apps.base_app_queue_manager import AppQueueManager | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.errors.error import ( | from core.errors.error import ( | ||||
| AppInvokeQuotaExceededError, | |||||
| ModelCurrentlyNotSupportError, | ModelCurrentlyNotSupportError, | ||||
| ProviderTokenNotInitError, | ProviderTokenNotInitError, | ||||
| QuotaExceededError, | QuotaExceededError, | ||||
| raise ProviderModelCurrentlyNotSupportError() | raise ProviderModelCurrentlyNotSupportError() | ||||
| except InvokeError as e: | except InvokeError as e: | ||||
| raise CompletionRequestError(e.description) | raise CompletionRequestError(e.description) | ||||
| except (ValueError, AppInvokeQuotaExceededError) as e: | |||||
| except ValueError as e: | |||||
| raise e | raise e | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception("internal server error.") | logging.exception("internal server error.") | 
| user=current_user, | user=current_user, | ||||
| source="datasets", | source="datasets", | ||||
| ) | ) | ||||
| data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} | |||||
| data_source = { | |||||
| "type": "upload_file", | |||||
| "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, | |||||
| } | |||||
| args["data_source"] = data_source | args["data_source"] = data_source | ||||
| # validate args | # validate args | ||||
| knowledge_config = KnowledgeConfig(**args) | knowledge_config = KnowledgeConfig(**args) | ||||
| raise FileTooLargeError(file_too_large_error.description) | raise FileTooLargeError(file_too_large_error.description) | ||||
| except services.errors.file.UnsupportedFileTypeError: | except services.errors.file.UnsupportedFileTypeError: | ||||
| raise UnsupportedFileTypeError() | raise UnsupportedFileTypeError() | ||||
| data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} | |||||
| data_source = { | |||||
| "type": "upload_file", | |||||
| "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, | |||||
| } | |||||
| args["data_source"] = data_source | args["data_source"] = data_source | ||||
| # validate args | # validate args | ||||
| args["original_document_id"] = str(document_id) | args["original_document_id"] = str(document_id) | 
| from collections.abc import Callable | from collections.abc import Callable | ||||
| from datetime import UTC, datetime | |||||
| from datetime import UTC, datetime, timedelta | |||||
| from enum import Enum | from enum import Enum | ||||
| from functools import wraps | from functools import wraps | ||||
| from typing import Optional | from typing import Optional | ||||
| from flask_login import user_logged_in # type: ignore | from flask_login import user_logged_in # type: ignore | ||||
| from flask_restful import Resource # type: ignore | from flask_restful import Resource # type: ignore | ||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from sqlalchemy import select, update | |||||
| from sqlalchemy.orm import Session | |||||
| from werkzeug.exceptions import Forbidden, Unauthorized | from werkzeug.exceptions import Forbidden, Unauthorized | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| return decorator | return decorator | ||||
| def validate_and_get_api_token(scope=None): | |||||
| def validate_and_get_api_token(scope: str | None = None): | |||||
| """ | """ | ||||
| Validate and get API token. | Validate and get API token. | ||||
| """ | """ | ||||
| if auth_scheme != "bearer": | if auth_scheme != "bearer": | ||||
| raise Unauthorized("Authorization scheme must be 'Bearer'") | raise Unauthorized("Authorization scheme must be 'Bearer'") | ||||
| api_token = ( | |||||
| db.session.query(ApiToken) | |||||
| .filter( | |||||
| ApiToken.token == auth_token, | |||||
| ApiToken.type == scope, | |||||
| current_time = datetime.now(UTC).replace(tzinfo=None) | |||||
| cutoff_time = current_time - timedelta(minutes=1) | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| update_stmt = ( | |||||
| update(ApiToken) | |||||
| .where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope) | |||||
| .values(last_used_at=current_time) | |||||
| .returning(ApiToken) | |||||
| ) | ) | ||||
| .first() | |||||
| ) | |||||
| if not api_token: | |||||
| raise Unauthorized("Access token is invalid") | |||||
| api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| db.session.commit() | |||||
| result = session.execute(update_stmt) | |||||
| api_token = result.scalar_one_or_none() | |||||
| if not api_token: | |||||
| stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) | |||||
| api_token = session.scalar(stmt) | |||||
| if not api_token: | |||||
| raise Unauthorized("Access token is invalid") | |||||
| else: | |||||
| session.commit() | |||||
| return api_token | return api_token | ||||
| from controllers.web.wraps import WebApiResource | from controllers.web.wraps import WebApiResource | ||||
| from core.app.apps.base_app_queue_manager import AppQueueManager | from core.app.apps.base_app_queue_manager import AppQueueManager | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||||
| from core.errors.error import ( | |||||
| ModelCurrentlyNotSupportError, | |||||
| ProviderTokenNotInitError, | |||||
| QuotaExceededError, | |||||
| ) | |||||
| from core.model_runtime.errors.invoke import InvokeError | from core.model_runtime.errors.invoke import InvokeError | ||||
| from libs import helper | from libs import helper | ||||
| from libs.helper import uuid_value | from libs.helper import uuid_value | 
| from controllers.web.wraps import WebApiResource | from controllers.web.wraps import WebApiResource | ||||
| from core.app.apps.base_app_queue_manager import AppQueueManager | from core.app.apps.base_app_queue_manager import AppQueueManager | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||||
| from core.errors.error import ( | |||||
| ModelCurrentlyNotSupportError, | |||||
| ProviderTokenNotInitError, | |||||
| QuotaExceededError, | |||||
| ) | |||||
| from core.model_runtime.errors.invoke import InvokeError | from core.model_runtime.errors.invoke import InvokeError | ||||
| from libs import helper | from libs import helper | ||||
| from models.model import App, AppMode, EndUser | from models.model import App, AppMode, EndUser | 
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | ||||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom | from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom | ||||
| from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse | from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse | ||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||||
| from core.ops.ops_trace_manager import TraceQueueManager | from core.ops.ops_trace_manager import TraceQueueManager | ||||
| from core.prompt.utils.get_thread_messages_length import get_thread_messages_length | from core.prompt.utils.get_thread_messages_length import get_thread_messages_length | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| except ValidationError as e: | except ValidationError as e: | ||||
| logger.exception("Validation Error when generating") | logger.exception("Validation Error when generating") | ||||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | ||||
| except (ValueError, InvokeError) as e: | |||||
| except ValueError as e: | |||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| logger.exception("Error when generating") | logger.exception("Error when generating") | ||||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | 
| from models.enums import CreatedByRole | from models.enums import CreatedByRole | ||||
| from models.workflow import ( | from models.workflow import ( | ||||
| Workflow, | Workflow, | ||||
| WorkflowNodeExecution, | |||||
| WorkflowRunStatus, | WorkflowRunStatus, | ||||
| ) | ) | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage): | |||||
| class AdvancedChatAppGenerateTaskPipeline: | |||||
| """ | """ | ||||
| AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. | AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. | ||||
| """ | """ | ||||
| _task_state: WorkflowTaskState | |||||
| _application_generate_entity: AdvancedChatAppGenerateEntity | |||||
| _workflow_system_variables: dict[SystemVariableKey, Any] | |||||
| _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] | |||||
| _conversation_name_generate_thread: Optional[Thread] = None | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| application_generate_entity: AdvancedChatAppGenerateEntity, | application_generate_entity: AdvancedChatAppGenerateEntity, | ||||
| stream: bool, | stream: bool, | ||||
| dialogue_count: int, | dialogue_count: int, | ||||
| ) -> None: | ) -> None: | ||||
| super().__init__( | |||||
| self._base_task_pipeline = BasedGenerateTaskPipeline( | |||||
| application_generate_entity=application_generate_entity, | application_generate_entity=application_generate_entity, | ||||
| queue_manager=queue_manager, | queue_manager=queue_manager, | ||||
| stream=stream, | stream=stream, | ||||
| else: | else: | ||||
| raise NotImplementedError(f"User type not supported: {type(user)}") | raise NotImplementedError(f"User type not supported: {type(user)}") | ||||
| self._workflow_cycle_manager = WorkflowCycleManage( | |||||
| application_generate_entity=application_generate_entity, | |||||
| workflow_system_variables={ | |||||
| SystemVariableKey.QUERY: message.query, | |||||
| SystemVariableKey.FILES: application_generate_entity.files, | |||||
| SystemVariableKey.CONVERSATION_ID: conversation.id, | |||||
| SystemVariableKey.USER_ID: user_session_id, | |||||
| SystemVariableKey.DIALOGUE_COUNT: dialogue_count, | |||||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||||
| }, | |||||
| ) | |||||
| self._task_state = WorkflowTaskState() | |||||
| self._message_cycle_manager = MessageCycleManage( | |||||
| application_generate_entity=application_generate_entity, task_state=self._task_state | |||||
| ) | |||||
| self._application_generate_entity = application_generate_entity | |||||
| self._workflow_id = workflow.id | self._workflow_id = workflow.id | ||||
| self._workflow_features_dict = workflow.features_dict | self._workflow_features_dict = workflow.features_dict | ||||
| self._conversation_id = conversation.id | self._conversation_id = conversation.id | ||||
| self._conversation_mode = conversation.mode | self._conversation_mode = conversation.mode | ||||
| self._message_id = message.id | self._message_id = message.id | ||||
| self._message_created_at = int(message.created_at.timestamp()) | self._message_created_at = int(message.created_at.timestamp()) | ||||
| self._workflow_system_variables = { | |||||
| SystemVariableKey.QUERY: message.query, | |||||
| SystemVariableKey.FILES: application_generate_entity.files, | |||||
| SystemVariableKey.CONVERSATION_ID: conversation.id, | |||||
| SystemVariableKey.USER_ID: user_session_id, | |||||
| SystemVariableKey.DIALOGUE_COUNT: dialogue_count, | |||||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||||
| } | |||||
| self._task_state = WorkflowTaskState() | |||||
| self._wip_workflow_node_executions = {} | |||||
| self._wip_workflow_agent_logs = {} | |||||
| self._conversation_name_generate_thread = None | |||||
| self._conversation_name_generate_thread: Thread | None = None | |||||
| self._recorded_files: list[Mapping[str, Any]] = [] | self._recorded_files: list[Mapping[str, Any]] = [] | ||||
| self._workflow_run_id = "" | |||||
| self._workflow_run_id: str = "" | |||||
| def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: | def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: | ||||
| """ | """ | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # start generate conversation name thread | # start generate conversation name thread | ||||
| self._conversation_name_generate_thread = self._generate_conversation_name( | |||||
| self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name( | |||||
| conversation_id=self._conversation_id, query=self._application_generate_entity.query | conversation_id=self._conversation_id, query=self._application_generate_entity.query | ||||
| ) | ) | ||||
| generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) | generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) | ||||
| if self._stream: | |||||
| if self._base_task_pipeline._stream: | |||||
| return self._to_stream_response(generator) | return self._to_stream_response(generator) | ||||
| else: | else: | ||||
| return self._to_blocking_response(generator) | return self._to_blocking_response(generator) | ||||
| # init fake graph runtime state | # init fake graph runtime state | ||||
| graph_runtime_state: Optional[GraphRuntimeState] = None | graph_runtime_state: Optional[GraphRuntimeState] = None | ||||
| for queue_message in self._queue_manager.listen(): | |||||
| for queue_message in self._base_task_pipeline._queue_manager.listen(): | |||||
| event = queue_message.event | event = queue_message.event | ||||
| if isinstance(event, QueuePingEvent): | if isinstance(event, QueuePingEvent): | ||||
| yield self._ping_stream_response() | |||||
| yield self._base_task_pipeline._ping_stream_response() | |||||
| elif isinstance(event, QueueErrorEvent): | elif isinstance(event, QueueErrorEvent): | ||||
| with Session(db.engine) as session: | |||||
| err = self._handle_error(event=event, session=session, message_id=self._message_id) | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| err = self._base_task_pipeline._handle_error( | |||||
| event=event, session=session, message_id=self._message_id | |||||
| ) | |||||
| session.commit() | session.commit() | ||||
| yield self._error_to_stream_response(err) | |||||
| yield self._base_task_pipeline._error_to_stream_response(err) | |||||
| break | break | ||||
| elif isinstance(event, QueueWorkflowStartedEvent): | elif isinstance(event, QueueWorkflowStartedEvent): | ||||
| # override graph runtime state | # override graph runtime state | ||||
| graph_runtime_state = event.graph_runtime_state | graph_runtime_state = event.graph_runtime_state | ||||
| with Session(db.engine) as session: | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| # init workflow run | # init workflow run | ||||
| workflow_run = self._handle_workflow_run_start( | |||||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( | |||||
| session=session, | session=session, | ||||
| workflow_id=self._workflow_id, | workflow_id=self._workflow_id, | ||||
| user_id=self._user_id, | user_id=self._user_id, | ||||
| if not message: | if not message: | ||||
| raise ValueError(f"Message not found: {self._message_id}") | raise ValueError(f"Message not found: {self._message_id}") | ||||
| message.workflow_run_id = workflow_run.id | message.workflow_run_id = workflow_run.id | ||||
| workflow_start_resp = self._workflow_start_to_stream_response( | |||||
| workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( | |||||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | ||||
| ) | ) | ||||
| session.commit() | session.commit() | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| workflow_node_execution = self._handle_workflow_node_execution_retried( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( | |||||
| session=session, workflow_run=workflow_run, event=event | session=session, workflow_run=workflow_run, event=event | ||||
| ) | ) | ||||
| node_retry_resp = self._workflow_node_retry_to_stream_response( | |||||
| node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( | |||||
| session=session, | session=session, | ||||
| event=event, | event=event, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| workflow_node_execution = self._handle_node_execution_start( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( | |||||
| session=session, workflow_run=workflow_run, event=event | session=session, workflow_run=workflow_run, event=event | ||||
| ) | ) | ||||
| node_start_resp = self._workflow_node_start_to_stream_response( | |||||
| node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( | |||||
| session=session, | session=session, | ||||
| event=event, | event=event, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| elif isinstance(event, QueueNodeSucceededEvent): | elif isinstance(event, QueueNodeSucceededEvent): | ||||
| # Record files if it's an answer node or end node | # Record files if it's an answer node or end node | ||||
| if event.node_type in [NodeType.ANSWER, NodeType.END]: | if event.node_type in [NodeType.ANSWER, NodeType.END]: | ||||
| self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) | |||||
| self._recorded_files.extend( | |||||
| self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {}) | |||||
| ) | |||||
| with Session(db.engine) as session: | |||||
| workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( | |||||
| session=session, event=event | |||||
| ) | |||||
| node_finish_resp = self._workflow_node_finish_to_stream_response( | |||||
| node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( | |||||
| session=session, | session=session, | ||||
| event=event, | event=event, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| if node_finish_resp: | if node_finish_resp: | ||||
| yield node_finish_resp | yield node_finish_resp | ||||
| elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): | elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): | ||||
| with Session(db.engine) as session: | |||||
| workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event) | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( | |||||
| session=session, event=event | |||||
| ) | |||||
| node_finish_resp = self._workflow_node_finish_to_stream_response( | |||||
| node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( | |||||
| session=session, | session=session, | ||||
| event=event, | event=event, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( | |||||
| session=session, | |||||
| task_id=self._application_generate_entity.task_id, | |||||
| workflow_run=workflow_run, | |||||
| event=event, | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| parallel_start_resp = ( | |||||
| self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( | |||||
| session=session, | |||||
| task_id=self._application_generate_entity.task_id, | |||||
| workflow_run=workflow_run, | |||||
| event=event, | |||||
| ) | |||||
| ) | ) | ||||
| yield parallel_start_resp | yield parallel_start_resp | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( | |||||
| session=session, | |||||
| task_id=self._application_generate_entity.task_id, | |||||
| workflow_run=workflow_run, | |||||
| event=event, | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| parallel_finish_resp = ( | |||||
| self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( | |||||
| session=session, | |||||
| task_id=self._application_generate_entity.task_id, | |||||
| workflow_run=workflow_run, | |||||
| event=event, | |||||
| ) | |||||
| ) | ) | ||||
| yield parallel_finish_resp | yield parallel_finish_resp | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| iter_start_resp = self._workflow_iteration_start_to_stream_response( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( | |||||
| session=session, | session=session, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| workflow_run=workflow_run, | workflow_run=workflow_run, | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| iter_next_resp = self._workflow_iteration_next_to_stream_response( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( | |||||
| session=session, | session=session, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| workflow_run=workflow_run, | workflow_run=workflow_run, | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| iter_finish_resp = self._workflow_iteration_completed_to_stream_response( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( | |||||
| session=session, | session=session, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| workflow_run=workflow_run, | workflow_run=workflow_run, | ||||
| if not graph_runtime_state: | if not graph_runtime_state: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._handle_workflow_run_success( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( | |||||
| session=session, | session=session, | ||||
| workflow_run_id=self._workflow_run_id, | workflow_run_id=self._workflow_run_id, | ||||
| start_at=graph_runtime_state.start_at, | start_at=graph_runtime_state.start_at, | ||||
| trace_manager=trace_manager, | trace_manager=trace_manager, | ||||
| ) | ) | ||||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | ||||
| ) | ) | ||||
| session.commit() | session.commit() | ||||
| yield workflow_finish_resp | yield workflow_finish_resp | ||||
| self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) | |||||
| self._base_task_pipeline._queue_manager.publish( | |||||
| QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE | |||||
| ) | |||||
| elif isinstance(event, QueueWorkflowPartialSuccessEvent): | elif isinstance(event, QueueWorkflowPartialSuccessEvent): | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| if not graph_runtime_state: | if not graph_runtime_state: | ||||
| raise ValueError("graph runtime state not initialized.") | raise ValueError("graph runtime state not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._handle_workflow_run_partial_success( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( | |||||
| session=session, | session=session, | ||||
| workflow_run_id=self._workflow_run_id, | workflow_run_id=self._workflow_run_id, | ||||
| start_at=graph_runtime_state.start_at, | start_at=graph_runtime_state.start_at, | ||||
| conversation_id=None, | conversation_id=None, | ||||
| trace_manager=trace_manager, | trace_manager=trace_manager, | ||||
| ) | ) | ||||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | ||||
| ) | ) | ||||
| session.commit() | session.commit() | ||||
| yield workflow_finish_resp | yield workflow_finish_resp | ||||
| self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) | |||||
| self._base_task_pipeline._queue_manager.publish( | |||||
| QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE | |||||
| ) | |||||
| elif isinstance(event, QueueWorkflowFailedEvent): | elif isinstance(event, QueueWorkflowFailedEvent): | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| if not graph_runtime_state: | if not graph_runtime_state: | ||||
| raise ValueError("graph runtime state not initialized.") | raise ValueError("graph runtime state not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._handle_workflow_run_failed( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( | |||||
| session=session, | session=session, | ||||
| workflow_run_id=self._workflow_run_id, | workflow_run_id=self._workflow_run_id, | ||||
| start_at=graph_runtime_state.start_at, | start_at=graph_runtime_state.start_at, | ||||
| trace_manager=trace_manager, | trace_manager=trace_manager, | ||||
| exceptions_count=event.exceptions_count, | exceptions_count=event.exceptions_count, | ||||
| ) | ) | ||||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | ||||
| ) | ) | ||||
| err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) | err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) | ||||
| err = self._handle_error(event=err_event, session=session, message_id=self._message_id) | |||||
| err = self._base_task_pipeline._handle_error( | |||||
| event=err_event, session=session, message_id=self._message_id | |||||
| ) | |||||
| session.commit() | session.commit() | ||||
| yield workflow_finish_resp | yield workflow_finish_resp | ||||
| yield self._error_to_stream_response(err) | |||||
| yield self._base_task_pipeline._error_to_stream_response(err) | |||||
| break | break | ||||
| elif isinstance(event, QueueStopEvent): | elif isinstance(event, QueueStopEvent): | ||||
| if self._workflow_run_id and graph_runtime_state: | if self._workflow_run_id and graph_runtime_state: | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._handle_workflow_run_failed( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( | |||||
| session=session, | session=session, | ||||
| workflow_run_id=self._workflow_run_id, | workflow_run_id=self._workflow_run_id, | ||||
| start_at=graph_runtime_state.start_at, | start_at=graph_runtime_state.start_at, | ||||
| conversation_id=self._conversation_id, | conversation_id=self._conversation_id, | ||||
| trace_manager=trace_manager, | trace_manager=trace_manager, | ||||
| ) | ) | ||||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||||
| session=session, | session=session, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| workflow_run=workflow_run, | workflow_run=workflow_run, | ||||
| yield self._message_end_to_stream_response() | yield self._message_end_to_stream_response() | ||||
| break | break | ||||
| elif isinstance(event, QueueRetrieverResourcesEvent): | elif isinstance(event, QueueRetrieverResourcesEvent): | ||||
| self._handle_retriever_resources(event) | |||||
| self._message_cycle_manager._handle_retriever_resources(event) | |||||
| with Session(db.engine) as session: | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| message = self._get_message(session=session) | message = self._get_message(session=session) | ||||
| message.message_metadata = ( | message.message_metadata = ( | ||||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | ||||
| ) | ) | ||||
| session.commit() | session.commit() | ||||
| elif isinstance(event, QueueAnnotationReplyEvent): | elif isinstance(event, QueueAnnotationReplyEvent): | ||||
| self._handle_annotation_reply(event) | |||||
| self._message_cycle_manager._handle_annotation_reply(event) | |||||
| with Session(db.engine) as session: | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| message = self._get_message(session=session) | message = self._get_message(session=session) | ||||
| message.message_metadata = ( | message.message_metadata = ( | ||||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | ||||
| tts_publisher.publish(queue_message) | tts_publisher.publish(queue_message) | ||||
| self._task_state.answer += delta_text | self._task_state.answer += delta_text | ||||
| yield self._message_to_stream_response( | |||||
| yield self._message_cycle_manager._message_to_stream_response( | |||||
| answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector | answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector | ||||
| ) | ) | ||||
| elif isinstance(event, QueueMessageReplaceEvent): | elif isinstance(event, QueueMessageReplaceEvent): | ||||
| # published by moderation | # published by moderation | ||||
| yield self._message_replace_to_stream_response(answer=event.text) | |||||
| yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text) | |||||
| elif isinstance(event, QueueAdvancedChatMessageEndEvent): | elif isinstance(event, QueueAdvancedChatMessageEndEvent): | ||||
| if not graph_runtime_state: | if not graph_runtime_state: | ||||
| raise ValueError("graph runtime state not initialized.") | raise ValueError("graph runtime state not initialized.") | ||||
| output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) | |||||
| output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( | |||||
| self._task_state.answer | |||||
| ) | |||||
| if output_moderation_answer: | if output_moderation_answer: | ||||
| self._task_state.answer = output_moderation_answer | self._task_state.answer = output_moderation_answer | ||||
| yield self._message_replace_to_stream_response(answer=output_moderation_answer) | |||||
| yield self._message_cycle_manager._message_replace_to_stream_response( | |||||
| answer=output_moderation_answer | |||||
| ) | |||||
| # Save message | # Save message | ||||
| with Session(db.engine) as session: | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| self._save_message(session=session, graph_runtime_state=graph_runtime_state) | self._save_message(session=session, graph_runtime_state=graph_runtime_state) | ||||
| session.commit() | session.commit() | ||||
| def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: | def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: | ||||
| message = self._get_message(session=session) | message = self._get_message(session=session) | ||||
| message.answer = self._task_state.answer | message.answer = self._task_state.answer | ||||
| message.provider_response_latency = time.perf_counter() - self._start_at | |||||
| message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at | |||||
| message.message_metadata = ( | message.message_metadata = ( | ||||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | ||||
| ) | ) | ||||
| :param text: text | :param text: text | ||||
| :return: True if output moderation should direct output, otherwise False | :return: True if output moderation should direct output, otherwise False | ||||
| """ | """ | ||||
| if self._output_moderation_handler: | |||||
| if self._output_moderation_handler.should_direct_output(): | |||||
| if self._base_task_pipeline._output_moderation_handler: | |||||
| if self._base_task_pipeline._output_moderation_handler.should_direct_output(): | |||||
| # stop subscribe new token when output moderation should direct output | # stop subscribe new token when output moderation should direct output | ||||
| self._task_state.answer = self._output_moderation_handler.get_final_output() | |||||
| self._queue_manager.publish( | |||||
| self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() | |||||
| self._base_task_pipeline._queue_manager.publish( | |||||
| QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE | QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE | ||||
| ) | ) | ||||
| self._queue_manager.publish( | |||||
| self._base_task_pipeline._queue_manager.publish( | |||||
| QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE | QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE | ||||
| ) | ) | ||||
| return True | return True | ||||
| else: | else: | ||||
| self._output_moderation_handler.append_new_token(text) | |||||
| self._base_task_pipeline._output_moderation_handler.append_new_token(text) | |||||
| return False | return False | ||||
| from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | ||||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | ||||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom | from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom | ||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||||
| from core.ops.ops_trace_manager import TraceQueueManager | from core.ops.ops_trace_manager import TraceQueueManager | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| except ValidationError as e: | except ValidationError as e: | ||||
| logger.exception("Validation Error when generating") | logger.exception("Validation Error when generating") | ||||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | ||||
| except (ValueError, InvokeError) as e: | |||||
| except ValueError as e: | |||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| logger.exception("Error when generating") | logger.exception("Error when generating") | ||||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | 
| from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | ||||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | ||||
| from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom | from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom | ||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||||
| from core.ops.ops_trace_manager import TraceQueueManager | from core.ops.ops_trace_manager import TraceQueueManager | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| except ValidationError as e: | except ValidationError as e: | ||||
| logger.exception("Validation Error when generating") | logger.exception("Validation Error when generating") | ||||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | ||||
| except (ValueError, InvokeError) as e: | |||||
| except ValueError as e: | |||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| logger.exception("Error when generating") | logger.exception("Error when generating") | ||||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | 
| from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | ||||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | ||||
| from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom | from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom | ||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||||
| from core.ops.ops_trace_manager import TraceQueueManager | from core.ops.ops_trace_manager import TraceQueueManager | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| except ValidationError as e: | except ValidationError as e: | ||||
| logger.exception("Validation Error when generating") | logger.exception("Validation Error when generating") | ||||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | ||||
| except (ValueError, InvokeError) as e: | |||||
| except ValueError as e: | |||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| logger.exception("Error when generating") | logger.exception("Error when generating") | ||||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | 
| from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline | from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity | from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity | ||||
| from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse | from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse | ||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||||
| from core.ops.ops_trace_manager import TraceQueueManager | from core.ops.ops_trace_manager import TraceQueueManager | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from factories import file_factory | from factories import file_factory | ||||
| single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( | single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( | ||||
| node_id=node_id, inputs=args["inputs"] | node_id=node_id, inputs=args["inputs"] | ||||
| ), | ), | ||||
| workflow_run_id=str(uuid.uuid4()), | |||||
| ) | ) | ||||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | ||||
| contexts.plugin_tool_providers.set({}) | contexts.plugin_tool_providers.set({}) | ||||
| except ValidationError as e: | except ValidationError as e: | ||||
| logger.exception("Validation Error when generating") | logger.exception("Validation Error when generating") | ||||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | ||||
| except (ValueError, InvokeError) as e: | |||||
| except ValueError as e: | |||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| logger.exception("Error when generating") | logger.exception("Error when generating") | ||||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | 
| import logging | import logging | ||||
| import time | import time | ||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from typing import Any, Optional, Union | |||||
| from typing import Optional, Union | |||||
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| Workflow, | Workflow, | ||||
| WorkflowAppLog, | WorkflowAppLog, | ||||
| WorkflowAppLogCreatedFrom, | WorkflowAppLogCreatedFrom, | ||||
| WorkflowNodeExecution, | |||||
| WorkflowRun, | WorkflowRun, | ||||
| WorkflowRunStatus, | WorkflowRunStatus, | ||||
| ) | ) | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage): | |||||
| class WorkflowAppGenerateTaskPipeline: | |||||
| """ | """ | ||||
| WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. | WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. | ||||
| """ | """ | ||||
| _task_state: WorkflowTaskState | |||||
| _application_generate_entity: WorkflowAppGenerateEntity | |||||
| _workflow_system_variables: dict[SystemVariableKey, Any] | |||||
| _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| application_generate_entity: WorkflowAppGenerateEntity, | application_generate_entity: WorkflowAppGenerateEntity, | ||||
| user: Union[Account, EndUser], | user: Union[Account, EndUser], | ||||
| stream: bool, | stream: bool, | ||||
| ) -> None: | ) -> None: | ||||
| super().__init__( | |||||
| self._base_task_pipeline = BasedGenerateTaskPipeline( | |||||
| application_generate_entity=application_generate_entity, | application_generate_entity=application_generate_entity, | ||||
| queue_manager=queue_manager, | queue_manager=queue_manager, | ||||
| stream=stream, | stream=stream, | ||||
| else: | else: | ||||
| raise ValueError(f"Invalid user type: {type(user)}") | raise ValueError(f"Invalid user type: {type(user)}") | ||||
| self._workflow_cycle_manager = WorkflowCycleManage( | |||||
| application_generate_entity=application_generate_entity, | |||||
| workflow_system_variables={ | |||||
| SystemVariableKey.FILES: application_generate_entity.files, | |||||
| SystemVariableKey.USER_ID: user_session_id, | |||||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||||
| }, | |||||
| ) | |||||
| self._application_generate_entity = application_generate_entity | |||||
| self._workflow_id = workflow.id | self._workflow_id = workflow.id | ||||
| self._workflow_features_dict = workflow.features_dict | self._workflow_features_dict = workflow.features_dict | ||||
| self._workflow_system_variables = { | |||||
| SystemVariableKey.FILES: application_generate_entity.files, | |||||
| SystemVariableKey.USER_ID: user_session_id, | |||||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||||
| } | |||||
| self._task_state = WorkflowTaskState() | self._task_state = WorkflowTaskState() | ||||
| self._workflow_run_id = "" | self._workflow_run_id = "" | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) | generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) | ||||
| if self._stream: | |||||
| if self._base_task_pipeline._stream: | |||||
| return self._to_stream_response(generator) | return self._to_stream_response(generator) | ||||
| else: | else: | ||||
| return self._to_blocking_response(generator) | return self._to_blocking_response(generator) | ||||
| """ | """ | ||||
| graph_runtime_state = None | graph_runtime_state = None | ||||
| for queue_message in self._queue_manager.listen(): | |||||
| for queue_message in self._base_task_pipeline._queue_manager.listen(): | |||||
| event = queue_message.event | event = queue_message.event | ||||
| if isinstance(event, QueuePingEvent): | if isinstance(event, QueuePingEvent): | ||||
| yield self._ping_stream_response() | |||||
| yield self._base_task_pipeline._ping_stream_response() | |||||
| elif isinstance(event, QueueErrorEvent): | elif isinstance(event, QueueErrorEvent): | ||||
| err = self._handle_error(event=event) | |||||
| yield self._error_to_stream_response(err) | |||||
| err = self._base_task_pipeline._handle_error(event=event) | |||||
| yield self._base_task_pipeline._error_to_stream_response(err) | |||||
| break | break | ||||
| elif isinstance(event, QueueWorkflowStartedEvent): | elif isinstance(event, QueueWorkflowStartedEvent): | ||||
| # override graph runtime state | # override graph runtime state | ||||
| graph_runtime_state = event.graph_runtime_state | graph_runtime_state = event.graph_runtime_state | ||||
| with Session(db.engine) as session: | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| # init workflow run | # init workflow run | ||||
| workflow_run = self._handle_workflow_run_start( | |||||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( | |||||
| session=session, | session=session, | ||||
| workflow_id=self._workflow_id, | workflow_id=self._workflow_id, | ||||
| user_id=self._user_id, | user_id=self._user_id, | ||||
| created_by_role=self._created_by_role, | created_by_role=self._created_by_role, | ||||
| ) | ) | ||||
| self._workflow_run_id = workflow_run.id | self._workflow_run_id = workflow_run.id | ||||
| start_resp = self._workflow_start_to_stream_response( | |||||
| start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( | |||||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | ||||
| ) | ) | ||||
| session.commit() | session.commit() | ||||
| ): | ): | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| workflow_node_execution = self._handle_workflow_node_execution_retried( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( | |||||
| session=session, workflow_run=workflow_run, event=event | session=session, workflow_run=workflow_run, event=event | ||||
| ) | ) | ||||
| response = self._workflow_node_retry_to_stream_response( | |||||
| response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( | |||||
| session=session, | session=session, | ||||
| event=event, | event=event, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| workflow_node_execution = self._handle_node_execution_start( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( | |||||
| session=session, workflow_run=workflow_run, event=event | session=session, workflow_run=workflow_run, event=event | ||||
| ) | ) | ||||
| node_start_response = self._workflow_node_start_to_stream_response( | |||||
| node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( | |||||
| session=session, | session=session, | ||||
| event=event, | event=event, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| if node_start_response: | if node_start_response: | ||||
| yield node_start_response | yield node_start_response | ||||
| elif isinstance(event, QueueNodeSucceededEvent): | elif isinstance(event, QueueNodeSucceededEvent): | ||||
| with Session(db.engine) as session: | |||||
| workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) | |||||
| node_success_response = self._workflow_node_finish_to_stream_response( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( | |||||
| session=session, event=event | |||||
| ) | |||||
| node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( | |||||
| session=session, | session=session, | ||||
| event=event, | event=event, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| if node_success_response: | if node_success_response: | ||||
| yield node_success_response | yield node_success_response | ||||
| elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): | elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): | ||||
| with Session(db.engine) as session: | |||||
| workflow_node_execution = self._handle_workflow_node_execution_failed( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( | |||||
| session=session, | session=session, | ||||
| event=event, | event=event, | ||||
| ) | ) | ||||
| node_failed_response = self._workflow_node_finish_to_stream_response( | |||||
| node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( | |||||
| session=session, | session=session, | ||||
| event=event, | event=event, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( | |||||
| session=session, | |||||
| task_id=self._application_generate_entity.task_id, | |||||
| workflow_run=workflow_run, | |||||
| event=event, | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| parallel_start_resp = ( | |||||
| self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( | |||||
| session=session, | |||||
| task_id=self._application_generate_entity.task_id, | |||||
| workflow_run=workflow_run, | |||||
| event=event, | |||||
| ) | |||||
| ) | ) | ||||
| yield parallel_start_resp | yield parallel_start_resp | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( | |||||
| session=session, | |||||
| task_id=self._application_generate_entity.task_id, | |||||
| workflow_run=workflow_run, | |||||
| event=event, | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| parallel_finish_resp = ( | |||||
| self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( | |||||
| session=session, | |||||
| task_id=self._application_generate_entity.task_id, | |||||
| workflow_run=workflow_run, | |||||
| event=event, | |||||
| ) | |||||
| ) | ) | ||||
| yield parallel_finish_resp | yield parallel_finish_resp | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| iter_start_resp = self._workflow_iteration_start_to_stream_response( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( | |||||
| session=session, | session=session, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| workflow_run=workflow_run, | workflow_run=workflow_run, | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| iter_next_resp = self._workflow_iteration_next_to_stream_response( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( | |||||
| session=session, | session=session, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| workflow_run=workflow_run, | workflow_run=workflow_run, | ||||
| if not self._workflow_run_id: | if not self._workflow_run_id: | ||||
| raise ValueError("workflow run not initialized.") | raise ValueError("workflow run not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||||
| iter_finish_resp = self._workflow_iteration_completed_to_stream_response( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||||
| session=session, workflow_run_id=self._workflow_run_id | |||||
| ) | |||||
| iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( | |||||
| session=session, | session=session, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| workflow_run=workflow_run, | workflow_run=workflow_run, | ||||
| if not graph_runtime_state: | if not graph_runtime_state: | ||||
| raise ValueError("graph runtime state not initialized.") | raise ValueError("graph runtime state not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._handle_workflow_run_success( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( | |||||
| session=session, | session=session, | ||||
| workflow_run_id=self._workflow_run_id, | workflow_run_id=self._workflow_run_id, | ||||
| start_at=graph_runtime_state.start_at, | start_at=graph_runtime_state.start_at, | ||||
| # save workflow app log | # save workflow app log | ||||
| self._save_workflow_app_log(session=session, workflow_run=workflow_run) | self._save_workflow_app_log(session=session, workflow_run=workflow_run) | ||||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||||
| session=session, | session=session, | ||||
| task_id=self._application_generate_entity.task_id, | task_id=self._application_generate_entity.task_id, | ||||
| workflow_run=workflow_run, | workflow_run=workflow_run, | ||||
| if not graph_runtime_state: | if not graph_runtime_state: | ||||
| raise ValueError("graph runtime state not initialized.") | raise ValueError("graph runtime state not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._handle_workflow_run_partial_success( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( | |||||
| session=session, | session=session, | ||||
| workflow_run_id=self._workflow_run_id, | workflow_run_id=self._workflow_run_id, | ||||
| start_at=graph_runtime_state.start_at, | start_at=graph_runtime_state.start_at, | ||||
| # save workflow app log | # save workflow app log | ||||
| self._save_workflow_app_log(session=session, workflow_run=workflow_run) | self._save_workflow_app_log(session=session, workflow_run=workflow_run) | ||||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | ||||
| ) | ) | ||||
| session.commit() | session.commit() | ||||
| if not graph_runtime_state: | if not graph_runtime_state: | ||||
| raise ValueError("graph runtime state not initialized.") | raise ValueError("graph runtime state not initialized.") | ||||
| with Session(db.engine) as session: | |||||
| workflow_run = self._handle_workflow_run_failed( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( | |||||
| session=session, | session=session, | ||||
| workflow_run_id=self._workflow_run_id, | workflow_run_id=self._workflow_run_id, | ||||
| start_at=graph_runtime_state.start_at, | start_at=graph_runtime_state.start_at, | ||||
| # save workflow app log | # save workflow app log | ||||
| self._save_workflow_app_log(session=session, workflow_run=workflow_run) | self._save_workflow_app_log(session=session, workflow_run=workflow_run) | ||||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | ||||
| ) | ) | ||||
| session.commit() | session.commit() | 
| # app config | # app config | ||||
| app_config: WorkflowUIBasedAppConfig | app_config: WorkflowUIBasedAppConfig | ||||
| workflow_run_id: Optional[str] = None | |||||
| workflow_run_id: str | |||||
| class SingleIterationRunEntity(BaseModel): | class SingleIterationRunEntity(BaseModel): | ||||
| """ | """ | 
| from core.app.entities.task_entities import ( | from core.app.entities.task_entities import ( | ||||
| ErrorStreamResponse, | ErrorStreamResponse, | ||||
| PingStreamResponse, | PingStreamResponse, | ||||
| TaskState, | |||||
| ) | ) | ||||
| from core.errors.error import QuotaExceededError | from core.errors.error import QuotaExceededError | ||||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | ||||
| BasedGenerateTaskPipeline is a class that generate stream output and state management for Application. | BasedGenerateTaskPipeline is a class that generate stream output and state management for Application. | ||||
| """ | """ | ||||
| _task_state: TaskState | |||||
| _application_generate_entity: AppGenerateEntity | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| application_generate_entity: AppGenerateEntity, | application_generate_entity: AppGenerateEntity, | ||||
| queue_manager: AppQueueManager, | queue_manager: AppQueueManager, | ||||
| stream: bool, | stream: bool, | ||||
| ) -> None: | ) -> None: | ||||
| """ | |||||
| Initialize GenerateTaskPipeline. | |||||
| :param application_generate_entity: application generate entity | |||||
| :param queue_manager: queue manager | |||||
| :param user: user | |||||
| :param stream: stream | |||||
| """ | |||||
| self._application_generate_entity = application_generate_entity | self._application_generate_entity = application_generate_entity | ||||
| self._queue_manager = queue_manager | self._queue_manager = queue_manager | ||||
| self._start_at = time.perf_counter() | self._start_at = time.perf_counter() | 
| class MessageCycleManage: | class MessageCycleManage: | ||||
| _application_generate_entity: Union[ | |||||
| ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity | |||||
| ] | |||||
| _task_state: Union[EasyUITaskState, WorkflowTaskState] | |||||
| def __init__( | |||||
| self, | |||||
| *, | |||||
| application_generate_entity: Union[ | |||||
| ChatAppGenerateEntity, | |||||
| CompletionAppGenerateEntity, | |||||
| AgentChatAppGenerateEntity, | |||||
| AdvancedChatAppGenerateEntity, | |||||
| ], | |||||
| task_state: Union[EasyUITaskState, WorkflowTaskState], | |||||
| ) -> None: | |||||
| self._application_generate_entity = application_generate_entity | |||||
| self._task_state = task_state | |||||
| def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: | def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: | ||||
| """ | """ | 
| ParallelBranchStartStreamResponse, | ParallelBranchStartStreamResponse, | ||||
| WorkflowFinishStreamResponse, | WorkflowFinishStreamResponse, | ||||
| WorkflowStartStreamResponse, | WorkflowStartStreamResponse, | ||||
| WorkflowTaskState, | |||||
| ) | ) | ||||
| from core.file import FILE_MODEL_IDENTITY, File | from core.file import FILE_MODEL_IDENTITY, File | ||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| WorkflowRunStatus, | WorkflowRunStatus, | ||||
| ) | ) | ||||
| from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError | |||||
| from .exc import WorkflowRunNotFoundError | |||||
| class WorkflowCycleManage: | class WorkflowCycleManage: | ||||
| _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] | |||||
| _task_state: WorkflowTaskState | |||||
| _workflow_system_variables: dict[SystemVariableKey, Any] | |||||
| def __init__( | |||||
| self, | |||||
| *, | |||||
| application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], | |||||
| workflow_system_variables: dict[SystemVariableKey, Any], | |||||
| ) -> None: | |||||
| self._workflow_run: WorkflowRun | None = None | |||||
| self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} | |||||
| self._application_generate_entity = application_generate_entity | |||||
| self._workflow_system_variables = workflow_system_variables | |||||
| def _handle_workflow_run_start( | def _handle_workflow_run_start( | ||||
| self, | self, | ||||
| inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) | inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) | ||||
| # init workflow run | # init workflow run | ||||
| workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4())) | |||||
| # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this | |||||
| workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) | |||||
| workflow_run = WorkflowRun() | workflow_run = WorkflowRun() | ||||
| workflow_run.id = workflow_run_id | workflow_run.id = workflow_run_id | ||||
| workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) | workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) | ||||
| workflow_run.exceptions_count = exceptions_count | workflow_run.exceptions_count = exceptions_count | ||||
| stmt = select(WorkflowNodeExecution).where( | |||||
| stmt = select(WorkflowNodeExecution.node_execution_id).where( | |||||
| WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, | WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, | ||||
| WorkflowNodeExecution.app_id == workflow_run.app_id, | WorkflowNodeExecution.app_id == workflow_run.app_id, | ||||
| WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, | WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, | ||||
| WorkflowNodeExecution.workflow_run_id == workflow_run.id, | WorkflowNodeExecution.workflow_run_id == workflow_run.id, | ||||
| WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, | WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, | ||||
| ) | ) | ||||
| running_workflow_node_executions = session.scalars(stmt).all() | |||||
| ids = session.scalars(stmt).all() | |||||
| # Use self._get_workflow_node_execution here to make sure the cache is updated | |||||
| running_workflow_node_executions = [ | |||||
| self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id | |||||
| ] | |||||
| for workflow_node_execution in running_workflow_node_executions: | for workflow_node_execution in running_workflow_node_executions: | ||||
| now = datetime.now(UTC).replace(tzinfo=None) | |||||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value | workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value | ||||
| workflow_node_execution.error = error | workflow_node_execution.error = error | ||||
| finish_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| workflow_node_execution.finished_at = finish_at | |||||
| workflow_node_execution.elapsed_time = (finish_at - workflow_node_execution.created_at).total_seconds() | |||||
| workflow_node_execution.finished_at = now | |||||
| workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds() | |||||
| if trace_manager: | if trace_manager: | ||||
| trace_manager.add_trace_task( | trace_manager.add_trace_task( | ||||
| workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) | workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) | ||||
| session.add(workflow_node_execution) | session.add(workflow_node_execution) | ||||
| self._workflow_node_executions[event.node_execution_id] = workflow_node_execution | |||||
| return workflow_node_execution | return workflow_node_execution | ||||
| def _handle_workflow_node_execution_success( | def _handle_workflow_node_execution_success( | ||||
| workflow_node_execution.finished_at = finished_at | workflow_node_execution.finished_at = finished_at | ||||
| workflow_node_execution.elapsed_time = elapsed_time | workflow_node_execution.elapsed_time = elapsed_time | ||||
| workflow_node_execution = session.merge(workflow_node_execution) | |||||
| return workflow_node_execution | return workflow_node_execution | ||||
| def _handle_workflow_node_execution_failed( | def _handle_workflow_node_execution_failed( | ||||
| workflow_node_execution.elapsed_time = elapsed_time | workflow_node_execution.elapsed_time = elapsed_time | ||||
| workflow_node_execution.execution_metadata = execution_metadata | workflow_node_execution.execution_metadata = execution_metadata | ||||
| workflow_node_execution = session.merge(workflow_node_execution) | |||||
| return workflow_node_execution | return workflow_node_execution | ||||
| def _handle_workflow_node_execution_retried( | def _handle_workflow_node_execution_retried( | ||||
| workflow_node_execution.index = event.node_run_index | workflow_node_execution.index = event.node_run_index | ||||
| session.add(workflow_node_execution) | session.add(workflow_node_execution) | ||||
| self._workflow_node_executions[event.node_execution_id] = workflow_node_execution | |||||
| return workflow_node_execution | return workflow_node_execution | ||||
| ################################################# | ################################################# | ||||
| return None | return None | ||||
| def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: | def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: | ||||
| """ | |||||
| Refetch workflow run | |||||
| :param workflow_run_id: workflow run id | |||||
| :return: | |||||
| """ | |||||
| if self._workflow_run and self._workflow_run.id == workflow_run_id: | |||||
| cached_workflow_run = self._workflow_run | |||||
| cached_workflow_run = session.merge(cached_workflow_run) | |||||
| return cached_workflow_run | |||||
| stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) | stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) | ||||
| workflow_run = session.scalar(stmt) | workflow_run = session.scalar(stmt) | ||||
| if not workflow_run: | if not workflow_run: | ||||
| raise WorkflowRunNotFoundError(workflow_run_id) | raise WorkflowRunNotFoundError(workflow_run_id) | ||||
| self._workflow_run = workflow_run | |||||
| return workflow_run | return workflow_run | ||||
| def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: | def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: | ||||
| stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id) | |||||
| workflow_node_execution = session.scalar(stmt) | |||||
| if not workflow_node_execution: | |||||
| raise WorkflowNodeExecutionNotFoundError(node_execution_id) | |||||
| return workflow_node_execution | |||||
| if node_execution_id not in self._workflow_node_executions: | |||||
| raise ValueError(f"Workflow node execution not found: {node_execution_id}") | |||||
| cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] | |||||
| return cached_workflow_node_execution | |||||
| def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: | def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: | ||||
| """ | """ | 
| import tiktoken | |||||
| from threading import Lock | |||||
| from typing import Any | |||||
| _tokenizer: Any = None | |||||
| _lock = Lock() | |||||
| class GPT2Tokenizer: | class GPT2Tokenizer: | ||||
| @staticmethod | |||||
| def _get_num_tokens_by_gpt2(text: str) -> int: | |||||
| """ | |||||
| use gpt2 tokenizer to get num tokens | |||||
| """ | |||||
| _tokenizer = GPT2Tokenizer.get_encoder() | |||||
| tokens = _tokenizer.encode(text) | |||||
| return len(tokens) | |||||
| @staticmethod | @staticmethod | ||||
| def get_num_tokens(text: str) -> int: | def get_num_tokens(text: str) -> int: | ||||
| encoding = tiktoken.encoding_for_model("gpt2") | |||||
| tiktoken_vec = encoding.encode(text) | |||||
| return len(tiktoken_vec) | |||||
| # Because this process needs more cpu resource, we turn this back before we find a better way to handle it. | |||||
| # | |||||
| # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text) | |||||
| # result = future.result() | |||||
| # return cast(int, result) | |||||
| return GPT2Tokenizer._get_num_tokens_by_gpt2(text) | |||||
| @staticmethod | |||||
| def get_encoder() -> Any: | |||||
| global _tokenizer, _lock | |||||
| with _lock: | |||||
| if _tokenizer is None: | |||||
| # Try to use tiktoken to get the tokenizer because it is faster | |||||
| # | |||||
| try: | |||||
| import tiktoken | |||||
| _tokenizer = tiktoken.get_encoding("gpt2") | |||||
| except Exception: | |||||
| from os.path import abspath, dirname, join | |||||
| from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore | |||||
| base_path = abspath(__file__) | |||||
| gpt2_tokenizer_path = join(dirname(base_path), "gpt2") | |||||
| _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) | |||||
| return _tokenizer | 
| return False | return False | ||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| if not ids: | |||||
| return | |||||
| quoted_ids = [f"'{id}'" for id in ids] | quoted_ids = [f"'{id}'" for id in ids] | ||||
| self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") | self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") | ||||
| self._client.delete_collection(self._collection_name) | self._client.delete_collection(self._collection_name) | ||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| if not ids: | |||||
| return | |||||
| collection = self._client.get_or_create_collection(self._collection_name) | collection = self._client.get_or_create_collection(self._collection_name) | ||||
| collection.delete(ids=ids) | collection.delete(ids=ids) | ||||
| import json | |||||
| import logging | |||||
| from typing import Any, Optional | |||||
| from flask import current_app | |||||
| from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ( | |||||
| ElasticSearchConfig, | |||||
| ElasticSearchVector, | |||||
| ElasticSearchVectorFactory, | |||||
| ) | |||||
| from core.rag.datasource.vdb.field import Field | |||||
| from core.rag.datasource.vdb.vector_type import VectorType | |||||
| from core.rag.embedding.embedding_base import Embeddings | |||||
| from extensions.ext_redis import redis_client | |||||
| from models.dataset import Dataset | |||||
| logger = logging.getLogger(__name__) | |||||
| class ElasticSearchJaVector(ElasticSearchVector): | |||||
| def create_collection( | |||||
| self, | |||||
| embeddings: list[list[float]], | |||||
| metadatas: Optional[list[dict[Any, Any]]] = None, | |||||
| index_params: Optional[dict] = None, | |||||
| ): | |||||
| lock_name = f"vector_indexing_lock_{self._collection_name}" | |||||
| with redis_client.lock(lock_name, timeout=20): | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | |||||
| logger.info(f"Collection {self._collection_name} already exists.") | |||||
| return | |||||
| if not self._client.indices.exists(index=self._collection_name): | |||||
| dim = len(embeddings[0]) | |||||
| settings = { | |||||
| "analysis": { | |||||
| "analyzer": { | |||||
| "ja_analyzer": { | |||||
| "type": "custom", | |||||
| "char_filter": [ | |||||
| "icu_normalizer", | |||||
| "kuromoji_iteration_mark", | |||||
| ], | |||||
| "tokenizer": "kuromoji_tokenizer", | |||||
| "filter": [ | |||||
| "kuromoji_baseform", | |||||
| "kuromoji_part_of_speech", | |||||
| "ja_stop", | |||||
| "kuromoji_number", | |||||
| "kuromoji_stemmer", | |||||
| ], | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| mappings = { | |||||
| "properties": { | |||||
| Field.CONTENT_KEY.value: { | |||||
| "type": "text", | |||||
| "analyzer": "ja_analyzer", | |||||
| "search_analyzer": "ja_analyzer", | |||||
| }, | |||||
| Field.VECTOR.value: { # Make sure the dimension is correct here | |||||
| "type": "dense_vector", | |||||
| "dims": dim, | |||||
| "index": True, | |||||
| "similarity": "cosine", | |||||
| }, | |||||
| Field.METADATA_KEY.value: { | |||||
| "type": "object", | |||||
| "properties": { | |||||
| "doc_id": {"type": "keyword"} # Map doc_id to keyword type | |||||
| }, | |||||
| }, | |||||
| } | |||||
| } | |||||
| self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings) | |||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||||
| class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory): | |||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector: | |||||
| if dataset.index_struct_dict: | |||||
| class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] | |||||
| collection_name = class_prefix | |||||
| else: | |||||
| dataset_id = dataset.id | |||||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||||
| dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) | |||||
| config = current_app.config | |||||
| return ElasticSearchJaVector( | |||||
| index_name=collection_name, | |||||
| config=ElasticSearchConfig( | |||||
| host=config.get("ELASTICSEARCH_HOST", "localhost"), | |||||
| port=config.get("ELASTICSEARCH_PORT", 9200), | |||||
| username=config.get("ELASTICSEARCH_USERNAME", ""), | |||||
| password=config.get("ELASTICSEARCH_PASSWORD", ""), | |||||
| ), | |||||
| attributes=[], | |||||
| ) | 
| return bool(self._client.exists(index=self._collection_name, id=id)) | return bool(self._client.exists(index=self._collection_name, id=id)) | ||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| if not ids: | |||||
| return | |||||
| for id in ids: | for id in ids: | ||||
| self._client.delete(index=self._collection_name, id=id) | self._client.delete(index=self._collection_name, id=id) | ||||
| METADATA_KEY = "metadata" | METADATA_KEY = "metadata" | ||||
| GROUP_KEY = "group_id" | GROUP_KEY = "group_id" | ||||
| VECTOR = "vector" | VECTOR = "vector" | ||||
| # Sparse Vector aims to support full text search | |||||
| SPARSE_VECTOR = "sparse_vector" | |||||
| TEXT_KEY = "text" | TEXT_KEY = "text" | ||||
| PRIMARY_KEY = "id" | PRIMARY_KEY = "id" | ||||
| DOC_ID = "metadata.doc_id" | DOC_ID = "metadata.doc_id" | 
| import logging | import logging | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from packaging import version | |||||
| from pydantic import BaseModel, model_validator | from pydantic import BaseModel, model_validator | ||||
| from pymilvus import MilvusClient, MilvusException # type: ignore | from pymilvus import MilvusClient, MilvusException # type: ignore | ||||
| from pymilvus.milvus_client import IndexParams # type: ignore | from pymilvus.milvus_client import IndexParams # type: ignore | ||||
| class MilvusConfig(BaseModel): | class MilvusConfig(BaseModel): | ||||
| uri: str | |||||
| token: Optional[str] = None | |||||
| user: str | |||||
| password: str | |||||
| batch_size: int = 100 | |||||
| database: str = "default" | |||||
| """ | |||||
| Configuration class for Milvus connection. | |||||
| """ | |||||
| uri: str # Milvus server URI | |||||
| token: Optional[str] = None # Optional token for authentication | |||||
| user: str # Username for authentication | |||||
| password: str # Password for authentication | |||||
| batch_size: int = 100 # Batch size for operations | |||||
| database: str = "default" # Database name | |||||
| enable_hybrid_search: bool = False # Flag to enable hybrid search | |||||
| @model_validator(mode="before") | @model_validator(mode="before") | ||||
| @classmethod | @classmethod | ||||
| def validate_config(cls, values: dict) -> dict: | def validate_config(cls, values: dict) -> dict: | ||||
| """ | |||||
| Validate the configuration values. | |||||
| Raises ValueError if required fields are missing. | |||||
| """ | |||||
| if not values.get("uri"): | if not values.get("uri"): | ||||
| raise ValueError("config MILVUS_URI is required") | raise ValueError("config MILVUS_URI is required") | ||||
| if not values.get("user"): | if not values.get("user"): | ||||
| return values | return values | ||||
| def to_milvus_params(self): | def to_milvus_params(self): | ||||
| """ | |||||
| Convert the configuration to a dictionary of Milvus connection parameters. | |||||
| """ | |||||
| return { | return { | ||||
| "uri": self.uri, | "uri": self.uri, | ||||
| "token": self.token, | "token": self.token, | ||||
| class MilvusVector(BaseVector): | class MilvusVector(BaseVector): | ||||
| """ | |||||
| Milvus vector storage implementation. | |||||
| """ | |||||
| def __init__(self, collection_name: str, config: MilvusConfig): | def __init__(self, collection_name: str, config: MilvusConfig): | ||||
| super().__init__(collection_name) | super().__init__(collection_name) | ||||
| self._client_config = config | self._client_config = config | ||||
| self._client = self._init_client(config) | self._client = self._init_client(config) | ||||
| self._consistency_level = "Session" | |||||
| self._fields: list[str] = [] | |||||
| self._consistency_level = "Session" # Consistency level for Milvus operations | |||||
| self._fields: list[str] = [] # List of fields in the collection | |||||
| self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported | |||||
| def _check_hybrid_search_support(self) -> bool: | |||||
| """ | |||||
| Check if the current Milvus version supports hybrid search. | |||||
| Returns True if the version is >= 2.5.0, otherwise False. | |||||
| """ | |||||
| if not self._client_config.enable_hybrid_search: | |||||
| return False | |||||
| try: | |||||
| milvus_version = self._client.get_server_version() | |||||
| return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version | |||||
| except Exception as e: | |||||
| logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.") | |||||
| return False | |||||
| def get_type(self) -> str: | def get_type(self) -> str: | ||||
| """ | |||||
| Get the type of vector storage (Milvus). | |||||
| """ | |||||
| return VectorType.MILVUS | return VectorType.MILVUS | ||||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | ||||
| """ | |||||
| Create a collection and add texts with embeddings. | |||||
| """ | |||||
| index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} | index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} | ||||
| metadatas = [d.metadata if d.metadata is not None else {} for d in texts] | metadatas = [d.metadata if d.metadata is not None else {} for d in texts] | ||||
| self.create_collection(embeddings, metadatas, index_params) | self.create_collection(embeddings, metadatas, index_params) | ||||
| self.add_texts(texts, embeddings) | self.add_texts(texts, embeddings) | ||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | ||||
| """ | |||||
| Add texts and their embeddings to the collection. | |||||
| """ | |||||
| insert_dict_list = [] | insert_dict_list = [] | ||||
| for i in range(len(documents)): | for i in range(len(documents)): | ||||
| insert_dict = { | insert_dict = { | ||||
| # Do not need to insert the sparse_vector field separately, as the text_bm25_emb | |||||
| # function will automatically convert the native text into a sparse vector for us. | |||||
| Field.CONTENT_KEY.value: documents[i].page_content, | Field.CONTENT_KEY.value: documents[i].page_content, | ||||
| Field.VECTOR.value: embeddings[i], | Field.VECTOR.value: embeddings[i], | ||||
| Field.METADATA_KEY.value: documents[i].metadata, | Field.METADATA_KEY.value: documents[i].metadata, | ||||
| insert_dict_list.append(insert_dict) | insert_dict_list.append(insert_dict) | ||||
| # Total insert count | # Total insert count | ||||
| total_count = len(insert_dict_list) | total_count = len(insert_dict_list) | ||||
| pks: list[str] = [] | pks: list[str] = [] | ||||
| for i in range(0, total_count, 1000): | for i in range(0, total_count, 1000): | ||||
| batch_insert_list = insert_dict_list[i : i + 1000] | |||||
| # Insert into the collection. | # Insert into the collection. | ||||
| batch_insert_list = insert_dict_list[i : i + 1000] | |||||
| try: | try: | ||||
| ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) | ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) | ||||
| pks.extend(ids) | pks.extend(ids) | ||||
| return pks | return pks | ||||
| def get_ids_by_metadata_field(self, key: str, value: str): | def get_ids_by_metadata_field(self, key: str, value: str): | ||||
| """ | |||||
| Get document IDs by metadata field key and value. | |||||
| """ | |||||
| result = self._client.query( | result = self._client.query( | ||||
| collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"] | collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"] | ||||
| ) | ) | ||||
| return None | return None | ||||
| def delete_by_metadata_field(self, key: str, value: str): | def delete_by_metadata_field(self, key: str, value: str): | ||||
| """ | |||||
| Delete documents by metadata field key and value. | |||||
| """ | |||||
| if self._client.has_collection(self._collection_name): | if self._client.has_collection(self._collection_name): | ||||
| ids = self.get_ids_by_metadata_field(key, value) | ids = self.get_ids_by_metadata_field(key, value) | ||||
| if ids: | if ids: | ||||
| self._client.delete(collection_name=self._collection_name, pks=ids) | self._client.delete(collection_name=self._collection_name, pks=ids) | ||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| """ | |||||
| Delete documents by their IDs. | |||||
| """ | |||||
| if self._client.has_collection(self._collection_name): | if self._client.has_collection(self._collection_name): | ||||
| result = self._client.query( | result = self._client.query( | ||||
| collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"] | collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"] | ||||
| self._client.delete(collection_name=self._collection_name, pks=ids) | self._client.delete(collection_name=self._collection_name, pks=ids) | ||||
| def delete(self) -> None: | def delete(self) -> None: | ||||
| """ | |||||
| Delete the entire collection. | |||||
| """ | |||||
| if self._client.has_collection(self._collection_name): | if self._client.has_collection(self._collection_name): | ||||
| self._client.drop_collection(self._collection_name, None) | self._client.drop_collection(self._collection_name, None) | ||||
| def text_exists(self, id: str) -> bool: | def text_exists(self, id: str) -> bool: | ||||
| """ | |||||
| Check if a text with the given ID exists in the collection. | |||||
| """ | |||||
| if not self._client.has_collection(self._collection_name): | if not self._client.has_collection(self._collection_name): | ||||
| return False | return False | ||||
| return len(result) > 0 | return len(result) > 0 | ||||
| def field_exists(self, field: str) -> bool: | |||||
| """ | |||||
| Check if a field exists in the collection. | |||||
| """ | |||||
| return field in self._fields | |||||
| def _process_search_results( | |||||
| self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0 | |||||
| ) -> list[Document]: | |||||
| """ | |||||
| Common method to process search results | |||||
| :param results: Search results | |||||
| :param output_fields: Fields to be output | |||||
| :param score_threshold: Score threshold for filtering | |||||
| :return: List of documents | |||||
| """ | |||||
| docs = [] | |||||
| for result in results[0]: | |||||
| metadata = result["entity"].get(output_fields[1], {}) | |||||
| metadata["score"] = result["distance"] | |||||
| if result["distance"] > score_threshold: | |||||
| doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata) | |||||
| docs.append(doc) | |||||
| return docs | |||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | ||||
| # Set search parameters. | |||||
| """ | |||||
| Search for documents by vector similarity. | |||||
| """ | |||||
| results = self._client.search( | results = self._client.search( | ||||
| collection_name=self._collection_name, | collection_name=self._collection_name, | ||||
| data=[query_vector], | data=[query_vector], | ||||
| anns_field=Field.VECTOR.value, | |||||
| limit=kwargs.get("top_k", 4), | limit=kwargs.get("top_k", 4), | ||||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | ||||
| ) | ) | ||||
| # Organize results. | |||||
| docs = [] | |||||
| for result in results[0]: | |||||
| metadata = result["entity"].get(Field.METADATA_KEY.value) | |||||
| metadata["score"] = result["distance"] | |||||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||||
| if result["distance"] > score_threshold: | |||||
| doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) | |||||
| docs.append(doc) | |||||
| return docs | |||||
| return self._process_search_results( | |||||
| results, | |||||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | |||||
| score_threshold=float(kwargs.get("score_threshold") or 0.0), | |||||
| ) | |||||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | ||||
| # milvus/zilliz doesn't support bm25 search | |||||
| return [] | |||||
| """ | |||||
| Search for documents by full-text search (if hybrid search is enabled). | |||||
| """ | |||||
| if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value): | |||||
| logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)") | |||||
| return [] | |||||
| results = self._client.search( | |||||
| collection_name=self._collection_name, | |||||
| data=[query], | |||||
| anns_field=Field.SPARSE_VECTOR.value, | |||||
| limit=kwargs.get("top_k", 4), | |||||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | |||||
| ) | |||||
| return self._process_search_results( | |||||
| results, | |||||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | |||||
| score_threshold=float(kwargs.get("score_threshold") or 0.0), | |||||
| ) | |||||
| def create_collection( | def create_collection( | ||||
| self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None | self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None | ||||
| ): | ): | ||||
| """ | |||||
| Create a new collection in Milvus with the specified schema and index parameters. | |||||
| """ | |||||
| lock_name = "vector_indexing_lock_{}".format(self._collection_name) | lock_name = "vector_indexing_lock_{}".format(self._collection_name) | ||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | ||||
| return | return | ||||
| # Grab the existing collection if it exists | # Grab the existing collection if it exists | ||||
| if not self._client.has_collection(self._collection_name): | if not self._client.has_collection(self._collection_name): | ||||
| from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore | |||||
| from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore | |||||
| from pymilvus.orm.types import infer_dtype_bydata # type: ignore | from pymilvus.orm.types import infer_dtype_bydata # type: ignore | ||||
| # Determine embedding dim | # Determine embedding dim | ||||
| if metadatas: | if metadatas: | ||||
| fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) | fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) | ||||
| # Create the text field | |||||
| fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)) | |||||
| # Create the text field, enable_analyzer will be set True to support milvus automatically | |||||
| # transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md | |||||
| fields.append( | |||||
| FieldSchema( | |||||
| Field.CONTENT_KEY.value, | |||||
| DataType.VARCHAR, | |||||
| max_length=65_535, | |||||
| enable_analyzer=self._hybrid_search_enabled, | |||||
| ) | |||||
| ) | |||||
| # Create the primary key field | # Create the primary key field | ||||
| fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) | fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) | ||||
| # Create the vector field, supports binary or float vectors | # Create the vector field, supports binary or float vectors | ||||
| fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) | fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) | ||||
| # Create Sparse Vector Index for the collection | |||||
| if self._hybrid_search_enabled: | |||||
| fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR)) | |||||
| # Create the schema for the collection | |||||
| schema = CollectionSchema(fields) | schema = CollectionSchema(fields) | ||||
| # Create custom function to support text to sparse vector by BM25 | |||||
| if self._hybrid_search_enabled: | |||||
| bm25_function = Function( | |||||
| name="text_bm25_emb", | |||||
| input_field_names=[Field.CONTENT_KEY.value], | |||||
| output_field_names=[Field.SPARSE_VECTOR.value], | |||||
| function_type=FunctionType.BM25, | |||||
| ) | |||||
| schema.add_function(bm25_function) | |||||
| for x in schema.fields: | for x in schema.fields: | ||||
| self._fields.append(x.name) | self._fields.append(x.name) | ||||
| # Since primary field is auto-id, no need to track it | # Since primary field is auto-id, no need to track it | ||||
| index_params_obj = IndexParams() | index_params_obj = IndexParams() | ||||
| index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params) | index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params) | ||||
| # Create Sparse Vector Index for the collection | |||||
| if self._hybrid_search_enabled: | |||||
| index_params_obj.add_index( | |||||
| field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25" | |||||
| ) | |||||
| # Create the collection | # Create the collection | ||||
| collection_name = self._collection_name | |||||
| self._client.create_collection( | self._client.create_collection( | ||||
| collection_name=collection_name, | |||||
| collection_name=self._collection_name, | |||||
| schema=schema, | schema=schema, | ||||
| index_params=index_params_obj, | index_params=index_params_obj, | ||||
| consistency_level=self._consistency_level, | consistency_level=self._consistency_level, | ||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | redis_client.set(collection_exist_cache_key, 1, ex=3600) | ||||
| def _init_client(self, config) -> MilvusClient: | def _init_client(self, config) -> MilvusClient: | ||||
| """ | |||||
| Initialize and return a Milvus client. | |||||
| """ | |||||
| client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database) | client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database) | ||||
| return client | return client | ||||
| class MilvusVectorFactory(AbstractVectorFactory): | class MilvusVectorFactory(AbstractVectorFactory): | ||||
| """ | |||||
| Factory class for creating MilvusVector instances. | |||||
| """ | |||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: | def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: | ||||
| """ | |||||
| Initialize a MilvusVector instance for the given dataset. | |||||
| """ | |||||
| if dataset.index_struct_dict: | if dataset.index_struct_dict: | ||||
| class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] | class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] | ||||
| collection_name = class_prefix | collection_name = class_prefix | ||||
| user=dify_config.MILVUS_USER or "", | user=dify_config.MILVUS_USER or "", | ||||
| password=dify_config.MILVUS_PASSWORD or "", | password=dify_config.MILVUS_PASSWORD or "", | ||||
| database=dify_config.MILVUS_DATABASE or "", | database=dify_config.MILVUS_DATABASE or "", | ||||
| enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False, | |||||
| ), | ), | ||||
| ) | ) | 
| return results.row_count > 0 | return results.row_count > 0 | ||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| if not ids: | |||||
| return | |||||
| self._client.command( | self._client.command( | ||||
| f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" | f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" | ||||
| ) | ) | 
| return bool(cur.rowcount != 0) | return bool(cur.rowcount != 0) | ||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| if not ids: | |||||
| return | |||||
| self._client.delete(table_name=self._collection_name, ids=ids) | self._client.delete(table_name=self._collection_name, ids=ids) | ||||
| def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: | def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: | 
| return docs | return docs | ||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| if not ids: | |||||
| return | |||||
| with self._get_cursor() as cur: | with self._get_cursor() as cur: | ||||
| cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) | cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) | ||||
| return docs | return docs | ||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios | |||||
| # Scenario 1: extract a document fails, resulting in a table not being created. | |||||
| # Then clicking the retry button triggers a delete operation on an empty list. | |||||
| if not ids: | |||||
| return | |||||
| with self._get_cursor() as cur: | with self._get_cursor() as cur: | ||||
| cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) | cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) | ||||
| return False | return False | ||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| if not ids: | |||||
| return | |||||
| self._db.collection(self._collection_name).delete(document_ids=ids) | self._db.collection(self._collection_name).delete(document_ids=ids) | ||||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | def delete_by_metadata_field(self, key: str, value: str) -> None: | 
| db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() | db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() | ||||
| ) | ) | ||||
| if not tidb_auth_binding: | if not tidb_auth_binding: | ||||
| idle_tidb_auth_binding = ( | |||||
| db.session.query(TidbAuthBinding) | |||||
| .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") | |||||
| .limit(1) | |||||
| .one_or_none() | |||||
| ) | |||||
| if idle_tidb_auth_binding: | |||||
| idle_tidb_auth_binding.active = True | |||||
| idle_tidb_auth_binding.tenant_id = dataset.tenant_id | |||||
| db.session.commit() | |||||
| TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" | |||||
| else: | |||||
| with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): | |||||
| tidb_auth_binding = ( | |||||
| with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): | |||||
| tidb_auth_binding = ( | |||||
| db.session.query(TidbAuthBinding) | |||||
| .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||||
| .one_or_none() | |||||
| ) | |||||
| if tidb_auth_binding: | |||||
| TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | |||||
| else: | |||||
| idle_tidb_auth_binding = ( | |||||
| db.session.query(TidbAuthBinding) | db.session.query(TidbAuthBinding) | ||||
| .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||||
| .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") | |||||
| .limit(1) | |||||
| .one_or_none() | .one_or_none() | ||||
| ) | ) | ||||
| if tidb_auth_binding: | |||||
| TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | |||||
| if idle_tidb_auth_binding: | |||||
| idle_tidb_auth_binding.active = True | |||||
| idle_tidb_auth_binding.tenant_id = dataset.tenant_id | |||||
| db.session.commit() | |||||
| TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" | |||||
| else: | else: | ||||
| new_cluster = TidbService.create_tidb_serverless_cluster( | new_cluster = TidbService.create_tidb_serverless_cluster( | ||||
| dify_config.TIDB_PROJECT_ID or "", | dify_config.TIDB_PROJECT_ID or "", | ||||
| db.session.add(new_tidb_auth_binding) | db.session.add(new_tidb_auth_binding) | ||||
| db.session.commit() | db.session.commit() | ||||
| TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" | TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" | ||||
| else: | else: | ||||
| TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | ||||
| from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory | from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory | ||||
| return ElasticSearchVectorFactory | return ElasticSearchVectorFactory | ||||
| case VectorType.ELASTICSEARCH_JA: | |||||
| from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import ( | |||||
| ElasticSearchJaVectorFactory, | |||||
| ) | |||||
| return ElasticSearchJaVectorFactory | |||||
| case VectorType.TIDB_VECTOR: | case VectorType.TIDB_VECTOR: | ||||
| from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory | from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory | ||||
| TENCENT = "tencent" | TENCENT = "tencent" | ||||
| ORACLE = "oracle" | ORACLE = "oracle" | ||||
| ELASTICSEARCH = "elasticsearch" | ELASTICSEARCH = "elasticsearch" | ||||
| ELASTICSEARCH_JA = "elasticsearch-ja" | |||||
| LINDORM = "lindorm" | LINDORM = "lindorm" | ||||
| COUCHBASE = "couchbase" | COUCHBASE = "couchbase" | ||||
| BAIDU = "baidu" | BAIDU = "baidu" | 
| self._file_cache_key = file_cache_key | self._file_cache_key = file_cache_key | ||||
| def extract(self) -> list[Document]: | def extract(self) -> list[Document]: | ||||
| plaintext_file_key = "" | |||||
| plaintext_file_exists = False | plaintext_file_exists = False | ||||
| if self._file_cache_key: | if self._file_cache_key: | ||||
| try: | try: | ||||
| text = "\n\n".join(text_list) | text = "\n\n".join(text_list) | ||||
| # save plaintext file for caching | # save plaintext file for caching | ||||
| if not plaintext_file_exists and plaintext_file_key: | |||||
| storage.save(plaintext_file_key, text.encode("utf-8")) | |||||
| if not plaintext_file_exists and self._file_cache_key: | |||||
| storage.save(self._file_cache_key, text.encode("utf-8")) | |||||
| return documents | return documents | ||||
| import uuid | import uuid | ||||
| from typing import Optional | from typing import Optional | ||||
| from configs import dify_config | |||||
| from core.model_manager import ModelInstance | from core.model_manager import ModelInstance | ||||
| from core.rag.cleaner.clean_processor import CleanProcessor | from core.rag.cleaner.clean_processor import CleanProcessor | ||||
| from core.rag.datasource.retrieval_service import RetrievalService | from core.rag.datasource.retrieval_service import RetrievalService | ||||
| child_nodes = self._split_child_nodes( | child_nodes = self._split_child_nodes( | ||||
| document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") | document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") | ||||
| ) | ) | ||||
| if kwargs.get("preview"): | |||||
| if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER: | |||||
| child_nodes = child_nodes[: dify_config.CHILD_CHUNKS_PREVIEW_NUMBER] | |||||
| document.children = child_nodes | document.children = child_nodes | ||||
| doc_id = str(uuid.uuid4()) | doc_id = str(uuid.uuid4()) | ||||
| hash = helper.generate_text_hash(document.page_content) | hash = helper.generate_text_hash(document.page_content) | 
| else: | else: | ||||
| body = body | body = body | ||||
| if method in {"get", "head", "post", "put", "delete", "patch"}: | |||||
| response: httpx.Response = getattr(ssrf_proxy, method)( | |||||
| if method in { | |||||
| "get", | |||||
| "head", | |||||
| "post", | |||||
| "put", | |||||
| "delete", | |||||
| "patch", | |||||
| "options", | |||||
| "GET", | |||||
| "POST", | |||||
| "PUT", | |||||
| "PATCH", | |||||
| "DELETE", | |||||
| "HEAD", | |||||
| "OPTIONS", | |||||
| }: | |||||
| response: httpx.Response = getattr(ssrf_proxy, method.lower())( | |||||
| url, | url, | ||||
| params=params, | params=params, | ||||
| headers=headers, | headers=headers, | 
| import io | import io | ||||
| import json | import json | ||||
| import logging | import logging | ||||
| import operator | |||||
| import os | import os | ||||
| import tempfile | import tempfile | ||||
| from typing import cast | |||||
| from collections.abc import Mapping, Sequence | |||||
| from typing import Any, cast | |||||
| import docx | import docx | ||||
| import pandas as pd | import pandas as pd | ||||
| import pypdfium2 # type: ignore | import pypdfium2 # type: ignore | ||||
| import yaml # type: ignore | import yaml # type: ignore | ||||
| from docx.table import Table | |||||
| from docx.text.paragraph import Paragraph | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.file import File, FileTransferMethod, file_manager | from core.file import File, FileTransferMethod, file_manager | ||||
| process_data=process_data, | process_data=process_data, | ||||
| ) | ) | ||||
| @classmethod | |||||
| def _extract_variable_selector_to_variable_mapping( | |||||
| cls, | |||||
| *, | |||||
| graph_config: Mapping[str, Any], | |||||
| node_id: str, | |||||
| node_data: DocumentExtractorNodeData, | |||||
| ) -> Mapping[str, Sequence[str]]: | |||||
| """ | |||||
| Extract variable selector to variable mapping | |||||
| :param graph_config: graph config | |||||
| :param node_id: node id | |||||
| :param node_data: node data | |||||
| :return: | |||||
| """ | |||||
| return {node_id + ".files": node_data.variable_selector} | |||||
| def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: | def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: | ||||
| """Extract text from a file based on its MIME type.""" | """Extract text from a file based on its MIME type.""" | ||||
| doc_file = io.BytesIO(file_content) | doc_file = io.BytesIO(file_content) | ||||
| doc = docx.Document(doc_file) | doc = docx.Document(doc_file) | ||||
| text = [] | text = [] | ||||
| # Process paragraphs | |||||
| for paragraph in doc.paragraphs: | |||||
| if paragraph.text.strip(): | |||||
| text.append(paragraph.text) | |||||
| # Process tables | |||||
| for table in doc.tables: | |||||
| # Table header | |||||
| try: | |||||
| # table maybe cause errors so ignore it. | |||||
| if len(table.rows) > 0 and table.rows[0].cells is not None: | |||||
| # Keep track of paragraph and table positions | |||||
| content_items: list[tuple[int, str, Table | Paragraph]] = [] | |||||
| # Process paragraphs and tables | |||||
| for i, paragraph in enumerate(doc.paragraphs): | |||||
| if paragraph.text.strip(): | |||||
| content_items.append((i, "paragraph", paragraph)) | |||||
| for i, table in enumerate(doc.tables): | |||||
| content_items.append((i, "table", table)) | |||||
| # Sort content items based on their original position | |||||
| content_items.sort(key=operator.itemgetter(0)) | |||||
| # Process sorted content | |||||
| for _, item_type, item in content_items: | |||||
| if item_type == "paragraph": | |||||
| if isinstance(item, Table): | |||||
| continue | |||||
| text.append(item.text) | |||||
| elif item_type == "table": | |||||
| # Process tables | |||||
| if not isinstance(item, Table): | |||||
| continue | |||||
| try: | |||||
| # Check if any cell in the table has text | # Check if any cell in the table has text | ||||
| has_content = False | has_content = False | ||||
| for row in table.rows: | |||||
| for row in item.rows: | |||||
| if any(cell.text.strip() for cell in row.cells): | if any(cell.text.strip() for cell in row.cells): | ||||
| has_content = True | has_content = True | ||||
| break | break | ||||
| if has_content: | if has_content: | ||||
| markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n" | |||||
| markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n" | |||||
| for row in table.rows[1:]: | |||||
| markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n" | |||||
| cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells] | |||||
| markdown_table = f"| {' | '.join(cell_texts)} |\n" | |||||
| markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n" | |||||
| for row in item.rows[1:]: | |||||
| # Replace newlines with <br> in each cell | |||||
| row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells] | |||||
| markdown_table += "| " + " | ".join(row_cells) + " |\n" | |||||
| text.append(markdown_table) | text.append(markdown_table) | ||||
| except Exception as e: | |||||
| logger.warning(f"Failed to extract table from DOC/DOCX: {e}") | |||||
| continue | |||||
| except Exception as e: | |||||
| logger.warning(f"Failed to extract table from DOC/DOCX: {e}") | |||||
| continue | |||||
| return "\n".join(text) | return "\n".join(text) | ||||
| except Exception as e: | except Exception as e: | ||||
| raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e | raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e | ||||
| Code Node Data. | Code Node Data. | ||||
| """ | """ | ||||
| method: Literal["get", "post", "put", "patch", "delete", "head"] | |||||
| method: Literal[ | |||||
| "get", | |||||
| "post", | |||||
| "put", | |||||
| "patch", | |||||
| "delete", | |||||
| "head", | |||||
| "options", | |||||
| "GET", | |||||
| "POST", | |||||
| "PUT", | |||||
| "PATCH", | |||||
| "DELETE", | |||||
| "HEAD", | |||||
| "OPTIONS", | |||||
| ] | |||||
| url: str | url: str | ||||
| authorization: HttpRequestNodeAuthorization | authorization: HttpRequestNodeAuthorization | ||||
| headers: str | headers: str | 
| class Executor: | class Executor: | ||||
| method: Literal["get", "head", "post", "put", "delete", "patch"] | |||||
| method: Literal[ | |||||
| "get", | |||||
| "head", | |||||
| "post", | |||||
| "put", | |||||
| "delete", | |||||
| "patch", | |||||
| "options", | |||||
| "GET", | |||||
| "POST", | |||||
| "PUT", | |||||
| "PATCH", | |||||
| "DELETE", | |||||
| "HEAD", | |||||
| "OPTIONS", | |||||
| ] | |||||
| url: str | url: str | ||||
| params: list[tuple[str, str]] | None | params: list[tuple[str, str]] | None | ||||
| content: str | bytes | None | content: str | bytes | None | ||||
| node_data.authorization.config.api_key | node_data.authorization.config.api_key | ||||
| ).text | ).text | ||||
| # check if node_data.url is a valid URL | |||||
| if not node_data.url: | |||||
| raise InvalidURLError("url is required") | |||||
| if not node_data.url.startswith(("http://", "https://")): | |||||
| raise InvalidURLError("url should start with http:// or https://") | |||||
| self.url: str = node_data.url | self.url: str = node_data.url | ||||
| self.method = node_data.method | self.method = node_data.method | ||||
| self.auth = node_data.authorization | self.auth = node_data.authorization | ||||
| def _init_url(self): | def _init_url(self): | ||||
| self.url = self.variable_pool.convert_template(self.node_data.url).text | self.url = self.variable_pool.convert_template(self.node_data.url).text | ||||
| # check if url is a valid URL | |||||
| if not self.url: | |||||
| raise InvalidURLError("url is required") | |||||
| if not self.url.startswith(("http://", "https://")): | |||||
| raise InvalidURLError("url should start with http:// or https://") | |||||
| def _init_params(self): | def _init_params(self): | ||||
| """ | """ | ||||
| Almost same as _init_headers(), difference: | Almost same as _init_headers(), difference: | ||||
| if len(data) != 1: | if len(data) != 1: | ||||
| raise RequestBodyError("json body type should have exactly one item") | raise RequestBodyError("json body type should have exactly one item") | ||||
| json_string = self.variable_pool.convert_template(data[0].value).text | json_string = self.variable_pool.convert_template(data[0].value).text | ||||
| json_object = json.loads(json_string, strict=False) | |||||
| try: | |||||
| json_object = json.loads(json_string, strict=False) | |||||
| except json.JSONDecodeError as e: | |||||
| raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e | |||||
| self.json = json_object | self.json = json_object | ||||
| # self.json = self._parse_object_contains_variables(json_object) | # self.json = self._parse_object_contains_variables(json_object) | ||||
| case "binary": | case "binary": | ||||
| """ | """ | ||||
| do http request depending on api bundle | do http request depending on api bundle | ||||
| """ | """ | ||||
| if self.method not in {"get", "head", "post", "put", "delete", "patch"}: | |||||
| if self.method not in { | |||||
| "get", | |||||
| "head", | |||||
| "post", | |||||
| "put", | |||||
| "delete", | |||||
| "patch", | |||||
| "options", | |||||
| "GET", | |||||
| "POST", | |||||
| "PUT", | |||||
| "PATCH", | |||||
| "DELETE", | |||||
| "HEAD", | |||||
| "OPTIONS", | |||||
| }: | |||||
| raise InvalidHttpMethodError(f"Invalid http method {self.method}") | raise InvalidHttpMethodError(f"Invalid http method {self.method}") | ||||
| request_args = { | request_args = { | ||||
| } | } | ||||
| # request_args = {k: v for k, v in request_args.items() if v is not None} | # request_args = {k: v for k, v in request_args.items() if v is not None} | ||||
| try: | try: | ||||
| response = getattr(ssrf_proxy, self.method)(**request_args) | |||||
| response = getattr(ssrf_proxy, self.method.lower())(**request_args) | |||||
| except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: | except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: | ||||
| raise HttpRequestNodeError(str(e)) | raise HttpRequestNodeError(str(e)) | ||||
| # FIXME: fix type ignore, this maybe httpx type issue | # FIXME: fix type ignore, this maybe httpx type issue | 
| ): | ): | ||||
| raise ValueError(f"Variable key {node_variable} not found in user inputs.") | raise ValueError(f"Variable key {node_variable} not found in user inputs.") | ||||
| # environment variable already exist in variable pool, not from user inputs | |||||
| if variable_pool.get(variable_selector): | |||||
| continue | |||||
| # fetch variable node id from variable selector | # fetch variable node id from variable selector | ||||
| variable_node_id = variable_selector[0] | variable_node_id = variable_selector[0] | ||||
| variable_key_list = variable_selector[1:] | variable_key_list = variable_selector[1:] | 
| --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ | --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ | ||||
| --workers ${SERVER_WORKER_AMOUNT:-1} \ | --workers ${SERVER_WORKER_AMOUNT:-1} \ | ||||
| --worker-class ${SERVER_WORKER_CLASS:-gevent} \ | --worker-class ${SERVER_WORKER_CLASS:-gevent} \ | ||||
| --worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \ | |||||
| --timeout ${GUNICORN_TIMEOUT:-200} \ | --timeout ${GUNICORN_TIMEOUT:-200} \ | ||||
| app:app | app:app | ||||
| fi | fi | 
| timezone = pytz.timezone(log_tz) | timezone = pytz.timezone(log_tz) | ||||
| def time_converter(seconds): | def time_converter(seconds): | ||||
| return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() | |||||
| return datetime.fromtimestamp(seconds, tz=timezone).timetuple() | |||||
| for handler in logging.root.handlers: | for handler in logging.root.handlers: | ||||
| if handler.formatter: | if handler.formatter: | 
| tenant_id: str, | tenant_id: str, | ||||
| transfer_method: FileTransferMethod, | transfer_method: FileTransferMethod, | ||||
| ) -> File: | ) -> File: | ||||
| url = mapping.get("url") | |||||
| url = mapping.get("url") or mapping.get("remote_url") | |||||
| if not url: | if not url: | ||||
| raise ValueError("Invalid file url") | raise ValueError("Invalid file url") | ||||
| response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) | response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) | ||||
| response_json = response.json() | response_json = response.json() | ||||
| if response.status_code != 200: | if response.status_code != 200: | ||||
| raise ValueError(f"Error fetching block parent page ID: {response_json.message}") | |||||
| message = response_json.get("message", "unknown error") | |||||
| raise ValueError(f"Error fetching block parent page ID: {message}") | |||||
| parent = response_json["parent"] | parent = response_json["parent"] | ||||
| parent_type = parent["type"] | parent_type = parent["type"] | ||||
| if parent_type == "block_id": | if parent_type == "block_id": | 
| """change workflow_runs.total_tokens to bigint | |||||
| Revision ID: a91b476a53de | |||||
| Revises: 923752d42eb6 | |||||
| Create Date: 2025-01-01 20:00:01.207369 | |||||
| """ | |||||
| from alembic import op | |||||
| import models as models | |||||
| import sqlalchemy as sa | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = 'a91b476a53de' | |||||
| down_revision = '923752d42eb6' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('workflow_runs', schema=None) as batch_op: | |||||
| batch_op.alter_column('total_tokens', | |||||
| existing_type=sa.INTEGER(), | |||||
| type_=sa.BigInteger(), | |||||
| existing_nullable=False, | |||||
| existing_server_default=sa.text('0')) | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('workflow_runs', schema=None) as batch_op: | |||||
| batch_op.alter_column('total_tokens', | |||||
| existing_type=sa.BigInteger(), | |||||
| type_=sa.INTEGER(), | |||||
| existing_nullable=False, | |||||
| existing_server_default=sa.text('0')) | |||||
| # ### end Alembic commands ### | 
| status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded | status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded | ||||
| outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") | outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") | ||||
| error: Mapped[Optional[str]] = mapped_column(db.Text) | error: Mapped[Optional[str]] = mapped_column(db.Text) | ||||
| elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) | |||||
| total_tokens: Mapped[int] = mapped_column(server_default=db.text("0")) | |||||
| elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0")) | |||||
| total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) | |||||
| total_steps = db.Column(db.Integer, server_default=db.text("0")) | total_steps = db.Column(db.Integer, server_default=db.text("0")) | ||||
| created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user | created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user | ||||
| created_by = db.Column(StringUUID, nullable=False) | created_by = db.Column(StringUUID, nullable=False) | 
| pypdfium2 = "~4.30.0" | pypdfium2 = "~4.30.0" | ||||
| python = ">=3.11,<3.13" | python = ">=3.11,<3.13" | ||||
| python-docx = "~1.1.0" | python-docx = "~1.1.0" | ||||
| python-dotenv = "1.0.0" | |||||
| python-dotenv = "1.0.1" | |||||
| pyyaml = "~6.0.1" | pyyaml = "~6.0.1" | ||||
| readabilipy = "0.2.0" | readabilipy = "0.2.0" | ||||
| redis = { version = "~5.0.3", extras = ["hiredis"] } | redis = { version = "~5.0.3", extras = ["hiredis"] } | ||||
| sentry-sdk = { version = "~1.44.1", extras = ["flask"] } | sentry-sdk = { version = "~1.44.1", extras = ["flask"] } | ||||
| sqlalchemy = "~2.0.29" | sqlalchemy = "~2.0.29" | ||||
| starlette = "0.41.0" | starlette = "0.41.0" | ||||
| tencentcloud-sdk-python-hunyuan = "~3.0.1158" | |||||
| tencentcloud-sdk-python-hunyuan = "~3.0.1294" | |||||
| tiktoken = "~0.8.0" | tiktoken = "~0.8.0" | ||||
| tokenizers = "~0.15.0" | tokenizers = "~0.15.0" | ||||
| transformers = "~4.35.0" | transformers = "~4.35.0" | ||||
| volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} | volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} | ||||
| websocket-client = "~1.7.0" | websocket-client = "~1.7.0" | ||||
| xinference-client = "0.15.2" | xinference-client = "0.15.2" | ||||
| yarl = "~1.9.4" | |||||
| yarl = "~1.18.3" | |||||
| youtube-transcript-api = "~0.6.2" | youtube-transcript-api = "~0.6.2" | ||||
| zhipuai = "~2.1.5" | zhipuai = "~2.1.5" | ||||
| # Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. | # Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. | ||||
| oracledb = "~2.2.1" | oracledb = "~2.2.1" | ||||
| pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } | pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } | ||||
| pgvector = "0.2.5" | pgvector = "0.2.5" | ||||
| pymilvus = "~2.4.4" | |||||
| pymilvus = "~2.5.0" | |||||
| pymochow = "1.3.1" | pymochow = "1.3.1" | ||||
| pyobvector = "~0.1.6" | pyobvector = "~0.1.6" | ||||
| qdrant-client = "1.7.3" | qdrant-client = "1.7.3" | 
| else: | else: | ||||
| plan = plan_cache.decode() | plan = plan_cache.decode() | ||||
| if plan == "sandbox": | if plan == "sandbox": | ||||
| # add auto disable log | |||||
| documents = ( | |||||
| db.session.query(Document) | |||||
| .filter( | |||||
| Document.dataset_id == dataset.id, | |||||
| Document.enabled == True, | |||||
| Document.archived == False, | |||||
| ) | |||||
| .all() | |||||
| ) | |||||
| for document in documents: | |||||
| dataset_auto_disable_log = DatasetAutoDisableLog( | |||||
| tenant_id=dataset.tenant_id, | |||||
| dataset_id=dataset.id, | |||||
| document_id=document.id, | |||||
| ) | |||||
| db.session.add(dataset_auto_disable_log) | |||||
| # remove index | # remove index | ||||
| index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() | index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() | ||||
| index_processor.clean(dataset, None) | index_processor.clean(dataset, None) | 
| REFRESH_TOKEN_PREFIX = "refresh_token:" | REFRESH_TOKEN_PREFIX = "refresh_token:" | ||||
| ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:" | ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:" | ||||
| REFRESH_TOKEN_EXPIRY = timedelta(days=30) | |||||
| REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS) | |||||
| class AccountService: | class AccountService: | 
| import uuid | import uuid | ||||
| from enum import StrEnum | from enum import StrEnum | ||||
| from typing import Optional, cast | from typing import Optional, cast | ||||
| from urllib.parse import urlparse | |||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| import yaml # type: ignore | import yaml # type: ignore | ||||
| raise ValueError(f"Invalid import_mode: {import_mode}") | raise ValueError(f"Invalid import_mode: {import_mode}") | ||||
| # Get YAML content | # Get YAML content | ||||
| content: bytes | str = b"" | |||||
| content: str = "" | |||||
| if mode == ImportMode.YAML_URL: | if mode == ImportMode.YAML_URL: | ||||
| if not yaml_url: | if not yaml_url: | ||||
| return Import( | return Import( | ||||
| error="yaml_url is required when import_mode is yaml-url", | error="yaml_url is required when import_mode is yaml-url", | ||||
| ) | ) | ||||
| try: | try: | ||||
| # tricky way to handle url from github to github raw url | |||||
| if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")): | |||||
| parsed_url = urlparse(yaml_url) | |||||
| if ( | |||||
| parsed_url.scheme == "https" | |||||
| and parsed_url.netloc == "github.com" | |||||
| and parsed_url.path.endswith((".yml", ".yaml")) | |||||
| ): | |||||
| yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") | yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") | ||||
| yaml_url = yaml_url.replace("/blob/", "/") | yaml_url = yaml_url.replace("/blob/", "/") | ||||
| response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) | response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) | ||||
| response.raise_for_status() | response.raise_for_status() | ||||
| content = response.content | |||||
| content = response.content.decode() | |||||
| if len(content) > DSL_MAX_SIZE: | if len(content) > DSL_MAX_SIZE: | ||||
| return Import( | return Import( | 
| class AppService: | class AppService: | ||||
| def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None: | |||||
| def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None: | |||||
| """ | """ | ||||
| Get app list with pagination | Get app list with pagination | ||||
| :param user_id: user id | |||||
| :param tenant_id: tenant id | :param tenant_id: tenant id | ||||
| :param args: request args | :param args: request args | ||||
| :return: | :return: | ||||
| elif args["mode"] == "channel": | elif args["mode"] == "channel": | ||||
| filters.append(App.mode == AppMode.CHANNEL.value) | filters.append(App.mode == AppMode.CHANNEL.value) | ||||
| if args.get("is_created_by_me", False): | |||||
| filters.append(App.created_by == user_id) | |||||
| if args.get("name"): | if args.get("name"): | ||||
| name = args["name"][:30] | name = args["name"][:30] | ||||
| filters.append(App.name.ilike(f"%{name}%")) | filters.append(App.name.ilike(f"%{name}%")) | 
| import os | import os | ||||
| from typing import Optional | |||||
| from typing import Literal, Optional | |||||
| import httpx | import httpx | ||||
| from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed | from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed | ||||
| params = {"tenant_id": tenant_id} | params = {"tenant_id": tenant_id} | ||||
| billing_info = cls._send_request("GET", "/subscription/info", params=params) | billing_info = cls._send_request("GET", "/subscription/info", params=params) | ||||
| return billing_info | return billing_info | ||||
| @classmethod | @classmethod | ||||
| retry=retry_if_exception_type(httpx.RequestError), | retry=retry_if_exception_type(httpx.RequestError), | ||||
| reraise=True, | reraise=True, | ||||
| ) | ) | ||||
| def _send_request(cls, method, endpoint, json=None, params=None): | |||||
| def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None): | |||||
| headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} | headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} | ||||
| url = f"{cls.base_url}{endpoint}" | url = f"{cls.base_url}{endpoint}" | ||||
| response = httpx.request(method, url, json=json, params=params, headers=headers) | response = httpx.request(method, url, json=json, params=params, headers=headers) | ||||
| if method == "GET" and response.status_code != httpx.codes.OK: | |||||
| raise ValueError("Unable to retrieve billing information. Please try again later or contact support.") | |||||
| return response.json() | return response.json() | ||||
| @staticmethod | @staticmethod | 
| else: | else: | ||||
| return [], 0 | return [], 0 | ||||
| else: | else: | ||||
| if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): | |||||
| if user.current_role != TenantAccountRole.OWNER: | |||||
| # show all datasets that the user has permission to access | # show all datasets that the user has permission to access | ||||
| if permitted_dataset_ids: | if permitted_dataset_ids: | ||||
| query = query.filter( | query = query.filter( | ||||
| if dataset.tenant_id != user.current_tenant_id: | if dataset.tenant_id != user.current_tenant_id: | ||||
| logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") | logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") | ||||
| raise NoPermissionError("You do not have permission to access this dataset.") | raise NoPermissionError("You do not have permission to access this dataset.") | ||||
| if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): | |||||
| if user.current_role != TenantAccountRole.OWNER: | |||||
| if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: | if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: | ||||
| logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") | logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") | ||||
| raise NoPermissionError("You do not have permission to access this dataset.") | raise NoPermissionError("You do not have permission to access this dataset.") | ||||
| if not user: | if not user: | ||||
| raise ValueError("User not found") | raise ValueError("User not found") | ||||
| if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): | |||||
| if user.current_role != TenantAccountRole.OWNER: | |||||
| if dataset.permission == DatasetPermissionEnum.ONLY_ME: | if dataset.permission == DatasetPermissionEnum.ONLY_ME: | ||||
| if dataset.created_by != user.id: | if dataset.created_by != user.id: | ||||
| raise NoPermissionError("You do not have permission to access this dataset.") | raise NoPermissionError("You do not have permission to access this dataset.") | ||||
| @staticmethod | @staticmethod | ||||
| def get_dataset_auto_disable_logs(dataset_id: str) -> dict: | def get_dataset_auto_disable_logs(dataset_id: str) -> dict: | ||||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||||
| if not features.billing.enabled or features.billing.subscription.plan == "sandbox": | |||||
| return { | |||||
| "document_ids": [], | |||||
| "count": 0, | |||||
| } | |||||
| # get recent 30 days auto disable logs | # get recent 30 days auto disable logs | ||||
| start_date = datetime.datetime.now() - datetime.timedelta(days=30) | start_date = datetime.datetime.now() - datetime.timedelta(days=30) | ||||
| dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( | dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( | ||||
| dataset.indexing_technique = knowledge_config.indexing_technique | dataset.indexing_technique = knowledge_config.indexing_technique | ||||
| if knowledge_config.indexing_technique == "high_quality": | if knowledge_config.indexing_technique == "high_quality": | ||||
| model_manager = ModelManager() | model_manager = ModelManager() | ||||
| embedding_model = model_manager.get_default_model_instance( | |||||
| tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING | |||||
| ) | |||||
| dataset.embedding_model = embedding_model.model | |||||
| dataset.embedding_model_provider = embedding_model.provider | |||||
| if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: | |||||
| dataset_embedding_model = knowledge_config.embedding_model | |||||
| dataset_embedding_model_provider = knowledge_config.embedding_model_provider | |||||
| else: | |||||
| embedding_model = model_manager.get_default_model_instance( | |||||
| tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING | |||||
| ) | |||||
| dataset_embedding_model = embedding_model.model | |||||
| dataset_embedding_model_provider = embedding_model.provider | |||||
| dataset.embedding_model = dataset_embedding_model | |||||
| dataset.embedding_model_provider = dataset_embedding_model_provider | |||||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | ||||
| embedding_model.provider, embedding_model.model | |||||
| dataset_embedding_model_provider, dataset_embedding_model | |||||
| ) | ) | ||||
| dataset.collection_binding_id = dataset_collection_binding.id | dataset.collection_binding_id = dataset_collection_binding.id | ||||
| if not dataset.retrieval_model: | if not dataset.retrieval_model: | ||||
| "score_threshold_enabled": False, | "score_threshold_enabled": False, | ||||
| } | } | ||||
| dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore | |||||
| dataset.retrieval_model = ( | |||||
| knowledge_config.retrieval_model.model_dump() | |||||
| if knowledge_config.retrieval_model | |||||
| else default_retrieval_model | |||||
| ) # type: ignore | |||||
| documents = [] | documents = [] | ||||
| if knowledge_config.original_document_id: | if knowledge_config.original_document_id: | 
| query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) | query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) | ||||
| if keyword: | if keyword: | ||||
| keyword_like_val = f"%{args['keyword'][:30]}%" | |||||
| keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u") | |||||
| keyword_conditions = [ | keyword_conditions = [ | ||||
| WorkflowRun.inputs.ilike(keyword_like_val), | WorkflowRun.inputs.ilike(keyword_like_val), | ||||
| WorkflowRun.outputs.ilike(keyword_like_val), | WorkflowRun.outputs.ilike(keyword_like_val), | 
| if not dataset: | if not dataset: | ||||
| raise Exception("Dataset not found") | raise Exception("Dataset not found") | ||||
| index_type = dataset.doc_form | |||||
| index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX | |||||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | index_processor = IndexProcessorFactory(index_type).init_index_processor() | ||||
| if action == "remove": | if action == "remove": | ||||
| index_processor.clean(dataset, None, with_keywords=False) | index_processor.clean(dataset, None, with_keywords=False) | ||||
| {"indexing_status": "error", "error": str(e)}, synchronize_session=False | {"indexing_status": "error", "error": str(e)}, synchronize_session=False | ||||
| ) | ) | ||||
| db.session.commit() | db.session.commit() | ||||
| else: | |||||
| # clean collection | |||||
| index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) | |||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info( | logging.info( | 
| import os | |||||
| from pathlib import Path | |||||
| import pytest | |||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||||
| from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel | |||||
| def test_validate_credentials(): | |||||
| model = GPUStackSpeech2TextModel() | |||||
| with pytest.raises(CredentialsValidateFailedError): | |||||
| model.validate_credentials( | |||||
| model="faster-whisper-medium", | |||||
| credentials={ | |||||
| "endpoint_url": "invalid_url", | |||||
| "api_key": "invalid_api_key", | |||||
| }, | |||||
| ) | |||||
| model.validate_credentials( | |||||
| model="faster-whisper-medium", | |||||
| credentials={ | |||||
| "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), | |||||
| "api_key": os.environ.get("GPUSTACK_API_KEY"), | |||||
| }, | |||||
| ) | |||||
| def test_invoke_model(): | |||||
| model = GPUStackSpeech2TextModel() | |||||
| # Get the directory of the current file | |||||
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |||||
| # Get assets directory | |||||
| assets_dir = os.path.join(os.path.dirname(current_dir), "assets") | |||||
| # Construct the path to the audio file | |||||
| audio_file_path = os.path.join(assets_dir, "audio.mp3") | |||||
| file = Path(audio_file_path).read_bytes() | |||||
| result = model.invoke( | |||||
| model="faster-whisper-medium", | |||||
| credentials={ | |||||
| "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), | |||||
| "api_key": os.environ.get("GPUSTACK_API_KEY"), | |||||
| }, | |||||
| file=file, | |||||
| ) | |||||
| assert isinstance(result, str) | |||||
| assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" | 
| import os | |||||
| from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel | |||||
| def test_invoke_model(): | |||||
| model = GPUStackText2SpeechModel() | |||||
| result = model.invoke( | |||||
| model="cosyvoice-300m-sft", | |||||
| tenant_id="test", | |||||
| credentials={ | |||||
| "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), | |||||
| "api_key": os.environ.get("GPUSTACK_API_KEY"), | |||||
| }, | |||||
| content_text="Hello world", | |||||
| voice="Chinese Female", | |||||
| ) | |||||
| content = b"" | |||||
| for chunk in result: | |||||
| content += chunk | |||||
| assert content != b"" | 
| ) | ) | ||||
| def search_by_full_text(self): | def search_by_full_text(self): | ||||
| # milvus dos not support full text searching yet in < 2.3.x | |||||
| # milvus support BM25 full text search after version 2.5.0-beta | |||||
| hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) | hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) | ||||
| assert len(hits_by_full_text) == 0 | |||||
| assert len(hits_by_full_text) >= 0 | |||||
| def get_ids_by_metadata_field(self): | def get_ids_by_metadata_field(self): | ||||
| ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) | ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) | 
| services: | services: | ||||
| # API service | # API service | ||||
| api: | api: | ||||
| image: langgenius/dify-api:0.14.2 | |||||
| image: langgenius/dify-api:0.15.0 | |||||
| restart: always | restart: always | ||||
| environment: | environment: | ||||
| # Startup mode, 'api' starts the API server. | # Startup mode, 'api' starts the API server. | ||||
| # worker service | # worker service | ||||
| # The Celery worker for processing the queue. | # The Celery worker for processing the queue. | ||||
| worker: | worker: | ||||
| image: langgenius/dify-api:0.14.2 | |||||
| image: langgenius/dify-api:0.15.0 | |||||
| restart: always | restart: always | ||||
| environment: | environment: | ||||
| CONSOLE_WEB_URL: '' | CONSOLE_WEB_URL: '' | ||||
| # Frontend web application. | # Frontend web application. | ||||
| web: | web: | ||||
| image: langgenius/dify-web:0.14.2 | |||||
| image: langgenius/dify-web:0.15.0 | |||||
| restart: always | restart: always | ||||
| environment: | environment: | ||||
| # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is | # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is | 
| # Access token expiration time in minutes | # Access token expiration time in minutes | ||||
| ACCESS_TOKEN_EXPIRE_MINUTES=60 | ACCESS_TOKEN_EXPIRE_MINUTES=60 | ||||
| # Refresh token expiration time in days | |||||
| REFRESH_TOKEN_EXPIRE_DAYS=30 | |||||
| # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. | # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. | ||||
| APP_MAX_ACTIVE_REQUESTS=0 | APP_MAX_ACTIVE_REQUESTS=0 | ||||
| APP_MAX_EXECUTION_TIME=1200 | APP_MAX_EXECUTION_TIME=1200 | ||||
| # The number of API server workers, i.e., the number of workers. | # The number of API server workers, i.e., the number of workers. | ||||
| # Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent | # Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent | ||||
| # Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers | # Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers | ||||
| SERVER_WORKER_AMOUNT= | |||||
| SERVER_WORKER_AMOUNT=1 | |||||
| # Defaults to gevent. If using windows, it can be switched to sync or solo. | # Defaults to gevent. If using windows, it can be switched to sync or solo. | ||||
| SERVER_WORKER_CLASS= | |||||
| SERVER_WORKER_CLASS=gevent | |||||
| # Default number of worker connections, the default is 10. | |||||
| SERVER_WORKER_CONNECTIONS=10 | |||||
| # Similar to SERVER_WORKER_CLASS. | # Similar to SERVER_WORKER_CLASS. | ||||
| # If using windows, it can be switched to sync or solo. | # If using windows, it can be switched to sync or solo. | ||||
| # ------------------------------ | # ------------------------------ | ||||
| # The type of vector store to use. | # The type of vector store to use. | ||||
| # Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`. | |||||
| # Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`. | |||||
| VECTOR_STORE=weaviate | VECTOR_STORE=weaviate | ||||
| # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. | # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. | ||||
| MILVUS_TOKEN= | MILVUS_TOKEN= | ||||
| MILVUS_USER=root | MILVUS_USER=root | ||||
| MILVUS_PASSWORD=Milvus | MILVUS_PASSWORD=Milvus | ||||
| MILVUS_ENABLE_HYBRID_SEARCH=False | |||||
| # MyScale configuration, only available when VECTOR_STORE is `myscale` | # MyScale configuration, only available when VECTOR_STORE is `myscale` | ||||
| # For multi-language support, please set MYSCALE_FTS_PARAMS with referring to: | # For multi-language support, please set MYSCALE_FTS_PARAMS with referring to: | ||||
| TENCENT_VECTOR_DB_REPLICAS=2 | TENCENT_VECTOR_DB_REPLICAS=2 | ||||
| # ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch` | # ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch` | ||||
| ELASTICSEARCH_HOST=0.0.0.0 | |||||
| ELASTICSEARCH_HOST=elasticsearch | |||||
| ELASTICSEARCH_PORT=9200 | ELASTICSEARCH_PORT=9200 | ||||
| ELASTICSEARCH_USERNAME=elastic | ELASTICSEARCH_USERNAME=elastic | ||||
| ELASTICSEARCH_PASSWORD=elastic | ELASTICSEARCH_PASSWORD=elastic | ||||
| # Maximum number of submitted thread count in a ThreadPool for parallel node execution | # Maximum number of submitted thread count in a ThreadPool for parallel node execution | ||||
| MAX_SUBMIT_COUNT=100 | MAX_SUBMIT_COUNT=100 | ||||
| # The maximum number of top-k value for RAG. | |||||
| TOP_K_MAX_VALUE=10 | |||||
| # ------------------------------ | # ------------------------------ | ||||
| # Plugin Daemon Configuration | # Plugin Daemon Configuration | ||||
| # ------------------------------ | # ------------------------------ | ||||
| MARKETPLACE_ENABLED=true | MARKETPLACE_ENABLED=true | ||||
| MARKETPLACE_API_URL=https://marketplace-plugin.dify.dev | MARKETPLACE_API_URL=https://marketplace-plugin.dify.dev | ||||
| CSP_WHITELIST: ${CSP_WHITELIST:-} | CSP_WHITELIST: ${CSP_WHITELIST:-} | ||||
| MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} | MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} | ||||
| MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev} | MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev} | ||||
| TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-} | |||||
| # The postgres database. | # The postgres database. | ||||
| db: | db: | ||||
| volumes: | volumes: | ||||
| - ./volumes/db/data:/var/lib/postgresql/data | - ./volumes/db/data:/var/lib/postgresql/data | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'pg_isready'] | |||||
| test: [ 'CMD', 'pg_isready' ] | |||||
| interval: 1s | interval: 1s | ||||
| timeout: 3s | timeout: 3s | ||||
| retries: 30 | retries: 30 | ||||
| # Set the redis password when startup redis server. | # Set the redis password when startup redis server. | ||||
| command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} | command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'redis-cli', 'ping'] | |||||
| test: [ 'CMD', 'redis-cli', 'ping' ] | |||||
| # The DifySandbox | # The DifySandbox | ||||
| sandbox: | sandbox: | ||||
| volumes: | volumes: | ||||
| - ./volumes/sandbox/dependencies:/dependencies | - ./volumes/sandbox/dependencies:/dependencies | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'curl', '-f', 'http://localhost:8194/health'] | |||||
| test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ] | |||||
| networks: | networks: | ||||
| - ssrf_proxy_network | - ssrf_proxy_network | ||||
| volumes: | volumes: | ||||
| - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template | - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template | ||||
| - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh | - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh | ||||
| entrypoint: | |||||
| [ | |||||
| 'sh', | |||||
| '-c', | |||||
| "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", | |||||
| ] | |||||
| entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] | |||||
| environment: | environment: | ||||
| # pls clearly modify the squid env vars to fit your network environment. | # pls clearly modify the squid env vars to fit your network environment. | ||||
| HTTP_PORT: ${SSRF_HTTP_PORT:-3128} | HTTP_PORT: ${SSRF_HTTP_PORT:-3128} | ||||
| - CERTBOT_EMAIL=${CERTBOT_EMAIL} | - CERTBOT_EMAIL=${CERTBOT_EMAIL} | ||||
| - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} | - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} | ||||
| - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} | - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} | ||||
| entrypoint: ['/docker-entrypoint.sh'] | |||||
| command: ['tail', '-f', '/dev/null'] | |||||
| entrypoint: [ '/docker-entrypoint.sh' ] | |||||
| command: [ 'tail', '-f', '/dev/null' ] | |||||
| # The nginx reverse proxy. | # The nginx reverse proxy. | ||||
| # used for reverse proxying the API service and Web service. | # used for reverse proxying the API service and Web service. | ||||
| - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) | - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) | ||||
| - ./volumes/certbot/conf:/etc/letsencrypt | - ./volumes/certbot/conf:/etc/letsencrypt | ||||
| - ./volumes/certbot/www:/var/www/html | - ./volumes/certbot/www:/var/www/html | ||||
| entrypoint: | |||||
| [ | |||||
| 'sh', | |||||
| '-c', | |||||
| "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", | |||||
| ] | |||||
| entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] | |||||
| environment: | environment: | ||||
| NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} | NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} | ||||
| NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} | NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} | ||||
| working_dir: /opt/couchbase | working_dir: /opt/couchbase | ||||
| stdin_open: true | stdin_open: true | ||||
| tty: true | tty: true | ||||
| entrypoint: [""] | |||||
| entrypoint: [ "" ] | |||||
| command: sh -c "/opt/couchbase/init/init-cbserver.sh" | command: sh -c "/opt/couchbase/init/init-cbserver.sh" | ||||
| volumes: | volumes: | ||||
| - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data | - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data | ||||
| volumes: | volumes: | ||||
| - ./volumes/pgvector/data:/var/lib/postgresql/data | - ./volumes/pgvector/data:/var/lib/postgresql/data | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'pg_isready'] | |||||
| test: [ 'CMD', 'pg_isready' ] | |||||
| interval: 1s | interval: 1s | ||||
| timeout: 3s | timeout: 3s | ||||
| retries: 30 | retries: 30 | ||||
| volumes: | volumes: | ||||
| - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data | - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'pg_isready'] | |||||
| test: [ 'CMD', 'pg_isready' ] | |||||
| interval: 1s | interval: 1s | ||||
| timeout: 3s | timeout: 3s | ||||
| retries: 30 | retries: 30 | ||||
| - ./volumes/milvus/etcd:/etcd | - ./volumes/milvus/etcd:/etcd | ||||
| command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd | command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'etcdctl', 'endpoint', 'health'] | |||||
| test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ] | |||||
| interval: 30s | interval: 30s | ||||
| timeout: 20s | timeout: 20s | ||||
| retries: 3 | retries: 3 | ||||
| - ./volumes/milvus/minio:/minio_data | - ./volumes/milvus/minio:/minio_data | ||||
| command: minio server /minio_data --console-address ":9001" | command: minio server /minio_data --console-address ":9001" | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] | |||||
| test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ] | |||||
| interval: 30s | interval: 30s | ||||
| timeout: 20s | timeout: 20s | ||||
| retries: 3 | retries: 3 | ||||
| image: milvusdb/milvus:v2.3.1 | image: milvusdb/milvus:v2.3.1 | ||||
| profiles: | profiles: | ||||
| - milvus | - milvus | ||||
| command: ['milvus', 'run', 'standalone'] | |||||
| command: [ 'milvus', 'run', 'standalone' ] | |||||
| environment: | environment: | ||||
| ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} | ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} | ||||
| MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} | MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} | ||||
| volumes: | volumes: | ||||
| - ./volumes/milvus/milvus:/var/lib/milvus | - ./volumes/milvus/milvus:/var/lib/milvus | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz'] | |||||
| test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ] | |||||
| interval: 30s | interval: 30s | ||||
| start_period: 90s | start_period: 90s | ||||
| timeout: 20s | timeout: 20s | ||||
| ports: | ports: | ||||
| - ${ELASTICSEARCH_PORT:-9200}:9200 | - ${ELASTICSEARCH_PORT:-9200}:9200 | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty'] | |||||
| test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ] | |||||
| interval: 30s | interval: 30s | ||||
| timeout: 10s | timeout: 10s | ||||
| retries: 50 | retries: 50 | ||||
| ports: | ports: | ||||
| - ${KIBANA_PORT:-5601}:5601 | - ${KIBANA_PORT:-5601}:5601 | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1'] | |||||
| test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ] | |||||
| interval: 30s | interval: 30s | ||||
| timeout: 10s | timeout: 10s | ||||
| retries: 3 | retries: 3 | 
| MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} | MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} | ||||
| FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} | FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} | ||||
| ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} | ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} | ||||
| REFRESH_TOKEN_EXPIRE_DAYS: ${REFRESH_TOKEN_EXPIRE_DAYS:-30} | |||||
| APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0} | APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0} | ||||
| APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200} | APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200} | ||||
| DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0} | DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0} | ||||
| DIFY_PORT: ${DIFY_PORT:-5001} | DIFY_PORT: ${DIFY_PORT:-5001} | ||||
| SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-} | |||||
| SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-} | |||||
| SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-1} | |||||
| SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-gevent} | |||||
| SERVER_WORKER_CONNECTIONS: ${SERVER_WORKER_CONNECTIONS:-10} | |||||
| CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-} | CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-} | ||||
| GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360} | GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360} | ||||
| CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-} | CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-} | ||||
| MILVUS_TOKEN: ${MILVUS_TOKEN:-} | MILVUS_TOKEN: ${MILVUS_TOKEN:-} | ||||
| MILVUS_USER: ${MILVUS_USER:-root} | MILVUS_USER: ${MILVUS_USER:-root} | ||||
| MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus} | MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus} | ||||
| MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False} | |||||
| MYSCALE_HOST: ${MYSCALE_HOST:-myscale} | MYSCALE_HOST: ${MYSCALE_HOST:-myscale} | ||||
| MYSCALE_PORT: ${MYSCALE_PORT:-8123} | MYSCALE_PORT: ${MYSCALE_PORT:-8123} | ||||
| MYSCALE_USER: ${MYSCALE_USER:-default} | MYSCALE_USER: ${MYSCALE_USER:-default} | ||||
| ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}} | ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}} | ||||
| MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true} | MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true} | ||||
| MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} | MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} | ||||
| TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10} | |||||
| services: | services: | ||||
| # API service | # API service | ||||
| CSP_WHITELIST: ${CSP_WHITELIST:-} | CSP_WHITELIST: ${CSP_WHITELIST:-} | ||||
| MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} | MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} | ||||
| MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev} | MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev} | ||||
| TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-} | |||||
| # The postgres database. | # The postgres database. | ||||
| db: | db: | ||||
| volumes: | volumes: | ||||
| - ./volumes/db/data:/var/lib/postgresql/data | - ./volumes/db/data:/var/lib/postgresql/data | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'pg_isready'] | |||||
| test: [ 'CMD', 'pg_isready' ] | |||||
| interval: 1s | interval: 1s | ||||
| timeout: 3s | timeout: 3s | ||||
| retries: 30 | retries: 30 | ||||
| # Set the redis password when startup redis server. | # Set the redis password when startup redis server. | ||||
| command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} | command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'redis-cli', 'ping'] | |||||
| test: [ 'CMD', 'redis-cli', 'ping' ] | |||||
| # The DifySandbox | # The DifySandbox | ||||
| sandbox: | sandbox: | ||||
| volumes: | volumes: | ||||
| - ./volumes/sandbox/dependencies:/dependencies | - ./volumes/sandbox/dependencies:/dependencies | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'curl', '-f', 'http://localhost:8194/health'] | |||||
| test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ] | |||||
| networks: | networks: | ||||
| - ssrf_proxy_network | - ssrf_proxy_network | ||||
| volumes: | volumes: | ||||
| - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template | - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template | ||||
| - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh | - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh | ||||
| entrypoint: | |||||
| [ | |||||
| 'sh', | |||||
| '-c', | |||||
| "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", | |||||
| ] | |||||
| entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] | |||||
| environment: | environment: | ||||
| # pls clearly modify the squid env vars to fit your network environment. | # pls clearly modify the squid env vars to fit your network environment. | ||||
| HTTP_PORT: ${SSRF_HTTP_PORT:-3128} | HTTP_PORT: ${SSRF_HTTP_PORT:-3128} | ||||
| - CERTBOT_EMAIL=${CERTBOT_EMAIL} | - CERTBOT_EMAIL=${CERTBOT_EMAIL} | ||||
| - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} | - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} | ||||
| - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} | - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} | ||||
| entrypoint: ['/docker-entrypoint.sh'] | |||||
| command: ['tail', '-f', '/dev/null'] | |||||
| entrypoint: [ '/docker-entrypoint.sh' ] | |||||
| command: [ 'tail', '-f', '/dev/null' ] | |||||
| # The nginx reverse proxy. | # The nginx reverse proxy. | ||||
| # used for reverse proxying the API service and Web service. | # used for reverse proxying the API service and Web service. | ||||
| - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) | - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) | ||||
| - ./volumes/certbot/conf:/etc/letsencrypt | - ./volumes/certbot/conf:/etc/letsencrypt | ||||
| - ./volumes/certbot/www:/var/www/html | - ./volumes/certbot/www:/var/www/html | ||||
| entrypoint: | |||||
| [ | |||||
| 'sh', | |||||
| '-c', | |||||
| "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", | |||||
| ] | |||||
| entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] | |||||
| environment: | environment: | ||||
| NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} | NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} | ||||
| NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} | NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} | ||||
| working_dir: /opt/couchbase | working_dir: /opt/couchbase | ||||
| stdin_open: true | stdin_open: true | ||||
| tty: true | tty: true | ||||
| entrypoint: [""] | |||||
| entrypoint: [ "" ] | |||||
| command: sh -c "/opt/couchbase/init/init-cbserver.sh" | command: sh -c "/opt/couchbase/init/init-cbserver.sh" | ||||
| volumes: | volumes: | ||||
| - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data | - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data | ||||
| volumes: | volumes: | ||||
| - ./volumes/pgvector/data:/var/lib/postgresql/data | - ./volumes/pgvector/data:/var/lib/postgresql/data | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'pg_isready'] | |||||
| test: [ 'CMD', 'pg_isready' ] | |||||
| interval: 1s | interval: 1s | ||||
| timeout: 3s | timeout: 3s | ||||
| retries: 30 | retries: 30 | ||||
| volumes: | volumes: | ||||
| - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data | - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'pg_isready'] | |||||
| test: [ 'CMD', 'pg_isready' ] | |||||
| interval: 1s | interval: 1s | ||||
| timeout: 3s | timeout: 3s | ||||
| retries: 30 | retries: 30 | ||||
| - ./volumes/milvus/etcd:/etcd | - ./volumes/milvus/etcd:/etcd | ||||
| command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd | command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'etcdctl', 'endpoint', 'health'] | |||||
| test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ] | |||||
| interval: 30s | interval: 30s | ||||
| timeout: 20s | timeout: 20s | ||||
| retries: 3 | retries: 3 | ||||
| - ./volumes/milvus/minio:/minio_data | - ./volumes/milvus/minio:/minio_data | ||||
| command: minio server /minio_data --console-address ":9001" | command: minio server /minio_data --console-address ":9001" | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] | |||||
| test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ] | |||||
| interval: 30s | interval: 30s | ||||
| timeout: 20s | timeout: 20s | ||||
| retries: 3 | retries: 3 | ||||
| milvus-standalone: | milvus-standalone: | ||||
| container_name: milvus-standalone | container_name: milvus-standalone | ||||
| image: milvusdb/milvus:v2.3.1 | |||||
| image: milvusdb/milvus:v2.5.0-beta | |||||
| profiles: | profiles: | ||||
| - milvus | - milvus | ||||
| command: ['milvus', 'run', 'standalone'] | |||||
| command: [ 'milvus', 'run', 'standalone' ] | |||||
| environment: | environment: | ||||
| ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} | ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} | ||||
| MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} | MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} | ||||
| volumes: | volumes: | ||||
| - ./volumes/milvus/milvus:/var/lib/milvus | - ./volumes/milvus/milvus:/var/lib/milvus | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz'] | |||||
| test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ] | |||||
| interval: 30s | interval: 30s | ||||
| start_period: 90s | start_period: 90s | ||||
| timeout: 20s | timeout: 20s | ||||
| container_name: elasticsearch | container_name: elasticsearch | ||||
| profiles: | profiles: | ||||
| - elasticsearch | - elasticsearch | ||||
| - elasticsearch-ja | |||||
| restart: always | restart: always | ||||
| volumes: | volumes: | ||||
| - ./elasticsearch/docker-entrypoint.sh:/docker-entrypoint-mount.sh | |||||
| - dify_es01_data:/usr/share/elasticsearch/data | - dify_es01_data:/usr/share/elasticsearch/data | ||||
| environment: | environment: | ||||
| ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} | ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} | ||||
| VECTOR_STORE: ${VECTOR_STORE:-} | |||||
| cluster.name: dify-es-cluster | cluster.name: dify-es-cluster | ||||
| node.name: dify-es0 | node.name: dify-es0 | ||||
| discovery.type: single-node | discovery.type: single-node | ||||
| xpack.license.self_generated.type: trial | |||||
| xpack.license.self_generated.type: basic | |||||
| xpack.security.enabled: 'true' | xpack.security.enabled: 'true' | ||||
| xpack.security.enrollment.enabled: 'false' | xpack.security.enrollment.enabled: 'false' | ||||
| xpack.security.http.ssl.enabled: 'false' | xpack.security.http.ssl.enabled: 'false' | ||||
| ports: | ports: | ||||
| - ${ELASTICSEARCH_PORT:-9200}:9200 | - ${ELASTICSEARCH_PORT:-9200}:9200 | ||||
| deploy: | |||||
| resources: | |||||
| limits: | |||||
| memory: 2g | |||||
| entrypoint: [ 'sh', '-c', "sh /docker-entrypoint-mount.sh" ] | |||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty'] | |||||
| test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ] | |||||
| interval: 30s | interval: 30s | ||||
| timeout: 10s | timeout: 10s | ||||
| retries: 50 | retries: 50 | ||||
| ports: | ports: | ||||
| - ${KIBANA_PORT:-5601}:5601 | - ${KIBANA_PORT:-5601}:5601 | ||||
| healthcheck: | healthcheck: | ||||
| test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1'] | |||||
| test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ] | |||||
| interval: 30s | interval: 30s | ||||
| timeout: 10s | timeout: 10s | ||||
| retries: 3 | retries: 3 | 
| #!/bin/bash | |||||
| set -e | |||||
| if [ "${VECTOR_STORE}" = "elasticsearch-ja" ]; then | |||||
| # Check if the ICU tokenizer plugin is installed | |||||
| if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-icu; then | |||||
| printf '%s\n' "Installing the ICU tokenizer plugin" | |||||
| if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-icu; then | |||||
| printf '%s\n' "Failed to install the ICU tokenizer plugin" | |||||
| exit 1 | |||||
| fi | |||||
| fi | |||||
| # Check if the Japanese language analyzer plugin is installed | |||||
| if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-kuromoji; then | |||||
| printf '%s\n' "Installing the Japanese language analyzer plugin" | |||||
| if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-kuromoji; then | |||||
| printf '%s\n' "Failed to install the Japanese language analyzer plugin" | |||||
| exit 1 | |||||
| fi | |||||
| fi | |||||
| fi | |||||
| # Run the original entrypoint script | |||||
| exec /bin/tini -- /usr/local/bin/docker-entrypoint.sh | 
| # CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP | # CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP | ||||
| NEXT_PUBLIC_CSP_WHITELIST= | NEXT_PUBLIC_CSP_WHITELIST= | ||||
| # The maximum number of top-k value for RAG. | |||||
| NEXT_PUBLIC_TOP_K_MAX_VALUE=10 | 
| import { useStore as useTagStore } from '@/app/components/base/tag-management/store' | import { useStore as useTagStore } from '@/app/components/base/tag-management/store' | ||||
| import TagManagementModal from '@/app/components/base/tag-management' | import TagManagementModal from '@/app/components/base/tag-management' | ||||
| import TagFilter from '@/app/components/base/tag-management/filter' | import TagFilter from '@/app/components/base/tag-management/filter' | ||||
| import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label' | |||||
| const getKey = ( | const getKey = ( | ||||
| pageIndex: number, | pageIndex: number, | ||||
| previousPageData: AppListResponse, | previousPageData: AppListResponse, | ||||
| activeTab: string, | activeTab: string, | ||||
| isCreatedByMe: boolean, | |||||
| tags: string[], | tags: string[], | ||||
| keywords: string, | keywords: string, | ||||
| ) => { | ) => { | ||||
| if (!pageIndex || previousPageData.has_more) { | if (!pageIndex || previousPageData.has_more) { | ||||
| const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords } } | |||||
| const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords, is_created_by_me: isCreatedByMe } } | |||||
| if (activeTab !== 'all') | if (activeTab !== 'all') | ||||
| params.params.mode = activeTab | params.params.mode = activeTab | ||||
| defaultTab: 'all', | defaultTab: 'all', | ||||
| }) | }) | ||||
| const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState() | const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState() | ||||
| const [isCreatedByMe, setIsCreatedByMe] = useState(false) | |||||
| const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs) | const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs) | ||||
| const [searchKeywords, setSearchKeywords] = useState(keywords) | const [searchKeywords, setSearchKeywords] = useState(keywords) | ||||
| const setKeywords = useCallback((keywords: string) => { | const setKeywords = useCallback((keywords: string) => { | ||||
| }, [setQuery]) | }, [setQuery]) | ||||
| const { data, isLoading, setSize, mutate } = useSWRInfinite( | const { data, isLoading, setSize, mutate } = useSWRInfinite( | ||||
| (pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, tagIDs, searchKeywords), | |||||
| (pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, isCreatedByMe, tagIDs, searchKeywords), | |||||
| fetchAppList, | fetchAppList, | ||||
| { revalidateFirstPage: true }, | { revalidateFirstPage: true }, | ||||
| ) | ) | ||||
| options={options} | options={options} | ||||
| /> | /> | ||||
| <div className='flex items-center gap-2'> | <div className='flex items-center gap-2'> | ||||
| <CheckboxWithLabel | |||||
| className='mr-2' | |||||
| label={t('app.showMyCreatedAppsOnly')} | |||||
| isChecked={isCreatedByMe} | |||||
| onChange={() => setIsCreatedByMe(!isCreatedByMe)} | |||||
| /> | |||||
| <TagFilter type='app' value={tagFilterValue} onChange={handleTagsChange} /> | <TagFilter type='app' value={tagFilterValue} onChange={handleTagsChange} /> | ||||
| <Input | <Input | ||||
| showLeftIcon | showLeftIcon | 
| - <code>high_quality</code> High quality: embedding using embedding model, built as vector database index | - <code>high_quality</code> High quality: embedding using embedding model, built as vector database index | ||||
| - <code>economy</code> Economy: Build using inverted index of keyword table index | - <code>economy</code> Economy: Build using inverted index of keyword table index | ||||
| </Property> | </Property> | ||||
| <Property name='doc_form' type='string' key='doc_form'> | |||||
| Format of indexed content | |||||
| - <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form | |||||
| - <code>hierarchical_model</code> Parent-child mode | |||||
| - <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions | |||||
| </Property> | |||||
| <Property name='doc_language' type='string' key='doc_language'> | |||||
| In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code> | |||||
| </Property> | |||||
| <Property name='process_rule' type='object' key='process_rule'> | <Property name='process_rule' type='object' key='process_rule'> | ||||
| Processing rules | Processing rules | ||||
| - <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom | - <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom | ||||
| - <code>segmentation</code> (object) Segmentation rules | - <code>segmentation</code> (object) Segmentation rules | ||||
| - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n | - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n | ||||
| - <code>max_tokens</code> Maximum length (token) defaults to 1000 | - <code>max_tokens</code> Maximum length (token) defaults to 1000 | ||||
| - <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval | |||||
| - <code>subchunk_segmentation</code> (object) Child chunk rules | |||||
| - <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code> | |||||
| - <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk | |||||
| - <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional) | |||||
| </Property> | </Property> | ||||
| </Properties> | </Properties> | ||||
| </Col> | </Col> | ||||
| - <code>high_quality</code> High quality: embedding using embedding model, built as vector database index | - <code>high_quality</code> High quality: embedding using embedding model, built as vector database index | ||||
| - <code>economy</code> Economy: Build using inverted index of keyword table index | - <code>economy</code> Economy: Build using inverted index of keyword table index | ||||
| - <code>doc_form</code> Format of indexed content | |||||
| - <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form | |||||
| - <code>hierarchical_model</code> Parent-child mode | |||||
| - <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions | |||||
| - <code>doc_language</code> In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code> | |||||
| - <code>process_rule</code> Processing rules | - <code>process_rule</code> Processing rules | ||||
| - <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom | - <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom | ||||
| - <code>rules</code> (object) Custom rules (in automatic mode, this field is empty) | - <code>rules</code> (object) Custom rules (in automatic mode, this field is empty) | ||||
| - <code>segmentation</code> (object) Segmentation rules | - <code>segmentation</code> (object) Segmentation rules | ||||
| - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n | - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n | ||||
| - <code>max_tokens</code> Maximum length (token) defaults to 1000 | - <code>max_tokens</code> Maximum length (token) defaults to 1000 | ||||
| - <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval | |||||
| - <code>subchunk_segmentation</code> (object) Child chunk rules | |||||
| - <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code> | |||||
| - <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk | |||||
| - <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional) | |||||
| </Property> | </Property> | ||||
| <Property name='file' type='multipart/form-data' key='file'> | <Property name='file' type='multipart/form-data' key='file'> | ||||
| Files that need to be uploaded. | Files that need to be uploaded. | ||||
| - <code>segmentation</code> (object) Segmentation rules | - <code>segmentation</code> (object) Segmentation rules | ||||
| - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n | - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n | ||||
| - <code>max_tokens</code> Maximum length (token) defaults to 1000 | - <code>max_tokens</code> Maximum length (token) defaults to 1000 | ||||
| - <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval | |||||
| - <code>subchunk_segmentation</code> (object) Child chunk rules | |||||
| - <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code> | |||||
| - <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk | |||||
| - <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional) | |||||
| </Property> | </Property> | ||||
| </Properties> | </Properties> | ||||
| </Col> | </Col> | ||||
| - <code>segmentation</code> (object) Segmentation rules | - <code>segmentation</code> (object) Segmentation rules | ||||
| - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n | - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n | ||||
| - <code>max_tokens</code> Maximum length (token) defaults to 1000 | - <code>max_tokens</code> Maximum length (token) defaults to 1000 | ||||
| - <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval | |||||
| - <code>subchunk_segmentation</code> (object) Child chunk rules | |||||
| - <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code> | |||||
| - <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk | |||||
| - <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional) | |||||
| </Property> | </Property> | ||||
| </Properties> | </Properties> | ||||
| </Col> | </Col> | ||||
| <Heading | <Heading | ||||
| url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' | url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' | ||||
| method='POST' | method='POST' | ||||
| title='Update a Chunk in a Document ' | |||||
| title='Update a Chunk in a Document' | |||||
| name='#update_segment' | name='#update_segment' | ||||
| /> | /> | ||||
| <Row> | <Row> | ||||
| - <code>answer</code> (text) Answer content, passed if the knowledge is in Q&A mode (optional) | - <code>answer</code> (text) Answer content, passed if the knowledge is in Q&A mode (optional) | ||||
| - <code>keywords</code> (list) Keyword (optional) | - <code>keywords</code> (list) Keyword (optional) | ||||
| - <code>enabled</code> (bool) False / true (optional) | - <code>enabled</code> (bool) False / true (optional) | ||||
| - <code>regenerate_child_chunks</code> (bool) Whether to regenerate child chunks (optional) | |||||
| </Property> | </Property> | ||||
| </Properties> | </Properties> | ||||
| </Col> | </Col> | 
| - <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 | - <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 | ||||
| - <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建 | - <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建 | ||||
| </Property> | </Property> | ||||
| <Property name='doc_form' type='string' key='doc_form'> | |||||
| 索引内容的形式 | |||||
| - <code>text_model</code> text 文档直接 embedding,经济模式默认为该模式 | |||||
| - <code>hierarchical_model</code> parent-child 模式 | |||||
| - <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding | |||||
| </Property> | |||||
| <Property name='doc_language' type='string' key='doc_language'> | |||||
| 在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code> | |||||
| </Property> | |||||
| <Property name='process_rule' type='object' key='process_rule'> | <Property name='process_rule' type='object' key='process_rule'> | ||||
| 处理规则 | 处理规则 | ||||
| - <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 | - <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 | ||||
| - <code>remove_urls_emails</code> 删除 URL、电子邮件地址 | - <code>remove_urls_emails</code> 删除 URL、电子邮件地址 | ||||
| - <code>enabled</code> (bool) 是否选中该规则,不传入文档 ID 时代表默认值 | - <code>enabled</code> (bool) 是否选中该规则,不传入文档 ID 时代表默认值 | ||||
| - <code>segmentation</code> (object) 分段规则 | - <code>segmentation</code> (object) 分段规则 | ||||
| - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n | |||||
| - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 <code>\n</code> | |||||
| - <code>max_tokens</code> 最大长度(token)默认为 1000 | - <code>max_tokens</code> 最大长度(token)默认为 1000 | ||||
| - <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回 | |||||
| - <code>subchunk_segmentation</code> (object) 子分段规则 | |||||
| - <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code> | |||||
| - <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度 | |||||
| - <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填) | |||||
| </Property> | </Property> | ||||
| </Properties> | </Properties> | ||||
| </Col> | </Col> | ||||
| - <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 | - <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 | ||||
| - <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建 | - <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建 | ||||
| - <code>doc_form</code> 索引内容的形式 | |||||
| - <code>text_model</code> text 文档直接 embedding,经济模式默认为该模式 | |||||
| - <code>hierarchical_model</code> parent-child 模式 | |||||
| - <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding | |||||
| - <code>doc_language</code> 在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code> | |||||
| - <code>process_rule</code> 处理规则 | - <code>process_rule</code> 处理规则 | ||||
| - <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 | - <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 | ||||
| - <code>rules</code> (object) 自定义规则(自动模式下,该字段为空) | - <code>rules</code> (object) 自定义规则(自动模式下,该字段为空) | ||||
| - <code>segmentation</code> (object) 分段规则 | - <code>segmentation</code> (object) 分段规则 | ||||
| - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n | - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n | ||||
| - <code>max_tokens</code> 最大长度(token)默认为 1000 | - <code>max_tokens</code> 最大长度(token)默认为 1000 | ||||
| - <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回 | |||||
| - <code>subchunk_segmentation</code> (object) 子分段规则 | |||||
| - <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code> | |||||
| - <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度 | |||||
| - <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填) | |||||
| </Property> | </Property> | ||||
| <Property name='file' type='multipart/form-data' key='file'> | <Property name='file' type='multipart/form-data' key='file'> | ||||
| 需要上传的文件。 | 需要上传的文件。 | ||||
| <Heading | <Heading | ||||
| url='/datasets/{dataset_id}/documents/{document_id}/update-by-text' | url='/datasets/{dataset_id}/documents/{document_id}/update-by-text' | ||||
| method='POST' | method='POST' | ||||
| title='通过文本更新文档 ' | |||||
| title='通过文本更新文档' | |||||
| name='#update-by-text' | name='#update-by-text' | ||||
| /> | /> | ||||
| <Row> | <Row> | ||||
| - <code>segmentation</code> (object) 分段规则 | - <code>segmentation</code> (object) 分段规则 | ||||
| - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n | - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n | ||||
| - <code>max_tokens</code> 最大长度(token)默认为 1000 | - <code>max_tokens</code> 最大长度(token)默认为 1000 | ||||
| - <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回 | |||||
| - <code>subchunk_segmentation</code> (object) 子分段规则 | |||||
| - <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code> | |||||
| - <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度 | |||||
| - <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填) | |||||
| </Property> | </Property> | ||||
| </Properties> | </Properties> | ||||
| </Col> | </Col> | ||||
| <Heading | <Heading | ||||
| url='/datasets/{dataset_id}/documents/{document_id}/update-by-file' | url='/datasets/{dataset_id}/documents/{document_id}/update-by-file' | ||||
| method='POST' | method='POST' | ||||
| title='通过文件更新文档 ' | |||||
| title='通过文件更新文档' | |||||
| name='#update-by-file' | name='#update-by-file' | ||||
| /> | /> | ||||
| <Row> | <Row> | ||||
| - <code>segmentation</code> (object) 分段规则 | - <code>segmentation</code> (object) 分段规则 | ||||
| - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n | - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n | ||||
| - <code>max_tokens</code> 最大长度(token)默认为 1000 | - <code>max_tokens</code> 最大长度(token)默认为 1000 | ||||
| - <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回 | |||||
| - <code>subchunk_segmentation</code> (object) 子分段规则 | |||||
| - <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code> | |||||
| - <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度 | |||||
| - <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填) | |||||
| </Property> | </Property> | ||||
| </Properties> | </Properties> | ||||
| </Col> | </Col> | ||||
| - <code>answer</code> (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值 | - <code>answer</code> (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值 | ||||
| - <code>keywords</code> (list) 关键字,非必填 | - <code>keywords</code> (list) 关键字,非必填 | ||||
| - <code>enabled</code> (bool) false/true,非必填 | - <code>enabled</code> (bool) false/true,非必填 | ||||
| - <code>regenerate_child_chunks</code> (bool) 是否重新生成子分段,非必填 | |||||
| </Property> | </Property> | ||||
| </Properties> | </Properties> | ||||
| </Col> | </Col> | 
| const [clientY, setClientY] = useState(0) | const [clientY, setClientY] = useState(0) | ||||
| const [isResizing, setIsResizing] = useState(false) | const [isResizing, setIsResizing] = useState(false) | ||||
| const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(getComputedStyle(document.body).userSelect) | const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(getComputedStyle(document.body).userSelect) | ||||
| const [oldHeight, setOldHeight] = useState(height) | |||||
| const handleStartResize = useCallback((e: React.MouseEvent<HTMLElement>) => { | const handleStartResize = useCallback((e: React.MouseEvent<HTMLElement>) => { | ||||
| setClientY(e.clientY) | setClientY(e.clientY) | ||||
| setIsResizing(true) | setIsResizing(true) | ||||
| setOldHeight(height) | |||||
| setPrevUserSelectStyle(getComputedStyle(document.body).userSelect) | setPrevUserSelectStyle(getComputedStyle(document.body).userSelect) | ||||
| document.body.style.userSelect = 'none' | document.body.style.userSelect = 'none' | ||||
| }, []) | |||||
| }, [height]) | |||||
| const handleStopResize = useCallback(() => { | const handleStopResize = useCallback(() => { | ||||
| setIsResizing(false) | setIsResizing(false) | ||||
| return | return | ||||
| const offset = e.clientY - clientY | const offset = e.clientY - clientY | ||||
| let newHeight = height + offset | |||||
| setClientY(e.clientY) | |||||
| let newHeight = oldHeight + offset | |||||
| if (newHeight < minHeight) | if (newHeight < minHeight) | ||||
| newHeight = minHeight | newHeight = minHeight | ||||
| onHeightChange(newHeight) | onHeightChange(newHeight) | 
| import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' | import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' | ||||
| import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block' | import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block' | ||||
| import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' | import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' | ||||
| import { useFeaturesStore } from '@/app/components/base/features/hooks' | |||||
| export type ISimplePromptInput = { | export type ISimplePromptInput = { | ||||
| mode: AppType | mode: AppType | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const media = useBreakpoints() | const media = useBreakpoints() | ||||
| const isMobile = media === MediaType.mobile | const isMobile = media === MediaType.mobile | ||||
| const featuresStore = useFeaturesStore() | |||||
| const { | |||||
| features, | |||||
| setFeatures, | |||||
| } = featuresStore!.getState() | |||||
| const { eventEmitter } = useEventEmitterContextContext() | const { eventEmitter } = useEventEmitterContextContext() | ||||
| const { | const { | ||||
| }) | }) | ||||
| setModelConfig(newModelConfig) | setModelConfig(newModelConfig) | ||||
| setPrevPromptConfig(modelConfig.configs) | setPrevPromptConfig(modelConfig.configs) | ||||
| if (mode !== AppType.completion) | |||||
| if (mode !== AppType.completion) { | |||||
| setIntroduction(res.opening_statement) | setIntroduction(res.opening_statement) | ||||
| const newFeatures = produce(features, (draft) => { | |||||
| draft.opening = { | |||||
| ...draft.opening, | |||||
| enabled: !!res.opening_statement, | |||||
| opening_statement: res.opening_statement, | |||||
| } | |||||
| }) | |||||
| setFeatures(newFeatures) | |||||
| } | |||||
| showAutomaticFalse() | showAutomaticFalse() | ||||
| } | } | ||||
| const minHeight = initEditorHeight || 228 | const minHeight = initEditorHeight || 228 | 
| const { | const { | ||||
| modelList: rerankModelList, | modelList: rerankModelList, | ||||
| defaultModel: rerankDefaultModel, | |||||
| currentModel: isRerankDefaultModelValid, | |||||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | ||||
| const { | const { | ||||
| currentModel: currentRerankModel, | currentModel: currentRerankModel, | ||||
| } = useCurrentProviderAndModel( | } = useCurrentProviderAndModel( | ||||
| rerankModelList, | rerankModelList, | ||||
| rerankDefaultModel | |||||
| ? { | |||||
| ...rerankDefaultModel, | |||||
| provider: rerankDefaultModel.provider.provider, | |||||
| } | |||||
| : undefined, | |||||
| { | |||||
| provider: datasetConfigs.reranking_model?.reranking_provider_name, | |||||
| model: datasetConfigs.reranking_model?.reranking_model_name, | |||||
| }, | |||||
| ) | ) | ||||
| const rerankModel = (() => { | |||||
| if (datasetConfigs.reranking_model?.reranking_provider_name) { | |||||
| return { | |||||
| provider_name: datasetConfigs.reranking_model.reranking_provider_name, | |||||
| model_name: datasetConfigs.reranking_model.reranking_model_name, | |||||
| } | |||||
| const rerankModel = useMemo(() => { | |||||
| return { | |||||
| provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '', | |||||
| model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '', | |||||
| } | } | ||||
| else if (rerankDefaultModel) { | |||||
| return { | |||||
| provider_name: rerankDefaultModel.provider.provider, | |||||
| model_name: rerankDefaultModel.model, | |||||
| } | |||||
| } | |||||
| })() | |||||
| }, [datasetConfigs.reranking_model]) | |||||
| const handleParamChange = (key: string, value: number) => { | const handleParamChange = (key: string, value: number) => { | ||||
| if (key === 'top_k') { | if (key === 'top_k') { | ||||
| } | } | ||||
| const handleRerankModeChange = (mode: RerankingModeEnum) => { | const handleRerankModeChange = (mode: RerankingModeEnum) => { | ||||
| if (mode === datasetConfigs.reranking_mode) | |||||
| return | |||||
| if (mode === RerankingModeEnum.RerankingModel && !currentRerankModel) | |||||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | |||||
| onChange({ | onChange({ | ||||
| ...datasetConfigs, | ...datasetConfigs, | ||||
| reranking_mode: mode, | reranking_mode: mode, | ||||
| const canManuallyToggleRerank = useMemo(() => { | const canManuallyToggleRerank = useMemo(() => { | ||||
| return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic) | return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic) | ||||
| || selectedDatasetsMode.allExternal | |||||
| || selectedDatasetsMode.allExternal | |||||
| }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal]) | }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal]) | ||||
| const showRerankModel = useMemo(() => { | const showRerankModel = useMemo(() => { | ||||
| if (!canManuallyToggleRerank) | if (!canManuallyToggleRerank) | ||||
| return true | return true | ||||
| else if (canManuallyToggleRerank && !isRerankDefaultModelValid) | |||||
| return false | |||||
| return datasetConfigs.reranking_enable | return datasetConfigs.reranking_enable | ||||
| }, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid]) | |||||
| }, [datasetConfigs.reranking_enable, canManuallyToggleRerank]) | |||||
| const handleDisabledSwitchClick = useCallback(() => { | |||||
| if (!currentRerankModel && !showRerankModel) | |||||
| const handleDisabledSwitchClick = useCallback((enable: boolean) => { | |||||
| if (!currentRerankModel && enable) | |||||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | ||||
| }, [currentRerankModel, showRerankModel, t]) | |||||
| useEffect(() => { | |||||
| if (canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) { | |||||
| onChange({ | |||||
| ...datasetConfigs, | |||||
| reranking_enable: showRerankModel, | |||||
| }) | |||||
| } | |||||
| }, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange]) | |||||
| onChange({ | |||||
| ...datasetConfigs, | |||||
| reranking_enable: enable, | |||||
| }) | |||||
| // eslint-disable-next-line react-hooks/exhaustive-deps | |||||
| }, [currentRerankModel, datasetConfigs, onChange]) | |||||
| return ( | return ( | ||||
| <div> | <div> | ||||
| <div className='flex items-center'> | <div className='flex items-center'> | ||||
| { | { | ||||
| selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && ( | selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && ( | ||||
| <div | |||||
| className='flex items-center' | |||||
| onClick={handleDisabledSwitchClick} | |||||
| > | |||||
| <Switch | |||||
| size='md' | |||||
| defaultValue={showRerankModel} | |||||
| disabled={!currentRerankModel || !canManuallyToggleRerank} | |||||
| onChange={(v) => { | |||||
| if (canManuallyToggleRerank) { | |||||
| onChange({ | |||||
| ...datasetConfigs, | |||||
| reranking_enable: v, | |||||
| }) | |||||
| } | |||||
| }} | |||||
| /> | |||||
| </div> | |||||
| <Switch | |||||
| size='md' | |||||
| defaultValue={showRerankModel} | |||||
| disabled={!canManuallyToggleRerank} | |||||
| onChange={handleDisabledSwitchClick} | |||||
| /> | |||||
| ) | ) | ||||
| } | } | ||||
| <div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div> | <div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div> | ||||
| triggerClassName='ml-1 w-4 h-4' | triggerClassName='ml-1 w-4 h-4' | ||||
| /> | /> | ||||
| </div> | </div> | ||||
| <div> | |||||
| <ModelSelector | |||||
| defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }} | |||||
| onSelect={(v) => { | |||||
| onChange({ | |||||
| ...datasetConfigs, | |||||
| reranking_model: { | |||||
| reranking_provider_name: v.provider, | |||||
| reranking_model_name: v.model, | |||||
| }, | |||||
| }) | |||||
| }} | |||||
| modelList={rerankModelList} | |||||
| /> | |||||
| </div> | |||||
| { | |||||
| showRerankModel && ( | |||||
| <div> | |||||
| <ModelSelector | |||||
| defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }} | |||||
| onSelect={(v) => { | |||||
| onChange({ | |||||
| ...datasetConfigs, | |||||
| reranking_model: { | |||||
| reranking_provider_name: v.provider, | |||||
| reranking_model_name: v.model, | |||||
| }, | |||||
| }) | |||||
| }} | |||||
| modelList={rerankModelList} | |||||
| /> | |||||
| </div> | |||||
| )} | |||||
| </div> | </div> | ||||
| ) | ) | ||||
| } | } | 
| import Button from '@/app/components/base/button' | import Button from '@/app/components/base/button' | ||||
| import { RETRIEVE_TYPE } from '@/types/app' | import { RETRIEVE_TYPE } from '@/types/app' | ||||
| import Toast from '@/app/components/base/toast' | import Toast from '@/app/components/base/toast' | ||||
| import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||||
| import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | ||||
| import { RerankingModeEnum } from '@/models/datasets' | import { RerankingModeEnum } from '@/models/datasets' | ||||
| import type { DataSet } from '@/models/datasets' | import type { DataSet } from '@/models/datasets' | ||||
| }, [datasetConfigs]) | }, [datasetConfigs]) | ||||
| const { | const { | ||||
| defaultModel: rerankDefaultModel, | |||||
| currentModel: isRerankDefaultModelValid, | |||||
| modelList: rerankModelList, | |||||
| currentModel: rerankDefaultModel, | |||||
| currentProvider: rerankDefaultProvider, | currentProvider: rerankDefaultProvider, | ||||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | ||||
| const { | |||||
| currentModel: isCurrentRerankModelValid, | |||||
| } = useCurrentProviderAndModel( | |||||
| rerankModelList, | |||||
| { | |||||
| provider: tempDataSetConfigs.reranking_model?.reranking_provider_name ?? '', | |||||
| model: tempDataSetConfigs.reranking_model?.reranking_model_name ?? '', | |||||
| }, | |||||
| ) | |||||
| const isValid = () => { | const isValid = () => { | ||||
| let errMsg = '' | let errMsg = '' | ||||
| if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) { | if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) { | ||||
| if (tempDataSetConfigs.reranking_enable | if (tempDataSetConfigs.reranking_enable | ||||
| && tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel | && tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel | ||||
| && !isRerankDefaultModelValid | |||||
| && !isCurrentRerankModelValid | |||||
| ) | ) | ||||
| errMsg = t('appDebug.datasetConfig.rerankModelRequired') | errMsg = t('appDebug.datasetConfig.rerankModelRequired') | ||||
| } | } | ||||
| const handleSave = () => { | const handleSave = () => { | ||||
| if (!isValid()) | if (!isValid()) | ||||
| return | return | ||||
| const config = { ...tempDataSetConfigs } | |||||
| if (config.retrieval_model === RETRIEVE_TYPE.multiWay | |||||
| && config.reranking_mode === RerankingModeEnum.RerankingModel | |||||
| && !config.reranking_model) { | |||||
| config.reranking_model = { | |||||
| reranking_provider_name: rerankDefaultModel?.provider?.provider, | |||||
| reranking_model_name: rerankDefaultModel?.model, | |||||
| } as any | |||||
| } | |||||
| setDatasetConfigs(config) | |||||
| setDatasetConfigs(tempDataSetConfigs) | |||||
| setRerankSettingModalOpen(false) | setRerankSettingModalOpen(false) | ||||
| } | } | ||||
| reranking_enable: restConfigs.reranking_enable, | reranking_enable: restConfigs.reranking_enable, | ||||
| }, selectedDatasets, selectedDatasets, { | }, selectedDatasets, selectedDatasets, { | ||||
| provider: rerankDefaultProvider?.provider, | provider: rerankDefaultProvider?.provider, | ||||
| model: isRerankDefaultModelValid?.model, | |||||
| model: rerankDefaultModel?.model, | |||||
| }) | }) | ||||
| setTempDataSetConfigs({ | setTempDataSetConfigs({ | ||||
| ...retrievalConfig, | ...retrievalConfig, | ||||
| reranking_model: restConfigs.reranking_model && { | |||||
| reranking_provider_name: restConfigs.reranking_model.reranking_provider_name, | |||||
| reranking_model_name: restConfigs.reranking_model.reranking_model_name, | |||||
| reranking_model: { | |||||
| reranking_provider_name: retrievalConfig.reranking_model?.provider || '', | |||||
| reranking_model_name: retrievalConfig.reranking_model?.model || '', | |||||
| }, | }, | ||||
| retrieval_model, | retrieval_model, | ||||
| score_threshold_enabled, | score_threshold_enabled, | 
| return ( | return ( | ||||
| <div> | <div> | ||||
| <div className='px-3 pt-5 h-[52px] space-x-3 rounded-lg border border-components-panel-border'> | |||||
| <div className='px-3 pt-5 pb-2 space-x-3 rounded-lg border border-components-panel-border'> | |||||
| <Slider | <Slider | ||||
| className={cn('grow h-0.5 !bg-util-colors-teal-teal-500 rounded-full')} | className={cn('grow h-0.5 !bg-util-colors-teal-teal-500 rounded-full')} | ||||
| max={1.0} | max={1.0} | ||||
| onChange={v => onChange({ value: [v, (10 - v * 10) / 10] })} | onChange={v => onChange({ value: [v, (10 - v * 10) / 10] })} | ||||
| trackClassName='weightedScoreSliderTrack' | trackClassName='weightedScoreSliderTrack' | ||||
| /> | /> | ||||
| <div className='flex justify-between mt-1'> | |||||
| <div className='flex justify-between mt-3'> | |||||
| <div className='shrink-0 flex items-center w-[90px] system-xs-semibold-uppercase text-util-colors-blue-light-blue-light-500'> | <div className='shrink-0 flex items-center w-[90px] system-xs-semibold-uppercase text-util-colors-blue-light-blue-light-500'> | ||||
| <div className='mr-1 truncate uppercase' title={t('dataset.weightedScore.semantic') || ''}> | <div className='mr-1 truncate uppercase' title={t('dataset.weightedScore.semantic') || ''}> | ||||
| {t('dataset.weightedScore.semantic')} | {t('dataset.weightedScore.semantic')} | 
| import Button from '@/app/components/base/button' | import Button from '@/app/components/base/button' | ||||
| import Input from '@/app/components/base/input' | import Input from '@/app/components/base/input' | ||||
| import Textarea from '@/app/components/base/textarea' | import Textarea from '@/app/components/base/textarea' | ||||
| import { type DataSet, RerankingModeEnum } from '@/models/datasets' | |||||
| import { type DataSet } from '@/models/datasets' | |||||
| import { useToastContext } from '@/app/components/base/toast' | import { useToastContext } from '@/app/components/base/toast' | ||||
| import { updateDatasetSetting } from '@/service/datasets' | import { updateDatasetSetting } from '@/service/datasets' | ||||
| import { useAppContext } from '@/context/app-context' | import { useAppContext } from '@/context/app-context' | ||||
| import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings' | import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings' | ||||
| import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' | import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' | ||||
| import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' | import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' | ||||
| import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' | |||||
| import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' | |||||
| import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' | import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' | ||||
| import PermissionSelector from '@/app/components/datasets/settings/permission-selector' | import PermissionSelector from '@/app/components/datasets/settings/permission-selector' | ||||
| import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' | import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' | ||||
| } | } | ||||
| if ( | if ( | ||||
| !isReRankModelSelected({ | !isReRankModelSelected({ | ||||
| rerankDefaultModel, | |||||
| isRerankDefaultModelValid: !!isRerankDefaultModelValid, | |||||
| rerankModelList, | rerankModelList, | ||||
| retrievalConfig, | retrievalConfig, | ||||
| indexMethod, | indexMethod, | ||||
| notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') }) | notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') }) | ||||
| return | return | ||||
| } | } | ||||
| const postRetrievalConfig = ensureRerankModelSelected({ | |||||
| rerankDefaultModel: rerankDefaultModel!, | |||||
| retrievalConfig: { | |||||
| ...retrievalConfig, | |||||
| reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel, | |||||
| }, | |||||
| indexMethod, | |||||
| }) | |||||
| try { | try { | ||||
| setLoading(true) | setLoading(true) | ||||
| const { id, name, description, permission } = localeCurrentDataset | const { id, name, description, permission } = localeCurrentDataset | ||||
| permission, | permission, | ||||
| indexing_technique: indexMethod, | indexing_technique: indexMethod, | ||||
| retrieval_model: { | retrieval_model: { | ||||
| ...postRetrievalConfig, | |||||
| score_threshold: postRetrievalConfig.score_threshold_enabled ? postRetrievalConfig.score_threshold : 0, | |||||
| ...retrievalConfig, | |||||
| score_threshold: retrievalConfig.score_threshold_enabled ? retrievalConfig.score_threshold : 0, | |||||
| }, | }, | ||||
| embedding_model: localeCurrentDataset.embedding_model, | embedding_model: localeCurrentDataset.embedding_model, | ||||
| embedding_model_provider: localeCurrentDataset.embedding_model_provider, | embedding_model_provider: localeCurrentDataset.embedding_model_provider, | ||||
| onSave({ | onSave({ | ||||
| ...localeCurrentDataset, | ...localeCurrentDataset, | ||||
| indexing_technique: indexMethod, | indexing_technique: indexMethod, | ||||
| retrieval_model_dict: postRetrievalConfig, | |||||
| retrieval_model_dict: retrievalConfig, | |||||
| }) | }) | ||||
| } | } | ||||
| catch (e) { | catch (e) { | 
| setDatasetConfigs({ | setDatasetConfigs({ | ||||
| ...retrievalConfig, | ...retrievalConfig, | ||||
| reranking_model: restConfigs.reranking_model && { | |||||
| reranking_provider_name: restConfigs.reranking_model.reranking_provider_name, | |||||
| reranking_model_name: restConfigs.reranking_model.reranking_model_name, | |||||
| reranking_model: { | |||||
| reranking_provider_name: retrievalConfig?.reranking_model?.provider || '', | |||||
| reranking_model_name: retrievalConfig?.reranking_model?.model || '', | |||||
| }, | }, | ||||
| retrieval_model, | retrieval_model, | ||||
| score_threshold_enabled, | score_threshold_enabled, | 
| inputs?: Record<string, any> | inputs?: Record<string, any> | ||||
| inputsForm?: InputForm[] | inputsForm?: InputForm[] | ||||
| theme?: Theme | null | theme?: Theme | null | ||||
| isResponding?: boolean | |||||
| } | } | ||||
| const ChatInputArea = ({ | const ChatInputArea = ({ | ||||
| showFeatureBar, | showFeatureBar, | ||||
| inputs = {}, | inputs = {}, | ||||
| inputsForm = [], | inputsForm = [], | ||||
| theme, | theme, | ||||
| isResponding, | |||||
| }: ChatInputAreaProps) => { | }: ChatInputAreaProps) => { | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const { notify } = useToastContext() | const { notify } = useToastContext() | ||||
| const historyRef = useRef(['']) | const historyRef = useRef(['']) | ||||
| const [currentIndex, setCurrentIndex] = useState(-1) | const [currentIndex, setCurrentIndex] = useState(-1) | ||||
| const handleSend = () => { | const handleSend = () => { | ||||
| if (isResponding) { | |||||
| notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') }) | |||||
| return | |||||
| } | |||||
| if (onSend) { | if (onSend) { | ||||
| const { files, setFiles } = filesStore.getState() | const { files, setFiles } = filesStore.getState() | ||||
| if (files.find(item => item.transferMethod === TransferMethod.local_file && !item.uploadedId)) { | if (files.find(item => item.transferMethod === TransferMethod.local_file && !item.uploadedId)) { | ||||
| setQuery(historyRef.current[currentIndex + 1]) | setQuery(historyRef.current[currentIndex + 1]) | ||||
| } | } | ||||
| else if (currentIndex === historyRef.current.length - 1) { | else if (currentIndex === historyRef.current.length - 1) { | ||||
| // If it is the last element, clear the input box | |||||
| // If it is the last element, clear the input box | |||||
| setCurrentIndex(historyRef.current.length) | setCurrentIndex(historyRef.current.length) | ||||
| setQuery('') | setQuery('') | ||||
| } | } | ||||
| 'p-1 w-full leading-6 body-lg-regular text-text-tertiary outline-none', | 'p-1 w-full leading-6 body-lg-regular text-text-tertiary outline-none', | ||||
| )} | )} | ||||
| placeholder={t('common.chat.inputPlaceholder') || ''} | placeholder={t('common.chat.inputPlaceholder') || ''} | ||||
| autoFocus | |||||
| autoSize={{ minRows: 1 }} | autoSize={{ minRows: 1 }} | ||||
| onResize={handleTextareaResize} | onResize={handleTextareaResize} | ||||
| value={query} | value={query} | 
| inputs={inputs} | inputs={inputs} | ||||
| inputsForm={inputsForm} | inputsForm={inputsForm} | ||||
| theme={themeBuilder?.theme} | theme={themeBuilder?.theme} | ||||
| isResponding={isResponding} | |||||
| /> | /> | ||||
| ) | ) | ||||
| } | } | 
| } = item | } = item | ||||
| return ( | return ( | ||||
| <div className='flex justify-end mb-2 last:mb-0 pl-10'> | |||||
| <div className='group relative mr-4'> | |||||
| <div className='flex justify-end mb-2 last:mb-0 pl-14'> | |||||
| <div className='group relative mr-4 max-w-full'> | |||||
| <div | <div | ||||
| className='px-4 py-3 bg-[#D1E9FF]/50 rounded-2xl text-sm text-gray-900' | className='px-4 py-3 bg-[#D1E9FF]/50 rounded-2xl text-sm text-gray-900' | ||||
| style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}} | style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}} | 
| } | } | ||||
| else if (language === 'echarts') { | else if (language === 'echarts') { | ||||
| return ( | return ( | ||||
| <div style={{ minHeight: '350px', minWidth: '700px' }}> | |||||
| <div style={{ minHeight: '350px', minWidth: '100%', overflowX: 'scroll' }}> | |||||
| <ErrorBoundary> | <ErrorBoundary> | ||||
| <ReactEcharts option={chartData} /> | |||||
| <ReactEcharts option={chartData} style={{ minWidth: '700px' }} /> | |||||
| </ErrorBoundary> | </ErrorBoundary> | ||||
| </div> | </div> | ||||
| ) | ) | 
| enable: boolean | enable: boolean | ||||
| } | } | ||||
| const maxTopK = (() => { | |||||
| const configValue = parseInt(globalThis.document?.body?.getAttribute('data-public-top-k-max-value') || '', 10) | |||||
| if (configValue && !isNaN(configValue)) | |||||
| return configValue | |||||
| return 10 | |||||
| })() | |||||
| const VALUE_LIMIT = { | const VALUE_LIMIT = { | ||||
| default: 2, | default: 2, | ||||
| step: 1, | step: 1, | ||||
| min: 1, | min: 1, | ||||
| max: 10, | |||||
| max: maxTopK, | |||||
| } | } | ||||
| const key = 'top_k' | const key = 'top_k' | 
| import { RerankingModeEnum } from '@/models/datasets' | import { RerankingModeEnum } from '@/models/datasets' | ||||
| export const isReRankModelSelected = ({ | export const isReRankModelSelected = ({ | ||||
| rerankDefaultModel, | |||||
| isRerankDefaultModelValid, | |||||
| retrievalConfig, | retrievalConfig, | ||||
| rerankModelList, | rerankModelList, | ||||
| indexMethod, | indexMethod, | ||||
| }: { | }: { | ||||
| rerankDefaultModel?: DefaultModelResponse | |||||
| isRerankDefaultModelValid: boolean | |||||
| retrievalConfig: RetrievalConfig | retrievalConfig: RetrievalConfig | ||||
| rerankModelList: Model[] | rerankModelList: Model[] | ||||
| indexMethod?: string | indexMethod?: string | ||||
| return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name) | return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name) | ||||
| } | } | ||||
| if (isRerankDefaultModelValid) | |||||
| return !!rerankDefaultModel | |||||
| return false | return false | ||||
| })() | })() | ||||
| if ( | |||||
| indexMethod === 'high_quality' | |||||
| && ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrievalConfig.search_method)) | |||||
| && retrievalConfig.reranking_enable | |||||
| && !rerankModelSelected | |||||
| ) | |||||
| return false | |||||
| if ( | if ( | ||||
| indexMethod === 'high_quality' | indexMethod === 'high_quality' | ||||
| && (retrievalConfig.search_method === RETRIEVE_METHOD.hybrid && retrievalConfig.reranking_mode !== RerankingModeEnum.WeightedScore) | && (retrievalConfig.search_method === RETRIEVE_METHOD.hybrid && retrievalConfig.reranking_mode !== RerankingModeEnum.WeightedScore) | 
| import type { RetrievalConfig } from '@/types/app' | import type { RetrievalConfig } from '@/types/app' | ||||
| type Props = { | type Props = { | ||||
| disabled?: boolean | |||||
| value: RetrievalConfig | value: RetrievalConfig | ||||
| onChange: (value: RetrievalConfig) => void | onChange: (value: RetrievalConfig) => void | ||||
| } | } | ||||
| const EconomicalRetrievalMethodConfig: FC<Props> = ({ | const EconomicalRetrievalMethodConfig: FC<Props> = ({ | ||||
| disabled = false, | |||||
| value, | value, | ||||
| onChange, | onChange, | ||||
| }) => { | }) => { | ||||
| return ( | return ( | ||||
| <div className='space-y-2'> | <div className='space-y-2'> | ||||
| <OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />} | |||||
| <OptionCard | |||||
| disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />} | |||||
| title={t('dataset.retrieval.invertedIndex.title')} | title={t('dataset.retrieval.invertedIndex.title')} | ||||
| description={t('dataset.retrieval.invertedIndex.description')} isActive | description={t('dataset.retrieval.invertedIndex.description')} isActive | ||||
| activeHeaderClassName='bg-dataset-option-card-purple-gradient' | activeHeaderClassName='bg-dataset-option-card-purple-gradient' | 
| 'use client' | 'use client' | ||||
| import type { FC } from 'react' | import type { FC } from 'react' | ||||
| import React from 'react' | |||||
| import React, { useCallback } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | import { useTranslation } from 'react-i18next' | ||||
| import Image from 'next/image' | import Image from 'next/image' | ||||
| import RetrievalParamConfig from '../retrieval-param-config' | import RetrievalParamConfig from '../retrieval-param-config' | ||||
| import type { RetrievalConfig } from '@/types/app' | import type { RetrievalConfig } from '@/types/app' | ||||
| import { RETRIEVE_METHOD } from '@/types/app' | import { RETRIEVE_METHOD } from '@/types/app' | ||||
| import { useProviderContext } from '@/context/provider-context' | import { useProviderContext } from '@/context/provider-context' | ||||
| import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||||
| import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | ||||
| import { | import { | ||||
| DEFAULT_WEIGHTED_SCORE, | DEFAULT_WEIGHTED_SCORE, | ||||
| import Badge from '@/app/components/base/badge' | import Badge from '@/app/components/base/badge' | ||||
| type Props = { | type Props = { | ||||
| disabled?: boolean | |||||
| value: RetrievalConfig | value: RetrievalConfig | ||||
| onChange: (value: RetrievalConfig) => void | onChange: (value: RetrievalConfig) => void | ||||
| } | } | ||||
| const RetrievalMethodConfig: FC<Props> = ({ | const RetrievalMethodConfig: FC<Props> = ({ | ||||
| value: passValue, | |||||
| disabled = false, | |||||
| value, | |||||
| onChange, | onChange, | ||||
| }) => { | }) => { | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const { supportRetrievalMethods } = useProviderContext() | const { supportRetrievalMethods } = useProviderContext() | ||||
| const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank) | |||||
| const value = (() => { | |||||
| if (!passValue.reranking_model.reranking_model_name) { | |||||
| return { | |||||
| ...passValue, | |||||
| reranking_model: { | |||||
| reranking_provider_name: rerankDefaultModel?.provider.provider || '', | |||||
| reranking_model_name: rerankDefaultModel?.model || '', | |||||
| }, | |||||
| reranking_mode: passValue.reranking_mode || (rerankDefaultModel ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore), | |||||
| weights: passValue.weights || { | |||||
| weight_type: WeightedScoreEnum.Customized, | |||||
| vector_setting: { | |||||
| vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic, | |||||
| embedding_provider_name: '', | |||||
| embedding_model_name: '', | |||||
| }, | |||||
| keyword_setting: { | |||||
| keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword, | |||||
| }, | |||||
| }, | |||||
| } | |||||
| const { | |||||
| defaultModel: rerankDefaultModel, | |||||
| currentModel: isRerankDefaultModelValid, | |||||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||||
| const onSwitch = useCallback((retrieveMethod: RETRIEVE_METHOD) => { | |||||
| if ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrieveMethod)) { | |||||
| onChange({ | |||||
| ...value, | |||||
| search_method: retrieveMethod, | |||||
| ...(!value.reranking_model.reranking_model_name | |||||
| ? { | |||||
| reranking_model: { | |||||
| reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', | |||||
| reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '', | |||||
| }, | |||||
| reranking_enable: !!isRerankDefaultModelValid, | |||||
| } | |||||
| : { | |||||
| reranking_enable: true, | |||||
| }), | |||||
| }) | |||||
| } | } | ||||
| return passValue | |||||
| })() | |||||
| if (retrieveMethod === RETRIEVE_METHOD.hybrid) { | |||||
| onChange({ | |||||
| ...value, | |||||
| search_method: retrieveMethod, | |||||
| ...(!value.reranking_model.reranking_model_name | |||||
| ? { | |||||
| reranking_model: { | |||||
| reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', | |||||
| reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '', | |||||
| }, | |||||
| reranking_enable: !!isRerankDefaultModelValid, | |||||
| reranking_mode: isRerankDefaultModelValid ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore, | |||||
| } | |||||
| : { | |||||
| reranking_enable: true, | |||||
| reranking_mode: RerankingModeEnum.RerankingModel, | |||||
| }), | |||||
| ...(!value.weights | |||||
| ? { | |||||
| weights: { | |||||
| weight_type: WeightedScoreEnum.Customized, | |||||
| vector_setting: { | |||||
| vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic, | |||||
| embedding_provider_name: '', | |||||
| embedding_model_name: '', | |||||
| }, | |||||
| keyword_setting: { | |||||
| keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword, | |||||
| }, | |||||
| }, | |||||
| } | |||||
| : {}), | |||||
| }) | |||||
| } | |||||
| }, [value, rerankDefaultModel, isRerankDefaultModelValid, onChange]) | |||||
| return ( | return ( | ||||
| <div className='space-y-2'> | <div className='space-y-2'> | ||||
| {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( | {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( | ||||
| <OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />} | |||||
| <OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />} | |||||
| title={t('dataset.retrieval.semantic_search.title')} | title={t('dataset.retrieval.semantic_search.title')} | ||||
| description={t('dataset.retrieval.semantic_search.description')} | description={t('dataset.retrieval.semantic_search.description')} | ||||
| isActive={ | isActive={ | ||||
| value.search_method === RETRIEVE_METHOD.semantic | value.search_method === RETRIEVE_METHOD.semantic | ||||
| } | } | ||||
| onSwitched={() => onChange({ | |||||
| ...value, | |||||
| search_method: RETRIEVE_METHOD.semantic, | |||||
| })} | |||||
| onSwitched={() => onSwitch(RETRIEVE_METHOD.semantic)} | |||||
| effectImg={Effect.src} | effectImg={Effect.src} | ||||
| activeHeaderClassName='bg-dataset-option-card-purple-gradient' | activeHeaderClassName='bg-dataset-option-card-purple-gradient' | ||||
| > | > | ||||
| /> | /> | ||||
| </OptionCard> | </OptionCard> | ||||
| )} | )} | ||||
| {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( | |||||
| <OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />} | |||||
| {supportRetrievalMethods.includes(RETRIEVE_METHOD.fullText) && ( | |||||
| <OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />} | |||||
| title={t('dataset.retrieval.full_text_search.title')} | title={t('dataset.retrieval.full_text_search.title')} | ||||
| description={t('dataset.retrieval.full_text_search.description')} | description={t('dataset.retrieval.full_text_search.description')} | ||||
| isActive={ | isActive={ | ||||
| value.search_method === RETRIEVE_METHOD.fullText | value.search_method === RETRIEVE_METHOD.fullText | ||||
| } | } | ||||
| onSwitched={() => onChange({ | |||||
| ...value, | |||||
| search_method: RETRIEVE_METHOD.fullText, | |||||
| })} | |||||
| onSwitched={() => onSwitch(RETRIEVE_METHOD.fullText)} | |||||
| effectImg={Effect.src} | effectImg={Effect.src} | ||||
| activeHeaderClassName='bg-dataset-option-card-purple-gradient' | activeHeaderClassName='bg-dataset-option-card-purple-gradient' | ||||
| > | > | ||||
| /> | /> | ||||
| </OptionCard> | </OptionCard> | ||||
| )} | )} | ||||
| {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( | |||||
| <OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />} | |||||
| {supportRetrievalMethods.includes(RETRIEVE_METHOD.hybrid) && ( | |||||
| <OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />} | |||||
| title={ | title={ | ||||
| <div className='flex items-center space-x-1'> | <div className='flex items-center space-x-1'> | ||||
| <div>{t('dataset.retrieval.hybrid_search.title')}</div> | <div>{t('dataset.retrieval.hybrid_search.title')}</div> | ||||
| description={t('dataset.retrieval.hybrid_search.description')} isActive={ | description={t('dataset.retrieval.hybrid_search.description')} isActive={ | ||||
| value.search_method === RETRIEVE_METHOD.hybrid | value.search_method === RETRIEVE_METHOD.hybrid | ||||
| } | } | ||||
| onSwitched={() => onChange({ | |||||
| ...value, | |||||
| search_method: RETRIEVE_METHOD.hybrid, | |||||
| reranking_enable: true, | |||||
| })} | |||||
| onSwitched={() => onSwitch(RETRIEVE_METHOD.hybrid)} | |||||
| effectImg={Effect.src} | effectImg={Effect.src} | ||||
| activeHeaderClassName='bg-dataset-option-card-purple-gradient' | activeHeaderClassName='bg-dataset-option-card-purple-gradient' | ||||
| > | > | 
| 'use client' | 'use client' | ||||
| import type { FC } from 'react' | import type { FC } from 'react' | ||||
| import React, { useCallback } from 'react' | |||||
| import React, { useCallback, useMemo } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | import { useTranslation } from 'react-i18next' | ||||
| import Image from 'next/image' | import Image from 'next/image' | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid | const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid | ||||
| const isEconomical = type === RETRIEVE_METHOD.invertedIndex | const isEconomical = type === RETRIEVE_METHOD.invertedIndex | ||||
| const isHybridSearch = type === RETRIEVE_METHOD.hybrid | |||||
| const { | const { | ||||
| defaultModel: rerankDefaultModel, | |||||
| modelList: rerankModelList, | modelList: rerankModelList, | ||||
| } = useModelListAndDefaultModel(ModelTypeEnum.rerank) | } = useModelListAndDefaultModel(ModelTypeEnum.rerank) | ||||
| currentModel, | currentModel, | ||||
| } = useCurrentProviderAndModel( | } = useCurrentProviderAndModel( | ||||
| rerankModelList, | rerankModelList, | ||||
| rerankDefaultModel | |||||
| ? { | |||||
| ...rerankDefaultModel, | |||||
| provider: rerankDefaultModel.provider.provider, | |||||
| } | |||||
| : undefined, | |||||
| { | |||||
| provider: value.reranking_model?.reranking_provider_name ?? '', | |||||
| model: value.reranking_model?.reranking_model_name ?? '', | |||||
| }, | |||||
| ) | ) | ||||
| const handleDisabledSwitchClick = useCallback(() => { | |||||
| if (!currentModel) | |||||
| const handleDisabledSwitchClick = useCallback((enable: boolean) => { | |||||
| if (enable && !currentModel) | |||||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | ||||
| }, [currentModel, rerankDefaultModel, t]) | |||||
| const isHybridSearch = type === RETRIEVE_METHOD.hybrid | |||||
| onChange({ | |||||
| ...value, | |||||
| reranking_enable: enable, | |||||
| }) | |||||
| // eslint-disable-next-line react-hooks/exhaustive-deps | |||||
| }, [currentModel, onChange, value]) | |||||
| const rerankModel = (() => { | |||||
| if (value.reranking_model) { | |||||
| return { | |||||
| provider_name: value.reranking_model.reranking_provider_name, | |||||
| model_name: value.reranking_model.reranking_model_name, | |||||
| } | |||||
| } | |||||
| else if (rerankDefaultModel) { | |||||
| return { | |||||
| provider_name: rerankDefaultModel.provider.provider, | |||||
| model_name: rerankDefaultModel.model, | |||||
| } | |||||
| const rerankModel = useMemo(() => { | |||||
| return { | |||||
| provider_name: value.reranking_model.reranking_provider_name, | |||||
| model_name: value.reranking_model.reranking_model_name, | |||||
| } | } | ||||
| })() | |||||
| }, [value.reranking_model]) | |||||
| const handleChangeRerankMode = (v: RerankingModeEnum) => { | const handleChangeRerankMode = (v: RerankingModeEnum) => { | ||||
| if (v === value.reranking_mode) | if (v === value.reranking_mode) | ||||
| }, | }, | ||||
| } | } | ||||
| } | } | ||||
| if (v === RerankingModeEnum.RerankingModel && !currentModel) | |||||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | |||||
| onChange(result) | onChange(result) | ||||
| } | } | ||||
| <div> | <div> | ||||
| <div className='flex items-center space-x-2 mb-2'> | <div className='flex items-center space-x-2 mb-2'> | ||||
| {canToggleRerankModalEnable && ( | {canToggleRerankModalEnable && ( | ||||
| <div | |||||
| className='flex items-center' | |||||
| onClick={handleDisabledSwitchClick} | |||||
| > | |||||
| <Switch | |||||
| size='md' | |||||
| defaultValue={currentModel ? value.reranking_enable : false} | |||||
| onChange={(v) => { | |||||
| onChange({ | |||||
| ...value, | |||||
| reranking_enable: v, | |||||
| }) | |||||
| }} | |||||
| disabled={!currentModel} | |||||
| /> | |||||
| </div> | |||||
| <Switch | |||||
| size='md' | |||||
| defaultValue={value.reranking_enable} | |||||
| onChange={handleDisabledSwitchClick} | |||||
| /> | |||||
| )} | )} | ||||
| <div className='flex items-center'> | <div className='flex items-center'> | ||||
| <span className='mr-0.5 system-sm-semibold text-text-secondary'>{t('common.modelProvider.rerankModel.key')}</span> | <span className='mr-0.5 system-sm-semibold text-text-secondary'>{t('common.modelProvider.rerankModel.key')}</span> | ||||
| /> | /> | ||||
| </div> | </div> | ||||
| </div> | </div> | ||||
| <ModelSelector | |||||
| triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`} | |||||
| defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }} | |||||
| modelList={rerankModelList} | |||||
| readonly={!value.reranking_enable} | |||||
| onSelect={(v) => { | |||||
| onChange({ | |||||
| ...value, | |||||
| reranking_model: { | |||||
| reranking_provider_name: v.provider, | |||||
| reranking_model_name: v.model, | |||||
| }, | |||||
| }) | |||||
| }} | |||||
| /> | |||||
| { | |||||
| value.reranking_enable && ( | |||||
| <ModelSelector | |||||
| defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }} | |||||
| modelList={rerankModelList} | |||||
| onSelect={(v) => { | |||||
| onChange({ | |||||
| ...value, | |||||
| reranking_model: { | |||||
| reranking_provider_name: v.provider, | |||||
| reranking_model_name: v.model, | |||||
| }, | |||||
| }) | |||||
| }} | |||||
| /> | |||||
| ) | |||||
| } | |||||
| </div> | </div> | ||||
| )} | )} | ||||
| { | { | ||||
| { | { | ||||
| value.reranking_mode !== RerankingModeEnum.WeightedScore && ( | value.reranking_mode !== RerankingModeEnum.WeightedScore && ( | ||||
| <ModelSelector | <ModelSelector | ||||
| triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`} | |||||
| defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }} | defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }} | ||||
| modelList={rerankModelList} | modelList={rerankModelList} | ||||
| readonly={!value.reranking_enable} | |||||
| onSelect={(v) => { | onSelect={(v) => { | ||||
| onChange({ | onChange({ | ||||
| ...value, | ...value, | 
| import { sleep } from '@/utils' | import { sleep } from '@/utils' | ||||
| import { RETRIEVE_METHOD } from '@/types/app' | import { RETRIEVE_METHOD } from '@/types/app' | ||||
| import Tooltip from '@/app/components/base/tooltip' | import Tooltip from '@/app/components/base/tooltip' | ||||
| import { useInvalidDocumentList } from '@/service/knowledge/use-document' | |||||
| type Props = { | type Props = { | ||||
| datasetId: string | datasetId: string | ||||
| }) | }) | ||||
| const router = useRouter() | const router = useRouter() | ||||
| const invalidDocumentList = useInvalidDocumentList() | |||||
| const navToDocumentList = () => { | const navToDocumentList = () => { | ||||
| invalidDocumentList() | |||||
| router.push(`/datasets/${datasetId}/documents`) | router.push(`/datasets/${datasetId}/documents`) | ||||
| } | } | ||||
| const navToApiDocs = () => { | const navToApiDocs = () => { | 
| import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs' | import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs' | ||||
| import cn from '@/utils/classnames' | import cn from '@/utils/classnames' | ||||
| import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, DocumentItem, FullDocumentDetail, ParentMode, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' | import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, DocumentItem, FullDocumentDetail, ParentMode, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' | ||||
| import { ChunkingMode, DataSourceType, ProcessMode } from '@/models/datasets' | |||||
| import Button from '@/app/components/base/button' | import Button from '@/app/components/base/button' | ||||
| import FloatRightContainer from '@/app/components/base/float-right-container' | import FloatRightContainer from '@/app/components/base/float-right-container' | ||||
| import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' | import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' | ||||
| import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' | import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' | ||||
| import { type RetrievalConfig } from '@/types/app' | import { type RetrievalConfig } from '@/types/app' | ||||
| import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' | |||||
| import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' | |||||
| import Toast from '@/app/components/base/toast' | import Toast from '@/app/components/base/toast' | ||||
| import type { NotionPage } from '@/models/common' | import type { NotionPage } from '@/models/common' | ||||
| import { DataSourceProvider } from '@/models/common' | import { DataSourceProvider } from '@/models/common' | ||||
| import { ChunkingMode, DataSourceType, RerankingModeEnum } from '@/models/datasets' | |||||
| import { useDatasetDetailContext } from '@/context/dataset-detail' | import { useDatasetDetailContext } from '@/context/dataset-detail' | ||||
| import I18n from '@/context/i18n' | import I18n from '@/context/i18n' | ||||
| import { RETRIEVE_METHOD } from '@/types/app' | import { RETRIEVE_METHOD } from '@/types/app' | ||||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | ||||
| import Checkbox from '@/app/components/base/checkbox' | import Checkbox from '@/app/components/base/checkbox' | ||||
| import RadioCard from '@/app/components/base/radio-card' | import RadioCard from '@/app/components/base/radio-card' | ||||
| import { IS_CE_EDITION } from '@/config' | |||||
| import { FULL_DOC_PREVIEW_LENGTH, IS_CE_EDITION } from '@/config' | |||||
| import Divider from '@/app/components/base/divider' | import Divider from '@/app/components/base/divider' | ||||
| import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument, useFetchDefaultProcessRule, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/knowledge/use-create-dataset' | import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument, useFetchDefaultProcessRule, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/knowledge/use-create-dataset' | ||||
| import Badge from '@/app/components/base/badge' | import Badge from '@/app/components/base/badge' | ||||
| onCancel?: () => void | onCancel?: () => void | ||||
| } | } | ||||
| export enum SegmentType { | |||||
| AUTO = 'automatic', | |||||
| CUSTOM = 'custom', | |||||
| } | |||||
| export enum IndexingType { | export enum IndexingType { | ||||
| QUALIFIED = 'high_quality', | QUALIFIED = 'high_quality', | ||||
| ECONOMICAL = 'economy', | ECONOMICAL = 'economy', | ||||
| } | } | ||||
| const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n' | const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n' | ||||
| const DEFAULT_MAXMIMUM_CHUNK_LENGTH = 500 | |||||
| const DEFAULT_MAXIMUM_CHUNK_LENGTH = 500 | |||||
| const DEFAULT_OVERLAP = 50 | const DEFAULT_OVERLAP = 50 | ||||
| type ParentChildConfig = { | type ParentChildConfig = { | ||||
| isSetting, | isSetting, | ||||
| documentDetail, | documentDetail, | ||||
| isAPIKeySet, | isAPIKeySet, | ||||
| onSetting, | |||||
| datasetId, | datasetId, | ||||
| indexingType, | indexingType, | ||||
| dataSourceType: inCreatePageDataSourceType, | dataSourceType: inCreatePageDataSourceType, | ||||
| const isInCreatePage = !datasetId || (datasetId && !currentDataset?.data_source_type) | const isInCreatePage = !datasetId || (datasetId && !currentDataset?.data_source_type) | ||||
| const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : currentDataset?.data_source_type | const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : currentDataset?.data_source_type | ||||
| const [segmentationType, setSegmentationType] = useState<SegmentType>(SegmentType.CUSTOM) | |||||
| const [segmentationType, setSegmentationType] = useState<ProcessMode>(ProcessMode.general) | |||||
| const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER) | const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER) | ||||
| const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => { | const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => { | ||||
| doSetSegmentIdentifier(value ? escape(value) : (canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER)) | doSetSegmentIdentifier(value ? escape(value) : (canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER)) | ||||
| }, []) | }, []) | ||||
| const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXMIMUM_CHUNK_LENGTH) // default chunk length | |||||
| const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXIMUM_CHUNK_LENGTH) // default chunk length | |||||
| const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(4000) | const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(4000) | ||||
| const [overlap, setOverlap] = useState(DEFAULT_OVERLAP) | const [overlap, setOverlap] = useState(DEFAULT_OVERLAP) | ||||
| const [rules, setRules] = useState<PreProcessingRule[]>([]) | const [rules, setRules] = useState<PreProcessingRule[]>([]) | ||||
| ) | ) | ||||
| // QA Related | // QA Related | ||||
| const [isLanguageSelectDisabled, _setIsLanguageSelectDisabled] = useState(false) | |||||
| const [isQAConfirmDialogOpen, setIsQAConfirmDialogOpen] = useState(false) | const [isQAConfirmDialogOpen, setIsQAConfirmDialogOpen] = useState(false) | ||||
| const [docForm, setDocForm] = useState<ChunkingMode>( | const [docForm, setDocForm] = useState<ChunkingMode>( | ||||
| (datasetId && documentDetail) ? documentDetail.doc_form as ChunkingMode : ChunkingMode.text, | (datasetId && documentDetail) ? documentDetail.doc_form as ChunkingMode : ChunkingMode.text, | ||||
| } | } | ||||
| const updatePreview = () => { | const updatePreview = () => { | ||||
| if (segmentationType === SegmentType.CUSTOM && maxChunkLength > 4000) { | |||||
| if (segmentationType === ProcessMode.general && maxChunkLength > 4000) { | |||||
| Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') }) | Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') }) | ||||
| return | return | ||||
| } | } | ||||
| model: defaultEmbeddingModel?.model || '', | model: defaultEmbeddingModel?.model || '', | ||||
| }, | }, | ||||
| ) | ) | ||||
| const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || { | |||||
| search_method: RETRIEVE_METHOD.semantic, | |||||
| reranking_enable: false, | |||||
| reranking_model: { | |||||
| reranking_provider_name: '', | |||||
| reranking_model_name: '', | |||||
| }, | |||||
| top_k: 3, | |||||
| score_threshold_enabled: false, | |||||
| score_threshold: 0.5, | |||||
| } as RetrievalConfig) | |||||
| useEffect(() => { | |||||
| if (currentDataset?.retrieval_model_dict) | |||||
| return | |||||
| setRetrievalConfig({ | |||||
| search_method: RETRIEVE_METHOD.semantic, | |||||
| reranking_enable: !!isRerankDefaultModelValid, | |||||
| reranking_model: { | |||||
| reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider.provider ?? '' : '', | |||||
| reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '', | |||||
| }, | |||||
| top_k: 3, | |||||
| score_threshold_enabled: false, | |||||
| score_threshold: 0.5, | |||||
| }) | |||||
| // eslint-disable-next-line react-hooks/exhaustive-deps | |||||
| }, [rerankDefaultModel, isRerankDefaultModelValid]) | |||||
| const getCreationParams = () => { | const getCreationParams = () => { | ||||
| let params | let params | ||||
| if (segmentationType === SegmentType.CUSTOM && overlap > maxChunkLength) { | |||||
| if (segmentationType === ProcessMode.general && overlap > maxChunkLength) { | |||||
| Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.overlapCheck') }) | Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.overlapCheck') }) | ||||
| return | return | ||||
| } | } | ||||
| if (segmentationType === SegmentType.CUSTOM && maxChunkLength > limitMaxChunkLength) { | |||||
| if (segmentationType === ProcessMode.general && maxChunkLength > limitMaxChunkLength) { | |||||
| Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck', { limit: limitMaxChunkLength }) }) | Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck', { limit: limitMaxChunkLength }) }) | ||||
| return | return | ||||
| } | } | ||||
| doc_form: currentDocForm, | doc_form: currentDocForm, | ||||
| doc_language: docLanguage, | doc_language: docLanguage, | ||||
| process_rule: getProcessRule(), | process_rule: getProcessRule(), | ||||
| // eslint-disable-next-line @typescript-eslint/no-use-before-define | |||||
| retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page. | retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page. | ||||
| embedding_model: embeddingModel.model, // Readonly | embedding_model: embeddingModel.model, // Readonly | ||||
| embedding_model_provider: embeddingModel.provider, // Readonly | embedding_model_provider: embeddingModel.provider, // Readonly | ||||
| const indexMethod = getIndexing_technique() | const indexMethod = getIndexing_technique() | ||||
| if ( | if ( | ||||
| !isReRankModelSelected({ | !isReRankModelSelected({ | ||||
| rerankDefaultModel, | |||||
| isRerankDefaultModelValid: !!isRerankDefaultModelValid, | |||||
| rerankModelList, | rerankModelList, | ||||
| // eslint-disable-next-line @typescript-eslint/no-use-before-define | |||||
| retrievalConfig, | retrievalConfig, | ||||
| indexMethod: indexMethod as string, | indexMethod: indexMethod as string, | ||||
| }) | }) | ||||
| Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') }) | Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') }) | ||||
| return | return | ||||
| } | } | ||||
| const postRetrievalConfig = ensureRerankModelSelected({ | |||||
| rerankDefaultModel: rerankDefaultModel!, | |||||
| retrievalConfig: { | |||||
| // eslint-disable-next-line @typescript-eslint/no-use-before-define | |||||
| ...retrievalConfig, | |||||
| // eslint-disable-next-line @typescript-eslint/no-use-before-define | |||||
| reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel, | |||||
| }, | |||||
| indexMethod: indexMethod as string, | |||||
| }) | |||||
| params = { | params = { | ||||
| data_source: { | data_source: { | ||||
| type: dataSourceType, | type: dataSourceType, | ||||
| process_rule: getProcessRule(), | process_rule: getProcessRule(), | ||||
| doc_form: currentDocForm, | doc_form: currentDocForm, | ||||
| doc_language: docLanguage, | doc_language: docLanguage, | ||||
| retrieval_model: postRetrievalConfig, | |||||
| retrieval_model: retrievalConfig, | |||||
| embedding_model: embeddingModel.model, | embedding_model: embeddingModel.model, | ||||
| embedding_model_provider: embeddingModel.provider, | embedding_model_provider: embeddingModel.provider, | ||||
| } as CreateDocumentReq | } as CreateDocumentReq | ||||
| const getDefaultMode = () => { | const getDefaultMode = () => { | ||||
| if (documentDetail) | if (documentDetail) | ||||
| // @ts-expect-error fix after api refactored | |||||
| setSegmentationType(documentDetail.dataset_process_rule.mode) | setSegmentationType(documentDetail.dataset_process_rule.mode) | ||||
| } | } | ||||
| onSuccess(data) { | onSuccess(data) { | ||||
| updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) | updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) | ||||
| updateResultCache && updateResultCache(data) | updateResultCache && updateResultCache(data) | ||||
| // eslint-disable-next-line @typescript-eslint/no-use-before-define | |||||
| updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string) | updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string) | ||||
| }, | }, | ||||
| }, | }, | ||||
| isSetting && onSave && onSave() | isSetting && onSave && onSave() | ||||
| } | } | ||||
| const changeToEconomicalType = () => { | |||||
| if (docForm !== ChunkingMode.text) | |||||
| return | |||||
| if (!hasSetIndexType) | |||||
| setIndexType(IndexingType.ECONOMICAL) | |||||
| } | |||||
| useEffect(() => { | useEffect(() => { | ||||
| // fetch rules | // fetch rules | ||||
| if (!isSetting) { | if (!isSetting) { | ||||
| setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL) | setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL) | ||||
| }, [isAPIKeySet, indexingType, datasetId]) | }, [isAPIKeySet, indexingType, datasetId]) | ||||
| const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || { | |||||
| search_method: RETRIEVE_METHOD.semantic, | |||||
| reranking_enable: false, | |||||
| reranking_model: { | |||||
| reranking_provider_name: rerankDefaultModel?.provider.provider, | |||||
| reranking_model_name: rerankDefaultModel?.model, | |||||
| }, | |||||
| top_k: 3, | |||||
| score_threshold_enabled: false, | |||||
| score_threshold: 0.5, | |||||
| } as RetrievalConfig) | |||||
| const economyDomRef = useRef<HTMLDivElement>(null) | const economyDomRef = useRef<HTMLDivElement>(null) | ||||
| const isHoveringEconomy = useHover(economyDomRef) | const isHoveringEconomy = useHover(economyDomRef) | ||||
| <div className={cn('system-md-semibold mb-1', datasetId && 'flex justify-between items-center')}>{t('datasetSettings.form.embeddingModel')}</div> | <div className={cn('system-md-semibold mb-1', datasetId && 'flex justify-between items-center')}>{t('datasetSettings.form.embeddingModel')}</div> | ||||
| <ModelSelector | <ModelSelector | ||||
| readonly={!!datasetId} | readonly={!!datasetId} | ||||
| triggerClassName={datasetId ? 'opacity-50' : ''} | |||||
| defaultModel={embeddingModel} | defaultModel={embeddingModel} | ||||
| modelList={embeddingModelList} | modelList={embeddingModelList} | ||||
| onSelect={(model: DefaultModel) => { | onSelect={(model: DefaultModel) => { | ||||
| getIndexing_technique() === IndexingType.QUALIFIED | getIndexing_technique() === IndexingType.QUALIFIED | ||||
| ? ( | ? ( | ||||
| <RetrievalMethodConfig | <RetrievalMethodConfig | ||||
| disabled={!!datasetId} | |||||
| value={retrievalConfig} | value={retrievalConfig} | ||||
| onChange={setRetrievalConfig} | onChange={setRetrievalConfig} | ||||
| /> | /> | ||||
| ) | ) | ||||
| : ( | : ( | ||||
| <EconomicalRetrievalMethodConfig | <EconomicalRetrievalMethodConfig | ||||
| disabled={!!datasetId} | |||||
| value={retrievalConfig} | value={retrievalConfig} | ||||
| onChange={setRetrievalConfig} | onChange={setRetrievalConfig} | ||||
| /> | /> | ||||
| ) | ) | ||||
| : ( | : ( | ||||
| <div className='flex items-center mt-8 py-2'> | <div className='flex items-center mt-8 py-2'> | ||||
| <Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button> | |||||
| {!datasetId && <Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button>} | |||||
| <Button className='ml-2' onClick={onCancel}>{t('datasetCreation.stepTwo.cancel')}</Button> | <Button className='ml-2' onClick={onCancel}>{t('datasetCreation.stepTwo.cancel')}</Button> | ||||
| </div> | </div> | ||||
| )} | )} | ||||
| } | } | ||||
| { | { | ||||
| currentDocForm !== ChunkingMode.qa | currentDocForm !== ChunkingMode.qa | ||||
| && <Badge text={t( | |||||
| 'datasetCreation.stepTwo.previewChunkCount', { | |||||
| count: estimate?.total_segments || 0, | |||||
| }) as string} | |||||
| /> | |||||
| && <Badge text={t( | |||||
| 'datasetCreation.stepTwo.previewChunkCount', { | |||||
| count: estimate?.total_segments || 0, | |||||
| }) as string} | |||||
| /> | |||||
| } | } | ||||
| </div> | </div> | ||||
| </PreviewHeader>} | </PreviewHeader>} | ||||
| {currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && ( | {currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && ( | ||||
| estimate?.preview?.map((item, index) => { | estimate?.preview?.map((item, index) => { | ||||
| const indexForLabel = index + 1 | const indexForLabel = index + 1 | ||||
| const childChunks = parentChildConfig.chunkForContext === 'full-doc' | |||||
| ? item.child_chunks.slice(0, FULL_DOC_PREVIEW_LENGTH) | |||||
| : item.child_chunks | |||||
| return ( | return ( | ||||
| <ChunkContainer | <ChunkContainer | ||||
| key={item.content} | key={item.content} | ||||
| characterCount={item.content.length} | characterCount={item.content.length} | ||||
| > | > | ||||
| <FormattedText> | <FormattedText> | ||||
| {item.child_chunks.map((child, index) => { | |||||
| {childChunks.map((child, index) => { | |||||
| const indexForLabel = index + 1 | const indexForLabel = index + 1 | ||||
| return ( | return ( | ||||
| <PreviewSlice | <PreviewSlice | 
| const TriangleArrow: FC<ComponentProps<'svg'>> = props => ( | const TriangleArrow: FC<ComponentProps<'svg'>> = props => ( | ||||
| <svg xmlns="http://www.w3.org/2000/svg" width="24" height="11" viewBox="0 0 24 11" fill="none" {...props}> | <svg xmlns="http://www.w3.org/2000/svg" width="24" height="11" viewBox="0 0 24 11" fill="none" {...props}> | ||||
| <path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor"/> | |||||
| <path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor" /> | |||||
| </svg> | </svg> | ||||
| ) | ) | ||||
| (isActive && !noHighlight) | (isActive && !noHighlight) | ||||
| ? 'border-[1.5px] border-components-option-card-option-selected-border' | ? 'border-[1.5px] border-components-option-card-option-selected-border' | ||||
| : 'border border-components-option-card-option-border', | : 'border border-components-option-card-option-border', | ||||
| disabled && 'opacity-50 cursor-not-allowed', | |||||
| disabled && 'opacity-50 pointer-events-none', | |||||
| className, | className, | ||||
| )} | )} | ||||
| style={{ | style={{ |