Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>tags/2.0.0-beta.2^2
| @@ -1,9 +1,9 @@ | |||
| from flask_login import current_user | |||
| from flask_restx import Resource, reqparse | |||
| from controllers.console import api | |||
| from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required | |||
| from libs.login import login_required | |||
| from libs.login import current_user, login_required | |||
| from models.model import Account | |||
| from services.billing_service import BillingService | |||
| @@ -17,9 +17,10 @@ class Subscription(Resource): | |||
| parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) | |||
| parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) | |||
| args = parser.parse_args() | |||
| assert isinstance(current_user, Account) | |||
| BillingService.is_tenant_owner_or_admin(current_user) | |||
| assert current_user.current_tenant_id is not None | |||
| return BillingService.get_subscription( | |||
| args["plan"], args["interval"], current_user.email, current_user.current_tenant_id | |||
| ) | |||
| @@ -31,7 +32,9 @@ class Invoices(Resource): | |||
| @account_initialization_required | |||
| @only_edition_cloud | |||
| def get(self): | |||
| assert isinstance(current_user, Account) | |||
| BillingService.is_tenant_owner_or_admin(current_user) | |||
| assert current_user.current_tenant_id is not None | |||
| return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) | |||
| @@ -2,7 +2,6 @@ import threading | |||
| from typing import Any, Optional | |||
| import pytz | |||
| from flask_login import current_user | |||
| import contexts | |||
| from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager | |||
| @@ -10,6 +9,7 @@ from core.plugin.impl.agent import PluginAgentClient | |||
| from core.plugin.impl.exc import PluginDaemonClientSideError | |||
| from core.tools.tool_manager import ToolManager | |||
| from extensions.ext_database import db | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| from models.model import App, Conversation, EndUser, Message, MessageAgentThought | |||
| @@ -61,7 +61,8 @@ class AgentService: | |||
| executor = executor.name | |||
| else: | |||
| executor = "Unknown" | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.timezone is not None | |||
| timezone = pytz.timezone(current_user.timezone) | |||
| app_model_config = app_model.app_model_config | |||
| @@ -2,7 +2,6 @@ import uuid | |||
| from typing import Optional | |||
| import pandas as pd | |||
| from flask_login import current_user | |||
| from sqlalchemy import or_, select | |||
| from werkzeug.datastructures import FileStorage | |||
| from werkzeug.exceptions import NotFound | |||
| @@ -10,6 +9,8 @@ from werkzeug.exceptions import NotFound | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation | |||
| from services.feature_service import FeatureService | |||
| from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task | |||
| @@ -24,6 +25,7 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -62,6 +64,7 @@ class AppAnnotationService: | |||
| db.session.commit() | |||
| # if annotation reply is enabled , add annotation to index | |||
| annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() | |||
| assert current_user.current_tenant_id is not None | |||
| if annotation_setting: | |||
| add_annotation_to_index_task.delay( | |||
| annotation.id, | |||
| @@ -84,6 +87,8 @@ class AppAnnotationService: | |||
| enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" | |||
| # send batch add segments task | |||
| redis_client.setnx(enable_app_annotation_job_key, "waiting") | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| enable_annotation_reply_task.delay( | |||
| str(job_id), | |||
| app_id, | |||
| @@ -97,6 +102,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def disable_app_annotation(cls, app_id: str): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" | |||
| cache_result = redis_client.get(disable_app_annotation_key) | |||
| if cache_result is not None: | |||
| @@ -113,6 +120,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -145,6 +154,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def export_annotation_list_by_app_id(cls, app_id: str): | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -164,6 +175,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -193,6 +206,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -230,6 +245,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def delete_app_annotation(cls, app_id: str, annotation_id: str): | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -269,6 +286,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -317,6 +336,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def batch_import_app_annotations(cls, app_id, file: FileStorage): | |||
| # get app info | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -355,6 +376,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| # get app info | |||
| app = ( | |||
| db.session.query(App) | |||
| @@ -425,6 +448,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def get_app_annotation_setting_by_app_id(cls, app_id: str): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| # get app info | |||
| app = ( | |||
| db.session.query(App) | |||
| @@ -451,6 +476,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| # get app info | |||
| app = ( | |||
| db.session.query(App) | |||
| @@ -491,6 +518,8 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def clear_all_annotations(cls, app_id: str): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| app = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -2,7 +2,6 @@ import json | |||
| import logging | |||
| from typing import Optional, TypedDict, cast | |||
| from flask_login import current_user | |||
| from flask_sqlalchemy.pagination import Pagination | |||
| from configs import dify_config | |||
| @@ -17,6 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager | |||
| from events.app_event import app_was_created | |||
| from extensions.ext_database import db | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| from models.model import App, AppMode, AppModelConfig, Site | |||
| from models.tools import ApiToolProvider | |||
| @@ -168,6 +168,8 @@ class AppService: | |||
| """ | |||
| Get App | |||
| """ | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| # get original app model config | |||
| if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: | |||
| model_config = app.app_model_config | |||
| @@ -242,6 +244,7 @@ class AppService: | |||
| :param args: request args | |||
| :return: App instance | |||
| """ | |||
| assert current_user is not None | |||
| app.name = args["name"] | |||
| app.description = args["description"] | |||
| app.icon_type = args["icon_type"] | |||
| @@ -262,6 +265,7 @@ class AppService: | |||
| :param name: new name | |||
| :return: App instance | |||
| """ | |||
| assert current_user is not None | |||
| app.name = name | |||
| app.updated_by = current_user.id | |||
| app.updated_at = naive_utc_now() | |||
| @@ -277,6 +281,7 @@ class AppService: | |||
| :param icon_background: new icon_background | |||
| :return: App instance | |||
| """ | |||
| assert current_user is not None | |||
| app.icon = icon | |||
| app.icon_background = icon_background | |||
| app.updated_by = current_user.id | |||
| @@ -294,7 +299,7 @@ class AppService: | |||
| """ | |||
| if enable_site == app.enable_site: | |||
| return app | |||
| assert current_user is not None | |||
| app.enable_site = enable_site | |||
| app.updated_by = current_user.id | |||
| app.updated_at = naive_utc_now() | |||
| @@ -311,6 +316,7 @@ class AppService: | |||
| """ | |||
| if enable_api == app.enable_api: | |||
| return app | |||
| assert current_user is not None | |||
| app.enable_api = enable_api | |||
| app.updated_by = current_user.id | |||
| @@ -70,7 +70,7 @@ class BillingService: | |||
| return response.json() | |||
| @staticmethod | |||
| def is_tenant_owner_or_admin(current_user): | |||
| def is_tenant_owner_or_admin(current_user: Account): | |||
| tenant_id = current_user.current_tenant_id | |||
| join: Optional[TenantAccountJoin] = ( | |||
| @@ -8,7 +8,7 @@ import uuid | |||
| from collections import Counter | |||
| from typing import Any, Literal, Optional | |||
| from flask_login import current_user | |||
| import sqlalchemy as sa | |||
| from sqlalchemy import exists, func, select | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import NotFound | |||
| @@ -27,6 +27,7 @@ from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from libs import helper | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.login import current_user | |||
| from models.account import Account, TenantAccountRole | |||
| from models.dataset import ( | |||
| AppDatasetJoin, | |||
| @@ -498,8 +499,11 @@ class DatasetService: | |||
| data: Update data dictionary | |||
| filtered_data: Filtered update data to modify | |||
| """ | |||
| # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None | |||
| try: | |||
| model_manager = ModelManager() | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=data["embedding_model_provider"], | |||
| @@ -611,8 +615,12 @@ class DatasetService: | |||
| data: Update data dictionary | |||
| filtered_data: Filtered update data to modify | |||
| """ | |||
| # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None | |||
| model_manager = ModelManager() | |||
| try: | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=data["embedding_model_provider"], | |||
| @@ -720,6 +728,8 @@ class DatasetService: | |||
| @staticmethod | |||
| def get_dataset_auto_disable_logs(dataset_id: str): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if not features.billing.enabled or features.billing.subscription.plan == "sandbox": | |||
| return { | |||
| @@ -924,6 +934,8 @@ class DocumentService: | |||
| @staticmethod | |||
| def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: | |||
| assert isinstance(current_user, Account) | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .where( | |||
| @@ -983,6 +995,8 @@ class DocumentService: | |||
| @staticmethod | |||
| def rename_document(dataset_id: str, document_id: str, name: str) -> Document: | |||
| assert isinstance(current_user, Account) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise ValueError("Dataset not found.") | |||
| @@ -1012,6 +1026,7 @@ class DocumentService: | |||
| if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: | |||
| raise DocumentIndexingError() | |||
| # update document to be paused | |||
| assert current_user is not None | |||
| document.is_paused = True | |||
| document.paused_by = current_user.id | |||
| document.paused_at = naive_utc_now() | |||
| @@ -1098,6 +1113,9 @@ class DocumentService: | |||
| # check doc_form | |||
| DatasetService.check_doc_form(dataset, knowledge_config.doc_form) | |||
| # check document limit | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.billing.enabled: | |||
| @@ -1434,6 +1452,8 @@ class DocumentService: | |||
| @staticmethod | |||
| def get_tenant_documents_count(): | |||
| assert isinstance(current_user, Account) | |||
| documents_count = ( | |||
| db.session.query(Document) | |||
| .where( | |||
| @@ -1454,6 +1474,8 @@ class DocumentService: | |||
| dataset_process_rule: Optional[DatasetProcessRule] = None, | |||
| created_from: str = "web", | |||
| ): | |||
| assert isinstance(current_user, Account) | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| document = DocumentService.get_document(dataset.id, document_data.original_document_id) | |||
| if document is None: | |||
| @@ -1513,7 +1535,7 @@ class DocumentService: | |||
| data_source_binding = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .where( | |||
| db.and_( | |||
| sa.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| @@ -1574,6 +1596,9 @@ class DocumentService: | |||
| @staticmethod | |||
| def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.billing.enabled: | |||
| @@ -2013,6 +2038,9 @@ class SegmentService: | |||
| @classmethod | |||
| def create_segment(cls, args: dict, document: Document, dataset: Dataset): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| content = args["content"] | |||
| doc_id = str(uuid.uuid4()) | |||
| segment_hash = helper.generate_text_hash(content) | |||
| @@ -2075,6 +2103,9 @@ class SegmentService: | |||
| @classmethod | |||
| def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| lock_name = f"multi_add_segment_lock_document_id_{document.id}" | |||
| increment_word_count = 0 | |||
| with redis_client.lock(lock_name, timeout=600): | |||
| @@ -2158,6 +2189,9 @@ class SegmentService: | |||
| @classmethod | |||
| def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| indexing_cache_key = f"segment_{segment.id}_indexing" | |||
| cache_result = redis_client.get(indexing_cache_key) | |||
| if cache_result is not None: | |||
| @@ -2349,6 +2383,7 @@ class SegmentService: | |||
| @classmethod | |||
| def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): | |||
| assert isinstance(current_user, Account) | |||
| segments = ( | |||
| db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) | |||
| .where( | |||
| @@ -2379,6 +2414,8 @@ class SegmentService: | |||
| def update_segments_status( | |||
| cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document | |||
| ): | |||
| assert current_user is not None | |||
| # Check if segment_ids is not empty to avoid WHERE false condition | |||
| if not segment_ids or len(segment_ids) == 0: | |||
| return | |||
| @@ -2441,6 +2478,8 @@ class SegmentService: | |||
| def create_child_chunk( | |||
| cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset | |||
| ) -> ChildChunk: | |||
| assert isinstance(current_user, Account) | |||
| lock_name = f"add_child_lock_{segment.id}" | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| index_node_id = str(uuid.uuid4()) | |||
| @@ -2488,6 +2527,8 @@ class SegmentService: | |||
| document: Document, | |||
| dataset: Dataset, | |||
| ) -> list[ChildChunk]: | |||
| assert isinstance(current_user, Account) | |||
| child_chunks = ( | |||
| db.session.query(ChildChunk) | |||
| .where( | |||
| @@ -2562,6 +2603,8 @@ class SegmentService: | |||
| document: Document, | |||
| dataset: Dataset, | |||
| ) -> ChildChunk: | |||
| assert current_user is not None | |||
| try: | |||
| child_chunk.content = content | |||
| child_chunk.word_count = len(content) | |||
| @@ -2592,6 +2635,8 @@ class SegmentService: | |||
| def get_child_chunks( | |||
| cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None | |||
| ): | |||
| assert isinstance(current_user, Account) | |||
| query = ( | |||
| select(ChildChunk) | |||
| .filter_by( | |||
| @@ -3,7 +3,6 @@ import os | |||
| import uuid | |||
| from typing import Any, Literal, Union | |||
| from flask_login import current_user | |||
| from werkzeug.exceptions import NotFound | |||
| from configs import dify_config | |||
| @@ -19,6 +18,7 @@ from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from libs.datetime_utils import naive_utc_now | |||
| from libs.helper import extract_tenant_id | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| from models.enums import CreatorUserRole | |||
| from models.model import EndUser, UploadFile | |||
| @@ -111,6 +111,9 @@ class FileService: | |||
| @staticmethod | |||
| def upload_text(text: str, text_name: str) -> UploadFile: | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| if len(text_name) > 200: | |||
| text_name = text_name[:200] | |||
| # user uuid as file name | |||
| @@ -1,10 +1,11 @@ | |||
| import json | |||
| from unittest.mock import MagicMock, patch | |||
| from unittest.mock import MagicMock, create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| from core.plugin.impl.exc import PluginDaemonClientSideError | |||
| from models.account import Account | |||
| from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought | |||
| from services.account_service import AccountService, TenantService | |||
| from services.agent_service import AgentService | |||
| @@ -21,7 +22,7 @@ class TestAgentService: | |||
| patch("services.agent_service.PluginAgentClient") as mock_plugin_agent_client, | |||
| patch("services.agent_service.ToolManager") as mock_tool_manager, | |||
| patch("services.agent_service.AgentConfigManager") as mock_agent_config_manager, | |||
| patch("services.agent_service.current_user") as mock_current_user, | |||
| patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, | |||
| patch("services.app_service.FeatureService") as mock_feature_service, | |||
| patch("services.app_service.EnterpriseService") as mock_enterprise_service, | |||
| patch("services.app_service.ModelManager") as mock_model_manager, | |||
| @@ -1,9 +1,10 @@ | |||
| from unittest.mock import patch | |||
| from unittest.mock import create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| from werkzeug.exceptions import NotFound | |||
| from models.account import Account | |||
| from models.model import MessageAnnotation | |||
| from services.annotation_service import AppAnnotationService | |||
| from services.app_service import AppService | |||
| @@ -24,7 +25,9 @@ class TestAnnotationService: | |||
| patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task, | |||
| patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task, | |||
| patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task, | |||
| patch("services.annotation_service.current_user") as mock_current_user, | |||
| patch( | |||
| "services.annotation_service.current_user", create_autospec(Account, instance=True) | |||
| ) as mock_current_user, | |||
| ): | |||
| # Setup default mock returns | |||
| mock_account_feature_service.get_features.return_value.billing.enabled = False | |||
| @@ -1,9 +1,10 @@ | |||
| from unittest.mock import patch | |||
| from unittest.mock import create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| from constants.model_template import default_app_templates | |||
| from models.account import Account | |||
| from models.model import App, Site | |||
| from services.account_service import AccountService, TenantService | |||
| from services.app_service import AppService | |||
| @@ -161,8 +162,13 @@ class TestAppService: | |||
| app_service = AppService() | |||
| created_app = app_service.create_app(tenant.id, app_args, account) | |||
| # Get app using the service | |||
| retrieved_app = app_service.get_app(created_app) | |||
| # Get app using the service - needs current_user mock | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.id = account.id | |||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| retrieved_app = app_service.get_app(created_app) | |||
| # Verify retrieved app matches created app | |||
| assert retrieved_app.id == created_app.id | |||
| @@ -406,7 +412,11 @@ class TestAppService: | |||
| "use_icon_as_answer_icon": True, | |||
| } | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.id = account.id | |||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app(app, update_args) | |||
| # Verify updated fields | |||
| @@ -456,7 +466,11 @@ class TestAppService: | |||
| # Update app name | |||
| new_name = "New App Name" | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.id = account.id | |||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app_name(app, new_name) | |||
| assert updated_app.name == new_name | |||
| @@ -504,7 +518,11 @@ class TestAppService: | |||
| # Update app icon | |||
| new_icon = "🌟" | |||
| new_icon_background = "#FFD93D" | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.id = account.id | |||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) | |||
| assert updated_app.icon == new_icon | |||
| @@ -551,13 +569,17 @@ class TestAppService: | |||
| original_site_status = app.enable_site | |||
| # Update site status to disabled | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.id = account.id | |||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app_site_status(app, False) | |||
| assert updated_app.enable_site is False | |||
| assert updated_app.updated_by == account.id | |||
| # Update site status back to enabled | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app_site_status(updated_app, True) | |||
| assert updated_app.enable_site is True | |||
| assert updated_app.updated_by == account.id | |||
| @@ -602,13 +624,17 @@ class TestAppService: | |||
| original_api_status = app.enable_api | |||
| # Update API status to disabled | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.id = account.id | |||
| mock_current_user.current_tenant_id = account.current_tenant_id | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app_api_status(app, False) | |||
| assert updated_app.enable_api is False | |||
| assert updated_app.updated_by == account.id | |||
| # Update API status back to enabled | |||
| with patch("flask_login.utils._get_user", return_value=account): | |||
| with patch("services.app_service.current_user", mock_current_user): | |||
| updated_app = app_service.update_app_api_status(updated_app, True) | |||
| assert updated_app.enable_api is True | |||
| assert updated_app.updated_by == account.id | |||
| @@ -1,6 +1,6 @@ | |||
| import hashlib | |||
| from io import BytesIO | |||
| from unittest.mock import patch | |||
| from unittest.mock import create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| @@ -417,11 +417,12 @@ class TestFileService: | |||
| text = "This is a test text content" | |||
| text_name = "test_text.txt" | |||
| # Mock current_user | |||
| with patch("services.file_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||
| mock_current_user.id = str(fake.uuid4()) | |||
| # Mock current_user using create_autospec | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||
| mock_current_user.id = str(fake.uuid4()) | |||
| with patch("services.file_service.current_user", mock_current_user): | |||
| upload_file = FileService.upload_text(text=text, text_name=text_name) | |||
| assert upload_file is not None | |||
| @@ -443,11 +444,12 @@ class TestFileService: | |||
| text = "test content" | |||
| long_name = "a" * 250 # Longer than 200 characters | |||
| # Mock current_user | |||
| with patch("services.file_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||
| mock_current_user.id = str(fake.uuid4()) | |||
| # Mock current_user using create_autospec | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||
| mock_current_user.id = str(fake.uuid4()) | |||
| with patch("services.file_service.current_user", mock_current_user): | |||
| upload_file = FileService.upload_text(text=text, text_name=long_name) | |||
| # Verify name was truncated | |||
| @@ -846,11 +848,12 @@ class TestFileService: | |||
| text = "" | |||
| text_name = "empty.txt" | |||
| # Mock current_user | |||
| with patch("services.file_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||
| mock_current_user.id = str(fake.uuid4()) | |||
| # Mock current_user using create_autospec | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = str(fake.uuid4()) | |||
| mock_current_user.id = str(fake.uuid4()) | |||
| with patch("services.file_service.current_user", mock_current_user): | |||
| upload_file = FileService.upload_text(text=text, text_name=text_name) | |||
| assert upload_file is not None | |||
| @@ -1,4 +1,4 @@ | |||
| from unittest.mock import patch | |||
| from unittest.mock import create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| @@ -17,7 +17,9 @@ class TestMetadataService: | |||
| def mock_external_service_dependencies(self): | |||
| """Mock setup for external service dependencies.""" | |||
| with ( | |||
| patch("services.metadata_service.current_user") as mock_current_user, | |||
| patch( | |||
| "services.metadata_service.current_user", create_autospec(Account, instance=True) | |||
| ) as mock_current_user, | |||
| patch("services.metadata_service.redis_client") as mock_redis_client, | |||
| patch("services.dataset_service.DocumentService") as mock_document_service, | |||
| ): | |||
| @@ -1,4 +1,4 @@ | |||
| from unittest.mock import patch | |||
| from unittest.mock import create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| @@ -17,7 +17,7 @@ class TestTagService: | |||
| def mock_external_service_dependencies(self): | |||
| """Mock setup for external service dependencies.""" | |||
| with ( | |||
| patch("services.tag_service.current_user") as mock_current_user, | |||
| patch("services.tag_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, | |||
| ): | |||
| # Setup default mock returns | |||
| mock_current_user.current_tenant_id = "test-tenant-id" | |||
| @@ -1,5 +1,5 @@ | |||
| from datetime import datetime | |||
| from unittest.mock import MagicMock, patch | |||
| from unittest.mock import MagicMock, create_autospec, patch | |||
| import pytest | |||
| from faker import Faker | |||
| @@ -231,9 +231,10 @@ class TestWebsiteService: | |||
| fake = Faker() | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlApiRequest( | |||
| provider="firecrawl", | |||
| @@ -285,9 +286,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlApiRequest( | |||
| provider="watercrawl", | |||
| @@ -336,9 +338,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request for single page crawling | |||
| api_request = WebsiteCrawlApiRequest( | |||
| provider="jinareader", | |||
| @@ -389,9 +392,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request with invalid provider | |||
| api_request = WebsiteCrawlApiRequest( | |||
| provider="invalid_provider", | |||
| @@ -419,9 +423,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123") | |||
| @@ -463,9 +468,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlStatusApiRequest(provider="watercrawl", job_id="watercrawl_job_123") | |||
| @@ -502,9 +508,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlStatusApiRequest(provider="jinareader", job_id="jina_job_123") | |||
| @@ -544,9 +551,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request with invalid provider | |||
| api_request = WebsiteCrawlStatusApiRequest(provider="invalid_provider", job_id="test_job_id_123") | |||
| @@ -569,9 +577,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Mock missing credentials | |||
| mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = None | |||
| @@ -597,9 +606,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Mock missing API key in config | |||
| mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = { | |||
| "config": {"base_url": "https://api.example.com"} | |||
| @@ -995,9 +1005,10 @@ class TestWebsiteService: | |||
| account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request for sub-page crawling | |||
| api_request = WebsiteCrawlApiRequest( | |||
| provider="jinareader", | |||
| @@ -1054,9 +1065,10 @@ class TestWebsiteService: | |||
| mock_external_service_dependencies["requests"].get.return_value = mock_failed_response | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlApiRequest( | |||
| provider="jinareader", | |||
| @@ -1096,9 +1108,10 @@ class TestWebsiteService: | |||
| mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance | |||
| # Mock current_user for the test | |||
| with patch("services.website_service.current_user") as mock_current_user: | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| mock_current_user = create_autospec(Account, instance=True) | |||
| mock_current_user.current_tenant_id = account.current_tenant.id | |||
| with patch("services.website_service.current_user", mock_current_user): | |||
| # Create API request | |||
| api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="active_job_123") | |||
| @@ -2,11 +2,12 @@ import datetime | |||
| from typing import Any, Optional | |||
| # Mock redis_client before importing dataset_service | |||
| from unittest.mock import Mock, patch | |||
| from unittest.mock import Mock, create_autospec, patch | |||
| import pytest | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from models.account import Account | |||
| from models.dataset import Dataset, ExternalKnowledgeBindings | |||
| from services.dataset_service import DatasetService | |||
| from services.errors.account import NoPermissionError | |||
| @@ -78,7 +79,7 @@ class DatasetUpdateTestDataFactory: | |||
| @staticmethod | |||
| def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock: | |||
| """Create a mock current user.""" | |||
| current_user = Mock() | |||
| current_user = create_autospec(Account, instance=True) | |||
| current_user.current_tenant_id = tenant_id | |||
| return current_user | |||
| @@ -135,7 +136,9 @@ class TestDatasetServiceUpdateDataset: | |||
| "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" | |||
| ) as mock_get_binding, | |||
| patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, | |||
| patch("services.dataset_service.current_user") as mock_current_user, | |||
| patch( | |||
| "services.dataset_service.current_user", create_autospec(Account, instance=True) | |||
| ) as mock_current_user, | |||
| ): | |||
| mock_current_user.current_tenant_id = "tenant-123" | |||
| yield { | |||
| @@ -1,9 +1,10 @@ | |||
| from unittest.mock import Mock, patch | |||
| from unittest.mock import Mock, create_autospec, patch | |||
| import pytest | |||
| from flask_restx import reqparse | |||
| from werkzeug.exceptions import BadRequest | |||
| from models.account import Account | |||
| from services.entities.knowledge_entities.knowledge_entities import MetadataArgs | |||
| from services.metadata_service import MetadataService | |||
| @@ -35,19 +36,21 @@ class TestMetadataBugCompleteValidation: | |||
| mock_metadata_args.name = None | |||
| mock_metadata_args.type = "string" | |||
| with patch("services.metadata_service.current_user") as mock_user: | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| mock_user = create_autospec(Account, instance=True) | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| with patch("services.metadata_service.current_user", mock_user): | |||
| # Should crash with TypeError | |||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | |||
| MetadataService.create_metadata("dataset-123", mock_metadata_args) | |||
| # Test update method as well | |||
| with patch("services.metadata_service.current_user") as mock_user: | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| mock_user = create_autospec(Account, instance=True) | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| with patch("services.metadata_service.current_user", mock_user): | |||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | |||
| MetadataService.update_metadata_name("dataset-123", "metadata-456", None) | |||
| @@ -1,8 +1,9 @@ | |||
| from unittest.mock import Mock, patch | |||
| from unittest.mock import Mock, create_autospec, patch | |||
| import pytest | |||
| from flask_restx import reqparse | |||
| from models.account import Account | |||
| from services.entities.knowledge_entities.knowledge_entities import MetadataArgs | |||
| from services.metadata_service import MetadataService | |||
| @@ -24,20 +25,22 @@ class TestMetadataNullableBug: | |||
| mock_metadata_args.name = None # This will cause len() to crash | |||
| mock_metadata_args.type = "string" | |||
| with patch("services.metadata_service.current_user") as mock_user: | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| mock_user = create_autospec(Account, instance=True) | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| with patch("services.metadata_service.current_user", mock_user): | |||
| # This should crash with TypeError when calling len(None) | |||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | |||
| MetadataService.create_metadata("dataset-123", mock_metadata_args) | |||
| def test_metadata_service_update_with_none_name_crashes(self): | |||
| """Test that MetadataService.update_metadata_name crashes when name is None.""" | |||
| with patch("services.metadata_service.current_user") as mock_user: | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| mock_user = create_autospec(Account, instance=True) | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| with patch("services.metadata_service.current_user", mock_user): | |||
| # This should crash with TypeError when calling len(None) | |||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | |||
| MetadataService.update_metadata_name("dataset-123", "metadata-456", None) | |||
| @@ -81,10 +84,11 @@ class TestMetadataNullableBug: | |||
| mock_metadata_args.name = None # From args["name"] | |||
| mock_metadata_args.type = None # From args["type"] | |||
| with patch("services.metadata_service.current_user") as mock_user: | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| mock_user = create_autospec(Account, instance=True) | |||
| mock_user.current_tenant_id = "tenant-123" | |||
| mock_user.id = "user-456" | |||
| with patch("services.metadata_service.current_user", mock_user): | |||
| # Step 4: Service layer crashes on len(None) | |||
| with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): | |||
| MetadataService.create_metadata("dataset-123", mock_metadata_args) | |||