Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.4.1
| @@ -202,18 +202,18 @@ class EmailCodeLoginApi(Resource): | |||
| except AccountRegisterError as are: | |||
| raise AccountInFreezeError() | |||
| if account: | |||
| tenant = TenantService.get_join_tenants(account) | |||
| if not tenant: | |||
| tenants = TenantService.get_join_tenants(account) | |||
| if not tenants: | |||
| workspaces = FeatureService.get_system_features().license.workspaces | |||
| if not workspaces.is_available(): | |||
| raise WorkspacesLimitExceeded() | |||
| if not FeatureService.get_system_features().is_allow_create_workspace: | |||
| raise NotAllowedCreateWorkspace() | |||
| else: | |||
| tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |||
| TenantService.create_tenant_member(tenant, account, role="owner") | |||
| account.current_tenant = tenant | |||
| tenant_was_created.send(tenant) | |||
| new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |||
| TenantService.create_tenant_member(new_tenant, account, role="owner") | |||
| account.current_tenant = new_tenant | |||
| tenant_was_created.send(new_tenant) | |||
| if account is None: | |||
| try: | |||
| @@ -148,15 +148,15 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): | |||
| account = _get_account_by_openid_or_email(provider, user_info) | |||
| if account: | |||
| tenant = TenantService.get_join_tenants(account) | |||
| if not tenant: | |||
| tenants = TenantService.get_join_tenants(account) | |||
| if not tenants: | |||
| if not FeatureService.get_system_features().is_allow_create_workspace: | |||
| raise WorkSpaceNotAllowedCreateError() | |||
| else: | |||
| tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |||
| TenantService.create_tenant_member(tenant, account, role="owner") | |||
| account.current_tenant = tenant | |||
| tenant_was_created.send(tenant) | |||
| new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |||
| TenantService.create_tenant_member(new_tenant, account, role="owner") | |||
| account.current_tenant = new_tenant | |||
| tenant_was_created.send(new_tenant) | |||
| if not account: | |||
| if not FeatureService.get_system_features().is_allow_register: | |||
| @@ -540,9 +540,22 @@ class DatasetIndexingStatusApi(Resource): | |||
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .count() | |||
| ) | |||
| document.completed_segments = completed_segments | |||
| document.total_segments = total_segments | |||
| documents_status.append(marshal(document, document_status_fields)) | |||
| # Create a dictionary with document attributes and additional fields | |||
| document_dict = { | |||
| "id": document.id, | |||
| "indexing_status": document.indexing_status, | |||
| "processing_started_at": document.processing_started_at, | |||
| "parsing_completed_at": document.parsing_completed_at, | |||
| "cleaning_completed_at": document.cleaning_completed_at, | |||
| "splitting_completed_at": document.splitting_completed_at, | |||
| "completed_at": document.completed_at, | |||
| "paused_at": document.paused_at, | |||
| "error": document.error, | |||
| "stopped_at": document.stopped_at, | |||
| "completed_segments": completed_segments, | |||
| "total_segments": total_segments, | |||
| } | |||
| documents_status.append(marshal(document_dict, document_status_fields)) | |||
| data = {"data": documents_status} | |||
| return data | |||
| @@ -583,11 +583,22 @@ class DocumentBatchIndexingStatusApi(DocumentResource): | |||
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .count() | |||
| ) | |||
| document.completed_segments = completed_segments | |||
| document.total_segments = total_segments | |||
| if document.is_paused: | |||
| document.indexing_status = "paused" | |||
| documents_status.append(marshal(document, document_status_fields)) | |||
| # Create a dictionary with document attributes and additional fields | |||
| document_dict = { | |||
| "id": document.id, | |||
| "indexing_status": "paused" if document.is_paused else document.indexing_status, | |||
| "processing_started_at": document.processing_started_at, | |||
| "parsing_completed_at": document.parsing_completed_at, | |||
| "cleaning_completed_at": document.cleaning_completed_at, | |||
| "splitting_completed_at": document.splitting_completed_at, | |||
| "completed_at": document.completed_at, | |||
| "paused_at": document.paused_at, | |||
| "error": document.error, | |||
| "stopped_at": document.stopped_at, | |||
| "completed_segments": completed_segments, | |||
| "total_segments": total_segments, | |||
| } | |||
| documents_status.append(marshal(document_dict, document_status_fields)) | |||
| data = {"data": documents_status} | |||
| return data | |||
| @@ -616,11 +627,22 @@ class DocumentIndexingStatusApi(DocumentResource): | |||
| .count() | |||
| ) | |||
| document.completed_segments = completed_segments | |||
| document.total_segments = total_segments | |||
| if document.is_paused: | |||
| document.indexing_status = "paused" | |||
| return marshal(document, document_status_fields) | |||
| # Create a dictionary with document attributes and additional fields | |||
| document_dict = { | |||
| "id": document.id, | |||
| "indexing_status": "paused" if document.is_paused else document.indexing_status, | |||
| "processing_started_at": document.processing_started_at, | |||
| "parsing_completed_at": document.parsing_completed_at, | |||
| "cleaning_completed_at": document.cleaning_completed_at, | |||
| "splitting_completed_at": document.splitting_completed_at, | |||
| "completed_at": document.completed_at, | |||
| "paused_at": document.paused_at, | |||
| "error": document.error, | |||
| "stopped_at": document.stopped_at, | |||
| "completed_segments": completed_segments, | |||
| "total_segments": total_segments, | |||
| } | |||
| return marshal(document_dict, document_status_fields) | |||
| class DocumentDetailApi(DocumentResource): | |||
| @@ -68,16 +68,24 @@ class TenantListApi(Resource): | |||
| @account_initialization_required | |||
| def get(self): | |||
| tenants = TenantService.get_join_tenants(current_user) | |||
| tenant_dicts = [] | |||
| for tenant in tenants: | |||
| features = FeatureService.get_features(tenant.id) | |||
| if features.billing.enabled: | |||
| tenant.plan = features.billing.subscription.plan | |||
| else: | |||
| tenant.plan = "sandbox" | |||
| if tenant.id == current_user.current_tenant_id: | |||
| tenant.current = True # Set current=True for current tenant | |||
| return {"workspaces": marshal(tenants, tenants_fields)}, 200 | |||
| # Create a dictionary with tenant attributes | |||
| tenant_dict = { | |||
| "id": tenant.id, | |||
| "name": tenant.name, | |||
| "status": tenant.status, | |||
| "created_at": tenant.created_at, | |||
| "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", | |||
| "current": tenant.id == current_user.current_tenant_id, | |||
| } | |||
| tenant_dicts.append(tenant_dict) | |||
| return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200 | |||
| class WorkspaceListApi(Resource): | |||
| @@ -64,9 +64,24 @@ class PluginUploadFileApi(Resource): | |||
| extension = guess_extension(tool_file.mimetype) or ".bin" | |||
| preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension) | |||
| tool_file.mime_type = mimetype | |||
| tool_file.extension = extension | |||
| tool_file.preview_url = preview_url | |||
| # Create a dictionary with all the necessary attributes | |||
| result = { | |||
| "id": tool_file.id, | |||
| "user_id": tool_file.user_id, | |||
| "tenant_id": tool_file.tenant_id, | |||
| "conversation_id": tool_file.conversation_id, | |||
| "file_key": tool_file.file_key, | |||
| "mimetype": tool_file.mimetype, | |||
| "original_url": tool_file.original_url, | |||
| "name": tool_file.name, | |||
| "size": tool_file.size, | |||
| "mime_type": mimetype, | |||
| "extension": extension, | |||
| "preview_url": preview_url, | |||
| } | |||
| return result, 201 | |||
| except services.errors.file.FileTooLargeError as file_too_large_error: | |||
| raise FileTooLargeError(file_too_large_error.description) | |||
| except services.errors.file.UnsupportedFileTypeError: | |||
| @@ -388,11 +388,22 @@ class DocumentIndexingStatusApi(DatasetApiResource): | |||
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .count() | |||
| ) | |||
| document.completed_segments = completed_segments | |||
| document.total_segments = total_segments | |||
| if document.is_paused: | |||
| document.indexing_status = "paused" | |||
| documents_status.append(marshal(document, document_status_fields)) | |||
| # Create a dictionary with document attributes and additional fields | |||
| document_dict = { | |||
| "id": document.id, | |||
| "indexing_status": "paused" if document.is_paused else document.indexing_status, | |||
| "processing_started_at": document.processing_started_at, | |||
| "parsing_completed_at": document.parsing_completed_at, | |||
| "cleaning_completed_at": document.cleaning_completed_at, | |||
| "splitting_completed_at": document.splitting_completed_at, | |||
| "completed_at": document.completed_at, | |||
| "paused_at": document.paused_at, | |||
| "error": document.error, | |||
| "stopped_at": document.stopped_at, | |||
| "completed_segments": completed_segments, | |||
| "total_segments": total_segments, | |||
| } | |||
| documents_status.append(marshal(document_dict, document_status_fields)) | |||
| data = {"data": documents_status} | |||
| return data | |||
| @@ -405,7 +405,29 @@ class RetrievalService: | |||
| record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore | |||
| record["score"] = segment_child_map[record["segment"].id]["max_score"] | |||
| return [RetrievalSegments(**record) for record in records] | |||
| result = [] | |||
| for record in records: | |||
| # Extract segment | |||
| segment = record["segment"] | |||
| # Extract child_chunks, ensuring it's a list or None | |||
| child_chunks = record.get("child_chunks") | |||
| if not isinstance(child_chunks, list): | |||
| child_chunks = None | |||
| # Extract score, ensuring it's a float or None | |||
| score_value = record.get("score") | |||
| score = ( | |||
| float(score_value) | |||
| if score_value is not None and isinstance(score_value, int | float | str) | |||
| else None | |||
| ) | |||
| # Create RetrievalSegments object | |||
| retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score) | |||
| result.append(retrieval_segment) | |||
| return result | |||
| except Exception as e: | |||
| db.session.rollback() | |||
| raise e | |||
| @@ -528,7 +528,7 @@ class ToolManager: | |||
| yield provider | |||
| except Exception: | |||
| logger.exception(f"load builtin provider {provider}") | |||
| logger.exception(f"load builtin provider {provider_path}") | |||
| continue | |||
| # set builtin providers loaded | |||
| cls._builtin_providers_loaded = True | |||
| @@ -644,10 +644,10 @@ class ToolManager: | |||
| ) | |||
| workflow_provider_controllers: list[WorkflowToolProviderController] = [] | |||
| for provider in workflow_providers: | |||
| for workflow_provider in workflow_providers: | |||
| try: | |||
| workflow_provider_controllers.append( | |||
| ToolTransformService.workflow_provider_to_controller(db_provider=provider) | |||
| ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) | |||
| ) | |||
| except Exception: | |||
| # app has been deleted | |||
| @@ -1,7 +1,9 @@ | |||
| import json | |||
| import logging | |||
| from collections.abc import Generator | |||
| from typing import Any, Optional, Union, cast | |||
| from typing import Any, Optional, cast | |||
| from flask_login import current_user | |||
| from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod | |||
| from core.tools.__base.tool import Tool | |||
| @@ -87,7 +89,7 @@ class WorkflowTool(Tool): | |||
| result = generator.generate( | |||
| app_model=app, | |||
| workflow=workflow, | |||
| user=self._get_user(user_id), | |||
| user=cast("Account | EndUser", current_user), | |||
| args={"inputs": tool_parameters, "files": files}, | |||
| invoke_from=self.runtime.invoke_from, | |||
| streaming=False, | |||
| @@ -111,20 +113,6 @@ class WorkflowTool(Tool): | |||
| yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) | |||
| yield self.create_json_message(outputs) | |||
| def _get_user(self, user_id: str) -> Union[EndUser, Account]: | |||
| """ | |||
| get the user by user id | |||
| """ | |||
| user = db.session.query(EndUser).filter(EndUser.id == user_id).first() | |||
| if not user: | |||
| user = db.session.query(Account).filter(Account.id == user_id).first() | |||
| if not user: | |||
| raise ValueError("user not found") | |||
| return user | |||
| def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": | |||
| """ | |||
| fork a new tool with metadata | |||
| @@ -3,11 +3,14 @@ import json | |||
| import flask_login # type: ignore | |||
| from flask import Response, request | |||
| from flask_login import user_loaded_from_request, user_logged_in | |||
| from werkzeug.exceptions import Unauthorized | |||
| from werkzeug.exceptions import NotFound, Unauthorized | |||
| import contexts | |||
| from dify_app import DifyApp | |||
| from extensions.ext_database import db | |||
| from libs.passport import PassportService | |||
| from models.account import Account | |||
| from models.model import EndUser | |||
| from services.account_service import AccountService | |||
| login_manager = flask_login.LoginManager() | |||
| @@ -17,34 +20,48 @@ login_manager = flask_login.LoginManager() | |||
| @login_manager.request_loader | |||
| def load_user_from_request(request_from_flask_login): | |||
| """Load user based on the request.""" | |||
| if request.blueprint not in {"console", "inner_api"}: | |||
| return None | |||
| # Check if the user_id contains a dot, indicating the old format | |||
| auth_header = request.headers.get("Authorization", "") | |||
| if not auth_header: | |||
| auth_token = request.args.get("_token") | |||
| if not auth_token: | |||
| raise Unauthorized("Invalid Authorization token.") | |||
| else: | |||
| auth_token: str | None = None | |||
| if auth_header: | |||
| if " " not in auth_header: | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| auth_scheme, auth_token = auth_header.split(None, 1) | |||
| auth_scheme, auth_token = auth_header.split(maxsplit=1) | |||
| auth_scheme = auth_scheme.lower() | |||
| if auth_scheme != "bearer": | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| else: | |||
| auth_token = request.args.get("_token") | |||
| decoded = PassportService().verify(auth_token) | |||
| user_id = decoded.get("user_id") | |||
| if request.blueprint in {"console", "inner_api"}: | |||
| if not auth_token: | |||
| raise Unauthorized("Invalid Authorization token.") | |||
| decoded = PassportService().verify(auth_token) | |||
| user_id = decoded.get("user_id") | |||
| if not user_id: | |||
| raise Unauthorized("Invalid Authorization token.") | |||
| logged_in_account = AccountService.load_logged_in_account(account_id=user_id) | |||
| return logged_in_account | |||
| logged_in_account = AccountService.load_logged_in_account(account_id=user_id) | |||
| return logged_in_account | |||
| elif request.blueprint == "web": | |||
| decoded = PassportService().verify(auth_token) | |||
| end_user_id = decoded.get("end_user_id") | |||
| if not end_user_id: | |||
| raise Unauthorized("Invalid Authorization token.") | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() | |||
| if not end_user: | |||
| raise NotFound("End user not found.") | |||
| return end_user | |||
| @user_logged_in.connect | |||
| @user_loaded_from_request.connect | |||
| def on_user_logged_in(_sender, user): | |||
| """Called when a user logged in.""" | |||
| if user: | |||
| """Called when a user logged in. | |||
| Note: AccountService.load_logged_in_account will populate user.current_tenant_id | |||
| through the load_user method, which calls account.set_tenant_id(). | |||
| """ | |||
| if user and isinstance(user, Account) and user.current_tenant_id: | |||
| contexts.tenant_id.set(user.current_tenant_id) | |||
| @@ -1,10 +1,10 @@ | |||
| import enum | |||
| import json | |||
| from typing import cast | |||
| from typing import Optional, cast | |||
| from flask_login import UserMixin # type: ignore | |||
| from sqlalchemy import func | |||
| from sqlalchemy.orm import Mapped, mapped_column | |||
| from sqlalchemy.orm import Mapped, mapped_column, reconstructor | |||
| from models.base import Base | |||
| @@ -12,6 +12,66 @@ from .engine import db | |||
| from .types import StringUUID | |||
| class TenantAccountRole(enum.StrEnum): | |||
| OWNER = "owner" | |||
| ADMIN = "admin" | |||
| EDITOR = "editor" | |||
| NORMAL = "normal" | |||
| DATASET_OPERATOR = "dataset_operator" | |||
| @staticmethod | |||
| def is_valid_role(role: str) -> bool: | |||
| if not role: | |||
| return False | |||
| return role in { | |||
| TenantAccountRole.OWNER, | |||
| TenantAccountRole.ADMIN, | |||
| TenantAccountRole.EDITOR, | |||
| TenantAccountRole.NORMAL, | |||
| TenantAccountRole.DATASET_OPERATOR, | |||
| } | |||
| @staticmethod | |||
| def is_privileged_role(role: Optional["TenantAccountRole"]) -> bool: | |||
| if not role: | |||
| return False | |||
| return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} | |||
| @staticmethod | |||
| def is_admin_role(role: Optional["TenantAccountRole"]) -> bool: | |||
| if not role: | |||
| return False | |||
| return role == TenantAccountRole.ADMIN | |||
| @staticmethod | |||
| def is_non_owner_role(role: Optional["TenantAccountRole"]) -> bool: | |||
| if not role: | |||
| return False | |||
| return role in { | |||
| TenantAccountRole.ADMIN, | |||
| TenantAccountRole.EDITOR, | |||
| TenantAccountRole.NORMAL, | |||
| TenantAccountRole.DATASET_OPERATOR, | |||
| } | |||
| @staticmethod | |||
| def is_editing_role(role: Optional["TenantAccountRole"]) -> bool: | |||
| if not role: | |||
| return False | |||
| return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} | |||
| @staticmethod | |||
| def is_dataset_edit_role(role: Optional["TenantAccountRole"]) -> bool: | |||
| if not role: | |||
| return False | |||
| return role in { | |||
| TenantAccountRole.OWNER, | |||
| TenantAccountRole.ADMIN, | |||
| TenantAccountRole.EDITOR, | |||
| TenantAccountRole.DATASET_OPERATOR, | |||
| } | |||
| class AccountStatus(enum.StrEnum): | |||
| PENDING = "pending" | |||
| UNINITIALIZED = "uninitialized" | |||
| @@ -41,24 +101,27 @@ class Account(UserMixin, Base): | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @reconstructor | |||
| def init_on_load(self): | |||
| self.role: Optional[TenantAccountRole] = None | |||
| self._current_tenant: Optional[Tenant] = None | |||
| @property | |||
| def is_password_set(self): | |||
| return self.password is not None | |||
| @property | |||
| def current_tenant(self): | |||
| return self._current_tenant # type: ignore | |||
| return self._current_tenant | |||
| @current_tenant.setter | |||
| def current_tenant(self, value: "Tenant"): | |||
| tenant = value | |||
| def current_tenant(self, tenant: "Tenant"): | |||
| ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first() | |||
| if ta: | |||
| tenant.current_role = ta.role | |||
| else: | |||
| tenant = None # type: ignore | |||
| self._current_tenant = tenant | |||
| self.role = TenantAccountRole(ta.role) | |||
| self._current_tenant = tenant | |||
| return | |||
| self._current_tenant = None | |||
| @property | |||
| def current_tenant_id(self) -> str | None: | |||
| @@ -80,12 +143,12 @@ class Account(UserMixin, Base): | |||
| return | |||
| tenant, join = tenant_account_join | |||
| tenant.current_role = join.role | |||
| self.role = join.role | |||
| self._current_tenant = tenant | |||
| @property | |||
| def current_role(self): | |||
| return self._current_tenant.current_role | |||
| return self.role | |||
| def get_status(self) -> AccountStatus: | |||
| status_str = self.status | |||
| @@ -105,23 +168,23 @@ class Account(UserMixin, Base): | |||
| # check current_user.current_tenant.current_role in ['admin', 'owner'] | |||
| @property | |||
| def is_admin_or_owner(self): | |||
| return TenantAccountRole.is_privileged_role(self._current_tenant.current_role) | |||
| return TenantAccountRole.is_privileged_role(self.role) | |||
| @property | |||
| def is_admin(self): | |||
| return TenantAccountRole.is_admin_role(self._current_tenant.current_role) | |||
| return TenantAccountRole.is_admin_role(self.role) | |||
| @property | |||
| def is_editor(self): | |||
| return TenantAccountRole.is_editing_role(self._current_tenant.current_role) | |||
| return TenantAccountRole.is_editing_role(self.role) | |||
| @property | |||
| def is_dataset_editor(self): | |||
| return TenantAccountRole.is_dataset_edit_role(self._current_tenant.current_role) | |||
| return TenantAccountRole.is_dataset_edit_role(self.role) | |||
| @property | |||
| def is_dataset_operator(self): | |||
| return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR | |||
| return self.role == TenantAccountRole.DATASET_OPERATOR | |||
| class TenantStatus(enum.StrEnum): | |||
| @@ -129,66 +192,6 @@ class TenantStatus(enum.StrEnum): | |||
| ARCHIVE = "archive" | |||
| class TenantAccountRole(enum.StrEnum): | |||
| OWNER = "owner" | |||
| ADMIN = "admin" | |||
| EDITOR = "editor" | |||
| NORMAL = "normal" | |||
| DATASET_OPERATOR = "dataset_operator" | |||
| @staticmethod | |||
| def is_valid_role(role: str) -> bool: | |||
| if not role: | |||
| return False | |||
| return role in { | |||
| TenantAccountRole.OWNER, | |||
| TenantAccountRole.ADMIN, | |||
| TenantAccountRole.EDITOR, | |||
| TenantAccountRole.NORMAL, | |||
| TenantAccountRole.DATASET_OPERATOR, | |||
| } | |||
| @staticmethod | |||
| def is_privileged_role(role: str) -> bool: | |||
| if not role: | |||
| return False | |||
| return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} | |||
| @staticmethod | |||
| def is_admin_role(role: str) -> bool: | |||
| if not role: | |||
| return False | |||
| return role == TenantAccountRole.ADMIN | |||
| @staticmethod | |||
| def is_non_owner_role(role: str) -> bool: | |||
| if not role: | |||
| return False | |||
| return role in { | |||
| TenantAccountRole.ADMIN, | |||
| TenantAccountRole.EDITOR, | |||
| TenantAccountRole.NORMAL, | |||
| TenantAccountRole.DATASET_OPERATOR, | |||
| } | |||
| @staticmethod | |||
| def is_editing_role(role: str) -> bool: | |||
| if not role: | |||
| return False | |||
| return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} | |||
| @staticmethod | |||
| def is_dataset_edit_role(role: str) -> bool: | |||
| if not role: | |||
| return False | |||
| return role in { | |||
| TenantAccountRole.OWNER, | |||
| TenantAccountRole.ADMIN, | |||
| TenantAccountRole.EDITOR, | |||
| TenantAccountRole.DATASET_OPERATOR, | |||
| } | |||
| class Tenant(Base): | |||
| __tablename__ = "tenants" | |||
| __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) | |||
| @@ -1,5 +1,7 @@ | |||
| from sqlalchemy.orm import declarative_base | |||
| from sqlalchemy.orm import DeclarativeBase | |||
| from models.engine import metadata | |||
| Base = declarative_base(metadata=metadata) | |||
| class Base(DeclarativeBase): | |||
| metadata = metadata | |||
| @@ -172,10 +172,6 @@ class WorkflowToolProvider(Base): | |||
| db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") | |||
| ) | |||
| @property | |||
| def schema_type(self) -> ApiProviderSchemaType: | |||
| return ApiProviderSchemaType.value_of(self.schema_type_str) | |||
| @property | |||
| def user(self) -> Account | None: | |||
| return db.session.query(Account).filter(Account.id == self.user_id).first() | |||
| @@ -3,7 +3,7 @@ import logging | |||
| from collections.abc import Mapping, Sequence | |||
| from datetime import UTC, datetime | |||
| from enum import Enum, StrEnum | |||
| from typing import TYPE_CHECKING, Any, Optional, Self, Union | |||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||
| from uuid import uuid4 | |||
| from core.variables import utils as variable_utils | |||
| @@ -150,7 +150,7 @@ class Workflow(Base): | |||
| conversation_variables: Sequence[Variable], | |||
| marked_name: str = "", | |||
| marked_comment: str = "", | |||
| ) -> Self: | |||
| ) -> "Workflow": | |||
| workflow = Workflow() | |||
| workflow.id = str(uuid4()) | |||
| workflow.tenant_id = tenant_id | |||
| @@ -23,11 +23,10 @@ class VectorService: | |||
| ): | |||
| documents: list[Document] = [] | |||
| document: Document | None = None | |||
| for segment in segments: | |||
| if doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() | |||
| if not document: | |||
| dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() | |||
| if not dataset_document: | |||
| _logger.warning( | |||
| "Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s", | |||
| segment.document_id, | |||
| @@ -37,7 +36,7 @@ class VectorService: | |||
| # get the process rule | |||
| processing_rule = ( | |||
| db.session.query(DatasetProcessRule) | |||
| .filter(DatasetProcessRule.id == document.dataset_process_rule_id) | |||
| .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||
| .first() | |||
| ) | |||
| if not processing_rule: | |||
| @@ -61,9 +60,11 @@ class VectorService: | |||
| ) | |||
| else: | |||
| raise ValueError("The knowledge base index technique is not high quality!") | |||
| cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False) | |||
| cls.generate_child_chunks( | |||
| segment, dataset_document, dataset, embedding_model_instance, processing_rule, False | |||
| ) | |||
| else: | |||
| document = Document( | |||
| rag_document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| @@ -72,7 +73,7 @@ class VectorService: | |||
| "dataset_id": segment.dataset_id, | |||
| }, | |||
| ) | |||
| documents.append(document) | |||
| documents.append(rag_document) | |||
| if len(documents) > 0: | |||
| index_processor = IndexProcessorFactory(doc_form).init_index_processor() | |||
| index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) | |||
| @@ -508,11 +508,11 @@ class WorkflowService: | |||
| raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") | |||
| # Check if this workflow is currently referenced by an app | |||
| stmt = select(App).where(App.workflow_id == workflow_id) | |||
| app = session.scalar(stmt) | |||
| app_stmt = select(App).where(App.workflow_id == workflow_id) | |||
| app = session.scalar(app_stmt) | |||
| if app: | |||
| # Cannot delete a workflow that's currently in use by an app | |||
| raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'") | |||
| raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'") | |||
| # Don't use workflow.tool_published as it's not accurate for specific workflow versions | |||
| # Check if there's a tool provider using this specific workflow version | |||
| @@ -111,7 +111,7 @@ def add_document_to_index_task(dataset_document_id: str): | |||
| logging.exception("add document to index failed") | |||
| dataset_document.enabled = False | |||
| dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| dataset_document.status = "error" | |||
| dataset_document.indexing_status = "error" | |||
| dataset_document.error = str(e) | |||
| db.session.commit() | |||
| finally: | |||
| @@ -193,7 +193,7 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): | |||
| def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): | |||
| # Get app's owner | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(Account).where(Account.id == App.owner_id).where(App.id == app_id) | |||
| stmt = select(Account).where(Account.id == App.created_by).where(App.id == app_id) | |||
| user = session.scalar(stmt) | |||
| if user is None: | |||
| @@ -34,13 +34,13 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel | |||
| # needs to patch those methods to avoid database access. | |||
| monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) | |||
| monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) | |||
| monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None) | |||
| # replace `WorkflowAppGenerator.generate` 's return value. | |||
| monkeypatch.setattr( | |||
| "core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", | |||
| lambda *args, **kwargs: {"data": {"error": "oops"}}, | |||
| ) | |||
| monkeypatch.setattr("flask_login.current_user", lambda *args, **kwargs: None) | |||
| with pytest.raises(ToolInvokeError) as exc_info: | |||
| # WorkflowTool always returns a generator, so we need to iterate to | |||