瀏覽代碼

Merge branch main into feat/rag-2

tags/2.0.0-beta.1
twwu 3 月之前
父節點
當前提交
bae2af0c85
共有 100 個文件被更改,包括 1011 次插入528 次删除
  1. 27
    0
      .github/workflows/autofix.yml
  2. 10
    0
      api/.env.example
  3. 6
    1
      api/README.md
  4. 18
    18
      api/commands.py
  5. 36
    0
      api/configs/feature/__init__.py
  6. 5
    5
      api/controllers/console/admin.py
  7. 4
    4
      api/controllers/console/apikey.py
  8. 5
    5
      api/controllers/console/app/conversation.py
  9. 4
    4
      api/controllers/console/app/mcp_server.py
  10. 7
    7
      api/controllers/console/app/message.py
  11. 1
    1
      api/controllers/console/app/model_config.py
  12. 2
    2
      api/controllers/console/app/site.py
  13. 1
    1
      api/controllers/console/app/wraps.py
  14. 6
    0
      api/controllers/console/auth/error.py
  15. 2
    2
      api/controllers/console/datasets/data_source.py
  16. 8
    8
      api/controllers/console/datasets/datasets.py
  17. 10
    10
      api/controllers/console/datasets/datasets_document.py
  18. 14
    14
      api/controllers/console/datasets/datasets_segments.py
  19. 5
    5
      api/controllers/console/explore/installed_app.py
  20. 1
    1
      api/controllers/console/explore/wraps.py
  21. 1
    1
      api/controllers/console/workspace/__init__.py
  22. 13
    5
      api/controllers/console/workspace/account.py
  23. 1
    1
      api/controllers/console/workspace/members.py
  24. 114
    1
      api/controllers/console/workspace/plugin.py
  25. 6
    2
      api/controllers/console/workspace/tool_providers.py
  26. 3
    3
      api/controllers/inner_api/plugin/wraps.py
  27. 1
    1
      api/controllers/inner_api/wraps.py
  28. 2
    2
      api/controllers/mcp/mcp.py
  29. 6
    0
      api/controllers/service_api/app/completion.py
  30. 1
    1
      api/controllers/service_api/app/site.py
  31. 5
    1
      api/controllers/service_api/app/workflow.py
  32. 10
    10
      api/controllers/service_api/dataset/document.py
  33. 9
    9
      api/controllers/service_api/dataset/segment.py
  34. 2
    2
      api/controllers/service_api/dataset/upload_file.py
  35. 14
    14
      api/controllers/service_api/wraps.py
  36. 15
    14
      api/controllers/web/passport.py
  37. 1
    1
      api/controllers/web/site.py
  38. 4
    3
      api/controllers/web/wraps.py
  39. 3
    3
      api/core/agent/base_agent_runner.py
  40. 0
    48
      api/core/app/apps/README.md
  41. 51
    15
      api/core/app/apps/advanced_chat/app_generator.py
  42. 84
    33
      api/core/app/apps/advanced_chat/app_runner.py
  43. 4
    0
      api/core/app/apps/advanced_chat/generate_task_pipeline.py
  44. 3
    3
      api/core/app/apps/agent_chat/app_runner.py
  45. 1
    1
      api/core/app/apps/chat/app_runner.py
  46. 1
    1
      api/core/app/apps/completion/app_generator.py
  47. 1
    1
      api/core/app/apps/completion/app_runner.py
  48. 4
    10
      api/core/app/apps/message_based_app_generator.py
  49. 43
    11
      api/core/app/apps/workflow/app_generator.py
  50. 21
    40
      api/core/app/apps/workflow/app_runner.py
  51. 3
    0
      api/core/app/apps/workflow/generate_task_pipeline.py
  52. 12
    27
      api/core/app/apps/workflow_app_runner.py
  53. 1
    1
      api/core/app/features/annotation_reply/annotation_reply.py
  54. 1
    1
      api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
  55. 2
    2
      api/core/app/task_pipeline/message_cycle_manager.py
  56. 5
    5
      api/core/callback_handler/index_tool_callback_handler.py
  57. 7
    7
      api/core/entities/provider_configuration.py
  58. 2
    2
      api/core/external_data_tool/api/api.py
  59. 1
    1
      api/core/helper/encrypter.py
  60. 20
    0
      api/core/helper/marketplace.py
  61. 42
    0
      api/core/helper/trace_id_helper.py
  62. 9
    10
      api/core/indexing_runner.py
  63. 2
    1
      api/core/llm_generator/llm_generator.py
  64. 0
    1
      api/core/llm_generator/output_parser/suggested_questions_after_answer.py
  65. 2
    2
      api/core/mcp/server/streamable_http.py
  66. 1
    1
      api/core/memory/token_buffer_memory.py
  67. 17
    0
      api/core/model_runtime/entities/message_entities.py
  68. 1
    1
      api/core/moderation/api/api.py
  69. 5
    4
      api/core/ops/aliyun_trace/aliyun_trace.py
  70. 4
    3
      api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
  71. 2
    2
      api/core/ops/base_trace_instance.py
  72. 4
    3
      api/core/ops/langfuse_trace/langfuse_trace.py
  73. 3
    2
      api/core/ops/langsmith_trace/langsmith_trace.py
  74. 4
    3
      api/core/ops/opik_trace/opik_trace.py
  75. 12
    8
      api/core/ops/ops_trace_manager.py
  76. 3
    1
      api/core/ops/utils.py
  77. 3
    2
      api/core/ops/weave_trace/weave_trace.py
  78. 3
    3
      api/core/plugin/backwards_invocation/app.py
  79. 1
    0
      api/core/plugin/entities/plugin_daemon.py
  80. 35
    0
      api/core/plugin/impl/oauth.py
  81. 3
    3
      api/core/provider_manager.py
  82. 3
    3
      api/core/rag/datasource/keyword/jieba/jieba.py
  83. 6
    6
      api/core/rag/datasource/retrieval_service.py
  84. 2
    2
      api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
  85. 1
    1
      api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
  86. 93
    33
      api/core/rag/datasource/vdb/tablestore/tablestore_vector.py
  87. 2
    1
      api/core/rag/datasource/vdb/tencent/tencent_vector.py
  88. 3
    3
      api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
  89. 1
    1
      api/core/rag/datasource/vdb/vector_factory.py
  90. 4
    4
      api/core/rag/docstore/dataset_docstore.py
  91. 4
    3
      api/core/rag/extractor/notion_extractor.py
  92. 3
    3
      api/core/rag/index_processor/processor/parent_child_index_processor.py
  93. 14
    14
      api/core/rag/retrieval/dataset_retrieval.py
  94. 1
    1
      api/core/tools/custom_tool/provider.py
  95. 4
    4
      api/core/tools/tool_file_manager.py
  96. 3
    3
      api/core/tools/tool_label_manager.py
  97. 56
    16
      api/core/tools/tool_manager.py
  98. 3
    3
      api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py
  99. 2
    6
      api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py
  100. 0
    0
      api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py

+ 27
- 0
.github/workflows/autofix.yml 查看文件

@@ -0,0 +1,27 @@
name: autofix.ci
on:
workflow_call:
pull_request:
push:
branches: [ "main" ]
permissions:
contents: read

jobs:
autofix:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

# Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f
- run: |
cd api
uv sync --dev
# Fix lint errors
uv run ruff check --fix-only .
# Format code
uv run ruff format .

- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27


+ 10
- 0
api/.env.example 查看文件

@@ -471,6 +471,16 @@ APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1

# Celery schedule tasks configuration
ENABLE_CLEAN_EMBEDDING_CACHE_TASK=false
ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
ENABLE_CLEAN_MESSAGES=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true

# Position configuration
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=

+ 6
- 1
api/README.md 查看文件

@@ -74,7 +74,12 @@
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.

```bash
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin
```

Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
```bash
uv run celery -A app.celery beat
```

## Testing

+ 18
- 18
api/commands.py 查看文件

@@ -51,7 +51,7 @@ def reset_password(email, new_password, password_confirm):
click.echo(click.style("Passwords do not match.", fg="red"))
return

account = db.session.query(Account).filter(Account.email == email).one_or_none()
account = db.session.query(Account).where(Account.email == email).one_or_none()

if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
@@ -90,7 +90,7 @@ def reset_email(email, new_email, email_confirm):
click.echo(click.style("New emails do not match.", fg="red"))
return

account = db.session.query(Account).filter(Account.email == email).one_or_none()
account = db.session.query(Account).where(Account.email == email).one_or_none()

if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
@@ -137,8 +137,8 @@ def reset_encrypt_key_pair():

tenant.encrypt_public_key = generate_key_pair(tenant.id)

db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit()

click.echo(
@@ -173,7 +173,7 @@ def migrate_annotation_vector_database():
per_page = 50
apps = (
db.session.query(App)
.filter(App.status == "normal")
.where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
@@ -193,7 +193,7 @@ def migrate_annotation_vector_database():
try:
click.echo("Creating app annotation index: {}".format(app.id))
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)

if not app_annotation_setting:
@@ -203,13 +203,13 @@ def migrate_annotation_vector_database():
# get dataset_collection_binding info
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo("App annotation collection binding not found: {}".format(app.id))
continue
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
@@ -306,7 +306,7 @@ def migrate_knowledge_vector_database():
while True:
try:
stmt = (
select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
)

datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
@@ -333,7 +333,7 @@ def migrate_knowledge_vector_database():
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:
@@ -368,7 +368,7 @@ def migrate_knowledge_vector_database():

dataset_documents = (
db.session.query(DatasetDocument)
.filter(
.where(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@@ -382,7 +382,7 @@ def migrate_knowledge_vector_database():
for dataset_document in dataset_documents:
segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
@@ -469,7 +469,7 @@ def convert_to_agent_apps():
app_id = str(i.id)
if app_id not in proceeded_app_ids:
proceeded_app_ids.append(app_id)
app = db.session.query(App).filter(App.id == app_id).first()
app = db.session.query(App).where(App.id == app_id).first()
if app is not None:
apps.append(app)

@@ -484,7 +484,7 @@ def convert_to_agent_apps():
db.session.commit()

# update conversation mode to agent
db.session.query(Conversation).filter(Conversation.app_id == app.id).update(
db.session.query(Conversation).where(Conversation.app_id == app.id).update(
{Conversation.mode: AppMode.AGENT_CHAT.value}
)

@@ -561,7 +561,7 @@ def old_metadata_migration():
try:
stmt = (
select(DatasetDocument)
.filter(DatasetDocument.doc_metadata.is_not(None))
.where(DatasetDocument.doc_metadata.is_not(None))
.order_by(DatasetDocument.created_at.desc())
)
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
@@ -579,7 +579,7 @@ def old_metadata_migration():
else:
dataset_metadata = (
db.session.query(DatasetMetadata)
.filter(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
.where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
.first()
)
if not dataset_metadata:
@@ -603,7 +603,7 @@ def old_metadata_migration():
else:
dataset_metadata_binding = (
db.session.query(DatasetMetadataBinding) # type: ignore
.filter(
.where(
DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
@@ -718,7 +718,7 @@ where sites.id is null limit 1000"""
continue

try:
app = db.session.query(App).filter(App.id == app_id).first()
app = db.session.query(App).where(App.id == app_id).first()
if not app:
print(f"App {app_id} not found")
continue

+ 36
- 0
api/configs/feature/__init__.py 查看文件

@@ -832,6 +832,41 @@ class CeleryBeatConfig(BaseSettings):
)


class CeleryScheduleTasksConfig(BaseSettings):
ENABLE_CLEAN_EMBEDDING_CACHE_TASK: bool = Field(
description="Enable clean embedding cache task",
default=False,
)
ENABLE_CLEAN_UNUSED_DATASETS_TASK: bool = Field(
description="Enable clean unused datasets task",
default=False,
)
ENABLE_CREATE_TIDB_SERVERLESS_TASK: bool = Field(
description="Enable create tidb service job task",
default=False,
)
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: bool = Field(
description="Enable update tidb service job status task",
default=False,
)
ENABLE_CLEAN_MESSAGES: bool = Field(
description="Enable clean messages task",
default=False,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,
)
ENABLE_DATASETS_QUEUE_MONITOR: bool = Field(
description="Enable queue monitor task",
default=False,
)
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field(
description="Enable check upgradable plugin task",
default=True,
)


class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description="Comma-separated list of pinned model providers",
@@ -961,5 +996,6 @@ class FeatureConfig(
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,
CeleryScheduleTasksConfig,
):
pass

+ 5
- 5
api/controllers/console/admin.py 查看文件

@@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()

app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
if not app:
raise NotFound(f"App '{args['app_id']}' is not found")

@@ -74,7 +74,7 @@ class InsertExploreAppListApi(Resource):

with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none()

if not recommended_app:
@@ -117,21 +117,21 @@ class InsertExploreAppApi(Resource):
def delete(self, app_id):
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()

if not recommended_app:
return {"result": "success"}, 204

with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()

if app:
app.is_public = False

with Session(db.engine) as session:
installed_apps = session.execute(
select(InstalledApp).filter(
select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)

+ 4
- 4
api/controllers/console/apikey.py 查看文件

@@ -61,7 +61,7 @@ class BaseApiKeyListResource(Resource):
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.all()
)
return {"items": keys}
@@ -76,7 +76,7 @@ class BaseApiKeyListResource(Resource):

current_key_count = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.count()
)

@@ -117,7 +117,7 @@ class BaseApiKeyResource(Resource):

key = (
db.session.query(ApiToken)
.filter(
.where(
getattr(ApiToken, self.resource_id_field) == resource_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
@@ -128,7 +128,7 @@ class BaseApiKeyResource(Resource):
if key is None:
flask_restful.abort(404, message="API key not found")

db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

return {"result": "success"}, 204

+ 5
- 5
api/controllers/console/app/conversation.py 查看文件

@@ -49,7 +49,7 @@ class CompletionConversationApi(Resource):
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")

if args["keyword"]:
query = query.join(Message, Message.conversation_id == Conversation.id).filter(
query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
Message.query.ilike("%{}%".format(args["keyword"])),
Message.answer.ilike("%{}%".format(args["keyword"])),
@@ -121,7 +121,7 @@ class CompletionConversationDetailApi(Resource):

conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)

@@ -181,7 +181,7 @@ class ChatConversationApi(Resource):
Message.conversation_id == Conversation.id,
)
.join(subquery, subquery.c.conversation_id == Conversation.id)
.filter(
.where(
or_(
Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter),
@@ -286,7 +286,7 @@ class ChatConversationDetailApi(Resource):

conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)

@@ -308,7 +308,7 @@ api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversati
def _get_conversation(app_model, conversation_id):
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)


+ 4
- 4
api/controllers/console/app/mcp_server.py 查看文件

@@ -26,7 +26,7 @@ class AppMCPServerController(Resource):
@get_app_model
@marshal_with(app_server_fields)
def get(self, app_model):
server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
return server

@setup_required
@@ -73,7 +73,7 @@ class AppMCPServerController(Resource):
parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=False, location="json")
args = parser.parse_args()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
if not server:
raise NotFound()

@@ -104,8 +104,8 @@ class AppMCPServerRefreshController(Resource):
raise NotFound()
server = (
db.session.query(AppMCPServer)
.filter(AppMCPServer.id == server_id)
.filter(AppMCPServer.tenant_id == current_user.current_tenant_id)
.where(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_user.current_tenant_id)
.first()
)
if not server:

+ 7
- 7
api/controllers/console/app/message.py 查看文件

@@ -56,7 +56,7 @@ class ChatMessageListApi(Resource):

conversation = (
db.session.query(Conversation)
.filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.first()
)

@@ -66,7 +66,7 @@ class ChatMessageListApi(Resource):
if args["first_id"]:
first_message = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.first()
)

@@ -75,7 +75,7 @@ class ChatMessageListApi(Resource):

history_messages = (
db.session.query(Message)
.filter(
.where(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id,
@@ -87,7 +87,7 @@ class ChatMessageListApi(Resource):
else:
history_messages = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.all()
@@ -98,7 +98,7 @@ class ChatMessageListApi(Resource):
current_page_first_message = history_messages[-1]
rest_count = (
db.session.query(Message)
.filter(
.where(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
@@ -167,7 +167,7 @@ class MessageAnnotationCountApi(Resource):
@account_initialization_required
@get_app_model
def get(self, app_model):
count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count()
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()

return {"count": count}

@@ -214,7 +214,7 @@ class MessageApi(Resource):
def get(self, app_model, message_id):
message_id = str(message_id)

message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()

if not message:
raise NotFound("Message Not Exists.")

+ 1
- 1
api/controllers/console/app/model_config.py 查看文件

@@ -42,7 +42,7 @@ class ModelConfigResource(Resource):
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
# get original app model config
original_app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
)
if original_app_model_config is None:
raise ValueError("Original app model config not found")

+ 2
- 2
api/controllers/console/app/site.py 查看文件

@@ -49,7 +49,7 @@ class AppSite(Resource):
if not current_user.is_editor:
raise Forbidden()

site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise NotFound

@@ -93,7 +93,7 @@ class AppSiteAccessTokenReset(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()

site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()

if not site:
raise NotFound

+ 1
- 1
api/controllers/console/app/wraps.py 查看文件

@@ -11,7 +11,7 @@ from models import App, AppMode
def _load_app_model(app_id: str) -> Optional[App]:
app_model = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
return app_model

+ 6
- 0
api/controllers/console/auth/error.py 查看文件

@@ -113,3 +113,9 @@ class MemberNotInTenantError(BaseHTTPException):
error_code = "member_not_in_tenant"
description = "The member is not in the workspace."
code = 400


class AccountInFreezeError(BaseHTTPException):
error_code = "account_in_freeze"
description = "This email is temporarily unavailable."
code = 400

+ 2
- 2
api/controllers/console/datasets/data_source.py 查看文件

@@ -30,7 +30,7 @@ class DataSourceApi(Resource):
# get workspace data source integrates
data_source_integrates = (
db.session.query(DataSourceOauthBinding)
.filter(
.where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False,
)
@@ -171,7 +171,7 @@ class DataSourceNotionApi(Resource):
page_id = str(page_id)
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter(
select(DataSourceOauthBinding).where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",

+ 8
- 8
api/controllers/console/datasets/datasets.py 查看文件

@@ -421,7 +421,7 @@ class DatasetIndexingEstimateApi(Resource):
file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
.all()
)

@@ -526,14 +526,14 @@ class DatasetIndexingStatusApi(Resource):
dataset_id = str(dataset_id)
documents = (
db.session.query(Document)
.filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
.where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
.all()
)
documents_status = []
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
@@ -542,7 +542,7 @@ class DatasetIndexingStatusApi(Resource):
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
# Create a dictionary with document attributes and additional fields
@@ -577,7 +577,7 @@ class DatasetApiKeyApi(Resource):
def get(self):
keys = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.all()
)
return {"items": keys}
@@ -593,7 +593,7 @@ class DatasetApiKeyApi(Resource):

current_key_count = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.count()
)

@@ -629,7 +629,7 @@ class DatasetApiDeleteApi(Resource):

key = (
db.session.query(ApiToken)
.filter(
.where(
ApiToken.tenant_id == current_user.current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
@@ -640,7 +640,7 @@ class DatasetApiDeleteApi(Resource):
if key is None:
flask_restful.abort(404, message="API key not found")

db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

return {"result": "success"}, 204

+ 10
- 10
api/controllers/console/datasets/datasets_document.py 查看文件

@@ -126,7 +126,7 @@ class GetProcessRuleApi(Resource):
# get the latest process rule
dataset_process_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == document.dataset_id)
.where(DatasetProcessRule.dataset_id == document.dataset_id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.one_or_none()
@@ -178,7 +178,7 @@ class DatasetDocumentListApi(Resource):

if search:
search = f"%{search}%"
query = query.filter(Document.name.like(search))
query = query.where(Document.name.like(search))

if sort.startswith("-"):
sort_logic = desc
@@ -214,7 +214,7 @@ class DatasetDocumentListApi(Resource):
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
@@ -223,7 +223,7 @@ class DatasetDocumentListApi(Resource):
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
@@ -419,7 +419,7 @@ class DocumentIndexingEstimateApi(DocumentResource):

file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first()
)

@@ -494,7 +494,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
.first()
)

@@ -570,7 +570,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
@@ -579,7 +579,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
# Create a dictionary with document attributes and additional fields
@@ -613,7 +613,7 @@ class DocumentIndexingStatusApi(DocumentResource):

completed_segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment",
@@ -622,7 +622,7 @@ class DocumentIndexingStatusApi(DocumentResource):
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.count()
)


+ 14
- 14
api/controllers/console/datasets/datasets_segments.py 查看文件

@@ -78,7 +78,7 @@ class DatasetDocumentSegmentListApi(Resource):

query = (
select(DocumentSegment)
.filter(
.where(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id,
)
@@ -86,19 +86,19 @@ class DatasetDocumentSegmentListApi(Resource):
)

if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
query = query.where(DocumentSegment.status.in_(status_list))

if hit_count_gte is not None:
query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
query = query.where(DocumentSegment.hit_count >= hit_count_gte)

if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))

if args["enabled"].lower() != "all":
if args["enabled"].lower() == "true":
query = query.filter(DocumentSegment.enabled == True)
query = query.where(DocumentSegment.enabled == True)
elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False)
query = query.where(DocumentSegment.enabled == False)

segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)

@@ -285,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -331,7 +331,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -436,7 +436,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -493,7 +493,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -540,7 +540,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -586,7 +586,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -595,7 +595,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk:
@@ -635,7 +635,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -644,7 +644,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk:

+ 5
- 5
api/controllers/console/explore/installed_app.py 查看文件

@@ -34,11 +34,11 @@ class InstalledAppsListApi(Resource):
if app_id:
installed_apps = (
db.session.query(InstalledApp)
.filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.all()
)
else:
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()

current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_app_list: list[dict[str, Any]] = [
@@ -94,12 +94,12 @@ class InstalledAppsListApi(Resource):
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args()

recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first()
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
if recommended_app is None:
raise NotFound("App not found")

current_tenant_id = current_user.current_tenant_id
app = db.session.query(App).filter(App.id == args["app_id"]).first()
app = db.session.query(App).where(App.id == args["app_id"]).first()

if app is None:
raise NotFound("App not found")
@@ -109,7 +109,7 @@ class InstalledAppsListApi(Resource):

installed_app = (
db.session.query(InstalledApp)
.filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
.where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
.first()
)


+ 1
- 1
api/controllers/console/explore/wraps.py 查看文件

@@ -28,7 +28,7 @@ def installed_app_required(view=None):

installed_app = (
db.session.query(InstalledApp)
.filter(
.where(
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
)
.first()

+ 1
- 1
api/controllers/console/workspace/__init__.py 查看文件

@@ -21,7 +21,7 @@ def plugin_permission_required(
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission)
.filter(
.where(
TenantPluginPermission.tenant_id == tenant_id,
)
.first()

+ 13
- 5
api/controllers/console/workspace/account.py 查看文件

@@ -9,6 +9,7 @@ from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
from controllers.console.auth.error import (
AccountInFreezeError,
EmailAlreadyInUseError,
EmailChangeLimitError,
EmailCodeError,
@@ -68,7 +69,7 @@ class AccountInitApi(Resource):
# check invitation code
invitation_code = (
db.session.query(InvitationCode)
.filter(
.where(
InvitationCode.code == args["invitation_code"],
InvitationCode.status == "unused",
)
@@ -228,7 +229,7 @@ class AccountIntegrateApi(Resource):
def get(self):
account = current_user

account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all()
account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()

base_url = request.url_root.rstrip("/")
oauth_base_path = "/console/api/oauth/login"
@@ -479,21 +480,28 @@ class ChangeEmailResetApi(Resource):
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()

if AccountService.is_account_in_freeze(args["new_email"]):
raise AccountInFreezeError()

if not AccountService.check_email_unique(args["new_email"]):
raise EmailAlreadyInUseError()

reset_data = AccountService.get_change_email_data(args["token"])
if not reset_data:
raise InvalidTokenError()

AccountService.revoke_change_email_token(args["token"])

if not AccountService.check_email_unique(args["new_email"]):
raise EmailAlreadyInUseError()

old_email = reset_data.get("old_email", "")
if current_user.email != old_email:
raise AccountNotFound()

updated_account = AccountService.update_account(current_user, email=args["new_email"])

AccountService.send_change_email_completed_notify_email(
email=args["new_email"],
)

return updated_account



+ 1
- 1
api/controllers/console/workspace/members.py 查看文件

@@ -108,7 +108,7 @@ class MemberCancelInviteApi(Resource):
@login_required
@account_initialization_required
def delete(self, member_id):
member = db.session.query(Account).filter(Account.id == str(member_id)).first()
member = db.session.query(Account).where(Account.id == str(member_id)).first()
if member is None:
abort(404)
else:

+ 114
- 1
api/controllers/console/workspace/plugin.py 查看文件

@@ -12,7 +12,8 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required
from models.account import TenantPluginPermission
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
@@ -534,6 +535,114 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})


class PluginChangePreferencesApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()

req = reqparse.RequestParser()
req.add_argument("permission", type=dict, required=True, location="json")
req.add_argument("auto_upgrade", type=dict, required=True, location="json")
args = req.parse_args()

tenant_id = user.current_tenant_id

permission = args["permission"]

install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone"))

auto_upgrade = args["auto_upgrade"]

strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting(
auto_upgrade.get("strategy_setting", "fix_only")
)
upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0)
upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude"))
exclude_plugins = auto_upgrade.get("exclude_plugins", [])
include_plugins = auto_upgrade.get("include_plugins", [])

# set permission
set_permission_result = PluginPermissionService.change_permission(
tenant_id,
install_permission,
debug_permission,
)
if not set_permission_result:
return jsonable_encoder({"success": False, "message": "Failed to set permission"})

# set auto upgrade strategy
set_auto_upgrade_strategy_result = PluginAutoUpgradeService.change_strategy(
tenant_id,
strategy_setting,
upgrade_time_of_day,
upgrade_mode,
exclude_plugins,
include_plugins,
)
if not set_auto_upgrade_strategy_result:
return jsonable_encoder({"success": False, "message": "Failed to set auto upgrade strategy"})

return jsonable_encoder({"success": True})


class PluginFetchPreferencesApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id

permission = PluginPermissionService.get_permission(tenant_id)
permission_dict = {
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
}

if permission:
permission_dict["install_permission"] = permission.install_permission
permission_dict["debug_permission"] = permission.debug_permission

auto_upgrade = PluginAutoUpgradeService.get_strategy(tenant_id)
auto_upgrade_dict = {
"strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
"upgrade_time_of_day": 0,
"upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
"exclude_plugins": [],
"include_plugins": [],
}

if auto_upgrade:
auto_upgrade_dict = {
"strategy_setting": auto_upgrade.strategy_setting,
"upgrade_time_of_day": auto_upgrade.upgrade_time_of_day,
"upgrade_mode": auto_upgrade.upgrade_mode,
"exclude_plugins": auto_upgrade.exclude_plugins,
"include_plugins": auto_upgrade.include_plugins,
}

return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})


class PluginAutoUpgradeExcludePluginApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# exclude one single plugin
tenant_id = current_user.current_tenant_id

req = reqparse.RequestParser()
req.add_argument("plugin_id", type=str, required=True, location="json")
args = req.parse_args()

return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})


api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
@@ -560,3 +669,7 @@ api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permissi
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")

api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")

api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch")
api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change")
api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude")

+ 6
- 2
api/controllers/console/workspace/tool_providers.py 查看文件

@@ -739,7 +739,7 @@ class ToolOAuthCallback(Resource):
raise Forbidden("no oauth available client config found for this tool provider")

redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials(
credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
@@ -747,7 +747,10 @@ class ToolOAuthCallback(Resource):
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
).credentials
)

credentials = credentials_response.credentials
expires_at = credentials_response.expires_at

if not credentials:
raise Exception("the plugin credentials failed")
@@ -758,6 +761,7 @@ class ToolOAuthCallback(Resource):
tenant_id=tenant_id,
provider=provider,
credentials=dict(credentials),
expires_at=expires_at,
api_type=CredentialType.OAUTH2,
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

+ 3
- 3
api/controllers/inner_api/plugin/wraps.py 查看文件

@@ -22,7 +22,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
user_id = "DEFAULT-USER"

if user_id == "DEFAULT-USER":
user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
user_model = session.query(EndUser).where(EndUser.session_id == "DEFAULT-USER").first()
if not user_model:
user_model = EndUser(
tenant_id=tenant_id,
@@ -36,7 +36,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
else:
user_model = AccountService.load_user(user_id)
if not user_model:
user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
user_model = session.query(EndUser).where(EndUser.id == user_id).first()
if not user_model:
raise ValueError("user not found")
except Exception:
@@ -71,7 +71,7 @@ def get_user_tenant(view: Optional[Callable] = None):
try:
tenant_model = (
db.session.query(Tenant)
.filter(
.where(
Tenant.id == tenant_id,
)
.first()

+ 1
- 1
api/controllers/inner_api/wraps.py 查看文件

@@ -55,7 +55,7 @@ def enterprise_inner_api_user_auth(view):
if signature_base64 != token:
return view(*args, **kwargs)

kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first()

return view(*args, **kwargs)


+ 2
- 2
api/controllers/mcp/mcp.py 查看文件

@@ -30,7 +30,7 @@ class MCPAppApi(Resource):

request_id = args.get("id")

server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not server:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
@@ -41,7 +41,7 @@ class MCPAppApi(Resource):
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
)

app = db.session.query(App).filter(App.id == server.app_id).first()
app = db.session.query(App).where(App.id == server.app_id).first()
if not app:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")

+ 6
- 0
api/controllers/service_api/app/completion.py 查看文件

@@ -1,5 +1,6 @@
import logging

from flask import request
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound

@@ -23,6 +24,7 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
@@ -111,6 +113,10 @@ class ChatApi(Resource):

args = parser.parse_args()

external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id

streaming = args["response_mode"] == "streaming"

try:

+ 1
- 1
api/controllers/service_api/app/site.py 查看文件

@@ -16,7 +16,7 @@ class AppSiteApi(Resource):
@marshal_with(fields.site_fields)
def get(self, app_model: App):
"""Retrieve app site info."""
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()

if not site:
raise Forbidden()

+ 5
- 1
api/controllers/service_api/app/workflow.py 查看文件

@@ -1,6 +1,7 @@
import logging

from dateutil.parser import isoparse
from flask import request
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session, sessionmaker
@@ -23,6 +24,7 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from extensions.ext_database import db
@@ -90,7 +92,9 @@ class WorkflowRunApi(Resource):
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
args = parser.parse_args()

external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming"

try:

+ 10
- 10
api/controllers/service_api/dataset/document.py 查看文件

@@ -63,7 +63,7 @@ class DocumentAddByTextApi(DatasetApiResource):

dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()

if not dataset:
raise ValueError("Dataset does not exist.")
@@ -136,7 +136,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()

if not dataset:
raise ValueError("Dataset does not exist.")
@@ -206,7 +206,7 @@ class DocumentAddByFileApi(DatasetApiResource):
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()

if not dataset:
raise ValueError("Dataset does not exist.")
@@ -299,7 +299,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()

if not dataset:
raise ValueError("Dataset does not exist.")
@@ -367,7 +367,7 @@ class DocumentDeleteApi(DatasetApiResource):
tenant_id = str(tenant_id)

# get dataset info
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()

if not dataset:
raise ValueError("Dataset does not exist.")
@@ -398,7 +398,7 @@ class DocumentListApi(DatasetApiResource):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")

@@ -406,7 +406,7 @@ class DocumentListApi(DatasetApiResource):

if search:
search = f"%{search}%"
query = query.filter(Document.name.like(search))
query = query.where(Document.name.like(search))

query = query.order_by(desc(Document.created_at), desc(Document.position))

@@ -430,7 +430,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
batch = str(batch)
tenant_id = str(tenant_id)
# get dataset
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# get documents
@@ -441,7 +441,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
@@ -450,7 +450,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
)
total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
# Create a dictionary with document attributes and additional fields

+ 9
- 9
api/controllers/service_api/dataset/segment.py 查看文件

@@ -42,7 +42,7 @@ class SegmentApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
@@ -89,7 +89,7 @@ class SegmentApi(DatasetApiResource):
tenant_id = str(tenant_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
@@ -146,7 +146,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
@@ -170,7 +170,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
@@ -216,7 +216,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
@@ -246,7 +246,7 @@ class ChildChunkApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")

@@ -296,7 +296,7 @@ class ChildChunkApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")

@@ -343,7 +343,7 @@ class DatasetChildChunkApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")

@@ -382,7 +382,7 @@ class DatasetChildChunkApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")


+ 2
- 2
api/controllers/service_api/dataset/upload_file.py 查看文件

@@ -17,7 +17,7 @@ class UploadFileApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
@@ -31,7 +31,7 @@ class UploadFileApi(DatasetApiResource):
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
else:

+ 14
- 14
api/controllers/service_api/wraps.py 查看文件

@@ -44,7 +44,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token("app")

app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
app_model = db.session.query(App).where(App.id == api_token.app_id).first()
if not app_model:
raise Forbidden("The app no longer exists.")

@@ -54,7 +54,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if not app_model.enable_api:
raise Forbidden("The app's API service has been disabled.")

tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first()
if tenant is None:
raise ValueError("Tenant does not exist.")
if tenant.status == TenantStatus.ARCHIVE:
@@ -62,15 +62,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio

tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).filter(Account.id == ta.account_id).first()
account = db.session.query(Account).where(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
@@ -213,15 +213,15 @@ def validate_dataset_token(view=None):
api_token = validate_and_get_api_token("dataset")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).filter(Account.id == ta.account_id).first()
account = db.session.query(Account).where(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
@@ -293,7 +293,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]

end_user = (
db.session.query(EndUser)
.filter(
.where(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
@@ -320,7 +320,7 @@ class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token]

def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()

if not dataset:
raise NotFound("Dataset not found.")

+ 15
- 14
api/controllers/web/passport.py 查看文件

@@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta

from flask import request
from flask_restful import Resource
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound, Unauthorized

from configs import dify_config
@@ -42,17 +43,17 @@ class PassportResource(Resource):
raise WebAppAuthRequiredError()

# get site from db and check if it is normal
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
if not site:
raise NotFound()
# get app from db and check if it is normal and enable_site
app_model = db.session.query(App).filter(App.id == site.app_id).first()
app_model = db.session.scalar(select(App).where(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()

if user_id:
end_user = (
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
end_user = db.session.scalar(
select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
)

if end_user:
@@ -121,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.")

site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
if not site:
raise NotFound()

app_model = db.session.query(App).filter(App.id == site.app_id).first()
app_model = db.session.scalar(select(App).where(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()

@@ -140,16 +141,14 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:

end_user = None
if end_user_id:
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if session_id:
end_user = (
db.session.query(EndUser)
.filter(
end_user = db.session.scalar(
select(EndUser).where(
EndUser.session_id == session_id,
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
)
.first()
)
if not end_user:
if not session_id:
@@ -187,8 +186,8 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
user_id = token_decoded.get("user_id")
end_user = None
if user_id:
end_user = (
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
end_user = db.session.scalar(
select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
)

if not end_user:
@@ -224,6 +223,8 @@ def generate_session_id():
"""
while True:
session_id = str(uuid.uuid4())
existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count()
existing_count = db.session.scalar(
select(func.count()).select_from(EndUser).where(EndUser.session_id == session_id)
)
if existing_count == 0:
return session_id

+ 1
- 1
api/controllers/web/site.py 查看文件

@@ -57,7 +57,7 @@ class AppSiteApi(WebApiResource):
def get(self, app_model, end_user):
"""Retrieve app site info."""
# get site
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()

if not site:
raise Forbidden()

+ 4
- 3
api/controllers/web/wraps.py 查看文件

@@ -3,6 +3,7 @@ from functools import wraps

from flask import request
from flask_restful import Resource
from sqlalchemy import select
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized

from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@@ -48,8 +49,8 @@ def decode_jwt_token():
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
app_model = db.session.query(App).filter(App.id == app_id).first()
site = db.session.query(Site).filter(Site.code == app_code).first()
app_model = db.session.scalar(select(App).where(App.id == app_id))
site = db.session.scalar(select(Site).where(Site.code == app_code))
if not app_model:
raise NotFound()
if not app_code or not site:
@@ -57,7 +58,7 @@ def decode_jwt_token():
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound()


+ 3
- 3
api/core/agent/base_agent_runner.py 查看文件

@@ -99,7 +99,7 @@ class BaseAgentRunner(AppRunner):
# get how many agent thoughts have been created
self.agent_thought_count = (
db.session.query(MessageAgentThought)
.filter(
.where(
MessageAgentThought.message_id == self.message.id,
)
.count()
@@ -336,7 +336,7 @@ class BaseAgentRunner(AppRunner):
Save agent thought
"""
updated_agent_thought = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first()
)
if not updated_agent_thought:
raise ValueError("agent thought not found")
@@ -496,7 +496,7 @@ class BaseAgentRunner(AppRunner):
return result

def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
if not files:
return UserPromptMessage(content=message.query)
if message.app_model_config:

+ 0
- 48
api/core/app/apps/README.md 查看文件

@@ -1,48 +0,0 @@
## Guidelines for Database Connection Management in App Runner and Task Pipeline

Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks.

Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors.

Examples:

1. Creating a new record:

```python
app = App(id=1)
db.session.add(app)
db.session.commit()
db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close
# Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment).
db.session.close()
return app.id
```

2. Fetching a record from the table:

```python
app = db.session.query(App).filter(App.id == app_id).first()
created_at = app.created_at
db.session.close()
# Handle tasks (include long-running).
```

3. Updating a table field:

```python
app = db.session.query(App).filter(App.id == app_id).first()

app.updated_at = time.utcnow()
db.session.commit()
db.session.close()

return app_id
```

+ 51
- 15
api/core/app/apps/advanced_chat/app_generator.py 查看文件

@@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload

from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker

import contexts
from configs import dify_config
@@ -23,6 +24,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
@@ -112,7 +114,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
query = query.replace("\x00", "")
inputs = args["inputs"]

extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
extras = {
"auto_generate_conversation_name": args.get("auto_generate_name", False),
**extract_external_trace_id_from_args(args),
}

# get conversation
conversation = None
@@ -482,21 +487,52 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"""

with preserve_flask_contexts(flask_app, context_vars=context):
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)

# chatbot app
runner = AdvancedChatAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
dialogue_count=self._dialogue_count,
variable_loader=variable_loader,
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)

with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
Workflow.app_id == application_generate_entity.app_config.app_id,
Workflow.id == application_generate_entity.app_config.workflow_id,
)
)
if workflow is None:
raise ValueError("Workflow not found")

# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}

if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id))
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id

app = session.scalar(select(App).where(App.id == application_generate_entity.app_config.app_id))
if app is None:
raise ValueError("App not found")

runner = AdvancedChatAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
dialogue_count=self._dialogue_count,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
app=app,
)

try:
runner.run()
except GenerateTaskStoppedError:
pass

+ 84
- 33
api/core/app/apps/advanced_chat/app_runner.py 查看文件

@@ -1,6 +1,6 @@
import logging
from collections.abc import Mapping
from typing import Any, cast
from typing import Any, Optional, cast

from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -9,13 +9,19 @@ from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueStopEvent,
QueueTextChunkEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
@@ -23,8 +29,9 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, EndUser, Message
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable, WorkflowType

logger = logging.getLogger(__name__)
@@ -37,42 +44,38 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):

def __init__(
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
dialogue_count: int,
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
app: App,
) -> None:
super().__init__(queue_manager, variable_loader)
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
self._dialogue_count = dialogue_count
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
self._workflow = workflow
self.system_user_id = system_user_id
self._app = app

def run(self) -> None:
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)

app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")

workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")

user_id: str | None = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id

workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
@@ -80,14 +83,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
)
@@ -98,7 +101,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):

# moderation
if self.handle_input_moderation(
app_record=app_record,
app_record=self._app,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
@@ -108,7 +111,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):

# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
app_record=self._app,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity,
@@ -128,7 +131,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in workflow.conversation_variables
for variable in self._workflow.conversation_variables
]
session.add_all(db_conversation_variables)
# Convert database entities to variables.
@@ -141,7 +144,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query=query,
files=files,
conversation_id=self.conversation.id,
user_id=user_id,
user_id=self.system_user_id,
dialogue_count=self._dialogue_count,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
@@ -152,25 +155,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables),
)

# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
graph = self._init_graph(graph_config=self._workflow.graph_dict)

db.session.close()

# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
@@ -241,3 +244,51 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._publish_event(QueueTextChunkEvent(text=text))

self._publish_event(QueueStopEvent(stopped_by=stopped_by))

def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
:param app_record: app record
:param message: message
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:return:
"""
annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query(
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
)

def moderation_for_inputs(
self,
*,
app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str | None = None,
message_id: str,
) -> tuple[bool, Mapping[str, Any], str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
:param tenant_id: tenant id
:param app_generate_entity: app generate entity
:param inputs: inputs
:param query: query
:param message_id: message id
:return:
"""
moderation_feature = InputModeration()
return moderation_feature.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=dict(inputs),
query=query or "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
)

+ 4
- 0
api/core/app/apps/advanced_chat/generate_task_pipeline.py 查看文件

@@ -559,6 +559,7 @@ class AdvancedChatAppGenerateTaskPipeline:
outputs=event.outputs,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
@@ -590,6 +591,7 @@ class AdvancedChatAppGenerateTaskPipeline:
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
@@ -622,6 +624,7 @@ class AdvancedChatAppGenerateTaskPipeline:
conversation_id=self._conversation_id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
@@ -653,6 +656,7 @@ class AdvancedChatAppGenerateTaskPipeline:
error_message=event.get_stop_reason(),
conversation_id=self._conversation_id,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,

+ 3
- 3
api/core/app/apps/agent_chat/app_runner.py 查看文件

@@ -45,7 +45,7 @@ class AgentChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)

app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")

@@ -183,10 +183,10 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING

conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = db.session.query(Message).filter(Message.id == message.id).first()
message_result = db.session.query(Message).where(Message.id == message.id).first()
if message_result is None:
raise ValueError("Message not found")
db.session.close()

+ 1
- 1
api/core/app/apps/chat/app_runner.py 查看文件

@@ -43,7 +43,7 @@ class ChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)

app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")


+ 1
- 1
api/core/app/apps/completion/app_generator.py 查看文件

@@ -248,7 +248,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
"""
message = (
db.session.query(Message)
.filter(
.where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),

+ 1
- 1
api/core/app/apps/completion/app_runner.py 查看文件

@@ -36,7 +36,7 @@ class CompletionAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)

app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")


+ 4
- 10
api/core/app/apps/message_based_app_generator.py 查看文件

@@ -85,7 +85,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
if conversation:
app_model_config = (
db.session.query(AppModelConfig)
.filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
)

@@ -151,13 +151,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
introduction = self._get_conversation_introduction(application_generate_entity)

# get conversation name
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
query = application_generate_entity.query or "New conversation"
else:
query = next(iter(application_generate_entity.inputs.values()), "New conversation")
if isinstance(query, int):
query = str(query)
query = query or "New conversation"
query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "…") if len(query) > 20 else query

if not conversation:
@@ -259,7 +253,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id
:return: conversation
"""
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()

if not conversation:
raise ConversationNotExistsError("Conversation not exists")
@@ -272,7 +266,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id
:return: message
"""
message = db.session.query(Message).filter(Message.id == message_id).first()
message = db.session.query(Message).where(Message.id == message_id).first()

if message is None:
raise MessageNotExistsError("Message not exists")

+ 43
- 11
api/core/app/apps/workflow/app_generator.py 查看文件

@@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload

from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker

import contexts
from configs import dify_config
@@ -22,6 +23,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import DifyCoreRepositoryFactory
@@ -123,6 +125,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
)

inputs: Mapping[str, Any] = args["inputs"]

extras = {
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
@@ -142,6 +148,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
call_depth=call_depth,
trace_manager=trace_manager,
workflow_execution_id=workflow_run_id,
extras=extras,
)

contexts.plugin_tool_providers.set({})
@@ -439,17 +446,44 @@ class WorkflowAppGenerator(BaseAppGenerator):
"""

with preserve_flask_contexts(flask_app, context_vars=context):
try:
# workflow app
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
Workflow.app_id == application_generate_entity.app_config.app_id,
Workflow.id == application_generate_entity.app_config.workflow_id,
)
)
if workflow is None:
raise ValueError("Workflow not found")

# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}

if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id))
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id

runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
)

try:
runner.run()
except GenerateTaskStoppedError:
except GenerateTaskStoppedError as e:
logger.warning(f"Task stopped: {str(e)}")
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
@@ -465,8 +499,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()

def _handle_response(
self,

+ 21
- 40
api/core/app/apps/workflow/app_runner.py 查看文件

@@ -14,10 +14,8 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom
from models.model import App, EndUser
from models.workflow import WorkflowType
from models.workflow import Workflow, WorkflowType

logger = logging.getLogger(__name__)

@@ -29,22 +27,23 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):

def __init__(
self,
*,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
workflow: Workflow,
system_user_id: str,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
super().__init__(queue_manager, variable_loader)
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id

def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
self._workflow = workflow
self._sys_user_id = system_user_id

def run(self) -> None:
"""
@@ -53,24 +52,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)

user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id

app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")

workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")

db.session.close()

workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
@@ -79,14 +60,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
)
@@ -98,7 +79,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):

system_inputs = SystemVariable(
files=files,
user_id=user_id,
user_id=self._sys_user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
@@ -107,21 +88,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
environment_variables=self._workflow.environment_variables,
conversation_variables=[],
)

# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
graph = self._init_graph(graph_config=self._workflow.graph_dict)

# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT

+ 3
- 0
api/core/app/apps/workflow/generate_task_pipeline.py 查看文件

@@ -490,6 +490,7 @@ class WorkflowAppGenerateTaskPipeline:
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)

# save workflow app log
@@ -524,6 +525,7 @@ class WorkflowAppGenerateTaskPipeline:
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)

# save workflow app log
@@ -561,6 +563,7 @@ class WorkflowAppGenerateTaskPipeline:
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)

# save workflow app log

+ 12
- 27
api/core/app/apps/workflow_app_runner.py 查看文件

@@ -1,8 +1,7 @@
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast

from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@@ -65,18 +64,20 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
from models.workflow import Workflow


class WorkflowBasedAppRunner(AppRunner):
def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None:
self.queue_manager = queue_manager
class WorkflowBasedAppRunner:
def __init__(
self,
*,
queue_manager: AppQueueManager,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
app_id: str,
) -> None:
self._queue_manager = queue_manager
self._variable_loader = variable_loader

def _get_app_id(self) -> str:
raise NotImplementedError("not implemented")
self._app_id = app_id

def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
@@ -693,21 +694,5 @@ class WorkflowBasedAppRunner(AppRunner):
)
)

def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)

# return workflow
return workflow

def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

+ 1
- 1
api/core/app/features/annotation_reply/annotation_reply.py 查看文件

@@ -26,7 +26,7 @@ class AnnotationReplyFeature:
:return:
"""
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first()
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
)

if not annotation_setting:

+ 1
- 1
api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py 查看文件

@@ -471,7 +471,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:return:
"""
agent_thought: Optional[MessageAgentThought] = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)

if agent_thought:

+ 2
- 2
api/core/app/task_pipeline/message_cycle_manager.py 查看文件

@@ -81,7 +81,7 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()

if not conversation:
return
@@ -140,7 +140,7 @@ class MessageCycleManager:
:param event: event
:return:
"""
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()

if message_file and message_file.url is not None:
# get tool file id

+ 5
- 5
api/core/callback_handler/index_tool_callback_handler.py 查看文件

@@ -49,7 +49,7 @@ class DatasetIndexToolCallbackHandler:
for document in documents:
if document.metadata is not None:
document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
@@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.filter(
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
@@ -69,18 +69,18 @@ class DatasetIndexToolCallbackHandler:
if child_chunk:
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
)
else:
query = db.session.query(DocumentSegment).filter(
query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)

if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])

# add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)

+ 7
- 7
api/core/entities/provider_configuration.py 查看文件

@@ -191,7 +191,7 @@ class ProviderConfiguration(BaseModel):

provider_record = (
db.session.query(Provider)
.filter(
.where(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names),
@@ -351,7 +351,7 @@ class ProviderConfiguration(BaseModel):

provider_model_record = (
db.session.query(ProviderModel)
.filter(
.where(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model,
@@ -481,7 +481,7 @@ class ProviderConfiguration(BaseModel):

return (
db.session.query(ProviderModelSetting)
.filter(
.where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
@@ -560,7 +560,7 @@ class ProviderConfiguration(BaseModel):

return (
db.session.query(LoadBalancingModelConfig)
.filter(
.where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
@@ -583,7 +583,7 @@ class ProviderConfiguration(BaseModel):

load_balancing_config_count = (
db.session.query(LoadBalancingModelConfig)
.filter(
.where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
@@ -627,7 +627,7 @@ class ProviderConfiguration(BaseModel):

model_setting = (
db.session.query(ProviderModelSetting)
.filter(
.where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
@@ -693,7 +693,7 @@ class ProviderConfiguration(BaseModel):

preferred_model_provider = (
db.session.query(TenantPreferredModelProvider)
.filter(
.where(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names),
)

+ 2
- 2
api/core/external_data_tool/api/api.py 查看文件

@@ -32,7 +32,7 @@ class ApiExternalDataTool(ExternalDataTool):
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)

@@ -56,7 +56,7 @@ class ApiExternalDataTool(ExternalDataTool):
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)


+ 1
- 1
api/core/helper/encrypter.py 查看文件

@@ -15,7 +15,7 @@ def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant
from models.engine import db

if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found")
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()

+ 20
- 0
api/core/helper/marketplace.py 查看文件

@@ -25,9 +25,29 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()

return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]


def batch_fetch_plugin_manifests_ignore_deserialization_error(
plugin_ids: list[str],
) -> Sequence[MarketplacePluginDeclaration]:
if len(plugin_ids) == 0:
return []

url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
try:
result.append(MarketplacePluginDeclaration(**plugin))
except Exception as e:
pass

return result


def record_install_plugin_event(plugin_unique_identifier: str):
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})

+ 42
- 0
api/core/helper/trace_id_helper.py 查看文件

@@ -0,0 +1,42 @@
import re
from collections.abc import Mapping
from typing import Any, Optional


def is_valid_trace_id(trace_id: str) -> bool:
"""
Check if the trace_id is valid.

Requirements: 1-128 characters, only letters, numbers, '-', and '_'.
"""
return bool(re.match(r"^[a-zA-Z0-9\-_]{1,128}$", trace_id))


def get_external_trace_id(request: Any) -> Optional[str]:
"""
Retrieve the trace_id from the request.

Priority: header ('X-Trace-Id'), then parameters, then JSON body. Returns None if not provided or invalid.
"""
trace_id = request.headers.get("X-Trace-Id")
if not trace_id:
trace_id = request.args.get("trace_id")
if not trace_id and getattr(request, "is_json", False):
json_data = getattr(request, "json", None)
if json_data:
trace_id = json_data.get("trace_id")
if isinstance(trace_id, str) and is_valid_trace_id(trace_id):
return trace_id
return None


def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict:
"""
Extract 'external_trace_id' from args.

Returns a dict suitable for use in extras. Returns an empty dict if not found.
"""
trace_id = args.get("external_trace_id")
if trace_id:
return {"external_trace_id": trace_id}
return {}

+ 9
- 10
api/core/indexing_runner.py 查看文件

@@ -59,7 +59,7 @@ class IndexingRunner:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@@ -119,12 +119,12 @@ class IndexingRunner:
db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@@ -212,7 +212,7 @@ class IndexingRunner:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)

@@ -316,7 +316,7 @@ class IndexingRunner:
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
try:
@@ -346,7 +346,7 @@ class IndexingRunner:
raise ValueError("no upload file found")

file_detail = (
db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
)

if file_detail:
@@ -599,7 +599,7 @@ class IndexingRunner:
keyword.create(documents)
if dataset.indexing_technique != "high_quality":
document_ids = [document.metadata["doc_id"] for document in documents]
db.session.query(DocumentSegment).filter(
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids),
@@ -630,7 +630,7 @@ class IndexingRunner:
index_processor.load(dataset, chunk_documents, with_keywords=False)

document_ids = [document.metadata["doc_id"] for document in chunk_documents]
db.session.query(DocumentSegment).filter(
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids),
@@ -672,8 +672,7 @@ class IndexingRunner:

if extra_update_params:
update_params.update(extra_update_params)

db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params)
db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) # type: ignore
db.session.commit()

@staticmethod

+ 2
- 1
api/core/llm_generator/llm_generator.py 查看文件

@@ -114,7 +114,8 @@ class LLMGenerator:
),
)

questions = output_parser.parse(cast(str, response.message.content))
text_content = response.message.get_text_content()
questions = output_parser.parse(text_content) if text_content else []
except InvokeError:
questions = []
except Exception:

+ 0
- 1
api/core/llm_generator/output_parser/suggested_questions_after_answer.py 查看文件

@@ -15,5 +15,4 @@ class SuggestedQuestionsAfterAnswerOutputParser:
json_obj = json.loads(action_match.group(0).strip())
else:
json_obj = []

return json_obj

+ 2
- 2
api/core/mcp/server/streamable_http.py 查看文件

@@ -28,7 +28,7 @@ class MCPServerStreamableHTTPRequestHandler:
):
self.app = app
self.request = request
mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first()
mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
if not mcp_server:
raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = mcp_server
@@ -192,7 +192,7 @@ class MCPServerStreamableHTTPRequestHandler:
def retrieve_end_user(self):
return (
db.session.query(EndUser)
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)


+ 1
- 1
api/core/memory/token_buffer_memory.py 查看文件

@@ -67,7 +67,7 @@ class TokenBufferMemory:

prompt_messages: list[PromptMessage] = []
for message in messages:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:

+ 17
- 0
api/core/model_runtime/entities/message_entities.py 查看文件

@@ -156,6 +156,23 @@ class PromptMessage(ABC, BaseModel):
"""
return not self.content

def get_text_content(self) -> str:
"""
Get text content from prompt message.

:return: Text content as string, empty string if no text content
"""
if isinstance(self.content, str):
return self.content
elif isinstance(self.content, list):
text_parts = []
for item in self.content:
if isinstance(item, TextPromptMessageContent):
text_parts.append(item.data)
return "".join(text_parts)
else:
return ""

@field_validator("content", mode="before")
@classmethod
def validate_content(cls, v):

+ 1
- 1
api/core/moderation/api/api.py 查看文件

@@ -89,7 +89,7 @@ class ApiModeration(Moderation):
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)


+ 5
- 4
api/core/ops/aliyun_trace/aliyun_trace.py 查看文件

@@ -101,7 +101,8 @@ class AliyunDataTrace(BaseTraceInstance):
raise ValueError(f"Aliyun get run url failed: {str(e)}")

def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
external_trace_id = trace_info.metadata.get("external_trace_id")
trace_id = external_trace_id or convert_to_trace_id(trace_info.workflow_run_id)
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
self.add_workflow_span(trace_id, workflow_span_id, trace_info)

@@ -119,7 +120,7 @@ class AliyunDataTrace(BaseTraceInstance):
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
user_id = end_user_data.session_id
@@ -243,14 +244,14 @@ class AliyunDataTrace(BaseTraceInstance):
if not app_id:
raise ValueError("No app_id found in trace_info metadata")

app = session.query(App).filter(App.id == app_id).first()
app = session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")

if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")

service_account = session.query(Account).filter(Account.id == app.created_by).first()
service_account = session.query(Account).where(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (

+ 4
- 3
api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py 查看文件

@@ -153,7 +153,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
workflow_metadata.update(trace_info.metadata)

trace_id = uuid_to_trace_id(trace_info.workflow_run_id)
external_trace_id = trace_info.metadata.get("external_trace_id")
trace_id = external_trace_id or uuid_to_trace_id(trace_info.workflow_run_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
@@ -296,7 +297,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
# Add end user data if available
if trace_info.message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first()
db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
)
if end_user_data is not None:
message_metadata["end_user_id"] = end_user_data.session_id
@@ -702,7 +703,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
)
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.all()
)
return workflow_nodes

+ 2
- 2
api/core/ops/base_trace_instance.py 查看文件

@@ -44,14 +44,14 @@ class BaseTraceInstance(ABC):
"""
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app = session.query(App).filter(App.id == app_id).first()
app = session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")

if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")

service_account = session.query(Account).filter(Account.id == app.created_by).first()
service_account = session.query(Account).where(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")


+ 4
- 3
api/core/ops/langfuse_trace/langfuse_trace.py 查看文件

@@ -67,13 +67,14 @@ class LangFuseDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)

def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.workflow_run_id
external_trace_id = trace_info.metadata.get("external_trace_id")
trace_id = external_trace_id or trace_info.workflow_run_id
user_id = trace_info.metadata.get("user_id")
metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id

if trace_info.message_id:
trace_id = trace_info.message_id
trace_id = external_trace_id or trace_info.message_id
name = TraceTaskName.MESSAGE_TRACE.value
trace_data = LangfuseTrace(
id=trace_id,
@@ -243,7 +244,7 @@ class LangFuseDataTrace(BaseTraceInstance):
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
user_id = end_user_data.session_id

+ 3
- 2
api/core/ops/langsmith_trace/langsmith_trace.py 查看文件

@@ -65,7 +65,8 @@ class LangSmithDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)

def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.message_id or trace_info.workflow_run_id
external_trace_id = trace_info.metadata.get("external_trace_id")
trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id
if trace_info.start_time is None:
trace_info.start_time = datetime.now()
message_dotted_order = (
@@ -261,7 +262,7 @@ class LangSmithDataTrace(BaseTraceInstance):

if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id

+ 4
- 3
api/core/ops/opik_trace/opik_trace.py 查看文件

@@ -96,7 +96,8 @@ class OpikDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)

def workflow_trace(self, trace_info: WorkflowTraceInfo):
dify_trace_id = trace_info.workflow_run_id
external_trace_id = trace_info.metadata.get("external_trace_id")
dify_trace_id = external_trace_id or trace_info.workflow_run_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
workflow_metadata = wrap_metadata(
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
@@ -104,7 +105,7 @@ class OpikDataTrace(BaseTraceInstance):
root_span_id = None

if trace_info.message_id:
dify_trace_id = trace_info.message_id
dify_trace_id = external_trace_id or trace_info.message_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)

trace_data = {
@@ -283,7 +284,7 @@ class OpikDataTrace(BaseTraceInstance):

if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id

+ 12
- 8
api/core/ops/ops_trace_manager.py 查看文件

@@ -218,7 +218,7 @@ class OpsTraceManager:
"""
trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)

@@ -226,7 +226,7 @@ class OpsTraceManager:
return None

# decrypt_token
app = db.session.query(App).filter(App.id == app_id).first()
app = db.session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError("App not found")

@@ -253,7 +253,7 @@ class OpsTraceManager:
if app_id is None:
return None

app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
app: Optional[App] = db.session.query(App).where(App.id == app_id).first()

if app is None:
return None
@@ -293,18 +293,18 @@ class OpsTraceManager:
@classmethod
def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None
message_data = db.session.query(Message).filter(Message.id == message_id).first()
message_data = db.session.query(Message).where(Message.id == message_id).first()
if not message_data:
return None
conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
if not conversation_data:
return None

if conversation_data.app_model_config_id:
app_model_config = (
db.session.query(AppModelConfig)
.filter(AppModelConfig.id == conversation_data.app_model_config_id)
.where(AppModelConfig.id == conversation_data.app_model_config_id)
.first()
)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
@@ -331,7 +331,7 @@ class OpsTraceManager:
if tracing_provider is not None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")

app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()
if not app_config:
raise ValueError("App not found")
app_config.tracing = json.dumps(
@@ -349,7 +349,7 @@ class OpsTraceManager:
:param app_id: app id
:return:
"""
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError("App not found")
if not app.tracing:
@@ -520,6 +520,10 @@ class TraceTask:
"app_id": workflow_run.app_id,
}

external_trace_id = self.kwargs.get("external_trace_id")
if external_trace_id:
metadata["external_trace_id"] = external_trace_id

workflow_trace_info = WorkflowTraceInfo(
workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id,

+ 3
- 1
api/core/ops/utils.py 查看文件

@@ -3,6 +3,8 @@ from datetime import datetime
from typing import Optional, Union
from urllib.parse import urlparse

from sqlalchemy import select

from extensions.ext_database import db
from models.model import Message

@@ -20,7 +22,7 @@ def filter_none_values(data: dict):


def get_message_data(message_id: str):
return db.session.query(Message).filter(Message.id == message_id).first()
return db.session.scalar(select(Message).where(Message.id == message_id))


@contextmanager

+ 3
- 2
api/core/ops/weave_trace/weave_trace.py 查看文件

@@ -87,7 +87,8 @@ class WeaveDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)

def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.message_id or trace_info.workflow_run_id
external_trace_id = trace_info.metadata.get("external_trace_id")
trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id
if trace_info.start_time is None:
trace_info.start_time = datetime.now()

@@ -234,7 +235,7 @@ class WeaveDataTrace(BaseTraceInstance):

if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id

+ 3
- 3
api/core/plugin/backwards_invocation/app.py 查看文件

@@ -193,9 +193,9 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get the user by user id
"""

user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
user = db.session.query(EndUser).where(EndUser.id == user_id).first()
if not user:
user = db.session.query(Account).filter(Account.id == user_id).first()
user = db.session.query(Account).where(Account.id == user_id).first()

if not user:
raise ValueError("user not found")
@@ -208,7 +208,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get app
"""
try:
app = db.session.query(App).filter(App.id == app_id).filter(App.tenant_id == tenant_id).first()
app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first()
except Exception:
raise ValueError("app not found")


+ 1
- 0
api/core/plugin/entities/plugin_daemon.py 查看文件

@@ -194,6 +194,7 @@ class PluginOAuthCredentialsResponse(BaseModel):
metadata: Mapping[str, Any] = Field(
default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc."
)
expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.")
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")



+ 35
- 0
api/core/plugin/impl/oauth.py 查看文件

@@ -84,6 +84,41 @@ class OAuthHandler(BasePluginClient):
except Exception as e:
raise ValueError(f"Error getting credentials: {e}")

def refresh_credentials(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any],
credentials: Mapping[str, Any],
) -> PluginOAuthCredentialsResponse:
try:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for refresh credentials request.")
except Exception as e:
raise ValueError(f"Error refreshing credentials: {e}")

def _convert_request_to_raw_data(self, request: Request) -> bytes:
"""
Convert a Request object to raw HTTP data.

+ 3
- 3
api/core/provider_manager.py 查看文件

@@ -275,7 +275,7 @@ class ProviderManager:
# Get the corresponding TenantDefaultModel record
default_model = (
db.session.query(TenantDefaultModel)
.filter(
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
@@ -367,7 +367,7 @@ class ProviderManager:
# Get the list of available models from get_configurations and check if it is LLM
default_model = (
db.session.query(TenantDefaultModel)
.filter(
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
@@ -541,7 +541,7 @@ class ProviderManager:
db.session.rollback()
existed_provider_record = (
db.session.query(Provider)
.filter(
.where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,

+ 3
- 3
api/core/rag/datasource/keyword/jieba/jieba.py 查看文件

@@ -94,11 +94,11 @@ class Jieba(BaseKeyword):

documents = []
for chunk_index in sorted_chunk_indices:
segment_query = db.session.query(DocumentSegment).filter(
segment_query = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
)
if document_ids_filter:
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first()

if segment:
@@ -215,7 +215,7 @@ class Jieba(BaseKeyword):
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
)
if document_segment:

+ 6
- 6
api/core/rag/datasource/retrieval_service.py 查看文件

@@ -127,7 +127,7 @@ class RetrievalService:
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
return []
metadata_condition = (
@@ -145,7 +145,7 @@ class RetrievalService:
@classmethod
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
with Session(db.engine) as session:
return session.query(Dataset).filter(Dataset.id == dataset_id).first()
return session.query(Dataset).where(Dataset.id == dataset_id).first()

@classmethod
def keyword_search(
@@ -294,7 +294,7 @@ class RetrievalService:
dataset_documents = {
doc.id: doc
for doc in db.session.query(DatasetDocument)
.filter(DatasetDocument.id.in_(document_ids))
.where(DatasetDocument.id.in_(document_ids))
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
.all()
}
@@ -318,7 +318,7 @@ class RetrievalService:
child_index_node_id = document.metadata.get("doc_id")

child_chunk = (
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first()
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
)

if not child_chunk:
@@ -326,7 +326,7 @@ class RetrievalService:

segment = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
@@ -381,7 +381,7 @@ class RetrievalService:

segment = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",

+ 2
- 2
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py 查看文件

@@ -6,7 +6,7 @@ from uuid import UUID, uuid4
from numpy import ndarray
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
from pydantic import BaseModel, model_validator
from sqlalchemy import Float, String, create_engine, insert, select, text
from sqlalchemy import Float, create_engine, insert, select, text
from sqlalchemy import text as sql_text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column
@@ -67,7 +67,7 @@ class PGVectoRS(BaseVector):
postgresql.UUID(as_uuid=True),
primary_key=True,
)
text: Mapped[str] = mapped_column(String)
text: Mapped[str]
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
vector: Mapped[ndarray] = mapped_column(VECTOR(dim))


+ 1
- 1
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py 查看文件

@@ -443,7 +443,7 @@ class QdrantVectorFactory(AbstractVectorFactory):
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:

+ 93
- 33
api/core/rag/datasource/vdb/tablestore/tablestore_vector.py 查看文件

@@ -118,10 +118,21 @@ class TableStoreVector(BaseVector):

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
return self._search_by_vector(query_vector, top_k)
document_ids_filter = kwargs.get("document_ids_filter")
filtered_list = None
if document_ids_filter:
filtered_list = ["document_id=" + item for item in document_ids_filter]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._search_by_vector(query_vector, filtered_list, top_k, score_threshold)

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._search_by_full_text(query)
top_k = kwargs.get("top_k", 4)
document_ids_filter = kwargs.get("document_ids_filter")
filtered_list = None
if document_ids_filter:
filtered_list = ["document_id=" + item for item in document_ids_filter]

return self._search_by_full_text(query, filtered_list, top_k)

def delete(self) -> None:
self._delete_table_if_exist()
@@ -230,32 +241,51 @@ class TableStoreVector(BaseVector):
primary_key = [("id", id)]
row = tablestore.Row(primary_key)
self._tablestore_client.delete_row(self._table_name, row, None)
logging.info("Tablestore delete row successfully. id:%s", id)

def _search_by_metadata(self, key: str, value: str) -> list[str]:
query = tablestore.SearchQuery(
tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)),
limit=100,
limit=1000,
get_total_count=False,
)
rows: list[str] = []
next_token = None
while True:
if next_token is not None:
query.next_token = next_token

search_response = self._tablestore_client.search(
table_name=self._table_name,
index_name=self._index_name,
search_query=query,
columns_to_get=tablestore.ColumnsToGet(
column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED
),
)

search_response = self._tablestore_client.search(
table_name=self._table_name,
index_name=self._index_name,
search_query=query,
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
)
if search_response is not None:
rows.extend([row[0][0][1] for row in search_response.rows])

return [row[0][0][1] for row in search_response.rows]
if search_response is None or search_response.next_token == b"":
break
else:
next_token = search_response.next_token

def _search_by_vector(self, query_vector: list[float], top_k: int) -> list[Document]:
ots_query = tablestore.KnnVectorQuery(
return rows

def _search_by_vector(
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
) -> list[Document]:
knn_vector_query = tablestore.KnnVectorQuery(
field_name=Field.VECTOR.value,
top_k=top_k,
float32_query_vector=query_vector,
)
if document_ids_filter:
knn_vector_query.filter = tablestore.TermsQuery(self._tags_field, document_ids_filter)

sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)])
search_query = tablestore.SearchQuery(ots_query, limit=top_k, get_total_count=False, sort=sort)
search_query = tablestore.SearchQuery(knn_vector_query, limit=top_k, get_total_count=False, sort=sort)

search_response = self._tablestore_client.search(
table_name=self._table_name,
@@ -263,30 +293,42 @@ class TableStoreVector(BaseVector):
search_query=search_query,
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
)
logging.info(
"Tablestore search successfully. request_id:%s",
search_response.request_id,
)
return self._to_query_result(search_response)

def _to_query_result(self, search_response: tablestore.SearchResponse) -> list[Document]:
documents = []
for row in search_response.rows:
documents.append(
Document(
page_content=row[1][2][1],
vector=json.loads(row[1][3][1]),
metadata=json.loads(row[1][0][1]),
)
)
for search_hit in search_response.search_hits:
if search_hit.score > score_threshold:
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]

vector_str = ots_column_map.get(Field.VECTOR.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)

vector = json.loads(vector_str) if vector_str else None
metadata = json.loads(metadata_str) if metadata_str else {}

metadata["score"] = search_hit.score

documents.append(
Document(
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
vector=vector,
metadata=metadata,
)
)
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents

def _search_by_full_text(self, query: str) -> list[Document]:
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]:
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))

if document_ids_filter:
bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter))

search_query = tablestore.SearchQuery(
query=tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value),
query=bool_query,
sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]),
limit=100,
limit=top_k,
)
search_response = self._tablestore_client.search(
table_name=self._table_name,
@@ -295,7 +337,25 @@ class TableStoreVector(BaseVector):
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
)

return self._to_query_result(search_response)
documents = []
for search_hit in search_response.search_hits:
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]

vector_str = ots_column_map.get(Field.VECTOR.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
vector = json.loads(vector_str) if vector_str else None
metadata = json.loads(metadata_str) if metadata_str else {}

documents.append(
Document(
page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
vector=vector,
metadata=metadata,
)
)
return documents


class TableStoreVectorFactory(AbstractVectorFactory):

+ 2
- 1
api/core/rag/datasource/vdb/tencent/tencent_vector.py 查看文件

@@ -284,7 +284,8 @@ class TencentVector(BaseVector):
# Compatible with version 1.1.3 and below.
meta = json.loads(meta)
score = 1 - result.get("score", 0.0)
score = result.get("score", 0.0)
else:
score = result.get("score", 0.0)
if score > score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)

+ 3
- 3
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py 查看文件

@@ -418,13 +418,13 @@ class TidbOnQdrantVector(BaseVector):
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
tidb_auth_binding = (
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
if not tidb_auth_binding:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.where(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if tidb_auth_binding:
@@ -433,7 +433,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
else:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)

+ 1
- 1
api/core/rag/datasource/vdb/vector_factory.py 查看文件

@@ -47,7 +47,7 @@ class Vector:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = (
db.session.query(Whitelist)
.filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none()
)
if whitelist:

+ 4
- 4
api/core/rag/docstore/dataset_docstore.py 查看文件

@@ -42,7 +42,7 @@ class DatasetDocumentStore:
@property
def docs(self) -> dict[str, Document]:
document_segments = (
db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all()
db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
)

output = {}
@@ -63,7 +63,7 @@ class DatasetDocumentStore:
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None:
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == self._document_id)
.where(DocumentSegment.document_id == self._document_id)
.scalar()
)

@@ -147,7 +147,7 @@ class DatasetDocumentStore:
segment_document.tokens = tokens
if save_child and doc.children:
# delete the existing child chunks
db.session.query(ChildChunk).filter(
db.session.query(ChildChunk).where(
ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.dataset_id == self._dataset.id,
ChildChunk.document_id == self._document_id,
@@ -230,7 +230,7 @@ class DatasetDocumentStore:
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
document_segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.first()
)


+ 4
- 3
api/core/rag/extractor/notion_extractor.py 查看文件

@@ -331,9 +331,10 @@ class NotionExtractor(BaseExtractor):
last_edited_time = self.get_notion_last_edited_time()
data_source_info = document_model.data_source_info_dict
data_source_info["last_edited_time"] = last_edited_time
update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)}

db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params)
db.session.query(DocumentModel).filter_by(id=document_model.id).update(
{DocumentModel.data_source_info: json.dumps(data_source_info)}
) # type: ignore
db.session.commit()

def get_notion_last_edited_time(self) -> str:
@@ -365,7 +366,7 @@ class NotionExtractor(BaseExtractor):
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
.where(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",

+ 3
- 3
api/core/rag/index_processor/processor/parent_child_index_processor.py 查看文件

@@ -121,7 +121,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = (
db.session.query(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
.where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id,
@@ -131,7 +131,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids]
vector.delete_by_ids(child_node_ids)
if delete_child_chunks:
db.session.query(ChildChunk).filter(
db.session.query(ChildChunk).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete()
db.session.commit()
@@ -139,7 +139,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
vector.delete()

if delete_child_chunks:
db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete()
db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete()
db.session.commit()

def retrieve(

+ 14
- 14
api/core/rag/retrieval/dataset_retrieval.py 查看文件

@@ -135,7 +135,7 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()

# pass if dataset is not available
if not dataset:
@@ -242,7 +242,7 @@ class DatasetRetrieval:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument)
.filter(
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
@@ -327,7 +327,7 @@ class DatasetRetrieval:

if dataset_id:
# get retrieval model config
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
results = []
if dataset.provider == "external":
@@ -516,14 +516,14 @@ class DatasetRetrieval:
if document.metadata is not None:
dataset_document = (
db.session.query(DatasetDocument)
.filter(DatasetDocument.id == document.metadata["document_id"])
.where(DatasetDocument.id == document.metadata["document_id"])
.first()
)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.filter(
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
@@ -533,7 +533,7 @@ class DatasetRetrieval:
if child_chunk:
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
@@ -541,13 +541,13 @@ class DatasetRetrieval:
)
db.session.commit()
else:
query = db.session.query(DocumentSegment).filter(
query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)

# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])

# add hit count to document segment
query.update(
@@ -600,7 +600,7 @@ class DatasetRetrieval:
):
with flask_app.app_context():
with Session(db.engine) as session:
dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()

if not dataset:
return []
@@ -685,7 +685,7 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()

# pass if dataset is not available
if not dataset:
@@ -862,7 +862,7 @@ class DatasetRetrieval:
metadata_filtering_conditions: Optional[MetadataFilteringCondition],
inputs: dict,
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
document_query = db.session.query(DatasetDocument).filter(
document_query = db.session.query(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@@ -930,9 +930,9 @@ class DatasetRetrieval:
raise ValueError("Invalid metadata filtering mode")
if filters:
if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore
document_query = document_query.filter(and_(*filters))
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.filter(or_(*filters))
document_query = document_query.where(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
@@ -958,7 +958,7 @@ class DatasetRetrieval:
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> Optional[list[dict[str, Any]]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config
if metadata_model_config is None:

+ 1
- 1
api/core/tools/custom_tool/provider.py 查看文件

@@ -178,7 +178,7 @@ class ApiToolProviderController(ToolProviderController):
# get tenant api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.all()
)


+ 4
- 4
api/core/tools/tool_file_manager.py 查看文件

@@ -160,7 +160,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.filter(
.where(
ToolFile.id == id,
)
.first()
@@ -184,7 +184,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session:
message_file: MessageFile | None = (
session.query(MessageFile)
.filter(
.where(
MessageFile.id == id,
)
.first()
@@ -204,7 +204,7 @@ class ToolFileManager:

tool_file: ToolFile | None = (
session.query(ToolFile)
.filter(
.where(
ToolFile.id == tool_file_id,
)
.first()
@@ -228,7 +228,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.filter(
.where(
ToolFile.id == tool_file_id,
)
.first()

+ 3
- 3
api/core/tools/tool_label_manager.py 查看文件

@@ -29,7 +29,7 @@ class ToolLabelManager:
raise ValueError("Unsupported tool type")

# delete old labels
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete()
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()

# insert new labels
for label in labels:
@@ -57,7 +57,7 @@ class ToolLabelManager:

labels = (
db.session.query(ToolLabelBinding.label_name)
.filter(
.where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
)
@@ -90,7 +90,7 @@ class ToolLabelManager:
provider_ids.append(controller.provider_id)

labels: list[ToolLabelBinding] = (
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all()
)

tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}

+ 56
- 16
api/core/tools/tool_manager.py 查看文件

@@ -1,16 +1,19 @@
import json
import logging
import mimetypes
from collections.abc import Generator
import time
from collections.abc import Generator, Mapping
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast

from pydantic import TypeAdapter
from yarl import URL

import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@@ -195,7 +198,7 @@ class ToolManager:
try:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
@@ -213,7 +216,7 @@ class ToolManager:
# use the default provider
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
@@ -226,7 +229,7 @@ class ToolManager:
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
@@ -244,12 +247,47 @@ class ToolManager:
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)

# decrypt the credentials
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)

# check if the credentials is expired
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from services.tools.builtin_tools_manage_service import BuiltinToolManageService

# refresh the credentials
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = (
TypeAdapter(dict[str, Any])
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
.decode("utf-8")
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials

return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=encrypter.decrypt(builtin_provider.credentials),
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
@@ -278,7 +316,7 @@ class ToolManager:
elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
)

@@ -578,7 +616,7 @@ class ToolManager:
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()

@classmethod
def list_providers_from_api(
@@ -626,7 +664,7 @@ class ToolManager:
# get db api providers
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
)

api_provider_controllers: list[dict[str, Any]] = [
@@ -649,7 +687,7 @@ class ToolManager:
if "workflow" in filters:
# get workflow providers
workflow_providers: list[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
)

workflow_provider_controllers: list[WorkflowToolProviderController] = []
@@ -693,7 +731,7 @@ class ToolManager:
"""
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
.where(
ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tenant_id,
)
@@ -730,7 +768,7 @@ class ToolManager:
"""
provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.filter(
.where(
MCPToolProvider.server_identifier == provider_id,
MCPToolProvider.tenant_id == tenant_id,
)
@@ -755,7 +793,7 @@ class ToolManager:
provider_name = provider
provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
)
@@ -847,7 +885,7 @@ class ToolManager:
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
)

@@ -864,7 +902,7 @@ class ToolManager:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
)

@@ -881,7 +919,7 @@ class ToolManager:
try:
mcp_provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.first()
)

@@ -973,7 +1011,9 @@ class ToolManager:
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
elif tool_input.type == "constant":
parameter_value = tool_input.value
elif tool_input.type == "mixed":
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.text
else:

+ 3
- 3
api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py 查看文件

@@ -87,7 +87,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
@@ -114,7 +114,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(Document)
.filter(
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
@@ -163,7 +163,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
):
with flask_app.app_context():
dataset = (
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
)

if not dataset:

+ 2
- 6
api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py 查看文件

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Any, Optional
from typing import Optional

from msal_extensions.persistence import ABC # type: ignore
from pydantic import BaseModel, ConfigDict
@@ -21,11 +21,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)

@abstractmethod
def _run(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def _run(self, query: str) -> str:
"""Use the tool.

Add run_manager: Optional[CallbackManagerForToolRun] = None

+ 0
- 0
api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py 查看文件


部分文件因文件數量過多而無法顯示

Loading…
取消
儲存