Procházet zdrojové kódy

Migrate SQLAlchemy from 1.x to 2.0 with automated and manual adjustments (#23224)

Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
tags/1.8.1
Yongtao Huang před 2 měsíci
rodič
revize
be3af1e234
Žádný účet není propojen s e-mailovou adresou tvůrce revize
33 změnil soubory, kde provedl 226 přidání a 260 odebrání
  1. 4
    2
      api/core/agent/base_agent_runner.py
  2. 1
    0
      api/core/app/apps/advanced_chat/app_runner.py
  3. 8
    5
      api/core/app/apps/agent_chat/app_runner.py
  4. 4
    2
      api/core/app/apps/chat/app_runner.py
  5. 8
    10
      api/core/app/apps/completion/app_generator.py
  6. 4
    2
      api/core/app/apps/completion/app_runner.py
  7. 3
    4
      api/core/app/apps/message_based_app_generator.py
  8. 4
    3
      api/core/app/features/annotation_reply/annotation_reply.py
  9. 2
    1
      api/core/app/task_pipeline/message_cycle_manager.py
  10. 9
    9
      api/core/callback_handler/index_tool_callback_handler.py
  11. 8
    10
      api/core/external_data_tool/api/api.py
  12. 10
    16
      api/core/indexing_runner.py
  13. 2
    2
      api/core/memory/token_buffer_memory.py
  14. 4
    4
      api/core/moderation/api/api.py
  15. 5
    4
      api/core/ops/aliyun_trace/aliyun_trace.py
  16. 5
    3
      api/core/ops/base_trace_instance.py
  17. 8
    9
      api/core/ops/ops_trace_manager.py
  18. 6
    3
      api/core/plugin/backwards_invocation/app.py
  19. 14
    26
      api/core/provider_manager.py
  20. 4
    4
      api/core/rag/datasource/keyword/jieba/jieba.py
  21. 11
    15
      api/core/rag/datasource/retrieval_service.py
  22. 3
    5
      api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
  23. 5
    8
      api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
  24. 5
    4
      api/core/rag/datasource/vdb/vector_factory.py
  25. 6
    8
      api/core/rag/docstore/dataset_docstore.py
  26. 7
    11
      api/core/rag/extractor/notion_extractor.py
  27. 24
    26
      api/core/rag/retrieval/dataset_retrieval.py
  28. 7
    9
      api/core/tools/tool_label_manager.py
  29. 8
    11
      api/core/tools/tool_manager.py
  30. 15
    21
      api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py
  31. 8
    11
      api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
  32. 6
    2
      api/core/tools/workflow_as_tool/tool.py
  33. 8
    10
      api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

+ 4
- 2
api/core/agent/base_agent_runner.py Zobrazit soubor

@@ -334,7 +334,8 @@ class BaseAgentRunner(AppRunner):
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
agent_thought = db.session.scalar(stmt)
if not agent_thought:
raise ValueError("agent thought not found")

@@ -492,7 +493,8 @@ class BaseAgentRunner(AppRunner):
return result

def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
stmt = select(MessageFile).where(MessageFile.message_id == message.id)
files = db.session.scalars(stmt).all()
if not files:
return UserPromptMessage(content=message.query)
if message.app_model_config:

+ 1
- 0
api/core/app/apps/advanced_chat/app_runner.py Zobrazit soubor

@@ -74,6 +74,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):

with Session(db.engine, expire_on_commit=False) as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))

if not app_record:
raise ValueError("App not found")


+ 8
- 5
api/core/app/apps/agent_chat/app_runner.py Zobrazit soubor

@@ -1,6 +1,8 @@
import logging
from typing import cast

from sqlalchemy import select

from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity
@@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
app_stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(app_stmt)
if not app_record:
raise ValueError("App not found")

@@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner):

if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
conversation_result = db.session.scalar(conversation_stmt)
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = db.session.query(Message).where(Message.id == message.id).first()
msg_stmt = select(Message).where(Message.id == message.id)
message_result = db.session.scalar(msg_stmt)
if message_result is None:
raise ValueError("Message not found")
db.session.close()

+ 4
- 2
api/core/app/apps/chat/app_runner.py Zobrazit soubor

@@ -1,6 +1,8 @@
import logging
from typing import cast

from sqlalchemy import select

from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig
@@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")


+ 8
- 10
api/core/app/apps/completion/app_generator.py Zobrazit soubor

@@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload

from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError
from sqlalchemy import select

from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
@@ -248,17 +249,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
message = (
db.session.query(Message)
.where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
stmt = select(Message).where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
message = db.session.scalar(stmt)

if not message:
raise MessageNotExistsError()

+ 4
- 2
api/core/app/apps/completion/app_runner.py Zobrazit soubor

@@ -1,6 +1,8 @@
import logging
from typing import cast

from sqlalchemy import select

from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig
@@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")


+ 3
- 4
api/core/app/apps/message_based_app_generator.py Zobrazit soubor

@@ -86,11 +86,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):

def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
if conversation:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
stmt = select(AppModelConfig).where(
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
)
app_model_config = db.session.scalar(stmt)

if not app_model_config:
raise AppModelConfigBrokenError()

+ 4
- 3
api/core/app/features/annotation_reply/annotation_reply.py Zobrazit soubor

@@ -1,6 +1,8 @@
import logging
from typing import Optional

from sqlalchemy import select

from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
@@ -25,9 +27,8 @@ class AnnotationReplyFeature:
:param invoke_from: invoke from
:return:
"""
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
)
stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
annotation_setting = db.session.scalar(stmt)

if not annotation_setting:
return None

+ 2
- 1
api/core/app/task_pipeline/message_cycle_manager.py Zobrazit soubor

@@ -86,7 +86,8 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation = db.session.scalar(stmt)

if not conversation:
return

+ 9
- 9
api/core/callback_handler/index_tool_callback_handler.py Zobrazit soubor

@@ -1,6 +1,8 @@
import logging
from collections.abc import Sequence

from sqlalchemy import select

from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@@ -49,7 +51,8 @@ class DatasetIndexToolCallbackHandler:
for document in documents:
if document.metadata is not None:
document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id)
dataset_document = db.session.scalar(dataset_document_stmt)
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
@@ -57,15 +60,12 @@ class DatasetIndexToolCallbackHandler:
)
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
segment = (
db.session.query(DocumentSegment)

+ 8
- 10
api/core/external_data_tool/api/api.py Zobrazit soubor

@@ -1,5 +1,7 @@
from typing import Optional

from sqlalchemy import select

from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.external_data_tool.base import ExternalDataTool
from core.helper import encrypter
@@ -28,13 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id:
raise ValueError("api_based_extension_id is required")

# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
)
api_based_extension = db.session.scalar(stmt)

if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
@@ -52,13 +52,11 @@ class ApiExternalDataTool(ExternalDataTool):
raise ValueError(f"config is required, config: {self.config}")
api_based_extension_id = self.config.get("api_based_extension_id")
assert api_based_extension_id is not None, "api_based_extension_id is required"

# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id
)
api_based_extension = db.session.scalar(stmt)

if not api_based_extension:
raise ValueError(

+ 10
- 16
api/core/indexing_runner.py Zobrazit soubor

@@ -8,6 +8,7 @@ import uuid
from typing import Any, Optional, cast

from flask import current_app
from sqlalchemy import select
from sqlalchemy.orm.exc import ObjectDeletedError

from configs import dify_config
@@ -56,13 +57,11 @@ class IndexingRunner:

if not dataset:
raise ValueError("no dataset found")

# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
stmt = select(DatasetProcessRule).where(
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
@@ -123,11 +122,8 @@ class IndexingRunner:
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")

@@ -208,7 +204,6 @@ class IndexingRunner:
child_documents.append(child_document)
document.children = child_documents
documents.append(document)

# build index
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
@@ -310,7 +305,8 @@ class IndexingRunner:
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
stmt = select(UploadFile).where(UploadFile.id == upload_file_id)
image_file = db.session.scalar(stmt)
if image_file is None:
continue
try:
@@ -339,10 +335,8 @@ class IndexingRunner:
if dataset_document.data_source_type == "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")

file_detail = (
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
)
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = db.session.scalars(stmt).one_or_none()

if file_detail:
extract_setting = ExtractSetting(

+ 2
- 2
api/core/memory/token_buffer_memory.py Zobrazit soubor

@@ -110,9 +110,9 @@ class TokenBufferMemory:
else:
message_limit = 500

stmt = stmt.limit(message_limit)
msg_limit_stmt = stmt.limit(message_limit)

messages = db.session.scalars(stmt).all()
messages = db.session.scalars(msg_limit_stmt).all()

# instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message

+ 4
- 4
api/core/moderation/api/api.py Zobrazit soubor

@@ -1,6 +1,7 @@
from typing import Optional

from pydantic import BaseModel, Field
from sqlalchemy import select

from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
from core.helper.encrypter import decrypt_token
@@ -87,10 +88,9 @@ class ApiModeration(Moderation):

@staticmethod
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
)
extension = db.session.scalar(stmt)

return extension

+ 5
- 4
api/core/ops/aliyun_trace/aliyun_trace.py Zobrazit soubor

@@ -5,6 +5,7 @@ from typing import Optional
from urllib.parse import urljoin

from opentelemetry.trace import Link, Status, StatusCode
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker

from core.ops.aliyun_trace.data_exporter.traceclient import (
@@ -263,15 +264,15 @@ class AliyunDataTrace(BaseTraceInstance):
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).where(App.id == app_id).first()
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")

if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first()
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (

+ 5
- 3
api/core/ops/base_trace_instance.py Zobrazit soubor

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod

from sqlalchemy import select
from sqlalchemy.orm import Session

from core.ops.entities.config_entity import BaseTracingConfig
@@ -44,14 +45,15 @@ class BaseTraceInstance(ABC):
"""
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app = session.query(App).where(App.id == app_id).first()
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")

if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first()
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")


+ 8
- 9
api/core/ops/ops_trace_manager.py Zobrazit soubor

@@ -226,9 +226,9 @@ class OpsTraceManager:

if not trace_config_data:
return None

# decrypt_token
app = db.session.query(App).where(App.id == app_id).first()
stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt)
if not app:
raise ValueError("App not found")

@@ -295,20 +295,19 @@ class OpsTraceManager:
@classmethod
def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None
message_data = db.session.query(Message).where(Message.id == message_id).first()
message_stmt = select(Message).where(Message.id == message_id)
message_data = db.session.scalar(message_stmt)
if not message_data:
return None
conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation_data = db.session.scalar(conversation_stmt)
if not conversation_data:
return None

if conversation_data.app_model_config_id:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation_data.app_model_config_id)
.first()
)
config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id)
app_model_config = db.session.scalar(config_stmt)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
app_model_config = conversation_data.override_model_configs


+ 6
- 3
api/core/plugin/backwards_invocation/app.py Zobrazit soubor

@@ -1,6 +1,8 @@
from collections.abc import Generator, Mapping
from typing import Optional, Union

from sqlalchemy import select

from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@@ -192,10 +194,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
get the user by user id
"""
user = db.session.query(EndUser).where(EndUser.id == user_id).first()
stmt = select(EndUser).where(EndUser.id == user_id)
user = db.session.scalar(stmt)
if not user:
user = db.session.query(Account).where(Account.id == user_id).first()
stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(stmt)

if not user:
raise ValueError("user not found")

+ 14
- 26
api/core/provider_manager.py Zobrazit soubor

@@ -276,15 +276,11 @@ class ProviderManager:
:param model_type: model type
:return:
"""
# Get the corresponding TenantDefaultModel record
default_model = (
db.session.query(TenantDefaultModel)
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
.first()
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
default_model = db.session.scalar(stmt)

# If it does not exist, get the first available provider model from get_configurations
# and update the TenantDefaultModel record
@@ -367,16 +363,11 @@ class ProviderManager:
model_names = [model.model for model in available_models]
if model not in model_names:
raise ValueError(f"Model {model} does not exist.")

# Get the list of available models from get_configurations and check if it is LLM
default_model = (
db.session.query(TenantDefaultModel)
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
.first()
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
default_model = db.session.scalar(stmt)

# create or update TenantDefaultModel record
if default_model:
@@ -598,16 +589,13 @@ class ProviderManager:
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError:
db.session.rollback()
existed_provider_record = (
db.session.query(Provider)
.where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
)
.first()
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
continue


+ 4
- 4
api/core/rag/datasource/keyword/jieba/jieba.py Zobrazit soubor

@@ -3,6 +3,7 @@ from typing import Any, Optional

import orjson
from pydantic import BaseModel
from sqlalchemy import select

from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
@@ -211,11 +212,10 @@ class Jieba(BaseKeyword):
return sorted_chunk_indices[:k]

def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id
)
document_segment = db.session.scalar(stmt)
if document_segment:
document_segment.keywords = keywords
db.session.add(document_segment)

+ 11
- 15
api/core/rag/datasource/retrieval_service.py Zobrazit soubor

@@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional

from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only

from configs import dify_config
@@ -127,7 +128,8 @@ class RetrievalService:
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
return []
metadata_condition = (
@@ -316,10 +318,8 @@ class RetrievalService:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id")

child_chunk = (
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
)
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = db.session.scalar(child_chunk_stmt)

if not child_chunk:
continue
@@ -378,17 +378,13 @@ class RetrievalService:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue

segment = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)

if not segment:
continue

+ 3
- 5
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py Zobrazit soubor

@@ -18,6 +18,7 @@ from qdrant_client.http.models import (
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from sqlalchemy import select

from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -445,11 +446,8 @@ class QdrantVector(BaseVector):
class QdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id)
dataset_collection_binding = db.session.scalars(stmt).one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:

+ 5
- 8
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py Zobrazit soubor

@@ -20,6 +20,7 @@ from qdrant_client.http.models import (
)
from qdrant_client.local.qdrant_local import QdrantLocal
from requests.auth import HTTPDigestAuth
from sqlalchemy import select

from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -416,16 +417,12 @@ class TidbOnQdrantVector(BaseVector):

class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
tidb_auth_binding = (
db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if not tidb_auth_binding:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.where(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"


+ 5
- 4
api/core/rag/datasource/vdb/vector_factory.py Zobrazit soubor

@@ -3,6 +3,8 @@ import time
from abc import ABC, abstractmethod
from typing import Any, Optional

from sqlalchemy import select

from configs import dify_config
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
@@ -45,11 +47,10 @@ class Vector:
vector_type = self._dataset.index_struct_dict["type"]
else:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = (
db.session.query(Whitelist)
.where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none()
stmt = select(Whitelist).where(
Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db"
)
whitelist = db.session.scalars(stmt).one_or_none()
if whitelist:
vector_type = VectorType.TIDB_ON_QDRANT


+ 6
- 8
api/core/rag/docstore/dataset_docstore.py Zobrazit soubor

@@ -1,7 +1,7 @@
from collections.abc import Sequence
from typing import Any, Optional

from sqlalchemy import func
from sqlalchemy import func, select

from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
@@ -41,9 +41,8 @@ class DatasetDocumentStore:

@property
def docs(self) -> dict[str, Document]:
document_segments = (
db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
)
stmt = select(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id)
document_segments = db.session.scalars(stmt).all()

output = {}
for document_segment in document_segments:
@@ -228,10 +227,9 @@ class DatasetDocumentStore:
return data

def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
document_segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.first()
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id
)
document_segment = db.session.scalar(stmt)

return document_segment

+ 7
- 11
api/core/rag/extractor/notion_extractor.py Zobrazit soubor

@@ -4,6 +4,7 @@ import operator
from typing import Any, Optional, cast

import requests
from sqlalchemy import select

from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
@@ -367,18 +368,13 @@ class NotionExtractor(BaseExtractor):

@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
)
)
.first()
stmt = select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
)
data_source_binding = db.session.scalar(stmt)

if not data_source_binding:
raise Exception(

+ 24
- 26
api/core/rag/retrieval/dataset_retrieval.py Zobrazit soubor

@@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Optional, Union, cast

from flask import Flask, current_app
from sqlalchemy import Float, and_, or_, text
from sqlalchemy import Float, and_, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import Session

@@ -135,7 +135,8 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)

# pass if dataset is not available
if not dataset:
@@ -240,15 +241,12 @@ class DatasetRetrieval:
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument)
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
@@ -327,7 +325,8 @@ class DatasetRetrieval:

if dataset_id:
# get retrieval model config
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
if dataset:
results = []
if dataset.provider == "external":
@@ -514,22 +513,18 @@ class DatasetRetrieval:
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
if document.metadata is not None:
dataset_document = (
db.session.query(DatasetDocument)
.where(DatasetDocument.id == document.metadata["document_id"])
.first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = db.session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
segment = (
db.session.query(DocumentSegment)
@@ -600,7 +595,8 @@ class DatasetRetrieval:
):
with flask_app.app_context():
with Session(db.engine) as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)

if not dataset:
return []
@@ -685,7 +681,8 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)

# pass if dataset is not available
if not dataset:
@@ -958,7 +955,8 @@ class DatasetRetrieval:
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> Optional[list[dict[str, Any]]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(metadata_stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config
if metadata_model_config is None:

+ 7
- 9
api/core/tools/tool_label_manager.py Zobrazit soubor

@@ -1,3 +1,5 @@
from sqlalchemy import select

from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
@@ -54,17 +56,13 @@ class ToolLabelManager:
return controller.tool_labels
else:
raise ValueError("Unsupported tool type")

labels = (
db.session.query(ToolLabelBinding.label_name)
.where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
)
.all()
stmt = select(ToolLabelBinding.label_name).where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
)
labels = db.session.scalars(stmt).all()

return [label.label_name for label in labels]
return list(labels)

@classmethod
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:

+ 8
- 11
api/core/tools/tool_manager.py Zobrazit soubor

@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast

import sqlalchemy as sa
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.orm import Session
from yarl import URL

@@ -198,14 +199,11 @@ class ToolManager:
# get specific credentials
if is_valid_uuid(credential_id):
try:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
builtin_provider_stmt = select(BuiltinToolProvider).where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
builtin_provider = db.session.scalar(builtin_provider_stmt)
except Exception as e:
builtin_provider = None
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
@@ -317,11 +315,10 @@ class ToolManager:
),
)
elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider = (
db.session.query(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
workflow_provider_stmt = select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
)
workflow_provider = db.session.scalar(workflow_provider_stmt)

if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")

+ 15
- 21
api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py Zobrazit soubor

@@ -3,6 +3,7 @@ from typing import Any

from flask import Flask, current_app
from pydantic import BaseModel, Field
from sqlalchemy import select

from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
@@ -85,17 +86,14 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):

document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
.all()
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
segments = db.session.scalars(document_segment_stmt).all()

if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
@@ -112,15 +110,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
resource_number = 1
for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(Document)
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
document_stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
position=resource_number,
@@ -162,9 +157,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
hit_callbacks: list[DatasetIndexToolCallbackHandler],
):
with flask_app.app_context():
dataset = (
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
)
stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)

if not dataset:
return []

+ 8
- 11
api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py Zobrazit soubor

@@ -1,6 +1,7 @@
from typing import Any, Optional, cast

from pydantic import BaseModel, Field
from sqlalchemy import select

from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
@@ -56,9 +57,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
)

def _run(self, query: str) -> str:
dataset = (
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
)
dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id)
dataset = db.session.scalar(dataset_stmt)

if not dataset:
return ""
@@ -188,15 +188,12 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument) # type: ignore
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt) # type: ignore
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,

+ 6
- 2
api/core/tools/workflow_as_tool/tool.py Zobrazit soubor

@@ -3,6 +3,8 @@ import logging
from collections.abc import Generator
from typing import Any, Optional

from sqlalchemy import select

from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
@@ -136,7 +138,8 @@ class WorkflowTool(Tool):
.first()
)
else:
workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first()
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = db.session.scalar(stmt)

if not workflow:
raise ValueError("workflow not found or not published")
@@ -147,7 +150,8 @@ class WorkflowTool(Tool):
"""
get the app by app id
"""
app = db.session.query(App).where(App.id == app_id).first()
stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt)
if not app:
raise ValueError("app not found")


+ 8
- 10
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py Zobrazit soubor

@@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast

from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import Float, and_, func, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import sessionmaker

@@ -367,15 +367,12 @@ class KnowledgeRetrievalNode(BaseNode):
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = (
db.session.query(Document)
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(stmt)
if dataset and document:
source = {
"metadata": {
@@ -514,7 +511,8 @@ class KnowledgeRetrievalNode(BaseNode):
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")

Načítá se…
Zrušit
Uložit