Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com>tags/1.9.0
| @@ -212,7 +212,9 @@ def migrate_annotation_vector_database(): | |||
| if not dataset_collection_binding: | |||
| click.echo(f"App annotation collection binding not found: {app.id}") | |||
| continue | |||
| annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all() | |||
| annotations = db.session.scalars( | |||
| select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) | |||
| ).all() | |||
| dataset = Dataset( | |||
| id=app.id, | |||
| tenant_id=app.tenant_id, | |||
| @@ -367,29 +369,25 @@ def migrate_knowledge_vector_database(): | |||
| ) | |||
| raise e | |||
| dataset_documents = ( | |||
| db.session.query(DatasetDocument) | |||
| .where( | |||
| dataset_documents = db.session.scalars( | |||
| select(DatasetDocument).where( | |||
| DatasetDocument.dataset_id == dataset.id, | |||
| DatasetDocument.indexing_status == "completed", | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| documents = [] | |||
| segments_count = 0 | |||
| for dataset_document in dataset_documents: | |||
| segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .where( | |||
| segments = db.session.scalars( | |||
| select(DocumentSegment).where( | |||
| DocumentSegment.document_id == dataset_document.id, | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.enabled == True, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| for segment in segments: | |||
| document = Document( | |||
| @@ -60,11 +60,11 @@ class BaseApiKeyListResource(Resource): | |||
| assert self.resource_id_field is not None, "resource_id_field must be set" | |||
| resource_id = str(resource_id) | |||
| _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) | |||
| keys = ( | |||
| db.session.query(ApiToken) | |||
| .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) | |||
| .all() | |||
| ) | |||
| keys = db.session.scalars( | |||
| select(ApiToken).where( | |||
| ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id | |||
| ) | |||
| ).all() | |||
| return {"items": keys} | |||
| @marshal_with(api_key_fields) | |||
| @@ -29,14 +29,12 @@ class DataSourceApi(Resource): | |||
| @marshal_with(integrate_list_fields) | |||
| def get(self): | |||
| # get workspace data source integrates | |||
| data_source_integrates = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .where( | |||
| data_source_integrates = db.session.scalars( | |||
| select(DataSourceOauthBinding).where( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.disabled == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| base_url = request.url_root.rstrip("/") | |||
| data_source_oauth_base_path = "/console/api/oauth/data-source" | |||
| @@ -2,6 +2,7 @@ import flask_restx | |||
| from flask import request | |||
| from flask_login import current_user | |||
| from flask_restx import Resource, marshal, marshal_with, reqparse | |||
| from sqlalchemy import select | |||
| from werkzeug.exceptions import Forbidden, NotFound | |||
| import services | |||
| @@ -411,11 +412,11 @@ class DatasetIndexingEstimateApi(Resource): | |||
| extract_settings = [] | |||
| if args["info_list"]["data_source_type"] == "upload_file": | |||
| file_ids = args["info_list"]["file_info_list"]["file_ids"] | |||
| file_details = ( | |||
| db.session.query(UploadFile) | |||
| .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) | |||
| .all() | |||
| ) | |||
| file_details = db.session.scalars( | |||
| select(UploadFile).where( | |||
| UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids) | |||
| ) | |||
| ).all() | |||
| if file_details is None: | |||
| raise NotFound("File not found.") | |||
| @@ -518,11 +519,11 @@ class DatasetIndexingStatusApi(Resource): | |||
| @account_initialization_required | |||
| def get(self, dataset_id): | |||
| dataset_id = str(dataset_id) | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) | |||
| .all() | |||
| ) | |||
| documents = db.session.scalars( | |||
| select(Document).where( | |||
| Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id | |||
| ) | |||
| ).all() | |||
| documents_status = [] | |||
| for document in documents: | |||
| completed_segments = ( | |||
| @@ -569,11 +570,11 @@ class DatasetApiKeyApi(Resource): | |||
| @account_initialization_required | |||
| @marshal_with(api_key_list) | |||
| def get(self): | |||
| keys = ( | |||
| db.session.query(ApiToken) | |||
| .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) | |||
| .all() | |||
| ) | |||
| keys = db.session.scalars( | |||
| select(ApiToken).where( | |||
| ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id | |||
| ) | |||
| ).all() | |||
| return {"items": keys} | |||
| @setup_required | |||
| @@ -1,5 +1,6 @@ | |||
| import logging | |||
| from argparse import ArgumentTypeError | |||
| from collections.abc import Sequence | |||
| from typing import Literal, cast | |||
| from flask import request | |||
| @@ -79,7 +80,7 @@ class DocumentResource(Resource): | |||
| return document | |||
| def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: | |||
| def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]: | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| @@ -3,7 +3,7 @@ from typing import Any | |||
| from flask import request | |||
| from flask_restx import Resource, inputs, marshal_with, reqparse | |||
| from sqlalchemy import and_ | |||
| from sqlalchemy import and_, select | |||
| from werkzeug.exceptions import BadRequest, Forbidden, NotFound | |||
| from controllers.console import api | |||
| @@ -33,13 +33,15 @@ class InstalledAppsListApi(Resource): | |||
| current_tenant_id = current_user.current_tenant_id | |||
| if app_id: | |||
| installed_apps = ( | |||
| db.session.query(InstalledApp) | |||
| .where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) | |||
| .all() | |||
| ) | |||
| installed_apps = db.session.scalars( | |||
| select(InstalledApp).where( | |||
| and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id) | |||
| ) | |||
| ).all() | |||
| else: | |||
| installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all() | |||
| installed_apps = db.session.scalars( | |||
| select(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id) | |||
| ).all() | |||
| if current_user.current_tenant is None: | |||
| raise ValueError("current_user.current_tenant must not be None") | |||
| @@ -248,7 +248,9 @@ class AccountIntegrateApi(Resource): | |||
| raise ValueError("Invalid user account") | |||
| account = current_user | |||
| account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all() | |||
| account_integrates = db.session.scalars( | |||
| select(AccountIntegrate).where(AccountIntegrate.account_id == account.id) | |||
| ).all() | |||
| base_url = request.url_root.rstrip("/") | |||
| oauth_base_path = "/console/api/oauth/login" | |||
| @@ -32,11 +32,16 @@ class TokenBufferMemory: | |||
| self.model_instance = model_instance | |||
| def _build_prompt_message_with_files( | |||
| self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool | |||
| self, | |||
| message_files: Sequence[MessageFile], | |||
| text_content: str, | |||
| message: Message, | |||
| app_record, | |||
| is_user_message: bool, | |||
| ) -> PromptMessage: | |||
| """ | |||
| Build prompt message with files. | |||
| :param message_files: list of MessageFile objects | |||
| :param message_files: Sequence of MessageFile objects | |||
| :param text_content: text content of the message | |||
| :param message: Message object | |||
| :param app_record: app record | |||
| @@ -128,14 +133,12 @@ class TokenBufferMemory: | |||
| prompt_messages: list[PromptMessage] = [] | |||
| for message in messages: | |||
| # Process user message with files | |||
| user_files = ( | |||
| db.session.query(MessageFile) | |||
| .where( | |||
| user_files = db.session.scalars( | |||
| select(MessageFile).where( | |||
| MessageFile.message_id == message.id, | |||
| (MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)), | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| if user_files: | |||
| user_prompt_message = self._build_prompt_message_with_files( | |||
| @@ -150,11 +153,9 @@ class TokenBufferMemory: | |||
| prompt_messages.append(UserPromptMessage(content=message.query)) | |||
| # Process assistant message with files | |||
| assistant_files = ( | |||
| db.session.query(MessageFile) | |||
| .where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant") | |||
| .all() | |||
| ) | |||
| assistant_files = db.session.scalars( | |||
| select(MessageFile).where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant") | |||
| ).all() | |||
| if assistant_files: | |||
| assistant_prompt_message = self._build_prompt_message_with_files( | |||
| @@ -15,6 +15,7 @@ from opentelemetry.sdk.resources import Resource | |||
| from opentelemetry.sdk.trace.export import SimpleSpanProcessor | |||
| from opentelemetry.sdk.trace.id_generator import RandomIdGenerator | |||
| from opentelemetry.trace import SpanContext, TraceFlags, TraceState | |||
| from sqlalchemy import select | |||
| from core.ops.base_trace_instance import BaseTraceInstance | |||
| from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig | |||
| @@ -699,8 +700,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): | |||
| def _get_workflow_nodes(self, workflow_run_id: str): | |||
| """Helper method to get workflow nodes""" | |||
| workflow_nodes = ( | |||
| db.session.query( | |||
| workflow_nodes = db.session.scalars( | |||
| select( | |||
| WorkflowNodeExecutionModel.id, | |||
| WorkflowNodeExecutionModel.tenant_id, | |||
| WorkflowNodeExecutionModel.app_id, | |||
| @@ -713,10 +714,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): | |||
| WorkflowNodeExecutionModel.elapsed_time, | |||
| WorkflowNodeExecutionModel.process_data, | |||
| WorkflowNodeExecutionModel.execution_metadata, | |||
| ) | |||
| .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) | |||
| .all() | |||
| ) | |||
| ).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) | |||
| ).all() | |||
| return workflow_nodes | |||
| def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: | |||
| @@ -1,5 +1,6 @@ | |||
| import time | |||
| import uuid | |||
| from collections.abc import Sequence | |||
| import requests | |||
| from requests.auth import HTTPDigestAuth | |||
| @@ -139,7 +140,7 @@ class TidbService: | |||
| @staticmethod | |||
| def batch_update_tidb_serverless_cluster_status( | |||
| tidb_serverless_list: list[TidbAuthBinding], | |||
| tidb_serverless_list: Sequence[TidbAuthBinding], | |||
| project_id: str, | |||
| api_url: str, | |||
| iam_url: str, | |||
| @@ -1,4 +1,5 @@ | |||
| from pydantic import Field | |||
| from sqlalchemy import select | |||
| from core.entities.provider_entities import ProviderConfig | |||
| from core.tools.__base.tool_provider import ToolProviderController | |||
| @@ -176,11 +177,11 @@ class ApiToolProviderController(ToolProviderController): | |||
| tools: list[ApiTool] = [] | |||
| # get tenant api providers | |||
| db_providers: list[ApiToolProvider] = ( | |||
| db.session.query(ApiToolProvider) | |||
| .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) | |||
| .all() | |||
| ) | |||
| db_providers = db.session.scalars( | |||
| select(ApiToolProvider).where( | |||
| ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name | |||
| ) | |||
| ).all() | |||
| if db_providers and len(db_providers) != 0: | |||
| for db_provider in db_providers: | |||
| @@ -87,9 +87,7 @@ class ToolLabelManager: | |||
| assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) | |||
| provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute] | |||
| labels: list[ToolLabelBinding] = ( | |||
| db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all() | |||
| ) | |||
| labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() | |||
| tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} | |||
| @@ -667,9 +667,9 @@ class ToolManager: | |||
| # get db api providers | |||
| if "api" in filters: | |||
| db_api_providers: list[ApiToolProvider] = ( | |||
| db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() | |||
| ) | |||
| db_api_providers = db.session.scalars( | |||
| select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id) | |||
| ).all() | |||
| api_provider_controllers: list[dict[str, Any]] = [ | |||
| {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} | |||
| @@ -690,9 +690,9 @@ class ToolManager: | |||
| if "workflow" in filters: | |||
| # get workflow providers | |||
| workflow_providers: list[WorkflowToolProvider] = ( | |||
| db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() | |||
| ) | |||
| workflow_providers = db.session.scalars( | |||
| select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) | |||
| ).all() | |||
| workflow_provider_controllers: list[WorkflowToolProviderController] = [] | |||
| for workflow_provider in workflow_providers: | |||
| @@ -1,3 +1,5 @@ | |||
| from sqlalchemy import select | |||
| from events.app_event import app_model_config_was_updated | |||
| from extensions.ext_database import db | |||
| from models.dataset import AppDatasetJoin | |||
| @@ -13,7 +15,7 @@ def handle(sender, **kwargs): | |||
| dataset_ids = get_dataset_ids_from_model_config(app_model_config) | |||
| app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() | |||
| app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all() | |||
| removed_dataset_ids: set[str] = set() | |||
| if not app_dataset_joins: | |||
| @@ -1,5 +1,7 @@ | |||
| from typing import cast | |||
| from sqlalchemy import select | |||
| from core.workflow.nodes import NodeType | |||
| from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData | |||
| from events.app_event import app_published_workflow_was_updated | |||
| @@ -15,7 +17,7 @@ def handle(sender, **kwargs): | |||
| published_workflow = cast(Workflow, published_workflow) | |||
| dataset_ids = get_dataset_ids_from_workflow(published_workflow) | |||
| app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() | |||
| app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all() | |||
| removed_dataset_ids: set[str] = set() | |||
| if not app_dataset_joins: | |||
| @@ -218,10 +218,12 @@ class Tenant(Base): | |||
| updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) | |||
| def get_accounts(self) -> list[Account]: | |||
| return ( | |||
| db.session.query(Account) | |||
| .where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) | |||
| .all() | |||
| return list( | |||
| db.session.scalars( | |||
| select(Account).where( | |||
| Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id | |||
| ) | |||
| ).all() | |||
| ) | |||
| @property | |||
| @@ -208,7 +208,9 @@ class Dataset(Base): | |||
| @property | |||
| def doc_metadata(self): | |||
| dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all() | |||
| dataset_metadatas = db.session.scalars( | |||
| select(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id) | |||
| ).all() | |||
| doc_metadata = [ | |||
| { | |||
| @@ -1055,13 +1057,11 @@ class ExternalKnowledgeApis(Base): | |||
| @property | |||
| def dataset_bindings(self) -> list[dict[str, Any]]: | |||
| external_knowledge_bindings = ( | |||
| db.session.query(ExternalKnowledgeBindings) | |||
| .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) | |||
| .all() | |||
| ) | |||
| external_knowledge_bindings = db.session.scalars( | |||
| select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) | |||
| ).all() | |||
| dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] | |||
| datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() | |||
| datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() | |||
| dataset_bindings: list[dict[str, Any]] = [] | |||
| for dataset in datasets: | |||
| dataset_bindings.append({"id": dataset.id, "name": dataset.name}) | |||
| @@ -812,7 +812,7 @@ class Conversation(Base): | |||
| @property | |||
| def status_count(self): | |||
| messages = db.session.query(Message).where(Message.conversation_id == self.id).all() | |||
| messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all() | |||
| status_counts = { | |||
| WorkflowExecutionStatus.RUNNING: 0, | |||
| WorkflowExecutionStatus.SUCCEEDED: 0, | |||
| @@ -1090,7 +1090,7 @@ class Message(Base): | |||
| @property | |||
| def feedbacks(self): | |||
| feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all() | |||
| feedbacks = db.session.scalars(select(MessageFeedback).where(MessageFeedback.message_id == self.id)).all() | |||
| return feedbacks | |||
| @property | |||
| @@ -1145,7 +1145,7 @@ class Message(Base): | |||
| def message_files(self) -> list[dict[str, Any]]: | |||
| from factories import file_factory | |||
| message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() | |||
| message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all() | |||
| current_app = db.session.query(App).where(App.id == self.app_id).first() | |||
| if not current_app: | |||
| raise ValueError(f"App {self.app_id} not found") | |||
| @@ -96,11 +96,11 @@ def clean_unused_datasets_task(): | |||
| break | |||
| for dataset in datasets: | |||
| dataset_query = ( | |||
| db.session.query(DatasetQuery) | |||
| .where(DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id) | |||
| .all() | |||
| ) | |||
| dataset_query = db.session.scalars( | |||
| select(DatasetQuery).where( | |||
| DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id | |||
| ) | |||
| ).all() | |||
| if not dataset_query or len(dataset_query) == 0: | |||
| try: | |||
| @@ -121,15 +121,13 @@ def clean_unused_datasets_task(): | |||
| if should_clean: | |||
| # Add auto disable log if required | |||
| if add_logs: | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .where( | |||
| documents = db.session.scalars( | |||
| select(Document).where( | |||
| Document.dataset_id == dataset.id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| for document in documents: | |||
| dataset_auto_disable_log = DatasetAutoDisableLog( | |||
| tenant_id=dataset.tenant_id, | |||
| @@ -3,6 +3,7 @@ import time | |||
| from collections import defaultdict | |||
| import click | |||
| from sqlalchemy import select | |||
| import app | |||
| from configs import dify_config | |||
| @@ -31,9 +32,9 @@ def mail_clean_document_notify_task(): | |||
| # send document clean notify mail | |||
| try: | |||
| dataset_auto_disable_logs = ( | |||
| db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False).all() | |||
| ) | |||
| dataset_auto_disable_logs = db.session.scalars( | |||
| select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False) | |||
| ).all() | |||
| # group by tenant_id | |||
| dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) | |||
| for dataset_auto_disable_log in dataset_auto_disable_logs: | |||
| @@ -1,6 +1,8 @@ | |||
| import time | |||
| from collections.abc import Sequence | |||
| import click | |||
| from sqlalchemy import select | |||
| import app | |||
| from configs import dify_config | |||
| @@ -15,11 +17,9 @@ def update_tidb_serverless_status_task(): | |||
| start_at = time.perf_counter() | |||
| try: | |||
| # check the number of idle tidb serverless | |||
| tidb_serverless_list = ( | |||
| db.session.query(TidbAuthBinding) | |||
| .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") | |||
| .all() | |||
| ) | |||
| tidb_serverless_list = db.session.scalars( | |||
| select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") | |||
| ).all() | |||
| if len(tidb_serverless_list) == 0: | |||
| return | |||
| # update tidb serverless status | |||
| @@ -32,7 +32,7 @@ def update_tidb_serverless_status_task(): | |||
| click.echo(click.style(f"Update tidb serverless status task success latency: {end_at - start_at}", fg="green")) | |||
| def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): | |||
| def update_clusters(tidb_serverless_list: Sequence[TidbAuthBinding]): | |||
| try: | |||
| # batch 20 | |||
| for i in range(0, len(tidb_serverless_list), 20): | |||
| @@ -263,11 +263,9 @@ class AppAnnotationService: | |||
| db.session.delete(annotation) | |||
| annotation_hit_histories = ( | |||
| db.session.query(AppAnnotationHitHistory) | |||
| .where(AppAnnotationHitHistory.annotation_id == annotation_id) | |||
| .all() | |||
| ) | |||
| annotation_hit_histories = db.session.scalars( | |||
| select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id) | |||
| ).all() | |||
| if annotation_hit_histories: | |||
| for annotation_hit_history in annotation_hit_histories: | |||
| db.session.delete(annotation_hit_history) | |||
| @@ -1,5 +1,7 @@ | |||
| import json | |||
| from sqlalchemy import select | |||
| from core.helper import encrypter | |||
| from extensions.ext_database import db | |||
| from models.source import DataSourceApiKeyAuthBinding | |||
| @@ -9,11 +11,11 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory | |||
| class ApiKeyAuthService: | |||
| @staticmethod | |||
| def get_provider_auth_list(tenant_id: str): | |||
| data_source_api_key_bindings = ( | |||
| db.session.query(DataSourceApiKeyAuthBinding) | |||
| .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) | |||
| .all() | |||
| ) | |||
| data_source_api_key_bindings = db.session.scalars( | |||
| select(DataSourceApiKeyAuthBinding).where( | |||
| DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False) | |||
| ) | |||
| ).all() | |||
| return data_source_api_key_bindings | |||
| @staticmethod | |||
| @@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor | |||
| import click | |||
| from flask import Flask, current_app | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from configs import dify_config | |||
| @@ -115,7 +116,7 @@ class ClearFreePlanTenantExpiredLogs: | |||
| @classmethod | |||
| def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): | |||
| with flask_app.app_context(): | |||
| apps = db.session.query(App).where(App.tenant_id == tenant_id).all() | |||
| apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all() | |||
| app_ids = [app.id for app in apps] | |||
| while True: | |||
| with Session(db.engine).no_autoflush as session: | |||
| @@ -6,6 +6,7 @@ import secrets | |||
| import time | |||
| import uuid | |||
| from collections import Counter | |||
| from collections.abc import Sequence | |||
| from typing import Any, Literal, Optional | |||
| import sqlalchemy as sa | |||
| @@ -741,14 +742,12 @@ class DatasetService: | |||
| } | |||
| # get recent 30 days auto disable logs | |||
| start_date = datetime.datetime.now() - datetime.timedelta(days=30) | |||
| dataset_auto_disable_logs = ( | |||
| db.session.query(DatasetAutoDisableLog) | |||
| .where( | |||
| dataset_auto_disable_logs = db.session.scalars( | |||
| select(DatasetAutoDisableLog).where( | |||
| DatasetAutoDisableLog.dataset_id == dataset_id, | |||
| DatasetAutoDisableLog.created_at >= start_date, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| if dataset_auto_disable_logs: | |||
| return { | |||
| "document_ids": [log.document_id for log in dataset_auto_disable_logs], | |||
| @@ -885,69 +884,58 @@ class DocumentService: | |||
| return document | |||
| @staticmethod | |||
| def get_document_by_ids(document_ids: list[str]) -> list[Document]: | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .where( | |||
| def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]: | |||
| documents = db.session.scalars( | |||
| select(Document).where( | |||
| Document.id.in_(document_ids), | |||
| Document.enabled == True, | |||
| Document.indexing_status == "completed", | |||
| Document.archived == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| return documents | |||
| @staticmethod | |||
| def get_document_by_dataset_id(dataset_id: str) -> list[Document]: | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .where( | |||
| def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]: | |||
| documents = db.session.scalars( | |||
| select(Document).where( | |||
| Document.dataset_id == dataset_id, | |||
| Document.enabled == True, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| return documents | |||
| @staticmethod | |||
| def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]: | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .where( | |||
| def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: | |||
| documents = db.session.scalars( | |||
| select(Document).where( | |||
| Document.dataset_id == dataset_id, | |||
| Document.enabled == True, | |||
| Document.indexing_status == "completed", | |||
| Document.archived == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| return documents | |||
| @staticmethod | |||
| def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) | |||
| .all() | |||
| ) | |||
| def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: | |||
| documents = db.session.scalars( | |||
| select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) | |||
| ).all() | |||
| return documents | |||
| @staticmethod | |||
| def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: | |||
| def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]: | |||
| assert isinstance(current_user, Account) | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .where( | |||
| documents = db.session.scalars( | |||
| select(Document).where( | |||
| Document.batch == batch, | |||
| Document.dataset_id == dataset_id, | |||
| Document.tenant_id == current_user.current_tenant_id, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| return documents | |||
| @@ -984,7 +972,7 @@ class DocumentService: | |||
| # Check if document_ids is not empty to avoid WHERE false condition | |||
| if not document_ids or len(document_ids) == 0: | |||
| return | |||
| documents = db.session.query(Document).where(Document.id.in_(document_ids)).all() | |||
| documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all() | |||
| file_ids = [ | |||
| document.data_source_info_dict["upload_file_id"] | |||
| for document in documents | |||
| @@ -2424,16 +2412,14 @@ class SegmentService: | |||
| if not segment_ids or len(segment_ids) == 0: | |||
| return | |||
| if action == "enable": | |||
| segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .where( | |||
| segments = db.session.scalars( | |||
| select(DocumentSegment).where( | |||
| DocumentSegment.id.in_(segment_ids), | |||
| DocumentSegment.dataset_id == dataset.id, | |||
| DocumentSegment.document_id == document.id, | |||
| DocumentSegment.enabled == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| if not segments: | |||
| return | |||
| real_deal_segment_ids = [] | |||
| @@ -2451,16 +2437,14 @@ class SegmentService: | |||
| enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) | |||
| elif action == "disable": | |||
| segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .where( | |||
| segments = db.session.scalars( | |||
| select(DocumentSegment).where( | |||
| DocumentSegment.id.in_(segment_ids), | |||
| DocumentSegment.dataset_id == dataset.id, | |||
| DocumentSegment.document_id == document.id, | |||
| DocumentSegment.enabled == True, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| if not segments: | |||
| return | |||
| real_deal_segment_ids = [] | |||
| @@ -2532,16 +2516,13 @@ class SegmentService: | |||
| dataset: Dataset, | |||
| ) -> list[ChildChunk]: | |||
| assert isinstance(current_user, Account) | |||
| child_chunks = ( | |||
| db.session.query(ChildChunk) | |||
| .where( | |||
| child_chunks = db.session.scalars( | |||
| select(ChildChunk).where( | |||
| ChildChunk.dataset_id == dataset.id, | |||
| ChildChunk.document_id == document.id, | |||
| ChildChunk.segment_id == segment.id, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| child_chunks_map = {chunk.id: chunk for chunk in child_chunks} | |||
| new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], [] | |||
| @@ -2751,13 +2732,11 @@ class DatasetCollectionBindingService: | |||
| class DatasetPermissionService: | |||
| @classmethod | |||
| def get_dataset_partial_member_list(cls, dataset_id): | |||
| user_list_query = ( | |||
| db.session.query( | |||
| user_list_query = db.session.scalars( | |||
| select( | |||
| DatasetPermission.account_id, | |||
| ) | |||
| .where(DatasetPermission.dataset_id == dataset_id) | |||
| .all() | |||
| ) | |||
| ).where(DatasetPermission.dataset_id == dataset_id) | |||
| ).all() | |||
| user_list = [] | |||
| for user in user_list_query: | |||
| @@ -3,7 +3,7 @@ import logging | |||
| from json import JSONDecodeError | |||
| from typing import Optional, Union | |||
| from sqlalchemy import or_ | |||
| from sqlalchemy import or_, select | |||
| from constants import HIDDEN_VALUE | |||
| from core.entities.provider_configuration import ProviderConfiguration | |||
| @@ -322,16 +322,14 @@ class ModelLoadBalancingService: | |||
| if not isinstance(configs, list): | |||
| raise ValueError("Invalid load balancing configs") | |||
| current_load_balancing_configs = ( | |||
| db.session.query(LoadBalancingModelConfig) | |||
| .where( | |||
| current_load_balancing_configs = db.session.scalars( | |||
| select(LoadBalancingModelConfig).where( | |||
| LoadBalancingModelConfig.tenant_id == tenant_id, | |||
| LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | |||
| LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), | |||
| LoadBalancingModelConfig.model_name == model, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| # id as key, config as value | |||
| current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} | |||
| @@ -1,5 +1,7 @@ | |||
| from typing import Optional | |||
| from sqlalchemy import select | |||
| from constants.languages import languages | |||
| from extensions.ext_database import db | |||
| from models.model import App, RecommendedApp | |||
| @@ -31,18 +33,14 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): | |||
| :param language: language | |||
| :return: | |||
| """ | |||
| recommended_apps = ( | |||
| db.session.query(RecommendedApp) | |||
| .where(RecommendedApp.is_listed == True, RecommendedApp.language == language) | |||
| .all() | |||
| ) | |||
| recommended_apps = db.session.scalars( | |||
| select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == language) | |||
| ).all() | |||
| if len(recommended_apps) == 0: | |||
| recommended_apps = ( | |||
| db.session.query(RecommendedApp) | |||
| .where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) | |||
| .all() | |||
| ) | |||
| recommended_apps = db.session.scalars( | |||
| select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) | |||
| ).all() | |||
| categories = set() | |||
| recommended_apps_result = [] | |||
| @@ -2,7 +2,7 @@ import uuid | |||
| from typing import Optional | |||
| from flask_login import current_user | |||
| from sqlalchemy import func | |||
| from sqlalchemy import func, select | |||
| from werkzeug.exceptions import NotFound | |||
| from extensions.ext_database import db | |||
| @@ -29,35 +29,30 @@ class TagService: | |||
| # Check if tag_ids is not empty to avoid WHERE false condition | |||
| if not tag_ids or len(tag_ids) == 0: | |||
| return [] | |||
| tags = ( | |||
| db.session.query(Tag) | |||
| .where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) | |||
| .all() | |||
| ) | |||
| tags = db.session.scalars( | |||
| select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) | |||
| ).all() | |||
| if not tags: | |||
| return [] | |||
| tag_ids = [tag.id for tag in tags] | |||
| # Check if tag_ids is not empty to avoid WHERE false condition | |||
| if not tag_ids or len(tag_ids) == 0: | |||
| return [] | |||
| tag_bindings = ( | |||
| db.session.query(TagBinding.target_id) | |||
| .where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) | |||
| .all() | |||
| ) | |||
| if not tag_bindings: | |||
| return [] | |||
| results = [tag_binding.target_id for tag_binding in tag_bindings] | |||
| return results | |||
| tag_bindings = db.session.scalars( | |||
| select(TagBinding.target_id).where( | |||
| TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id | |||
| ) | |||
| ).all() | |||
| return tag_bindings | |||
| @staticmethod | |||
| def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str): | |||
| if not tag_type or not tag_name: | |||
| return [] | |||
| tags = ( | |||
| db.session.query(Tag) | |||
| .where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) | |||
| .all() | |||
| tags = list( | |||
| db.session.scalars( | |||
| select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) | |||
| ).all() | |||
| ) | |||
| if not tags: | |||
| return [] | |||
| @@ -117,7 +112,7 @@ class TagService: | |||
| raise NotFound("Tag not found") | |||
| db.session.delete(tag) | |||
| # delete tag binding | |||
| tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all() | |||
| tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all() | |||
| if tag_bindings: | |||
| for tag_binding in tag_bindings: | |||
| db.session.delete(tag_binding) | |||
| @@ -4,6 +4,7 @@ from collections.abc import Mapping | |||
| from typing import Any, cast | |||
| from httpx import get | |||
| from sqlalchemy import select | |||
| from core.entities.provider_entities import ProviderConfig | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| @@ -443,9 +444,7 @@ class ApiToolManageService: | |||
| list api tools | |||
| """ | |||
| # get all api providers | |||
| db_providers: list[ApiToolProvider] = ( | |||
| db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or [] | |||
| ) | |||
| db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() | |||
| result: list[ToolProviderApiEntity] = [] | |||
| @@ -3,7 +3,7 @@ from collections.abc import Mapping | |||
| from datetime import datetime | |||
| from typing import Any | |||
| from sqlalchemy import or_ | |||
| from sqlalchemy import or_, select | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.tools.__base.tool_provider import ToolProviderController | |||
| @@ -186,7 +186,9 @@ class WorkflowToolManageService: | |||
| :param tenant_id: the tenant id | |||
| :return: the list of tools | |||
| """ | |||
| db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() | |||
| db_tools = db.session.scalars( | |||
| select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) | |||
| ).all() | |||
| tools: list[WorkflowToolProviderController] = [] | |||
| for provider in db_tools: | |||
| @@ -3,6 +3,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from core.rag.models.document import Document | |||
| @@ -39,7 +40,7 @@ def enable_annotation_reply_task( | |||
| db.session.close() | |||
| return | |||
| annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all() | |||
| annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all() | |||
| enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" | |||
| enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" | |||
| @@ -3,6 +3,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.tools.utils.web_reader_tool import get_image_upload_file_ids | |||
| @@ -34,7 +35,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form | |||
| if not dataset: | |||
| raise Exception("Document has no dataset") | |||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all() | |||
| segments = db.session.scalars( | |||
| select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) | |||
| ).all() | |||
| # check segment is exist | |||
| if segments: | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| @@ -59,7 +62,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form | |||
| db.session.commit() | |||
| if file_ids: | |||
| files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all() | |||
| files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() | |||
| for file in files: | |||
| try: | |||
| storage.delete(file.key) | |||
| @@ -3,6 +3,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.tools.utils.web_reader_tool import get_image_upload_file_ids | |||
| @@ -55,8 +56,8 @@ def clean_dataset_task( | |||
| index_struct=index_struct, | |||
| collection_binding_id=collection_binding_id, | |||
| ) | |||
| documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all() | |||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all() | |||
| documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() | |||
| segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() | |||
| # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace | |||
| # This ensures all invalid doc_form values are properly handled | |||
| @@ -4,6 +4,7 @@ from typing import Optional | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.tools.utils.web_reader_tool import get_image_upload_file_ids | |||
| @@ -35,7 +36,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i | |||
| if not dataset: | |||
| raise Exception("Document has no dataset") | |||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||
| segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() | |||
| # check segment is exist | |||
| if segments: | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| @@ -3,6 +3,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from extensions.ext_database import db | |||
| @@ -34,7 +35,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): | |||
| document = db.session.query(Document).where(Document.id == document_id).first() | |||
| db.session.delete(document) | |||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||
| segments = db.session.scalars( | |||
| select(DocumentSegment).where(DocumentSegment.document_id == document_id) | |||
| ).all() | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||
| @@ -4,6 +4,7 @@ from typing import Literal | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| @@ -36,16 +37,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a | |||
| if action == "remove": | |||
| index_processor.clean(dataset, None, with_keywords=False) | |||
| elif action == "add": | |||
| dataset_documents = ( | |||
| db.session.query(DatasetDocument) | |||
| .where( | |||
| dataset_documents = db.session.scalars( | |||
| select(DatasetDocument).where( | |||
| DatasetDocument.dataset_id == dataset_id, | |||
| DatasetDocument.indexing_status == "completed", | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| if dataset_documents: | |||
| dataset_documents_ids = [doc.id for doc in dataset_documents] | |||
| @@ -89,16 +88,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a | |||
| ) | |||
| db.session.commit() | |||
| elif action == "update": | |||
| dataset_documents = ( | |||
| db.session.query(DatasetDocument) | |||
| .where( | |||
| dataset_documents = db.session.scalars( | |||
| select(DatasetDocument).where( | |||
| DatasetDocument.dataset_id == dataset_id, | |||
| DatasetDocument.indexing_status == "completed", | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| # add new index | |||
| if dataset_documents: | |||
| # update document status | |||
| @@ -3,6 +3,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from extensions.ext_database import db | |||
| @@ -44,15 +45,13 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen | |||
| # sync index processor | |||
| index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() | |||
| segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .where( | |||
| segments = db.session.scalars( | |||
| select(DocumentSegment).where( | |||
| DocumentSegment.id.in_(segment_ids), | |||
| DocumentSegment.dataset_id == dataset_id, | |||
| DocumentSegment.document_id == document_id, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| if not segments: | |||
| db.session.close() | |||
| @@ -3,6 +3,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from core.indexing_runner import DocumentIsPausedError, IndexingRunner | |||
| from core.rag.extractor.notion_extractor import NotionExtractor | |||
| @@ -85,7 +86,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): | |||
| index_type = document.doc_form | |||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||
| segments = db.session.scalars( | |||
| select(DocumentSegment).where(DocumentSegment.document_id == document_id) | |||
| ).all() | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| @@ -3,6 +3,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from core.indexing_runner import DocumentIsPausedError, IndexingRunner | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| @@ -45,7 +46,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): | |||
| index_type = document.doc_form | |||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||
| segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() | |||
| if segments: | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| @@ -3,6 +3,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from configs import dify_config | |||
| from core.indexing_runner import DocumentIsPausedError, IndexingRunner | |||
| @@ -79,7 +80,9 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): | |||
| index_type = document.doc_form | |||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||
| segments = db.session.scalars( | |||
| select(DocumentSegment).where(DocumentSegment.document_id == document_id) | |||
| ).all() | |||
| if segments: | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| @@ -3,6 +3,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| @@ -45,15 +46,13 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i | |||
| # sync index processor | |||
| index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() | |||
| segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .where( | |||
| segments = db.session.scalars( | |||
| select(DocumentSegment).where( | |||
| DocumentSegment.id.in_(segment_ids), | |||
| DocumentSegment.dataset_id == dataset_id, | |||
| DocumentSegment.document_id == document_id, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| if not segments: | |||
| logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan")) | |||
| db.session.close() | |||
| @@ -3,6 +3,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from extensions.ext_database import db | |||
| @@ -45,7 +46,7 @@ def remove_document_from_index_task(document_id: str): | |||
| index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() | |||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all() | |||
| segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all() | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| if index_node_ids: | |||
| try: | |||
| @@ -3,6 +3,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| @@ -69,7 +70,9 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): | |||
| # clean old data | |||
| index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() | |||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||
| segments = db.session.scalars( | |||
| select(DocumentSegment).where(DocumentSegment.document_id == document_id) | |||
| ).all() | |||
| if segments: | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| @@ -3,6 +3,7 @@ import time | |||
| import click | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| @@ -63,7 +64,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): | |||
| # clean old data | |||
| index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() | |||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||
| segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() | |||
| if segments: | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch | |||
| import pytest | |||
| from faker import Faker | |||
| from sqlalchemy import select | |||
| from models.account import TenantAccountJoin, TenantAccountRole | |||
| from models.model import Account, Tenant | |||
| @@ -468,7 +469,7 @@ class TestModelLoadBalancingService: | |||
| assert load_balancing_config.id is not None | |||
| # Verify inherit config was created in database | |||
| inherit_configs = ( | |||
| db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all() | |||
| ) | |||
| inherit_configs = db.session.scalars( | |||
| select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__") | |||
| ).all() | |||
| assert len(inherit_configs) == 1 | |||
| @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| from sqlalchemy import select | |||
| from werkzeug.exceptions import NotFound | |||
| from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole | |||
| @@ -954,7 +955,9 @@ class TestTagService: | |||
| from extensions.ext_database import db | |||
| # Verify only one binding exists | |||
| bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() | |||
| bindings = db.session.scalars( | |||
| select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) | |||
| ).all() | |||
| assert len(bindings) == 1 | |||
| def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): | |||
| @@ -1064,7 +1067,9 @@ class TestTagService: | |||
| # No error should be raised, and database state should remain unchanged | |||
| from extensions.ext_database import db | |||
| bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() | |||
| bindings = db.session.scalars( | |||
| select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) | |||
| ).all() | |||
| assert len(bindings) == 0 | |||
| def test_check_target_exists_knowledge_success( | |||
| @@ -2,6 +2,7 @@ from unittest.mock import patch | |||
| import pytest | |||
| from faker import Faker | |||
| from sqlalchemy import select | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from models.account import Account | |||
| @@ -354,16 +355,14 @@ class TestWebConversationService: | |||
| # Verify only one pinned conversation record exists | |||
| from extensions.ext_database import db | |||
| pinned_conversations = ( | |||
| db.session.query(PinnedConversation) | |||
| .where( | |||
| pinned_conversations = db.session.scalars( | |||
| select(PinnedConversation).where( | |||
| PinnedConversation.app_id == app.id, | |||
| PinnedConversation.conversation_id == conversation.id, | |||
| PinnedConversation.created_by_role == "account", | |||
| PinnedConversation.created_by == account.id, | |||
| ) | |||
| .all() | |||
| ) | |||
| ).all() | |||
| assert len(pinned_conversations) == 1 | |||
| @@ -28,18 +28,20 @@ class TestApiKeyAuthService: | |||
| mock_binding.provider = self.provider | |||
| mock_binding.disabled = False | |||
| mock_session.query.return_value.where.return_value.all.return_value = [mock_binding] | |||
| mock_session.scalars.return_value.all.return_value = [mock_binding] | |||
| result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) | |||
| assert len(result) == 1 | |||
| assert result[0].tenant_id == self.tenant_id | |||
| mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding) | |||
| assert mock_session.scalars.call_count == 1 | |||
| select_arg = mock_session.scalars.call_args[0][0] | |||
| assert "data_source_api_key_auth_binding" in str(select_arg).lower() | |||
| @patch("services.auth.api_key_auth_service.db.session") | |||
| def test_get_provider_auth_list_empty(self, mock_session): | |||
| """Test get provider auth list - empty result""" | |||
| mock_session.query.return_value.where.return_value.all.return_value = [] | |||
| mock_session.scalars.return_value.all.return_value = [] | |||
| result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) | |||
| @@ -48,13 +50,15 @@ class TestApiKeyAuthService: | |||
| @patch("services.auth.api_key_auth_service.db.session") | |||
| def test_get_provider_auth_list_filters_disabled(self, mock_session): | |||
| """Test get provider auth list - filters disabled items""" | |||
| mock_session.query.return_value.where.return_value.all.return_value = [] | |||
| mock_session.scalars.return_value.all.return_value = [] | |||
| ApiKeyAuthService.get_provider_auth_list(self.tenant_id) | |||
| # Verify where conditions include disabled.is_(False) | |||
| where_call = mock_session.query.return_value.where.call_args[0] | |||
| assert len(where_call) == 2 # tenant_id and disabled filter conditions | |||
| select_stmt = mock_session.scalars.call_args[0][0] | |||
| where_clauses = list(getattr(select_stmt, "_where_criteria", []) or []) | |||
| # Ensure both tenant filter and disabled filter exist | |||
| where_strs = [str(c).lower() for c in where_clauses] | |||
| assert any("tenant_id" in s for s in where_strs) | |||
| assert any("disabled" in s for s in where_strs) | |||
| @patch("services.auth.api_key_auth_service.db.session") | |||
| @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") | |||
| @@ -63,10 +63,10 @@ class TestAuthIntegration: | |||
| tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) | |||
| tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) | |||
| mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding] | |||
| mock_session.scalars.return_value.all.return_value = [tenant1_binding] | |||
| result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) | |||
| mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding] | |||
| mock_session.scalars.return_value.all.return_value = [tenant2_binding] | |||
| result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) | |||
| assert len(result1) == 1 | |||