Browse Source

Migrate SQLAlchemy from 1.x to 2.0 with automated and manual adjustments (#23224)

Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
tags/1.8.1
Yongtao Huang 2 months ago
parent
commit
be3af1e234
No account linked to committer's email address
33 changed files with 226 additions and 260 deletions
  1. 4
    2
      api/core/agent/base_agent_runner.py
  2. 1
    0
      api/core/app/apps/advanced_chat/app_runner.py
  3. 8
    5
      api/core/app/apps/agent_chat/app_runner.py
  4. 4
    2
      api/core/app/apps/chat/app_runner.py
  5. 8
    10
      api/core/app/apps/completion/app_generator.py
  6. 4
    2
      api/core/app/apps/completion/app_runner.py
  7. 3
    4
      api/core/app/apps/message_based_app_generator.py
  8. 4
    3
      api/core/app/features/annotation_reply/annotation_reply.py
  9. 2
    1
      api/core/app/task_pipeline/message_cycle_manager.py
  10. 9
    9
      api/core/callback_handler/index_tool_callback_handler.py
  11. 8
    10
      api/core/external_data_tool/api/api.py
  12. 10
    16
      api/core/indexing_runner.py
  13. 2
    2
      api/core/memory/token_buffer_memory.py
  14. 4
    4
      api/core/moderation/api/api.py
  15. 5
    4
      api/core/ops/aliyun_trace/aliyun_trace.py
  16. 5
    3
      api/core/ops/base_trace_instance.py
  17. 8
    9
      api/core/ops/ops_trace_manager.py
  18. 6
    3
      api/core/plugin/backwards_invocation/app.py
  19. 14
    26
      api/core/provider_manager.py
  20. 4
    4
      api/core/rag/datasource/keyword/jieba/jieba.py
  21. 11
    15
      api/core/rag/datasource/retrieval_service.py
  22. 3
    5
      api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
  23. 5
    8
      api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
  24. 5
    4
      api/core/rag/datasource/vdb/vector_factory.py
  25. 6
    8
      api/core/rag/docstore/dataset_docstore.py
  26. 7
    11
      api/core/rag/extractor/notion_extractor.py
  27. 24
    26
      api/core/rag/retrieval/dataset_retrieval.py
  28. 7
    9
      api/core/tools/tool_label_manager.py
  29. 8
    11
      api/core/tools/tool_manager.py
  30. 15
    21
      api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py
  31. 8
    11
      api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
  32. 6
    2
      api/core/tools/workflow_as_tool/tool.py
  33. 8
    10
      api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

+ 4
- 2
api/core/agent/base_agent_runner.py View File

""" """
Save agent thought Save agent thought
""" """
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
agent_thought = db.session.scalar(stmt)
if not agent_thought: if not agent_thought:
raise ValueError("agent thought not found") raise ValueError("agent thought not found")


return result return result


def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
stmt = select(MessageFile).where(MessageFile.message_id == message.id)
files = db.session.scalars(stmt).all()
if not files: if not files:
return UserPromptMessage(content=message.query) return UserPromptMessage(content=message.query)
if message.app_model_config: if message.app_model_config:

+ 1
- 0
api/core/app/apps/advanced_chat/app_runner.py View File



with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id)) app_record = session.scalar(select(App).where(App.id == app_config.app_id))

if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")



+ 8
- 5
api/core/app/apps/agent_chat/app_runner.py View File

import logging import logging
from typing import cast from typing import cast


from sqlalchemy import select

from core.agent.cot_chat_agent_runner import CotChatAgentRunner from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity from core.agent.entities import AgentEntity
""" """
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config) app_config = cast(AgentChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
app_stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(app_stmt)
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")




if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
conversation_result = db.session.scalar(conversation_stmt)
if conversation_result is None: if conversation_result is None:
raise ValueError("Conversation not found") raise ValueError("Conversation not found")
message_result = db.session.query(Message).where(Message.id == message.id).first()
msg_stmt = select(Message).where(Message.id == message.id)
message_result = db.session.scalar(msg_stmt)
if message_result is None: if message_result is None:
raise ValueError("Message not found") raise ValueError("Message not found")
db.session.close() db.session.close()

+ 4
- 2
api/core/app/apps/chat/app_runner.py View File

import logging import logging
from typing import cast from typing import cast


from sqlalchemy import select

from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig from core.app.apps.chat.app_config_manager import ChatAppConfig
""" """
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config) app_config = cast(ChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")



+ 8
- 10
api/core/app/apps/completion/app_generator.py View File



from flask import Flask, copy_current_request_context, current_app from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy import select


from configs import dify_config from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
""" """
message = (
db.session.query(Message)
.where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
stmt = select(Message).where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
) )
message = db.session.scalar(stmt)


if not message: if not message:
raise MessageNotExistsError() raise MessageNotExistsError()

+ 4
- 2
api/core/app/apps/completion/app_runner.py View File

import logging import logging
from typing import cast from typing import cast


from sqlalchemy import select

from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig from core.app.apps.completion.app_config_manager import CompletionAppConfig
""" """
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config) app_config = cast(CompletionAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")



+ 3
- 4
api/core/app/apps/message_based_app_generator.py View File



def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
if conversation: if conversation:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
stmt = select(AppModelConfig).where(
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
) )
app_model_config = db.session.scalar(stmt)


if not app_model_config: if not app_model_config:
raise AppModelConfigBrokenError() raise AppModelConfigBrokenError()

+ 4
- 3
api/core/app/features/annotation_reply/annotation_reply.py View File

import logging import logging
from typing import Optional from typing import Optional


from sqlalchemy import select

from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db from extensions.ext_database import db
:param invoke_from: invoke from :param invoke_from: invoke from
:return: :return:
""" """
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
)
stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
annotation_setting = db.session.scalar(stmt)


if not annotation_setting: if not annotation_setting:
return None return None

+ 2
- 1
api/core/app/task_pipeline/message_cycle_manager.py View File

def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context(): with flask_app.app_context():
# get conversation and message # get conversation and message
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation = db.session.scalar(stmt)


if not conversation: if not conversation:
return return

+ 9
- 9
api/core/callback_handler/index_tool_callback_handler.py View File

import logging import logging
from collections.abc import Sequence from collections.abc import Sequence


from sqlalchemy import select

from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
for document in documents: for document in documents:
if document.metadata is not None: if document.metadata is not None:
document_id = document.metadata["document_id"] document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id)
dataset_document = db.session.scalar(dataset_document_stmt)
if not dataset_document: if not dataset_document:
_logger.warning( _logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s", "Expected DatasetDocument record to exist, but none was found, document_id=%s",
) )
continue continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
) )
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk: if child_chunk:
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)

+ 8
- 10
api/core/external_data_tool/api/api.py View File

from typing import Optional from typing import Optional


from sqlalchemy import select

from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.external_data_tool.base import ExternalDataTool from core.external_data_tool.base import ExternalDataTool
from core.helper import encrypter from core.helper import encrypter
api_based_extension_id = config.get("api_based_extension_id") api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id: if not api_based_extension_id:
raise ValueError("api_based_extension_id is required") raise ValueError("api_based_extension_id is required")

# get api_based_extension # get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
) )
api_based_extension = db.session.scalar(stmt)


if not api_based_extension: if not api_based_extension:
raise ValueError("api_based_extension_id is invalid") raise ValueError("api_based_extension_id is invalid")
raise ValueError(f"config is required, config: {self.config}") raise ValueError(f"config is required, config: {self.config}")
api_based_extension_id = self.config.get("api_based_extension_id") api_based_extension_id = self.config.get("api_based_extension_id")
assert api_based_extension_id is not None, "api_based_extension_id is required" assert api_based_extension_id is not None, "api_based_extension_id is required"

# get api_based_extension # get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id
) )
api_based_extension = db.session.scalar(stmt)


if not api_based_extension: if not api_based_extension:
raise ValueError( raise ValueError(

+ 10
- 16
api/core/indexing_runner.py View File

from typing import Any, Optional, cast from typing import Any, Optional, cast


from flask import current_app from flask import current_app
from sqlalchemy import select
from sqlalchemy.orm.exc import ObjectDeletedError from sqlalchemy.orm.exc import ObjectDeletedError


from configs import dify_config from configs import dify_config


if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")

# get the process rule # get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
stmt = select(DatasetProcessRule).where(
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
) )
processing_rule = db.session.scalar(stmt)
if not processing_rule: if not processing_rule:
raise ValueError("no process rule found") raise ValueError("no process rule found")
index_type = dataset_document.doc_form index_type = dataset_document.doc_form
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit() db.session.commit()
# get the process rule # get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
processing_rule = db.session.scalar(stmt)
if not processing_rule: if not processing_rule:
raise ValueError("no process rule found") raise ValueError("no process rule found")


child_documents.append(child_document) child_documents.append(child_document)
document.children = child_documents document.children = child_documents
documents.append(document) documents.append(document)

# build index # build index
index_type = dataset_document.doc_form index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
# delete image files and related db records # delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content) image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids: for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
stmt = select(UploadFile).where(UploadFile.id == upload_file_id)
image_file = db.session.scalar(stmt)
if image_file is None: if image_file is None:
continue continue
try: try:
if dataset_document.data_source_type == "upload_file": if dataset_document.data_source_type == "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info: if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found") raise ValueError("no upload file found")

file_detail = (
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
)
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = db.session.scalars(stmt).one_or_none()


if file_detail: if file_detail:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(

+ 2
- 2
api/core/memory/token_buffer_memory.py View File

else: else:
message_limit = 500 message_limit = 500


stmt = stmt.limit(message_limit)
msg_limit_stmt = stmt.limit(message_limit)


messages = db.session.scalars(stmt).all()
messages = db.session.scalars(msg_limit_stmt).all()


# instead of all messages from the conversation, we only need to extract messages # instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message # that belong to the thread of last message

+ 4
- 4
api/core/moderation/api/api.py View File

from typing import Optional from typing import Optional


from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select


from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
from core.helper.encrypter import decrypt_token from core.helper.encrypter import decrypt_token


@staticmethod @staticmethod
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
) )
extension = db.session.scalar(stmt)


return extension return extension

+ 5
- 4
api/core/ops/aliyun_trace/aliyun_trace.py View File

from urllib.parse import urljoin from urllib.parse import urljoin


from opentelemetry.trace import Link, Status, StatusCode from opentelemetry.trace import Link, Status, StatusCode
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker


from core.ops.aliyun_trace.data_exporter.traceclient import ( from core.ops.aliyun_trace.data_exporter.traceclient import (
app_id = trace_info.metadata.get("app_id") app_id = trace_info.metadata.get("app_id")
if not app_id: if not app_id:
raise ValueError("No app_id found in trace_info metadata") raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).where(App.id == app_id).first()
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app: if not app:
raise ValueError(f"App with id {app_id} not found") raise ValueError(f"App with id {app_id} not found")


if not app.created_by: if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)") raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first()
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account: if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = ( current_tenant = (

+ 5
- 3
api/core/ops/base_trace_instance.py View File

from abc import ABC, abstractmethod from abc import ABC, abstractmethod


from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session


from core.ops.entities.config_entity import BaseTracingConfig from core.ops.entities.config_entity import BaseTracingConfig
""" """
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator # Get the app to find its creator
app = session.query(App).where(App.id == app_id).first()
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app: if not app:
raise ValueError(f"App with id {app_id} not found") raise ValueError(f"App with id {app_id} not found")


if not app.created_by: if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)") raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first()
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account: if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")



+ 8
- 9
api/core/ops/ops_trace_manager.py View File



if not trace_config_data: if not trace_config_data:
return None return None

# decrypt_token # decrypt_token
app = db.session.query(App).where(App.id == app_id).first()
stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt)
if not app: if not app:
raise ValueError("App not found") raise ValueError("App not found")


@classmethod @classmethod
def get_app_config_through_message_id(cls, message_id: str): def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None app_model_config = None
message_data = db.session.query(Message).where(Message.id == message_id).first()
message_stmt = select(Message).where(Message.id == message_id)
message_data = db.session.scalar(message_stmt)
if not message_data: if not message_data:
return None return None
conversation_id = message_data.conversation_id conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation_data = db.session.scalar(conversation_stmt)
if not conversation_data: if not conversation_data:
return None return None


if conversation_data.app_model_config_id: if conversation_data.app_model_config_id:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation_data.app_model_config_id)
.first()
)
config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id)
app_model_config = db.session.scalar(config_stmt)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
app_model_config = conversation_data.override_model_configs app_model_config = conversation_data.override_model_configs



+ 6
- 3
api/core/plugin/backwards_invocation/app.py View File

from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Optional, Union from typing import Optional, Union


from sqlalchemy import select

from controllers.service_api.wraps import create_or_update_end_user_for_user_id from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
""" """
get the user by user id get the user by user id
""" """
user = db.session.query(EndUser).where(EndUser.id == user_id).first()
stmt = select(EndUser).where(EndUser.id == user_id)
user = db.session.scalar(stmt)
if not user: if not user:
user = db.session.query(Account).where(Account.id == user_id).first()
stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(stmt)


if not user: if not user:
raise ValueError("user not found") raise ValueError("user not found")

+ 14
- 26
api/core/provider_manager.py View File

:param model_type: model type :param model_type: model type
:return: :return:
""" """
# Get the corresponding TenantDefaultModel record
default_model = (
db.session.query(TenantDefaultModel)
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
.first()
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
) )
default_model = db.session.scalar(stmt)


# If it does not exist, get the first available provider model from get_configurations # If it does not exist, get the first available provider model from get_configurations
# and update the TenantDefaultModel record # and update the TenantDefaultModel record
model_names = [model.model for model in available_models] model_names = [model.model for model in available_models]
if model not in model_names: if model not in model_names:
raise ValueError(f"Model {model} does not exist.") raise ValueError(f"Model {model} does not exist.")

# Get the list of available models from get_configurations and check if it is LLM
default_model = (
db.session.query(TenantDefaultModel)
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
.first()
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
) )
default_model = db.session.scalar(stmt)


# create or update TenantDefaultModel record # create or update TenantDefaultModel record
if default_model: if default_model:
provider_name_to_provider_records_dict[provider_name].append(new_provider_record) provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
existed_provider_record = (
db.session.query(Provider)
.where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
)
.first()
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
) )
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record: if not existed_provider_record:
continue continue



+ 4
- 4
api/core/rag/datasource/keyword/jieba/jieba.py View File



import orjson import orjson
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import select


from configs import dify_config from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
return sorted_chunk_indices[:k] return sorted_chunk_indices[:k]


def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id
) )
document_segment = db.session.scalar(stmt)
if document_segment: if document_segment:
document_segment.keywords = keywords document_segment.keywords = keywords
db.session.add(document_segment) db.session.add(document_segment)

+ 11
- 15
api/core/rag/datasource/retrieval_service.py View File

from typing import Optional from typing import Optional


from flask import Flask, current_app from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only from sqlalchemy.orm import Session, load_only


from configs import dify_config from configs import dify_config
external_retrieval_model: Optional[dict] = None, external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None, metadata_filtering_conditions: Optional[dict] = None,
): ):
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset: if not dataset:
return [] return []
metadata_condition = ( metadata_condition = (
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents # Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id") child_index_node_id = document.metadata.get("doc_id")

child_chunk = (
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
)
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = db.session.scalar(child_chunk_stmt)


if not child_chunk: if not child_chunk:
continue continue
index_node_id = document.metadata.get("doc_id") index_node_id = document.metadata.get("doc_id")
if not index_node_id: if not index_node_id:
continue continue

segment = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
) )
segment = db.session.scalar(document_segment_stmt)


if not segment: if not segment:
continue continue

+ 3
- 5
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py View File

TokenizerType, TokenizerType,
) )
from qdrant_client.local.qdrant_local import QdrantLocal from qdrant_client.local.qdrant_local import QdrantLocal
from sqlalchemy import select


from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
class QdrantVectorFactory(AbstractVectorFactory): class QdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
if dataset.collection_binding_id: if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id)
dataset_collection_binding = db.session.scalars(stmt).one_or_none()
if dataset_collection_binding: if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name collection_name = dataset_collection_binding.collection_name
else: else:

+ 5
- 8
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py View File

) )
from qdrant_client.local.qdrant_local import QdrantLocal from qdrant_client.local.qdrant_local import QdrantLocal
from requests.auth import HTTPDigestAuth from requests.auth import HTTPDigestAuth
from sqlalchemy import select


from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field


class TidbOnQdrantVectorFactory(AbstractVectorFactory): class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
tidb_auth_binding = (
db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if not tidb_auth_binding: if not tidb_auth_binding:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.where(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if tidb_auth_binding: if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"



+ 5
- 4
api/core/rag/datasource/vdb/vector_factory.py View File

from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Optional from typing import Any, Optional


from sqlalchemy import select

from configs import dify_config from configs import dify_config
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
vector_type = self._dataset.index_struct_dict["type"] vector_type = self._dataset.index_struct_dict["type"]
else: else:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE: if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = (
db.session.query(Whitelist)
.where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none()
stmt = select(Whitelist).where(
Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db"
) )
whitelist = db.session.scalars(stmt).one_or_none()
if whitelist: if whitelist:
vector_type = VectorType.TIDB_ON_QDRANT vector_type = VectorType.TIDB_ON_QDRANT



+ 6
- 8
api/core/rag/docstore/dataset_docstore.py View File

from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Optional from typing import Any, Optional


from sqlalchemy import func
from sqlalchemy import func, select


from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType


@property @property
def docs(self) -> dict[str, Document]: def docs(self) -> dict[str, Document]:
document_segments = (
db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
)
stmt = select(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id)
document_segments = db.session.scalars(stmt).all()


output = {} output = {}
for document_segment in document_segments: for document_segment in document_segments:
return data return data


def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
document_segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.first()
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id
) )
document_segment = db.session.scalar(stmt)


return document_segment return document_segment

+ 7
- 11
api/core/rag/extractor/notion_extractor.py View File

from typing import Any, Optional, cast from typing import Any, Optional, cast


import requests import requests
from sqlalchemy import select


from configs import dify_config from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor


@classmethod @classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
)
)
.first()
stmt = select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
) )
data_source_binding = db.session.scalar(stmt)


if not data_source_binding: if not data_source_binding:
raise Exception( raise Exception(

+ 24
- 26
api/core/rag/retrieval/dataset_retrieval.py View File

from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast


from flask import Flask, current_app from flask import Flask, current_app
from sqlalchemy import Float, and_, or_, text
from sqlalchemy import Float, and_, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import Session from sqlalchemy.orm import Session


available_datasets = [] available_datasets = []
for dataset_id in dataset_ids: for dataset_id in dataset_ids:
# get dataset from dataset id # get dataset from dataset id
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)


# pass if dataset is not available # pass if dataset is not available
if not dataset: if not dataset:
for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument)
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
) )
document = db.session.scalar(dataset_document_stmt)
if dataset and document: if dataset and document:
source = RetrievalSourceMetadata( source = RetrievalSourceMetadata(
dataset_id=dataset.id, dataset_id=dataset.id,


if dataset_id: if dataset_id:
# get retrieval model config # get retrieval model config
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
if dataset: if dataset:
results = [] results = []
if dataset.provider == "external": if dataset.provider == "external":
dify_documents = [document for document in documents if document.provider == "dify"] dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents: for document in dify_documents:
if document.metadata is not None: if document.metadata is not None:
dataset_document = (
db.session.query(DatasetDocument)
.where(DatasetDocument.id == document.metadata["document_id"])
.first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
) )
dataset_document = db.session.scalar(dataset_document_stmt)
if dataset_document: if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
) )
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk: if child_chunk:
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
): ):
with flask_app.app_context(): with flask_app.app_context():
with Session(db.engine) as session: with Session(db.engine) as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)


if not dataset: if not dataset:
return [] return []
available_datasets = [] available_datasets = []
for dataset_id in dataset_ids: for dataset_id in dataset_ids:
# get dataset from dataset id # get dataset from dataset id
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)


# pass if dataset is not available # pass if dataset is not available
if not dataset: if not dataset:
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> Optional[list[dict[str, Any]]]: ) -> Optional[list[dict[str, Any]]]:
# get all metadata field # get all metadata field
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(metadata_stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config # get metadata model config
if metadata_model_config is None: if metadata_model_config is None:

+ 7
- 9
api/core/tools/tool_label_manager.py View File

from sqlalchemy import select

from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController
return controller.tool_labels return controller.tool_labels
else: else:
raise ValueError("Unsupported tool type") raise ValueError("Unsupported tool type")

labels = (
db.session.query(ToolLabelBinding.label_name)
.where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
)
.all()
stmt = select(ToolLabelBinding.label_name).where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
) )
labels = db.session.scalars(stmt).all()


return [label.label_name for label in labels]
return list(labels)


@classmethod @classmethod
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:

+ 8
- 11
api/core/tools/tool_manager.py View File



import sqlalchemy as sa import sqlalchemy as sa
from pydantic import TypeAdapter from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from yarl import URL from yarl import URL


# get specific credentials # get specific credentials
if is_valid_uuid(credential_id): if is_valid_uuid(credential_id):
try: try:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
builtin_provider_stmt = select(BuiltinToolProvider).where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
) )
builtin_provider = db.session.scalar(builtin_provider_stmt)
except Exception as e: except Exception as e:
builtin_provider = None builtin_provider = None
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
), ),
) )
elif provider_type == ToolProviderType.WORKFLOW: elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider = (
db.session.query(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
workflow_provider_stmt = select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
) )
workflow_provider = db.session.scalar(workflow_provider_stmt)


if workflow_provider is None: if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")

+ 15
- 21
api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py View File



from flask import Flask, current_app from flask import Flask, current_app
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select


from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager from core.model_manager import ModelManager


document_context_list = [] document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
.all()
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
) )
segments = db.session.scalars(document_segment_stmt).all()


if segments: if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
resource_number = 1 resource_number = 1
for segment in sorted_segments: for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(Document)
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
document_stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
) )
document = db.session.scalar(document_stmt)
if dataset and document: if dataset and document:
source = RetrievalSourceMetadata( source = RetrievalSourceMetadata(
position=resource_number, position=resource_number,
hit_callbacks: list[DatasetIndexToolCallbackHandler], hit_callbacks: list[DatasetIndexToolCallbackHandler],
): ):
with flask_app.app_context(): with flask_app.app_context():
dataset = (
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
)
stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)


if not dataset: if not dataset:
return [] return []

+ 8
- 11
api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py View File

from typing import Any, Optional, cast from typing import Any, Optional, cast


from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select


from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
) )


def _run(self, query: str) -> str: def _run(self, query: str) -> str:
dataset = (
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
)
dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id)
dataset = db.session.scalar(dataset_stmt)


if not dataset: if not dataset:
return "" return ""
for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument) # type: ignore
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
) )
document = db.session.scalar(dataset_document_stmt) # type: ignore
if dataset and document: if dataset and document:
source = RetrievalSourceMetadata( source = RetrievalSourceMetadata(
dataset_id=dataset.id, dataset_id=dataset.id,

+ 6
- 2
api/core/tools/workflow_as_tool/tool.py View File

from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional from typing import Any, Optional


from sqlalchemy import select

from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
.first() .first()
) )
else: else:
workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first()
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = db.session.scalar(stmt)


if not workflow: if not workflow:
raise ValueError("workflow not found or not published") raise ValueError("workflow not found or not published")
""" """
get the app by app id get the app by app id
""" """
app = db.session.query(App).where(App.id == app_id).first()
stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt)
if not app: if not app:
raise ValueError("app not found") raise ValueError("app not found")



+ 8
- 10
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py View File

from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast


from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import Float, and_, func, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker


for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = (
db.session.query(Document)
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
) )
document = db.session.scalar(stmt)
if dataset and document: if dataset and document:
source = { source = {
"metadata": { "metadata": {
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
# get all metadata field # get all metadata field
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None: if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required") raise ValueError("metadata_model_config is required")

Loading…
Cancel
Save