瀏覽代碼

Merge branch 'main' into fix/chore-fix

tags/1.0.0-beta.1
Yeuoly 9 月之前
父節點
當前提交
fb309462ad
共有 100 個文件被更改,包括 2805 次插入1931 次删除
  1. 3
    0
      api/.env.example
  2. 4
    1
      api/app.py
  3. 10
    0
      api/configs/feature/__init__.py
  4. 6
    0
      api/configs/middleware/vdb/milvus_config.py
  5. 1
    1
      api/configs/packaging/__init__.py
  6. 2
    1
      api/controllers/console/app/app.py
  7. 2
    3
      api/controllers/console/app/completion.py
  8. 1
    2
      api/controllers/console/app/statistic.py
  9. 2
    0
      api/controllers/console/datasets/datasets.py
  10. 2
    1
      api/controllers/console/datasets/datasets_document.py
  11. 5
    1
      api/controllers/console/explore/completion.py
  12. 5
    1
      api/controllers/console/explore/workflow.py
  13. 2
    3
      api/controllers/service_api/app/completion.py
  14. 1
    2
      api/controllers/service_api/app/workflow.py
  15. 8
    2
      api/controllers/service_api/dataset/document.py
  16. 22
    15
      api/controllers/service_api/wraps.py
  17. 5
    1
      api/controllers/web/completion.py
  18. 5
    1
      api/controllers/web/workflow.py
  19. 2
    2
      api/core/app/apps/advanced_chat/app_generator.py
  20. 140
    109
      api/core/app/apps/advanced_chat/generate_task_pipeline.py
  21. 2
    2
      api/core/app/apps/agent_chat/app_generator.py
  22. 2
    2
      api/core/app/apps/chat/app_generator.py
  23. 2
    2
      api/core/app/apps/completion/app_generator.py
  24. 3
    2
      api/core/app/apps/workflow/app_generator.py
  25. 89
    72
      api/core/app/apps/workflow/generate_task_pipeline.py
  26. 1
    1
      api/core/app/entities/app_invoke_entities.py
  27. 0
    11
      api/core/app/task_pipeline/based_generate_task_pipeline.py
  28. 13
    4
      api/core/app/task_pipeline/message_cycle_manage.py
  29. 37
    23
      api/core/app/task_pipeline/workflow_cycle_manage.py
  30. 0
    0
      api/core/app/task_pipeline/workflow_cycle_state_manager.py
  31. 42
    4
      api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py
  32. 2
    0
      api/core/rag/datasource/vdb/baidu/baidu_vector.py
  33. 2
    0
      api/core/rag/datasource/vdb/chroma/chroma_vector.py
  34. 104
    0
      api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py
  35. 2
    0
      api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
  36. 2
    0
      api/core/rag/datasource/vdb/field.py
  37. 171
    29
      api/core/rag/datasource/vdb/milvus/milvus_vector.py
  38. 2
    0
      api/core/rag/datasource/vdb/myscale/myscale_vector.py
  39. 2
    0
      api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
  40. 2
    0
      api/core/rag/datasource/vdb/oracle/oraclevector.py
  41. 5
    0
      api/core/rag/datasource/vdb/pgvector/pgvector.py
  42. 2
    0
      api/core/rag/datasource/vdb/tencent/tencent_vector.py
  43. 18
    19
      api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
  44. 6
    0
      api/core/rag/datasource/vdb/vector_factory.py
  45. 1
    0
      api/core/rag/datasource/vdb/vector_type.py
  46. 2
    3
      api/core/rag/extractor/pdf_extractor.py
  47. 5
    0
      api/core/rag/index_processor/processor/parent_child_index_processor.py
  48. 17
    2
      api/core/tools/custom_tool/tool.py
  49. 61
    19
      api/core/workflow/nodes/document_extractor/node.py
  50. 16
    1
      api/core/workflow/nodes/http_request/entities.py
  51. 43
    10
      api/core/workflow/nodes/http_request/executor.py
  52. 4
    0
      api/core/workflow/workflow_entry.py
  53. 1
    0
      api/docker/entrypoint.sh
  54. 1
    1
      api/extensions/ext_logging.py
  55. 1
    1
      api/factories/file_factory.py
  56. 2
    1
      api/libs/oauth_data_source.py
  57. 41
    0
      api/migrations/versions/2025_01_01_2000-a91b476a53de_change_workflow_runs_total_tokens_to_.py
  58. 2
    2
      api/models/workflow.py
  59. 1247
    1168
      api/poetry.lock
  60. 4
    4
      api/pyproject.toml
  61. 0
    17
      api/schedule/clean_unused_datasets_task.py
  62. 1
    1
      api/services/account_service.py
  63. 9
    4
      api/services/app_dsl_service.py
  64. 4
    1
      api/services/app_service.py
  65. 4
    4
      api/services/billing_service.py
  66. 26
    10
      api/services/dataset_service.py
  67. 1
    1
      api/services/workflow_app_service.py
  68. 4
    1
      api/tasks/deal_dataset_vector_index_task.py
  69. 55
    0
      api/tests/integration_tests/model_runtime/gpustack/test_speech2text.py
  70. 24
    0
      api/tests/integration_tests/model_runtime/gpustack/test_tts.py
  71. 2
    2
      api/tests/integration_tests/vdb/milvus/test_milvus.py
  72. 3
    3
      docker-legacy/docker-compose.yaml
  73. 15
    4
      docker/.env.example
  74. 17
    26
      docker/docker-compose-template.yaml
  75. 33
    30
      docker/docker-compose.yaml
  76. 25
    0
      docker/elasticsearch/docker-entrypoint.sh
  77. 3
    0
      web/.env.example
  78. 11
    2
      web/app/(commonLayout)/apps/Apps.tsx
  79. 38
    1
      web/app/(commonLayout)/datasets/template/template.en.mdx
  80. 40
    3
      web/app/(commonLayout)/datasets/template/template.zh.mdx
  81. 4
    3
      web/app/components/app/configuration/config-prompt/prompt-editor-height-resize-wrap.tsx
  82. 17
    1
      web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx
  83. 49
    70
      web/app/components/app/configuration/dataset-config/params-config/config-content.tsx
  84. 19
    18
      web/app/components/app/configuration/dataset-config/params-config/index.tsx
  85. 2
    2
      web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx
  86. 5
    15
      web/app/components/app/configuration/dataset-config/settings-modal/index.tsx
  87. 3
    3
      web/app/components/app/configuration/index.tsx
  88. 9
    1
      web/app/components/base/chat/chat/chat-input-area/index.tsx
  89. 1
    0
      web/app/components/base/chat/chat/index.tsx
  90. 2
    2
      web/app/components/base/chat/chat/question.tsx
  91. 2
    2
      web/app/components/base/markdown.tsx
  92. 7
    1
      web/app/components/base/param-item/top-k-item.tsx
  93. 8
    7
      web/app/components/datasets/common/check-rerank-model.ts
  94. 4
    1
      web/app/components/datasets/common/economical-retrieval-method-config/index.tsx
  95. 71
    45
      web/app/components/datasets/common/retrieval-method-config/index.tsx
  96. 43
    59
      web/app/components/datasets/common/retrieval-param-config/index.tsx
  97. 3
    0
      web/app/components/datasets/create/embedding-process/index.tsx
  98. 52
    60
      web/app/components/datasets/create/step-two/index.tsx
  99. 2
    2
      web/app/components/datasets/create/step-two/option-card.tsx
  100. 0
    0
      web/app/components/datasets/documents/detail/completed/index.tsx

+ 3
- 0
api/.env.example 查看文件

# Access token expiration time in minutes # Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60 ACCESS_TOKEN_EXPIRE_MINUTES=60


# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30

# celery configuration # celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1



+ 4
- 1
api/app.py 查看文件



app = create_migrations_app() app = create_migrations_app()
else: else:
if os.environ.get("FLASK_DEBUG", "False") != "True":
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey # type: ignore from gevent import monkey # type: ignore


# gevent # gevent

+ 10
- 0
api/configs/feature/__init__.py 查看文件

default=60, default=60,
) )


REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field(
description="Expiration time for refresh tokens in days",
default=30,
)

LOGIN_LOCKOUT_DURATION: PositiveInt = Field( LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.", description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
default=86400, default=86400,
default=4000, default=4000,
) )


CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field(
description="Maximum number of child chunks to preview",
default=50,
)



class MultiModalTransferConfig(BaseSettings): class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field( MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(

+ 6
- 0
api/configs/middleware/vdb/milvus_config.py 查看文件

description="Name of the Milvus database to connect to (default is 'default')", description="Name of the Milvus database to connect to (default is 'default')",
default="default", default="default",
) )

MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
"older versions",
default=True,
)

+ 1
- 1
api/configs/packaging/__init__.py 查看文件



CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="0.14.2",
default="0.15.0",
) )


COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

+ 2
- 1
api/controllers/console/app/app.py 查看文件

) )
parser.add_argument("name", type=str, location="args", required=False) parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)


args = parser.parse_args() args = parser.parse_args()


# get app list # get app list
app_service = AppService() app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
if not app_pagination: if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}



+ 2
- 3
api/controllers/console/app/completion.py 查看文件

from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
ProviderTokenNotInitError, ProviderTokenNotInitError,
QuotaExceededError, QuotaExceededError,
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")
raise InvokeRateLimitHttpError(ex.description) raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")

+ 1
- 2
api/controllers/console/app/statistic.py 查看文件

messages m messages m
ON c.id = m.conversation_id ON c.id = m.conversation_id
WHERE WHERE
c.override_model_configs IS NULL
AND c.app_id = :app_id"""
c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}


timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)

+ 2
- 0
api/controllers/console/datasets/datasets.py 查看文件

| VectorType.MYSCALE | VectorType.MYSCALE
| VectorType.ORACLE | VectorType.ORACLE
| VectorType.ELASTICSEARCH | VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR | VectorType.PGVECTOR
| VectorType.TIDB_ON_QDRANT | VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM | VectorType.LINDORM
| VectorType.MYSCALE | VectorType.MYSCALE
| VectorType.ORACLE | VectorType.ORACLE
| VectorType.ELASTICSEARCH | VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE | VectorType.COUCHBASE
| VectorType.PGVECTOR | VectorType.PGVECTOR
| VectorType.LINDORM | VectorType.LINDORM

+ 2
- 1
api/controllers/console/datasets/datasets_document.py 查看文件

parser.add_argument("original_document_id", type=str, required=False, location="json") parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")

parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
parser.add_argument( parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json" "doc_language", type=str, default="English", required=False, nullable=False, location="json"
) )

+ 5
- 1
api/controllers/console/explore/completion.py 查看文件

from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper

+ 5
- 1
api/controllers/console/explore/workflow.py 查看文件

from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.login import current_user from libs.login import current_user

+ 2
- 3
api/controllers/service_api/app/completion.py 查看文件

from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
ProviderTokenNotInitError, ProviderTokenNotInitError,
QuotaExceededError, QuotaExceededError,
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")

+ 1
- 2
api/controllers/service_api/app/workflow.py 查看文件

from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
AppInvokeQuotaExceededError,
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
ProviderTokenNotInitError, ProviderTokenNotInitError,
QuotaExceededError, QuotaExceededError,
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e:
except ValueError as e:
raise e raise e
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")

+ 8
- 2
api/controllers/service_api/dataset/document.py 查看文件

user=current_user, user=current_user,
source="datasets", source="datasets",
) )
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig(**args)
raise FileTooLargeError(file_too_large_error.description) raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
args["original_document_id"] = str(document_id) args["original_document_id"] = str(document_id)

+ 22
- 15
api/controllers/service_api/wraps.py 查看文件

from collections.abc import Callable from collections.abc import Callable
from datetime import UTC, datetime
from datetime import UTC, datetime, timedelta
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
from flask_login import user_logged_in # type: ignore from flask_login import user_logged_in # type: ignore
from flask_restful import Resource # type: ignore from flask_restful import Resource # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized from werkzeug.exceptions import Forbidden, Unauthorized


from extensions.ext_database import db from extensions.ext_database import db
return decorator return decorator




def validate_and_get_api_token(scope=None):
def validate_and_get_api_token(scope: str | None = None):
""" """
Validate and get API token. Validate and get API token.
""" """
if auth_scheme != "bearer": if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'") raise Unauthorized("Authorization scheme must be 'Bearer'")


api_token = (
db.session.query(ApiToken)
.filter(
ApiToken.token == auth_token,
ApiToken.type == scope,
current_time = datetime.now(UTC).replace(tzinfo=None)
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
.values(last_used_at=current_time)
.returning(ApiToken)
) )
.first()
)

if not api_token:
raise Unauthorized("Access token is invalid")

api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
result = session.execute(update_stmt)
api_token = result.scalar_one_or_none()

if not api_token:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
raise Unauthorized("Access token is invalid")
else:
session.commit()


return api_token return api_token



+ 5
- 1
api/controllers/web/completion.py 查看文件

from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value

+ 5
- 1
api/controllers/web/workflow.py 查看文件

from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser

+ 2
- 2
api/core/app/apps/advanced_chat/app_generator.py 查看文件

from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from extensions.ext_database import db from extensions.ext_database import db
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

+ 140
- 109
api/core/app/apps/advanced_chat/generate_task_pipeline.py 查看文件

from models.enums import CreatedByRole from models.enums import CreatedByRole
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecution,
WorkflowRunStatus, WorkflowRunStatus,
) )


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)




class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
class AdvancedChatAppGenerateTaskPipeline:
""" """
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """


_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None

def __init__( def __init__(
self, self,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
stream: bool, stream: bool,
dialogue_count: int, dialogue_count: int,
) -> None: ) -> None:
super().__init__(
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
stream=stream, stream=stream,
else: else:
raise NotImplementedError(f"User type not supported: {type(user)}") raise NotImplementedError(f"User type not supported: {type(user)}")


self._workflow_cycle_manager = WorkflowCycleManage(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
)

self._task_state = WorkflowTaskState()
self._message_cycle_manager = MessageCycleManage(
application_generate_entity=application_generate_entity, task_state=self._task_state
)

self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict self._workflow_features_dict = workflow.features_dict

self._conversation_id = conversation.id self._conversation_id = conversation.id
self._conversation_mode = conversation.mode self._conversation_mode = conversation.mode

self._message_id = message.id self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp()) self._message_created_at = int(message.created_at.timestamp())

self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}

self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self._wip_workflow_agent_logs = {}

self._conversation_name_generate_thread = None
self._conversation_name_generate_thread: Thread | None = None
self._recorded_files: list[Mapping[str, Any]] = [] self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id = ""
self._workflow_run_id: str = ""


def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
""" """
:return: :return:
""" """
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query conversation_id=self._conversation_id, query=self._application_generate_entity.query
) )


generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)


if self._stream:
if self._base_task_pipeline._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
# init fake graph runtime state # init fake graph runtime state
graph_runtime_state: Optional[GraphRuntimeState] = None graph_runtime_state: Optional[GraphRuntimeState] = None


for queue_message in self._queue_manager.listen():
for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event event = queue_message.event


if isinstance(event, QueuePingEvent): if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent): elif isinstance(event, QueueErrorEvent):
with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
with Session(db.engine, expire_on_commit=False) as session:
err = self._base_task_pipeline._handle_error(
event=event, session=session, message_id=self._message_id
)
session.commit() session.commit()
yield self._error_to_stream_response(err)
yield self._base_task_pipeline._error_to_stream_response(err)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state # override graph runtime state
graph_runtime_state = event.graph_runtime_state graph_runtime_state = event.graph_runtime_state


with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run # init workflow run
workflow_run = self._handle_workflow_run_start(
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session, session=session,
workflow_id=self._workflow_id, workflow_id=self._workflow_id,
user_id=self._user_id, user_id=self._user_id,
if not message: if not message:
raise ValueError(f"Message not found: {self._message_id}") raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_run.id message.workflow_run_id = workflow_run.id
workflow_start_resp = self._workflow_start_to_stream_response(
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_workflow_node_execution_retried(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
node_retry_resp = self._workflow_node_retry_to_stream_response(
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_node_execution_start(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )


node_start_resp = self._workflow_node_start_to_stream_response(
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
# Record files if it's an answer node or end node # Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]: if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
self._recorded_files.extend(
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
)


with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event
)


node_finish_resp = self._workflow_node_finish_to_stream_response(
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
if node_finish_resp: if node_finish_resp:
yield node_finish_resp yield node_finish_resp
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, event=event
)


node_finish_resp = self._workflow_node_finish_to_stream_response(
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_start_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )


yield parallel_start_resp yield parallel_start_resp
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_finish_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )


yield parallel_finish_resp yield parallel_finish_resp
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_start_resp = self._workflow_iteration_start_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_next_resp = self._workflow_iteration_next_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
trace_manager=trace_manager, trace_manager=trace_manager,
) )


workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()


yield workflow_finish_resp yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent): elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")


with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_partial_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
conversation_id=None, conversation_id=None,
trace_manager=trace_manager, trace_manager=trace_manager,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()


yield workflow_finish_resp yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
self._base_task_pipeline._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent): elif isinstance(event, QueueWorkflowFailedEvent):
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")


with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
trace_manager=trace_manager, trace_manager=trace_manager,
exceptions_count=event.exceptions_count, exceptions_count=event.exceptions_count,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
err = self._base_task_pipeline._handle_error(
event=err_event, session=session, message_id=self._message_id
)
session.commit() session.commit()


yield workflow_finish_resp yield workflow_finish_resp
yield self._error_to_stream_response(err)
yield self._base_task_pipeline._error_to_stream_response(err)
break break
elif isinstance(event, QueueStopEvent): elif isinstance(event, QueueStopEvent):
if self._workflow_run_id and graph_runtime_state: if self._workflow_run_id and graph_runtime_state:
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
conversation_id=self._conversation_id, conversation_id=self._conversation_id,
trace_manager=trace_manager, trace_manager=trace_manager,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
break break
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
self._message_cycle_manager._handle_retriever_resources(event)


with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
session.commit() session.commit()
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
self._message_cycle_manager._handle_annotation_reply(event)


with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
tts_publisher.publish(queue_message) tts_publisher.publish(queue_message)


self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._message_to_stream_response(
yield self._message_cycle_manager._message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
) )
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation # published by moderation
yield self._message_replace_to_stream_response(answer=event.text)
yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent): elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")


output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
self._task_state.answer
)
if output_moderation_answer: if output_moderation_answer:
self._task_state.answer = output_moderation_answer self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
yield self._message_cycle_manager._message_replace_to_stream_response(
answer=output_moderation_answer
)


# Save message # Save message
with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
self._save_message(session=session, graph_runtime_state=graph_runtime_state) self._save_message(session=session, graph_runtime_state=graph_runtime_state)
session.commit() session.commit()


def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
message = self._get_message(session=session) message = self._get_message(session=session)
message.answer = self._task_state.answer message.answer = self._task_state.answer
message.provider_response_latency = time.perf_counter() - self._start_at
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
:param text: text :param text: text
:return: True if output moderation should direct output, otherwise False :return: True if output moderation should direct output, otherwise False
""" """
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
if self._base_task_pipeline._output_moderation_handler:
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output # stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish(
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
self._base_task_pipeline._queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
) )


self._queue_manager.publish(
self._base_task_pipeline._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
) )
return True return True
else: else:
self._output_moderation_handler.append_new_token(text)
self._base_task_pipeline._output_moderation_handler.append_new_token(text)


return False return False



+ 2
- 2
api/core/app/apps/agent_chat/app_generator.py 查看文件

from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

+ 2
- 2
api/core/app/apps/chat/app_generator.py 查看文件

from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

+ 2
- 2
api/core/app/apps/completion/app_generator.py 查看文件

from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

+ 3
- 2
api/core/app/apps/workflow/app_generator.py 查看文件

from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"] node_id=node_id, inputs=args["inputs"]
), ),
workflow_run_id=str(uuid.uuid4()),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
except ValueError as e:
if dify_config.DEBUG: if dify_config.DEBUG:
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)

+ 89
- 72
api/core/app/apps/workflow/generate_task_pipeline.py 查看文件

import logging import logging
import time import time
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union
from typing import Optional, Union


from sqlalchemy.orm import Session from sqlalchemy.orm import Session


Workflow, Workflow,
WorkflowAppLog, WorkflowAppLog,
WorkflowAppLogCreatedFrom, WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun, WorkflowRun,
WorkflowRunStatus, WorkflowRunStatus,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)




class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage):
class WorkflowAppGenerateTaskPipeline:
""" """
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """


_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]

def __init__( def __init__(
self, self,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
user: Union[Account, EndUser], user: Union[Account, EndUser],
stream: bool, stream: bool,
) -> None: ) -> None:
super().__init__(
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
stream=stream, stream=stream,
else: else:
raise ValueError(f"Invalid user type: {type(user)}") raise ValueError(f"Invalid user type: {type(user)}")


self._workflow_cycle_manager = WorkflowCycleManage(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
)

self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict self._workflow_features_dict = workflow.features_dict

self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
}

self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
self._workflow_run_id = "" self._workflow_run_id = ""


:return: :return:
""" """
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
if self._base_task_pipeline._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
return self._to_blocking_response(generator) return self._to_blocking_response(generator)
""" """
graph_runtime_state = None graph_runtime_state = None


for queue_message in self._queue_manager.listen():
for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event event = queue_message.event


if isinstance(event, QueuePingEvent): if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent): elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event=event)
yield self._error_to_stream_response(err)
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state # override graph runtime state
graph_runtime_state = event.graph_runtime_state graph_runtime_state = event.graph_runtime_state


with Session(db.engine) as session:
with Session(db.engine, expire_on_commit=False) as session:
# init workflow run # init workflow run
workflow_run = self._handle_workflow_run_start(
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
session=session, session=session,
workflow_id=self._workflow_id, workflow_id=self._workflow_id,
user_id=self._user_id, user_id=self._user_id,
created_by_role=self._created_by_role, created_by_role=self._created_by_role,
) )
self._workflow_run_id = workflow_run.id self._workflow_run_id = workflow_run.id
start_resp = self._workflow_start_to_stream_response(
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
): ):
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_workflow_node_execution_retried(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
response = self._workflow_node_retry_to_stream_response(
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_node_execution_start(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event session=session, workflow_run=workflow_run, event=event
) )
node_start_response = self._workflow_node_start_to_stream_response(
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
if node_start_response: if node_start_response:
yield node_start_response yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
node_success_response = self._workflow_node_finish_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event
)
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
if node_success_response: if node_success_response:
yield node_success_response yield node_success_response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, session=session,
event=event, event=event,
) )
node_failed_response = self._workflow_node_finish_to_stream_response(
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_start_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )


yield parallel_start_resp yield parallel_start_resp
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
parallel_finish_resp = (
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
) )


yield parallel_finish_resp yield parallel_finish_resp
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_start_resp = self._workflow_iteration_start_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_next_resp = self._workflow_iteration_next_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")


with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._get_workflow_run(
session=session, workflow_run_id=self._workflow_run_id
)
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")


with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
# save workflow app log # save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)


workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, session=session,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run, workflow_run=workflow_run,
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")


with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_partial_success(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
# save workflow app log # save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)


workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")


with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session, session=session,
workflow_run_id=self._workflow_run_id, workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
# save workflow app log # save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run) self._save_workflow_app_log(session=session, workflow_run=workflow_run)


workflow_finish_resp = self._workflow_finish_to_stream_response(
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit() session.commit()

+ 1
- 1
api/core/app/entities/app_invoke_entities.py 查看文件



# app config # app config
app_config: WorkflowUIBasedAppConfig app_config: WorkflowUIBasedAppConfig
workflow_run_id: Optional[str] = None
workflow_run_id: str


class SingleIterationRunEntity(BaseModel): class SingleIterationRunEntity(BaseModel):
""" """

+ 0
- 11
api/core/app/task_pipeline/based_generate_task_pipeline.py 查看文件

from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
ErrorStreamResponse, ErrorStreamResponse,
PingStreamResponse, PingStreamResponse,
TaskState,
) )
from core.errors.error import QuotaExceededError from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application. BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """


_task_state: TaskState
_application_generate_entity: AppGenerateEntity

def __init__( def __init__(
self, self,
application_generate_entity: AppGenerateEntity, application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
stream: bool, stream: bool,
) -> None: ) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param user: user
:param stream: stream
"""
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._start_at = time.perf_counter() self._start_at = time.perf_counter()

+ 13
- 4
api/core/app/task_pipeline/message_cycle_manage.py 查看文件





class MessageCycleManage: class MessageCycleManage:
_application_generate_entity: Union[
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
]
_task_state: Union[EasyUITaskState, WorkflowTaskState]
def __init__(
self,
*,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None:
self._application_generate_entity = application_generate_entity
self._task_state = task_state


def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
""" """

+ 37
- 23
api/core/app/task_pipeline/workflow_cycle_manage.py 查看文件

ParallelBranchStartStreamResponse, ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse, WorkflowFinishStreamResponse,
WorkflowStartStreamResponse, WorkflowStartStreamResponse,
WorkflowTaskState,
) )
from core.file import FILE_MODEL_IDENTITY, File from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
WorkflowRunStatus, WorkflowRunStatus,
) )


from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
from .exc import WorkflowRunNotFoundError




class WorkflowCycleManage: class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables


def _handle_workflow_run_start( def _handle_workflow_run_start(
self, self,
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})


# init workflow run # init workflow run
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())


workflow_run = WorkflowRun() workflow_run = WorkflowRun()
workflow_run.id = workflow_run_id workflow_run.id = workflow_run_id
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count workflow_run.exceptions_count = exceptions_count


stmt = select(WorkflowNodeExecution).where(
stmt = select(WorkflowNodeExecution.node_execution_id).where(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id, WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.workflow_run_id == workflow_run.id, WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
) )

running_workflow_node_executions = session.scalars(stmt).all()
ids = session.scalars(stmt).all()
# Use self._get_workflow_node_execution here to make sure the cache is updated
running_workflow_node_executions = [
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
]


for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error workflow_node_execution.error = error
finish_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.finished_at = finish_at
workflow_node_execution.elapsed_time = (finish_at - workflow_node_execution.created_at).total_seconds()
workflow_node_execution.finished_at = now
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()


if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)


session.add(workflow_node_execution) session.add(workflow_node_execution)

self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution


def _handle_workflow_node_execution_success( def _handle_workflow_node_execution_success(
workflow_node_execution.finished_at = finished_at workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time


workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution return workflow_node_execution


def _handle_workflow_node_execution_failed( def _handle_workflow_node_execution_failed(
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata


workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution return workflow_node_execution


def _handle_workflow_node_execution_retried( def _handle_workflow_node_execution_retried(
workflow_node_execution.index = event.node_run_index workflow_node_execution.index = event.node_run_index


session.add(workflow_node_execution) session.add(workflow_node_execution)

self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution


################################################# #################################################
return None return None


def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run
:param workflow_run_id: workflow run id
:return:
"""
if self._workflow_run and self._workflow_run.id == workflow_run_id:
cached_workflow_run = self._workflow_run
cached_workflow_run = session.merge(cached_workflow_run)
return cached_workflow_run
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt) workflow_run = session.scalar(stmt)
if not workflow_run: if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id) raise WorkflowRunNotFoundError(workflow_run_id)
self._workflow_run = workflow_run


return workflow_run return workflow_run


def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id)
workflow_node_execution = session.scalar(stmt)
if not workflow_node_execution:
raise WorkflowNodeExecutionNotFoundError(node_execution_id)

return workflow_node_execution
if node_execution_id not in self._workflow_node_executions:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return cached_workflow_node_execution


def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
""" """

+ 0
- 0
api/core/app/task_pipeline/workflow_cycle_state_manager.py 查看文件


+ 42
- 4
api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py 查看文件

import tiktoken
from threading import Lock
from typing import Any

_tokenizer: Any = None
_lock = Lock()




class GPT2Tokenizer: class GPT2Tokenizer:
@staticmethod
def _get_num_tokens_by_gpt2(text: str) -> int:
"""
use gpt2 tokenizer to get num tokens
"""
_tokenizer = GPT2Tokenizer.get_encoder()
tokens = _tokenizer.encode(text)
return len(tokens)

@staticmethod @staticmethod
def get_num_tokens(text: str) -> int: def get_num_tokens(text: str) -> int:
encoding = tiktoken.encoding_for_model("gpt2")
tiktoken_vec = encoding.encode(text)
return len(tiktoken_vec)
# Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
#
# future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
# result = future.result()
# return cast(int, result)
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)

@staticmethod
def get_encoder() -> Any:
global _tokenizer, _lock
with _lock:
if _tokenizer is None:
# Try to use tiktoken to get the tokenizer because it is faster
#
try:
import tiktoken

_tokenizer = tiktoken.get_encoding("gpt2")
except Exception:
from os.path import abspath, dirname, join

from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore

base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)

return _tokenizer

+ 2
- 0
api/core/rag/datasource/vdb/baidu/baidu_vector.py 查看文件

return False return False


def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
quoted_ids = [f"'{id}'" for id in ids] quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")



+ 2
- 0
api/core/rag/datasource/vdb/chroma/chroma_vector.py 查看文件

self._client.delete_collection(self._collection_name) self._client.delete_collection(self._collection_name)


def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
collection = self._client.get_or_create_collection(self._collection_name) collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(ids=ids) collection.delete(ids=ids)



+ 104
- 0
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py 查看文件

import json
import logging
from typing import Any, Optional

from flask import current_app

from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchConfig,
ElasticSearchVector,
ElasticSearchVectorFactory,
)
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_redis import redis_client
from models.dataset import Dataset

logger = logging.getLogger(__name__)


class ElasticSearchJaVector(ElasticSearchVector):
def create_collection(
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return

if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
settings = {
"analysis": {
"analyzer": {
"ja_analyzer": {
"type": "custom",
"char_filter": [
"icu_normalizer",
"kuromoji_iteration_mark",
],
"tokenizer": "kuromoji_tokenizer",
"filter": [
"kuromoji_baseform",
"kuromoji_part_of_speech",
"ja_stop",
"kuromoji_number",
"kuromoji_stemmer",
],
}
}
}
}
mappings = {
"properties": {
Field.CONTENT_KEY.value: {
"type": "text",
"analyzer": "ja_analyzer",
"search_analyzer": "ja_analyzer",
},
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"index": True,
"similarity": "cosine",
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
},
},
}
}
self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings)

redis_client.set(collection_exist_cache_key, 1, ex=3600)


class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name))

config = current_app.config
return ElasticSearchJaVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get("ELASTICSEARCH_HOST", "localhost"),
port=config.get("ELASTICSEARCH_PORT", 9200),
username=config.get("ELASTICSEARCH_USERNAME", ""),
password=config.get("ELASTICSEARCH_PASSWORD", ""),
),
attributes=[],
)

+ 2
- 0
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py 查看文件

return bool(self._client.exists(index=self._collection_name, id=id)) return bool(self._client.exists(index=self._collection_name, id=id))


def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
for id in ids: for id in ids:
self._client.delete(index=self._collection_name, id=id) self._client.delete(index=self._collection_name, id=id)



+ 2
- 0
api/core/rag/datasource/vdb/field.py 查看文件

METADATA_KEY = "metadata" METADATA_KEY = "metadata"
GROUP_KEY = "group_id" GROUP_KEY = "group_id"
VECTOR = "vector" VECTOR = "vector"
# Sparse Vector aims to support full text search
SPARSE_VECTOR = "sparse_vector"
TEXT_KEY = "text" TEXT_KEY = "text"
PRIMARY_KEY = "id" PRIMARY_KEY = "id"
DOC_ID = "metadata.doc_id" DOC_ID = "metadata.doc_id"

+ 171
- 29
api/core/rag/datasource/vdb/milvus/milvus_vector.py 查看文件

import logging import logging
from typing import Any, Optional from typing import Any, Optional


from packaging import version
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException # type: ignore from pymilvus import MilvusClient, MilvusException # type: ignore
from pymilvus.milvus_client import IndexParams # type: ignore from pymilvus.milvus_client import IndexParams # type: ignore




class MilvusConfig(BaseModel): class MilvusConfig(BaseModel):
uri: str
token: Optional[str] = None
user: str
password: str
batch_size: int = 100
database: str = "default"
"""
Configuration class for Milvus connection.
"""

uri: str # Milvus server URI
token: Optional[str] = None # Optional token for authentication
user: str # Username for authentication
password: str # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search


@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
"""
Validate the configuration values.
Raises ValueError if required fields are missing.
"""
if not values.get("uri"): if not values.get("uri"):
raise ValueError("config MILVUS_URI is required") raise ValueError("config MILVUS_URI is required")
if not values.get("user"): if not values.get("user"):
return values return values


def to_milvus_params(self): def to_milvus_params(self):
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
return { return {
"uri": self.uri, "uri": self.uri,
"token": self.token, "token": self.token,




class MilvusVector(BaseVector): class MilvusVector(BaseVector):
"""
Milvus vector storage implementation.
"""

def __init__(self, collection_name: str, config: MilvusConfig): def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name) super().__init__(collection_name)
self._client_config = config self._client_config = config
self._client = self._init_client(config) self._client = self._init_client(config)
self._consistency_level = "Session"
self._fields: list[str] = []
self._consistency_level = "Session" # Consistency level for Milvus operations
self._fields: list[str] = [] # List of fields in the collection
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported

def _check_hybrid_search_support(self) -> bool:
"""
Check if the current Milvus version supports hybrid search.
Returns True if the version is >= 2.5.0, otherwise False.
"""
if not self._client_config.enable_hybrid_search:
return False

try:
milvus_version = self._client.get_server_version()
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
except Exception as e:
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
return False


def get_type(self) -> str: def get_type(self) -> str:
"""
Get the type of vector storage (Milvus).
"""
return VectorType.MILVUS return VectorType.MILVUS


def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""
Create a collection and add texts with embeddings.
"""
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata if d.metadata is not None else {} for d in texts] metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas, index_params) self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings) self.add_texts(texts, embeddings)


def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Add texts and their embeddings to the collection.
"""
insert_dict_list = [] insert_dict_list = []
for i in range(len(documents)): for i in range(len(documents)):
insert_dict = { insert_dict = {
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
# function will automatically convert the native text into a sparse vector for us.
Field.CONTENT_KEY.value: documents[i].page_content, Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata, Field.METADATA_KEY.value: documents[i].metadata,
insert_dict_list.append(insert_dict) insert_dict_list.append(insert_dict)
# Total insert count # Total insert count
total_count = len(insert_dict_list) total_count = len(insert_dict_list)

pks: list[str] = [] pks: list[str] = []


for i in range(0, total_count, 1000): for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i : i + 1000]
# Insert into the collection. # Insert into the collection.
batch_insert_list = insert_dict_list[i : i + 1000]
try: try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids) pks.extend(ids)
return pks return pks


def get_ids_by_metadata_field(self, key: str, value: str): def get_ids_by_metadata_field(self, key: str, value: str):
"""
Get document IDs by metadata field key and value.
"""
result = self._client.query( result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"] collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
) )
return None return None


def delete_by_metadata_field(self, key: str, value: str): def delete_by_metadata_field(self, key: str, value: str):
"""
Delete documents by metadata field key and value.
"""
if self._client.has_collection(self._collection_name): if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value) ids = self.get_ids_by_metadata_field(key, value)
if ids: if ids:
self._client.delete(collection_name=self._collection_name, pks=ids) self._client.delete(collection_name=self._collection_name, pks=ids)


def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
"""
Delete documents by their IDs.
"""
if self._client.has_collection(self._collection_name): if self._client.has_collection(self._collection_name):
result = self._client.query( result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"] collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
self._client.delete(collection_name=self._collection_name, pks=ids) self._client.delete(collection_name=self._collection_name, pks=ids)


def delete(self) -> None: def delete(self) -> None:
"""
Delete the entire collection.
"""
if self._client.has_collection(self._collection_name): if self._client.has_collection(self._collection_name):
self._client.drop_collection(self._collection_name, None) self._client.drop_collection(self._collection_name, None)


def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
"""
Check if a text with the given ID exists in the collection.
"""
if not self._client.has_collection(self._collection_name): if not self._client.has_collection(self._collection_name):
return False return False




return len(result) > 0 return len(result) > 0


def field_exists(self, field: str) -> bool:
"""
Check if a field exists in the collection.
"""
return field in self._fields

def _process_search_results(
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
) -> list[Document]:
"""
Common method to process search results

:param results: Search results
:param output_fields: Fields to be output
:param score_threshold: Score threshold for filtering
:return: List of documents
"""
docs = []
for result in results[0]:
metadata = result["entity"].get(output_fields[1], {})
metadata["score"] = result["distance"]

if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
docs.append(doc)

return docs

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
"""
Search for documents by vector similarity.
"""
results = self._client.search( results = self._client.search(
collection_name=self._collection_name, collection_name=self._collection_name,
data=[query_vector], data=[query_vector],
anns_field=Field.VECTOR.value,
limit=kwargs.get("top_k", 4), limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
) )
# Organize results.
docs = []
for result in results[0]:
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs

return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)


def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
"""
Search for documents by full-text search (if hybrid search is enabled).
"""
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
return []

results = self._client.search(
collection_name=self._collection_name,
data=[query],
anns_field=Field.SPARSE_VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)

return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)


def create_collection( def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
): ):
"""
Create a new collection in Milvus with the specified schema and index parameters.
"""
lock_name = "vector_indexing_lock_{}".format(self._collection_name) lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20): with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
return return
# Grab the existing collection if it exists # Grab the existing collection if it exists
if not self._client.has_collection(self._collection_name): if not self._client.has_collection(self._collection_name):
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
from pymilvus.orm.types import infer_dtype_bydata # type: ignore from pymilvus.orm.types import infer_dtype_bydata # type: ignore


# Determine embedding dim # Determine embedding dim
if metadatas: if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))


# Create the text field
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
# Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
fields.append(
FieldSchema(
Field.CONTENT_KEY.value,
DataType.VARCHAR,
max_length=65_535,
enable_analyzer=self._hybrid_search_enabled,
)
)
# Create the primary key field # Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors # Create the vector field, supports binary or float vectors
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))


# Create the schema for the collection
schema = CollectionSchema(fields) schema = CollectionSchema(fields)


# Create custom function to support text to sparse vector by BM25
if self._hybrid_search_enabled:
bm25_function = Function(
name="text_bm25_emb",
input_field_names=[Field.CONTENT_KEY.value],
output_field_names=[Field.SPARSE_VECTOR.value],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)

for x in schema.fields: for x in schema.fields:
self._fields.append(x.name) self._fields.append(x.name)
# Since primary field is auto-id, no need to track it # Since primary field is auto-id, no need to track it
index_params_obj = IndexParams() index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params) index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)


# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
index_params_obj.add_index(
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
)

# Create the collection # Create the collection
collection_name = self._collection_name
self._client.create_collection( self._client.create_collection(
collection_name=collection_name,
collection_name=self._collection_name,
schema=schema, schema=schema,
index_params=index_params_obj, index_params=index_params_obj,
consistency_level=self._consistency_level, consistency_level=self._consistency_level,
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)


def _init_client(self, config) -> MilvusClient: def _init_client(self, config) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database) client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client return client




class MilvusVectorFactory(AbstractVectorFactory): class MilvusVectorFactory(AbstractVectorFactory):
"""
Factory class for creating MilvusVector instances.
"""

def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
"""
Initialize a MilvusVector instance for the given dataset.
"""
if dataset.index_struct_dict: if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix collection_name = class_prefix
user=dify_config.MILVUS_USER or "", user=dify_config.MILVUS_USER or "",
password=dify_config.MILVUS_PASSWORD or "", password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "", database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
), ),
) )

+ 2
- 0
api/core/rag/datasource/vdb/myscale/myscale_vector.py 查看文件

return results.row_count > 0 return results.row_count > 0


def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._client.command( self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
) )

+ 2
- 0
api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py 查看文件

return bool(cur.rowcount != 0) return bool(cur.rowcount != 0)


def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._client.delete(table_name=self._collection_name, ids=ids) self._client.delete(table_name=self._collection_name, ids=ids)


def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:

+ 2
- 0
api/core/rag/datasource/vdb/oracle/oraclevector.py 查看文件

return docs return docs


def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
with self._get_cursor() as cur: with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))



+ 5
- 0
api/core/rag/datasource/vdb/pgvector/pgvector.py 查看文件

return docs return docs


def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
# Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
# Scenario 1: extract a document fails, resulting in a table not being created.
# Then clicking the retry button triggers a delete operation on an empty list.
if not ids:
return
with self._get_cursor() as cur: with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))



+ 2
- 0
api/core/rag/datasource/vdb/tencent/tencent_vector.py 查看文件

return False return False


def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
self._db.collection(self._collection_name).delete(document_ids=ids) self._db.collection(self._collection_name).delete(document_ids=ids)


def delete_by_metadata_field(self, key: str, value: str) -> None: def delete_by_metadata_field(self, key: str, value: str) -> None:

+ 18
- 19
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py 查看文件

db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
) )
if not tidb_auth_binding: if not tidb_auth_binding:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"

else:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding) db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none() .one_or_none()
) )
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"

if idle_tidb_auth_binding:
idle_tidb_auth_binding.active = True
idle_tidb_auth_binding.tenant_id = dataset.tenant_id
db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}"
else: else:
new_cluster = TidbService.create_tidb_serverless_cluster( new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID or "", dify_config.TIDB_PROJECT_ID or "",
db.session.add(new_tidb_auth_binding) db.session.add(new_tidb_auth_binding)
db.session.commit() db.session.commit()
TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}"

else: else:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"



+ 6
- 0
api/core/rag/datasource/vdb/vector_factory.py 查看文件

from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory


return ElasticSearchVectorFactory return ElasticSearchVectorFactory
case VectorType.ELASTICSEARCH_JA:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
ElasticSearchJaVectorFactory,
)

return ElasticSearchJaVectorFactory
case VectorType.TIDB_VECTOR: case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory



+ 1
- 0
api/core/rag/datasource/vdb/vector_type.py 查看文件

TENCENT = "tencent" TENCENT = "tencent"
ORACLE = "oracle" ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch" ELASTICSEARCH = "elasticsearch"
ELASTICSEARCH_JA = "elasticsearch-ja"
LINDORM = "lindorm" LINDORM = "lindorm"
COUCHBASE = "couchbase" COUCHBASE = "couchbase"
BAIDU = "baidu" BAIDU = "baidu"

+ 2
- 3
api/core/rag/extractor/pdf_extractor.py 查看文件

self._file_cache_key = file_cache_key self._file_cache_key = file_cache_key


def extract(self) -> list[Document]: def extract(self) -> list[Document]:
plaintext_file_key = ""
plaintext_file_exists = False plaintext_file_exists = False
if self._file_cache_key: if self._file_cache_key:
try: try:
text = "\n\n".join(text_list) text = "\n\n".join(text_list)


# save plaintext file for caching # save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode("utf-8"))
if not plaintext_file_exists and self._file_cache_key:
storage.save(self._file_cache_key, text.encode("utf-8"))


return documents return documents



+ 5
- 0
api/core/rag/index_processor/processor/parent_child_index_processor.py 查看文件

import uuid import uuid
from typing import Optional from typing import Optional


from configs import dify_config
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
child_nodes = self._split_child_nodes( child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
) )
if kwargs.get("preview"):
if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER:
child_nodes = child_nodes[: dify_config.CHILD_CHUNKS_PREVIEW_NUMBER]

document.children = child_nodes document.children = child_nodes
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content) hash = helper.generate_text_hash(document.page_content)

+ 17
- 2
api/core/tools/custom_tool/tool.py 查看文件

else: else:
body = body body = body


if method in {"get", "head", "post", "put", "delete", "patch"}:
response: httpx.Response = getattr(ssrf_proxy, method)(
if method in {
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
response: httpx.Response = getattr(ssrf_proxy, method.lower())(
url, url,
params=params, params=params,
headers=headers, headers=headers,

+ 61
- 19
api/core/workflow/nodes/document_extractor/node.py 查看文件

import io import io
import json import json
import logging import logging
import operator
import os import os
import tempfile import tempfile
from typing import cast
from collections.abc import Mapping, Sequence
from typing import Any, cast


import docx import docx
import pandas as pd import pandas as pd
import pypdfium2 # type: ignore import pypdfium2 # type: ignore
import yaml # type: ignore import yaml # type: ignore
from docx.table import Table
from docx.text.paragraph import Paragraph


from configs import dify_config from configs import dify_config
from core.file import File, FileTransferMethod, file_manager from core.file import File, FileTransferMethod, file_manager
process_data=process_data, process_data=process_data,
) )


@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: DocumentExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {node_id + ".files": node_data.variable_selector}



def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
"""Extract text from a file based on its MIME type.""" """Extract text from a file based on its MIME type."""
doc_file = io.BytesIO(file_content) doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file) doc = docx.Document(doc_file)
text = [] text = []
# Process paragraphs
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text.append(paragraph.text)


# Process tables
for table in doc.tables:
# Table header
try:
# table maybe cause errors so ignore it.
if len(table.rows) > 0 and table.rows[0].cells is not None:
# Keep track of paragraph and table positions
content_items: list[tuple[int, str, Table | Paragraph]] = []

# Process paragraphs and tables
for i, paragraph in enumerate(doc.paragraphs):
if paragraph.text.strip():
content_items.append((i, "paragraph", paragraph))

for i, table in enumerate(doc.tables):
content_items.append((i, "table", table))

# Sort content items based on their original position
content_items.sort(key=operator.itemgetter(0))

# Process sorted content
for _, item_type, item in content_items:
if item_type == "paragraph":
if isinstance(item, Table):
continue
text.append(item.text)
elif item_type == "table":
# Process tables
if not isinstance(item, Table):
continue
try:
# Check if any cell in the table has text # Check if any cell in the table has text
has_content = False has_content = False
for row in table.rows:
for row in item.rows:
if any(cell.text.strip() for cell in row.cells): if any(cell.text.strip() for cell in row.cells):
has_content = True has_content = True
break break


if has_content: if has_content:
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
for row in table.rows[1:]:
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
cell_texts = [cell.text.replace("\n", "<br>") for cell in item.rows[0].cells]
markdown_table = f"| {' | '.join(cell_texts)} |\n"
markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n"

for row in item.rows[1:]:
# Replace newlines with <br> in each cell
row_cells = [cell.text.replace("\n", "<br>") for cell in row.cells]
markdown_table += "| " + " | ".join(row_cells) + " |\n"

text.append(markdown_table) text.append(markdown_table)
except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
continue
except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
continue


return "\n".join(text) return "\n".join(text)

except Exception as e: except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e



+ 16
- 1
api/core/workflow/nodes/http_request/entities.py 查看文件

Code Node Data. Code Node Data.
""" """


method: Literal["get", "post", "put", "patch", "delete", "head"]
method: Literal[
"get",
"post",
"put",
"patch",
"delete",
"head",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
]
url: str url: str
authorization: HttpRequestNodeAuthorization authorization: HttpRequestNodeAuthorization
headers: str headers: str

+ 43
- 10
api/core/workflow/nodes/http_request/executor.py 查看文件





class Executor: class Executor:
method: Literal["get", "head", "post", "put", "delete", "patch"]
method: Literal[
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
]
url: str url: str
params: list[tuple[str, str]] | None params: list[tuple[str, str]] | None
content: str | bytes | None content: str | bytes | None
node_data.authorization.config.api_key node_data.authorization.config.api_key
).text ).text


# check if node_data.url is a valid URL
if not node_data.url:
raise InvalidURLError("url is required")
if not node_data.url.startswith(("http://", "https://")):
raise InvalidURLError("url should start with http:// or https://")

self.url: str = node_data.url self.url: str = node_data.url
self.method = node_data.method self.method = node_data.method
self.auth = node_data.authorization self.auth = node_data.authorization
def _init_url(self): def _init_url(self):
self.url = self.variable_pool.convert_template(self.node_data.url).text self.url = self.variable_pool.convert_template(self.node_data.url).text


# check if url is a valid URL
if not self.url:
raise InvalidURLError("url is required")
if not self.url.startswith(("http://", "https://")):
raise InvalidURLError("url should start with http:// or https://")

def _init_params(self): def _init_params(self):
""" """
Almost same as _init_headers(), difference: Almost same as _init_headers(), difference:
if len(data) != 1: if len(data) != 1:
raise RequestBodyError("json body type should have exactly one item") raise RequestBodyError("json body type should have exactly one item")
json_string = self.variable_pool.convert_template(data[0].value).text json_string = self.variable_pool.convert_template(data[0].value).text
json_object = json.loads(json_string, strict=False)
try:
json_object = json.loads(json_string, strict=False)
except json.JSONDecodeError as e:
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
self.json = json_object self.json = json_object
# self.json = self._parse_object_contains_variables(json_object) # self.json = self._parse_object_contains_variables(json_object)
case "binary": case "binary":
""" """
do http request depending on api bundle do http request depending on api bundle
""" """
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
if self.method not in {
"get",
"head",
"post",
"put",
"delete",
"patch",
"options",
"GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
}:
raise InvalidHttpMethodError(f"Invalid http method {self.method}") raise InvalidHttpMethodError(f"Invalid http method {self.method}")


request_args = { request_args = {
} }
# request_args = {k: v for k, v in request_args.items() if v is not None} # request_args = {k: v for k, v in request_args.items() if v is not None}
try: try:
response = getattr(ssrf_proxy, self.method)(**request_args)
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
raise HttpRequestNodeError(str(e)) raise HttpRequestNodeError(str(e))
# FIXME: fix type ignore, this maybe httpx type issue # FIXME: fix type ignore, this maybe httpx type issue

+ 4
- 0
api/core/workflow/workflow_entry.py 查看文件

): ):
raise ValueError(f"Variable key {node_variable} not found in user inputs.") raise ValueError(f"Variable key {node_variable} not found in user inputs.")


# environment variable already exist in variable pool, not from user inputs
if variable_pool.get(variable_selector):
continue

# fetch variable node id from variable selector # fetch variable node id from variable selector
variable_node_id = variable_selector[0] variable_node_id = variable_selector[0]
variable_key_list = variable_selector[1:] variable_key_list = variable_selector[1:]

+ 1
- 0
api/docker/entrypoint.sh 查看文件

--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \ --workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \ --worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \ --timeout ${GUNICORN_TIMEOUT:-200} \
app:app app:app
fi fi

+ 1
- 1
api/extensions/ext_logging.py 查看文件

timezone = pytz.timezone(log_tz) timezone = pytz.timezone(log_tz)


def time_converter(seconds): def time_converter(seconds):
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()


for handler in logging.root.handlers: for handler in logging.root.handlers:
if handler.formatter: if handler.formatter:

+ 1
- 1
api/factories/file_factory.py 查看文件

tenant_id: str, tenant_id: str,
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
) -> File: ) -> File:
url = mapping.get("url")
url = mapping.get("url") or mapping.get("remote_url")
if not url: if not url:
raise ValueError("Invalid file url") raise ValueError("Invalid file url")



+ 2
- 1
api/libs/oauth_data_source.py 查看文件

response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
raise ValueError(f"Error fetching block parent page ID: {response_json.message}")
message = response_json.get("message", "unknown error")
raise ValueError(f"Error fetching block parent page ID: {message}")
parent = response_json["parent"] parent = response_json["parent"]
parent_type = parent["type"] parent_type = parent["type"]
if parent_type == "block_id": if parent_type == "block_id":

+ 41
- 0
api/migrations/versions/2025_01_01_2000-a91b476a53de_change_workflow_runs_total_tokens_to_.py 查看文件

"""change workflow_runs.total_tokens to bigint

Revision ID: a91b476a53de
Revises: 923752d42eb6
Create Date: 2025-01-01 20:00:01.207369

"""
from alembic import op
import models as models
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'a91b476a53de'
down_revision = '923752d42eb6'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.alter_column('total_tokens',
existing_type=sa.INTEGER(),
type_=sa.BigInteger(),
existing_nullable=False,
existing_server_default=sa.text('0'))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.alter_column('total_tokens',
existing_type=sa.BigInteger(),
type_=sa.INTEGER(),
existing_nullable=False,
existing_server_default=sa.text('0'))

# ### end Alembic commands ###

+ 2
- 2
api/models/workflow.py 查看文件

status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(db.Text) error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
total_tokens: Mapped[int] = mapped_column(server_default=db.text("0"))
elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps = db.Column(db.Integer, server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0"))
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
created_by = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)

+ 1247
- 1168
api/poetry.lock
文件差異過大導致無法顯示
查看文件


+ 4
- 4
api/pyproject.toml 查看文件

pypdfium2 = "~4.30.0" pypdfium2 = "~4.30.0"
python = ">=3.11,<3.13" python = ">=3.11,<3.13"
python-docx = "~1.1.0" python-docx = "~1.1.0"
python-dotenv = "1.0.0"
python-dotenv = "1.0.1"
pyyaml = "~6.0.1" pyyaml = "~6.0.1"
readabilipy = "0.2.0" readabilipy = "0.2.0"
redis = { version = "~5.0.3", extras = ["hiredis"] } redis = { version = "~5.0.3", extras = ["hiredis"] }
sentry-sdk = { version = "~1.44.1", extras = ["flask"] } sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
sqlalchemy = "~2.0.29" sqlalchemy = "~2.0.29"
starlette = "0.41.0" starlette = "0.41.0"
tencentcloud-sdk-python-hunyuan = "~3.0.1158"
tencentcloud-sdk-python-hunyuan = "~3.0.1294"
tiktoken = "~0.8.0" tiktoken = "~0.8.0"
tokenizers = "~0.15.0" tokenizers = "~0.15.0"
transformers = "~4.35.0" transformers = "~4.35.0"
volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"}
websocket-client = "~1.7.0" websocket-client = "~1.7.0"
xinference-client = "0.15.2" xinference-client = "0.15.2"
yarl = "~1.9.4"
yarl = "~1.18.3"
youtube-transcript-api = "~0.6.2" youtube-transcript-api = "~0.6.2"
zhipuai = "~2.1.5" zhipuai = "~2.1.5"
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. # Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
oracledb = "~2.2.1" oracledb = "~2.2.1"
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
pgvector = "0.2.5" pgvector = "0.2.5"
pymilvus = "~2.4.4"
pymilvus = "~2.5.0"
pymochow = "1.3.1" pymochow = "1.3.1"
pyobvector = "~0.1.6" pyobvector = "~0.1.6"
qdrant-client = "1.7.3" qdrant-client = "1.7.3"

+ 0
- 17
api/schedule/clean_unused_datasets_task.py 查看文件

else: else:
plan = plan_cache.decode() plan = plan_cache.decode()
if plan == "sandbox": if plan == "sandbox":
# add auto disable log
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index # remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None) index_processor.clean(dataset, None)

+ 1
- 1
api/services/account_service.py 查看文件



REFRESH_TOKEN_PREFIX = "refresh_token:" REFRESH_TOKEN_PREFIX = "refresh_token:"
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:" ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
REFRESH_TOKEN_EXPIRY = timedelta(days=30)
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)




class AccountService: class AccountService:

+ 9
- 4
api/services/app_dsl_service.py 查看文件

import uuid import uuid
from enum import StrEnum from enum import StrEnum
from typing import Optional, cast from typing import Optional, cast
from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4


import yaml # type: ignore import yaml # type: ignore
raise ValueError(f"Invalid import_mode: {import_mode}") raise ValueError(f"Invalid import_mode: {import_mode}")


# Get YAML content # Get YAML content
content: bytes | str = b""
content: str = ""
if mode == ImportMode.YAML_URL: if mode == ImportMode.YAML_URL:
if not yaml_url: if not yaml_url:
return Import( return Import(
error="yaml_url is required when import_mode is yaml-url", error="yaml_url is required when import_mode is yaml-url",
) )
try: try:
# tricky way to handle url from github to github raw url
if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")):
parsed_url = urlparse(yaml_url)
if (
parsed_url.scheme == "https"
and parsed_url.netloc == "github.com"
and parsed_url.path.endswith((".yml", ".yaml"))
):
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
yaml_url = yaml_url.replace("/blob/", "/") yaml_url = yaml_url.replace("/blob/", "/")
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
response.raise_for_status() response.raise_for_status()
content = response.content
content = response.content.decode()


if len(content) > DSL_MAX_SIZE: if len(content) > DSL_MAX_SIZE:
return Import( return Import(

+ 4
- 1
api/services/app_service.py 查看文件





class AppService: class AppService:
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None:
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None:
""" """
Get app list with pagination Get app list with pagination
:param user_id: user id
:param tenant_id: tenant id :param tenant_id: tenant id
:param args: request args :param args: request args
:return: :return:
elif args["mode"] == "channel": elif args["mode"] == "channel":
filters.append(App.mode == AppMode.CHANNEL.value) filters.append(App.mode == AppMode.CHANNEL.value)


if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"): if args.get("name"):
name = args["name"][:30] name = args["name"][:30]
filters.append(App.name.ilike(f"%{name}%")) filters.append(App.name.ilike(f"%{name}%"))

+ 4
- 4
api/services/billing_service.py 查看文件

import os import os
from typing import Optional
from typing import Literal, Optional


import httpx import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
params = {"tenant_id": tenant_id} params = {"tenant_id": tenant_id}


billing_info = cls._send_request("GET", "/subscription/info", params=params) billing_info = cls._send_request("GET", "/subscription/info", params=params)

return billing_info return billing_info


@classmethod @classmethod
retry=retry_if_exception_type(httpx.RequestError), retry=retry_if_exception_type(httpx.RequestError),
reraise=True, reraise=True,
) )
def _send_request(cls, method, endpoint, json=None, params=None):
def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}


url = f"{cls.base_url}{endpoint}" url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers) response = httpx.request(method, url, json=json, params=params, headers=headers)

if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
return response.json() return response.json()


@staticmethod @staticmethod

+ 26
- 10
api/services/dataset_service.py 查看文件

else: else:
return [], 0 return [], 0
else: else:
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if user.current_role != TenantAccountRole.OWNER:
# show all datasets that the user has permission to access # show all datasets that the user has permission to access
if permitted_dataset_ids: if permitted_dataset_ids:
query = query.filter( query = query.filter(
if dataset.tenant_id != user.current_tenant_id: if dataset.tenant_id != user.current_tenant_id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if user.current_role != TenantAccountRole.OWNER:
if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
if not user: if not user:
raise ValueError("User not found") raise ValueError("User not found")


if user.current_role not in (TenantAccountRole.OWNER, TenantAccountRole.ADMIN):
if user.current_role != TenantAccountRole.OWNER:
if dataset.permission == DatasetPermissionEnum.ONLY_ME: if dataset.permission == DatasetPermissionEnum.ONLY_ME:
if dataset.created_by != user.id: if dataset.created_by != user.id:
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")


@staticmethod @staticmethod
def get_dataset_auto_disable_logs(dataset_id: str) -> dict: def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
return {
"document_ids": [],
"count": 0,
}
# get recent 30 days auto disable logs # get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30) start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(
dataset.indexing_technique = knowledge_config.indexing_technique dataset.indexing_technique = knowledge_config.indexing_technique
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
model_manager = ModelManager() model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
dataset_embedding_model = knowledge_config.embedding_model
dataset_embedding_model_provider = knowledge_config.embedding_model_provider
else:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
dataset_embedding_model_provider, dataset_embedding_model
) )
dataset.collection_binding_id = dataset_collection_binding.id dataset.collection_binding_id = dataset_collection_binding.id
if not dataset.retrieval_model: if not dataset.retrieval_model:
"score_threshold_enabled": False, "score_threshold_enabled": False,
} }


dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
dataset.retrieval_model = (
knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model
else default_retrieval_model
) # type: ignore


documents = [] documents = []
if knowledge_config.original_document_id: if knowledge_config.original_document_id:

+ 1
- 1
api/services/workflow_app_service.py 查看文件

query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)


if keyword: if keyword:
keyword_like_val = f"%{args['keyword'][:30]}%"
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
keyword_conditions = [ keyword_conditions = [
WorkflowRun.inputs.ilike(keyword_like_val), WorkflowRun.inputs.ilike(keyword_like_val),
WorkflowRun.outputs.ilike(keyword_like_val), WorkflowRun.outputs.ilike(keyword_like_val),

+ 4
- 1
api/tasks/deal_dataset_vector_index_task.py 查看文件



if not dataset: if not dataset:
raise Exception("Dataset not found") raise Exception("Dataset not found")
index_type = dataset.doc_form
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove": if action == "remove":
index_processor.clean(dataset, None, with_keywords=False) index_processor.clean(dataset, None, with_keywords=False)
{"indexing_status": "error", "error": str(e)}, synchronize_session=False {"indexing_status": "error", "error": str(e)}, synchronize_session=False
) )
db.session.commit() db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)


end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(

+ 55
- 0
api/tests/integration_tests/model_runtime/gpustack/test_speech2text.py 查看文件

import os
from pathlib import Path

import pytest

from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel


def test_validate_credentials():
model = GPUStackSpeech2TextModel()

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": "invalid_url",
"api_key": "invalid_api_key",
},
)

model.validate_credentials(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
)


def test_invoke_model():
model = GPUStackSpeech2TextModel()

# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))

# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")

# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, "audio.mp3")

file = Path(audio_file_path).read_bytes()

result = model.invoke(
model="faster-whisper-medium",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
file=file,
)

assert isinstance(result, str)
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

+ 24
- 0
api/tests/integration_tests/model_runtime/gpustack/test_tts.py 查看文件

import os

from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel


def test_invoke_model():
model = GPUStackText2SpeechModel()

result = model.invoke(
model="cosyvoice-300m-sft",
tenant_id="test",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
content_text="Hello world",
voice="Chinese Female",
)

content = b""
for chunk in result:
content += chunk

assert content != b""

+ 2
- 2
api/tests/integration_tests/vdb/milvus/test_milvus.py 查看文件

) )


def search_by_full_text(self): def search_by_full_text(self):
# milvus dos not support full text searching yet in < 2.3.x
# milvus support BM25 full text search after version 2.5.0-beta
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
assert len(hits_by_full_text) >= 0


def get_ids_by_metadata_field(self): def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)

+ 3
- 3
docker-legacy/docker-compose.yaml 查看文件

services: services:
# API service # API service
api: api:
image: langgenius/dify-api:0.14.2
image: langgenius/dify-api:0.15.0
restart: always restart: always
environment: environment:
# Startup mode, 'api' starts the API server. # Startup mode, 'api' starts the API server.
# worker service # worker service
# The Celery worker for processing the queue. # The Celery worker for processing the queue.
worker: worker:
image: langgenius/dify-api:0.14.2
image: langgenius/dify-api:0.15.0
restart: always restart: always
environment: environment:
CONSOLE_WEB_URL: '' CONSOLE_WEB_URL: ''


# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:0.14.2
image: langgenius/dify-web:0.15.0
restart: always restart: always
environment: environment:
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is

+ 15
- 4
docker/.env.example 查看文件

# Access token expiration time in minutes # Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60 ACCESS_TOKEN_EXPIRE_MINUTES=60


# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30

# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
APP_MAX_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0
APP_MAX_EXECUTION_TIME=1200 APP_MAX_EXECUTION_TIME=1200
# The number of API server workers, i.e., the number of workers. # The number of API server workers, i.e., the number of workers.
# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent # Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent
# Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers # Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers
SERVER_WORKER_AMOUNT=
SERVER_WORKER_AMOUNT=1


# Defaults to gevent. If using windows, it can be switched to sync or solo. # Defaults to gevent. If using windows, it can be switched to sync or solo.
SERVER_WORKER_CLASS=
SERVER_WORKER_CLASS=gevent

# Default number of worker connections, the default is 10.
SERVER_WORKER_CONNECTIONS=10


# Similar to SERVER_WORKER_CLASS. # Similar to SERVER_WORKER_CLASS.
# If using windows, it can be switched to sync or solo. # If using windows, it can be switched to sync or solo.
# ------------------------------ # ------------------------------


# The type of vector store to use. # The type of vector store to use.
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`.
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`.
VECTOR_STORE=weaviate VECTOR_STORE=weaviate


# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
MILVUS_TOKEN= MILVUS_TOKEN=
MILVUS_USER=root MILVUS_USER=root
MILVUS_PASSWORD=Milvus MILVUS_PASSWORD=Milvus
MILVUS_ENABLE_HYBRID_SEARCH=False


# MyScale configuration, only available when VECTOR_STORE is `myscale` # MyScale configuration, only available when VECTOR_STORE is `myscale`
# For multi-language support, please set MYSCALE_FTS_PARAMS with referring to: # For multi-language support, please set MYSCALE_FTS_PARAMS with referring to:
TENCENT_VECTOR_DB_REPLICAS=2 TENCENT_VECTOR_DB_REPLICAS=2


# ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch` # ElasticSearch configuration, only available when VECTOR_STORE is `elasticsearch`
ELASTICSEARCH_HOST=0.0.0.0
ELASTICSEARCH_HOST=elasticsearch
ELASTICSEARCH_PORT=9200 ELASTICSEARCH_PORT=9200
ELASTICSEARCH_USERNAME=elastic ELASTICSEARCH_USERNAME=elastic
ELASTICSEARCH_PASSWORD=elastic ELASTICSEARCH_PASSWORD=elastic
# Maximum number of submitted thread count in a ThreadPool for parallel node execution # Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100 MAX_SUBMIT_COUNT=100


# The maximum number of top-k value for RAG.
TOP_K_MAX_VALUE=10

# ------------------------------ # ------------------------------
# Plugin Daemon Configuration # Plugin Daemon Configuration
# ------------------------------ # ------------------------------


MARKETPLACE_ENABLED=true MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace-plugin.dify.dev MARKETPLACE_API_URL=https://marketplace-plugin.dify.dev


+ 17
- 26
docker/docker-compose-template.yaml 查看文件

CSP_WHITELIST: ${CSP_WHITELIST:-} CSP_WHITELIST: ${CSP_WHITELIST:-}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev} MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}


# The postgres database. # The postgres database.
db: db:
volumes: volumes:
- ./volumes/db/data:/var/lib/postgresql/data - ./volumes/db/data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ['CMD', 'pg_isready']
test: [ 'CMD', 'pg_isready' ]
interval: 1s interval: 1s
timeout: 3s timeout: 3s
retries: 30 retries: 30
# Set the redis password when startup redis server. # Set the redis password when startup redis server.
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
healthcheck: healthcheck:
test: ['CMD', 'redis-cli', 'ping']
test: [ 'CMD', 'redis-cli', 'ping' ]


# The DifySandbox # The DifySandbox
sandbox: sandbox:
volumes: volumes:
- ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/dependencies:/dependencies
healthcheck: healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:8194/health']
test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ]
networks: networks:
- ssrf_proxy_network - ssrf_proxy_network


volumes: volumes:
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
entrypoint:
[
'sh',
'-c',
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
environment: environment:
# pls clearly modify the squid env vars to fit your network environment. # pls clearly modify the squid env vars to fit your network environment.
HTTP_PORT: ${SSRF_HTTP_PORT:-3128} HTTP_PORT: ${SSRF_HTTP_PORT:-3128}
- CERTBOT_EMAIL=${CERTBOT_EMAIL} - CERTBOT_EMAIL=${CERTBOT_EMAIL}
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN} - CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
entrypoint: ['/docker-entrypoint.sh']
command: ['tail', '-f', '/dev/null']
entrypoint: [ '/docker-entrypoint.sh' ]
command: [ 'tail', '-f', '/dev/null' ]


# The nginx reverse proxy. # The nginx reverse proxy.
# used for reverse proxying the API service and Web service. # used for reverse proxying the API service and Web service.
- ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container)
- ./volumes/certbot/conf:/etc/letsencrypt - ./volumes/certbot/conf:/etc/letsencrypt
- ./volumes/certbot/www:/var/www/html - ./volumes/certbot/www:/var/www/html
entrypoint:
[
'sh',
'-c',
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
environment: environment:
NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_}
NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false}
working_dir: /opt/couchbase working_dir: /opt/couchbase
stdin_open: true stdin_open: true
tty: true tty: true
entrypoint: [""]
entrypoint: [ "" ]
command: sh -c "/opt/couchbase/init/init-cbserver.sh" command: sh -c "/opt/couchbase/init/init-cbserver.sh"
volumes: volumes:
- ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data
volumes: volumes:
- ./volumes/pgvector/data:/var/lib/postgresql/data - ./volumes/pgvector/data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ['CMD', 'pg_isready']
test: [ 'CMD', 'pg_isready' ]
interval: 1s interval: 1s
timeout: 3s timeout: 3s
retries: 30 retries: 30
volumes: volumes:
- ./volumes/pgvecto_rs/data:/var/lib/postgresql/data - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ['CMD', 'pg_isready']
test: [ 'CMD', 'pg_isready' ]
interval: 1s interval: 1s
timeout: 3s timeout: 3s
retries: 30 retries: 30
- ./volumes/milvus/etcd:/etcd - ./volumes/milvus/etcd:/etcd
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
healthcheck: healthcheck:
test: ['CMD', 'etcdctl', 'endpoint', 'health']
test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ]
interval: 30s interval: 30s
timeout: 20s timeout: 20s
retries: 3 retries: 3
- ./volumes/milvus/minio:/minio_data - ./volumes/milvus/minio:/minio_data
command: minio server /minio_data --console-address ":9001" command: minio server /minio_data --console-address ":9001"
healthcheck: healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live']
test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ]
interval: 30s interval: 30s
timeout: 20s timeout: 20s
retries: 3 retries: 3
image: milvusdb/milvus:v2.3.1 image: milvusdb/milvus:v2.3.1
profiles: profiles:
- milvus - milvus
command: ['milvus', 'run', 'standalone']
command: [ 'milvus', 'run', 'standalone' ]
environment: environment:
ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379}
MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000}
volumes: volumes:
- ./volumes/milvus/milvus:/var/lib/milvus - ./volumes/milvus/milvus:/var/lib/milvus
healthcheck: healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz']
test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ]
interval: 30s interval: 30s
start_period: 90s start_period: 90s
timeout: 20s timeout: 20s
ports: ports:
- ${ELASTICSEARCH_PORT:-9200}:9200 - ${ELASTICSEARCH_PORT:-9200}:9200
healthcheck: healthcheck:
test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty']
test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 50 retries: 50
ports: ports:
- ${KIBANA_PORT:-5601}:5601 - ${KIBANA_PORT:-5601}:5601
healthcheck: healthcheck:
test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1']
test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 3 retries: 3

+ 33
- 30
docker/docker-compose.yaml 查看文件

MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true}
FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300}
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
REFRESH_TOKEN_EXPIRE_DAYS: ${REFRESH_TOKEN_EXPIRE_DAYS:-30}
APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0} APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0}
APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200} APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200}
DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0} DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}
DIFY_PORT: ${DIFY_PORT:-5001} DIFY_PORT: ${DIFY_PORT:-5001}
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-}
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-}
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-1}
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-gevent}
SERVER_WORKER_CONNECTIONS: ${SERVER_WORKER_CONNECTIONS:-10}
CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-} CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-}
GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360} GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360}
CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-} CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-}
MILVUS_TOKEN: ${MILVUS_TOKEN:-} MILVUS_TOKEN: ${MILVUS_TOKEN:-}
MILVUS_USER: ${MILVUS_USER:-root} MILVUS_USER: ${MILVUS_USER:-root}
MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus} MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus}
MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False}
MYSCALE_HOST: ${MYSCALE_HOST:-myscale} MYSCALE_HOST: ${MYSCALE_HOST:-myscale}
MYSCALE_PORT: ${MYSCALE_PORT:-8123} MYSCALE_PORT: ${MYSCALE_PORT:-8123}
MYSCALE_USER: ${MYSCALE_USER:-default} MYSCALE_USER: ${MYSCALE_USER:-default}
ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}} ENDPOINT_URL_TEMPLATE: ${ENDPOINT_URL_TEMPLATE:-http://localhost/e/{hook_id}}
MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true} MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10}


services: services:
# API service # API service
CSP_WHITELIST: ${CSP_WHITELIST:-} CSP_WHITELIST: ${CSP_WHITELIST:-}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace-plugin.dify.dev}
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev} MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace-plugin.dify.dev}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}


# The postgres database. # The postgres database.
db: db:
volumes: volumes:
- ./volumes/db/data:/var/lib/postgresql/data - ./volumes/db/data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ['CMD', 'pg_isready']
test: [ 'CMD', 'pg_isready' ]
interval: 1s interval: 1s
timeout: 3s timeout: 3s
retries: 30 retries: 30
# Set the redis password when startup redis server. # Set the redis password when startup redis server.
command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456}
healthcheck: healthcheck:
test: ['CMD', 'redis-cli', 'ping']
test: [ 'CMD', 'redis-cli', 'ping' ]


# The DifySandbox # The DifySandbox
sandbox: sandbox:
volumes: volumes:
- ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/dependencies:/dependencies
healthcheck: healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:8194/health']
test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ]
networks: networks:
- ssrf_proxy_network - ssrf_proxy_network


volumes: volumes:
- ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template
- ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh
entrypoint:
[
'sh',
'-c',
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
environment: environment:
# pls clearly modify the squid env vars to fit your network environment. # pls clearly modify the squid env vars to fit your network environment.
HTTP_PORT: ${SSRF_HTTP_PORT:-3128} HTTP_PORT: ${SSRF_HTTP_PORT:-3128}
- CERTBOT_EMAIL=${CERTBOT_EMAIL} - CERTBOT_EMAIL=${CERTBOT_EMAIL}
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN} - CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
entrypoint: ['/docker-entrypoint.sh']
command: ['tail', '-f', '/dev/null']
entrypoint: [ '/docker-entrypoint.sh' ]
command: [ 'tail', '-f', '/dev/null' ]


# The nginx reverse proxy. # The nginx reverse proxy.
# used for reverse proxying the API service and Web service. # used for reverse proxying the API service and Web service.
- ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container)
- ./volumes/certbot/conf:/etc/letsencrypt - ./volumes/certbot/conf:/etc/letsencrypt
- ./volumes/certbot/www:/var/www/html - ./volumes/certbot/www:/var/www/html
entrypoint:
[
'sh',
'-c',
"cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh",
]
entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ]
environment: environment:
NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_}
NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false}
working_dir: /opt/couchbase working_dir: /opt/couchbase
stdin_open: true stdin_open: true
tty: true tty: true
entrypoint: [""]
entrypoint: [ "" ]
command: sh -c "/opt/couchbase/init/init-cbserver.sh" command: sh -c "/opt/couchbase/init/init-cbserver.sh"
volumes: volumes:
- ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data
volumes: volumes:
- ./volumes/pgvector/data:/var/lib/postgresql/data - ./volumes/pgvector/data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ['CMD', 'pg_isready']
test: [ 'CMD', 'pg_isready' ]
interval: 1s interval: 1s
timeout: 3s timeout: 3s
retries: 30 retries: 30
volumes: volumes:
- ./volumes/pgvecto_rs/data:/var/lib/postgresql/data - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ['CMD', 'pg_isready']
test: [ 'CMD', 'pg_isready' ]
interval: 1s interval: 1s
timeout: 3s timeout: 3s
retries: 30 retries: 30
- ./volumes/milvus/etcd:/etcd - ./volumes/milvus/etcd:/etcd
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
healthcheck: healthcheck:
test: ['CMD', 'etcdctl', 'endpoint', 'health']
test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ]
interval: 30s interval: 30s
timeout: 20s timeout: 20s
retries: 3 retries: 3
- ./volumes/milvus/minio:/minio_data - ./volumes/milvus/minio:/minio_data
command: minio server /minio_data --console-address ":9001" command: minio server /minio_data --console-address ":9001"
healthcheck: healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live']
test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ]
interval: 30s interval: 30s
timeout: 20s timeout: 20s
retries: 3 retries: 3


milvus-standalone: milvus-standalone:
container_name: milvus-standalone container_name: milvus-standalone
image: milvusdb/milvus:v2.3.1
image: milvusdb/milvus:v2.5.0-beta
profiles: profiles:
- milvus - milvus
command: ['milvus', 'run', 'standalone']
command: [ 'milvus', 'run', 'standalone' ]
environment: environment:
ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379}
MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000}
volumes: volumes:
- ./volumes/milvus/milvus:/var/lib/milvus - ./volumes/milvus/milvus:/var/lib/milvus
healthcheck: healthcheck:
test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz']
test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ]
interval: 30s interval: 30s
start_period: 90s start_period: 90s
timeout: 20s timeout: 20s
container_name: elasticsearch container_name: elasticsearch
profiles: profiles:
- elasticsearch - elasticsearch
- elasticsearch-ja
restart: always restart: always
volumes: volumes:
- ./elasticsearch/docker-entrypoint.sh:/docker-entrypoint-mount.sh
- dify_es01_data:/usr/share/elasticsearch/data - dify_es01_data:/usr/share/elasticsearch/data
environment: environment:
ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic}
VECTOR_STORE: ${VECTOR_STORE:-}
cluster.name: dify-es-cluster cluster.name: dify-es-cluster
node.name: dify-es0 node.name: dify-es0
discovery.type: single-node discovery.type: single-node
xpack.license.self_generated.type: trial
xpack.license.self_generated.type: basic
xpack.security.enabled: 'true' xpack.security.enabled: 'true'
xpack.security.enrollment.enabled: 'false' xpack.security.enrollment.enabled: 'false'
xpack.security.http.ssl.enabled: 'false' xpack.security.http.ssl.enabled: 'false'
ports: ports:
- ${ELASTICSEARCH_PORT:-9200}:9200 - ${ELASTICSEARCH_PORT:-9200}:9200
deploy:
resources:
limits:
memory: 2g
entrypoint: [ 'sh', '-c', "sh /docker-entrypoint-mount.sh" ]
healthcheck: healthcheck:
test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty']
test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 50 retries: 50
ports: ports:
- ${KIBANA_PORT:-5601}:5601 - ${KIBANA_PORT:-5601}:5601
healthcheck: healthcheck:
test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1']
test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 3 retries: 3

+ 25
- 0
docker/elasticsearch/docker-entrypoint.sh 查看文件

#!/bin/bash

set -e

if [ "${VECTOR_STORE}" = "elasticsearch-ja" ]; then
# Check if the ICU tokenizer plugin is installed
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-icu; then
printf '%s\n' "Installing the ICU tokenizer plugin"
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-icu; then
printf '%s\n' "Failed to install the ICU tokenizer plugin"
exit 1
fi
fi
# Check if the Japanese language analyzer plugin is installed
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-kuromoji; then
printf '%s\n' "Installing the Japanese language analyzer plugin"
if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-kuromoji; then
printf '%s\n' "Failed to install the Japanese language analyzer plugin"
exit 1
fi
fi
fi

# Run the original entrypoint script
exec /bin/tini -- /usr/local/bin/docker-entrypoint.sh

+ 3
- 0
web/.env.example 查看文件



# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP # CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
NEXT_PUBLIC_CSP_WHITELIST= NEXT_PUBLIC_CSP_WHITELIST=

# The maximum number of top-k value for RAG.
NEXT_PUBLIC_TOP_K_MAX_VALUE=10

+ 11
- 2
web/app/(commonLayout)/apps/Apps.tsx 查看文件

import { useStore as useTagStore } from '@/app/components/base/tag-management/store' import { useStore as useTagStore } from '@/app/components/base/tag-management/store'
import TagManagementModal from '@/app/components/base/tag-management' import TagManagementModal from '@/app/components/base/tag-management'
import TagFilter from '@/app/components/base/tag-management/filter' import TagFilter from '@/app/components/base/tag-management/filter'
import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label'


const getKey = ( const getKey = (
pageIndex: number, pageIndex: number,
previousPageData: AppListResponse, previousPageData: AppListResponse,
activeTab: string, activeTab: string,
isCreatedByMe: boolean,
tags: string[], tags: string[],
keywords: string, keywords: string,
) => { ) => {
if (!pageIndex || previousPageData.has_more) { if (!pageIndex || previousPageData.has_more) {
const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords } }
const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords, is_created_by_me: isCreatedByMe } }


if (activeTab !== 'all') if (activeTab !== 'all')
params.params.mode = activeTab params.params.mode = activeTab
defaultTab: 'all', defaultTab: 'all',
}) })
const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState() const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState()
const [isCreatedByMe, setIsCreatedByMe] = useState(false)
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs) const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)
const [searchKeywords, setSearchKeywords] = useState(keywords) const [searchKeywords, setSearchKeywords] = useState(keywords)
const setKeywords = useCallback((keywords: string) => { const setKeywords = useCallback((keywords: string) => {
}, [setQuery]) }, [setQuery])


const { data, isLoading, setSize, mutate } = useSWRInfinite( const { data, isLoading, setSize, mutate } = useSWRInfinite(
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, tagIDs, searchKeywords),
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, isCreatedByMe, tagIDs, searchKeywords),
fetchAppList, fetchAppList,
{ revalidateFirstPage: true }, { revalidateFirstPage: true },
) )
options={options} options={options}
/> />
<div className='flex items-center gap-2'> <div className='flex items-center gap-2'>
<CheckboxWithLabel
className='mr-2'
label={t('app.showMyCreatedAppsOnly')}
isChecked={isCreatedByMe}
onChange={() => setIsCreatedByMe(!isCreatedByMe)}
/>
<TagFilter type='app' value={tagFilterValue} onChange={handleTagsChange} /> <TagFilter type='app' value={tagFilterValue} onChange={handleTagsChange} />
<Input <Input
showLeftIcon showLeftIcon

+ 38
- 1
web/app/(commonLayout)/datasets/template/template.en.mdx 查看文件

- <code>high_quality</code> High quality: embedding using embedding model, built as vector database index - <code>high_quality</code> High quality: embedding using embedding model, built as vector database index
- <code>economy</code> Economy: Build using inverted index of keyword table index - <code>economy</code> Economy: Build using inverted index of keyword table index
</Property> </Property>
<Property name='doc_form' type='string' key='doc_form'>
Format of indexed content
- <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form
- <code>hierarchical_model</code> Parent-child mode
- <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions
</Property>
<Property name='doc_language' type='string' key='doc_language'>
In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code>
</Property>
<Property name='process_rule' type='object' key='process_rule'> <Property name='process_rule' type='object' key='process_rule'>
Processing rules Processing rules
- <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom - <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom
- <code>segmentation</code> (object) Segmentation rules - <code>segmentation</code> (object) Segmentation rules
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
- <code>max_tokens</code> Maximum length (token) defaults to 1000 - <code>max_tokens</code> Maximum length (token) defaults to 1000
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
- <code>subchunk_segmentation</code> (object) Child chunk rules
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>
- <code>high_quality</code> High quality: embedding using embedding model, built as vector database index - <code>high_quality</code> High quality: embedding using embedding model, built as vector database index
- <code>economy</code> Economy: Build using inverted index of keyword table index - <code>economy</code> Economy: Build using inverted index of keyword table index


- <code>doc_form</code> Format of indexed content
- <code>text_model</code> Text documents are directly embedded; `economy` mode defaults to using this form
- <code>hierarchical_model</code> Parent-child mode
- <code>qa_model</code> Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions

- <code>doc_language</code> In Q&A mode, specify the language of the document, for example: <code>English</code>, <code>Chinese</code>

- <code>process_rule</code> Processing rules - <code>process_rule</code> Processing rules
- <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom - <code>mode</code> (string) Cleaning, segmentation mode, automatic / custom
- <code>rules</code> (object) Custom rules (in automatic mode, this field is empty) - <code>rules</code> (object) Custom rules (in automatic mode, this field is empty)
- <code>segmentation</code> (object) Segmentation rules - <code>segmentation</code> (object) Segmentation rules
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
- <code>max_tokens</code> Maximum length (token) defaults to 1000 - <code>max_tokens</code> Maximum length (token) defaults to 1000
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
- <code>subchunk_segmentation</code> (object) Child chunk rules
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
</Property> </Property>
<Property name='file' type='multipart/form-data' key='file'> <Property name='file' type='multipart/form-data' key='file'>
Files that need to be uploaded. Files that need to be uploaded.
- <code>segmentation</code> (object) Segmentation rules - <code>segmentation</code> (object) Segmentation rules
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
- <code>max_tokens</code> Maximum length (token) defaults to 1000 - <code>max_tokens</code> Maximum length (token) defaults to 1000
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
- <code>subchunk_segmentation</code> (object) Child chunk rules
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>
- <code>segmentation</code> (object) Segmentation rules - <code>segmentation</code> (object) Segmentation rules
- <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n - <code>separator</code> Custom segment identifier, currently only allows one delimiter to be set. Default is \n
- <code>max_tokens</code> Maximum length (token) defaults to 1000 - <code>max_tokens</code> Maximum length (token) defaults to 1000
- <code>parent_mode</code> Retrieval mode of parent chunks: <code>full-doc</code> full text retrieval / <code>paragraph</code> paragraph retrieval
- <code>subchunk_segmentation</code> (object) Child chunk rules
- <code>separator</code> Segmentation identifier. Currently, only one delimiter is allowed. The default is <code>***</code>
- <code>max_tokens</code> The maximum length (tokens) must be validated to be shorter than the length of the parent chunk
- <code>chunk_overlap</code> Define the overlap between adjacent chunks (optional)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>
<Heading <Heading
url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' url='/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}'
method='POST' method='POST'
title='Update a Chunk in a Document '
title='Update a Chunk in a Document'
name='#update_segment' name='#update_segment'
/> />
<Row> <Row>
- <code>answer</code> (text) Answer content, passed if the knowledge is in Q&A mode (optional) - <code>answer</code> (text) Answer content, passed if the knowledge is in Q&A mode (optional)
- <code>keywords</code> (list) Keyword (optional) - <code>keywords</code> (list) Keyword (optional)
- <code>enabled</code> (bool) False / true (optional) - <code>enabled</code> (bool) False / true (optional)
- <code>regenerate_child_chunks</code> (bool) Whether to regenerate child chunks (optional)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>

+ 40
- 3
web/app/(commonLayout)/datasets/template/template.zh.mdx 查看文件

- <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 - <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引
- <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建 - <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建
</Property> </Property>
<Property name='doc_form' type='string' key='doc_form'>
索引内容的形式
- <code>text_model</code> text 文档直接 embedding,经济模式默认为该模式
- <code>hierarchical_model</code> parent-child 模式
- <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding
</Property>
<Property name='doc_language' type='string' key='doc_language'>
在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code>
</Property>
<Property name='process_rule' type='object' key='process_rule'> <Property name='process_rule' type='object' key='process_rule'>
处理规则 处理规则
- <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 - <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义
- <code>remove_urls_emails</code> 删除 URL、电子邮件地址 - <code>remove_urls_emails</code> 删除 URL、电子邮件地址
- <code>enabled</code> (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - <code>enabled</code> (bool) 是否选中该规则,不传入文档 ID 时代表默认值
- <code>segmentation</code> (object) 分段规则 - <code>segmentation</code> (object) 分段规则
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 <code>\n</code>
- <code>max_tokens</code> 最大长度(token)默认为 1000 - <code>max_tokens</code> 最大长度(token)默认为 1000
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
- <code>subchunk_segmentation</code> (object) 子分段规则
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>
- <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 - <code>high_quality</code> 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引
- <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建 - <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建


- <code>doc_form</code> 索引内容的形式
- <code>text_model</code> text 文档直接 embedding,经济模式默认为该模式
- <code>hierarchical_model</code> parent-child 模式
- <code>qa_model</code> Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding

- <code>doc_language</code> 在 Q&A 模式下,指定文档的语言,例如:<code>English</code>、<code>Chinese</code>

- <code>process_rule</code> 处理规则 - <code>process_rule</code> 处理规则
- <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 - <code>mode</code> (string) 清洗、分段模式 ,automatic 自动 / custom 自定义
- <code>rules</code> (object) 自定义规则(自动模式下,该字段为空) - <code>rules</code> (object) 自定义规则(自动模式下,该字段为空)
- <code>segmentation</code> (object) 分段规则 - <code>segmentation</code> (object) 分段规则
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
- <code>max_tokens</code> 最大长度(token)默认为 1000 - <code>max_tokens</code> 最大长度(token)默认为 1000
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
- <code>subchunk_segmentation</code> (object) 子分段规则
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
</Property> </Property>
<Property name='file' type='multipart/form-data' key='file'> <Property name='file' type='multipart/form-data' key='file'>
需要上传的文件。 需要上传的文件。
<Heading <Heading
url='/datasets/{dataset_id}/documents/{document_id}/update-by-text' url='/datasets/{dataset_id}/documents/{document_id}/update-by-text'
method='POST' method='POST'
title='通过文本更新文档 '
title='通过文本更新文档'
name='#update-by-text' name='#update-by-text'
/> />
<Row> <Row>
- <code>segmentation</code> (object) 分段规则 - <code>segmentation</code> (object) 分段规则
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
- <code>max_tokens</code> 最大长度(token)默认为 1000 - <code>max_tokens</code> 最大长度(token)默认为 1000
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
- <code>subchunk_segmentation</code> (object) 子分段规则
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>
<Heading <Heading
url='/datasets/{dataset_id}/documents/{document_id}/update-by-file' url='/datasets/{dataset_id}/documents/{document_id}/update-by-file'
method='POST' method='POST'
title='通过文件更新文档 '
title='通过文件更新文档'
name='#update-by-file' name='#update-by-file'
/> />
<Row> <Row>
- <code>segmentation</code> (object) 分段规则 - <code>segmentation</code> (object) 分段规则
- <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - <code>separator</code> 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n
- <code>max_tokens</code> 最大长度(token)默认为 1000 - <code>max_tokens</code> 最大长度(token)默认为 1000
- <code>parent_mode</code> 父分段的召回模式 <code>full-doc</code> 全文召回 / <code>paragraph</code> 段落召回
- <code>subchunk_segmentation</code> (object) 子分段规则
- <code>separator</code> 分段标识符,目前仅允许设置一个分隔符。默认为 <code>***</code>
- <code>max_tokens</code> 最大长度 (token) 需要校验小于父级的长度
- <code>chunk_overlap</code> 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填)
</Property> </Property>
</Properties> </Properties>
</Col> </Col>
- <code>answer</code> (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值 - <code>answer</code> (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值
- <code>keywords</code> (list) 关键字,非必填 - <code>keywords</code> (list) 关键字,非必填
- <code>enabled</code> (bool) false/true,非必填 - <code>enabled</code> (bool) false/true,非必填
- <code>regenerate_child_chunks</code> (bool) 是否重新生成子分段,非必填
</Property> </Property>
</Properties> </Properties>
</Col> </Col>

+ 4
- 3
web/app/components/app/configuration/config-prompt/prompt-editor-height-resize-wrap.tsx 查看文件

const [clientY, setClientY] = useState(0) const [clientY, setClientY] = useState(0)
const [isResizing, setIsResizing] = useState(false) const [isResizing, setIsResizing] = useState(false)
const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(getComputedStyle(document.body).userSelect) const [prevUserSelectStyle, setPrevUserSelectStyle] = useState(getComputedStyle(document.body).userSelect)
const [oldHeight, setOldHeight] = useState(height)


const handleStartResize = useCallback((e: React.MouseEvent<HTMLElement>) => { const handleStartResize = useCallback((e: React.MouseEvent<HTMLElement>) => {
setClientY(e.clientY) setClientY(e.clientY)
setIsResizing(true) setIsResizing(true)
setOldHeight(height)
setPrevUserSelectStyle(getComputedStyle(document.body).userSelect) setPrevUserSelectStyle(getComputedStyle(document.body).userSelect)
document.body.style.userSelect = 'none' document.body.style.userSelect = 'none'
}, [])
}, [height])


const handleStopResize = useCallback(() => { const handleStopResize = useCallback(() => {
setIsResizing(false) setIsResizing(false)
return return


const offset = e.clientY - clientY const offset = e.clientY - clientY
let newHeight = height + offset
setClientY(e.clientY)
let newHeight = oldHeight + offset
if (newHeight < minHeight) if (newHeight < minHeight)
newHeight = minHeight newHeight = minHeight
onHeightChange(newHeight) onHeightChange(newHeight)

+ 17
- 1
web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx 查看文件

import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block'
import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block' import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { useFeaturesStore } from '@/app/components/base/features/hooks'


export type ISimplePromptInput = { export type ISimplePromptInput = {
mode: AppType mode: AppType
const { t } = useTranslation() const { t } = useTranslation()
const media = useBreakpoints() const media = useBreakpoints()
const isMobile = media === MediaType.mobile const isMobile = media === MediaType.mobile
const featuresStore = useFeaturesStore()
const {
features,
setFeatures,
} = featuresStore!.getState()


const { eventEmitter } = useEventEmitterContextContext() const { eventEmitter } = useEventEmitterContextContext()
const { const {
}) })
setModelConfig(newModelConfig) setModelConfig(newModelConfig)
setPrevPromptConfig(modelConfig.configs) setPrevPromptConfig(modelConfig.configs)
if (mode !== AppType.completion)

if (mode !== AppType.completion) {
setIntroduction(res.opening_statement) setIntroduction(res.opening_statement)
const newFeatures = produce(features, (draft) => {
draft.opening = {
...draft.opening,
enabled: !!res.opening_statement,
opening_statement: res.opening_statement,
}
})
setFeatures(newFeatures)
}
showAutomaticFalse() showAutomaticFalse()
} }
const minHeight = initEditorHeight || 228 const minHeight = initEditorHeight || 228

+ 49
- 70
web/app/components/app/configuration/dataset-config/params-config/config-content.tsx 查看文件



const { const {
modelList: rerankModelList, modelList: rerankModelList,
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)


const { const {
currentModel: currentRerankModel, currentModel: currentRerankModel,
} = useCurrentProviderAndModel( } = useCurrentProviderAndModel(
rerankModelList, rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
{
provider: datasetConfigs.reranking_model?.reranking_provider_name,
model: datasetConfigs.reranking_model?.reranking_model_name,
},
) )


const rerankModel = (() => {
if (datasetConfigs.reranking_model?.reranking_provider_name) {
return {
provider_name: datasetConfigs.reranking_model.reranking_provider_name,
model_name: datasetConfigs.reranking_model.reranking_model_name,
}
const rerankModel = useMemo(() => {
return {
provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '',
model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '',
} }
else if (rerankDefaultModel) {
return {
provider_name: rerankDefaultModel.provider.provider,
model_name: rerankDefaultModel.model,
}
}
})()
}, [datasetConfigs.reranking_model])


const handleParamChange = (key: string, value: number) => { const handleParamChange = (key: string, value: number) => {
if (key === 'top_k') { if (key === 'top_k') {
} }


const handleRerankModeChange = (mode: RerankingModeEnum) => { const handleRerankModeChange = (mode: RerankingModeEnum) => {
if (mode === datasetConfigs.reranking_mode)
return

if (mode === RerankingModeEnum.RerankingModel && !currentRerankModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })

onChange({ onChange({
...datasetConfigs, ...datasetConfigs,
reranking_mode: mode, reranking_mode: mode,


const canManuallyToggleRerank = useMemo(() => { const canManuallyToggleRerank = useMemo(() => {
return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic) return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
|| selectedDatasetsMode.allExternal
|| selectedDatasetsMode.allExternal
}, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal]) }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])


const showRerankModel = useMemo(() => { const showRerankModel = useMemo(() => {
if (!canManuallyToggleRerank) if (!canManuallyToggleRerank)
return true return true
else if (canManuallyToggleRerank && !isRerankDefaultModelValid)
return false


return datasetConfigs.reranking_enable return datasetConfigs.reranking_enable
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid])
}, [datasetConfigs.reranking_enable, canManuallyToggleRerank])


const handleDisabledSwitchClick = useCallback(() => {
if (!currentRerankModel && !showRerankModel)
const handleDisabledSwitchClick = useCallback((enable: boolean) => {
if (!currentRerankModel && enable)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
}, [currentRerankModel, showRerankModel, t])

useEffect(() => {
if (canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) {
onChange({
...datasetConfigs,
reranking_enable: showRerankModel,
})
}
}, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange])
onChange({
...datasetConfigs,
reranking_enable: enable,
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentRerankModel, datasetConfigs, onChange])


return ( return (
<div> <div>
<div className='flex items-center'> <div className='flex items-center'>
{ {
selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && ( selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && (
<div
className='flex items-center'
onClick={handleDisabledSwitchClick}
>
<Switch
size='md'
defaultValue={showRerankModel}
disabled={!currentRerankModel || !canManuallyToggleRerank}
onChange={(v) => {
if (canManuallyToggleRerank) {
onChange({
...datasetConfigs,
reranking_enable: v,
})
}
}}
/>
</div>
<Switch
size='md'
defaultValue={showRerankModel}
disabled={!canManuallyToggleRerank}
onChange={handleDisabledSwitchClick}
/>
) )
} }
<div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div> <div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>
triggerClassName='ml-1 w-4 h-4' triggerClassName='ml-1 w-4 h-4'
/> />
</div> </div>
<div>
<ModelSelector
defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
onSelect={(v) => {
onChange({
...datasetConfigs,
reranking_model: {
reranking_provider_name: v.provider,
reranking_model_name: v.model,
},
})
}}
modelList={rerankModelList}
/>
</div>
{
showRerankModel && (
<div>
<ModelSelector
defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
onSelect={(v) => {
onChange({
...datasetConfigs,
reranking_model: {
reranking_provider_name: v.provider,
reranking_model_name: v.model,
},
})
}}
modelList={rerankModelList}
/>
</div>
)}
</div> </div>
) )
} }

+ 19
- 18
web/app/components/app/configuration/dataset-config/params-config/index.tsx 查看文件

import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import { RETRIEVE_TYPE } from '@/types/app' import { RETRIEVE_TYPE } from '@/types/app'
import Toast from '@/app/components/base/toast' import Toast from '@/app/components/base/toast'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { RerankingModeEnum } from '@/models/datasets' import { RerankingModeEnum } from '@/models/datasets'
import type { DataSet } from '@/models/datasets' import type { DataSet } from '@/models/datasets'
}, [datasetConfigs]) }, [datasetConfigs])


const { const {
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid,
modelList: rerankModelList,
currentModel: rerankDefaultModel,
currentProvider: rerankDefaultProvider, currentProvider: rerankDefaultProvider,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)


const {
currentModel: isCurrentRerankModelValid,
} = useCurrentProviderAndModel(
rerankModelList,
{
provider: tempDataSetConfigs.reranking_model?.reranking_provider_name ?? '',
model: tempDataSetConfigs.reranking_model?.reranking_model_name ?? '',
},
)

const isValid = () => { const isValid = () => {
let errMsg = '' let errMsg = ''
if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) { if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
if (tempDataSetConfigs.reranking_enable if (tempDataSetConfigs.reranking_enable
&& tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel && tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel
&& !isRerankDefaultModelValid
&& !isCurrentRerankModelValid
) )
errMsg = t('appDebug.datasetConfig.rerankModelRequired') errMsg = t('appDebug.datasetConfig.rerankModelRequired')
} }
const handleSave = () => { const handleSave = () => {
if (!isValid()) if (!isValid())
return return
const config = { ...tempDataSetConfigs }
if (config.retrieval_model === RETRIEVE_TYPE.multiWay
&& config.reranking_mode === RerankingModeEnum.RerankingModel
&& !config.reranking_model) {
config.reranking_model = {
reranking_provider_name: rerankDefaultModel?.provider?.provider,
reranking_model_name: rerankDefaultModel?.model,
} as any
}
setDatasetConfigs(config)
setDatasetConfigs(tempDataSetConfigs)
setRerankSettingModalOpen(false) setRerankSettingModalOpen(false)
} }


reranking_enable: restConfigs.reranking_enable, reranking_enable: restConfigs.reranking_enable,
}, selectedDatasets, selectedDatasets, { }, selectedDatasets, selectedDatasets, {
provider: rerankDefaultProvider?.provider, provider: rerankDefaultProvider?.provider,
model: isRerankDefaultModelValid?.model,
model: rerankDefaultModel?.model,
}) })


setTempDataSetConfigs({ setTempDataSetConfigs({
...retrievalConfig, ...retrievalConfig,
reranking_model: restConfigs.reranking_model && {
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
reranking_model: {
reranking_provider_name: retrievalConfig.reranking_model?.provider || '',
reranking_model_name: retrievalConfig.reranking_model?.model || '',
}, },
retrieval_model, retrieval_model,
score_threshold_enabled, score_threshold_enabled,

+ 2
- 2
web/app/components/app/configuration/dataset-config/params-config/weighted-score.tsx 查看文件



return ( return (
<div> <div>
<div className='px-3 pt-5 h-[52px] space-x-3 rounded-lg border border-components-panel-border'>
<div className='px-3 pt-5 pb-2 space-x-3 rounded-lg border border-components-panel-border'>
<Slider <Slider
className={cn('grow h-0.5 !bg-util-colors-teal-teal-500 rounded-full')} className={cn('grow h-0.5 !bg-util-colors-teal-teal-500 rounded-full')}
max={1.0} max={1.0}
onChange={v => onChange({ value: [v, (10 - v * 10) / 10] })} onChange={v => onChange({ value: [v, (10 - v * 10) / 10] })}
trackClassName='weightedScoreSliderTrack' trackClassName='weightedScoreSliderTrack'
/> />
<div className='flex justify-between mt-1'>
<div className='flex justify-between mt-3'>
<div className='shrink-0 flex items-center w-[90px] system-xs-semibold-uppercase text-util-colors-blue-light-blue-light-500'> <div className='shrink-0 flex items-center w-[90px] system-xs-semibold-uppercase text-util-colors-blue-light-blue-light-500'>
<div className='mr-1 truncate uppercase' title={t('dataset.weightedScore.semantic') || ''}> <div className='mr-1 truncate uppercase' title={t('dataset.weightedScore.semantic') || ''}>
{t('dataset.weightedScore.semantic')} {t('dataset.weightedScore.semantic')}

+ 5
- 15
web/app/components/app/configuration/dataset-config/settings-modal/index.tsx 查看文件

import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input' import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea' import Textarea from '@/app/components/base/textarea'
import { type DataSet, RerankingModeEnum } from '@/models/datasets'
import { type DataSet } from '@/models/datasets'
import { useToastContext } from '@/app/components/base/toast' import { useToastContext } from '@/app/components/base/toast'
import { updateDatasetSetting } from '@/service/datasets' import { updateDatasetSetting } from '@/service/datasets'
import { useAppContext } from '@/context/app-context' import { useAppContext } from '@/context/app-context'
import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings' import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import PermissionSelector from '@/app/components/datasets/settings/permission-selector' import PermissionSelector from '@/app/components/datasets/settings/permission-selector'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
} }
if ( if (
!isReRankModelSelected({ !isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelValid: !!isRerankDefaultModelValid,
rerankModelList, rerankModelList,
retrievalConfig, retrievalConfig,
indexMethod, indexMethod,
notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') }) notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
return return
} }
const postRetrievalConfig = ensureRerankModelSelected({
rerankDefaultModel: rerankDefaultModel!,
retrievalConfig: {
...retrievalConfig,
reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel,
},
indexMethod,
})
try { try {
setLoading(true) setLoading(true)
const { id, name, description, permission } = localeCurrentDataset const { id, name, description, permission } = localeCurrentDataset
permission, permission,
indexing_technique: indexMethod, indexing_technique: indexMethod,
retrieval_model: { retrieval_model: {
...postRetrievalConfig,
score_threshold: postRetrievalConfig.score_threshold_enabled ? postRetrievalConfig.score_threshold : 0,
...retrievalConfig,
score_threshold: retrievalConfig.score_threshold_enabled ? retrievalConfig.score_threshold : 0,
}, },
embedding_model: localeCurrentDataset.embedding_model, embedding_model: localeCurrentDataset.embedding_model,
embedding_model_provider: localeCurrentDataset.embedding_model_provider, embedding_model_provider: localeCurrentDataset.embedding_model_provider,
onSave({ onSave({
...localeCurrentDataset, ...localeCurrentDataset,
indexing_technique: indexMethod, indexing_technique: indexMethod,
retrieval_model_dict: postRetrievalConfig,
retrieval_model_dict: retrievalConfig,
}) })
} }
catch (e) { catch (e) {

+ 3
- 3
web/app/components/app/configuration/index.tsx 查看文件



setDatasetConfigs({ setDatasetConfigs({
...retrievalConfig, ...retrievalConfig,
reranking_model: restConfigs.reranking_model && {
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
reranking_model: {
reranking_provider_name: retrievalConfig?.reranking_model?.provider || '',
reranking_model_name: retrievalConfig?.reranking_model?.model || '',
}, },
retrieval_model, retrieval_model,
score_threshold_enabled, score_threshold_enabled,

+ 9
- 1
web/app/components/base/chat/chat/chat-input-area/index.tsx 查看文件

inputs?: Record<string, any> inputs?: Record<string, any>
inputsForm?: InputForm[] inputsForm?: InputForm[]
theme?: Theme | null theme?: Theme | null
isResponding?: boolean
} }
const ChatInputArea = ({ const ChatInputArea = ({
showFeatureBar, showFeatureBar,
inputs = {}, inputs = {},
inputsForm = [], inputsForm = [],
theme, theme,
isResponding,
}: ChatInputAreaProps) => { }: ChatInputAreaProps) => {
const { t } = useTranslation() const { t } = useTranslation()
const { notify } = useToastContext() const { notify } = useToastContext()
const historyRef = useRef(['']) const historyRef = useRef([''])
const [currentIndex, setCurrentIndex] = useState(-1) const [currentIndex, setCurrentIndex] = useState(-1)
const handleSend = () => { const handleSend = () => {
if (isResponding) {
notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
return
}

if (onSend) { if (onSend) {
const { files, setFiles } = filesStore.getState() const { files, setFiles } = filesStore.getState()
if (files.find(item => item.transferMethod === TransferMethod.local_file && !item.uploadedId)) { if (files.find(item => item.transferMethod === TransferMethod.local_file && !item.uploadedId)) {
setQuery(historyRef.current[currentIndex + 1]) setQuery(historyRef.current[currentIndex + 1])
} }
else if (currentIndex === historyRef.current.length - 1) { else if (currentIndex === historyRef.current.length - 1) {
// If it is the last element, clear the input box
// If it is the last element, clear the input box
setCurrentIndex(historyRef.current.length) setCurrentIndex(historyRef.current.length)
setQuery('') setQuery('')
} }
'p-1 w-full leading-6 body-lg-regular text-text-tertiary outline-none', 'p-1 w-full leading-6 body-lg-regular text-text-tertiary outline-none',
)} )}
placeholder={t('common.chat.inputPlaceholder') || ''} placeholder={t('common.chat.inputPlaceholder') || ''}
autoFocus
autoSize={{ minRows: 1 }} autoSize={{ minRows: 1 }}
onResize={handleTextareaResize} onResize={handleTextareaResize}
value={query} value={query}

+ 1
- 0
web/app/components/base/chat/chat/index.tsx 查看文件

inputs={inputs} inputs={inputs}
inputsForm={inputsForm} inputsForm={inputsForm}
theme={themeBuilder?.theme} theme={themeBuilder?.theme}
isResponding={isResponding}
/> />
) )
} }

+ 2
- 2
web/app/components/base/chat/chat/question.tsx 查看文件

} = item } = item


return ( return (
<div className='flex justify-end mb-2 last:mb-0 pl-10'>
<div className='group relative mr-4'>
<div className='flex justify-end mb-2 last:mb-0 pl-14'>
<div className='group relative mr-4 max-w-full'>
<div <div
className='px-4 py-3 bg-[#D1E9FF]/50 rounded-2xl text-sm text-gray-900' className='px-4 py-3 bg-[#D1E9FF]/50 rounded-2xl text-sm text-gray-900'
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}} style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}

+ 2
- 2
web/app/components/base/markdown.tsx 查看文件

} }
else if (language === 'echarts') { else if (language === 'echarts') {
return ( return (
<div style={{ minHeight: '350px', minWidth: '700px' }}>
<div style={{ minHeight: '350px', minWidth: '100%', overflowX: 'scroll' }}>
<ErrorBoundary> <ErrorBoundary>
<ReactEcharts option={chartData} />
<ReactEcharts option={chartData} style={{ minWidth: '700px' }} />
</ErrorBoundary> </ErrorBoundary>
</div> </div>
) )

+ 7
- 1
web/app/components/base/param-item/top-k-item.tsx 查看文件

enable: boolean enable: boolean
} }


const maxTopK = (() => {
const configValue = parseInt(globalThis.document?.body?.getAttribute('data-public-top-k-max-value') || '', 10)
if (configValue && !isNaN(configValue))
return configValue
return 10
})()
const VALUE_LIMIT = { const VALUE_LIMIT = {
default: 2, default: 2,
step: 1, step: 1,
min: 1, min: 1,
max: 10,
max: maxTopK,
} }


const key = 'top_k' const key = 'top_k'

+ 8
- 7
web/app/components/datasets/common/check-rerank-model.ts 查看文件

import { RerankingModeEnum } from '@/models/datasets' import { RerankingModeEnum } from '@/models/datasets'


export const isReRankModelSelected = ({ export const isReRankModelSelected = ({
rerankDefaultModel,
isRerankDefaultModelValid,
retrievalConfig, retrievalConfig,
rerankModelList, rerankModelList,
indexMethod, indexMethod,
}: { }: {
rerankDefaultModel?: DefaultModelResponse
isRerankDefaultModelValid: boolean
retrievalConfig: RetrievalConfig retrievalConfig: RetrievalConfig
rerankModelList: Model[] rerankModelList: Model[]
indexMethod?: string indexMethod?: string
return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name) return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name)
} }


if (isRerankDefaultModelValid)
return !!rerankDefaultModel

return false return false
})() })()


if (
indexMethod === 'high_quality'
&& ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrievalConfig.search_method))
&& retrievalConfig.reranking_enable
&& !rerankModelSelected
)
return false

if ( if (
indexMethod === 'high_quality' indexMethod === 'high_quality'
&& (retrievalConfig.search_method === RETRIEVE_METHOD.hybrid && retrievalConfig.reranking_mode !== RerankingModeEnum.WeightedScore) && (retrievalConfig.search_method === RETRIEVE_METHOD.hybrid && retrievalConfig.reranking_mode !== RerankingModeEnum.WeightedScore)

+ 4
- 1
web/app/components/datasets/common/economical-retrieval-method-config/index.tsx 查看文件

import type { RetrievalConfig } from '@/types/app' import type { RetrievalConfig } from '@/types/app'


type Props = { type Props = {
disabled?: boolean
value: RetrievalConfig value: RetrievalConfig
onChange: (value: RetrievalConfig) => void onChange: (value: RetrievalConfig) => void
} }


const EconomicalRetrievalMethodConfig: FC<Props> = ({ const EconomicalRetrievalMethodConfig: FC<Props> = ({
disabled = false,
value, value,
onChange, onChange,
}) => { }) => {


return ( return (
<div className='space-y-2'> <div className='space-y-2'>
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
<OptionCard
disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
title={t('dataset.retrieval.invertedIndex.title')} title={t('dataset.retrieval.invertedIndex.title')}
description={t('dataset.retrieval.invertedIndex.description')} isActive description={t('dataset.retrieval.invertedIndex.description')} isActive
activeHeaderClassName='bg-dataset-option-card-purple-gradient' activeHeaderClassName='bg-dataset-option-card-purple-gradient'

+ 71
- 45
web/app/components/datasets/common/retrieval-method-config/index.tsx 查看文件

'use client' 'use client'
import type { FC } from 'react' import type { FC } from 'react'
import React from 'react'
import React, { useCallback } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import Image from 'next/image' import Image from 'next/image'
import RetrievalParamConfig from '../retrieval-param-config' import RetrievalParamConfig from '../retrieval-param-config'
import type { RetrievalConfig } from '@/types/app' import type { RetrievalConfig } from '@/types/app'
import { RETRIEVE_METHOD } from '@/types/app' import { RETRIEVE_METHOD } from '@/types/app'
import { useProviderContext } from '@/context/provider-context' import { useProviderContext } from '@/context/provider-context'
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { import {
DEFAULT_WEIGHTED_SCORE, DEFAULT_WEIGHTED_SCORE,
import Badge from '@/app/components/base/badge' import Badge from '@/app/components/base/badge'


type Props = { type Props = {
disabled?: boolean
value: RetrievalConfig value: RetrievalConfig
onChange: (value: RetrievalConfig) => void onChange: (value: RetrievalConfig) => void
} }


const RetrievalMethodConfig: FC<Props> = ({ const RetrievalMethodConfig: FC<Props> = ({
value: passValue,
disabled = false,
value,
onChange, onChange,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const { supportRetrievalMethods } = useProviderContext() const { supportRetrievalMethods } = useProviderContext()
const { data: rerankDefaultModel } = useDefaultModel(ModelTypeEnum.rerank)
const value = (() => {
if (!passValue.reranking_model.reranking_model_name) {
return {
...passValue,
reranking_model: {
reranking_provider_name: rerankDefaultModel?.provider.provider || '',
reranking_model_name: rerankDefaultModel?.model || '',
},
reranking_mode: passValue.reranking_mode || (rerankDefaultModel ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore),
weights: passValue.weights || {
weight_type: WeightedScoreEnum.Customized,
vector_setting: {
vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic,
embedding_provider_name: '',
embedding_model_name: '',
},
keyword_setting: {
keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword,
},
},
}
const {
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const onSwitch = useCallback((retrieveMethod: RETRIEVE_METHOD) => {
if ([RETRIEVE_METHOD.semantic, RETRIEVE_METHOD.fullText].includes(retrieveMethod)) {
onChange({
...value,
search_method: retrieveMethod,
...(!value.reranking_model.reranking_model_name
? {
reranking_model: {
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
},
reranking_enable: !!isRerankDefaultModelValid,
}
: {
reranking_enable: true,
}),
})
} }
return passValue
})()
if (retrieveMethod === RETRIEVE_METHOD.hybrid) {
onChange({
...value,
search_method: retrieveMethod,
...(!value.reranking_model.reranking_model_name
? {
reranking_model: {
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '',
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
},
reranking_enable: !!isRerankDefaultModelValid,
reranking_mode: isRerankDefaultModelValid ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore,
}
: {
reranking_enable: true,
reranking_mode: RerankingModeEnum.RerankingModel,
}),
...(!value.weights
? {
weights: {
weight_type: WeightedScoreEnum.Customized,
vector_setting: {
vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic,
embedding_provider_name: '',
embedding_model_name: '',
},
keyword_setting: {
keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword,
},
},
}
: {}),
})
}
}, [value, rerankDefaultModel, isRerankDefaultModelValid, onChange])

return ( return (
<div className='space-y-2'> <div className='space-y-2'>
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
<OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.vector} alt='' />}
title={t('dataset.retrieval.semantic_search.title')} title={t('dataset.retrieval.semantic_search.title')}
description={t('dataset.retrieval.semantic_search.description')} description={t('dataset.retrieval.semantic_search.description')}
isActive={ isActive={
value.search_method === RETRIEVE_METHOD.semantic value.search_method === RETRIEVE_METHOD.semantic
} }
onSwitched={() => onChange({
...value,
search_method: RETRIEVE_METHOD.semantic,
})}
onSwitched={() => onSwitch(RETRIEVE_METHOD.semantic)}
effectImg={Effect.src} effectImg={Effect.src}
activeHeaderClassName='bg-dataset-option-card-purple-gradient' activeHeaderClassName='bg-dataset-option-card-purple-gradient'
> >
/> />
</OptionCard> </OptionCard>
)} )}
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />}
{supportRetrievalMethods.includes(RETRIEVE_METHOD.fullText) && (
<OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.fullText} alt='' />}
title={t('dataset.retrieval.full_text_search.title')} title={t('dataset.retrieval.full_text_search.title')}
description={t('dataset.retrieval.full_text_search.description')} description={t('dataset.retrieval.full_text_search.description')}
isActive={ isActive={
value.search_method === RETRIEVE_METHOD.fullText value.search_method === RETRIEVE_METHOD.fullText
} }
onSwitched={() => onChange({
...value,
search_method: RETRIEVE_METHOD.fullText,
})}
onSwitched={() => onSwitch(RETRIEVE_METHOD.fullText)}
effectImg={Effect.src} effectImg={Effect.src}
activeHeaderClassName='bg-dataset-option-card-purple-gradient' activeHeaderClassName='bg-dataset-option-card-purple-gradient'
> >
/> />
</OptionCard> </OptionCard>
)} )}
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
<OptionCard icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />}
{supportRetrievalMethods.includes(RETRIEVE_METHOD.hybrid) && (
<OptionCard disabled={disabled} icon={<Image className='w-4 h-4' src={retrievalIcon.hybrid} alt='' />}
title={ title={
<div className='flex items-center space-x-1'> <div className='flex items-center space-x-1'>
<div>{t('dataset.retrieval.hybrid_search.title')}</div> <div>{t('dataset.retrieval.hybrid_search.title')}</div>
description={t('dataset.retrieval.hybrid_search.description')} isActive={ description={t('dataset.retrieval.hybrid_search.description')} isActive={
value.search_method === RETRIEVE_METHOD.hybrid value.search_method === RETRIEVE_METHOD.hybrid
} }
onSwitched={() => onChange({
...value,
search_method: RETRIEVE_METHOD.hybrid,
reranking_enable: true,
})}
onSwitched={() => onSwitch(RETRIEVE_METHOD.hybrid)}
effectImg={Effect.src} effectImg={Effect.src}
activeHeaderClassName='bg-dataset-option-card-purple-gradient' activeHeaderClassName='bg-dataset-option-card-purple-gradient'
> >

+ 43
- 59
web/app/components/datasets/common/retrieval-param-config/index.tsx 查看文件

'use client' 'use client'
import type { FC } from 'react' import type { FC } from 'react'
import React, { useCallback } from 'react'
import React, { useCallback, useMemo } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'


import Image from 'next/image' import Image from 'next/image'
const { t } = useTranslation() const { t } = useTranslation()
const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid const canToggleRerankModalEnable = type !== RETRIEVE_METHOD.hybrid
const isEconomical = type === RETRIEVE_METHOD.invertedIndex const isEconomical = type === RETRIEVE_METHOD.invertedIndex
const isHybridSearch = type === RETRIEVE_METHOD.hybrid
const { const {
defaultModel: rerankDefaultModel,
modelList: rerankModelList, modelList: rerankModelList,
} = useModelListAndDefaultModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModel(ModelTypeEnum.rerank)


currentModel, currentModel,
} = useCurrentProviderAndModel( } = useCurrentProviderAndModel(
rerankModelList, rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
{
provider: value.reranking_model?.reranking_provider_name ?? '',
model: value.reranking_model?.reranking_model_name ?? '',
},
) )


const handleDisabledSwitchClick = useCallback(() => {
if (!currentModel)
const handleDisabledSwitchClick = useCallback((enable: boolean) => {
if (enable && !currentModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
}, [currentModel, rerankDefaultModel, t])

const isHybridSearch = type === RETRIEVE_METHOD.hybrid
onChange({
...value,
reranking_enable: enable,
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentModel, onChange, value])


const rerankModel = (() => {
if (value.reranking_model) {
return {
provider_name: value.reranking_model.reranking_provider_name,
model_name: value.reranking_model.reranking_model_name,
}
}
else if (rerankDefaultModel) {
return {
provider_name: rerankDefaultModel.provider.provider,
model_name: rerankDefaultModel.model,
}
const rerankModel = useMemo(() => {
return {
provider_name: value.reranking_model.reranking_provider_name,
model_name: value.reranking_model.reranking_model_name,
} }
})()
}, [value.reranking_model])


const handleChangeRerankMode = (v: RerankingModeEnum) => { const handleChangeRerankMode = (v: RerankingModeEnum) => {
if (v === value.reranking_mode) if (v === value.reranking_mode)
}, },
} }
} }
if (v === RerankingModeEnum.RerankingModel && !currentModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
onChange(result) onChange(result)
} }


<div> <div>
<div className='flex items-center space-x-2 mb-2'> <div className='flex items-center space-x-2 mb-2'>
{canToggleRerankModalEnable && ( {canToggleRerankModalEnable && (
<div
className='flex items-center'
onClick={handleDisabledSwitchClick}
>
<Switch
size='md'
defaultValue={currentModel ? value.reranking_enable : false}
onChange={(v) => {
onChange({
...value,
reranking_enable: v,
})
}}
disabled={!currentModel}
/>
</div>
<Switch
size='md'
defaultValue={value.reranking_enable}
onChange={handleDisabledSwitchClick}
/>
)} )}
<div className='flex items-center'> <div className='flex items-center'>
<span className='mr-0.5 system-sm-semibold text-text-secondary'>{t('common.modelProvider.rerankModel.key')}</span> <span className='mr-0.5 system-sm-semibold text-text-secondary'>{t('common.modelProvider.rerankModel.key')}</span>
/> />
</div> </div>
</div> </div>
<ModelSelector
triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`}
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
modelList={rerankModelList}
readonly={!value.reranking_enable}
onSelect={(v) => {
onChange({
...value,
reranking_model: {
reranking_provider_name: v.provider,
reranking_model_name: v.model,
},
})
}}
/>
{
value.reranking_enable && (
<ModelSelector
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
modelList={rerankModelList}
onSelect={(v) => {
onChange({
...value,
reranking_model: {
reranking_provider_name: v.provider,
reranking_model_name: v.model,
},
})
}}
/>
)
}
</div> </div>
)} )}
{ {
{ {
value.reranking_mode !== RerankingModeEnum.WeightedScore && ( value.reranking_mode !== RerankingModeEnum.WeightedScore && (
<ModelSelector <ModelSelector
triggerClassName={`${!value.reranking_enable && '!opacity-60 !cursor-not-allowed'}`}
defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }} defaultModel={rerankModel && { provider: rerankModel.provider_name, model: rerankModel.model_name }}
modelList={rerankModelList} modelList={rerankModelList}
readonly={!value.reranking_enable}
onSelect={(v) => { onSelect={(v) => {
onChange({ onChange({
...value, ...value,

+ 3
- 0
web/app/components/datasets/create/embedding-process/index.tsx 查看文件

import { sleep } from '@/utils' import { sleep } from '@/utils'
import { RETRIEVE_METHOD } from '@/types/app' import { RETRIEVE_METHOD } from '@/types/app'
import Tooltip from '@/app/components/base/tooltip' import Tooltip from '@/app/components/base/tooltip'
import { useInvalidDocumentList } from '@/service/knowledge/use-document'


type Props = { type Props = {
datasetId: string datasetId: string
}) })


const router = useRouter() const router = useRouter()
const invalidDocumentList = useInvalidDocumentList()
const navToDocumentList = () => { const navToDocumentList = () => {
invalidDocumentList()
router.push(`/datasets/${datasetId}/documents`) router.push(`/datasets/${datasetId}/documents`)
} }
const navToApiDocs = () => { const navToApiDocs = () => {

+ 52
- 60
web/app/components/datasets/create/step-two/index.tsx 查看文件

import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs' import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, DocumentItem, FullDocumentDetail, ParentMode, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, DocumentItem, FullDocumentDetail, ParentMode, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets'
import { ChunkingMode, DataSourceType, ProcessMode } from '@/models/datasets'


import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import FloatRightContainer from '@/app/components/base/float-right-container' import FloatRightContainer from '@/app/components/base/float-right-container'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import { type RetrievalConfig } from '@/types/app' import { type RetrievalConfig } from '@/types/app'
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import Toast from '@/app/components/base/toast' import Toast from '@/app/components/base/toast'
import type { NotionPage } from '@/models/common' import type { NotionPage } from '@/models/common'
import { DataSourceProvider } from '@/models/common' import { DataSourceProvider } from '@/models/common'
import { ChunkingMode, DataSourceType, RerankingModeEnum } from '@/models/datasets'
import { useDatasetDetailContext } from '@/context/dataset-detail' import { useDatasetDetailContext } from '@/context/dataset-detail'
import I18n from '@/context/i18n' import I18n from '@/context/i18n'
import { RETRIEVE_METHOD } from '@/types/app' import { RETRIEVE_METHOD } from '@/types/app'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import Checkbox from '@/app/components/base/checkbox' import Checkbox from '@/app/components/base/checkbox'
import RadioCard from '@/app/components/base/radio-card' import RadioCard from '@/app/components/base/radio-card'
import { IS_CE_EDITION } from '@/config'
import { FULL_DOC_PREVIEW_LENGTH, IS_CE_EDITION } from '@/config'
import Divider from '@/app/components/base/divider' import Divider from '@/app/components/base/divider'
import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument, useFetchDefaultProcessRule, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/knowledge/use-create-dataset' import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument, useFetchDefaultProcessRule, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/knowledge/use-create-dataset'
import Badge from '@/app/components/base/badge' import Badge from '@/app/components/base/badge'
onCancel?: () => void onCancel?: () => void
} }


export enum SegmentType {
AUTO = 'automatic',
CUSTOM = 'custom',
}
export enum IndexingType { export enum IndexingType {
QUALIFIED = 'high_quality', QUALIFIED = 'high_quality',
ECONOMICAL = 'economy', ECONOMICAL = 'economy',
} }


const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n' const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n'
const DEFAULT_MAXMIMUM_CHUNK_LENGTH = 500
const DEFAULT_MAXIMUM_CHUNK_LENGTH = 500
const DEFAULT_OVERLAP = 50 const DEFAULT_OVERLAP = 50


type ParentChildConfig = { type ParentChildConfig = {
isSetting, isSetting,
documentDetail, documentDetail,
isAPIKeySet, isAPIKeySet,
onSetting,
datasetId, datasetId,
indexingType, indexingType,
dataSourceType: inCreatePageDataSourceType, dataSourceType: inCreatePageDataSourceType,


const isInCreatePage = !datasetId || (datasetId && !currentDataset?.data_source_type) const isInCreatePage = !datasetId || (datasetId && !currentDataset?.data_source_type)
const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : currentDataset?.data_source_type const dataSourceType = isInCreatePage ? inCreatePageDataSourceType : currentDataset?.data_source_type
const [segmentationType, setSegmentationType] = useState<SegmentType>(SegmentType.CUSTOM)
const [segmentationType, setSegmentationType] = useState<ProcessMode>(ProcessMode.general)
const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER) const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER)
const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => { const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => {
doSetSegmentIdentifier(value ? escape(value) : (canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER)) doSetSegmentIdentifier(value ? escape(value) : (canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER))
}, []) }, [])
const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXMIMUM_CHUNK_LENGTH) // default chunk length
const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXIMUM_CHUNK_LENGTH) // default chunk length
const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(4000) const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(4000)
const [overlap, setOverlap] = useState(DEFAULT_OVERLAP) const [overlap, setOverlap] = useState(DEFAULT_OVERLAP)
const [rules, setRules] = useState<PreProcessingRule[]>([]) const [rules, setRules] = useState<PreProcessingRule[]>([])
) )


// QA Related // QA Related
const [isLanguageSelectDisabled, _setIsLanguageSelectDisabled] = useState(false)
const [isQAConfirmDialogOpen, setIsQAConfirmDialogOpen] = useState(false) const [isQAConfirmDialogOpen, setIsQAConfirmDialogOpen] = useState(false)
const [docForm, setDocForm] = useState<ChunkingMode>( const [docForm, setDocForm] = useState<ChunkingMode>(
(datasetId && documentDetail) ? documentDetail.doc_form as ChunkingMode : ChunkingMode.text, (datasetId && documentDetail) ? documentDetail.doc_form as ChunkingMode : ChunkingMode.text,
} }


const updatePreview = () => { const updatePreview = () => {
if (segmentationType === SegmentType.CUSTOM && maxChunkLength > 4000) {
if (segmentationType === ProcessMode.general && maxChunkLength > 4000) {
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') }) Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') })
return return
} }
model: defaultEmbeddingModel?.model || '', model: defaultEmbeddingModel?.model || '',
}, },
) )
const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
} as RetrievalConfig)

useEffect(() => {
if (currentDataset?.retrieval_model_dict)
return
setRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: !!isRerankDefaultModelValid,
reranking_model: {
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider.provider ?? '' : '',
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [rerankDefaultModel, isRerankDefaultModelValid])

const getCreationParams = () => { const getCreationParams = () => {
let params let params
if (segmentationType === SegmentType.CUSTOM && overlap > maxChunkLength) {
if (segmentationType === ProcessMode.general && overlap > maxChunkLength) {
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.overlapCheck') }) Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.overlapCheck') })
return return
} }
if (segmentationType === SegmentType.CUSTOM && maxChunkLength > limitMaxChunkLength) {
if (segmentationType === ProcessMode.general && maxChunkLength > limitMaxChunkLength) {
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck', { limit: limitMaxChunkLength }) }) Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck', { limit: limitMaxChunkLength }) })
return return
} }
doc_form: currentDocForm, doc_form: currentDocForm,
doc_language: docLanguage, doc_language: docLanguage,
process_rule: getProcessRule(), process_rule: getProcessRule(),
// eslint-disable-next-line @typescript-eslint/no-use-before-define
retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page. retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page.
embedding_model: embeddingModel.model, // Readonly embedding_model: embeddingModel.model, // Readonly
embedding_model_provider: embeddingModel.provider, // Readonly embedding_model_provider: embeddingModel.provider, // Readonly
const indexMethod = getIndexing_technique() const indexMethod = getIndexing_technique()
if ( if (
!isReRankModelSelected({ !isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelValid: !!isRerankDefaultModelValid,
rerankModelList, rerankModelList,
// eslint-disable-next-line @typescript-eslint/no-use-before-define
retrievalConfig, retrievalConfig,
indexMethod: indexMethod as string, indexMethod: indexMethod as string,
}) })
Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') }) Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
return return
} }
const postRetrievalConfig = ensureRerankModelSelected({
rerankDefaultModel: rerankDefaultModel!,
retrievalConfig: {
// eslint-disable-next-line @typescript-eslint/no-use-before-define
...retrievalConfig,
// eslint-disable-next-line @typescript-eslint/no-use-before-define
reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel,
},
indexMethod: indexMethod as string,
})
params = { params = {
data_source: { data_source: {
type: dataSourceType, type: dataSourceType,
process_rule: getProcessRule(), process_rule: getProcessRule(),
doc_form: currentDocForm, doc_form: currentDocForm,
doc_language: docLanguage, doc_language: docLanguage,

retrieval_model: postRetrievalConfig,
retrieval_model: retrievalConfig,
embedding_model: embeddingModel.model, embedding_model: embeddingModel.model,
embedding_model_provider: embeddingModel.provider, embedding_model_provider: embeddingModel.provider,
} as CreateDocumentReq } as CreateDocumentReq


const getDefaultMode = () => { const getDefaultMode = () => {
if (documentDetail) if (documentDetail)
// @ts-expect-error fix after api refactored
setSegmentationType(documentDetail.dataset_process_rule.mode) setSegmentationType(documentDetail.dataset_process_rule.mode)
} }


onSuccess(data) { onSuccess(data) {
updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) updateIndexingTypeCache && updateIndexingTypeCache(indexType as string)
updateResultCache && updateResultCache(data) updateResultCache && updateResultCache(data)
// eslint-disable-next-line @typescript-eslint/no-use-before-define
updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string) updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string)
}, },
}, },
isSetting && onSave && onSave() isSetting && onSave && onSave()
} }


const changeToEconomicalType = () => {
if (docForm !== ChunkingMode.text)
return

if (!hasSetIndexType)
setIndexType(IndexingType.ECONOMICAL)
}

useEffect(() => { useEffect(() => {
// fetch rules // fetch rules
if (!isSetting) { if (!isSetting) {
setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL) setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL)
}, [isAPIKeySet, indexingType, datasetId]) }, [isAPIKeySet, indexingType, datasetId])


const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: rerankDefaultModel?.provider.provider,
reranking_model_name: rerankDefaultModel?.model,
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
} as RetrievalConfig)

const economyDomRef = useRef<HTMLDivElement>(null) const economyDomRef = useRef<HTMLDivElement>(null)
const isHoveringEconomy = useHover(economyDomRef) const isHoveringEconomy = useHover(economyDomRef)


<div className={cn('system-md-semibold mb-1', datasetId && 'flex justify-between items-center')}>{t('datasetSettings.form.embeddingModel')}</div> <div className={cn('system-md-semibold mb-1', datasetId && 'flex justify-between items-center')}>{t('datasetSettings.form.embeddingModel')}</div>
<ModelSelector <ModelSelector
readonly={!!datasetId} readonly={!!datasetId}
triggerClassName={datasetId ? 'opacity-50' : ''}
defaultModel={embeddingModel} defaultModel={embeddingModel}
modelList={embeddingModelList} modelList={embeddingModelList}
onSelect={(model: DefaultModel) => { onSelect={(model: DefaultModel) => {
getIndexing_technique() === IndexingType.QUALIFIED getIndexing_technique() === IndexingType.QUALIFIED
? ( ? (
<RetrievalMethodConfig <RetrievalMethodConfig
disabled={!!datasetId}
value={retrievalConfig} value={retrievalConfig}
onChange={setRetrievalConfig} onChange={setRetrievalConfig}
/> />
) )
: ( : (
<EconomicalRetrievalMethodConfig <EconomicalRetrievalMethodConfig
disabled={!!datasetId}
value={retrievalConfig} value={retrievalConfig}
onChange={setRetrievalConfig} onChange={setRetrievalConfig}
/> />
) )
: ( : (
<div className='flex items-center mt-8 py-2'> <div className='flex items-center mt-8 py-2'>
<Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button>
{!datasetId && <Button loading={isCreating} variant='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.save')}</Button>}
<Button className='ml-2' onClick={onCancel}>{t('datasetCreation.stepTwo.cancel')}</Button> <Button className='ml-2' onClick={onCancel}>{t('datasetCreation.stepTwo.cancel')}</Button>
</div> </div>
)} )}
} }
{ {
currentDocForm !== ChunkingMode.qa currentDocForm !== ChunkingMode.qa
&& <Badge text={t(
'datasetCreation.stepTwo.previewChunkCount', {
count: estimate?.total_segments || 0,
}) as string}
/>
&& <Badge text={t(
'datasetCreation.stepTwo.previewChunkCount', {
count: estimate?.total_segments || 0,
}) as string}
/>
} }
</div> </div>
</PreviewHeader>} </PreviewHeader>}
{currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && ( {currentDocForm === ChunkingMode.parentChild && currentEstimateMutation.data?.preview && (
estimate?.preview?.map((item, index) => { estimate?.preview?.map((item, index) => {
const indexForLabel = index + 1 const indexForLabel = index + 1
const childChunks = parentChildConfig.chunkForContext === 'full-doc'
? item.child_chunks.slice(0, FULL_DOC_PREVIEW_LENGTH)
: item.child_chunks
return ( return (
<ChunkContainer <ChunkContainer
key={item.content} key={item.content}
characterCount={item.content.length} characterCount={item.content.length}
> >
<FormattedText> <FormattedText>
{item.child_chunks.map((child, index) => {
{childChunks.map((child, index) => {
const indexForLabel = index + 1 const indexForLabel = index + 1
return ( return (
<PreviewSlice <PreviewSlice

+ 2
- 2
web/app/components/datasets/create/step-two/option-card.tsx 查看文件



const TriangleArrow: FC<ComponentProps<'svg'>> = props => ( const TriangleArrow: FC<ComponentProps<'svg'>> = props => (
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="11" viewBox="0 0 24 11" fill="none" {...props}> <svg xmlns="http://www.w3.org/2000/svg" width="24" height="11" viewBox="0 0 24 11" fill="none" {...props}>
<path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor"/>
<path d="M9.87868 1.12132C11.0503 -0.0502525 12.9497 -0.0502525 14.1213 1.12132L23.3137 10.3137H0.686292L9.87868 1.12132Z" fill="currentColor" />
</svg> </svg>
) )


(isActive && !noHighlight) (isActive && !noHighlight)
? 'border-[1.5px] border-components-option-card-option-selected-border' ? 'border-[1.5px] border-components-option-card-option-selected-border'
: 'border border-components-option-card-option-border', : 'border border-components-option-card-option-border',
disabled && 'opacity-50 cursor-not-allowed',
disabled && 'opacity-50 pointer-events-none',
className, className,
)} )}
style={{ style={{

+ 0
- 0
web/app/components/datasets/documents/detail/completed/index.tsx 查看文件


部分文件因文件數量過多而無法顯示

Loading…
取消
儲存