| @@ -23,6 +23,9 @@ FILES_ACCESS_TIMEOUT=300 | |||
| # Access token expiration time in minutes | |||
| ACCESS_TOKEN_EXPIRE_MINUTES=60 | |||
| # Refresh token expiration time in days | |||
| REFRESH_TOKEN_EXPIRE_DAYS=30 | |||
| # celery configuration | |||
| CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 | |||
| @@ -14,7 +14,10 @@ if is_db_command(): | |||
| app = create_migrations_app() | |||
| else: | |||
| if os.environ.get("FLASK_DEBUG", "False") != "True": | |||
| # It seems that JetBrains Python debugger does not work well with gevent, | |||
| # so we need to disable gevent in debug mode. | |||
| # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. | |||
| if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: | |||
| from gevent import monkey # type: ignore | |||
| # gevent | |||
| @@ -546,6 +546,11 @@ class AuthConfig(BaseSettings): | |||
| default=60, | |||
| ) | |||
| REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field( | |||
| description="Expiration time for refresh tokens in days", | |||
| default=30, | |||
| ) | |||
| LOGIN_LOCKOUT_DURATION: PositiveInt = Field( | |||
| description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.", | |||
| default=86400, | |||
| @@ -725,6 +730,11 @@ class IndexingConfig(BaseSettings): | |||
| default=4000, | |||
| ) | |||
| CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field( | |||
| description="Maximum number of child chunks to preview", | |||
| default=50, | |||
| ) | |||
| class MultiModalTransferConfig(BaseSettings): | |||
| MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field( | |||
| @@ -33,3 +33,9 @@ class MilvusConfig(BaseSettings): | |||
| description="Name of the Milvus database to connect to (default is 'default')", | |||
| default="default", | |||
| ) | |||
| MILVUS_ENABLE_HYBRID_SEARCH: bool = Field( | |||
| description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with " | |||
| "older versions", | |||
| default=True, | |||
| ) | |||
| @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): | |||
| CURRENT_VERSION: str = Field( | |||
| description="Dify version", | |||
| default="0.14.2", | |||
| default="0.15.0", | |||
| ) | |||
| COMMIT_SHA: str = Field( | |||
| @@ -57,12 +57,13 @@ class AppListApi(Resource): | |||
| ) | |||
| parser.add_argument("name", type=str, location="args", required=False) | |||
| parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) | |||
| parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False) | |||
| args = parser.parse_args() | |||
| # get app list | |||
| app_service = AppService() | |||
| app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) | |||
| app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args) | |||
| if not app_pagination: | |||
| return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} | |||
| @@ -20,7 +20,6 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.errors.error import ( | |||
| AppInvokeQuotaExceededError, | |||
| ModelCurrentlyNotSupportError, | |||
| ProviderTokenNotInitError, | |||
| QuotaExceededError, | |||
| @@ -76,7 +75,7 @@ class CompletionMessageApi(Resource): | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(e.description) | |||
| except (ValueError, AppInvokeQuotaExceededError) as e: | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| @@ -141,7 +140,7 @@ class ChatMessageApi(Resource): | |||
| raise InvokeRateLimitHttpError(ex.description) | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(e.description) | |||
| except (ValueError, AppInvokeQuotaExceededError) as e: | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| @@ -273,8 +273,7 @@ FROM | |||
| messages m | |||
| ON c.id = m.conversation_id | |||
| WHERE | |||
| c.override_model_configs IS NULL | |||
| AND c.app_id = :app_id""" | |||
| c.app_id = :app_id""" | |||
| arg_dict = {"tz": account.timezone, "app_id": app_model.id} | |||
| timezone = pytz.timezone(account.timezone) | |||
| @@ -640,6 +640,7 @@ class DatasetRetrievalSettingApi(Resource): | |||
| | VectorType.MYSCALE | |||
| | VectorType.ORACLE | |||
| | VectorType.ELASTICSEARCH | |||
| | VectorType.ELASTICSEARCH_JA | |||
| | VectorType.PGVECTOR | |||
| | VectorType.TIDB_ON_QDRANT | |||
| | VectorType.LINDORM | |||
| @@ -683,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource): | |||
| | VectorType.MYSCALE | |||
| | VectorType.ORACLE | |||
| | VectorType.ELASTICSEARCH | |||
| | VectorType.ELASTICSEARCH_JA | |||
| | VectorType.COUCHBASE | |||
| | VectorType.PGVECTOR | |||
| | VectorType.LINDORM | |||
| @@ -269,7 +269,8 @@ class DatasetDocumentListApi(Resource): | |||
| parser.add_argument("original_document_id", type=str, required=False, location="json") | |||
| parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | |||
| parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") | |||
| parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument( | |||
| "doc_language", type=str, default="English", required=False, nullable=False, location="json" | |||
| ) | |||
| @@ -18,7 +18,11 @@ from controllers.console.explore.error import NotChatAppError, NotCompletionAppE | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||
| from core.errors.error import ( | |||
| ModelCurrentlyNotSupportError, | |||
| ProviderTokenNotInitError, | |||
| QuotaExceededError, | |||
| ) | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from extensions.ext_database import db | |||
| from libs import helper | |||
| @@ -13,7 +13,11 @@ from controllers.console.explore.error import NotWorkflowAppError | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||
| from core.errors.error import ( | |||
| ModelCurrentlyNotSupportError, | |||
| ProviderTokenNotInitError, | |||
| QuotaExceededError, | |||
| ) | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs import helper | |||
| from libs.login import current_user | |||
| @@ -18,7 +18,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.errors.error import ( | |||
| AppInvokeQuotaExceededError, | |||
| ModelCurrentlyNotSupportError, | |||
| ProviderTokenNotInitError, | |||
| QuotaExceededError, | |||
| @@ -74,7 +73,7 @@ class CompletionApi(Resource): | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(e.description) | |||
| except (ValueError, AppInvokeQuotaExceededError) as e: | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| @@ -133,7 +132,7 @@ class ChatApi(Resource): | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(e.description) | |||
| except (ValueError, AppInvokeQuotaExceededError) as e: | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| @@ -16,7 +16,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.errors.error import ( | |||
| AppInvokeQuotaExceededError, | |||
| ModelCurrentlyNotSupportError, | |||
| ProviderTokenNotInitError, | |||
| QuotaExceededError, | |||
| @@ -94,7 +93,7 @@ class WorkflowRunApi(Resource): | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(e.description) | |||
| except (ValueError, AppInvokeQuotaExceededError) as e: | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| @@ -190,7 +190,10 @@ class DocumentAddByFileApi(DatasetApiResource): | |||
| user=current_user, | |||
| source="datasets", | |||
| ) | |||
| data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} | |||
| data_source = { | |||
| "type": "upload_file", | |||
| "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, | |||
| } | |||
| args["data_source"] = data_source | |||
| # validate args | |||
| knowledge_config = KnowledgeConfig(**args) | |||
| @@ -254,7 +257,10 @@ class DocumentUpdateByFileApi(DatasetApiResource): | |||
| raise FileTooLargeError(file_too_large_error.description) | |||
| except services.errors.file.UnsupportedFileTypeError: | |||
| raise UnsupportedFileTypeError() | |||
| data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} | |||
| data_source = { | |||
| "type": "upload_file", | |||
| "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, | |||
| } | |||
| args["data_source"] = data_source | |||
| # validate args | |||
| args["original_document_id"] = str(document_id) | |||
| @@ -1,5 +1,5 @@ | |||
| from collections.abc import Callable | |||
| from datetime import UTC, datetime | |||
| from datetime import UTC, datetime, timedelta | |||
| from enum import Enum | |||
| from functools import wraps | |||
| from typing import Optional | |||
| @@ -8,6 +8,8 @@ from flask import current_app, request | |||
| from flask_login import user_logged_in # type: ignore | |||
| from flask_restful import Resource # type: ignore | |||
| from pydantic import BaseModel | |||
| from sqlalchemy import select, update | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import Forbidden, Unauthorized | |||
| from extensions.ext_database import db | |||
| @@ -174,7 +176,7 @@ def validate_dataset_token(view=None): | |||
| return decorator | |||
| def validate_and_get_api_token(scope=None): | |||
| def validate_and_get_api_token(scope: str | None = None): | |||
| """ | |||
| Validate and get API token. | |||
| """ | |||
| @@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None): | |||
| if auth_scheme != "bearer": | |||
| raise Unauthorized("Authorization scheme must be 'Bearer'") | |||
| api_token = ( | |||
| db.session.query(ApiToken) | |||
| .filter( | |||
| ApiToken.token == auth_token, | |||
| ApiToken.type == scope, | |||
| current_time = datetime.now(UTC).replace(tzinfo=None) | |||
| cutoff_time = current_time - timedelta(minutes=1) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| update_stmt = ( | |||
| update(ApiToken) | |||
| .where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope) | |||
| .values(last_used_at=current_time) | |||
| .returning(ApiToken) | |||
| ) | |||
| .first() | |||
| ) | |||
| if not api_token: | |||
| raise Unauthorized("Access token is invalid") | |||
| api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| result = session.execute(update_stmt) | |||
| api_token = result.scalar_one_or_none() | |||
| if not api_token: | |||
| stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) | |||
| api_token = session.scalar(stmt) | |||
| if not api_token: | |||
| raise Unauthorized("Access token is invalid") | |||
| else: | |||
| session.commit() | |||
| return api_token | |||
| @@ -19,7 +19,11 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr | |||
| from controllers.web.wraps import WebApiResource | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||
| from core.errors.error import ( | |||
| ModelCurrentlyNotSupportError, | |||
| ProviderTokenNotInitError, | |||
| QuotaExceededError, | |||
| ) | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs import helper | |||
| from libs.helper import uuid_value | |||
| @@ -14,7 +14,11 @@ from controllers.web.error import ( | |||
| from controllers.web.wraps import WebApiResource | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||
| from core.errors.error import ( | |||
| ModelCurrentlyNotSupportError, | |||
| ProviderTokenNotInitError, | |||
| QuotaExceededError, | |||
| ) | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs import helper | |||
| from models.model import App, AppMode, EndUser | |||
| @@ -21,7 +21,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | |||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom | |||
| from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.prompt.utils.get_thread_messages_length import get_thread_messages_length | |||
| from extensions.ext_database import db | |||
| @@ -346,7 +346,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| except ValidationError as e: | |||
| logger.exception("Validation Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| except (ValueError, InvokeError) as e: | |||
| except ValueError as e: | |||
| if dify_config.DEBUG: | |||
| logger.exception("Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| @@ -68,24 +68,17 @@ from models.account import Account | |||
| from models.enums import CreatedByRole | |||
| from models.workflow import ( | |||
| Workflow, | |||
| WorkflowNodeExecution, | |||
| WorkflowRunStatus, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage): | |||
| class AdvancedChatAppGenerateTaskPipeline: | |||
| """ | |||
| AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. | |||
| """ | |||
| _task_state: WorkflowTaskState | |||
| _application_generate_entity: AdvancedChatAppGenerateEntity | |||
| _workflow_system_variables: dict[SystemVariableKey, Any] | |||
| _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] | |||
| _conversation_name_generate_thread: Optional[Thread] = None | |||
| def __init__( | |||
| self, | |||
| application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| @@ -97,7 +90,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| stream: bool, | |||
| dialogue_count: int, | |||
| ) -> None: | |||
| super().__init__( | |||
| self._base_task_pipeline = BasedGenerateTaskPipeline( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| stream=stream, | |||
| @@ -114,33 +107,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| else: | |||
| raise NotImplementedError(f"User type not supported: {type(user)}") | |||
| self._workflow_cycle_manager = WorkflowCycleManage( | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_system_variables={ | |||
| SystemVariableKey.QUERY: message.query, | |||
| SystemVariableKey.FILES: application_generate_entity.files, | |||
| SystemVariableKey.CONVERSATION_ID: conversation.id, | |||
| SystemVariableKey.USER_ID: user_session_id, | |||
| SystemVariableKey.DIALOGUE_COUNT: dialogue_count, | |||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||
| }, | |||
| ) | |||
| self._task_state = WorkflowTaskState() | |||
| self._message_cycle_manager = MessageCycleManage( | |||
| application_generate_entity=application_generate_entity, task_state=self._task_state | |||
| ) | |||
| self._application_generate_entity = application_generate_entity | |||
| self._workflow_id = workflow.id | |||
| self._workflow_features_dict = workflow.features_dict | |||
| self._conversation_id = conversation.id | |||
| self._conversation_mode = conversation.mode | |||
| self._message_id = message.id | |||
| self._message_created_at = int(message.created_at.timestamp()) | |||
| self._workflow_system_variables = { | |||
| SystemVariableKey.QUERY: message.query, | |||
| SystemVariableKey.FILES: application_generate_entity.files, | |||
| SystemVariableKey.CONVERSATION_ID: conversation.id, | |||
| SystemVariableKey.USER_ID: user_session_id, | |||
| SystemVariableKey.DIALOGUE_COUNT: dialogue_count, | |||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||
| } | |||
| self._task_state = WorkflowTaskState() | |||
| self._wip_workflow_node_executions = {} | |||
| self._wip_workflow_agent_logs = {} | |||
| self._conversation_name_generate_thread = None | |||
| self._conversation_name_generate_thread: Thread | None = None | |||
| self._recorded_files: list[Mapping[str, Any]] = [] | |||
| self._workflow_run_id = "" | |||
| self._workflow_run_id: str = "" | |||
| def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: | |||
| """ | |||
| @@ -148,13 +143,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| :return: | |||
| """ | |||
| # start generate conversation name thread | |||
| self._conversation_name_generate_thread = self._generate_conversation_name( | |||
| self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name( | |||
| conversation_id=self._conversation_id, query=self._application_generate_entity.query | |||
| ) | |||
| generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) | |||
| if self._stream: | |||
| if self._base_task_pipeline._stream: | |||
| return self._to_stream_response(generator) | |||
| else: | |||
| return self._to_blocking_response(generator) | |||
| @@ -273,24 +268,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| # init fake graph runtime state | |||
| graph_runtime_state: Optional[GraphRuntimeState] = None | |||
| for queue_message in self._queue_manager.listen(): | |||
| for queue_message in self._base_task_pipeline._queue_manager.listen(): | |||
| event = queue_message.event | |||
| if isinstance(event, QueuePingEvent): | |||
| yield self._ping_stream_response() | |||
| yield self._base_task_pipeline._ping_stream_response() | |||
| elif isinstance(event, QueueErrorEvent): | |||
| with Session(db.engine) as session: | |||
| err = self._handle_error(event=event, session=session, message_id=self._message_id) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| err = self._base_task_pipeline._handle_error( | |||
| event=event, session=session, message_id=self._message_id | |||
| ) | |||
| session.commit() | |||
| yield self._error_to_stream_response(err) | |||
| yield self._base_task_pipeline._error_to_stream_response(err) | |||
| break | |||
| elif isinstance(event, QueueWorkflowStartedEvent): | |||
| # override graph runtime state | |||
| graph_runtime_state = event.graph_runtime_state | |||
| with Session(db.engine) as session: | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # init workflow run | |||
| workflow_run = self._handle_workflow_run_start( | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( | |||
| session=session, | |||
| workflow_id=self._workflow_id, | |||
| user_id=self._user_id, | |||
| @@ -301,7 +298,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if not message: | |||
| raise ValueError(f"Message not found: {self._message_id}") | |||
| message.workflow_run_id = workflow_run.id | |||
| workflow_start_resp = self._workflow_start_to_stream_response( | |||
| workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| session.commit() | |||
| @@ -314,12 +311,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| workflow_node_execution = self._handle_workflow_node_execution_retried( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( | |||
| session=session, workflow_run=workflow_run, event=event | |||
| ) | |||
| node_retry_resp = self._workflow_node_retry_to_stream_response( | |||
| node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( | |||
| session=session, | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| @@ -333,13 +332,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| workflow_node_execution = self._handle_node_execution_start( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( | |||
| session=session, workflow_run=workflow_run, event=event | |||
| ) | |||
| node_start_resp = self._workflow_node_start_to_stream_response( | |||
| node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( | |||
| session=session, | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| @@ -352,12 +353,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| elif isinstance(event, QueueNodeSucceededEvent): | |||
| # Record files if it's an answer node or end node | |||
| if event.node_type in [NodeType.ANSWER, NodeType.END]: | |||
| self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) | |||
| self._recorded_files.extend( | |||
| self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {}) | |||
| ) | |||
| with Session(db.engine) as session: | |||
| workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( | |||
| session=session, event=event | |||
| ) | |||
| node_finish_resp = self._workflow_node_finish_to_stream_response( | |||
| node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( | |||
| session=session, | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| @@ -368,10 +373,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if node_finish_resp: | |||
| yield node_finish_resp | |||
| elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): | |||
| with Session(db.engine) as session: | |||
| workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( | |||
| session=session, event=event | |||
| ) | |||
| node_finish_resp = self._workflow_node_finish_to_stream_response( | |||
| node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( | |||
| session=session, | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| @@ -385,13 +392,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| parallel_start_resp = ( | |||
| self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| ) | |||
| yield parallel_start_resp | |||
| @@ -399,13 +410,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| parallel_finish_resp = ( | |||
| self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| ) | |||
| yield parallel_finish_resp | |||
| @@ -413,9 +428,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| iter_start_resp = self._workflow_iteration_start_to_stream_response( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| @@ -427,9 +444,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| iter_next_resp = self._workflow_iteration_next_to_stream_response( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| @@ -441,9 +460,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| iter_finish_resp = self._workflow_iteration_completed_to_stream_response( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| @@ -458,8 +479,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| if not graph_runtime_state: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_success( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( | |||
| session=session, | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| @@ -470,21 +491,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| trace_manager=trace_manager, | |||
| ) | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) | |||
| self._base_task_pipeline._queue_manager.publish( | |||
| QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE | |||
| ) | |||
| elif isinstance(event, QueueWorkflowPartialSuccessEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_partial_success( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( | |||
| session=session, | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| @@ -495,21 +518,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) | |||
| self._base_task_pipeline._queue_manager.publish( | |||
| QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE | |||
| ) | |||
| elif isinstance(event, QueueWorkflowFailedEvent): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_failed( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( | |||
| session=session, | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| @@ -521,20 +546,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| trace_manager=trace_manager, | |||
| exceptions_count=event.exceptions_count, | |||
| ) | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) | |||
| err = self._handle_error(event=err_event, session=session, message_id=self._message_id) | |||
| err = self._base_task_pipeline._handle_error( | |||
| event=err_event, session=session, message_id=self._message_id | |||
| ) | |||
| session.commit() | |||
| yield workflow_finish_resp | |||
| yield self._error_to_stream_response(err) | |||
| yield self._base_task_pipeline._error_to_stream_response(err) | |||
| break | |||
| elif isinstance(event, QueueStopEvent): | |||
| if self._workflow_run_id and graph_runtime_state: | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_failed( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( | |||
| session=session, | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| @@ -545,7 +572,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| conversation_id=self._conversation_id, | |||
| trace_manager=trace_manager, | |||
| ) | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| @@ -559,18 +586,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| yield self._message_end_to_stream_response() | |||
| break | |||
| elif isinstance(event, QueueRetrieverResourcesEvent): | |||
| self._handle_retriever_resources(event) | |||
| self._message_cycle_manager._handle_retriever_resources(event) | |||
| with Session(db.engine) as session: | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| message = self._get_message(session=session) | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| session.commit() | |||
| elif isinstance(event, QueueAnnotationReplyEvent): | |||
| self._handle_annotation_reply(event) | |||
| self._message_cycle_manager._handle_annotation_reply(event) | |||
| with Session(db.engine) as session: | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| message = self._get_message(session=session) | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| @@ -591,23 +618,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| tts_publisher.publish(queue_message) | |||
| self._task_state.answer += delta_text | |||
| yield self._message_to_stream_response( | |||
| yield self._message_cycle_manager._message_to_stream_response( | |||
| answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector | |||
| ) | |||
| elif isinstance(event, QueueMessageReplaceEvent): | |||
| # published by moderation | |||
| yield self._message_replace_to_stream_response(answer=event.text) | |||
| yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text) | |||
| elif isinstance(event, QueueAdvancedChatMessageEndEvent): | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) | |||
| output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( | |||
| self._task_state.answer | |||
| ) | |||
| if output_moderation_answer: | |||
| self._task_state.answer = output_moderation_answer | |||
| yield self._message_replace_to_stream_response(answer=output_moderation_answer) | |||
| yield self._message_cycle_manager._message_replace_to_stream_response( | |||
| answer=output_moderation_answer | |||
| ) | |||
| # Save message | |||
| with Session(db.engine) as session: | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| self._save_message(session=session, graph_runtime_state=graph_runtime_state) | |||
| session.commit() | |||
| @@ -627,7 +658,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: | |||
| message = self._get_message(session=session) | |||
| message.answer = self._task_state.answer | |||
| message.provider_response_latency = time.perf_counter() - self._start_at | |||
| message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at | |||
| message.message_metadata = ( | |||
| json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None | |||
| ) | |||
| @@ -691,20 +722,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc | |||
| :param text: text | |||
| :return: True if output moderation should direct output, otherwise False | |||
| """ | |||
| if self._output_moderation_handler: | |||
| if self._output_moderation_handler.should_direct_output(): | |||
| if self._base_task_pipeline._output_moderation_handler: | |||
| if self._base_task_pipeline._output_moderation_handler.should_direct_output(): | |||
| # stop subscribe new token when output moderation should direct output | |||
| self._task_state.answer = self._output_moderation_handler.get_final_output() | |||
| self._queue_manager.publish( | |||
| self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() | |||
| self._base_task_pipeline._queue_manager.publish( | |||
| QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE | |||
| ) | |||
| self._queue_manager.publish( | |||
| self._base_task_pipeline._queue_manager.publish( | |||
| QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE | |||
| ) | |||
| return True | |||
| else: | |||
| self._output_moderation_handler.append_new_token(text) | |||
| self._base_task_pipeline._output_moderation_handler.append_new_token(text) | |||
| return False | |||
| @@ -19,7 +19,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt | |||
| from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | |||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | |||
| from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| @@ -251,7 +251,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): | |||
| except ValidationError as e: | |||
| logger.exception("Validation Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| except (ValueError, InvokeError) as e: | |||
| except ValueError as e: | |||
| if dify_config.DEBUG: | |||
| logger.exception("Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| @@ -18,7 +18,7 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon | |||
| from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | |||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | |||
| from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| @@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): | |||
| except ValidationError as e: | |||
| logger.exception("Validation Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| except (ValueError, InvokeError) as e: | |||
| except ValueError as e: | |||
| if dify_config.DEBUG: | |||
| logger.exception("Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| @@ -17,7 +17,7 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe | |||
| from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | |||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | |||
| from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| @@ -214,7 +214,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| except ValidationError as e: | |||
| logger.exception("Validation Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| except (ValueError, InvokeError) as e: | |||
| except ValueError as e: | |||
| if dify_config.DEBUG: | |||
| logger.exception("Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| @@ -20,7 +20,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera | |||
| from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline | |||
| from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity | |||
| from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from extensions.ext_database import db | |||
| from factories import file_factory | |||
| @@ -235,6 +235,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( | |||
| node_id=node_id, inputs=args["inputs"] | |||
| ), | |||
| workflow_run_id=str(uuid.uuid4()), | |||
| ) | |||
| contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) | |||
| contexts.plugin_tool_providers.set({}) | |||
| @@ -286,7 +287,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| except ValidationError as e: | |||
| logger.exception("Validation Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| except (ValueError, InvokeError) as e: | |||
| except ValueError as e: | |||
| if dify_config.DEBUG: | |||
| logger.exception("Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| @@ -1,7 +1,7 @@ | |||
| import logging | |||
| import time | |||
| from collections.abc import Generator | |||
| from typing import Any, Optional, Union | |||
| from typing import Optional, Union | |||
| from sqlalchemy.orm import Session | |||
| @@ -59,7 +59,6 @@ from models.workflow import ( | |||
| Workflow, | |||
| WorkflowAppLog, | |||
| WorkflowAppLogCreatedFrom, | |||
| WorkflowNodeExecution, | |||
| WorkflowRun, | |||
| WorkflowRunStatus, | |||
| ) | |||
| @@ -67,16 +66,11 @@ from models.workflow import ( | |||
| logger = logging.getLogger(__name__) | |||
| class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage): | |||
| class WorkflowAppGenerateTaskPipeline: | |||
| """ | |||
| WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. | |||
| """ | |||
| _task_state: WorkflowTaskState | |||
| _application_generate_entity: WorkflowAppGenerateEntity | |||
| _workflow_system_variables: dict[SystemVariableKey, Any] | |||
| _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] | |||
| def __init__( | |||
| self, | |||
| application_generate_entity: WorkflowAppGenerateEntity, | |||
| @@ -85,7 +79,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| user: Union[Account, EndUser], | |||
| stream: bool, | |||
| ) -> None: | |||
| super().__init__( | |||
| self._base_task_pipeline = BasedGenerateTaskPipeline( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| stream=stream, | |||
| @@ -102,17 +96,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| else: | |||
| raise ValueError(f"Invalid user type: {type(user)}") | |||
| self._workflow_cycle_manager = WorkflowCycleManage( | |||
| application_generate_entity=application_generate_entity, | |||
| workflow_system_variables={ | |||
| SystemVariableKey.FILES: application_generate_entity.files, | |||
| SystemVariableKey.USER_ID: user_session_id, | |||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||
| }, | |||
| ) | |||
| self._application_generate_entity = application_generate_entity | |||
| self._workflow_id = workflow.id | |||
| self._workflow_features_dict = workflow.features_dict | |||
| self._workflow_system_variables = { | |||
| SystemVariableKey.FILES: application_generate_entity.files, | |||
| SystemVariableKey.USER_ID: user_session_id, | |||
| SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, | |||
| SystemVariableKey.WORKFLOW_ID: workflow.id, | |||
| SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, | |||
| } | |||
| self._task_state = WorkflowTaskState() | |||
| self._workflow_run_id = "" | |||
| @@ -122,7 +119,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| :return: | |||
| """ | |||
| generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) | |||
| if self._stream: | |||
| if self._base_task_pipeline._stream: | |||
| return self._to_stream_response(generator) | |||
| else: | |||
| return self._to_blocking_response(generator) | |||
| @@ -239,29 +236,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| """ | |||
| graph_runtime_state = None | |||
| for queue_message in self._queue_manager.listen(): | |||
| for queue_message in self._base_task_pipeline._queue_manager.listen(): | |||
| event = queue_message.event | |||
| if isinstance(event, QueuePingEvent): | |||
| yield self._ping_stream_response() | |||
| yield self._base_task_pipeline._ping_stream_response() | |||
| elif isinstance(event, QueueErrorEvent): | |||
| err = self._handle_error(event=event) | |||
| yield self._error_to_stream_response(err) | |||
| err = self._base_task_pipeline._handle_error(event=event) | |||
| yield self._base_task_pipeline._error_to_stream_response(err) | |||
| break | |||
| elif isinstance(event, QueueWorkflowStartedEvent): | |||
| # override graph runtime state | |||
| graph_runtime_state = event.graph_runtime_state | |||
| with Session(db.engine) as session: | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # init workflow run | |||
| workflow_run = self._handle_workflow_run_start( | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( | |||
| session=session, | |||
| workflow_id=self._workflow_id, | |||
| user_id=self._user_id, | |||
| created_by_role=self._created_by_role, | |||
| ) | |||
| self._workflow_run_id = workflow_run.id | |||
| start_resp = self._workflow_start_to_stream_response( | |||
| start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| session.commit() | |||
| @@ -273,12 +270,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| ): | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| workflow_node_execution = self._handle_workflow_node_execution_retried( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( | |||
| session=session, workflow_run=workflow_run, event=event | |||
| ) | |||
| response = self._workflow_node_retry_to_stream_response( | |||
| response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( | |||
| session=session, | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| @@ -292,12 +291,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| workflow_node_execution = self._handle_node_execution_start( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( | |||
| session=session, workflow_run=workflow_run, event=event | |||
| ) | |||
| node_start_response = self._workflow_node_start_to_stream_response( | |||
| node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( | |||
| session=session, | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| @@ -308,9 +309,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if node_start_response: | |||
| yield node_start_response | |||
| elif isinstance(event, QueueNodeSucceededEvent): | |||
| with Session(db.engine) as session: | |||
| workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event) | |||
| node_success_response = self._workflow_node_finish_to_stream_response( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( | |||
| session=session, event=event | |||
| ) | |||
| node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( | |||
| session=session, | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| @@ -321,12 +324,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if node_success_response: | |||
| yield node_success_response | |||
| elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): | |||
| with Session(db.engine) as session: | |||
| workflow_node_execution = self._handle_workflow_node_execution_failed( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( | |||
| session=session, | |||
| event=event, | |||
| ) | |||
| node_failed_response = self._workflow_node_finish_to_stream_response( | |||
| node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( | |||
| session=session, | |||
| event=event, | |||
| task_id=self._application_generate_entity.task_id, | |||
| @@ -341,13 +344,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| parallel_start_resp = ( | |||
| self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| ) | |||
| yield parallel_start_resp | |||
| @@ -356,13 +363,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| parallel_finish_resp = ( | |||
| self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| event=event, | |||
| ) | |||
| ) | |||
| yield parallel_finish_resp | |||
| @@ -371,9 +382,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| iter_start_resp = self._workflow_iteration_start_to_stream_response( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| @@ -386,9 +399,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| iter_next_resp = self._workflow_iteration_next_to_stream_response( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| @@ -401,9 +416,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if not self._workflow_run_id: | |||
| raise ValueError("workflow run not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id) | |||
| iter_finish_resp = self._workflow_iteration_completed_to_stream_response( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._get_workflow_run( | |||
| session=session, workflow_run_id=self._workflow_run_id | |||
| ) | |||
| iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| @@ -418,8 +435,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_success( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( | |||
| session=session, | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| @@ -433,7 +450,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| # save workflow app log | |||
| self._save_workflow_app_log(session=session, workflow_run=workflow_run) | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| session=session, | |||
| task_id=self._application_generate_entity.task_id, | |||
| workflow_run=workflow_run, | |||
| @@ -447,8 +464,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_partial_success( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( | |||
| session=session, | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| @@ -463,7 +480,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| # save workflow app log | |||
| self._save_workflow_app_log(session=session, workflow_run=workflow_run) | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| session.commit() | |||
| @@ -475,8 +492,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| if not graph_runtime_state: | |||
| raise ValueError("graph runtime state not initialized.") | |||
| with Session(db.engine) as session: | |||
| workflow_run = self._handle_workflow_run_failed( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( | |||
| session=session, | |||
| workflow_run_id=self._workflow_run_id, | |||
| start_at=graph_runtime_state.start_at, | |||
| @@ -494,7 +511,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa | |||
| # save workflow app log | |||
| self._save_workflow_app_log(session=session, workflow_run=workflow_run) | |||
| workflow_finish_resp = self._workflow_finish_to_stream_response( | |||
| workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( | |||
| session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run | |||
| ) | |||
| session.commit() | |||
| @@ -195,7 +195,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): | |||
| # app config | |||
| app_config: WorkflowUIBasedAppConfig | |||
| workflow_run_id: Optional[str] = None | |||
| workflow_run_id: str | |||
| class SingleIterationRunEntity(BaseModel): | |||
| """ | |||
| @@ -15,7 +15,6 @@ from core.app.entities.queue_entities import ( | |||
| from core.app.entities.task_entities import ( | |||
| ErrorStreamResponse, | |||
| PingStreamResponse, | |||
| TaskState, | |||
| ) | |||
| from core.errors.error import QuotaExceededError | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| @@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline: | |||
| BasedGenerateTaskPipeline is a class that generate stream output and state management for Application. | |||
| """ | |||
| _task_state: TaskState | |||
| _application_generate_entity: AppGenerateEntity | |||
| def __init__( | |||
| self, | |||
| application_generate_entity: AppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| stream: bool, | |||
| ) -> None: | |||
| """ | |||
| Initialize GenerateTaskPipeline. | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: queue manager | |||
| :param user: user | |||
| :param stream: stream | |||
| """ | |||
| self._application_generate_entity = application_generate_entity | |||
| self._queue_manager = queue_manager | |||
| self._start_at = time.perf_counter() | |||
| @@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService | |||
| class MessageCycleManage: | |||
| _application_generate_entity: Union[ | |||
| ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity | |||
| ] | |||
| _task_state: Union[EasyUITaskState, WorkflowTaskState] | |||
| def __init__( | |||
| self, | |||
| *, | |||
| application_generate_entity: Union[ | |||
| ChatAppGenerateEntity, | |||
| CompletionAppGenerateEntity, | |||
| AgentChatAppGenerateEntity, | |||
| AdvancedChatAppGenerateEntity, | |||
| ], | |||
| task_state: Union[EasyUITaskState, WorkflowTaskState], | |||
| ) -> None: | |||
| self._application_generate_entity = application_generate_entity | |||
| self._task_state = task_state | |||
| def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: | |||
| """ | |||
| @@ -36,7 +36,6 @@ from core.app.entities.task_entities import ( | |||
| ParallelBranchStartStreamResponse, | |||
| WorkflowFinishStreamResponse, | |||
| WorkflowStartStreamResponse, | |||
| WorkflowTaskState, | |||
| ) | |||
| from core.file import FILE_MODEL_IDENTITY, File | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| @@ -60,13 +59,20 @@ from models.workflow import ( | |||
| WorkflowRunStatus, | |||
| ) | |||
| from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError | |||
| from .exc import WorkflowRunNotFoundError | |||
| class WorkflowCycleManage: | |||
| _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] | |||
| _task_state: WorkflowTaskState | |||
| _workflow_system_variables: dict[SystemVariableKey, Any] | |||
| def __init__( | |||
| self, | |||
| *, | |||
| application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], | |||
| workflow_system_variables: dict[SystemVariableKey, Any], | |||
| ) -> None: | |||
| self._workflow_run: WorkflowRun | None = None | |||
| self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} | |||
| self._application_generate_entity = application_generate_entity | |||
| self._workflow_system_variables = workflow_system_variables | |||
| def _handle_workflow_run_start( | |||
| self, | |||
| @@ -104,7 +110,8 @@ class WorkflowCycleManage: | |||
| inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) | |||
| # init workflow run | |||
| workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4())) | |||
| # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this | |||
| workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) | |||
| workflow_run = WorkflowRun() | |||
| workflow_run.id = workflow_run_id | |||
| @@ -241,7 +248,7 @@ class WorkflowCycleManage: | |||
| workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_run.exceptions_count = exceptions_count | |||
| stmt = select(WorkflowNodeExecution).where( | |||
| stmt = select(WorkflowNodeExecution.node_execution_id).where( | |||
| WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, | |||
| WorkflowNodeExecution.app_id == workflow_run.app_id, | |||
| WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, | |||
| @@ -249,15 +256,18 @@ class WorkflowCycleManage: | |||
| WorkflowNodeExecution.workflow_run_id == workflow_run.id, | |||
| WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, | |||
| ) | |||
| running_workflow_node_executions = session.scalars(stmt).all() | |||
| ids = session.scalars(stmt).all() | |||
| # Use self._get_workflow_node_execution here to make sure the cache is updated | |||
| running_workflow_node_executions = [ | |||
| self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id | |||
| ] | |||
| for workflow_node_execution in running_workflow_node_executions: | |||
| now = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value | |||
| workflow_node_execution.error = error | |||
| finish_at = datetime.now(UTC).replace(tzinfo=None) | |||
| workflow_node_execution.finished_at = finish_at | |||
| workflow_node_execution.elapsed_time = (finish_at - workflow_node_execution.created_at).total_seconds() | |||
| workflow_node_execution.finished_at = now | |||
| workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds() | |||
| if trace_manager: | |||
| trace_manager.add_trace_task( | |||
| @@ -299,6 +309,8 @@ class WorkflowCycleManage: | |||
| workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) | |||
| session.add(workflow_node_execution) | |||
| self._workflow_node_executions[event.node_execution_id] = workflow_node_execution | |||
| return workflow_node_execution | |||
| def _handle_workflow_node_execution_success( | |||
| @@ -325,6 +337,7 @@ class WorkflowCycleManage: | |||
| workflow_node_execution.finished_at = finished_at | |||
| workflow_node_execution.elapsed_time = elapsed_time | |||
| workflow_node_execution = session.merge(workflow_node_execution) | |||
| return workflow_node_execution | |||
| def _handle_workflow_node_execution_failed( | |||
| @@ -364,6 +377,7 @@ class WorkflowCycleManage: | |||
| workflow_node_execution.elapsed_time = elapsed_time | |||
| workflow_node_execution.execution_metadata = execution_metadata | |||
| workflow_node_execution = session.merge(workflow_node_execution) | |||
| return workflow_node_execution | |||
| def _handle_workflow_node_execution_retried( | |||
| @@ -415,6 +429,8 @@ class WorkflowCycleManage: | |||
| workflow_node_execution.index = event.node_run_index | |||
| session.add(workflow_node_execution) | |||
| self._workflow_node_executions[event.node_execution_id] = workflow_node_execution | |||
| return workflow_node_execution | |||
| ################################################# | |||
| @@ -811,25 +827,23 @@ class WorkflowCycleManage: | |||
| return None | |||
| def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: | |||
| """ | |||
| Refetch workflow run | |||
| :param workflow_run_id: workflow run id | |||
| :return: | |||
| """ | |||
| if self._workflow_run and self._workflow_run.id == workflow_run_id: | |||
| cached_workflow_run = self._workflow_run | |||
| cached_workflow_run = session.merge(cached_workflow_run) | |||
| return cached_workflow_run | |||
| stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) | |||
| workflow_run = session.scalar(stmt) | |||
| if not workflow_run: | |||
| raise WorkflowRunNotFoundError(workflow_run_id) | |||
| self._workflow_run = workflow_run | |||
| return workflow_run | |||
| def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: | |||
| stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id) | |||
| workflow_node_execution = session.scalar(stmt) | |||
| if not workflow_node_execution: | |||
| raise WorkflowNodeExecutionNotFoundError(node_execution_id) | |||
| return workflow_node_execution | |||
| if node_execution_id not in self._workflow_node_executions: | |||
| raise ValueError(f"Workflow node execution not found: {node_execution_id}") | |||
| cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] | |||
| return cached_workflow_node_execution | |||
| def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: | |||
| """ | |||
| @@ -1,9 +1,47 @@ | |||
| import tiktoken | |||
| from threading import Lock | |||
| from typing import Any | |||
| _tokenizer: Any = None | |||
| _lock = Lock() | |||
| class GPT2Tokenizer: | |||
| @staticmethod | |||
| def _get_num_tokens_by_gpt2(text: str) -> int: | |||
| """ | |||
| use gpt2 tokenizer to get num tokens | |||
| """ | |||
| _tokenizer = GPT2Tokenizer.get_encoder() | |||
| tokens = _tokenizer.encode(text) | |||
| return len(tokens) | |||
| @staticmethod | |||
| def get_num_tokens(text: str) -> int: | |||
| encoding = tiktoken.encoding_for_model("gpt2") | |||
| tiktoken_vec = encoding.encode(text) | |||
| return len(tiktoken_vec) | |||
| # Because this process needs more cpu resource, we turn this back before we find a better way to handle it. | |||
| # | |||
| # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text) | |||
| # result = future.result() | |||
| # return cast(int, result) | |||
| return GPT2Tokenizer._get_num_tokens_by_gpt2(text) | |||
| @staticmethod | |||
| def get_encoder() -> Any: | |||
| global _tokenizer, _lock | |||
| with _lock: | |||
| if _tokenizer is None: | |||
| # Try to use tiktoken to get the tokenizer because it is faster | |||
| # | |||
| try: | |||
| import tiktoken | |||
| _tokenizer = tiktoken.get_encoding("gpt2") | |||
| except Exception: | |||
| from os.path import abspath, dirname, join | |||
| from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore | |||
| base_path = abspath(__file__) | |||
| gpt2_tokenizer_path = join(dirname(base_path), "gpt2") | |||
| _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) | |||
| return _tokenizer | |||
| @@ -113,6 +113,8 @@ class BaiduVector(BaseVector): | |||
| return False | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| if not ids: | |||
| return | |||
| quoted_ids = [f"'{id}'" for id in ids] | |||
| self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") | |||
| @@ -83,6 +83,8 @@ class ChromaVector(BaseVector): | |||
| self._client.delete_collection(self._collection_name) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| if not ids: | |||
| return | |||
| collection = self._client.get_or_create_collection(self._collection_name) | |||
| collection.delete(ids=ids) | |||
| @@ -0,0 +1,104 @@ | |||
| import json | |||
| import logging | |||
| from typing import Any, Optional | |||
| from flask import current_app | |||
| from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ( | |||
| ElasticSearchConfig, | |||
| ElasticSearchVector, | |||
| ElasticSearchVectorFactory, | |||
| ) | |||
| from core.rag.datasource.vdb.field import Field | |||
| from core.rag.datasource.vdb.vector_type import VectorType | |||
| from core.rag.embedding.embedding_base import Embeddings | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset | |||
| logger = logging.getLogger(__name__) | |||
| class ElasticSearchJaVector(ElasticSearchVector): | |||
| def create_collection( | |||
| self, | |||
| embeddings: list[list[float]], | |||
| metadatas: Optional[list[dict[Any, Any]]] = None, | |||
| index_params: Optional[dict] = None, | |||
| ): | |||
| lock_name = f"vector_indexing_lock_{self._collection_name}" | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||
| if redis_client.get(collection_exist_cache_key): | |||
| logger.info(f"Collection {self._collection_name} already exists.") | |||
| return | |||
| if not self._client.indices.exists(index=self._collection_name): | |||
| dim = len(embeddings[0]) | |||
| settings = { | |||
| "analysis": { | |||
| "analyzer": { | |||
| "ja_analyzer": { | |||
| "type": "custom", | |||
| "char_filter": [ | |||
| "icu_normalizer", | |||
| "kuromoji_iteration_mark", | |||
| ], | |||
| "tokenizer": "kuromoji_tokenizer", | |||
| "filter": [ | |||
| "kuromoji_baseform", | |||
| "kuromoji_part_of_speech", | |||
| "ja_stop", | |||
| "kuromoji_number", | |||
| "kuromoji_stemmer", | |||
| ], | |||
| } | |||
| } | |||
| } | |||
| } | |||
| mappings = { | |||
| "properties": { | |||
| Field.CONTENT_KEY.value: { | |||
| "type": "text", | |||
| "analyzer": "ja_analyzer", | |||
| "search_analyzer": "ja_analyzer", | |||
| }, | |||
| Field.VECTOR.value: { # Make sure the dimension is correct here | |||
| "type": "dense_vector", | |||
| "dims": dim, | |||
| "index": True, | |||
| "similarity": "cosine", | |||
| }, | |||
| Field.METADATA_KEY.value: { | |||
| "type": "object", | |||
| "properties": { | |||
| "doc_id": {"type": "keyword"} # Map doc_id to keyword type | |||
| }, | |||
| }, | |||
| } | |||
| } | |||
| self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings) | |||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||
| class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory): | |||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector: | |||
| if dataset.index_struct_dict: | |||
| class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] | |||
| collection_name = class_prefix | |||
| else: | |||
| dataset_id = dataset.id | |||
| collection_name = Dataset.gen_collection_name_by_id(dataset_id) | |||
| dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) | |||
| config = current_app.config | |||
| return ElasticSearchJaVector( | |||
| index_name=collection_name, | |||
| config=ElasticSearchConfig( | |||
| host=config.get("ELASTICSEARCH_HOST", "localhost"), | |||
| port=config.get("ELASTICSEARCH_PORT", 9200), | |||
| username=config.get("ELASTICSEARCH_USERNAME", ""), | |||
| password=config.get("ELASTICSEARCH_PASSWORD", ""), | |||
| ), | |||
| attributes=[], | |||
| ) | |||
| @@ -98,6 +98,8 @@ class ElasticSearchVector(BaseVector): | |||
| return bool(self._client.exists(index=self._collection_name, id=id)) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| if not ids: | |||
| return | |||
| for id in ids: | |||
| self._client.delete(index=self._collection_name, id=id) | |||
| @@ -6,6 +6,8 @@ class Field(Enum): | |||
| METADATA_KEY = "metadata" | |||
| GROUP_KEY = "group_id" | |||
| VECTOR = "vector" | |||
| # Sparse Vector aims to support full text search | |||
| SPARSE_VECTOR = "sparse_vector" | |||
| TEXT_KEY = "text" | |||
| PRIMARY_KEY = "id" | |||
| DOC_ID = "metadata.doc_id" | |||
| @@ -2,6 +2,7 @@ import json | |||
| import logging | |||
| from typing import Any, Optional | |||
| from packaging import version | |||
| from pydantic import BaseModel, model_validator | |||
| from pymilvus import MilvusClient, MilvusException # type: ignore | |||
| from pymilvus.milvus_client import IndexParams # type: ignore | |||
| @@ -20,16 +21,25 @@ logger = logging.getLogger(__name__) | |||
| class MilvusConfig(BaseModel): | |||
| uri: str | |||
| token: Optional[str] = None | |||
| user: str | |||
| password: str | |||
| batch_size: int = 100 | |||
| database: str = "default" | |||
| """ | |||
| Configuration class for Milvus connection. | |||
| """ | |||
| uri: str # Milvus server URI | |||
| token: Optional[str] = None # Optional token for authentication | |||
| user: str # Username for authentication | |||
| password: str # Password for authentication | |||
| batch_size: int = 100 # Batch size for operations | |||
| database: str = "default" # Database name | |||
| enable_hybrid_search: bool = False # Flag to enable hybrid search | |||
| @model_validator(mode="before") | |||
| @classmethod | |||
| def validate_config(cls, values: dict) -> dict: | |||
| """ | |||
| Validate the configuration values. | |||
| Raises ValueError if required fields are missing. | |||
| """ | |||
| if not values.get("uri"): | |||
| raise ValueError("config MILVUS_URI is required") | |||
| if not values.get("user"): | |||
| @@ -39,6 +49,9 @@ class MilvusConfig(BaseModel): | |||
| return values | |||
| def to_milvus_params(self): | |||
| """ | |||
| Convert the configuration to a dictionary of Milvus connection parameters. | |||
| """ | |||
| return { | |||
| "uri": self.uri, | |||
| "token": self.token, | |||
| @@ -49,26 +62,57 @@ class MilvusConfig(BaseModel): | |||
| class MilvusVector(BaseVector): | |||
| """ | |||
| Milvus vector storage implementation. | |||
| """ | |||
| def __init__(self, collection_name: str, config: MilvusConfig): | |||
| super().__init__(collection_name) | |||
| self._client_config = config | |||
| self._client = self._init_client(config) | |||
| self._consistency_level = "Session" | |||
| self._fields: list[str] = [] | |||
| self._consistency_level = "Session" # Consistency level for Milvus operations | |||
| self._fields: list[str] = [] # List of fields in the collection | |||
| self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported | |||
| def _check_hybrid_search_support(self) -> bool: | |||
| """ | |||
| Check if the current Milvus version supports hybrid search. | |||
| Returns True if the version is >= 2.5.0, otherwise False. | |||
| """ | |||
| if not self._client_config.enable_hybrid_search: | |||
| return False | |||
| try: | |||
| milvus_version = self._client.get_server_version() | |||
| return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version | |||
| except Exception as e: | |||
| logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.") | |||
| return False | |||
| def get_type(self) -> str: | |||
| """ | |||
| Get the type of vector storage (Milvus). | |||
| """ | |||
| return VectorType.MILVUS | |||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||
| """ | |||
| Create a collection and add texts with embeddings. | |||
| """ | |||
| index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} | |||
| metadatas = [d.metadata if d.metadata is not None else {} for d in texts] | |||
| self.create_collection(embeddings, metadatas, index_params) | |||
| self.add_texts(texts, embeddings) | |||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||
| """ | |||
| Add texts and their embeddings to the collection. | |||
| """ | |||
| insert_dict_list = [] | |||
| for i in range(len(documents)): | |||
| insert_dict = { | |||
| # Do not need to insert the sparse_vector field separately, as the text_bm25_emb | |||
| # function will automatically convert the native text into a sparse vector for us. | |||
| Field.CONTENT_KEY.value: documents[i].page_content, | |||
| Field.VECTOR.value: embeddings[i], | |||
| Field.METADATA_KEY.value: documents[i].metadata, | |||
| @@ -76,12 +120,11 @@ class MilvusVector(BaseVector): | |||
| insert_dict_list.append(insert_dict) | |||
| # Total insert count | |||
| total_count = len(insert_dict_list) | |||
| pks: list[str] = [] | |||
| for i in range(0, total_count, 1000): | |||
| batch_insert_list = insert_dict_list[i : i + 1000] | |||
| # Insert into the collection. | |||
| batch_insert_list = insert_dict_list[i : i + 1000] | |||
| try: | |||
| ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) | |||
| pks.extend(ids) | |||
| @@ -91,6 +134,9 @@ class MilvusVector(BaseVector): | |||
| return pks | |||
| def get_ids_by_metadata_field(self, key: str, value: str): | |||
| """ | |||
| Get document IDs by metadata field key and value. | |||
| """ | |||
| result = self._client.query( | |||
| collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"] | |||
| ) | |||
| @@ -100,12 +146,18 @@ class MilvusVector(BaseVector): | |||
| return None | |||
| def delete_by_metadata_field(self, key: str, value: str): | |||
| """ | |||
| Delete documents by metadata field key and value. | |||
| """ | |||
| if self._client.has_collection(self._collection_name): | |||
| ids = self.get_ids_by_metadata_field(key, value) | |||
| if ids: | |||
| self._client.delete(collection_name=self._collection_name, pks=ids) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| """ | |||
| Delete documents by their IDs. | |||
| """ | |||
| if self._client.has_collection(self._collection_name): | |||
| result = self._client.query( | |||
| collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"] | |||
| @@ -115,10 +167,16 @@ class MilvusVector(BaseVector): | |||
| self._client.delete(collection_name=self._collection_name, pks=ids) | |||
| def delete(self) -> None: | |||
| """ | |||
| Delete the entire collection. | |||
| """ | |||
| if self._client.has_collection(self._collection_name): | |||
| self._client.drop_collection(self._collection_name, None) | |||
| def text_exists(self, id: str) -> bool: | |||
| """ | |||
| Check if a text with the given ID exists in the collection. | |||
| """ | |||
| if not self._client.has_collection(self._collection_name): | |||
| return False | |||
| @@ -128,32 +186,80 @@ class MilvusVector(BaseVector): | |||
| return len(result) > 0 | |||
| def field_exists(self, field: str) -> bool: | |||
| """ | |||
| Check if a field exists in the collection. | |||
| """ | |||
| return field in self._fields | |||
| def _process_search_results( | |||
| self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0 | |||
| ) -> list[Document]: | |||
| """ | |||
| Common method to process search results | |||
| :param results: Search results | |||
| :param output_fields: Fields to be output | |||
| :param score_threshold: Score threshold for filtering | |||
| :return: List of documents | |||
| """ | |||
| docs = [] | |||
| for result in results[0]: | |||
| metadata = result["entity"].get(output_fields[1], {}) | |||
| metadata["score"] = result["distance"] | |||
| if result["distance"] > score_threshold: | |||
| doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata) | |||
| docs.append(doc) | |||
| return docs | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| # Set search parameters. | |||
| """ | |||
| Search for documents by vector similarity. | |||
| """ | |||
| results = self._client.search( | |||
| collection_name=self._collection_name, | |||
| data=[query_vector], | |||
| anns_field=Field.VECTOR.value, | |||
| limit=kwargs.get("top_k", 4), | |||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | |||
| ) | |||
| # Organize results. | |||
| docs = [] | |||
| for result in results[0]: | |||
| metadata = result["entity"].get(Field.METADATA_KEY.value) | |||
| metadata["score"] = result["distance"] | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| if result["distance"] > score_threshold: | |||
| doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) | |||
| docs.append(doc) | |||
| return docs | |||
| return self._process_search_results( | |||
| results, | |||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | |||
| score_threshold=float(kwargs.get("score_threshold") or 0.0), | |||
| ) | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| # milvus/zilliz doesn't support bm25 search | |||
| return [] | |||
| """ | |||
| Search for documents by full-text search (if hybrid search is enabled). | |||
| """ | |||
| if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value): | |||
| logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)") | |||
| return [] | |||
| results = self._client.search( | |||
| collection_name=self._collection_name, | |||
| data=[query], | |||
| anns_field=Field.SPARSE_VECTOR.value, | |||
| limit=kwargs.get("top_k", 4), | |||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | |||
| ) | |||
| return self._process_search_results( | |||
| results, | |||
| output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], | |||
| score_threshold=float(kwargs.get("score_threshold") or 0.0), | |||
| ) | |||
| def create_collection( | |||
| self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None | |||
| ): | |||
| """ | |||
| Create a new collection in Milvus with the specified schema and index parameters. | |||
| """ | |||
| lock_name = "vector_indexing_lock_{}".format(self._collection_name) | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||
| @@ -161,7 +267,7 @@ class MilvusVector(BaseVector): | |||
| return | |||
| # Grab the existing collection if it exists | |||
| if not self._client.has_collection(self._collection_name): | |||
| from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore | |||
| from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore | |||
| from pymilvus.orm.types import infer_dtype_bydata # type: ignore | |||
| # Determine embedding dim | |||
| @@ -170,16 +276,36 @@ class MilvusVector(BaseVector): | |||
| if metadatas: | |||
| fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) | |||
| # Create the text field | |||
| fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)) | |||
| # Create the text field, enable_analyzer will be set True to support milvus automatically | |||
| # transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md | |||
| fields.append( | |||
| FieldSchema( | |||
| Field.CONTENT_KEY.value, | |||
| DataType.VARCHAR, | |||
| max_length=65_535, | |||
| enable_analyzer=self._hybrid_search_enabled, | |||
| ) | |||
| ) | |||
| # Create the primary key field | |||
| fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) | |||
| # Create the vector field, supports binary or float vectors | |||
| fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) | |||
| # Create Sparse Vector Index for the collection | |||
| if self._hybrid_search_enabled: | |||
| fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR)) | |||
| # Create the schema for the collection | |||
| schema = CollectionSchema(fields) | |||
| # Create custom function to support text to sparse vector by BM25 | |||
| if self._hybrid_search_enabled: | |||
| bm25_function = Function( | |||
| name="text_bm25_emb", | |||
| input_field_names=[Field.CONTENT_KEY.value], | |||
| output_field_names=[Field.SPARSE_VECTOR.value], | |||
| function_type=FunctionType.BM25, | |||
| ) | |||
| schema.add_function(bm25_function) | |||
| for x in schema.fields: | |||
| self._fields.append(x.name) | |||
| # Since primary field is auto-id, no need to track it | |||
| @@ -189,10 +315,15 @@ class MilvusVector(BaseVector): | |||
| index_params_obj = IndexParams() | |||
| index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params) | |||
| # Create Sparse Vector Index for the collection | |||
| if self._hybrid_search_enabled: | |||
| index_params_obj.add_index( | |||
| field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25" | |||
| ) | |||
| # Create the collection | |||
| collection_name = self._collection_name | |||
| self._client.create_collection( | |||
| collection_name=collection_name, | |||
| collection_name=self._collection_name, | |||
| schema=schema, | |||
| index_params=index_params_obj, | |||
| consistency_level=self._consistency_level, | |||
| @@ -200,12 +331,22 @@ class MilvusVector(BaseVector): | |||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||
| def _init_client(self, config) -> MilvusClient: | |||
| """ | |||
| Initialize and return a Milvus client. | |||
| """ | |||
| client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database) | |||
| return client | |||
| class MilvusVectorFactory(AbstractVectorFactory): | |||
| """ | |||
| Factory class for creating MilvusVector instances. | |||
| """ | |||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: | |||
| """ | |||
| Initialize a MilvusVector instance for the given dataset. | |||
| """ | |||
| if dataset.index_struct_dict: | |||
| class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] | |||
| collection_name = class_prefix | |||
| @@ -222,5 +363,6 @@ class MilvusVectorFactory(AbstractVectorFactory): | |||
| user=dify_config.MILVUS_USER or "", | |||
| password=dify_config.MILVUS_PASSWORD or "", | |||
| database=dify_config.MILVUS_DATABASE or "", | |||
| enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False, | |||
| ), | |||
| ) | |||
| @@ -100,6 +100,8 @@ class MyScaleVector(BaseVector): | |||
| return results.row_count > 0 | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| if not ids: | |||
| return | |||
| self._client.command( | |||
| f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" | |||
| ) | |||
| @@ -134,6 +134,8 @@ class OceanBaseVector(BaseVector): | |||
| return bool(cur.rowcount != 0) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| if not ids: | |||
| return | |||
| self._client.delete(table_name=self._collection_name, ids=ids) | |||
| def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: | |||
| @@ -167,6 +167,8 @@ class OracleVector(BaseVector): | |||
| return docs | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| if not ids: | |||
| return | |||
| with self._get_cursor() as cur: | |||
| cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) | |||
| @@ -129,6 +129,11 @@ class PGVector(BaseVector): | |||
| return docs | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios | |||
| # Scenario 1: extract a document fails, resulting in a table not being created. | |||
| # Then clicking the retry button triggers a delete operation on an empty list. | |||
| if not ids: | |||
| return | |||
| with self._get_cursor() as cur: | |||
| cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) | |||
| @@ -140,6 +140,8 @@ class TencentVector(BaseVector): | |||
| return False | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| if not ids: | |||
| return | |||
| self._db.collection(self._collection_name).delete(document_ids=ids) | |||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||
| @@ -409,27 +409,27 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): | |||
| db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() | |||
| ) | |||
| if not tidb_auth_binding: | |||
| idle_tidb_auth_binding = ( | |||
| db.session.query(TidbAuthBinding) | |||
| .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") | |||
| .limit(1) | |||
| .one_or_none() | |||
| ) | |||
| if idle_tidb_auth_binding: | |||
| idle_tidb_auth_binding.active = True | |||
| idle_tidb_auth_binding.tenant_id = dataset.tenant_id | |||
| db.session.commit() | |||
| TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" | |||
| else: | |||
| with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): | |||
| tidb_auth_binding = ( | |||
| with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): | |||
| tidb_auth_binding = ( | |||
| db.session.query(TidbAuthBinding) | |||
| .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||
| .one_or_none() | |||
| ) | |||
| if tidb_auth_binding: | |||
| TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | |||
| else: | |||
| idle_tidb_auth_binding = ( | |||
| db.session.query(TidbAuthBinding) | |||
| .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||
| .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") | |||
| .limit(1) | |||
| .one_or_none() | |||
| ) | |||
| if tidb_auth_binding: | |||
| TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | |||
| if idle_tidb_auth_binding: | |||
| idle_tidb_auth_binding.active = True | |||
| idle_tidb_auth_binding.tenant_id = dataset.tenant_id | |||
| db.session.commit() | |||
| TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" | |||
| else: | |||
| new_cluster = TidbService.create_tidb_serverless_cluster( | |||
| dify_config.TIDB_PROJECT_ID or "", | |||
| @@ -451,7 +451,6 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): | |||
| db.session.add(new_tidb_auth_binding) | |||
| db.session.commit() | |||
| TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" | |||
| else: | |||
| TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | |||
| @@ -90,6 +90,12 @@ class Vector: | |||
| from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory | |||
| return ElasticSearchVectorFactory | |||
| case VectorType.ELASTICSEARCH_JA: | |||
| from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import ( | |||
| ElasticSearchJaVectorFactory, | |||
| ) | |||
| return ElasticSearchJaVectorFactory | |||
| case VectorType.TIDB_VECTOR: | |||
| from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory | |||
| @@ -16,6 +16,7 @@ class VectorType(StrEnum): | |||
| TENCENT = "tencent" | |||
| ORACLE = "oracle" | |||
| ELASTICSEARCH = "elasticsearch" | |||
| ELASTICSEARCH_JA = "elasticsearch-ja" | |||
| LINDORM = "lindorm" | |||
| COUCHBASE = "couchbase" | |||
| BAIDU = "baidu" | |||
| @@ -23,7 +23,6 @@ class PdfExtractor(BaseExtractor): | |||
| self._file_cache_key = file_cache_key | |||
| def extract(self) -> list[Document]: | |||
| plaintext_file_key = "" | |||
| plaintext_file_exists = False | |||
| if self._file_cache_key: | |||
| try: | |||
| @@ -39,8 +38,8 @@ class PdfExtractor(BaseExtractor): | |||
| text = "\n\n".join(text_list) | |||
| # save plaintext file for caching | |||
| if not plaintext_file_exists and plaintext_file_key: | |||
| storage.save(plaintext_file_key, text.encode("utf-8")) | |||
| if not plaintext_file_exists and self._file_cache_key: | |||
| storage.save(self._file_cache_key, text.encode("utf-8")) | |||
| return documents | |||
| @@ -3,6 +3,7 @@ | |||
| import uuid | |||
| from typing import Optional | |||
| from configs import dify_config | |||
| from core.model_manager import ModelInstance | |||
| from core.rag.cleaner.clean_processor import CleanProcessor | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| @@ -80,6 +81,10 @@ class ParentChildIndexProcessor(BaseIndexProcessor): | |||
| child_nodes = self._split_child_nodes( | |||
| document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") | |||
| ) | |||
| if kwargs.get("preview"): | |||
| if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER: | |||
| child_nodes = child_nodes[: dify_config.CHILD_CHUNKS_PREVIEW_NUMBER] | |||
| document.children = child_nodes | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(document.page_content) | |||
| @@ -212,8 +212,23 @@ class ApiTool(Tool): | |||
| else: | |||
| body = body | |||
| if method in {"get", "head", "post", "put", "delete", "patch"}: | |||
| response: httpx.Response = getattr(ssrf_proxy, method)( | |||
| if method in { | |||
| "get", | |||
| "head", | |||
| "post", | |||
| "put", | |||
| "delete", | |||
| "patch", | |||
| "options", | |||
| "GET", | |||
| "POST", | |||
| "PUT", | |||
| "PATCH", | |||
| "DELETE", | |||
| "HEAD", | |||
| "OPTIONS", | |||
| }: | |||
| response: httpx.Response = getattr(ssrf_proxy, method.lower())( | |||
| url, | |||
| params=params, | |||
| headers=headers, | |||
| @@ -2,14 +2,18 @@ import csv | |||
| import io | |||
| import json | |||
| import logging | |||
| import operator | |||
| import os | |||
| import tempfile | |||
| from typing import cast | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import Any, cast | |||
| import docx | |||
| import pandas as pd | |||
| import pypdfium2 # type: ignore | |||
| import yaml # type: ignore | |||
| from docx.table import Table | |||
| from docx.text.paragraph import Paragraph | |||
| from configs import dify_config | |||
| from core.file import File, FileTransferMethod, file_manager | |||
| @@ -78,6 +82,23 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): | |||
| process_data=process_data, | |||
| ) | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| *, | |||
| graph_config: Mapping[str, Any], | |||
| node_id: str, | |||
| node_data: DocumentExtractorNodeData, | |||
| ) -> Mapping[str, Sequence[str]]: | |||
| """ | |||
| Extract variable selector to variable mapping | |||
| :param graph_config: graph config | |||
| :param node_id: node id | |||
| :param node_data: node data | |||
| :return: | |||
| """ | |||
| return {node_id + ".files": node_data.variable_selector} | |||
| def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: | |||
| """Extract text from a file based on its MIME type.""" | |||
| @@ -189,35 +210,56 @@ def _extract_text_from_doc(file_content: bytes) -> str: | |||
| doc_file = io.BytesIO(file_content) | |||
| doc = docx.Document(doc_file) | |||
| text = [] | |||
| # Process paragraphs | |||
| for paragraph in doc.paragraphs: | |||
| if paragraph.text.strip(): | |||
| text.append(paragraph.text) | |||
| # Process tables | |||
| for table in doc.tables: | |||
| # Table header | |||
| try: | |||
| # table maybe cause errors so ignore it. | |||
| if len(table.rows) > 0 and table.rows[0].cells is not None: | |||
| # Keep track of paragraph and table positions | |||
| content_items: list[tuple[int, str, Table | Paragraph]] = [] | |||
| # Process paragraphs and tables | |||
| for i, paragraph in enumerate(doc.paragraphs): | |||
| if paragraph.text.strip(): | |||
| content_items.append((i, "paragraph", paragraph)) | |||
| for i, table in enumerate(doc.tables): | |||
| content_items.append((i, "table", table)) | |||
| # Sort content items based on their original position | |||
| content_items.sort(key=operator.itemgetter(0)) | |||
| # Process sorted content | |||
| for _, item_type, item in content_items: | |||
| if item_type == "paragraph": | |||
| if isinstance(item, Table): | |||
| continue | |||
| text.append(item.text) | |||
| elif item_type == "table": | |||
| # Process tables | |||
| if not isinstance(item, Table): | |||
| continue | |||
| try: | |||
| # Check if any cell in the table has text | |||
| has_content = False | |||
| for row in table.rows: | |||
| for row in item.rows: | |||
| if any(cell.text.strip() for cell in row.cells): | |||
| has_content = True | |||
| break | |||
| if has_content: | |||
| markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n" | |||
| markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n" | |||
| for row in table.rows[1:]: | |||
| markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n" | |||
| cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells] | |||
| markdown_table = f"| {' | '.join(cell_texts)} |\n" | |||
| markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n" | |||
| for row in item.rows[1:]: | |||
| # Replace newlines with <br> in each cell | |||
| row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells] | |||
| markdown_table += "| " + " | ".join(row_cells) + " |\n" | |||
| text.append(markdown_table) | |||
| except Exception as e: | |||
| logger.warning(f"Failed to extract table from DOC/DOCX: {e}") | |||
| continue | |||
| except Exception as e: | |||
| logger.warning(f"Failed to extract table from DOC/DOCX: {e}") | |||
| continue | |||
| return "\n".join(text) | |||
| except Exception as e: | |||
| raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e | |||
| @@ -68,7 +68,22 @@ class HttpRequestNodeData(BaseNodeData): | |||
| Code Node Data. | |||
| """ | |||
| method: Literal["get", "post", "put", "patch", "delete", "head"] | |||
| method: Literal[ | |||
| "get", | |||
| "post", | |||
| "put", | |||
| "patch", | |||
| "delete", | |||
| "head", | |||
| "options", | |||
| "GET", | |||
| "POST", | |||
| "PUT", | |||
| "PATCH", | |||
| "DELETE", | |||
| "HEAD", | |||
| "OPTIONS", | |||
| ] | |||
| url: str | |||
| authorization: HttpRequestNodeAuthorization | |||
| headers: str | |||
| @@ -37,7 +37,22 @@ BODY_TYPE_TO_CONTENT_TYPE = { | |||
| class Executor: | |||
| method: Literal["get", "head", "post", "put", "delete", "patch"] | |||
| method: Literal[ | |||
| "get", | |||
| "head", | |||
| "post", | |||
| "put", | |||
| "delete", | |||
| "patch", | |||
| "options", | |||
| "GET", | |||
| "POST", | |||
| "PUT", | |||
| "PATCH", | |||
| "DELETE", | |||
| "HEAD", | |||
| "OPTIONS", | |||
| ] | |||
| url: str | |||
| params: list[tuple[str, str]] | None | |||
| content: str | bytes | None | |||
| @@ -67,12 +82,6 @@ class Executor: | |||
| node_data.authorization.config.api_key | |||
| ).text | |||
| # check if node_data.url is a valid URL | |||
| if not node_data.url: | |||
| raise InvalidURLError("url is required") | |||
| if not node_data.url.startswith(("http://", "https://")): | |||
| raise InvalidURLError("url should start with http:// or https://") | |||
| self.url: str = node_data.url | |||
| self.method = node_data.method | |||
| self.auth = node_data.authorization | |||
| @@ -99,6 +108,12 @@ class Executor: | |||
| def _init_url(self): | |||
| self.url = self.variable_pool.convert_template(self.node_data.url).text | |||
| # check if url is a valid URL | |||
| if not self.url: | |||
| raise InvalidURLError("url is required") | |||
| if not self.url.startswith(("http://", "https://")): | |||
| raise InvalidURLError("url should start with http:// or https://") | |||
| def _init_params(self): | |||
| """ | |||
| Almost same as _init_headers(), difference: | |||
| @@ -158,7 +173,10 @@ class Executor: | |||
| if len(data) != 1: | |||
| raise RequestBodyError("json body type should have exactly one item") | |||
| json_string = self.variable_pool.convert_template(data[0].value).text | |||
| json_object = json.loads(json_string, strict=False) | |||
| try: | |||
| json_object = json.loads(json_string, strict=False) | |||
| except json.JSONDecodeError as e: | |||
| raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e | |||
| self.json = json_object | |||
| # self.json = self._parse_object_contains_variables(json_object) | |||
| case "binary": | |||
| @@ -246,7 +264,22 @@ class Executor: | |||
| """ | |||
| do http request depending on api bundle | |||
| """ | |||
| if self.method not in {"get", "head", "post", "put", "delete", "patch"}: | |||
| if self.method not in { | |||
| "get", | |||
| "head", | |||
| "post", | |||
| "put", | |||
| "delete", | |||
| "patch", | |||
| "options", | |||
| "GET", | |||
| "POST", | |||
| "PUT", | |||
| "PATCH", | |||
| "DELETE", | |||
| "HEAD", | |||
| "OPTIONS", | |||
| }: | |||
| raise InvalidHttpMethodError(f"Invalid http method {self.method}") | |||
| request_args = { | |||
| @@ -263,7 +296,7 @@ class Executor: | |||
| } | |||
| # request_args = {k: v for k, v in request_args.items() if v is not None} | |||
| try: | |||
| response = getattr(ssrf_proxy, self.method)(**request_args) | |||
| response = getattr(ssrf_proxy, self.method.lower())(**request_args) | |||
| except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: | |||
| raise HttpRequestNodeError(str(e)) | |||
| # FIXME: fix type ignore, this maybe httpx type issue | |||
| @@ -340,6 +340,10 @@ class WorkflowEntry: | |||
| ): | |||
| raise ValueError(f"Variable key {node_variable} not found in user inputs.") | |||
| # environment variable already exist in variable pool, not from user inputs | |||
| if variable_pool.get(variable_selector): | |||
| continue | |||
| # fetch variable node id from variable selector | |||
| variable_node_id = variable_selector[0] | |||
| variable_key_list = variable_selector[1:] | |||
| @@ -33,6 +33,7 @@ else | |||
| --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ | |||
| --workers ${SERVER_WORKER_AMOUNT:-1} \ | |||
| --worker-class ${SERVER_WORKER_CLASS:-gevent} \ | |||
| --worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \ | |||
| --timeout ${GUNICORN_TIMEOUT:-200} \ | |||
| app:app | |||
| fi | |||
| @@ -46,7 +46,7 @@ def init_app(app: DifyApp): | |||
| timezone = pytz.timezone(log_tz) | |||
| def time_converter(seconds): | |||
| return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() | |||
| return datetime.fromtimestamp(seconds, tz=timezone).timetuple() | |||
| for handler in logging.root.handlers: | |||
| if handler.formatter: | |||
| @@ -158,7 +158,7 @@ def _build_from_remote_url( | |||
| tenant_id: str, | |||
| transfer_method: FileTransferMethod, | |||
| ) -> File: | |||
| url = mapping.get("url") | |||
| url = mapping.get("url") or mapping.get("remote_url") | |||
| if not url: | |||
| raise ValueError("Invalid file url") | |||
| @@ -255,7 +255,8 @@ class NotionOAuth(OAuthDataSource): | |||
| response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) | |||
| response_json = response.json() | |||
| if response.status_code != 200: | |||
| raise ValueError(f"Error fetching block parent page ID: {response_json.message}") | |||
| message = response_json.get("message", "unknown error") | |||
| raise ValueError(f"Error fetching block parent page ID: {message}") | |||
| parent = response_json["parent"] | |||
| parent_type = parent["type"] | |||
| if parent_type == "block_id": | |||
| @@ -0,0 +1,41 @@ | |||
| """change workflow_runs.total_tokens to bigint | |||
| Revision ID: a91b476a53de | |||
| Revises: 923752d42eb6 | |||
| Create Date: 2025-01-01 20:00:01.207369 | |||
| """ | |||
| from alembic import op | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| # revision identifiers, used by Alembic. | |||
| revision = 'a91b476a53de' | |||
| down_revision = '923752d42eb6' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('workflow_runs', schema=None) as batch_op: | |||
| batch_op.alter_column('total_tokens', | |||
| existing_type=sa.INTEGER(), | |||
| type_=sa.BigInteger(), | |||
| existing_nullable=False, | |||
| existing_server_default=sa.text('0')) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('workflow_runs', schema=None) as batch_op: | |||
| batch_op.alter_column('total_tokens', | |||
| existing_type=sa.BigInteger(), | |||
| type_=sa.INTEGER(), | |||
| existing_nullable=False, | |||
| existing_server_default=sa.text('0')) | |||
| # ### end Alembic commands ### | |||
| @@ -415,8 +415,8 @@ class WorkflowRun(Base): | |||
| status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded | |||
| outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") | |||
| error: Mapped[Optional[str]] = mapped_column(db.Text) | |||
| elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) | |||
| total_tokens: Mapped[int] = mapped_column(server_default=db.text("0")) | |||
| elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0")) | |||
| total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) | |||
| total_steps = db.Column(db.Integer, server_default=db.text("0")) | |||
| created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| @@ -71,7 +71,7 @@ pyjwt = "~2.8.0" | |||
| pypdfium2 = "~4.30.0" | |||
| python = ">=3.11,<3.13" | |||
| python-docx = "~1.1.0" | |||
| python-dotenv = "1.0.0" | |||
| python-dotenv = "1.0.1" | |||
| pyyaml = "~6.0.1" | |||
| readabilipy = "0.2.0" | |||
| redis = { version = "~5.0.3", extras = ["hiredis"] } | |||
| @@ -82,7 +82,7 @@ scikit-learn = "~1.5.1" | |||
| sentry-sdk = { version = "~1.44.1", extras = ["flask"] } | |||
| sqlalchemy = "~2.0.29" | |||
| starlette = "0.41.0" | |||
| tencentcloud-sdk-python-hunyuan = "~3.0.1158" | |||
| tencentcloud-sdk-python-hunyuan = "~3.0.1294" | |||
| tiktoken = "~0.8.0" | |||
| tokenizers = "~0.15.0" | |||
| transformers = "~4.35.0" | |||
| @@ -92,7 +92,7 @@ validators = "0.21.0" | |||
| volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} | |||
| websocket-client = "~1.7.0" | |||
| xinference-client = "0.15.2" | |||
| yarl = "~1.9.4" | |||
| yarl = "~1.18.3" | |||
| youtube-transcript-api = "~0.6.2" | |||
| zhipuai = "~2.1.5" | |||
| # Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. | |||
| @@ -157,7 +157,7 @@ opensearch-py = "2.4.0" | |||
| oracledb = "~2.2.1" | |||
| pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } | |||
| pgvector = "0.2.5" | |||
| pymilvus = "~2.4.4" | |||
| pymilvus = "~2.5.0" | |||
| pymochow = "1.3.1" | |||
| pyobvector = "~0.1.6" | |||
| qdrant-client = "1.7.3" | |||
| @@ -168,23 +168,6 @@ def clean_unused_datasets_task(): | |||
| else: | |||
| plan = plan_cache.decode() | |||
| if plan == "sandbox": | |||
| # add auto disable log | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .filter( | |||
| Document.dataset_id == dataset.id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| for document in documents: | |||
| dataset_auto_disable_log = DatasetAutoDisableLog( | |||
| tenant_id=dataset.tenant_id, | |||
| dataset_id=dataset.id, | |||
| document_id=document.id, | |||
| ) | |||
| db.session.add(dataset_auto_disable_log) | |||
| # remove index | |||
| index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() | |||
| index_processor.clean(dataset, None) | |||
| @@ -66,7 +66,7 @@ class TokenPair(BaseModel): | |||
| REFRESH_TOKEN_PREFIX = "refresh_token:" | |||
| ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:" | |||
| REFRESH_TOKEN_EXPIRY = timedelta(days=30) | |||
| REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS) | |||
| class AccountService: | |||
| @@ -2,6 +2,7 @@ import logging | |||
| import uuid | |||
| from enum import StrEnum | |||
| from typing import Optional, cast | |||
| from urllib.parse import urlparse | |||
| from uuid import uuid4 | |||
| import yaml # type: ignore | |||
| @@ -124,7 +125,7 @@ class AppDslService: | |||
| raise ValueError(f"Invalid import_mode: {import_mode}") | |||
| # Get YAML content | |||
| content: bytes | str = b"" | |||
| content: str = "" | |||
| if mode == ImportMode.YAML_URL: | |||
| if not yaml_url: | |||
| return Import( | |||
| @@ -133,13 +134,17 @@ class AppDslService: | |||
| error="yaml_url is required when import_mode is yaml-url", | |||
| ) | |||
| try: | |||
| # tricky way to handle url from github to github raw url | |||
| if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")): | |||
| parsed_url = urlparse(yaml_url) | |||
| if ( | |||
| parsed_url.scheme == "https" | |||
| and parsed_url.netloc == "github.com" | |||
| and parsed_url.path.endswith((".yml", ".yaml")) | |||
| ): | |||
| yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") | |||
| yaml_url = yaml_url.replace("/blob/", "/") | |||
| response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) | |||
| response.raise_for_status() | |||
| content = response.content | |||
| content = response.content.decode() | |||
| if len(content) > DSL_MAX_SIZE: | |||
| return Import( | |||
| @@ -26,9 +26,10 @@ from tasks.remove_app_and_related_data_task import remove_app_and_related_data_t | |||
| class AppService: | |||
| def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None: | |||
| def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None: | |||
| """ | |||
| Get app list with pagination | |||
| :param user_id: user id | |||
| :param tenant_id: tenant id | |||
| :param args: request args | |||
| :return: | |||
| @@ -44,6 +45,8 @@ class AppService: | |||
| elif args["mode"] == "channel": | |||
| filters.append(App.mode == AppMode.CHANNEL.value) | |||
| if args.get("is_created_by_me", False): | |||
| filters.append(App.created_by == user_id) | |||
| if args.get("name"): | |||
| name = args["name"][:30] | |||
| filters.append(App.name.ilike(f"%{name}%")) | |||
| @@ -1,5 +1,5 @@ | |||
| import os | |||
| from typing import Optional | |||
| from typing import Literal, Optional | |||
| import httpx | |||
| from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed | |||
| @@ -17,7 +17,6 @@ class BillingService: | |||
| params = {"tenant_id": tenant_id} | |||
| billing_info = cls._send_request("GET", "/subscription/info", params=params) | |||
| return billing_info | |||
| @classmethod | |||
| @@ -47,12 +46,13 @@ class BillingService: | |||
| retry=retry_if_exception_type(httpx.RequestError), | |||
| reraise=True, | |||
| ) | |||
| def _send_request(cls, method, endpoint, json=None, params=None): | |||
| def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None): | |||
| headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} | |||
| url = f"{cls.base_url}{endpoint}" | |||
| response = httpx.request(method, url, json=json, params=params, headers=headers) | |||
| if method == "GET" and response.status_code != httpx.codes.OK: | |||
| raise ValueError("Unable to retrieve billing information. Please try again later or contact support.") | |||
| return response.json() | |||
| @staticmethod | |||
| @@ -86,7 +86,7 @@ class DatasetService: | |||
| else: | |||
| return [], 0 | |||
| else: | |||
| if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): | |||
| if user.current_role != TenantAccountRole.OWNER: | |||
| # show all datasets that the user has permission to access | |||
| if permitted_dataset_ids: | |||
| query = query.filter( | |||
| @@ -382,7 +382,7 @@ class DatasetService: | |||
| if dataset.tenant_id != user.current_tenant_id: | |||
| logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") | |||
| raise NoPermissionError("You do not have permission to access this dataset.") | |||
| if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): | |||
| if user.current_role != TenantAccountRole.OWNER: | |||
| if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: | |||
| logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") | |||
| raise NoPermissionError("You do not have permission to access this dataset.") | |||
| @@ -404,7 +404,7 @@ class DatasetService: | |||
| if not user: | |||
| raise ValueError("User not found") | |||
| if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN): | |||
| if user.current_role != TenantAccountRole.OWNER: | |||
| if dataset.permission == DatasetPermissionEnum.ONLY_ME: | |||
| if dataset.created_by != user.id: | |||
| raise NoPermissionError("You do not have permission to access this dataset.") | |||
| @@ -434,6 +434,12 @@ class DatasetService: | |||
| @staticmethod | |||
| def get_dataset_auto_disable_logs(dataset_id: str) -> dict: | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if not features.billing.enabled or features.billing.subscription.plan == "sandbox": | |||
| return { | |||
| "document_ids": [], | |||
| "count": 0, | |||
| } | |||
| # get recent 30 days auto disable logs | |||
| start_date = datetime.datetime.now() - datetime.timedelta(days=30) | |||
| dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( | |||
| @@ -786,13 +792,19 @@ class DocumentService: | |||
| dataset.indexing_technique = knowledge_config.indexing_technique | |||
| if knowledge_config.indexing_technique == "high_quality": | |||
| model_manager = ModelManager() | |||
| embedding_model = model_manager.get_default_model_instance( | |||
| tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING | |||
| ) | |||
| dataset.embedding_model = embedding_model.model | |||
| dataset.embedding_model_provider = embedding_model.provider | |||
| if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: | |||
| dataset_embedding_model = knowledge_config.embedding_model | |||
| dataset_embedding_model_provider = knowledge_config.embedding_model_provider | |||
| else: | |||
| embedding_model = model_manager.get_default_model_instance( | |||
| tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING | |||
| ) | |||
| dataset_embedding_model = embedding_model.model | |||
| dataset_embedding_model_provider = embedding_model.provider | |||
| dataset.embedding_model = dataset_embedding_model | |||
| dataset.embedding_model_provider = dataset_embedding_model_provider | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||
| embedding_model.provider, embedding_model.model | |||
| dataset_embedding_model_provider, dataset_embedding_model | |||
| ) | |||
| dataset.collection_binding_id = dataset_collection_binding.id | |||
| if not dataset.retrieval_model: | |||
| @@ -804,7 +816,11 @@ class DocumentService: | |||
| "score_threshold_enabled": False, | |||
| } | |||
| dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore | |||
| dataset.retrieval_model = ( | |||
| knowledge_config.retrieval_model.model_dump() | |||
| if knowledge_config.retrieval_model | |||
| else default_retrieval_model | |||
| ) # type: ignore | |||
| documents = [] | |||
| if knowledge_config.original_document_id: | |||
| @@ -27,7 +27,7 @@ class WorkflowAppService: | |||
| query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) | |||
| if keyword: | |||
| keyword_like_val = f"%{args['keyword'][:30]}%" | |||
| keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u") | |||
| keyword_conditions = [ | |||
| WorkflowRun.inputs.ilike(keyword_like_val), | |||
| WorkflowRun.outputs.ilike(keyword_like_val), | |||
| @@ -28,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): | |||
| if not dataset: | |||
| raise Exception("Dataset not found") | |||
| index_type = dataset.doc_form | |||
| index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX | |||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||
| if action == "remove": | |||
| index_processor.clean(dataset, None, with_keywords=False) | |||
| @@ -157,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): | |||
| {"indexing_status": "error", "error": str(e)}, synchronize_session=False | |||
| ) | |||
| db.session.commit() | |||
| else: | |||
| # clean collection | |||
| index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| @@ -0,0 +1,55 @@ | |||
| import os | |||
| from pathlib import Path | |||
| import pytest | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel | |||
| def test_validate_credentials(): | |||
| model = GPUStackSpeech2TextModel() | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| model.validate_credentials( | |||
| model="faster-whisper-medium", | |||
| credentials={ | |||
| "endpoint_url": "invalid_url", | |||
| "api_key": "invalid_api_key", | |||
| }, | |||
| ) | |||
| model.validate_credentials( | |||
| model="faster-whisper-medium", | |||
| credentials={ | |||
| "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), | |||
| "api_key": os.environ.get("GPUSTACK_API_KEY"), | |||
| }, | |||
| ) | |||
| def test_invoke_model(): | |||
| model = GPUStackSpeech2TextModel() | |||
| # Get the directory of the current file | |||
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |||
| # Get assets directory | |||
| assets_dir = os.path.join(os.path.dirname(current_dir), "assets") | |||
| # Construct the path to the audio file | |||
| audio_file_path = os.path.join(assets_dir, "audio.mp3") | |||
| file = Path(audio_file_path).read_bytes() | |||
| result = model.invoke( | |||
| model="faster-whisper-medium", | |||
| credentials={ | |||
| "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), | |||
| "api_key": os.environ.get("GPUSTACK_API_KEY"), | |||
| }, | |||
| file=file, | |||
| ) | |||
| assert isinstance(result, str) | |||
| assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" | |||
| @@ -0,0 +1,24 @@ | |||
| import os | |||
| from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel | |||
| def test_invoke_model(): | |||
| model = GPUStackText2SpeechModel() | |||
| result = model.invoke( | |||
| model="cosyvoice-300m-sft", | |||
| tenant_id="test", | |||
| credentials={ | |||
| "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), | |||
| "api_key": os.environ.get("GPUSTACK_API_KEY"), | |||
| }, | |||
| content_text="Hello world", | |||
| voice="Chinese Female", | |||
| ) | |||
| content = b"" | |||
| for chunk in result: | |||
| content += chunk | |||
| assert content != b"" | |||
| @@ -19,9 +19,9 @@ class MilvusVectorTest(AbstractVectorTest): | |||
| ) | |||
| def search_by_full_text(self): | |||
| # milvus dos not support full text searching yet in < 2.3.x | |||
| # milvus support BM25 full text search after version 2.5.0-beta | |||
| hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) | |||
| assert len(hits_by_full_text) == 0 | |||
| assert len(hits_by_full_text) >= 0 | |||
| def get_ids_by_metadata_field(self): | |||
| ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) | |||
| @@ -2,7 +2,7 @@ version: '3' | |||
| services: | |||
| # API service | |||
| api: | |||
| image: langgenius/dify-api:0.14.2 | |||
| image: langgenius/dify-api:0.15.0 | |||
| restart: always | |||
| environment: | |||
| # Startup mode, 'api' starts the API server. | |||
| @@ -227,7 +227,7 @@ services: | |||
| # worker service | |||
| # The Celery worker for processing the queue. | |||
| worker: | |||
| image: langgenius/dify-api:0.14.2 | |||
| image: langgenius/dify-api:0.15.0 | |||
| restart: always | |||
| environment: | |||
| CONSOLE_WEB_URL: '' | |||
| @@ -397,7 +397,7 @@ services: | |||
| # Frontend web application. | |||
| web: | |||
| image: langgenius/dify-web:0.14.2 | |||
| image: langgenius/dify-web:0.15.0 | |||
| restart: always | |||
| environment: | |||
| # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is | |||
| @@ -105,6 +105,9 @@ FILES_ACCESS_TIMEOUT=300 | |||
| # Access token expiration time in minutes | |||
| ACCESS_TOKEN_EXPIRE_MINUTES=60 | |||
| # Refresh token expiration time in days | |||
| REFRESH_TOKEN_EXPIRE_DAYS=30 | |||
| # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. | |||
| APP_MAX_ACTIVE_REQUESTS=0 | |||
| APP_MAX_EXECUTION_TIME=1200 | |||
| @@ -123,10 +126,13 @@ DIFY_PORT=5001 | |||
| # The number of API server workers, i.e., the number of workers. | |||
| # Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent | |||
| # Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers | |||
| SERVER_WORKER_AMOUNT= | |||
| SERVER_WORKER_AMOUNT=1 | |||
| # Defaults to gevent. If using windows, it can be switched to sync or solo. | |||
| SERVER_WORKER_CLASS= | |||
| SERVER_WORKER_CLASS=gevent | |||
| # Default number of worker connections, the default is 10. | |||
| SERVER_WORKER_CONNECTIONS=10 | |||
| # Similar to SERVER_WORKER_CLASS. | |||
| # If using windows, it can be switched to sync or solo. | |||
| @@ -377,7 +383,7 @@ SUPABASE_URL=your-server-url | |||
| # ------------------------------ | |||
| # The type of vector store to use. | |||
| # Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`. | |||
| # Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`. | |||
| VECTOR_STORE=weaviate | |||
| # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. | |||
| @@ -397,6 +403,7 @@ MILVUS_URI=http://127.0.0.1:19530 | |||
| MILVUS_TOKEN= | |||
| MILVUS_USER=root | |||
| MILVUS_PASSWORD=Milvus | |||
| MILVUS_ENABLE_HYBRID_SEARCH=False | |||
| # MyScale configuration, only available when VECTOR_STORE is `myscale` | |||
| # For multi-language support, please set MYSCALE_FTS_PARAMS with referring to: | |||
| @@ -506,7 +513,7 @@ TENCENT_VECTOR_DB_SHARD=1 | |||
| TENCENT_VECTOR_DB_REPLICAS=2 | |||
| # ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch` | |||
| ELASTICSEARCH_HOST=0.0.0.0 | |||
| ELASTICSEARCH_HOST=elasticsearch | |||
| ELASTICSEARCH_PORT=9200 | |||
| ELASTICSEARCH_USERNAME=elastic | |||
| ELASTICSEARCH_PASSWORD=elastic | |||
| @@ -923,6 +930,9 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false | |||
| # Maximum number of submitted thread count in a ThreadPool for parallel node execution | |||
| MAX_SUBMIT_COUNT=100 | |||
| # The maximum number of top-k value for RAG. | |||
| TOP_K_MAX_VALUE=10 | |||
| # ------------------------------ | |||
| # Plugin Daemon Configuration | |||
| # ------------------------------ | |||
| @@ -947,3 +957,4 @@ ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id} | |||
| MARKETPLACE_ENABLED=true | |||
| MARKETPLACE_API_URL=https://marketplace-plugin.dify.dev | |||
| @@ -73,6 +73,7 @@ services: | |||
| CSP_WHITELIST: ${CSP_WHITELIST:-} | |||
| MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} | |||
| MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev} | |||
| TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-} | |||
| # The postgres database. | |||
| db: | |||
| @@ -92,7 +93,7 @@ services: | |||
| volumes: | |||
| - ./volumes/db/data:/var/lib/postgresql/data | |||
| healthcheck: | |||
| test: ['CMD', 'pg_isready'] | |||
| test: [ 'CMD', 'pg_isready' ] | |||
| interval: 1s | |||
| timeout: 3s | |||
| retries: 30 | |||
| @@ -111,7 +112,7 @@ services: | |||
| # Set the redis password when startup redis server. | |||
| command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} | |||
| healthcheck: | |||
| test: ['CMD', 'redis-cli', 'ping'] | |||
| test: [ 'CMD', 'redis-cli', 'ping' ] | |||
| # The DifySandbox | |||
| sandbox: | |||
| @@ -131,7 +132,7 @@ services: | |||
| volumes: | |||
| - ./volumes/sandbox/dependencies:/dependencies | |||
| healthcheck: | |||
| test: ['CMD', 'curl', '-f', 'http://localhost:8194/health'] | |||
| test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ] | |||
| networks: | |||
| - ssrf_proxy_network | |||
| @@ -167,12 +168,7 @@ services: | |||
| volumes: | |||
| - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template | |||
| - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh | |||
| entrypoint: | |||
| [ | |||
| 'sh', | |||
| '-c', | |||
| "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", | |||
| ] | |||
| entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] | |||
| environment: | |||
| # pls clearly modify the squid env vars to fit your network environment. | |||
| HTTP_PORT: ${SSRF_HTTP_PORT:-3128} | |||
| @@ -201,8 +197,8 @@ services: | |||
| - CERTBOT_EMAIL=${CERTBOT_EMAIL} | |||
| - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} | |||
| - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} | |||
| entrypoint: ['/docker-entrypoint.sh'] | |||
| command: ['tail', '-f', '/dev/null'] | |||
| entrypoint: [ '/docker-entrypoint.sh' ] | |||
| command: [ 'tail', '-f', '/dev/null' ] | |||
| # The nginx reverse proxy. | |||
| # used for reverse proxying the API service and Web service. | |||
| @@ -219,12 +215,7 @@ services: | |||
| - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) | |||
| - ./volumes/certbot/conf:/etc/letsencrypt | |||
| - ./volumes/certbot/www:/var/www/html | |||
| entrypoint: | |||
| [ | |||
| 'sh', | |||
| '-c', | |||
| "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", | |||
| ] | |||
| entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] | |||
| environment: | |||
| NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} | |||
| NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} | |||
| @@ -316,7 +307,7 @@ services: | |||
| working_dir: /opt/couchbase | |||
| stdin_open: true | |||
| tty: true | |||
| entrypoint: [""] | |||
| entrypoint: [ "" ] | |||
| command: sh -c "/opt/couchbase/init/init-cbserver.sh" | |||
| volumes: | |||
| - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data | |||
| @@ -345,7 +336,7 @@ services: | |||
| volumes: | |||
| - ./volumes/pgvector/data:/var/lib/postgresql/data | |||
| healthcheck: | |||
| test: ['CMD', 'pg_isready'] | |||
| test: [ 'CMD', 'pg_isready' ] | |||
| interval: 1s | |||
| timeout: 3s | |||
| retries: 30 | |||
| @@ -367,7 +358,7 @@ services: | |||
| volumes: | |||
| - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data | |||
| healthcheck: | |||
| test: ['CMD', 'pg_isready'] | |||
| test: [ 'CMD', 'pg_isready' ] | |||
| interval: 1s | |||
| timeout: 3s | |||
| retries: 30 | |||
| @@ -432,7 +423,7 @@ services: | |||
| - ./volumes/milvus/etcd:/etcd | |||
| command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd | |||
| healthcheck: | |||
| test: ['CMD', 'etcdctl', 'endpoint', 'health'] | |||
| test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ] | |||
| interval: 30s | |||
| timeout: 20s | |||
| retries: 3 | |||
| @@ -451,7 +442,7 @@ services: | |||
| - ./volumes/milvus/minio:/minio_data | |||
| command: minio server /minio_data --console-address ":9001" | |||
| healthcheck: | |||
| test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] | |||
| test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ] | |||
| interval: 30s | |||
| timeout: 20s | |||
| retries: 3 | |||
| @@ -463,7 +454,7 @@ services: | |||
| image: milvusdb/milvus:v2.3.1 | |||
| profiles: | |||
| - milvus | |||
| command: ['milvus', 'run', 'standalone'] | |||
| command: [ 'milvus', 'run', 'standalone' ] | |||
| environment: | |||
| ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} | |||
| MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} | |||
| @@ -471,7 +462,7 @@ services: | |||
| volumes: | |||
| - ./volumes/milvus/milvus:/var/lib/milvus | |||
| healthcheck: | |||
| test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz'] | |||
| test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ] | |||
| interval: 30s | |||
| start_period: 90s | |||
| timeout: 20s | |||
| @@ -559,7 +550,7 @@ services: | |||
| ports: | |||
| - ${ELASTICSEARCH_PORT:-9200}:9200 | |||
| healthcheck: | |||
| test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty'] | |||
| test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ] | |||
| interval: 30s | |||
| timeout: 10s | |||
| retries: 50 | |||
| @@ -587,7 +578,7 @@ services: | |||
| ports: | |||
| - ${KIBANA_PORT:-5601}:5601 | |||
| healthcheck: | |||
| test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1'] | |||
| test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ] | |||
| interval: 30s | |||
| timeout: 10s | |||
| retries: 3 | |||
| @@ -27,12 +27,14 @@ x-shared-env: &shared-api-worker-env | |||
| MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} | |||
| FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} | |||
| ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} | |||
| REFRESH_TOKEN_EXPIRE_DAYS: ${REFRESH_TOKEN_EXPIRE_DAYS:-30} | |||
| APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0} | |||
| APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200} | |||
| DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0} | |||
| DIFY_PORT: ${DIFY_PORT:-5001} | |||
| SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-} | |||
| SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-} | |||
| SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-1} | |||
| SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-gevent} | |||
| SERVER_WORKER_CONNECTIONS: ${SERVER_WORKER_CONNECTIONS:-10} | |||
| CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-} | |||
| GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360} | |||
| CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-} | |||
| @@ -136,6 +138,7 @@ x-shared-env: &shared-api-worker-env | |||
| MILVUS_TOKEN: ${MILVUS_TOKEN:-} | |||
| MILVUS_USER: ${MILVUS_USER:-root} | |||
| MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus} | |||
| MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False} | |||
| MYSCALE_HOST: ${MYSCALE_HOST:-myscale} | |||
| MYSCALE_PORT: ${MYSCALE_PORT:-8123} | |||
| MYSCALE_USER: ${MYSCALE_USER:-default} | |||
| @@ -401,6 +404,7 @@ x-shared-env: &shared-api-worker-env | |||
| ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}} | |||
| MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true} | |||
| MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} | |||
| TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10} | |||
| services: | |||
| # API service | |||
| @@ -476,6 +480,7 @@ services: | |||
| CSP_WHITELIST: ${CSP_WHITELIST:-} | |||
| MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} | |||
| MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev} | |||
| TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-} | |||
| # The postgres database. | |||
| db: | |||
| @@ -495,7 +500,7 @@ services: | |||
| volumes: | |||
| - ./volumes/db/data:/var/lib/postgresql/data | |||
| healthcheck: | |||
| test: ['CMD', 'pg_isready'] | |||
| test: [ 'CMD', 'pg_isready' ] | |||
| interval: 1s | |||
| timeout: 3s | |||
| retries: 30 | |||
| @@ -514,7 +519,7 @@ services: | |||
| # Set the redis password when startup redis server. | |||
| command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} | |||
| healthcheck: | |||
| test: ['CMD', 'redis-cli', 'ping'] | |||
| test: [ 'CMD', 'redis-cli', 'ping' ] | |||
| # The DifySandbox | |||
| sandbox: | |||
| @@ -534,7 +539,7 @@ services: | |||
| volumes: | |||
| - ./volumes/sandbox/dependencies:/dependencies | |||
| healthcheck: | |||
| test: ['CMD', 'curl', '-f', 'http://localhost:8194/health'] | |||
| test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ] | |||
| networks: | |||
| - ssrf_proxy_network | |||
| @@ -571,12 +576,7 @@ services: | |||
| volumes: | |||
| - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template | |||
| - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh | |||
| entrypoint: | |||
| [ | |||
| 'sh', | |||
| '-c', | |||
| "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", | |||
| ] | |||
| entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] | |||
| environment: | |||
| # pls clearly modify the squid env vars to fit your network environment. | |||
| HTTP_PORT: ${SSRF_HTTP_PORT:-3128} | |||
| @@ -605,8 +605,8 @@ services: | |||
| - CERTBOT_EMAIL=${CERTBOT_EMAIL} | |||
| - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} | |||
| - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} | |||
| entrypoint: ['/docker-entrypoint.sh'] | |||
| command: ['tail', '-f', '/dev/null'] | |||
| entrypoint: [ '/docker-entrypoint.sh' ] | |||
| command: [ 'tail', '-f', '/dev/null' ] | |||
| # The nginx reverse proxy. | |||
| # used for reverse proxying the API service and Web service. | |||
| @@ -623,12 +623,7 @@ services: | |||
| - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) | |||
| - ./volumes/certbot/conf:/etc/letsencrypt | |||
| - ./volumes/certbot/www:/var/www/html | |||
| entrypoint: | |||
| [ | |||
| 'sh', | |||
| '-c', | |||
| "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", | |||
| ] | |||
| entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] | |||
| environment: | |||
| NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} | |||
| NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} | |||
| @@ -720,7 +715,7 @@ services: | |||
| working_dir: /opt/couchbase | |||
| stdin_open: true | |||
| tty: true | |||
| entrypoint: [""] | |||
| entrypoint: [ "" ] | |||
| command: sh -c "/opt/couchbase/init/init-cbserver.sh" | |||
| volumes: | |||
| - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data | |||
| @@ -749,7 +744,7 @@ services: | |||
| volumes: | |||
| - ./volumes/pgvector/data:/var/lib/postgresql/data | |||
| healthcheck: | |||
| test: ['CMD', 'pg_isready'] | |||
| test: [ 'CMD', 'pg_isready' ] | |||
| interval: 1s | |||
| timeout: 3s | |||
| retries: 30 | |||
| @@ -771,7 +766,7 @@ services: | |||
| volumes: | |||
| - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data | |||
| healthcheck: | |||
| test: ['CMD', 'pg_isready'] | |||
| test: [ 'CMD', 'pg_isready' ] | |||
| interval: 1s | |||
| timeout: 3s | |||
| retries: 30 | |||
| @@ -836,7 +831,7 @@ services: | |||
| - ./volumes/milvus/etcd:/etcd | |||
| command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd | |||
| healthcheck: | |||
| test: ['CMD', 'etcdctl', 'endpoint', 'health'] | |||
| test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ] | |||
| interval: 30s | |||
| timeout: 20s | |||
| retries: 3 | |||
| @@ -855,7 +850,7 @@ services: | |||
| - ./volumes/milvus/minio:/minio_data | |||
| command: minio server /minio_data --console-address ":9001" | |||
| healthcheck: | |||
| test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] | |||
| test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ] | |||
| interval: 30s | |||
| timeout: 20s | |||
| retries: 3 | |||
| @@ -864,10 +859,10 @@ services: | |||
| milvus-standalone: | |||
| container_name: milvus-standalone | |||
| image: milvusdb/milvus:v2.3.1 | |||
| image: milvusdb/milvus:v2.5.0-beta | |||
| profiles: | |||
| - milvus | |||
| command: ['milvus', 'run', 'standalone'] | |||
| command: [ 'milvus', 'run', 'standalone' ] | |||
| environment: | |||
| ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} | |||
| MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} | |||
| @@ -875,7 +870,7 @@ services: | |||
| volumes: | |||
| - ./volumes/milvus/milvus:/var/lib/milvus | |||
| healthcheck: | |||
| test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz'] | |||
| test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ] | |||
| interval: 30s | |||
| start_period: 90s | |||
| timeout: 20s | |||
| @@ -948,22 +943,30 @@ services: | |||
| container_name: elasticsearch | |||
| profiles: | |||
| - elasticsearch | |||
| - elasticsearch-ja | |||
| restart: always | |||
| volumes: | |||
| - ./elasticsearch/docker-entrypoint.sh:/docker-entrypoint-mount.sh | |||
| - dify_es01_data:/usr/share/elasticsearch/data | |||
| environment: | |||
| ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} | |||
| VECTOR_STORE: ${VECTOR_STORE:-} | |||
| cluster.name: dify-es-cluster | |||
| node.name: dify-es0 | |||
| discovery.type: single-node | |||
| xpack.license.self_generated.type: trial | |||
| xpack.license.self_generated.type: basic | |||
| xpack.security.enabled: 'true' | |||
| xpack.security.enrollment.enabled: 'false' | |||
| xpack.security.http.ssl.enabled: 'false' | |||
| ports: | |||
| - ${ELASTICSEARCH_PORT:-9200}:9200 | |||
| deploy: | |||
| resources: | |||
| limits: | |||
| memory: 2g | |||
| entrypoint: [ 'sh', '-c', "sh /docker-entrypoint-mount.sh" ] | |||
| healthcheck: | |||
| test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty'] | |||
| test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ] | |||
| interval: 30s | |||
| timeout: 10s | |||
| retries: 50 | |||
| @@ -991,7 +994,7 @@ services: | |||
| ports: | |||
| - ${KIBANA_PORT:-5601}:5601 | |||
| healthcheck: | |||
| test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1'] | |||
| test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ] | |||
| interval: 30s | |||
| timeout: 10s | |||
| retries: 3 | |||
| @@ -0,0 +1,25 @@ | |||
| #!/bin/bash | |||
| set -e | |||
| if [ "${VECTOR_STORE}" = "elasticsearch-ja" ]; then | |||
| # Check if the ICU tokenizer plugin is installed | |||
| if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-icu; then | |||
| printf '%s\n' "Installing the ICU tokenizer plugin" | |||
| if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-icu; then | |||
| printf '%s\n' "Failed to install the ICU tokenizer plugin" | |||
| exit 1 | |||
| fi | |||
| fi | |||
| # Check if the Japanese language analyzer plugin is installed | |||
| if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-kuromoji; then | |||
| printf '%s\n' "Installing the Japanese language analyzer plugin" | |||
| if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-kuromoji; then | |||
| printf '%s\n' "Failed to install the Japanese language analyzer plugin" | |||
| exit 1 | |||
| fi | |||
| fi | |||
| fi | |||
| # Run the original entrypoint script | |||
| exec /bin/tini -- /usr/local/bin/docker-entrypoint.sh | |||
| @@ -25,3 +25,6 @@ NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000 | |||
| # CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP | |||
| NEXT_PUBLIC_CSP_WHITELIST= | |||
| # The maximum number of top-k value for RAG. | |||
| NEXT_PUBLIC_TOP_K_MAX_VALUE=10 | |||
| @@ -25,16 +25,18 @@ import Input from '@/app/components/base/input' | |||
| import { useStore as useTagStore } from '@/app/components/base/tag-management/store' | |||
| import TagManagementModal from '@/app/components/base/tag-management' | |||
| import TagFilter from '@/app/components/base/tag-management/filter' | |||
| import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label' | |||
| const getKey = ( | |||
| pageIndex: number, | |||
| previousPageData: AppListResponse, | |||
| activeTab: string, | |||
| isCreatedByMe: boolean, | |||
| tags: string[], | |||
| keywords: string, | |||
| ) => { | |||
| if (!pageIndex || previousPageData.has_more) { | |||
| const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords } } | |||
| const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords, is_created_by_me: isCreatedByMe } } | |||
| if (activeTab !== 'all') | |||
| params.params.mode = activeTab | |||
| @@ -58,6 +60,7 @@ const Apps = () => { | |||
| defaultTab: 'all', | |||
| }) | |||
| const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState() | |||
| const [isCreatedByMe, setIsCreatedByMe] = useState(false) | |||
| const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs) | |||
| const [searchKeywords, setSearchKeywords] = useState(keywords) | |||
| const setKeywords = useCallback((keywords: string) => { | |||
| @@ -68,7 +71,7 @@ const Apps = () => { | |||
| }, [setQuery]) | |||
| const { data, isLoading, setSize, mutate } = useSWRInfinite( | |||
| (pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, tagIDs, searchKeywords), | |||
| (pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, isCreatedByMe, tagIDs, searchKeywords), | |||
| fetchAppList, | |||
| { revalidateFirstPage: true }, | |||
| ) | |||
| @@ -132,6 +135,12 @@ const Apps = () => { | |||
| options={options} | |||
| /> | |||
| <div className='flex items-center gap-2'> | |||
| <CheckboxWithLabel | |||
| className='mr-2' | |||
| label={t('app.showMyCreatedAppsOnly')} | |||
| isChecked={isCreatedByMe} | |||
| onChange={() => setIsCreatedByMe(!isCreatedByMe)} | |||
| /> | |||
| <TagFilter type='app' value={tagFilterValue} onChange={handleTagsChange} /> | |||
| <Input | |||
| showLeftIcon | |||
| @@ -52,6 +52,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>high_quality</code> High quality: embedding using embedding model, built as vector database index | |||
| - <code>economy</code> Economy: Build using inverted index of keyword table index | |||
| </Property> | |||
| <Property name='doc_form' type='string' key='doc_form'> | |||
| Format of indexed content | |||
| - <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form | |||
| - <code>hierarchical_model</code> Parent-child mode | |||
| - <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions | |||
| </Property> | |||
| <Property name='doc_language' type='string' key='doc_language'> | |||
| In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code> | |||
| </Property> | |||
| <Property name='process_rule' type='object' key='process_rule'> | |||
| Processing rules | |||
| - <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom | |||
| @@ -65,6 +74,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>segmentation</code> (object) Segmentation rules | |||
| - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n | |||
| - <code>max_tokens</code> Maximum length (token) defaults to 1000 | |||
| - <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval | |||
| - <code>subchunk_segmentation</code> (object) Child chunk rules | |||
| - <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code> | |||
| - <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk | |||
| - <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional) | |||
| </Property> | |||
| </Properties> | |||
| </Col> | |||
| @@ -155,6 +169,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>high_quality</code> High quality: embedding using embedding model, built as vector database index | |||
| - <code>economy</code> Economy: Build using inverted index of keyword table index | |||
| - <code>doc_form</code> Format of indexed content | |||
| - <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form | |||
| - <code>hierarchical_model</code> Parent-child mode | |||
| - <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions | |||
| - <code>doc_language</code> In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code> | |||
| - <code>process_rule</code> Processing rules | |||
| - <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom | |||
| - <code>rules</code> (object) Custom rules (in automatic mode, this field is empty) | |||
| @@ -167,6 +188,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>segmentation</code> (object) Segmentation rules | |||
| - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n | |||
| - <code>max_tokens</code> Maximum length (token) defaults to 1000 | |||
| - <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval | |||
| - <code>subchunk_segmentation</code> (object) Child chunk rules | |||
| - <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code> | |||
| - <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk | |||
| - <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional) | |||
| </Property> | |||
| <Property name='file' type='multipart/form-data' key='file'> | |||
| Files that need to be uploaded. | |||
| @@ -449,6 +475,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>segmentation</code> (object) Segmentation rules | |||
| - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n | |||
| - <code>max_tokens</code> Maximum length (token) defaults to 1000 | |||
| - <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval | |||
| - <code>subchunk_segmentation</code> (object) Child chunk rules | |||
| - <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code> | |||
| - <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk | |||
| - <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional) | |||
| </Property> | |||
| </Properties> | |||
| </Col> | |||
| @@ -546,6 +577,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>segmentation</code> (object) Segmentation rules | |||
| - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n | |||
| - <code>max_tokens</code> Maximum length (token) defaults to 1000 | |||
| - <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval | |||
| - <code>subchunk_segmentation</code> (object) Child chunk rules | |||
| - <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code> | |||
| - <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk | |||
| - <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional) | |||
| </Property> | |||
| </Properties> | |||
| </Col> | |||
| @@ -984,7 +1020,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| <Heading | |||
| url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' | |||
| method='POST' | |||
| title='Update a Chunk in a Document ' | |||
| title='Update a Chunk in a Document' | |||
| name='#update_segment' | |||
| /> | |||
| <Row> | |||
| @@ -1009,6 +1045,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>answer</code> (text) Answer content, passed if the knowledge is in Q&A mode (optional) | |||
| - <code>keywords</code> (list) Keyword (optional) | |||
| - <code>enabled</code> (bool) False / true (optional) | |||
| - <code>regenerate_child_chunks</code> (bool) Whether to regenerate child chunks (optional) | |||
| </Property> | |||
| </Properties> | |||
| </Col> | |||
| @@ -52,6 +52,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 | |||
| - <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建 | |||
| </Property> | |||
| <Property name='doc_form' type='string' key='doc_form'> | |||
| 索引内容的形式 | |||
| - <code>text_model</code> text 文档直接 embedding,经济模式默认为该模式 | |||
| - <code>hierarchical_model</code> parent-child 模式 | |||
| - <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding | |||
| </Property> | |||
| <Property name='doc_language' type='string' key='doc_language'> | |||
| 在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code> | |||
| </Property> | |||
| <Property name='process_rule' type='object' key='process_rule'> | |||
| 处理规则 | |||
| - <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 | |||
| @@ -63,8 +72,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>remove_urls_emails</code> 删除 URL、电子邮件地址 | |||
| - <code>enabled</code> (bool) 是否选中该规则,不传入文档 ID 时代表默认值 | |||
| - <code>segmentation</code> (object) 分段规则 | |||
| - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n | |||
| - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 <code>\n</code> | |||
| - <code>max_tokens</code> 最大长度(token)默认为 1000 | |||
| - <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回 | |||
| - <code>subchunk_segmentation</code> (object) 子分段规则 | |||
| - <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code> | |||
| - <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度 | |||
| - <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填) | |||
| </Property> | |||
| </Properties> | |||
| </Col> | |||
| @@ -155,6 +169,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 | |||
| - <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建 | |||
| - <code>doc_form</code> 索引内容的形式 | |||
| - <code>text_model</code> text 文档直接 embedding,经济模式默认为该模式 | |||
| - <code>hierarchical_model</code> parent-child 模式 | |||
| - <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding | |||
| - <code>doc_language</code> 在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code> | |||
| - <code>process_rule</code> 处理规则 | |||
| - <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 | |||
| - <code>rules</code> (object) 自定义规则(自动模式下,该字段为空) | |||
| @@ -167,6 +188,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>segmentation</code> (object) 分段规则 | |||
| - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n | |||
| - <code>max_tokens</code> 最大长度(token)默认为 1000 | |||
| - <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回 | |||
| - <code>subchunk_segmentation</code> (object) 子分段规则 | |||
| - <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code> | |||
| - <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度 | |||
| - <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填) | |||
| </Property> | |||
| <Property name='file' type='multipart/form-data' key='file'> | |||
| 需要上传的文件。 | |||
| @@ -411,7 +437,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| <Heading | |||
| url='/datasets/{dataset_id}/documents/{document_id}/update-by-text' | |||
| method='POST' | |||
| title='通过文本更新文档 ' | |||
| title='通过文本更新文档' | |||
| name='#update-by-text' | |||
| /> | |||
| <Row> | |||
| @@ -449,6 +475,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>segmentation</code> (object) 分段规则 | |||
| - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n | |||
| - <code>max_tokens</code> 最大长度(token)默认为 1000 | |||
| - <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回 | |||
| - <code>subchunk_segmentation</code> (object) 子分段规则 | |||
| - <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code> | |||
| - <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度 | |||
| - <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填) | |||
| </Property> | |||
| </Properties> | |||
| </Col> | |||
| @@ -508,7 +539,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| <Heading | |||
| url='/datasets/{dataset_id}/documents/{document_id}/update-by-file' | |||
| method='POST' | |||
| title='通过文件更新文档 ' | |||
| title='通过文件更新文档' | |||
| name='#update-by-file' | |||
| /> | |||
| <Row> | |||
| @@ -546,6 +577,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>segmentation</code> (object) 分段规则 | |||
| - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n | |||
| - <code>max_tokens</code> 最大长度(token)默认为 1000 | |||
| - <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回 | |||
| - <code>subchunk_segmentation</code> (object) 子分段规则 | |||
| - <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code> | |||
| - <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度 | |||
| - <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填) | |||
| </Property> | |||
| </Properties> | |||
| </Col> | |||
| @@ -1009,6 +1045,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from | |||
| - <code>answer</code> (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值 | |||
| - <code>keywords</code> (list) 关键字,非必填 | |||
| - <code>enabled</code> (bool) false/true,非必填 | |||
| - <code>regenerate_child_chunks</code> (bool) 是否重新生成子分段,非必填 | |||
| </Property> | |||
| </Properties> | |||
| </Col> | |||
| @@ -26,13 +26,15 @@ const PromptEditorHeightResizeWrap: FC<Props> = ({ | |||
| const [clientY, setClientY] = useState(0) | |||
| const [isResizing, setIsResizing] = useState(false) | |||
| const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(getComputedStyle(document.body).userSelect) | |||
| const [oldHeight, setOldHeight] = useState(height) | |||
| const handleStartResize = useCallback((e: React.MouseEvent<HTMLElement>) => { | |||
| setClientY(e.clientY) | |||
| setIsResizing(true) | |||
| setOldHeight(height) | |||
| setPrevUserSelectStyle(getComputedStyle(document.body).userSelect) | |||
| document.body.style.userSelect = 'none' | |||
| }, []) | |||
| }, [height]) | |||
| const handleStopResize = useCallback(() => { | |||
| setIsResizing(false) | |||
| @@ -44,8 +46,7 @@ const PromptEditorHeightResizeWrap: FC<Props> = ({ | |||
| return | |||
| const offset = e.clientY - clientY | |||
| let newHeight = height + offset | |||
| setClientY(e.clientY) | |||
| let newHeight = oldHeight + offset | |||
| if (newHeight < minHeight) | |||
| newHeight = minHeight | |||
| onHeightChange(newHeight) | |||
| @@ -27,6 +27,7 @@ import { ADD_EXTERNAL_DATA_TOOL } from '@/app/components/app/configuration/confi | |||
| import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' | |||
| import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block' | |||
| import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' | |||
| import { useFeaturesStore } from '@/app/components/base/features/hooks' | |||
| export type ISimplePromptInput = { | |||
| mode: AppType | |||
| @@ -54,6 +55,11 @@ const Prompt: FC<ISimplePromptInput> = ({ | |||
| const { t } = useTranslation() | |||
| const media = useBreakpoints() | |||
| const isMobile = media === MediaType.mobile | |||
| const featuresStore = useFeaturesStore() | |||
| const { | |||
| features, | |||
| setFeatures, | |||
| } = featuresStore!.getState() | |||
| const { eventEmitter } = useEventEmitterContextContext() | |||
| const { | |||
| @@ -137,8 +143,18 @@ const Prompt: FC<ISimplePromptInput> = ({ | |||
| }) | |||
| setModelConfig(newModelConfig) | |||
| setPrevPromptConfig(modelConfig.configs) | |||
| if (mode !== AppType.completion) | |||
| if (mode !== AppType.completion) { | |||
| setIntroduction(res.opening_statement) | |||
| const newFeatures = produce(features, (draft) => { | |||
| draft.opening = { | |||
| ...draft.opening, | |||
| enabled: !!res.opening_statement, | |||
| opening_statement: res.opening_statement, | |||
| } | |||
| }) | |||
| setFeatures(newFeatures) | |||
| } | |||
| showAutomaticFalse() | |||
| } | |||
| const minHeight = initEditorHeight || 228 | |||
| @@ -59,36 +59,24 @@ const ConfigContent: FC<Props> = ({ | |||
| const { | |||
| modelList: rerankModelList, | |||
| defaultModel: rerankDefaultModel, | |||
| currentModel: isRerankDefaultModelValid, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const { | |||
| currentModel: currentRerankModel, | |||
| } = useCurrentProviderAndModel( | |||
| rerankModelList, | |||
| rerankDefaultModel | |||
| ? { | |||
| ...rerankDefaultModel, | |||
| provider: rerankDefaultModel.provider.provider, | |||
| } | |||
| : undefined, | |||
| { | |||
| provider: datasetConfigs.reranking_model?.reranking_provider_name, | |||
| model: datasetConfigs.reranking_model?.reranking_model_name, | |||
| }, | |||
| ) | |||
| const rerankModel = (() => { | |||
| if (datasetConfigs.reranking_model?.reranking_provider_name) { | |||
| return { | |||
| provider_name: datasetConfigs.reranking_model.reranking_provider_name, | |||
| model_name: datasetConfigs.reranking_model.reranking_model_name, | |||
| } | |||
| const rerankModel = useMemo(() => { | |||
| return { | |||
| provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '', | |||
| model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '', | |||
| } | |||
| else if (rerankDefaultModel) { | |||
| return { | |||
| provider_name: rerankDefaultModel.provider.provider, | |||
| model_name: rerankDefaultModel.model, | |||
| } | |||
| } | |||
| })() | |||
| }, [datasetConfigs.reranking_model]) | |||
| const handleParamChange = (key: string, value: number) => { | |||
| if (key === 'top_k') { | |||
| @@ -133,6 +121,12 @@ const ConfigContent: FC<Props> = ({ | |||
| } | |||
| const handleRerankModeChange = (mode: RerankingModeEnum) => { | |||
| if (mode === datasetConfigs.reranking_mode) | |||
| return | |||
| if (mode === RerankingModeEnum.RerankingModel && !currentRerankModel) | |||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | |||
| onChange({ | |||
| ...datasetConfigs, | |||
| reranking_mode: mode, | |||
| @@ -162,31 +156,25 @@ const ConfigContent: FC<Props> = ({ | |||
| const canManuallyToggleRerank = useMemo(() => { | |||
| return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic) | |||
| || selectedDatasetsMode.allExternal | |||
| || selectedDatasetsMode.allExternal | |||
| }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal]) | |||
| const showRerankModel = useMemo(() => { | |||
| if (!canManuallyToggleRerank) | |||
| return true | |||
| else if (canManuallyToggleRerank && !isRerankDefaultModelValid) | |||
| return false | |||
| return datasetConfigs.reranking_enable | |||
| }, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid]) | |||
| }, [datasetConfigs.reranking_enable, canManuallyToggleRerank]) | |||
| const handleDisabledSwitchClick = useCallback(() => { | |||
| if (!currentRerankModel && !showRerankModel) | |||
| const handleDisabledSwitchClick = useCallback((enable: boolean) => { | |||
| if (!currentRerankModel && enable) | |||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | |||
| }, [currentRerankModel, showRerankModel, t]) | |||
| useEffect(() => { | |||
| if (canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) { | |||
| onChange({ | |||
| ...datasetConfigs, | |||
| reranking_enable: showRerankModel, | |||
| }) | |||
| } | |||
| }, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange]) | |||
| onChange({ | |||
| ...datasetConfigs, | |||
| reranking_enable: enable, | |||
| }) | |||
| // eslint-disable-next-line react-hooks/exhaustive-deps | |||
| }, [currentRerankModel, datasetConfigs, onChange]) | |||
| return ( | |||
| <div> | |||
| @@ -267,24 +255,12 @@ const ConfigContent: FC<Props> = ({ | |||
| <div className='flex items-center'> | |||
| { | |||
| selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && ( | |||
| <div | |||
| className='flex items-center' | |||
| onClick={handleDisabledSwitchClick} | |||
| > | |||
| <Switch | |||
| size='md' | |||
| defaultValue={showRerankModel} | |||
| disabled={!currentRerankModel || !canManuallyToggleRerank} | |||
| onChange={(v) => { | |||
| if (canManuallyToggleRerank) { | |||
| onChange({ | |||
| ...datasetConfigs, | |||
| reranking_enable: v, | |||
| }) | |||
| } | |||
| }} | |||
| /> | |||
| </div> | |||
| <Switch | |||
| size='md' | |||
| defaultValue={showRerankModel} | |||
| disabled={!canManuallyToggleRerank} | |||
| onChange={handleDisabledSwitchClick} | |||
| /> | |||
| ) | |||
| } | |||
| <div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div> | |||
| @@ -298,21 +274,24 @@ const ConfigContent: FC<Props> = ({ | |||
| triggerClassName='ml-1 w-4 h-4' | |||
| /> | |||
| </div> | |||
| <div> | |||
| <ModelSelector | |||
| defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }} | |||
| onSelect={(v) => { | |||
| onChange({ | |||
| ...datasetConfigs, | |||
| reranking_model: { | |||
| reranking_provider_name: v.provider, | |||
| reranking_model_name: v.model, | |||
| }, | |||
| }) | |||
| }} | |||
| modelList={rerankModelList} | |||
| /> | |||
| </div> | |||
| { | |||
| showRerankModel && ( | |||
| <div> | |||
| <ModelSelector | |||
| defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }} | |||
| onSelect={(v) => { | |||
| onChange({ | |||
| ...datasetConfigs, | |||
| reranking_model: { | |||
| reranking_provider_name: v.provider, | |||
| reranking_model_name: v.model, | |||
| }, | |||
| }) | |||
| }} | |||
| modelList={rerankModelList} | |||
| /> | |||
| </div> | |||
| )} | |||
| </div> | |||
| ) | |||
| } | |||
| @@ -10,7 +10,7 @@ import Modal from '@/app/components/base/modal' | |||
| import Button from '@/app/components/base/button' | |||
| import { RETRIEVE_TYPE } from '@/types/app' | |||
| import Toast from '@/app/components/base/toast' | |||
| import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import { RerankingModeEnum } from '@/models/datasets' | |||
| import type { DataSet } from '@/models/datasets' | |||
| @@ -41,17 +41,27 @@ const ParamsConfig = ({ | |||
| }, [datasetConfigs]) | |||
| const { | |||
| defaultModel: rerankDefaultModel, | |||
| currentModel: isRerankDefaultModelValid, | |||
| modelList: rerankModelList, | |||
| currentModel: rerankDefaultModel, | |||
| currentProvider: rerankDefaultProvider, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const { | |||
| currentModel: isCurrentRerankModelValid, | |||
| } = useCurrentProviderAndModel( | |||
| rerankModelList, | |||
| { | |||
| provider: tempDataSetConfigs.reranking_model?.reranking_provider_name ?? '', | |||
| model: tempDataSetConfigs.reranking_model?.reranking_model_name ?? '', | |||
| }, | |||
| ) | |||
| const isValid = () => { | |||
| let errMsg = '' | |||
| if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) { | |||
| if (tempDataSetConfigs.reranking_enable | |||
| && tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel | |||
| && !isRerankDefaultModelValid | |||
| && !isCurrentRerankModelValid | |||
| ) | |||
| errMsg = t('appDebug.datasetConfig.rerankModelRequired') | |||
| } | |||
| @@ -66,16 +76,7 @@ const ParamsConfig = ({ | |||
| const handleSave = () => { | |||
| if (!isValid()) | |||
| return | |||
| const config = { ...tempDataSetConfigs } | |||
| if (config.retrieval_model === RETRIEVE_TYPE.multiWay | |||
| && config.reranking_mode === RerankingModeEnum.RerankingModel | |||
| && !config.reranking_model) { | |||
| config.reranking_model = { | |||
| reranking_provider_name: rerankDefaultModel?.provider?.provider, | |||
| reranking_model_name: rerankDefaultModel?.model, | |||
| } as any | |||
| } | |||
| setDatasetConfigs(config) | |||
| setDatasetConfigs(tempDataSetConfigs) | |||
| setRerankSettingModalOpen(false) | |||
| } | |||
| @@ -94,14 +95,14 @@ const ParamsConfig = ({ | |||
| reranking_enable: restConfigs.reranking_enable, | |||
| }, selectedDatasets, selectedDatasets, { | |||
| provider: rerankDefaultProvider?.provider, | |||
| model: isRerankDefaultModelValid?.model, | |||
| model: rerankDefaultModel?.model, | |||
| }) | |||
| setTempDataSetConfigs({ | |||
| ...retrievalConfig, | |||
| reranking_model: restConfigs.reranking_model && { | |||
| reranking_provider_name: restConfigs.reranking_model.reranking_provider_name, | |||
| reranking_model_name: restConfigs.reranking_model.reranking_model_name, | |||
| reranking_model: { | |||
| reranking_provider_name: retrievalConfig.reranking_model?.provider || '', | |||
| reranking_model_name: retrievalConfig.reranking_model?.model || '', | |||
| }, | |||
| retrieval_model, | |||
| score_threshold_enabled, | |||
| @@ -29,7 +29,7 @@ const WeightedScore = ({ | |||
| return ( | |||
| <div> | |||
| <div className='px-3 pt-5 h-[52px] space-x-3 rounded-lg border border-components-panel-border'> | |||
| <div className='px-3 pt-5 pb-2 space-x-3 rounded-lg border border-components-panel-border'> | |||
| <Slider | |||
| className={cn('grow h-0.5 !bg-util-colors-teal-teal-500 rounded-full')} | |||
| max={1.0} | |||
| @@ -39,7 +39,7 @@ const WeightedScore = ({ | |||
| onChange={v => onChange({ value: [v, (10 - v * 10) / 10] })} | |||
| trackClassName='weightedScoreSliderTrack' | |||
| /> | |||
| <div className='flex justify-between mt-1'> | |||
| <div className='flex justify-between mt-3'> | |||
| <div className='shrink-0 flex items-center w-[90px] system-xs-semibold-uppercase text-util-colors-blue-light-blue-light-500'> | |||
| <div className='mr-1 truncate uppercase' title={t('dataset.weightedScore.semantic') || ''}> | |||
| {t('dataset.weightedScore.semantic')} | |||
| @@ -12,7 +12,7 @@ import Divider from '@/app/components/base/divider' | |||
| import Button from '@/app/components/base/button' | |||
| import Input from '@/app/components/base/input' | |||
| import Textarea from '@/app/components/base/textarea' | |||
| import { type DataSet, RerankingModeEnum } from '@/models/datasets' | |||
| import { type DataSet } from '@/models/datasets' | |||
| import { useToastContext } from '@/app/components/base/toast' | |||
| import { updateDatasetSetting } from '@/service/datasets' | |||
| import { useAppContext } from '@/context/app-context' | |||
| @@ -21,7 +21,7 @@ import type { RetrievalConfig } from '@/types/app' | |||
| import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings' | |||
| import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' | |||
| import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' | |||
| import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' | |||
| import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' | |||
| import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' | |||
| import PermissionSelector from '@/app/components/datasets/settings/permission-selector' | |||
| import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' | |||
| @@ -99,8 +99,6 @@ const SettingsModal: FC<SettingsModalProps> = ({ | |||
| } | |||
| if ( | |||
| !isReRankModelSelected({ | |||
| rerankDefaultModel, | |||
| isRerankDefaultModelValid: !!isRerankDefaultModelValid, | |||
| rerankModelList, | |||
| retrievalConfig, | |||
| indexMethod, | |||
| @@ -109,14 +107,6 @@ const SettingsModal: FC<SettingsModalProps> = ({ | |||
| notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') }) | |||
| return | |||
| } | |||
| const postRetrievalConfig = ensureRerankModelSelected({ | |||
| rerankDefaultModel: rerankDefaultModel!, | |||
| retrievalConfig: { | |||
| ...retrievalConfig, | |||
| reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel, | |||
| }, | |||
| indexMethod, | |||
| }) | |||
| try { | |||
| setLoading(true) | |||
| const { id, name, description, permission } = localeCurrentDataset | |||
| @@ -128,8 +118,8 @@ const SettingsModal: FC<SettingsModalProps> = ({ | |||
| permission, | |||
| indexing_technique: indexMethod, | |||
| retrieval_model: { | |||
| ...postRetrievalConfig, | |||
| score_threshold: postRetrievalConfig.score_threshold_enabled ? postRetrievalConfig.score_threshold : 0, | |||
| ...retrievalConfig, | |||
| score_threshold: retrievalConfig.score_threshold_enabled ? retrievalConfig.score_threshold : 0, | |||
| }, | |||
| embedding_model: localeCurrentDataset.embedding_model, | |||
| embedding_model_provider: localeCurrentDataset.embedding_model_provider, | |||
| @@ -157,7 +147,7 @@ const SettingsModal: FC<SettingsModalProps> = ({ | |||
| onSave({ | |||
| ...localeCurrentDataset, | |||
| indexing_technique: indexMethod, | |||
| retrieval_model_dict: postRetrievalConfig, | |||
| retrieval_model_dict: retrievalConfig, | |||
| }) | |||
| } | |||
| catch (e) { | |||
| @@ -287,9 +287,9 @@ const Configuration: FC = () => { | |||
| setDatasetConfigs({ | |||
| ...retrievalConfig, | |||
| reranking_model: restConfigs.reranking_model && { | |||
| reranking_provider_name: restConfigs.reranking_model.reranking_provider_name, | |||
| reranking_model_name: restConfigs.reranking_model.reranking_model_name, | |||
| reranking_model: { | |||
| reranking_provider_name: retrievalConfig?.reranking_model?.provider || '', | |||
| reranking_model_name: retrievalConfig?.reranking_model?.model || '', | |||
| }, | |||
| retrieval_model, | |||
| score_threshold_enabled, | |||
| @@ -39,6 +39,7 @@ type ChatInputAreaProps = { | |||
| inputs?: Record<string, any> | |||
| inputsForm?: InputForm[] | |||
| theme?: Theme | null | |||
| isResponding?: boolean | |||
| } | |||
| const ChatInputArea = ({ | |||
| showFeatureBar, | |||
| @@ -51,6 +52,7 @@ const ChatInputArea = ({ | |||
| inputs = {}, | |||
| inputsForm = [], | |||
| theme, | |||
| isResponding, | |||
| }: ChatInputAreaProps) => { | |||
| const { t } = useTranslation() | |||
| const { notify } = useToastContext() | |||
| @@ -77,6 +79,11 @@ const ChatInputArea = ({ | |||
| const historyRef = useRef(['']) | |||
| const [currentIndex, setCurrentIndex] = useState(-1) | |||
| const handleSend = () => { | |||
| if (isResponding) { | |||
| notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') }) | |||
| return | |||
| } | |||
| if (onSend) { | |||
| const { files, setFiles } = filesStore.getState() | |||
| if (files.find(item => item.transferMethod === TransferMethod.local_file && !item.uploadedId)) { | |||
| @@ -116,7 +123,7 @@ const ChatInputArea = ({ | |||
| setQuery(historyRef.current[currentIndex + 1]) | |||
| } | |||
| else if (currentIndex === historyRef.current.length - 1) { | |||
| // If it is the last element, clear the input box | |||
| // If it is the last element, clear the input box | |||
| setCurrentIndex(historyRef.current.length) | |||
| setQuery('') | |||
| } | |||
| @@ -169,6 +176,7 @@ const ChatInputArea = ({ | |||
| 'p-1 w-full leading-6 body-lg-regular text-text-tertiary outline-none', | |||
| )} | |||
| placeholder={t('common.chat.inputPlaceholder') || ''} | |||
| autoFocus | |||
| autoSize={{ minRows: 1 }} | |||
| onResize={handleTextareaResize} | |||
| value={query} | |||
| @@ -292,6 +292,7 @@ const Chat: FC<ChatProps> = ({ | |||
| inputs={inputs} | |||
| inputsForm={inputsForm} | |||
| theme={themeBuilder?.theme} | |||
| isResponding={isResponding} | |||
| /> | |||
| ) | |||
| } | |||
| @@ -28,8 +28,8 @@ const Question: FC<QuestionProps> = ({ | |||
| } = item | |||
| return ( | |||
| <div className='flex justify-end mb-2 last:mb-0 pl-10'> | |||
| <div className='group relative mr-4'> | |||
| <div className='flex justify-end mb-2 last:mb-0 pl-14'> | |||
| <div className='group relative mr-4 max-w-full'> | |||
| <div | |||
| className='px-4 py-3 bg-[#D1E9FF]/50 rounded-2xl text-sm text-gray-900' | |||
| style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}} | |||
| @@ -111,9 +111,9 @@ const CodeBlock: CodeComponent = memo(({ inline, className, children, ...props } | |||
| } | |||
| else if (language === 'echarts') { | |||
| return ( | |||
| <div style={{ minHeight: '350px', minWidth: '700px' }}> | |||
| <div style={{ minHeight: '350px', minWidth: '100%', overflowX: 'scroll' }}> | |||
| <ErrorBoundary> | |||
| <ReactEcharts option={chartData} /> | |||
| <ReactEcharts option={chartData} style={{ minWidth: '700px' }} /> | |||
| </ErrorBoundary> | |||
| </div> | |||
| ) | |||
| @@ -11,11 +11,17 @@ type Props = { | |||
| enable: boolean | |||
| } | |||
| const maxTopK = (() => { | |||
| const configValue = parseInt(globalThis.document?.body?.getAttribute('data-public-top-k-max-value') || '', 10) | |||
| if (configValue && !isNaN(configValue)) | |||
| return configValue | |||
| return 10 | |||
| })() | |||
| const VALUE_LIMIT = { | |||
| default: 2, | |||
| step: 1, | |||
| min: 1, | |||
| max: 10, | |||
| max: maxTopK, | |||
| } | |||
| const key = 'top_k' | |||
| @@ -6,14 +6,10 @@ import type { | |||
| import { RerankingModeEnum } from '@/models/datasets' | |||
| export const isReRankModelSelected = ({ | |||
| rerankDefaultModel, | |||
| isRerankDefaultModelValid, | |||
| retrievalConfig, | |||
| rerankModelList, | |||
| indexMethod, | |||
| }: { | |||
| rerankDefaultModel?: DefaultModelResponse | |||
| isRerankDefaultModelValid: boolean | |||
| retrievalConfig: RetrievalConfig | |||
| rerankModelList: Model[] | |||
| indexMethod?: string | |||
| @@ -25,12 +21,17 @@ export const isReRankModelSelected = ({ | |||
| return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name) | |||
| } | |||
| if (isRerankDefaultModelValid) | |||
| return !!rerankDefaultModel | |||
| return false | |||
| })() | |||
| if ( | |||
| indexMethod === 'high_quality' | |||
| && ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrievalConfig.search_method)) | |||
| && retrievalConfig.reranking_enable | |||
| && !rerankModelSelected | |||
| ) | |||
| return false | |||
| if ( | |||
| indexMethod === 'high_quality' | |||
| && (retrievalConfig.search_method === RETRIEVE_METHOD.hybrid && retrievalConfig.reranking_mode !== RerankingModeEnum.WeightedScore) | |||
| @@ -10,11 +10,13 @@ import { RETRIEVE_METHOD } from '@/types/app' | |||
| import type { RetrievalConfig } from '@/types/app' | |||
| type Props = { | |||
| disabled?: boolean | |||
| value: RetrievalConfig | |||
| onChange: (value: RetrievalConfig) => void | |||
| } | |||
| const EconomicalRetrievalMethodConfig: FC<Props> = ({ | |||
| disabled = false, | |||
| value, | |||
| onChange, | |||
| }) => { | |||
| @@ -22,7 +24,8 @@ const EconomicalRetrievalMethodConfig: FC<Props> = ({ | |||
| return ( | |||
| <div className='space-y-2'> | |||
| <OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />} | |||
| <OptionCard | |||
| disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />} | |||
| title={t('dataset.retrieval.invertedIndex.title')} | |||
| description={t('dataset.retrieval.invertedIndex.description')} isActive | |||
| activeHeaderClassName='bg-dataset-option-card-purple-gradient' | |||
| @@ -1,6 +1,6 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React from 'react' | |||
| import React, { useCallback } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import Image from 'next/image' | |||
| import RetrievalParamConfig from '../retrieval-param-config' | |||
| @@ -10,7 +10,7 @@ import { retrievalIcon } from '../../create/icons' | |||
| import type { RetrievalConfig } from '@/types/app' | |||
| import { RETRIEVE_METHOD } from '@/types/app' | |||
| import { useProviderContext } from '@/context/provider-context' | |||
| import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import { | |||
| DEFAULT_WEIGHTED_SCORE, | |||
| @@ -20,54 +20,87 @@ import { | |||
| import Badge from '@/app/components/base/badge' | |||
| type Props = { | |||
| disabled?: boolean | |||
| value: RetrievalConfig | |||
| onChange: (value: RetrievalConfig) => void | |||
| } | |||
| const RetrievalMethodConfig: FC<Props> = ({ | |||
| value: passValue, | |||
| disabled = false, | |||
| value, | |||
| onChange, | |||
| }) => { | |||
| const { t } = useTranslation() | |||
| const { supportRetrievalMethods } = useProviderContext() | |||
| const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank) | |||
| const value = (() => { | |||
| if (!passValue.reranking_model.reranking_model_name) { | |||
| return { | |||
| ...passValue, | |||
| reranking_model: { | |||
| reranking_provider_name: rerankDefaultModel?.provider.provider || '', | |||
| reranking_model_name: rerankDefaultModel?.model || '', | |||
| }, | |||
| reranking_mode: passValue.reranking_mode || (rerankDefaultModel ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore), | |||
| weights: passValue.weights || { | |||
| weight_type: WeightedScoreEnum.Customized, | |||
| vector_setting: { | |||
| vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic, | |||
| embedding_provider_name: '', | |||
| embedding_model_name: '', | |||
| }, | |||
| keyword_setting: { | |||
| keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword, | |||
| }, | |||
| }, | |||
| } | |||
| const { | |||
| defaultModel: rerankDefaultModel, | |||
| currentModel: isRerankDefaultModelValid, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const onSwitch = useCallback((retrieveMethod: RETRIEVE_METHOD) => { | |||
| if ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrieveMethod)) { | |||
| onChange({ | |||
| ...value, | |||
| search_method: retrieveMethod, | |||
| ...(!value.reranking_model.reranking_model_name | |||
| ? { | |||
| reranking_model: { | |||
| reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', | |||
| reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '', | |||
| }, | |||
| reranking_enable: !!isRerankDefaultModelValid, | |||
| } | |||
| : { | |||
| reranking_enable: true, | |||
| }), | |||
| }) | |||
| } | |||
| return passValue | |||
| })() | |||
| if (retrieveMethod === RETRIEVE_METHOD.hybrid) { | |||
| onChange({ | |||
| ...value, | |||
| search_method: retrieveMethod, | |||
| ...(!value.reranking_model.reranking_model_name | |||
| ? { | |||
| reranking_model: { | |||
| reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', | |||
| reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '', | |||
| }, | |||
| reranking_enable: !!isRerankDefaultModelValid, | |||
| reranking_mode: isRerankDefaultModelValid ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore, | |||
| } | |||
| : { | |||
| reranking_enable: true, | |||
| reranking_mode: RerankingModeEnum.RerankingModel, | |||
| }), | |||
| ...(!value.weights | |||
| ? { | |||
| weights: { | |||
| weight_type: WeightedScoreEnum.Customized, | |||
| vector_setting: { | |||
| vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic, | |||
| embedding_provider_name: '', | |||
| embedding_model_name: '', | |||
| }, | |||
| keyword_setting: { | |||
| keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword, | |||
| }, | |||
| }, | |||
| } | |||
| : {}), | |||
| }) | |||
| } | |||
| }, [value, rerankDefaultModel, isRerankDefaultModelValid, onChange]) | |||
| return ( | |||
| <div className='space-y-2'> | |||
| {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( | |||
| <OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />} | |||
| <OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />} | |||
| title={t('dataset.retrieval.semantic_search.title')} | |||
| description={t('dataset.retrieval.semantic_search.description')} | |||
| isActive={ | |||
| value.search_method === RETRIEVE_METHOD.semantic | |||
| } | |||
| onSwitched={() => onChange({ | |||
| ...value, | |||
| search_method: RETRIEVE_METHOD.semantic, | |||
| })} | |||
| onSwitched={() => onSwitch(RETRIEVE_METHOD.semantic)} | |||
| effectImg={Effect.src} | |||
| activeHeaderClassName='bg-dataset-option-card-purple-gradient' | |||
| > | |||
| @@ -78,17 +111,14 @@ const RetrievalMethodConfig: FC<Props> = ({ | |||
| /> | |||
| </OptionCard> | |||
| )} | |||
| {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( | |||
| <OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />} | |||
| {supportRetrievalMethods.includes(RETRIEVE_METHOD.fullText) && ( | |||
| <OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />} | |||
| title={t('dataset.retrieval.full_text_search.title')} | |||
| description={t('dataset.retrieval.full_text_search.description')} | |||
| isActive={ | |||
| value.search_method === RETRIEVE_METHOD.fullText | |||
| } | |||
| onSwitched={() => onChange({ | |||
| ...value, | |||
| search_method: RETRIEVE_METHOD.fullText, | |||
| })} | |||
| onSwitched={() => onSwitch(RETRIEVE_METHOD.fullText)} | |||
| effectImg={Effect.src} | |||
| activeHeaderClassName='bg-dataset-option-card-purple-gradient' | |||
| > | |||
| @@ -99,8 +129,8 @@ const RetrievalMethodConfig: FC<Props> = ({ | |||
| /> | |||
| </OptionCard> | |||
| )} | |||
| {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( | |||
| <OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />} | |||
| {supportRetrievalMethods.includes(RETRIEVE_METHOD.hybrid) && ( | |||
| <OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />} | |||
| title={ | |||
| <div className='flex items-center space-x-1'> | |||
| <div>{t('dataset.retrieval.hybrid_search.title')}</div> | |||
| @@ -110,11 +140,7 @@ const RetrievalMethodConfig: FC<Props> = ({ | |||
| description={t('dataset.retrieval.hybrid_search.description')} isActive={ | |||
| value.search_method === RETRIEVE_METHOD.hybrid | |||
| } | |||
| onSwitched={() => onChange({ | |||
| ...value, | |||
| search_method: RETRIEVE_METHOD.hybrid, | |||
| reranking_enable: true, | |||
| })} | |||
| onSwitched={() => onSwitch(RETRIEVE_METHOD.hybrid)} | |||
| effectImg={Effect.src} | |||
| activeHeaderClassName='bg-dataset-option-card-purple-gradient' | |||
| > | |||
| @@ -1,6 +1,6 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React, { useCallback } from 'react' | |||
| import React, { useCallback, useMemo } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import Image from 'next/image' | |||
| @@ -39,8 +39,8 @@ const RetrievalParamConfig: FC<Props> = ({ | |||
| const { t } = useTranslation() | |||
| const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid | |||
| const isEconomical = type === RETRIEVE_METHOD.invertedIndex | |||
| const isHybridSearch = type === RETRIEVE_METHOD.hybrid | |||
| const { | |||
| defaultModel: rerankDefaultModel, | |||
| modelList: rerankModelList, | |||
| } = useModelListAndDefaultModel(ModelTypeEnum.rerank) | |||
| @@ -48,35 +48,28 @@ const RetrievalParamConfig: FC<Props> = ({ | |||
| currentModel, | |||
| } = useCurrentProviderAndModel( | |||
| rerankModelList, | |||
| rerankDefaultModel | |||
| ? { | |||
| ...rerankDefaultModel, | |||
| provider: rerankDefaultModel.provider.provider, | |||
| } | |||
| : undefined, | |||
| { | |||
| provider: value.reranking_model?.reranking_provider_name ?? '', | |||
| model: value.reranking_model?.reranking_model_name ?? '', | |||
| }, | |||
| ) | |||
| const handleDisabledSwitchClick = useCallback(() => { | |||
| if (!currentModel) | |||
| const handleDisabledSwitchClick = useCallback((enable: boolean) => { | |||
| if (enable && !currentModel) | |||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | |||
| }, [currentModel, rerankDefaultModel, t]) | |||
| const isHybridSearch = type === RETRIEVE_METHOD.hybrid | |||
| onChange({ | |||
| ...value, | |||
| reranking_enable: enable, | |||
| }) | |||
| // eslint-disable-next-line react-hooks/exhaustive-deps | |||
| }, [currentModel, onChange, value]) | |||
| const rerankModel = (() => { | |||
| if (value.reranking_model) { | |||
| return { | |||
| provider_name: value.reranking_model.reranking_provider_name, | |||
| model_name: value.reranking_model.reranking_model_name, | |||
| } | |||
| } | |||
| else if (rerankDefaultModel) { | |||
| return { | |||
| provider_name: rerankDefaultModel.provider.provider, | |||
| model_name: rerankDefaultModel.model, | |||
| } | |||
| const rerankModel = useMemo(() => { | |||
| return { | |||
| provider_name: value.reranking_model.reranking_provider_name, | |||
| model_name: value.reranking_model.reranking_model_name, | |||
| } | |||
| })() | |||
| }, [value.reranking_model]) | |||
| const handleChangeRerankMode = (v: RerankingModeEnum) => { | |||
| if (v === value.reranking_mode) | |||
| @@ -100,6 +93,8 @@ const RetrievalParamConfig: FC<Props> = ({ | |||
| }, | |||
| } | |||
| } | |||
| if (v === RerankingModeEnum.RerankingModel && !currentModel) | |||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | |||
| onChange(result) | |||
| } | |||
| @@ -122,22 +117,11 @@ const RetrievalParamConfig: FC<Props> = ({ | |||
| <div> | |||
| <div className='flex items-center space-x-2 mb-2'> | |||
| {canToggleRerankModalEnable && ( | |||
| <div | |||
| className='flex items-center' | |||
| onClick={handleDisabledSwitchClick} | |||
| > | |||
| <Switch | |||
| size='md' | |||
| defaultValue={currentModel ? value.reranking_enable : false} | |||
| onChange={(v) => { | |||
| onChange({ | |||
| ...value, | |||
| reranking_enable: v, | |||
| }) | |||
| }} | |||
| disabled={!currentModel} | |||
| /> | |||
| </div> | |||
| <Switch | |||
| size='md' | |||
| defaultValue={value.reranking_enable} | |||
| onChange={handleDisabledSwitchClick} | |||
| /> | |||
| )} | |||
| <div className='flex items-center'> | |||
| <span className='mr-0.5 system-sm-semibold text-text-secondary'>{t('common.modelProvider.rerankModel.key')}</span> | |||
| @@ -148,21 +132,23 @@ const RetrievalParamConfig: FC<Props> = ({ | |||
| /> | |||
| </div> | |||
| </div> | |||
| <ModelSelector | |||
| triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`} | |||
| defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }} | |||
| modelList={rerankModelList} | |||
| readonly={!value.reranking_enable} | |||
| onSelect={(v) => { | |||
| onChange({ | |||
| ...value, | |||
| reranking_model: { | |||
| reranking_provider_name: v.provider, | |||
| reranking_model_name: v.model, | |||
| }, | |||
| }) | |||
| }} | |||
| /> | |||
| { | |||
| value.reranking_enable && ( | |||
| <ModelSelector | |||
| defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }} | |||
| modelList={rerankModelList} | |||
| onSelect={(v) => { | |||
| onChange({ | |||
| ...value, | |||
| reranking_model: { | |||
| reranking_provider_name: v.provider, | |||
| reranking_model_name: v.model, | |||
| }, | |||
| }) | |||
| }} | |||
| /> | |||
| ) | |||
| } | |||
| </div> | |||
| )} | |||
| { | |||
| @@ -255,10 +241,8 @@ const RetrievalParamConfig: FC<Props> = ({ | |||
| { | |||
| value.reranking_mode !== RerankingModeEnum.WeightedScore && ( | |||
| <ModelSelector | |||
| triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`} | |||
| defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }} | |||
| modelList={rerankModelList} | |||
| readonly={!value.reranking_enable} | |||
| onSelect={(v) => { | |||
| onChange({ | |||
| ...value, | |||
| @@ -30,6 +30,7 @@ import { useProviderContext } from '@/context/provider-context' | |||
| import { sleep } from '@/utils' | |||
| import { RETRIEVE_METHOD } from '@/types/app' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| import { useInvalidDocumentList } from '@/service/knowledge/use-document' | |||
| type Props = { | |||
| datasetId: string | |||
| @@ -207,7 +208,9 @@ const EmbeddingProcess: FC<Props> = ({ datasetId, batchId, documents = [], index | |||
| }) | |||
| const router = useRouter() | |||
| const invalidDocumentList = useInvalidDocumentList() | |||
| const navToDocumentList = () => { | |||
| invalidDocumentList() | |||
| router.push(`/datasets/${datasetId}/documents`) | |||
| } | |||
| const navToApiDocs = () => { | |||
| @@ -31,17 +31,17 @@ import LanguageSelect from './language-select' | |||
| import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs' | |||
| import cn from '@/utils/classnames' | |||
| import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, DocumentItem, FullDocumentDetail, ParentMode, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' | |||
| import { ChunkingMode, DataSourceType, ProcessMode } from '@/models/datasets' | |||
| import Button from '@/app/components/base/button' | |||
| import FloatRightContainer from '@/app/components/base/float-right-container' | |||
| import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' | |||
| import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' | |||
| import { type RetrievalConfig } from '@/types/app' | |||
| import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' | |||
| import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' | |||
| import Toast from '@/app/components/base/toast' | |||
| import type { NotionPage } from '@/models/common' | |||
| import { DataSourceProvider } from '@/models/common' | |||
| import { ChunkingMode, DataSourceType, RerankingModeEnum } from '@/models/datasets' | |||
| import { useDatasetDetailContext } from '@/context/dataset-detail' | |||
| import I18n from '@/context/i18n' | |||
| import { RETRIEVE_METHOD } from '@/types/app' | |||
| @@ -53,7 +53,7 @@ import type { DefaultModel } from '@/app/components/header/account-setting/model | |||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import Checkbox from '@/app/components/base/checkbox' | |||
| import RadioCard from '@/app/components/base/radio-card' | |||
| import { IS_CE_EDITION } from '@/config' | |||
| import { FULL_DOC_PREVIEW_LENGTH, IS_CE_EDITION } from '@/config' | |||
| import Divider from '@/app/components/base/divider' | |||
| import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument, useFetchDefaultProcessRule, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/knowledge/use-create-dataset' | |||
| import Badge from '@/app/components/base/badge' | |||
| @@ -90,17 +90,13 @@ type StepTwoProps = { | |||
| onCancel?: () => void | |||
| } | |||
| export enum SegmentType { | |||
| AUTO = 'automatic', | |||
| CUSTOM = 'custom', | |||
| } | |||
| export enum IndexingType { | |||
| QUALIFIED = 'high_quality', | |||
| ECONOMICAL = 'economy', | |||
| } | |||
| const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n' | |||
| const DEFAULT_MAXMIMUM_CHUNK_LENGTH = 500 | |||
| const DEFAULT_MAXIMUM_CHUNK_LENGTH = 500 | |||
| const DEFAULT_OVERLAP = 50 | |||
| type ParentChildConfig = { | |||
| @@ -131,7 +127,6 @@ const StepTwo = ({ | |||
| isSetting, | |||
| documentDetail, | |||
| isAPIKeySet, | |||
| onSetting, | |||
| datasetId, | |||
| indexingType, | |||
| dataSourceType: inCreatePageDataSourceType, | |||
| @@ -162,12 +157,12 @@ const StepTwo = ({ | |||
| const isInCreatePage = !datasetId || (datasetId && !currentDataset?.data_source_type) | |||
| const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : currentDataset?.data_source_type | |||
| const [segmentationType, setSegmentationType] = useState<SegmentType>(SegmentType.CUSTOM) | |||
| const [segmentationType, setSegmentationType] = useState<ProcessMode>(ProcessMode.general) | |||
| const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER) | |||
| const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => { | |||
| doSetSegmentIdentifier(value ? escape(value) : (canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER)) | |||
| }, []) | |||
| const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXMIMUM_CHUNK_LENGTH) // default chunk length | |||
| const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXIMUM_CHUNK_LENGTH) // default chunk length | |||
| const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(4000) | |||
| const [overlap, setOverlap] = useState(DEFAULT_OVERLAP) | |||
| const [rules, setRules] = useState<PreProcessingRule[]>([]) | |||
| @@ -198,7 +193,6 @@ const StepTwo = ({ | |||
| ) | |||
| // QA Related | |||
| const [isLanguageSelectDisabled, _setIsLanguageSelectDisabled] = useState(false) | |||
| const [isQAConfirmDialogOpen, setIsQAConfirmDialogOpen] = useState(false) | |||
| const [docForm, setDocForm] = useState<ChunkingMode>( | |||
| (datasetId && documentDetail) ? documentDetail.doc_form as ChunkingMode : ChunkingMode.text, | |||
| @@ -348,7 +342,7 @@ const StepTwo = ({ | |||
| } | |||
| const updatePreview = () => { | |||
| if (segmentationType === SegmentType.CUSTOM && maxChunkLength > 4000) { | |||
| if (segmentationType === ProcessMode.general && maxChunkLength > 4000) { | |||
| Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') }) | |||
| return | |||
| } | |||
| @@ -373,13 +367,42 @@ const StepTwo = ({ | |||
| model: defaultEmbeddingModel?.model || '', | |||
| }, | |||
| ) | |||
| const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || { | |||
| search_method: RETRIEVE_METHOD.semantic, | |||
| reranking_enable: false, | |||
| reranking_model: { | |||
| reranking_provider_name: '', | |||
| reranking_model_name: '', | |||
| }, | |||
| top_k: 3, | |||
| score_threshold_enabled: false, | |||
| score_threshold: 0.5, | |||
| } as RetrievalConfig) | |||
| useEffect(() => { | |||
| if (currentDataset?.retrieval_model_dict) | |||
| return | |||
| setRetrievalConfig({ | |||
| search_method: RETRIEVE_METHOD.semantic, | |||
| reranking_enable: !!isRerankDefaultModelValid, | |||
| reranking_model: { | |||
| reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider.provider ?? '' : '', | |||
| reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '', | |||
| }, | |||
| top_k: 3, | |||
| score_threshold_enabled: false, | |||
| score_threshold: 0.5, | |||
| }) | |||
| // eslint-disable-next-line react-hooks/exhaustive-deps | |||
| }, [rerankDefaultModel, isRerankDefaultModelValid]) | |||
| const getCreationParams = () => { | |||
| let params | |||
| if (segmentationType === SegmentType.CUSTOM && overlap > maxChunkLength) { | |||
| if (segmentationType === ProcessMode.general && overlap > maxChunkLength) { | |||
| Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.overlapCheck') }) | |||
| return | |||
| } | |||
| if (segmentationType === SegmentType.CUSTOM && maxChunkLength > limitMaxChunkLength) { | |||
| if (segmentationType === ProcessMode.general && maxChunkLength > limitMaxChunkLength) { | |||
| Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck', { limit: limitMaxChunkLength }) }) | |||
| return | |||
| } | |||
| @@ -389,7 +412,6 @@ const StepTwo = ({ | |||
| doc_form: currentDocForm, | |||
| doc_language: docLanguage, | |||
| process_rule: getProcessRule(), | |||
| // eslint-disable-next-line @typescript-eslint/no-use-before-define | |||
| retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page. | |||
| embedding_model: embeddingModel.model, // Readonly | |||
| embedding_model_provider: embeddingModel.provider, // Readonly | |||
| @@ -400,10 +422,7 @@ const StepTwo = ({ | |||
| const indexMethod = getIndexing_technique() | |||
| if ( | |||
| !isReRankModelSelected({ | |||
| rerankDefaultModel, | |||
| isRerankDefaultModelValid: !!isRerankDefaultModelValid, | |||
| rerankModelList, | |||
| // eslint-disable-next-line @typescript-eslint/no-use-before-define | |||
| retrievalConfig, | |||
| indexMethod: indexMethod as string, | |||
| }) | |||
| @@ -411,16 +430,6 @@ const StepTwo = ({ | |||
| Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') }) | |||
| return | |||
| } | |||
| const postRetrievalConfig = ensureRerankModelSelected({ | |||
| rerankDefaultModel: rerankDefaultModel!, | |||
| retrievalConfig: { | |||
| // eslint-disable-next-line @typescript-eslint/no-use-before-define | |||
| ...retrievalConfig, | |||
| // eslint-disable-next-line @typescript-eslint/no-use-before-define | |||
| reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel, | |||
| }, | |||
| indexMethod: indexMethod as string, | |||
| }) | |||
| params = { | |||
| data_source: { | |||
| type: dataSourceType, | |||
| @@ -432,8 +441,7 @@ const StepTwo = ({ | |||
| process_rule: getProcessRule(), | |||
| doc_form: currentDocForm, | |||
| doc_language: docLanguage, | |||
| retrieval_model: postRetrievalConfig, | |||
| retrieval_model: retrievalConfig, | |||
| embedding_model: embeddingModel.model, | |||
| embedding_model_provider: embeddingModel.provider, | |||
| } as CreateDocumentReq | |||
| @@ -490,7 +498,6 @@ const StepTwo = ({ | |||
| const getDefaultMode = () => { | |||
| if (documentDetail) | |||
| // @ts-expect-error fix after api refactored | |||
| setSegmentationType(documentDetail.dataset_process_rule.mode) | |||
| } | |||
| @@ -525,7 +532,6 @@ const StepTwo = ({ | |||
| onSuccess(data) { | |||
| updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) | |||
| updateResultCache && updateResultCache(data) | |||
| // eslint-disable-next-line @typescript-eslint/no-use-before-define | |||
| updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string) | |||
| }, | |||
| }, | |||
| @@ -545,14 +551,6 @@ const StepTwo = ({ | |||
| isSetting && onSave && onSave() | |||
| } | |||
| const changeToEconomicalType = () => { | |||
| if (docForm !== ChunkingMode.text) | |||
| return | |||
| if (!hasSetIndexType) | |||
| setIndexType(IndexingType.ECONOMICAL) | |||
| } | |||
| useEffect(() => { | |||
| // fetch rules | |||
| if (!isSetting) { | |||
| @@ -574,18 +572,6 @@ const StepTwo = ({ | |||
| setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL) | |||
| }, [isAPIKeySet, indexingType, datasetId]) | |||
| const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || { | |||
| search_method: RETRIEVE_METHOD.semantic, | |||
| reranking_enable: false, | |||
| reranking_model: { | |||
| reranking_provider_name: rerankDefaultModel?.provider.provider, | |||
| reranking_model_name: rerankDefaultModel?.model, | |||
| }, | |||
| top_k: 3, | |||
| score_threshold_enabled: false, | |||
| score_threshold: 0.5, | |||
| } as RetrievalConfig) | |||
| const economyDomRef = useRef<HTMLDivElement>(null) | |||
| const isHoveringEconomy = useHover(economyDomRef) | |||
| @@ -946,6 +932,7 @@ const StepTwo = ({ | |||
| <div className={cn('system-md-semibold mb-1', datasetId && 'flex justify-between items-center')}>{t('datasetSettings.form.embeddingModel')}</div> | |||
| <ModelSelector | |||
| readonly={!!datasetId} | |||
| triggerClassName={datasetId ? 'opacity-50' : ''} | |||
| defaultModel={embeddingModel} | |||
| modelList={embeddingModelList} | |||
| onSelect={(model: DefaultModel) => { | |||
| @@ -984,12 +971,14 @@ const StepTwo = ({ | |||
| getIndexing_technique() === IndexingType.QUALIFIED | |||
| ? ( | |||
| <RetrievalMethodConfig | |||
| disabled={!!datasetId} | |||
| value={retrievalConfig} | |||
| onChange={setRetrievalConfig} | |||
| /> | |||
| ) | |||
| : ( | |||
| <EconomicalRetrievalMethodConfig | |||
| disabled={!!datasetId} | |||
| value={retrievalConfig} | |||
| onChange={setRetrievalConfig} | |||
| /> | |||
| @@ -1010,7 +999,7 @@ const StepTwo = ({ | |||
| ) | |||
| : ( | |||
| <div className='flex items-center mt-8 py-2'> | |||
| <Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button> | |||
| {!datasetId && <Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button>} | |||
| <Button className='ml-2' onClick={onCancel}>{t('datasetCreation.stepTwo.cancel')}</Button> | |||
| </div> | |||
| )} | |||
| @@ -1081,11 +1070,11 @@ const StepTwo = ({ | |||
| } | |||
| { | |||
| currentDocForm !== ChunkingMode.qa | |||
| && <Badge text={t( | |||
| 'datasetCreation.stepTwo.previewChunkCount', { | |||
| count: estimate?.total_segments || 0, | |||
| }) as string} | |||
| /> | |||
| && <Badge text={t( | |||
| 'datasetCreation.stepTwo.previewChunkCount', { | |||
| count: estimate?.total_segments || 0, | |||
| }) as string} | |||
| /> | |||
| } | |||
| </div> | |||
| </PreviewHeader>} | |||
| @@ -1117,6 +1106,9 @@ const StepTwo = ({ | |||
| {currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && ( | |||
| estimate?.preview?.map((item, index) => { | |||
| const indexForLabel = index + 1 | |||
| const childChunks = parentChildConfig.chunkForContext === 'full-doc' | |||
| ? item.child_chunks.slice(0, FULL_DOC_PREVIEW_LENGTH) | |||
| : item.child_chunks | |||
| return ( | |||
| <ChunkContainer | |||
| key={item.content} | |||
| @@ -1124,7 +1116,7 @@ const StepTwo = ({ | |||
| characterCount={item.content.length} | |||
| > | |||
| <FormattedText> | |||
| {item.child_chunks.map((child, index) => { | |||
| {childChunks.map((child, index) => { | |||
| const indexForLabel = index + 1 | |||
| return ( | |||
| <PreviewSlice | |||
| @@ -4,7 +4,7 @@ import classNames from '@/utils/classnames' | |||
| const TriangleArrow: FC<ComponentProps<'svg'>> = props => ( | |||
| <svg xmlns="http://www.w3.org/2000/svg" width="24" height="11" viewBox="0 0 24 11" fill="none" {...props}> | |||
| <path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor"/> | |||
| <path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor" /> | |||
| </svg> | |||
| ) | |||
| @@ -65,7 +65,7 @@ export const OptionCard: FC<OptionCardProps> = forwardRef((props, ref) => { | |||
| (isActive && !noHighlight) | |||
| ? 'border-[1.5px] border-components-option-card-option-selected-border' | |||
| : 'border border-components-option-card-option-border', | |||
| disabled && 'opacity-50 cursor-not-allowed', | |||
| disabled && 'opacity-50 pointer-events-none', | |||
| className, | |||
| )} | |||
| style={{ | |||