Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>tags/1.8.1
| @@ -334,7 +334,8 @@ class BaseAgentRunner(AppRunner): | |||
| """ | |||
| Save agent thought | |||
| """ | |||
| agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first() | |||
| stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id) | |||
| agent_thought = db.session.scalar(stmt) | |||
| if not agent_thought: | |||
| raise ValueError("agent thought not found") | |||
| @@ -492,7 +493,8 @@ class BaseAgentRunner(AppRunner): | |||
| return result | |||
| def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: | |||
| files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all() | |||
| stmt = select(MessageFile).where(MessageFile.message_id == message.id) | |||
| files = db.session.scalars(stmt).all() | |||
| if not files: | |||
| return UserPromptMessage(content=message.query) | |||
| if message.app_model_config: | |||
| @@ -74,6 +74,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| app_record = session.scalar(select(App).where(App.id == app_config.app_id)) | |||
| if not app_record: | |||
| raise ValueError("App not found") | |||
| @@ -1,6 +1,8 @@ | |||
| import logging | |||
| from typing import cast | |||
| from sqlalchemy import select | |||
| from core.agent.cot_chat_agent_runner import CotChatAgentRunner | |||
| from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner | |||
| from core.agent.entities import AgentEntity | |||
| @@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner): | |||
| """ | |||
| app_config = application_generate_entity.app_config | |||
| app_config = cast(AgentChatAppConfig, app_config) | |||
| app_record = db.session.query(App).where(App.id == app_config.app_id).first() | |||
| app_stmt = select(App).where(App.id == app_config.app_id) | |||
| app_record = db.session.scalar(app_stmt) | |||
| if not app_record: | |||
| raise ValueError("App not found") | |||
| @@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner): | |||
| if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): | |||
| agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING | |||
| conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first() | |||
| conversation_stmt = select(Conversation).where(Conversation.id == conversation.id) | |||
| conversation_result = db.session.scalar(conversation_stmt) | |||
| if conversation_result is None: | |||
| raise ValueError("Conversation not found") | |||
| message_result = db.session.query(Message).where(Message.id == message.id).first() | |||
| msg_stmt = select(Message).where(Message.id == message.id) | |||
| message_result = db.session.scalar(msg_stmt) | |||
| if message_result is None: | |||
| raise ValueError("Message not found") | |||
| db.session.close() | |||
| @@ -1,6 +1,8 @@ | |||
| import logging | |||
| from typing import cast | |||
| from sqlalchemy import select | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.apps.base_app_runner import AppRunner | |||
| from core.app.apps.chat.app_config_manager import ChatAppConfig | |||
| @@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner): | |||
| """ | |||
| app_config = application_generate_entity.app_config | |||
| app_config = cast(ChatAppConfig, app_config) | |||
| app_record = db.session.query(App).where(App.id == app_config.app_id).first() | |||
| stmt = select(App).where(App.id == app_config.app_id) | |||
| app_record = db.session.scalar(stmt) | |||
| if not app_record: | |||
| raise ValueError("App not found") | |||
| @@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload | |||
| from flask import Flask, copy_current_request_context, current_app | |||
| from pydantic import ValidationError | |||
| from sqlalchemy import select | |||
| from configs import dify_config | |||
| from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter | |||
| @@ -248,17 +249,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| :param invoke_from: invoke from source | |||
| :param stream: is stream | |||
| """ | |||
| message = ( | |||
| db.session.query(Message) | |||
| .where( | |||
| Message.id == message_id, | |||
| Message.app_id == app_model.id, | |||
| Message.from_source == ("api" if isinstance(user, EndUser) else "console"), | |||
| Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), | |||
| Message.from_account_id == (user.id if isinstance(user, Account) else None), | |||
| ) | |||
| .first() | |||
| stmt = select(Message).where( | |||
| Message.id == message_id, | |||
| Message.app_id == app_model.id, | |||
| Message.from_source == ("api" if isinstance(user, EndUser) else "console"), | |||
| Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), | |||
| Message.from_account_id == (user.id if isinstance(user, Account) else None), | |||
| ) | |||
| message = db.session.scalar(stmt) | |||
| if not message: | |||
| raise MessageNotExistsError() | |||
| @@ -1,6 +1,8 @@ | |||
| import logging | |||
| from typing import cast | |||
| from sqlalchemy import select | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.apps.base_app_runner import AppRunner | |||
| from core.app.apps.completion.app_config_manager import CompletionAppConfig | |||
| @@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner): | |||
| """ | |||
| app_config = application_generate_entity.app_config | |||
| app_config = cast(CompletionAppConfig, app_config) | |||
| app_record = db.session.query(App).where(App.id == app_config.app_id).first() | |||
| stmt = select(App).where(App.id == app_config.app_id) | |||
| app_record = db.session.scalar(stmt) | |||
| if not app_record: | |||
| raise ValueError("App not found") | |||
| @@ -86,11 +86,10 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: | |||
| if conversation: | |||
| app_model_config = ( | |||
| db.session.query(AppModelConfig) | |||
| .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) | |||
| .first() | |||
| stmt = select(AppModelConfig).where( | |||
| AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id | |||
| ) | |||
| app_model_config = db.session.scalar(stmt) | |||
| if not app_model_config: | |||
| raise AppModelConfigBrokenError() | |||
| @@ -1,6 +1,8 @@ | |||
| import logging | |||
| from typing import Optional | |||
| from sqlalchemy import select | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from extensions.ext_database import db | |||
| @@ -25,9 +27,8 @@ class AnnotationReplyFeature: | |||
| :param invoke_from: invoke from | |||
| :return: | |||
| """ | |||
| annotation_setting = ( | |||
| db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first() | |||
| ) | |||
| stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id) | |||
| annotation_setting = db.session.scalar(stmt) | |||
| if not annotation_setting: | |||
| return None | |||
| @@ -86,7 +86,8 @@ class MessageCycleManager: | |||
| def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): | |||
| with flask_app.app_context(): | |||
| # get conversation and message | |||
| conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() | |||
| stmt = select(Conversation).where(Conversation.id == conversation_id) | |||
| conversation = db.session.scalar(stmt) | |||
| if not conversation: | |||
| return | |||
| @@ -1,6 +1,8 @@ | |||
| import logging | |||
| from collections.abc import Sequence | |||
| from sqlalchemy import select | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.app.entities.queue_entities import QueueRetrieverResourcesEvent | |||
| @@ -49,7 +51,8 @@ class DatasetIndexToolCallbackHandler: | |||
| for document in documents: | |||
| if document.metadata is not None: | |||
| document_id = document.metadata["document_id"] | |||
| dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() | |||
| dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id) | |||
| dataset_document = db.session.scalar(dataset_document_stmt) | |||
| if not dataset_document: | |||
| _logger.warning( | |||
| "Expected DatasetDocument record to exist, but none was found, document_id=%s", | |||
| @@ -57,15 +60,12 @@ class DatasetIndexToolCallbackHandler: | |||
| ) | |||
| continue | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_chunk = ( | |||
| db.session.query(ChildChunk) | |||
| .where( | |||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||
| ChildChunk.document_id == dataset_document.id, | |||
| ) | |||
| .first() | |||
| child_chunk_stmt = select(ChildChunk).where( | |||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||
| ChildChunk.document_id == dataset_document.id, | |||
| ) | |||
| child_chunk = db.session.scalar(child_chunk_stmt) | |||
| if child_chunk: | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| @@ -1,5 +1,7 @@ | |||
| from typing import Optional | |||
| from sqlalchemy import select | |||
| from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor | |||
| from core.external_data_tool.base import ExternalDataTool | |||
| from core.helper import encrypter | |||
| @@ -28,13 +30,11 @@ class ApiExternalDataTool(ExternalDataTool): | |||
| api_based_extension_id = config.get("api_based_extension_id") | |||
| if not api_based_extension_id: | |||
| raise ValueError("api_based_extension_id is required") | |||
| # get api_based_extension | |||
| api_based_extension = ( | |||
| db.session.query(APIBasedExtension) | |||
| .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) | |||
| .first() | |||
| stmt = select(APIBasedExtension).where( | |||
| APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id | |||
| ) | |||
| api_based_extension = db.session.scalar(stmt) | |||
| if not api_based_extension: | |||
| raise ValueError("api_based_extension_id is invalid") | |||
| @@ -52,13 +52,11 @@ class ApiExternalDataTool(ExternalDataTool): | |||
| raise ValueError(f"config is required, config: {self.config}") | |||
| api_based_extension_id = self.config.get("api_based_extension_id") | |||
| assert api_based_extension_id is not None, "api_based_extension_id is required" | |||
| # get api_based_extension | |||
| api_based_extension = ( | |||
| db.session.query(APIBasedExtension) | |||
| .where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) | |||
| .first() | |||
| stmt = select(APIBasedExtension).where( | |||
| APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id | |||
| ) | |||
| api_based_extension = db.session.scalar(stmt) | |||
| if not api_based_extension: | |||
| raise ValueError( | |||
| @@ -8,6 +8,7 @@ import uuid | |||
| from typing import Any, Optional, cast | |||
| from flask import current_app | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm.exc import ObjectDeletedError | |||
| from configs import dify_config | |||
| @@ -56,13 +57,11 @@ class IndexingRunner: | |||
| if not dataset: | |||
| raise ValueError("no dataset found") | |||
| # get the process rule | |||
| processing_rule = ( | |||
| db.session.query(DatasetProcessRule) | |||
| .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||
| .first() | |||
| stmt = select(DatasetProcessRule).where( | |||
| DatasetProcessRule.id == dataset_document.dataset_process_rule_id | |||
| ) | |||
| processing_rule = db.session.scalar(stmt) | |||
| if not processing_rule: | |||
| raise ValueError("no process rule found") | |||
| index_type = dataset_document.doc_form | |||
| @@ -123,11 +122,8 @@ class IndexingRunner: | |||
| db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() | |||
| db.session.commit() | |||
| # get the process rule | |||
| processing_rule = ( | |||
| db.session.query(DatasetProcessRule) | |||
| .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||
| .first() | |||
| ) | |||
| stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||
| processing_rule = db.session.scalar(stmt) | |||
| if not processing_rule: | |||
| raise ValueError("no process rule found") | |||
| @@ -208,7 +204,6 @@ class IndexingRunner: | |||
| child_documents.append(child_document) | |||
| document.children = child_documents | |||
| documents.append(document) | |||
| # build index | |||
| index_type = dataset_document.doc_form | |||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||
| @@ -310,7 +305,8 @@ class IndexingRunner: | |||
| # delete image files and related db records | |||
| image_upload_file_ids = get_image_upload_file_ids(document.page_content) | |||
| for upload_file_id in image_upload_file_ids: | |||
| image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() | |||
| stmt = select(UploadFile).where(UploadFile.id == upload_file_id) | |||
| image_file = db.session.scalar(stmt) | |||
| if image_file is None: | |||
| continue | |||
| try: | |||
| @@ -339,10 +335,8 @@ class IndexingRunner: | |||
| if dataset_document.data_source_type == "upload_file": | |||
| if not data_source_info or "upload_file_id" not in data_source_info: | |||
| raise ValueError("no upload file found") | |||
| file_detail = ( | |||
| db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() | |||
| ) | |||
| stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) | |||
| file_detail = db.session.scalars(stmt).one_or_none() | |||
| if file_detail: | |||
| extract_setting = ExtractSetting( | |||
| @@ -110,9 +110,9 @@ class TokenBufferMemory: | |||
| else: | |||
| message_limit = 500 | |||
| stmt = stmt.limit(message_limit) | |||
| msg_limit_stmt = stmt.limit(message_limit) | |||
| messages = db.session.scalars(stmt).all() | |||
| messages = db.session.scalars(msg_limit_stmt).all() | |||
| # instead of all messages from the conversation, we only need to extract messages | |||
| # that belong to the thread of last message | |||
| @@ -1,6 +1,7 @@ | |||
| from typing import Optional | |||
| from pydantic import BaseModel, Field | |||
| from sqlalchemy import select | |||
| from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor | |||
| from core.helper.encrypter import decrypt_token | |||
| @@ -87,10 +88,9 @@ class ApiModeration(Moderation): | |||
| @staticmethod | |||
| def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: | |||
| extension = ( | |||
| db.session.query(APIBasedExtension) | |||
| .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) | |||
| .first() | |||
| stmt = select(APIBasedExtension).where( | |||
| APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id | |||
| ) | |||
| extension = db.session.scalar(stmt) | |||
| return extension | |||
| @@ -5,6 +5,7 @@ from typing import Optional | |||
| from urllib.parse import urljoin | |||
| from opentelemetry.trace import Link, Status, StatusCode | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from core.ops.aliyun_trace.data_exporter.traceclient import ( | |||
| @@ -263,15 +264,15 @@ class AliyunDataTrace(BaseTraceInstance): | |||
| app_id = trace_info.metadata.get("app_id") | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| app = session.query(App).where(App.id == app_id).first() | |||
| app_stmt = select(App).where(App.id == app_id) | |||
| app = session.scalar(app_stmt) | |||
| if not app: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| if not app.created_by: | |||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||
| service_account = session.query(Account).where(Account.id == app.created_by).first() | |||
| account_stmt = select(Account).where(Account.id == app.created_by) | |||
| service_account = session.scalar(account_stmt) | |||
| if not service_account: | |||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||
| current_tenant = ( | |||
| @@ -1,5 +1,6 @@ | |||
| from abc import ABC, abstractmethod | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from core.ops.entities.config_entity import BaseTracingConfig | |||
| @@ -44,14 +45,15 @@ class BaseTraceInstance(ABC): | |||
| """ | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # Get the app to find its creator | |||
| app = session.query(App).where(App.id == app_id).first() | |||
| app_stmt = select(App).where(App.id == app_id) | |||
| app = session.scalar(app_stmt) | |||
| if not app: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| if not app.created_by: | |||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||
| service_account = session.query(Account).where(Account.id == app.created_by).first() | |||
| account_stmt = select(Account).where(Account.id == app.created_by) | |||
| service_account = session.scalar(account_stmt) | |||
| if not service_account: | |||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||
| @@ -226,9 +226,9 @@ class OpsTraceManager: | |||
| if not trace_config_data: | |||
| return None | |||
| # decrypt_token | |||
| app = db.session.query(App).where(App.id == app_id).first() | |||
| stmt = select(App).where(App.id == app_id) | |||
| app = db.session.scalar(stmt) | |||
| if not app: | |||
| raise ValueError("App not found") | |||
| @@ -295,20 +295,19 @@ class OpsTraceManager: | |||
| @classmethod | |||
| def get_app_config_through_message_id(cls, message_id: str): | |||
| app_model_config = None | |||
| message_data = db.session.query(Message).where(Message.id == message_id).first() | |||
| message_stmt = select(Message).where(Message.id == message_id) | |||
| message_data = db.session.scalar(message_stmt) | |||
| if not message_data: | |||
| return None | |||
| conversation_id = message_data.conversation_id | |||
| conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first() | |||
| conversation_stmt = select(Conversation).where(Conversation.id == conversation_id) | |||
| conversation_data = db.session.scalar(conversation_stmt) | |||
| if not conversation_data: | |||
| return None | |||
| if conversation_data.app_model_config_id: | |||
| app_model_config = ( | |||
| db.session.query(AppModelConfig) | |||
| .where(AppModelConfig.id == conversation_data.app_model_config_id) | |||
| .first() | |||
| ) | |||
| config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id) | |||
| app_model_config = db.session.scalar(config_stmt) | |||
| elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: | |||
| app_model_config = conversation_data.override_model_configs | |||
| @@ -1,6 +1,8 @@ | |||
| from collections.abc import Generator, Mapping | |||
| from typing import Optional, Union | |||
| from sqlalchemy import select | |||
| from controllers.service_api.wraps import create_or_update_end_user_for_user_id | |||
| from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict | |||
| from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator | |||
| @@ -192,10 +194,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): | |||
| """ | |||
| get the user by user id | |||
| """ | |||
| user = db.session.query(EndUser).where(EndUser.id == user_id).first() | |||
| stmt = select(EndUser).where(EndUser.id == user_id) | |||
| user = db.session.scalar(stmt) | |||
| if not user: | |||
| user = db.session.query(Account).where(Account.id == user_id).first() | |||
| stmt = select(Account).where(Account.id == user_id) | |||
| user = db.session.scalar(stmt) | |||
| if not user: | |||
| raise ValueError("user not found") | |||
| @@ -276,15 +276,11 @@ class ProviderManager: | |||
| :param model_type: model type | |||
| :return: | |||
| """ | |||
| # Get the corresponding TenantDefaultModel record | |||
| default_model = ( | |||
| db.session.query(TenantDefaultModel) | |||
| .where( | |||
| TenantDefaultModel.tenant_id == tenant_id, | |||
| TenantDefaultModel.model_type == model_type.to_origin_model_type(), | |||
| ) | |||
| .first() | |||
| stmt = select(TenantDefaultModel).where( | |||
| TenantDefaultModel.tenant_id == tenant_id, | |||
| TenantDefaultModel.model_type == model_type.to_origin_model_type(), | |||
| ) | |||
| default_model = db.session.scalar(stmt) | |||
| # If it does not exist, get the first available provider model from get_configurations | |||
| # and update the TenantDefaultModel record | |||
| @@ -367,16 +363,11 @@ class ProviderManager: | |||
| model_names = [model.model for model in available_models] | |||
| if model not in model_names: | |||
| raise ValueError(f"Model {model} does not exist.") | |||
| # Get the list of available models from get_configurations and check if it is LLM | |||
| default_model = ( | |||
| db.session.query(TenantDefaultModel) | |||
| .where( | |||
| TenantDefaultModel.tenant_id == tenant_id, | |||
| TenantDefaultModel.model_type == model_type.to_origin_model_type(), | |||
| ) | |||
| .first() | |||
| stmt = select(TenantDefaultModel).where( | |||
| TenantDefaultModel.tenant_id == tenant_id, | |||
| TenantDefaultModel.model_type == model_type.to_origin_model_type(), | |||
| ) | |||
| default_model = db.session.scalar(stmt) | |||
| # create or update TenantDefaultModel record | |||
| if default_model: | |||
| @@ -598,16 +589,13 @@ class ProviderManager: | |||
| provider_name_to_provider_records_dict[provider_name].append(new_provider_record) | |||
| except IntegrityError: | |||
| db.session.rollback() | |||
| existed_provider_record = ( | |||
| db.session.query(Provider) | |||
| .where( | |||
| Provider.tenant_id == tenant_id, | |||
| Provider.provider_name == ModelProviderID(provider_name).provider_name, | |||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||
| Provider.quota_type == ProviderQuotaType.TRIAL.value, | |||
| ) | |||
| .first() | |||
| stmt = select(Provider).where( | |||
| Provider.tenant_id == tenant_id, | |||
| Provider.provider_name == ModelProviderID(provider_name).provider_name, | |||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||
| Provider.quota_type == ProviderQuotaType.TRIAL.value, | |||
| ) | |||
| existed_provider_record = db.session.scalar(stmt) | |||
| if not existed_provider_record: | |||
| continue | |||
| @@ -3,6 +3,7 @@ from typing import Any, Optional | |||
| import orjson | |||
| from pydantic import BaseModel | |||
| from sqlalchemy import select | |||
| from configs import dify_config | |||
| from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | |||
| @@ -211,11 +212,10 @@ class Jieba(BaseKeyword): | |||
| return sorted_chunk_indices[:k] | |||
| def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): | |||
| document_segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) | |||
| .first() | |||
| stmt = select(DocumentSegment).where( | |||
| DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id | |||
| ) | |||
| document_segment = db.session.scalar(stmt) | |||
| if document_segment: | |||
| document_segment.keywords = keywords | |||
| db.session.add(document_segment) | |||
| @@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor | |||
| from typing import Optional | |||
| from flask import Flask, current_app | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session, load_only | |||
| from configs import dify_config | |||
| @@ -127,7 +128,8 @@ class RetrievalService: | |||
| external_retrieval_model: Optional[dict] = None, | |||
| metadata_filtering_conditions: Optional[dict] = None, | |||
| ): | |||
| dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() | |||
| stmt = select(Dataset).where(Dataset.id == dataset_id) | |||
| dataset = db.session.scalar(stmt) | |||
| if not dataset: | |||
| return [] | |||
| metadata_condition = ( | |||
| @@ -316,10 +318,8 @@ class RetrievalService: | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| # Handle parent-child documents | |||
| child_index_node_id = document.metadata.get("doc_id") | |||
| child_chunk = ( | |||
| db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first() | |||
| ) | |||
| child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) | |||
| child_chunk = db.session.scalar(child_chunk_stmt) | |||
| if not child_chunk: | |||
| continue | |||
| @@ -378,17 +378,13 @@ class RetrievalService: | |||
| index_node_id = document.metadata.get("doc_id") | |||
| if not index_node_id: | |||
| continue | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .where( | |||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.index_node_id == index_node_id, | |||
| ) | |||
| .first() | |||
| document_segment_stmt = select(DocumentSegment).where( | |||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.index_node_id == index_node_id, | |||
| ) | |||
| segment = db.session.scalar(document_segment_stmt) | |||
| if not segment: | |||
| continue | |||
| @@ -18,6 +18,7 @@ from qdrant_client.http.models import ( | |||
| TokenizerType, | |||
| ) | |||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||
| from sqlalchemy import select | |||
| from configs import dify_config | |||
| from core.rag.datasource.vdb.field import Field | |||
| @@ -445,11 +446,8 @@ class QdrantVector(BaseVector): | |||
| class QdrantVectorFactory(AbstractVectorFactory): | |||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: | |||
| if dataset.collection_binding_id: | |||
| dataset_collection_binding = ( | |||
| db.session.query(DatasetCollectionBinding) | |||
| .where(DatasetCollectionBinding.id == dataset.collection_binding_id) | |||
| .one_or_none() | |||
| ) | |||
| stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id) | |||
| dataset_collection_binding = db.session.scalars(stmt).one_or_none() | |||
| if dataset_collection_binding: | |||
| collection_name = dataset_collection_binding.collection_name | |||
| else: | |||
| @@ -20,6 +20,7 @@ from qdrant_client.http.models import ( | |||
| ) | |||
| from qdrant_client.local.qdrant_local import QdrantLocal | |||
| from requests.auth import HTTPDigestAuth | |||
| from sqlalchemy import select | |||
| from configs import dify_config | |||
| from core.rag.datasource.vdb.field import Field | |||
| @@ -416,16 +417,12 @@ class TidbOnQdrantVector(BaseVector): | |||
| class TidbOnQdrantVectorFactory(AbstractVectorFactory): | |||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: | |||
| tidb_auth_binding = ( | |||
| db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() | |||
| ) | |||
| stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||
| tidb_auth_binding = db.session.scalars(stmt).one_or_none() | |||
| if not tidb_auth_binding: | |||
| with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): | |||
| tidb_auth_binding = ( | |||
| db.session.query(TidbAuthBinding) | |||
| .where(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||
| .one_or_none() | |||
| ) | |||
| stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||
| tidb_auth_binding = db.session.scalars(stmt).one_or_none() | |||
| if tidb_auth_binding: | |||
| TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | |||
| @@ -3,6 +3,8 @@ import time | |||
| from abc import ABC, abstractmethod | |||
| from typing import Any, Optional | |||
| from sqlalchemy import select | |||
| from configs import dify_config | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| @@ -45,11 +47,10 @@ class Vector: | |||
| vector_type = self._dataset.index_struct_dict["type"] | |||
| else: | |||
| if dify_config.VECTOR_STORE_WHITELIST_ENABLE: | |||
| whitelist = ( | |||
| db.session.query(Whitelist) | |||
| .where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") | |||
| .one_or_none() | |||
| stmt = select(Whitelist).where( | |||
| Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db" | |||
| ) | |||
| whitelist = db.session.scalars(stmt).one_or_none() | |||
| if whitelist: | |||
| vector_type = VectorType.TIDB_ON_QDRANT | |||
| @@ -1,7 +1,7 @@ | |||
| from collections.abc import Sequence | |||
| from typing import Any, Optional | |||
| from sqlalchemy import func | |||
| from sqlalchemy import func, select | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| @@ -41,9 +41,8 @@ class DatasetDocumentStore: | |||
| @property | |||
| def docs(self) -> dict[str, Document]: | |||
| document_segments = ( | |||
| db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all() | |||
| ) | |||
| stmt = select(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id) | |||
| document_segments = db.session.scalars(stmt).all() | |||
| output = {} | |||
| for document_segment in document_segments: | |||
| @@ -228,10 +227,9 @@ class DatasetDocumentStore: | |||
| return data | |||
| def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: | |||
| document_segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) | |||
| .first() | |||
| stmt = select(DocumentSegment).where( | |||
| DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id | |||
| ) | |||
| document_segment = db.session.scalar(stmt) | |||
| return document_segment | |||
| @@ -4,6 +4,7 @@ import operator | |||
| from typing import Any, Optional, cast | |||
| import requests | |||
| from sqlalchemy import select | |||
| from configs import dify_config | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| @@ -367,18 +368,13 @@ class NotionExtractor(BaseExtractor): | |||
| @classmethod | |||
| def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: | |||
| data_source_binding = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .where( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', | |||
| ) | |||
| ) | |||
| .first() | |||
| stmt = select(DataSourceOauthBinding).where( | |||
| DataSourceOauthBinding.tenant_id == tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', | |||
| ) | |||
| data_source_binding = db.session.scalar(stmt) | |||
| if not data_source_binding: | |||
| raise Exception( | |||
| @@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping | |||
| from typing import Any, Optional, Union, cast | |||
| from flask import Flask, current_app | |||
| from sqlalchemy import Float, and_, or_, text | |||
| from sqlalchemy import Float, and_, or_, select, text | |||
| from sqlalchemy import cast as sqlalchemy_cast | |||
| from sqlalchemy.orm import Session | |||
| @@ -135,7 +135,8 @@ class DatasetRetrieval: | |||
| available_datasets = [] | |||
| for dataset_id in dataset_ids: | |||
| # get dataset from dataset id | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id) | |||
| dataset = db.session.scalar(dataset_stmt) | |||
| # pass if dataset is not available | |||
| if not dataset: | |||
| @@ -240,15 +241,12 @@ class DatasetRetrieval: | |||
| for record in records: | |||
| segment = record.segment | |||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | |||
| document = ( | |||
| db.session.query(DatasetDocument) | |||
| .where( | |||
| DatasetDocument.id == segment.document_id, | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ) | |||
| .first() | |||
| dataset_document_stmt = select(DatasetDocument).where( | |||
| DatasetDocument.id == segment.document_id, | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ) | |||
| document = db.session.scalar(dataset_document_stmt) | |||
| if dataset and document: | |||
| source = RetrievalSourceMetadata( | |||
| dataset_id=dataset.id, | |||
| @@ -327,7 +325,8 @@ class DatasetRetrieval: | |||
| if dataset_id: | |||
| # get retrieval model config | |||
| dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() | |||
| dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) | |||
| dataset = db.session.scalar(dataset_stmt) | |||
| if dataset: | |||
| results = [] | |||
| if dataset.provider == "external": | |||
| @@ -514,22 +513,18 @@ class DatasetRetrieval: | |||
| dify_documents = [document for document in documents if document.provider == "dify"] | |||
| for document in dify_documents: | |||
| if document.metadata is not None: | |||
| dataset_document = ( | |||
| db.session.query(DatasetDocument) | |||
| .where(DatasetDocument.id == document.metadata["document_id"]) | |||
| .first() | |||
| dataset_document_stmt = select(DatasetDocument).where( | |||
| DatasetDocument.id == document.metadata["document_id"] | |||
| ) | |||
| dataset_document = db.session.scalar(dataset_document_stmt) | |||
| if dataset_document: | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_chunk = ( | |||
| db.session.query(ChildChunk) | |||
| .where( | |||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||
| ChildChunk.document_id == dataset_document.id, | |||
| ) | |||
| .first() | |||
| child_chunk_stmt = select(ChildChunk).where( | |||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||
| ChildChunk.document_id == dataset_document.id, | |||
| ) | |||
| child_chunk = db.session.scalar(child_chunk_stmt) | |||
| if child_chunk: | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| @@ -600,7 +595,8 @@ class DatasetRetrieval: | |||
| ): | |||
| with flask_app.app_context(): | |||
| with Session(db.engine) as session: | |||
| dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() | |||
| dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) | |||
| dataset = db.session.scalar(dataset_stmt) | |||
| if not dataset: | |||
| return [] | |||
| @@ -685,7 +681,8 @@ class DatasetRetrieval: | |||
| available_datasets = [] | |||
| for dataset_id in dataset_ids: | |||
| # get dataset from dataset id | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id) | |||
| dataset = db.session.scalar(dataset_stmt) | |||
| # pass if dataset is not available | |||
| if not dataset: | |||
| @@ -958,7 +955,8 @@ class DatasetRetrieval: | |||
| self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig | |||
| ) -> Optional[list[dict[str, Any]]]: | |||
| # get all metadata field | |||
| metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all() | |||
| metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) | |||
| metadata_fields = db.session.scalars(metadata_stmt).all() | |||
| all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] | |||
| # get metadata model config | |||
| if metadata_model_config is None: | |||
| @@ -1,3 +1,5 @@ | |||
| from sqlalchemy import select | |||
| from core.tools.__base.tool_provider import ToolProviderController | |||
| from core.tools.builtin_tool.provider import BuiltinToolProviderController | |||
| from core.tools.custom_tool.provider import ApiToolProviderController | |||
| @@ -54,17 +56,13 @@ class ToolLabelManager: | |||
| return controller.tool_labels | |||
| else: | |||
| raise ValueError("Unsupported tool type") | |||
| labels = ( | |||
| db.session.query(ToolLabelBinding.label_name) | |||
| .where( | |||
| ToolLabelBinding.tool_id == provider_id, | |||
| ToolLabelBinding.tool_type == controller.provider_type.value, | |||
| ) | |||
| .all() | |||
| stmt = select(ToolLabelBinding.label_name).where( | |||
| ToolLabelBinding.tool_id == provider_id, | |||
| ToolLabelBinding.tool_type == controller.provider_type.value, | |||
| ) | |||
| labels = db.session.scalars(stmt).all() | |||
| return [label.label_name for label in labels] | |||
| return list(labels) | |||
| @classmethod | |||
| def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: | |||
| @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast | |||
| import sqlalchemy as sa | |||
| from pydantic import TypeAdapter | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from yarl import URL | |||
| @@ -198,14 +199,11 @@ class ToolManager: | |||
| # get specific credentials | |||
| if is_valid_uuid(credential_id): | |||
| try: | |||
| builtin_provider = ( | |||
| db.session.query(BuiltinToolProvider) | |||
| .where( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.id == credential_id, | |||
| ) | |||
| .first() | |||
| builtin_provider_stmt = select(BuiltinToolProvider).where( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.id == credential_id, | |||
| ) | |||
| builtin_provider = db.session.scalar(builtin_provider_stmt) | |||
| except Exception as e: | |||
| builtin_provider = None | |||
| logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) | |||
| @@ -317,11 +315,10 @@ class ToolManager: | |||
| ), | |||
| ) | |||
| elif provider_type == ToolProviderType.WORKFLOW: | |||
| workflow_provider = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) | |||
| .first() | |||
| workflow_provider_stmt = select(WorkflowToolProvider).where( | |||
| WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id | |||
| ) | |||
| workflow_provider = db.session.scalar(workflow_provider_stmt) | |||
| if workflow_provider is None: | |||
| raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") | |||
| @@ -3,6 +3,7 @@ from typing import Any | |||
| from flask import Flask, current_app | |||
| from pydantic import BaseModel, Field | |||
| from sqlalchemy import select | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.model_manager import ModelManager | |||
| @@ -85,17 +86,14 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| document_context_list = [] | |||
| index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] | |||
| segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .where( | |||
| DocumentSegment.dataset_id.in_(self.dataset_ids), | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||
| ) | |||
| .all() | |||
| document_segment_stmt = select(DocumentSegment).where( | |||
| DocumentSegment.dataset_id.in_(self.dataset_ids), | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||
| ) | |||
| segments = db.session.scalars(document_segment_stmt).all() | |||
| if segments: | |||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||
| @@ -112,15 +110,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| resource_number = 1 | |||
| for segment in sorted_segments: | |||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | |||
| document = ( | |||
| db.session.query(Document) | |||
| .where( | |||
| Document.id == segment.document_id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ) | |||
| .first() | |||
| document_stmt = select(Document).where( | |||
| Document.id == segment.document_id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ) | |||
| document = db.session.scalar(document_stmt) | |||
| if dataset and document: | |||
| source = RetrievalSourceMetadata( | |||
| position=resource_number, | |||
| @@ -162,9 +157,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| hit_callbacks: list[DatasetIndexToolCallbackHandler], | |||
| ): | |||
| with flask_app.app_context(): | |||
| dataset = ( | |||
| db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() | |||
| ) | |||
| stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id) | |||
| dataset = db.session.scalar(stmt) | |||
| if not dataset: | |||
| return [] | |||
| @@ -1,6 +1,7 @@ | |||
| from typing import Any, Optional, cast | |||
| from pydantic import BaseModel, Field | |||
| from sqlalchemy import select | |||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| @@ -56,9 +57,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| ) | |||
| def _run(self, query: str) -> str: | |||
| dataset = ( | |||
| db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() | |||
| ) | |||
| dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id) | |||
| dataset = db.session.scalar(dataset_stmt) | |||
| if not dataset: | |||
| return "" | |||
| @@ -188,15 +188,12 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| for record in records: | |||
| segment = record.segment | |||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | |||
| document = ( | |||
| db.session.query(DatasetDocument) # type: ignore | |||
| .where( | |||
| DatasetDocument.id == segment.document_id, | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ) | |||
| .first() | |||
| dataset_document_stmt = select(DatasetDocument).where( | |||
| DatasetDocument.id == segment.document_id, | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ) | |||
| document = db.session.scalar(dataset_document_stmt) # type: ignore | |||
| if dataset and document: | |||
| source = RetrievalSourceMetadata( | |||
| dataset_id=dataset.id, | |||
| @@ -3,6 +3,8 @@ import logging | |||
| from collections.abc import Generator | |||
| from typing import Any, Optional | |||
| from sqlalchemy import select | |||
| from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod | |||
| from core.tools.__base.tool import Tool | |||
| from core.tools.__base.tool_runtime import ToolRuntime | |||
| @@ -136,7 +138,8 @@ class WorkflowTool(Tool): | |||
| .first() | |||
| ) | |||
| else: | |||
| workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first() | |||
| stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) | |||
| workflow = db.session.scalar(stmt) | |||
| if not workflow: | |||
| raise ValueError("workflow not found or not published") | |||
| @@ -147,7 +150,8 @@ class WorkflowTool(Tool): | |||
| """ | |||
| get the app by app id | |||
| """ | |||
| app = db.session.query(App).where(App.id == app_id).first() | |||
| stmt = select(App).where(App.id == app_id) | |||
| app = db.session.scalar(stmt) | |||
| if not app: | |||
| raise ValueError("app not found") | |||
| @@ -6,7 +6,7 @@ from collections import defaultdict | |||
| from collections.abc import Mapping, Sequence | |||
| from typing import TYPE_CHECKING, Any, Optional, cast | |||
| from sqlalchemy import Float, and_, func, or_, text | |||
| from sqlalchemy import Float, and_, func, or_, select, text | |||
| from sqlalchemy import cast as sqlalchemy_cast | |||
| from sqlalchemy.orm import sessionmaker | |||
| @@ -367,15 +367,12 @@ class KnowledgeRetrievalNode(BaseNode): | |||
| for record in records: | |||
| segment = record.segment | |||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore | |||
| document = ( | |||
| db.session.query(Document) | |||
| .where( | |||
| Document.id == segment.document_id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ) | |||
| .first() | |||
| stmt = select(Document).where( | |||
| Document.id == segment.document_id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ) | |||
| document = db.session.scalar(stmt) | |||
| if dataset and document: | |||
| source = { | |||
| "metadata": { | |||
| @@ -514,7 +511,8 @@ class KnowledgeRetrievalNode(BaseNode): | |||
| self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData | |||
| ) -> list[dict[str, Any]]: | |||
| # get all metadata field | |||
| metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all() | |||
| stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) | |||
| metadata_fields = db.session.scalars(stmt).all() | |||
| all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] | |||
| if node_data.metadata_model_config is None: | |||
| raise ValueError("metadata_model_config is required") | |||