Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com>tags/1.9.0
| if not dataset_collection_binding: | if not dataset_collection_binding: | ||||
| click.echo(f"App annotation collection binding not found: {app.id}") | click.echo(f"App annotation collection binding not found: {app.id}") | ||||
| continue | continue | ||||
| annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all() | |||||
| annotations = db.session.scalars( | |||||
| select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) | |||||
| ).all() | |||||
| dataset = Dataset( | dataset = Dataset( | ||||
| id=app.id, | id=app.id, | ||||
| tenant_id=app.tenant_id, | tenant_id=app.tenant_id, | ||||
| ) | ) | ||||
| raise e | raise e | ||||
| dataset_documents = ( | |||||
| db.session.query(DatasetDocument) | |||||
| .where( | |||||
| dataset_documents = db.session.scalars( | |||||
| select(DatasetDocument).where( | |||||
| DatasetDocument.dataset_id == dataset.id, | DatasetDocument.dataset_id == dataset.id, | ||||
| DatasetDocument.indexing_status == "completed", | DatasetDocument.indexing_status == "completed", | ||||
| DatasetDocument.enabled == True, | DatasetDocument.enabled == True, | ||||
| DatasetDocument.archived == False, | DatasetDocument.archived == False, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| documents = [] | documents = [] | ||||
| segments_count = 0 | segments_count = 0 | ||||
| for dataset_document in dataset_documents: | for dataset_document in dataset_documents: | ||||
| segments = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .where( | |||||
| segments = db.session.scalars( | |||||
| select(DocumentSegment).where( | |||||
| DocumentSegment.document_id == dataset_document.id, | DocumentSegment.document_id == dataset_document.id, | ||||
| DocumentSegment.status == "completed", | DocumentSegment.status == "completed", | ||||
| DocumentSegment.enabled == True, | DocumentSegment.enabled == True, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| for segment in segments: | for segment in segments: | ||||
| document = Document( | document = Document( |
| assert self.resource_id_field is not None, "resource_id_field must be set" | assert self.resource_id_field is not None, "resource_id_field must be set" | ||||
| resource_id = str(resource_id) | resource_id = str(resource_id) | ||||
| _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) | _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) | ||||
| keys = ( | |||||
| db.session.query(ApiToken) | |||||
| .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) | |||||
| .all() | |||||
| ) | |||||
| keys = db.session.scalars( | |||||
| select(ApiToken).where( | |||||
| ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id | |||||
| ) | |||||
| ).all() | |||||
| return {"items": keys} | return {"items": keys} | ||||
| @marshal_with(api_key_fields) | @marshal_with(api_key_fields) |
| @marshal_with(integrate_list_fields) | @marshal_with(integrate_list_fields) | ||||
| def get(self): | def get(self): | ||||
| # get workspace data source integrates | # get workspace data source integrates | ||||
| data_source_integrates = ( | |||||
| db.session.query(DataSourceOauthBinding) | |||||
| .where( | |||||
| data_source_integrates = db.session.scalars( | |||||
| select(DataSourceOauthBinding).where( | |||||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | ||||
| DataSourceOauthBinding.disabled == False, | DataSourceOauthBinding.disabled == False, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| base_url = request.url_root.rstrip("/") | base_url = request.url_root.rstrip("/") | ||||
| data_source_oauth_base_path = "/console/api/oauth/data-source" | data_source_oauth_base_path = "/console/api/oauth/data-source" |
| from flask import request | from flask import request | ||||
| from flask_login import current_user | from flask_login import current_user | ||||
| from flask_restx import Resource, marshal, marshal_with, reqparse | from flask_restx import Resource, marshal, marshal_with, reqparse | ||||
| from sqlalchemy import select | |||||
| from werkzeug.exceptions import Forbidden, NotFound | from werkzeug.exceptions import Forbidden, NotFound | ||||
| import services | import services | ||||
| extract_settings = [] | extract_settings = [] | ||||
| if args["info_list"]["data_source_type"] == "upload_file": | if args["info_list"]["data_source_type"] == "upload_file": | ||||
| file_ids = args["info_list"]["file_info_list"]["file_ids"] | file_ids = args["info_list"]["file_info_list"]["file_ids"] | ||||
| file_details = ( | |||||
| db.session.query(UploadFile) | |||||
| .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) | |||||
| .all() | |||||
| ) | |||||
| file_details = db.session.scalars( | |||||
| select(UploadFile).where( | |||||
| UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids) | |||||
| ) | |||||
| ).all() | |||||
| if file_details is None: | if file_details is None: | ||||
| raise NotFound("File not found.") | raise NotFound("File not found.") | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self, dataset_id): | def get(self, dataset_id): | ||||
| dataset_id = str(dataset_id) | dataset_id = str(dataset_id) | ||||
| documents = ( | |||||
| db.session.query(Document) | |||||
| .where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) | |||||
| .all() | |||||
| ) | |||||
| documents = db.session.scalars( | |||||
| select(Document).where( | |||||
| Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id | |||||
| ) | |||||
| ).all() | |||||
| documents_status = [] | documents_status = [] | ||||
| for document in documents: | for document in documents: | ||||
| completed_segments = ( | completed_segments = ( | ||||
| @account_initialization_required | @account_initialization_required | ||||
| @marshal_with(api_key_list) | @marshal_with(api_key_list) | ||||
| def get(self): | def get(self): | ||||
| keys = ( | |||||
| db.session.query(ApiToken) | |||||
| .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) | |||||
| .all() | |||||
| ) | |||||
| keys = db.session.scalars( | |||||
| select(ApiToken).where( | |||||
| ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id | |||||
| ) | |||||
| ).all() | |||||
| return {"items": keys} | return {"items": keys} | ||||
| @setup_required | @setup_required |
| import logging | import logging | ||||
| from argparse import ArgumentTypeError | from argparse import ArgumentTypeError | ||||
| from collections.abc import Sequence | |||||
| from typing import Literal, cast | from typing import Literal, cast | ||||
| from flask import request | from flask import request | ||||
| return document | return document | ||||
| def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: | |||||
| def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]: | |||||
| dataset = DatasetService.get_dataset(dataset_id) | dataset = DatasetService.get_dataset(dataset_id) | ||||
| if not dataset: | if not dataset: | ||||
| raise NotFound("Dataset not found.") | raise NotFound("Dataset not found.") |
| from flask import request | from flask import request | ||||
| from flask_restx import Resource, inputs, marshal_with, reqparse | from flask_restx import Resource, inputs, marshal_with, reqparse | ||||
| from sqlalchemy import and_ | |||||
| from sqlalchemy import and_, select | |||||
| from werkzeug.exceptions import BadRequest, Forbidden, NotFound | from werkzeug.exceptions import BadRequest, Forbidden, NotFound | ||||
| from controllers.console import api | from controllers.console import api | ||||
| current_tenant_id = current_user.current_tenant_id | current_tenant_id = current_user.current_tenant_id | ||||
| if app_id: | if app_id: | ||||
| installed_apps = ( | |||||
| db.session.query(InstalledApp) | |||||
| .where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) | |||||
| .all() | |||||
| ) | |||||
| installed_apps = db.session.scalars( | |||||
| select(InstalledApp).where( | |||||
| and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id) | |||||
| ) | |||||
| ).all() | |||||
| else: | else: | ||||
| installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all() | |||||
| installed_apps = db.session.scalars( | |||||
| select(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id) | |||||
| ).all() | |||||
| if current_user.current_tenant is None: | if current_user.current_tenant is None: | ||||
| raise ValueError("current_user.current_tenant must not be None") | raise ValueError("current_user.current_tenant must not be None") |
| raise ValueError("Invalid user account") | raise ValueError("Invalid user account") | ||||
| account = current_user | account = current_user | ||||
| account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all() | |||||
| account_integrates = db.session.scalars( | |||||
| select(AccountIntegrate).where(AccountIntegrate.account_id == account.id) | |||||
| ).all() | |||||
| base_url = request.url_root.rstrip("/") | base_url = request.url_root.rstrip("/") | ||||
| oauth_base_path = "/console/api/oauth/login" | oauth_base_path = "/console/api/oauth/login" |
| self.model_instance = model_instance | self.model_instance = model_instance | ||||
| def _build_prompt_message_with_files( | def _build_prompt_message_with_files( | ||||
| self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool | |||||
| self, | |||||
| message_files: Sequence[MessageFile], | |||||
| text_content: str, | |||||
| message: Message, | |||||
| app_record, | |||||
| is_user_message: bool, | |||||
| ) -> PromptMessage: | ) -> PromptMessage: | ||||
| """ | """ | ||||
| Build prompt message with files. | Build prompt message with files. | ||||
| :param message_files: list of MessageFile objects | |||||
| :param message_files: Sequence of MessageFile objects | |||||
| :param text_content: text content of the message | :param text_content: text content of the message | ||||
| :param message: Message object | :param message: Message object | ||||
| :param app_record: app record | :param app_record: app record | ||||
| prompt_messages: list[PromptMessage] = [] | prompt_messages: list[PromptMessage] = [] | ||||
| for message in messages: | for message in messages: | ||||
| # Process user message with files | # Process user message with files | ||||
| user_files = ( | |||||
| db.session.query(MessageFile) | |||||
| .where( | |||||
| user_files = db.session.scalars( | |||||
| select(MessageFile).where( | |||||
| MessageFile.message_id == message.id, | MessageFile.message_id == message.id, | ||||
| (MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)), | (MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)), | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| if user_files: | if user_files: | ||||
| user_prompt_message = self._build_prompt_message_with_files( | user_prompt_message = self._build_prompt_message_with_files( | ||||
| prompt_messages.append(UserPromptMessage(content=message.query)) | prompt_messages.append(UserPromptMessage(content=message.query)) | ||||
| # Process assistant message with files | # Process assistant message with files | ||||
| assistant_files = ( | |||||
| db.session.query(MessageFile) | |||||
| .where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant") | |||||
| .all() | |||||
| ) | |||||
| assistant_files = db.session.scalars( | |||||
| select(MessageFile).where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant") | |||||
| ).all() | |||||
| if assistant_files: | if assistant_files: | ||||
| assistant_prompt_message = self._build_prompt_message_with_files( | assistant_prompt_message = self._build_prompt_message_with_files( |
| from opentelemetry.sdk.trace.export import SimpleSpanProcessor | from opentelemetry.sdk.trace.export import SimpleSpanProcessor | ||||
| from opentelemetry.sdk.trace.id_generator import RandomIdGenerator | from opentelemetry.sdk.trace.id_generator import RandomIdGenerator | ||||
| from opentelemetry.trace import SpanContext, TraceFlags, TraceState | from opentelemetry.trace import SpanContext, TraceFlags, TraceState | ||||
| from sqlalchemy import select | |||||
| from core.ops.base_trace_instance import BaseTraceInstance | from core.ops.base_trace_instance import BaseTraceInstance | ||||
| from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig | from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig | ||||
| def _get_workflow_nodes(self, workflow_run_id: str): | def _get_workflow_nodes(self, workflow_run_id: str): | ||||
| """Helper method to get workflow nodes""" | """Helper method to get workflow nodes""" | ||||
| workflow_nodes = ( | |||||
| db.session.query( | |||||
| workflow_nodes = db.session.scalars( | |||||
| select( | |||||
| WorkflowNodeExecutionModel.id, | WorkflowNodeExecutionModel.id, | ||||
| WorkflowNodeExecutionModel.tenant_id, | WorkflowNodeExecutionModel.tenant_id, | ||||
| WorkflowNodeExecutionModel.app_id, | WorkflowNodeExecutionModel.app_id, | ||||
| WorkflowNodeExecutionModel.elapsed_time, | WorkflowNodeExecutionModel.elapsed_time, | ||||
| WorkflowNodeExecutionModel.process_data, | WorkflowNodeExecutionModel.process_data, | ||||
| WorkflowNodeExecutionModel.execution_metadata, | WorkflowNodeExecutionModel.execution_metadata, | ||||
| ) | |||||
| .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) | |||||
| .all() | |||||
| ) | |||||
| ).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) | |||||
| ).all() | |||||
| return workflow_nodes | return workflow_nodes | ||||
| def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: | def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: |
| import time | import time | ||||
| import uuid | import uuid | ||||
| from collections.abc import Sequence | |||||
| import requests | import requests | ||||
| from requests.auth import HTTPDigestAuth | from requests.auth import HTTPDigestAuth | ||||
| @staticmethod | @staticmethod | ||||
| def batch_update_tidb_serverless_cluster_status( | def batch_update_tidb_serverless_cluster_status( | ||||
| tidb_serverless_list: list[TidbAuthBinding], | |||||
| tidb_serverless_list: Sequence[TidbAuthBinding], | |||||
| project_id: str, | project_id: str, | ||||
| api_url: str, | api_url: str, | ||||
| iam_url: str, | iam_url: str, |
| from pydantic import Field | from pydantic import Field | ||||
| from sqlalchemy import select | |||||
| from core.entities.provider_entities import ProviderConfig | from core.entities.provider_entities import ProviderConfig | ||||
| from core.tools.__base.tool_provider import ToolProviderController | from core.tools.__base.tool_provider import ToolProviderController | ||||
| tools: list[ApiTool] = [] | tools: list[ApiTool] = [] | ||||
| # get tenant api providers | # get tenant api providers | ||||
| db_providers: list[ApiToolProvider] = ( | |||||
| db.session.query(ApiToolProvider) | |||||
| .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) | |||||
| .all() | |||||
| ) | |||||
| db_providers = db.session.scalars( | |||||
| select(ApiToolProvider).where( | |||||
| ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name | |||||
| ) | |||||
| ).all() | |||||
| if db_providers and len(db_providers) != 0: | if db_providers and len(db_providers) != 0: | ||||
| for db_provider in db_providers: | for db_provider in db_providers: |
| assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) | assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) | ||||
| provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute] | provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute] | ||||
| labels: list[ToolLabelBinding] = ( | |||||
| db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all() | |||||
| ) | |||||
| labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() | |||||
| tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} | tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} | ||||
| # get db api providers | # get db api providers | ||||
| if "api" in filters: | if "api" in filters: | ||||
| db_api_providers: list[ApiToolProvider] = ( | |||||
| db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() | |||||
| ) | |||||
| db_api_providers = db.session.scalars( | |||||
| select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id) | |||||
| ).all() | |||||
| api_provider_controllers: list[dict[str, Any]] = [ | api_provider_controllers: list[dict[str, Any]] = [ | ||||
| {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} | {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} | ||||
| if "workflow" in filters: | if "workflow" in filters: | ||||
| # get workflow providers | # get workflow providers | ||||
| workflow_providers: list[WorkflowToolProvider] = ( | |||||
| db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() | |||||
| ) | |||||
| workflow_providers = db.session.scalars( | |||||
| select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) | |||||
| ).all() | |||||
| workflow_provider_controllers: list[WorkflowToolProviderController] = [] | workflow_provider_controllers: list[WorkflowToolProviderController] = [] | ||||
| for workflow_provider in workflow_providers: | for workflow_provider in workflow_providers: |
| from sqlalchemy import select | |||||
| from events.app_event import app_model_config_was_updated | from events.app_event import app_model_config_was_updated | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import AppDatasetJoin | from models.dataset import AppDatasetJoin | ||||
| dataset_ids = get_dataset_ids_from_model_config(app_model_config) | dataset_ids = get_dataset_ids_from_model_config(app_model_config) | ||||
| app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() | |||||
| app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all() | |||||
| removed_dataset_ids: set[str] = set() | removed_dataset_ids: set[str] = set() | ||||
| if not app_dataset_joins: | if not app_dataset_joins: |
| from typing import cast | from typing import cast | ||||
| from sqlalchemy import select | |||||
| from core.workflow.nodes import NodeType | from core.workflow.nodes import NodeType | ||||
| from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData | from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData | ||||
| from events.app_event import app_published_workflow_was_updated | from events.app_event import app_published_workflow_was_updated | ||||
| published_workflow = cast(Workflow, published_workflow) | published_workflow = cast(Workflow, published_workflow) | ||||
| dataset_ids = get_dataset_ids_from_workflow(published_workflow) | dataset_ids = get_dataset_ids_from_workflow(published_workflow) | ||||
| app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() | |||||
| app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all() | |||||
| removed_dataset_ids: set[str] = set() | removed_dataset_ids: set[str] = set() | ||||
| if not app_dataset_joins: | if not app_dataset_joins: |
| updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) | updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) | ||||
| def get_accounts(self) -> list[Account]: | def get_accounts(self) -> list[Account]: | ||||
| return ( | |||||
| db.session.query(Account) | |||||
| .where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) | |||||
| .all() | |||||
| return list( | |||||
| db.session.scalars( | |||||
| select(Account).where( | |||||
| Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id | |||||
| ) | |||||
| ).all() | |||||
| ) | ) | ||||
| @property | @property |
| @property | @property | ||||
| def doc_metadata(self): | def doc_metadata(self): | ||||
| dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all() | |||||
| dataset_metadatas = db.session.scalars( | |||||
| select(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id) | |||||
| ).all() | |||||
| doc_metadata = [ | doc_metadata = [ | ||||
| { | { | ||||
| @property | @property | ||||
| def dataset_bindings(self) -> list[dict[str, Any]]: | def dataset_bindings(self) -> list[dict[str, Any]]: | ||||
| external_knowledge_bindings = ( | |||||
| db.session.query(ExternalKnowledgeBindings) | |||||
| .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) | |||||
| .all() | |||||
| ) | |||||
| external_knowledge_bindings = db.session.scalars( | |||||
| select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) | |||||
| ).all() | |||||
| dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] | dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] | ||||
| datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() | |||||
| datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() | |||||
| dataset_bindings: list[dict[str, Any]] = [] | dataset_bindings: list[dict[str, Any]] = [] | ||||
| for dataset in datasets: | for dataset in datasets: | ||||
| dataset_bindings.append({"id": dataset.id, "name": dataset.name}) | dataset_bindings.append({"id": dataset.id, "name": dataset.name}) |
| @property | @property | ||||
| def status_count(self): | def status_count(self): | ||||
| messages = db.session.query(Message).where(Message.conversation_id == self.id).all() | |||||
| messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all() | |||||
| status_counts = { | status_counts = { | ||||
| WorkflowExecutionStatus.RUNNING: 0, | WorkflowExecutionStatus.RUNNING: 0, | ||||
| WorkflowExecutionStatus.SUCCEEDED: 0, | WorkflowExecutionStatus.SUCCEEDED: 0, | ||||
| @property | @property | ||||
| def feedbacks(self): | def feedbacks(self): | ||||
| feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all() | |||||
| feedbacks = db.session.scalars(select(MessageFeedback).where(MessageFeedback.message_id == self.id)).all() | |||||
| return feedbacks | return feedbacks | ||||
| @property | @property | ||||
| def message_files(self) -> list[dict[str, Any]]: | def message_files(self) -> list[dict[str, Any]]: | ||||
| from factories import file_factory | from factories import file_factory | ||||
| message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() | |||||
| message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all() | |||||
| current_app = db.session.query(App).where(App.id == self.app_id).first() | current_app = db.session.query(App).where(App.id == self.app_id).first() | ||||
| if not current_app: | if not current_app: | ||||
| raise ValueError(f"App {self.app_id} not found") | raise ValueError(f"App {self.app_id} not found") |
| break | break | ||||
| for dataset in datasets: | for dataset in datasets: | ||||
| dataset_query = ( | |||||
| db.session.query(DatasetQuery) | |||||
| .where(DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id) | |||||
| .all() | |||||
| ) | |||||
| dataset_query = db.session.scalars( | |||||
| select(DatasetQuery).where( | |||||
| DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id | |||||
| ) | |||||
| ).all() | |||||
| if not dataset_query or len(dataset_query) == 0: | if not dataset_query or len(dataset_query) == 0: | ||||
| try: | try: | ||||
| if should_clean: | if should_clean: | ||||
| # Add auto disable log if required | # Add auto disable log if required | ||||
| if add_logs: | if add_logs: | ||||
| documents = ( | |||||
| db.session.query(Document) | |||||
| .where( | |||||
| documents = db.session.scalars( | |||||
| select(Document).where( | |||||
| Document.dataset_id == dataset.id, | Document.dataset_id == dataset.id, | ||||
| Document.enabled == True, | Document.enabled == True, | ||||
| Document.archived == False, | Document.archived == False, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| for document in documents: | for document in documents: | ||||
| dataset_auto_disable_log = DatasetAutoDisableLog( | dataset_auto_disable_log = DatasetAutoDisableLog( | ||||
| tenant_id=dataset.tenant_id, | tenant_id=dataset.tenant_id, |
| from collections import defaultdict | from collections import defaultdict | ||||
| import click | import click | ||||
| from sqlalchemy import select | |||||
| import app | import app | ||||
| from configs import dify_config | from configs import dify_config | ||||
| # send document clean notify mail | # send document clean notify mail | ||||
| try: | try: | ||||
| dataset_auto_disable_logs = ( | |||||
| db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False).all() | |||||
| ) | |||||
| dataset_auto_disable_logs = db.session.scalars( | |||||
| select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False) | |||||
| ).all() | |||||
| # group by tenant_id | # group by tenant_id | ||||
| dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) | dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) | ||||
| for dataset_auto_disable_log in dataset_auto_disable_logs: | for dataset_auto_disable_log in dataset_auto_disable_logs: |
| import time | import time | ||||
| from collections.abc import Sequence | |||||
| import click | import click | ||||
| from sqlalchemy import select | |||||
| import app | import app | ||||
| from configs import dify_config | from configs import dify_config | ||||
| start_at = time.perf_counter() | start_at = time.perf_counter() | ||||
| try: | try: | ||||
| # check the number of idle tidb serverless | # check the number of idle tidb serverless | ||||
| tidb_serverless_list = ( | |||||
| db.session.query(TidbAuthBinding) | |||||
| .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") | |||||
| .all() | |||||
| ) | |||||
| tidb_serverless_list = db.session.scalars( | |||||
| select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") | |||||
| ).all() | |||||
| if len(tidb_serverless_list) == 0: | if len(tidb_serverless_list) == 0: | ||||
| return | return | ||||
| # update tidb serverless status | # update tidb serverless status | ||||
| click.echo(click.style(f"Update tidb serverless status task success latency: {end_at - start_at}", fg="green")) | click.echo(click.style(f"Update tidb serverless status task success latency: {end_at - start_at}", fg="green")) | ||||
| def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): | |||||
| def update_clusters(tidb_serverless_list: Sequence[TidbAuthBinding]): | |||||
| try: | try: | ||||
| # batch 20 | # batch 20 | ||||
| for i in range(0, len(tidb_serverless_list), 20): | for i in range(0, len(tidb_serverless_list), 20): |
| db.session.delete(annotation) | db.session.delete(annotation) | ||||
| annotation_hit_histories = ( | |||||
| db.session.query(AppAnnotationHitHistory) | |||||
| .where(AppAnnotationHitHistory.annotation_id == annotation_id) | |||||
| .all() | |||||
| ) | |||||
| annotation_hit_histories = db.session.scalars( | |||||
| select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id) | |||||
| ).all() | |||||
| if annotation_hit_histories: | if annotation_hit_histories: | ||||
| for annotation_hit_history in annotation_hit_histories: | for annotation_hit_history in annotation_hit_histories: | ||||
| db.session.delete(annotation_hit_history) | db.session.delete(annotation_hit_history) |
| import json | import json | ||||
| from sqlalchemy import select | |||||
| from core.helper import encrypter | from core.helper import encrypter | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.source import DataSourceApiKeyAuthBinding | from models.source import DataSourceApiKeyAuthBinding | ||||
| class ApiKeyAuthService: | class ApiKeyAuthService: | ||||
| @staticmethod | @staticmethod | ||||
| def get_provider_auth_list(tenant_id: str): | def get_provider_auth_list(tenant_id: str): | ||||
| data_source_api_key_bindings = ( | |||||
| db.session.query(DataSourceApiKeyAuthBinding) | |||||
| .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) | |||||
| .all() | |||||
| ) | |||||
| data_source_api_key_bindings = db.session.scalars( | |||||
| select(DataSourceApiKeyAuthBinding).where( | |||||
| DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False) | |||||
| ) | |||||
| ).all() | |||||
| return data_source_api_key_bindings | return data_source_api_key_bindings | ||||
| @staticmethod | @staticmethod |
| import click | import click | ||||
| from flask import Flask, current_app | from flask import Flask, current_app | ||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session, sessionmaker | from sqlalchemy.orm import Session, sessionmaker | ||||
| from configs import dify_config | from configs import dify_config | ||||
| @classmethod | @classmethod | ||||
| def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): | def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): | ||||
| with flask_app.app_context(): | with flask_app.app_context(): | ||||
| apps = db.session.query(App).where(App.tenant_id == tenant_id).all() | |||||
| apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all() | |||||
| app_ids = [app.id for app in apps] | app_ids = [app.id for app in apps] | ||||
| while True: | while True: | ||||
| with Session(db.engine).no_autoflush as session: | with Session(db.engine).no_autoflush as session: |
| import time | import time | ||||
| import uuid | import uuid | ||||
| from collections import Counter | from collections import Counter | ||||
| from collections.abc import Sequence | |||||
| from typing import Any, Literal, Optional | from typing import Any, Literal, Optional | ||||
| import sqlalchemy as sa | import sqlalchemy as sa | ||||
| } | } | ||||
| # get recent 30 days auto disable logs | # get recent 30 days auto disable logs | ||||
| start_date = datetime.datetime.now() - datetime.timedelta(days=30) | start_date = datetime.datetime.now() - datetime.timedelta(days=30) | ||||
| dataset_auto_disable_logs = ( | |||||
| db.session.query(DatasetAutoDisableLog) | |||||
| .where( | |||||
| dataset_auto_disable_logs = db.session.scalars( | |||||
| select(DatasetAutoDisableLog).where( | |||||
| DatasetAutoDisableLog.dataset_id == dataset_id, | DatasetAutoDisableLog.dataset_id == dataset_id, | ||||
| DatasetAutoDisableLog.created_at >= start_date, | DatasetAutoDisableLog.created_at >= start_date, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| if dataset_auto_disable_logs: | if dataset_auto_disable_logs: | ||||
| return { | return { | ||||
| "document_ids": [log.document_id for log in dataset_auto_disable_logs], | "document_ids": [log.document_id for log in dataset_auto_disable_logs], | ||||
| return document | return document | ||||
| @staticmethod | @staticmethod | ||||
| def get_document_by_ids(document_ids: list[str]) -> list[Document]: | |||||
| documents = ( | |||||
| db.session.query(Document) | |||||
| .where( | |||||
| def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]: | |||||
| documents = db.session.scalars( | |||||
| select(Document).where( | |||||
| Document.id.in_(document_ids), | Document.id.in_(document_ids), | ||||
| Document.enabled == True, | Document.enabled == True, | ||||
| Document.indexing_status == "completed", | Document.indexing_status == "completed", | ||||
| Document.archived == False, | Document.archived == False, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| return documents | return documents | ||||
| @staticmethod | @staticmethod | ||||
| def get_document_by_dataset_id(dataset_id: str) -> list[Document]: | |||||
| documents = ( | |||||
| db.session.query(Document) | |||||
| .where( | |||||
| def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]: | |||||
| documents = db.session.scalars( | |||||
| select(Document).where( | |||||
| Document.dataset_id == dataset_id, | Document.dataset_id == dataset_id, | ||||
| Document.enabled == True, | Document.enabled == True, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| return documents | return documents | ||||
| @staticmethod | @staticmethod | ||||
| def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]: | |||||
| documents = ( | |||||
| db.session.query(Document) | |||||
| .where( | |||||
| def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: | |||||
| documents = db.session.scalars( | |||||
| select(Document).where( | |||||
| Document.dataset_id == dataset_id, | Document.dataset_id == dataset_id, | ||||
| Document.enabled == True, | Document.enabled == True, | ||||
| Document.indexing_status == "completed", | Document.indexing_status == "completed", | ||||
| Document.archived == False, | Document.archived == False, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| return documents | return documents | ||||
| @staticmethod | @staticmethod | ||||
| def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: | |||||
| documents = ( | |||||
| db.session.query(Document) | |||||
| .where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) | |||||
| .all() | |||||
| ) | |||||
| def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: | |||||
| documents = db.session.scalars( | |||||
| select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) | |||||
| ).all() | |||||
| return documents | return documents | ||||
| @staticmethod | @staticmethod | ||||
| def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: | |||||
| def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]: | |||||
| assert isinstance(current_user, Account) | assert isinstance(current_user, Account) | ||||
| documents = ( | |||||
| db.session.query(Document) | |||||
| .where( | |||||
| documents = db.session.scalars( | |||||
| select(Document).where( | |||||
| Document.batch == batch, | Document.batch == batch, | ||||
| Document.dataset_id == dataset_id, | Document.dataset_id == dataset_id, | ||||
| Document.tenant_id == current_user.current_tenant_id, | Document.tenant_id == current_user.current_tenant_id, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| return documents | return documents | ||||
| # Check if document_ids is not empty to avoid WHERE false condition | # Check if document_ids is not empty to avoid WHERE false condition | ||||
| if not document_ids or len(document_ids) == 0: | if not document_ids or len(document_ids) == 0: | ||||
| return | return | ||||
| documents = db.session.query(Document).where(Document.id.in_(document_ids)).all() | |||||
| documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all() | |||||
| file_ids = [ | file_ids = [ | ||||
| document.data_source_info_dict["upload_file_id"] | document.data_source_info_dict["upload_file_id"] | ||||
| for document in documents | for document in documents | ||||
| if not segment_ids or len(segment_ids) == 0: | if not segment_ids or len(segment_ids) == 0: | ||||
| return | return | ||||
| if action == "enable": | if action == "enable": | ||||
| segments = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .where( | |||||
| segments = db.session.scalars( | |||||
| select(DocumentSegment).where( | |||||
| DocumentSegment.id.in_(segment_ids), | DocumentSegment.id.in_(segment_ids), | ||||
| DocumentSegment.dataset_id == dataset.id, | DocumentSegment.dataset_id == dataset.id, | ||||
| DocumentSegment.document_id == document.id, | DocumentSegment.document_id == document.id, | ||||
| DocumentSegment.enabled == False, | DocumentSegment.enabled == False, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| if not segments: | if not segments: | ||||
| return | return | ||||
| real_deal_segment_ids = [] | real_deal_segment_ids = [] | ||||
| enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) | enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) | ||||
| elif action == "disable": | elif action == "disable": | ||||
| segments = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .where( | |||||
| segments = db.session.scalars( | |||||
| select(DocumentSegment).where( | |||||
| DocumentSegment.id.in_(segment_ids), | DocumentSegment.id.in_(segment_ids), | ||||
| DocumentSegment.dataset_id == dataset.id, | DocumentSegment.dataset_id == dataset.id, | ||||
| DocumentSegment.document_id == document.id, | DocumentSegment.document_id == document.id, | ||||
| DocumentSegment.enabled == True, | DocumentSegment.enabled == True, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| if not segments: | if not segments: | ||||
| return | return | ||||
| real_deal_segment_ids = [] | real_deal_segment_ids = [] | ||||
| dataset: Dataset, | dataset: Dataset, | ||||
| ) -> list[ChildChunk]: | ) -> list[ChildChunk]: | ||||
| assert isinstance(current_user, Account) | assert isinstance(current_user, Account) | ||||
| child_chunks = ( | |||||
| db.session.query(ChildChunk) | |||||
| .where( | |||||
| child_chunks = db.session.scalars( | |||||
| select(ChildChunk).where( | |||||
| ChildChunk.dataset_id == dataset.id, | ChildChunk.dataset_id == dataset.id, | ||||
| ChildChunk.document_id == document.id, | ChildChunk.document_id == document.id, | ||||
| ChildChunk.segment_id == segment.id, | ChildChunk.segment_id == segment.id, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| child_chunks_map = {chunk.id: chunk for chunk in child_chunks} | child_chunks_map = {chunk.id: chunk for chunk in child_chunks} | ||||
| new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], [] | new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], [] | ||||
| class DatasetPermissionService: | class DatasetPermissionService: | ||||
| @classmethod | @classmethod | ||||
| def get_dataset_partial_member_list(cls, dataset_id): | def get_dataset_partial_member_list(cls, dataset_id): | ||||
| user_list_query = ( | |||||
| db.session.query( | |||||
| user_list_query = db.session.scalars( | |||||
| select( | |||||
| DatasetPermission.account_id, | DatasetPermission.account_id, | ||||
| ) | |||||
| .where(DatasetPermission.dataset_id == dataset_id) | |||||
| .all() | |||||
| ) | |||||
| ).where(DatasetPermission.dataset_id == dataset_id) | |||||
| ).all() | |||||
| user_list = [] | user_list = [] | ||||
| for user in user_list_query: | for user in user_list_query: |
| from json import JSONDecodeError | from json import JSONDecodeError | ||||
| from typing import Optional, Union | from typing import Optional, Union | ||||
| from sqlalchemy import or_ | |||||
| from sqlalchemy import or_, select | |||||
| from constants import HIDDEN_VALUE | from constants import HIDDEN_VALUE | ||||
| from core.entities.provider_configuration import ProviderConfiguration | from core.entities.provider_configuration import ProviderConfiguration | ||||
| if not isinstance(configs, list): | if not isinstance(configs, list): | ||||
| raise ValueError("Invalid load balancing configs") | raise ValueError("Invalid load balancing configs") | ||||
| current_load_balancing_configs = ( | |||||
| db.session.query(LoadBalancingModelConfig) | |||||
| .where( | |||||
| current_load_balancing_configs = db.session.scalars( | |||||
| select(LoadBalancingModelConfig).where( | |||||
| LoadBalancingModelConfig.tenant_id == tenant_id, | LoadBalancingModelConfig.tenant_id == tenant_id, | ||||
| LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | ||||
| LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), | LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), | ||||
| LoadBalancingModelConfig.model_name == model, | LoadBalancingModelConfig.model_name == model, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| # id as key, config as value | # id as key, config as value | ||||
| current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} | current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} |
| from typing import Optional | from typing import Optional | ||||
| from sqlalchemy import select | |||||
| from constants.languages import languages | from constants.languages import languages | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.model import App, RecommendedApp | from models.model import App, RecommendedApp | ||||
| :param language: language | :param language: language | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| recommended_apps = ( | |||||
| db.session.query(RecommendedApp) | |||||
| .where(RecommendedApp.is_listed == True, RecommendedApp.language == language) | |||||
| .all() | |||||
| ) | |||||
| recommended_apps = db.session.scalars( | |||||
| select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == language) | |||||
| ).all() | |||||
| if len(recommended_apps) == 0: | if len(recommended_apps) == 0: | ||||
| recommended_apps = ( | |||||
| db.session.query(RecommendedApp) | |||||
| .where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) | |||||
| .all() | |||||
| ) | |||||
| recommended_apps = db.session.scalars( | |||||
| select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) | |||||
| ).all() | |||||
| categories = set() | categories = set() | ||||
| recommended_apps_result = [] | recommended_apps_result = [] |
| from typing import Optional | from typing import Optional | ||||
| from flask_login import current_user | from flask_login import current_user | ||||
| from sqlalchemy import func | |||||
| from sqlalchemy import func, select | |||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| # Check if tag_ids is not empty to avoid WHERE false condition | # Check if tag_ids is not empty to avoid WHERE false condition | ||||
| if not tag_ids or len(tag_ids) == 0: | if not tag_ids or len(tag_ids) == 0: | ||||
| return [] | return [] | ||||
| tags = ( | |||||
| db.session.query(Tag) | |||||
| .where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) | |||||
| .all() | |||||
| ) | |||||
| tags = db.session.scalars( | |||||
| select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) | |||||
| ).all() | |||||
| if not tags: | if not tags: | ||||
| return [] | return [] | ||||
| tag_ids = [tag.id for tag in tags] | tag_ids = [tag.id for tag in tags] | ||||
| # Check if tag_ids is not empty to avoid WHERE false condition | # Check if tag_ids is not empty to avoid WHERE false condition | ||||
| if not tag_ids or len(tag_ids) == 0: | if not tag_ids or len(tag_ids) == 0: | ||||
| return [] | return [] | ||||
| tag_bindings = ( | |||||
| db.session.query(TagBinding.target_id) | |||||
| .where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) | |||||
| .all() | |||||
| ) | |||||
| if not tag_bindings: | |||||
| return [] | |||||
| results = [tag_binding.target_id for tag_binding in tag_bindings] | |||||
| return results | |||||
| tag_bindings = db.session.scalars( | |||||
| select(TagBinding.target_id).where( | |||||
| TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id | |||||
| ) | |||||
| ).all() | |||||
| return tag_bindings | |||||
| @staticmethod | @staticmethod | ||||
| def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str): | def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str): | ||||
| if not tag_type or not tag_name: | if not tag_type or not tag_name: | ||||
| return [] | return [] | ||||
| tags = ( | |||||
| db.session.query(Tag) | |||||
| .where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) | |||||
| .all() | |||||
| tags = list( | |||||
| db.session.scalars( | |||||
| select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) | |||||
| ).all() | |||||
| ) | ) | ||||
| if not tags: | if not tags: | ||||
| return [] | return [] | ||||
| raise NotFound("Tag not found") | raise NotFound("Tag not found") | ||||
| db.session.delete(tag) | db.session.delete(tag) | ||||
| # delete tag binding | # delete tag binding | ||||
| tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all() | |||||
| tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all() | |||||
| if tag_bindings: | if tag_bindings: | ||||
| for tag_binding in tag_bindings: | for tag_binding in tag_bindings: | ||||
| db.session.delete(tag_binding) | db.session.delete(tag_binding) |
| from typing import Any, cast | from typing import Any, cast | ||||
| from httpx import get | from httpx import get | ||||
| from sqlalchemy import select | |||||
| from core.entities.provider_entities import ProviderConfig | from core.entities.provider_entities import ProviderConfig | ||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| list api tools | list api tools | ||||
| """ | """ | ||||
| # get all api providers | # get all api providers | ||||
| db_providers: list[ApiToolProvider] = ( | |||||
| db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or [] | |||||
| ) | |||||
| db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() | |||||
| result: list[ToolProviderApiEntity] = [] | result: list[ToolProviderApiEntity] = [] | ||||
| from datetime import datetime | from datetime import datetime | ||||
| from typing import Any | from typing import Any | ||||
| from sqlalchemy import or_ | |||||
| from sqlalchemy import or_, select | |||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from core.tools.__base.tool_provider import ToolProviderController | from core.tools.__base.tool_provider import ToolProviderController | ||||
| :param tenant_id: the tenant id | :param tenant_id: the tenant id | ||||
| :return: the list of tools | :return: the list of tools | ||||
| """ | """ | ||||
| db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() | |||||
| db_tools = db.session.scalars( | |||||
| select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) | |||||
| ).all() | |||||
| tools: list[WorkflowToolProviderController] = [] | tools: list[WorkflowToolProviderController] = [] | ||||
| for provider in db_tools: | for provider in db_tools: |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from core.rag.datasource.vdb.vector_factory import Vector | from core.rag.datasource.vdb.vector_factory import Vector | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| db.session.close() | db.session.close() | ||||
| return | return | ||||
| annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all() | |||||
| annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all() | |||||
| enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" | enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" | ||||
| enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" | enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" | ||||
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| from core.tools.utils.web_reader_tool import get_image_upload_file_ids | from core.tools.utils.web_reader_tool import get_image_upload_file_ids | ||||
| if not dataset: | if not dataset: | ||||
| raise Exception("Document has no dataset") | raise Exception("Document has no dataset") | ||||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all() | |||||
| segments = db.session.scalars( | |||||
| select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) | |||||
| ).all() | |||||
| # check segment is exist | # check segment is exist | ||||
| if segments: | if segments: | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| db.session.commit() | db.session.commit() | ||||
| if file_ids: | if file_ids: | ||||
| files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all() | |||||
| files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() | |||||
| for file in files: | for file in files: | ||||
| try: | try: | ||||
| storage.delete(file.key) | storage.delete(file.key) |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| from core.tools.utils.web_reader_tool import get_image_upload_file_ids | from core.tools.utils.web_reader_tool import get_image_upload_file_ids | ||||
| index_struct=index_struct, | index_struct=index_struct, | ||||
| collection_binding_id=collection_binding_id, | collection_binding_id=collection_binding_id, | ||||
| ) | ) | ||||
| documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all() | |||||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all() | |||||
| documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() | |||||
| segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() | |||||
| # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace | # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace | ||||
| # This ensures all invalid doc_form values are properly handled | # This ensures all invalid doc_form values are properly handled |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| from core.tools.utils.web_reader_tool import get_image_upload_file_ids | from core.tools.utils.web_reader_tool import get_image_upload_file_ids | ||||
| if not dataset: | if not dataset: | ||||
| raise Exception("Document has no dataset") | raise Exception("Document has no dataset") | ||||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||||
| segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() | |||||
| # check segment is exist | # check segment is exist | ||||
| if segments: | if segments: | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| document = db.session.query(Document).where(Document.id == document_id).first() | document = db.session.query(Document).where(Document.id == document_id).first() | ||||
| db.session.delete(document) | db.session.delete(document) | ||||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||||
| segments = db.session.scalars( | |||||
| select(DocumentSegment).where(DocumentSegment.document_id == document_id) | |||||
| ).all() | |||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from core.rag.index_processor.constant.index_type import IndexType | from core.rag.index_processor.constant.index_type import IndexType | ||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| if action == "remove": | if action == "remove": | ||||
| index_processor.clean(dataset, None, with_keywords=False) | index_processor.clean(dataset, None, with_keywords=False) | ||||
| elif action == "add": | elif action == "add": | ||||
| dataset_documents = ( | |||||
| db.session.query(DatasetDocument) | |||||
| .where( | |||||
| dataset_documents = db.session.scalars( | |||||
| select(DatasetDocument).where( | |||||
| DatasetDocument.dataset_id == dataset_id, | DatasetDocument.dataset_id == dataset_id, | ||||
| DatasetDocument.indexing_status == "completed", | DatasetDocument.indexing_status == "completed", | ||||
| DatasetDocument.enabled == True, | DatasetDocument.enabled == True, | ||||
| DatasetDocument.archived == False, | DatasetDocument.archived == False, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| if dataset_documents: | if dataset_documents: | ||||
| dataset_documents_ids = [doc.id for doc in dataset_documents] | dataset_documents_ids = [doc.id for doc in dataset_documents] | ||||
| ) | ) | ||||
| db.session.commit() | db.session.commit() | ||||
| elif action == "update": | elif action == "update": | ||||
| dataset_documents = ( | |||||
| db.session.query(DatasetDocument) | |||||
| .where( | |||||
| dataset_documents = db.session.scalars( | |||||
| select(DatasetDocument).where( | |||||
| DatasetDocument.dataset_id == dataset_id, | DatasetDocument.dataset_id == dataset_id, | ||||
| DatasetDocument.indexing_status == "completed", | DatasetDocument.indexing_status == "completed", | ||||
| DatasetDocument.enabled == True, | DatasetDocument.enabled == True, | ||||
| DatasetDocument.archived == False, | DatasetDocument.archived == False, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| # add new index | # add new index | ||||
| if dataset_documents: | if dataset_documents: | ||||
| # update document status | # update document status |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| # sync index processor | # sync index processor | ||||
| index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() | index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() | ||||
| segments = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .where( | |||||
| segments = db.session.scalars( | |||||
| select(DocumentSegment).where( | |||||
| DocumentSegment.id.in_(segment_ids), | DocumentSegment.id.in_(segment_ids), | ||||
| DocumentSegment.dataset_id == dataset_id, | DocumentSegment.dataset_id == dataset_id, | ||||
| DocumentSegment.document_id == document_id, | DocumentSegment.document_id == document_id, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| if not segments: | if not segments: | ||||
| db.session.close() | db.session.close() |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from core.indexing_runner import DocumentIsPausedError, IndexingRunner | from core.indexing_runner import DocumentIsPausedError, IndexingRunner | ||||
| from core.rag.extractor.notion_extractor import NotionExtractor | from core.rag.extractor.notion_extractor import NotionExtractor | ||||
| index_type = document.doc_form | index_type = document.doc_form | ||||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | index_processor = IndexProcessorFactory(index_type).init_index_processor() | ||||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||||
| segments = db.session.scalars( | |||||
| select(DocumentSegment).where(DocumentSegment.document_id == document_id) | |||||
| ).all() | |||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| # delete from vector index | # delete from vector index |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from core.indexing_runner import DocumentIsPausedError, IndexingRunner | from core.indexing_runner import DocumentIsPausedError, IndexingRunner | ||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| index_type = document.doc_form | index_type = document.doc_form | ||||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | index_processor = IndexProcessorFactory(index_type).init_index_processor() | ||||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||||
| segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() | |||||
| if segments: | if segments: | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.indexing_runner import DocumentIsPausedError, IndexingRunner | from core.indexing_runner import DocumentIsPausedError, IndexingRunner | ||||
| index_type = document.doc_form | index_type = document.doc_form | ||||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | index_processor = IndexProcessorFactory(index_type).init_index_processor() | ||||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||||
| segments = db.session.scalars( | |||||
| select(DocumentSegment).where(DocumentSegment.document_id == document_id) | |||||
| ).all() | |||||
| if segments: | if segments: | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from core.rag.index_processor.constant.index_type import IndexType | from core.rag.index_processor.constant.index_type import IndexType | ||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| # sync index processor | # sync index processor | ||||
| index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() | index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() | ||||
| segments = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .where( | |||||
| segments = db.session.scalars( | |||||
| select(DocumentSegment).where( | |||||
| DocumentSegment.id.in_(segment_ids), | DocumentSegment.id.in_(segment_ids), | ||||
| DocumentSegment.dataset_id == dataset_id, | DocumentSegment.dataset_id == dataset_id, | ||||
| DocumentSegment.document_id == document_id, | DocumentSegment.document_id == document_id, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| if not segments: | if not segments: | ||||
| logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan")) | logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan")) | ||||
| db.session.close() | db.session.close() |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() | index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() | ||||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all() | |||||
| segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all() | |||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| if index_node_ids: | if index_node_ids: | ||||
| try: | try: |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from core.indexing_runner import IndexingRunner | from core.indexing_runner import IndexingRunner | ||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| # clean old data | # clean old data | ||||
| index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() | index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() | ||||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||||
| segments = db.session.scalars( | |||||
| select(DocumentSegment).where(DocumentSegment.document_id == document_id) | |||||
| ).all() | |||||
| if segments: | if segments: | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| # delete from vector index | # delete from vector index |
| import click | import click | ||||
| from celery import shared_task | from celery import shared_task | ||||
| from sqlalchemy import select | |||||
| from core.indexing_runner import IndexingRunner | from core.indexing_runner import IndexingRunner | ||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| # clean old data | # clean old data | ||||
| index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() | index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() | ||||
| segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() | |||||
| segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() | |||||
| if segments: | if segments: | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| # delete from vector index | # delete from vector index |
| import pytest | import pytest | ||||
| from faker import Faker | from faker import Faker | ||||
| from sqlalchemy import select | |||||
| from models.account import TenantAccountJoin, TenantAccountRole | from models.account import TenantAccountJoin, TenantAccountRole | ||||
| from models.model import Account, Tenant | from models.model import Account, Tenant | ||||
| assert load_balancing_config.id is not None | assert load_balancing_config.id is not None | ||||
| # Verify inherit config was created in database | # Verify inherit config was created in database | ||||
| inherit_configs = ( | |||||
| db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all() | |||||
| ) | |||||
| inherit_configs = db.session.scalars( | |||||
| select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__") | |||||
| ).all() | |||||
| assert len(inherit_configs) == 1 | assert len(inherit_configs) == 1 |
| import pytest | import pytest | ||||
| from faker import Faker | from faker import Faker | ||||
| from sqlalchemy import select | |||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole | from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| # Verify only one binding exists | # Verify only one binding exists | ||||
| bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() | |||||
| bindings = db.session.scalars( | |||||
| select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) | |||||
| ).all() | |||||
| assert len(bindings) == 1 | assert len(bindings) == 1 | ||||
| def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): | def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): | ||||
| # No error should be raised, and database state should remain unchanged | # No error should be raised, and database state should remain unchanged | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() | |||||
| bindings = db.session.scalars( | |||||
| select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) | |||||
| ).all() | |||||
| assert len(bindings) == 0 | assert len(bindings) == 0 | ||||
| def test_check_target_exists_knowledge_success( | def test_check_target_exists_knowledge_success( |
| import pytest | import pytest | ||||
| from faker import Faker | from faker import Faker | ||||
| from sqlalchemy import select | |||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from models.account import Account | from models.account import Account | ||||
| # Verify only one pinned conversation record exists | # Verify only one pinned conversation record exists | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| pinned_conversations = ( | |||||
| db.session.query(PinnedConversation) | |||||
| .where( | |||||
| pinned_conversations = db.session.scalars( | |||||
| select(PinnedConversation).where( | |||||
| PinnedConversation.app_id == app.id, | PinnedConversation.app_id == app.id, | ||||
| PinnedConversation.conversation_id == conversation.id, | PinnedConversation.conversation_id == conversation.id, | ||||
| PinnedConversation.created_by_role == "account", | PinnedConversation.created_by_role == "account", | ||||
| PinnedConversation.created_by == account.id, | PinnedConversation.created_by == account.id, | ||||
| ) | ) | ||||
| .all() | |||||
| ) | |||||
| ).all() | |||||
| assert len(pinned_conversations) == 1 | assert len(pinned_conversations) == 1 | ||||
| mock_binding.provider = self.provider | mock_binding.provider = self.provider | ||||
| mock_binding.disabled = False | mock_binding.disabled = False | ||||
| mock_session.query.return_value.where.return_value.all.return_value = [mock_binding] | |||||
| mock_session.scalars.return_value.all.return_value = [mock_binding] | |||||
| result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) | result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) | ||||
| assert len(result) == 1 | assert len(result) == 1 | ||||
| assert result[0].tenant_id == self.tenant_id | assert result[0].tenant_id == self.tenant_id | ||||
| mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding) | |||||
| assert mock_session.scalars.call_count == 1 | |||||
| select_arg = mock_session.scalars.call_args[0][0] | |||||
| assert "data_source_api_key_auth_binding" in str(select_arg).lower() | |||||
| @patch("services.auth.api_key_auth_service.db.session") | @patch("services.auth.api_key_auth_service.db.session") | ||||
| def test_get_provider_auth_list_empty(self, mock_session): | def test_get_provider_auth_list_empty(self, mock_session): | ||||
| """Test get provider auth list - empty result""" | """Test get provider auth list - empty result""" | ||||
| mock_session.query.return_value.where.return_value.all.return_value = [] | |||||
| mock_session.scalars.return_value.all.return_value = [] | |||||
| result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) | result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) | ||||
| @patch("services.auth.api_key_auth_service.db.session") | @patch("services.auth.api_key_auth_service.db.session") | ||||
| def test_get_provider_auth_list_filters_disabled(self, mock_session): | def test_get_provider_auth_list_filters_disabled(self, mock_session): | ||||
| """Test get provider auth list - filters disabled items""" | """Test get provider auth list - filters disabled items""" | ||||
| mock_session.query.return_value.where.return_value.all.return_value = [] | |||||
| mock_session.scalars.return_value.all.return_value = [] | |||||
| ApiKeyAuthService.get_provider_auth_list(self.tenant_id) | ApiKeyAuthService.get_provider_auth_list(self.tenant_id) | ||||
| # Verify where conditions include disabled.is_(False) | |||||
| where_call = mock_session.query.return_value.where.call_args[0] | |||||
| assert len(where_call) == 2 # tenant_id and disabled filter conditions | |||||
| select_stmt = mock_session.scalars.call_args[0][0] | |||||
| where_clauses = list(getattr(select_stmt, "_where_criteria", []) or []) | |||||
| # Ensure both tenant filter and disabled filter exist | |||||
| where_strs = [str(c).lower() for c in where_clauses] | |||||
| assert any("tenant_id" in s for s in where_strs) | |||||
| assert any("disabled" in s for s in where_strs) | |||||
| @patch("services.auth.api_key_auth_service.db.session") | @patch("services.auth.api_key_auth_service.db.session") | ||||
| @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") | @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") |
| tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) | tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) | ||||
| tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) | tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) | ||||
| mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding] | |||||
| mock_session.scalars.return_value.all.return_value = [tenant1_binding] | |||||
| result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) | result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) | ||||
| mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding] | |||||
| mock_session.scalars.return_value.all.return_value = [tenant2_binding] | |||||
| result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) | result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) | ||||
| assert len(result1) == 1 | assert len(result1) == 1 |