Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.4.1
| except AccountRegisterError as are: | except AccountRegisterError as are: | ||||
| raise AccountInFreezeError() | raise AccountInFreezeError() | ||||
| if account: | if account: | ||||
| tenant = TenantService.get_join_tenants(account) | |||||
| if not tenant: | |||||
| tenants = TenantService.get_join_tenants(account) | |||||
| if not tenants: | |||||
| workspaces = FeatureService.get_system_features().license.workspaces | workspaces = FeatureService.get_system_features().license.workspaces | ||||
| if not workspaces.is_available(): | if not workspaces.is_available(): | ||||
| raise WorkspacesLimitExceeded() | raise WorkspacesLimitExceeded() | ||||
| if not FeatureService.get_system_features().is_allow_create_workspace: | if not FeatureService.get_system_features().is_allow_create_workspace: | ||||
| raise NotAllowedCreateWorkspace() | raise NotAllowedCreateWorkspace() | ||||
| else: | else: | ||||
| tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |||||
| TenantService.create_tenant_member(tenant, account, role="owner") | |||||
| account.current_tenant = tenant | |||||
| tenant_was_created.send(tenant) | |||||
| new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |||||
| TenantService.create_tenant_member(new_tenant, account, role="owner") | |||||
| account.current_tenant = new_tenant | |||||
| tenant_was_created.send(new_tenant) | |||||
| if account is None: | if account is None: | ||||
| try: | try: |
| account = _get_account_by_openid_or_email(provider, user_info) | account = _get_account_by_openid_or_email(provider, user_info) | ||||
| if account: | if account: | ||||
| tenant = TenantService.get_join_tenants(account) | |||||
| if not tenant: | |||||
| tenants = TenantService.get_join_tenants(account) | |||||
| if not tenants: | |||||
| if not FeatureService.get_system_features().is_allow_create_workspace: | if not FeatureService.get_system_features().is_allow_create_workspace: | ||||
| raise WorkSpaceNotAllowedCreateError() | raise WorkSpaceNotAllowedCreateError() | ||||
| else: | else: | ||||
| tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |||||
| TenantService.create_tenant_member(tenant, account, role="owner") | |||||
| account.current_tenant = tenant | |||||
| tenant_was_created.send(tenant) | |||||
| new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |||||
| TenantService.create_tenant_member(new_tenant, account, role="owner") | |||||
| account.current_tenant = new_tenant | |||||
| tenant_was_created.send(new_tenant) | |||||
| if not account: | if not account: | ||||
| if not FeatureService.get_system_features().is_allow_register: | if not FeatureService.get_system_features().is_allow_register: |
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | ||||
| .count() | .count() | ||||
| ) | ) | ||||
| document.completed_segments = completed_segments | |||||
| document.total_segments = total_segments | |||||
| documents_status.append(marshal(document, document_status_fields)) | |||||
| # Create a dictionary with document attributes and additional fields | |||||
| document_dict = { | |||||
| "id": document.id, | |||||
| "indexing_status": document.indexing_status, | |||||
| "processing_started_at": document.processing_started_at, | |||||
| "parsing_completed_at": document.parsing_completed_at, | |||||
| "cleaning_completed_at": document.cleaning_completed_at, | |||||
| "splitting_completed_at": document.splitting_completed_at, | |||||
| "completed_at": document.completed_at, | |||||
| "paused_at": document.paused_at, | |||||
| "error": document.error, | |||||
| "stopped_at": document.stopped_at, | |||||
| "completed_segments": completed_segments, | |||||
| "total_segments": total_segments, | |||||
| } | |||||
| documents_status.append(marshal(document_dict, document_status_fields)) | |||||
| data = {"data": documents_status} | data = {"data": documents_status} | ||||
| return data | return data | ||||
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | ||||
| .count() | .count() | ||||
| ) | ) | ||||
| document.completed_segments = completed_segments | |||||
| document.total_segments = total_segments | |||||
| if document.is_paused: | |||||
| document.indexing_status = "paused" | |||||
| documents_status.append(marshal(document, document_status_fields)) | |||||
| # Create a dictionary with document attributes and additional fields | |||||
| document_dict = { | |||||
| "id": document.id, | |||||
| "indexing_status": "paused" if document.is_paused else document.indexing_status, | |||||
| "processing_started_at": document.processing_started_at, | |||||
| "parsing_completed_at": document.parsing_completed_at, | |||||
| "cleaning_completed_at": document.cleaning_completed_at, | |||||
| "splitting_completed_at": document.splitting_completed_at, | |||||
| "completed_at": document.completed_at, | |||||
| "paused_at": document.paused_at, | |||||
| "error": document.error, | |||||
| "stopped_at": document.stopped_at, | |||||
| "completed_segments": completed_segments, | |||||
| "total_segments": total_segments, | |||||
| } | |||||
| documents_status.append(marshal(document_dict, document_status_fields)) | |||||
| data = {"data": documents_status} | data = {"data": documents_status} | ||||
| return data | return data | ||||
| .count() | .count() | ||||
| ) | ) | ||||
| document.completed_segments = completed_segments | |||||
| document.total_segments = total_segments | |||||
| if document.is_paused: | |||||
| document.indexing_status = "paused" | |||||
| return marshal(document, document_status_fields) | |||||
| # Create a dictionary with document attributes and additional fields | |||||
| document_dict = { | |||||
| "id": document.id, | |||||
| "indexing_status": "paused" if document.is_paused else document.indexing_status, | |||||
| "processing_started_at": document.processing_started_at, | |||||
| "parsing_completed_at": document.parsing_completed_at, | |||||
| "cleaning_completed_at": document.cleaning_completed_at, | |||||
| "splitting_completed_at": document.splitting_completed_at, | |||||
| "completed_at": document.completed_at, | |||||
| "paused_at": document.paused_at, | |||||
| "error": document.error, | |||||
| "stopped_at": document.stopped_at, | |||||
| "completed_segments": completed_segments, | |||||
| "total_segments": total_segments, | |||||
| } | |||||
| return marshal(document_dict, document_status_fields) | |||||
| class DocumentDetailApi(DocumentResource): | class DocumentDetailApi(DocumentResource): |
| @account_initialization_required | @account_initialization_required | ||||
| def get(self): | def get(self): | ||||
| tenants = TenantService.get_join_tenants(current_user) | tenants = TenantService.get_join_tenants(current_user) | ||||
| tenant_dicts = [] | |||||
| for tenant in tenants: | for tenant in tenants: | ||||
| features = FeatureService.get_features(tenant.id) | features = FeatureService.get_features(tenant.id) | ||||
| if features.billing.enabled: | |||||
| tenant.plan = features.billing.subscription.plan | |||||
| else: | |||||
| tenant.plan = "sandbox" | |||||
| if tenant.id == current_user.current_tenant_id: | |||||
| tenant.current = True # Set current=True for current tenant | |||||
| return {"workspaces": marshal(tenants, tenants_fields)}, 200 | |||||
| # Create a dictionary with tenant attributes | |||||
| tenant_dict = { | |||||
| "id": tenant.id, | |||||
| "name": tenant.name, | |||||
| "status": tenant.status, | |||||
| "created_at": tenant.created_at, | |||||
| "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", | |||||
| "current": tenant.id == current_user.current_tenant_id, | |||||
| } | |||||
| tenant_dicts.append(tenant_dict) | |||||
| return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200 | |||||
| class WorkspaceListApi(Resource): | class WorkspaceListApi(Resource): |
| extension = guess_extension(tool_file.mimetype) or ".bin" | extension = guess_extension(tool_file.mimetype) or ".bin" | ||||
| preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension) | preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension) | ||||
| tool_file.mime_type = mimetype | |||||
| tool_file.extension = extension | |||||
| tool_file.preview_url = preview_url | |||||
| # Create a dictionary with all the necessary attributes | |||||
| result = { | |||||
| "id": tool_file.id, | |||||
| "user_id": tool_file.user_id, | |||||
| "tenant_id": tool_file.tenant_id, | |||||
| "conversation_id": tool_file.conversation_id, | |||||
| "file_key": tool_file.file_key, | |||||
| "mimetype": tool_file.mimetype, | |||||
| "original_url": tool_file.original_url, | |||||
| "name": tool_file.name, | |||||
| "size": tool_file.size, | |||||
| "mime_type": mimetype, | |||||
| "extension": extension, | |||||
| "preview_url": preview_url, | |||||
| } | |||||
| return result, 201 | |||||
| except services.errors.file.FileTooLargeError as file_too_large_error: | except services.errors.file.FileTooLargeError as file_too_large_error: | ||||
| raise FileTooLargeError(file_too_large_error.description) | raise FileTooLargeError(file_too_large_error.description) | ||||
| except services.errors.file.UnsupportedFileTypeError: | except services.errors.file.UnsupportedFileTypeError: |
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | ||||
| .count() | .count() | ||||
| ) | ) | ||||
| document.completed_segments = completed_segments | |||||
| document.total_segments = total_segments | |||||
| if document.is_paused: | |||||
| document.indexing_status = "paused" | |||||
| documents_status.append(marshal(document, document_status_fields)) | |||||
| # Create a dictionary with document attributes and additional fields | |||||
| document_dict = { | |||||
| "id": document.id, | |||||
| "indexing_status": "paused" if document.is_paused else document.indexing_status, | |||||
| "processing_started_at": document.processing_started_at, | |||||
| "parsing_completed_at": document.parsing_completed_at, | |||||
| "cleaning_completed_at": document.cleaning_completed_at, | |||||
| "splitting_completed_at": document.splitting_completed_at, | |||||
| "completed_at": document.completed_at, | |||||
| "paused_at": document.paused_at, | |||||
| "error": document.error, | |||||
| "stopped_at": document.stopped_at, | |||||
| "completed_segments": completed_segments, | |||||
| "total_segments": total_segments, | |||||
| } | |||||
| documents_status.append(marshal(document_dict, document_status_fields)) | |||||
| data = {"data": documents_status} | data = {"data": documents_status} | ||||
| return data | return data | ||||
| record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore | record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore | ||||
| record["score"] = segment_child_map[record["segment"].id]["max_score"] | record["score"] = segment_child_map[record["segment"].id]["max_score"] | ||||
| return [RetrievalSegments(**record) for record in records] | |||||
| result = [] | |||||
| for record in records: | |||||
| # Extract segment | |||||
| segment = record["segment"] | |||||
| # Extract child_chunks, ensuring it's a list or None | |||||
| child_chunks = record.get("child_chunks") | |||||
| if not isinstance(child_chunks, list): | |||||
| child_chunks = None | |||||
| # Extract score, ensuring it's a float or None | |||||
| score_value = record.get("score") | |||||
| score = ( | |||||
| float(score_value) | |||||
| if score_value is not None and isinstance(score_value, int | float | str) | |||||
| else None | |||||
| ) | |||||
| # Create RetrievalSegments object | |||||
| retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score) | |||||
| result.append(retrieval_segment) | |||||
| return result | |||||
| except Exception as e: | except Exception as e: | ||||
| db.session.rollback() | db.session.rollback() | ||||
| raise e | raise e |
| yield provider | yield provider | ||||
| except Exception: | except Exception: | ||||
| logger.exception(f"load builtin provider {provider}") | |||||
| logger.exception(f"load builtin provider {provider_path}") | |||||
| continue | continue | ||||
| # set builtin providers loaded | # set builtin providers loaded | ||||
| cls._builtin_providers_loaded = True | cls._builtin_providers_loaded = True | ||||
| ) | ) | ||||
| workflow_provider_controllers: list[WorkflowToolProviderController] = [] | workflow_provider_controllers: list[WorkflowToolProviderController] = [] | ||||
| for provider in workflow_providers: | |||||
| for workflow_provider in workflow_providers: | |||||
| try: | try: | ||||
| workflow_provider_controllers.append( | workflow_provider_controllers.append( | ||||
| ToolTransformService.workflow_provider_to_controller(db_provider=provider) | |||||
| ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) | |||||
| ) | ) | ||||
| except Exception: | except Exception: | ||||
| # app has been deleted | # app has been deleted |
| import json | import json | ||||
| import logging | import logging | ||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from typing import Any, Optional, Union, cast | |||||
| from typing import Any, Optional, cast | |||||
| from flask_login import current_user | |||||
| from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod | from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod | ||||
| from core.tools.__base.tool import Tool | from core.tools.__base.tool import Tool | ||||
| result = generator.generate( | result = generator.generate( | ||||
| app_model=app, | app_model=app, | ||||
| workflow=workflow, | workflow=workflow, | ||||
| user=self._get_user(user_id), | |||||
| user=cast("Account | EndUser", current_user), | |||||
| args={"inputs": tool_parameters, "files": files}, | args={"inputs": tool_parameters, "files": files}, | ||||
| invoke_from=self.runtime.invoke_from, | invoke_from=self.runtime.invoke_from, | ||||
| streaming=False, | streaming=False, | ||||
| yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) | yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) | ||||
| yield self.create_json_message(outputs) | yield self.create_json_message(outputs) | ||||
| def _get_user(self, user_id: str) -> Union[EndUser, Account]: | |||||
| """ | |||||
| get the user by user id | |||||
| """ | |||||
| user = db.session.query(EndUser).filter(EndUser.id == user_id).first() | |||||
| if not user: | |||||
| user = db.session.query(Account).filter(Account.id == user_id).first() | |||||
| if not user: | |||||
| raise ValueError("user not found") | |||||
| return user | |||||
| def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": | def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": | ||||
| """ | """ | ||||
| fork a new tool with metadata | fork a new tool with metadata |
| import flask_login # type: ignore | import flask_login # type: ignore | ||||
| from flask import Response, request | from flask import Response, request | ||||
| from flask_login import user_loaded_from_request, user_logged_in | from flask_login import user_loaded_from_request, user_logged_in | ||||
| from werkzeug.exceptions import Unauthorized | |||||
| from werkzeug.exceptions import NotFound, Unauthorized | |||||
| import contexts | import contexts | ||||
| from dify_app import DifyApp | from dify_app import DifyApp | ||||
| from extensions.ext_database import db | |||||
| from libs.passport import PassportService | from libs.passport import PassportService | ||||
| from models.account import Account | |||||
| from models.model import EndUser | |||||
| from services.account_service import AccountService | from services.account_service import AccountService | ||||
| login_manager = flask_login.LoginManager() | login_manager = flask_login.LoginManager() | ||||
| @login_manager.request_loader | @login_manager.request_loader | ||||
| def load_user_from_request(request_from_flask_login): | def load_user_from_request(request_from_flask_login): | ||||
| """Load user based on the request.""" | """Load user based on the request.""" | ||||
| if request.blueprint not in {"console", "inner_api"}: | |||||
| return None | |||||
| # Check if the user_id contains a dot, indicating the old format | |||||
| auth_header = request.headers.get("Authorization", "") | auth_header = request.headers.get("Authorization", "") | ||||
| if not auth_header: | |||||
| auth_token = request.args.get("_token") | |||||
| if not auth_token: | |||||
| raise Unauthorized("Invalid Authorization token.") | |||||
| else: | |||||
| auth_token: str | None = None | |||||
| if auth_header: | |||||
| if " " not in auth_header: | if " " not in auth_header: | ||||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | ||||
| auth_scheme, auth_token = auth_header.split(None, 1) | |||||
| auth_scheme, auth_token = auth_header.split(maxsplit=1) | |||||
| auth_scheme = auth_scheme.lower() | auth_scheme = auth_scheme.lower() | ||||
| if auth_scheme != "bearer": | if auth_scheme != "bearer": | ||||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | ||||
| else: | |||||
| auth_token = request.args.get("_token") | |||||
| decoded = PassportService().verify(auth_token) | |||||
| user_id = decoded.get("user_id") | |||||
| if request.blueprint in {"console", "inner_api"}: | |||||
| if not auth_token: | |||||
| raise Unauthorized("Invalid Authorization token.") | |||||
| decoded = PassportService().verify(auth_token) | |||||
| user_id = decoded.get("user_id") | |||||
| if not user_id: | |||||
| raise Unauthorized("Invalid Authorization token.") | |||||
| logged_in_account = AccountService.load_logged_in_account(account_id=user_id) | |||||
| return logged_in_account | |||||
| logged_in_account = AccountService.load_logged_in_account(account_id=user_id) | |||||
| return logged_in_account | |||||
| elif request.blueprint == "web": | |||||
| decoded = PassportService().verify(auth_token) | |||||
| end_user_id = decoded.get("end_user_id") | |||||
| if not end_user_id: | |||||
| raise Unauthorized("Invalid Authorization token.") | |||||
| end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() | |||||
| if not end_user: | |||||
| raise NotFound("End user not found.") | |||||
| return end_user | |||||
| @user_logged_in.connect | @user_logged_in.connect | ||||
| @user_loaded_from_request.connect | @user_loaded_from_request.connect | ||||
| def on_user_logged_in(_sender, user): | def on_user_logged_in(_sender, user): | ||||
| """Called when a user logged in.""" | |||||
| if user: | |||||
| """Called when a user logged in. | |||||
| Note: AccountService.load_logged_in_account will populate user.current_tenant_id | |||||
| through the load_user method, which calls account.set_tenant_id(). | |||||
| """ | |||||
| if user and isinstance(user, Account) and user.current_tenant_id: | |||||
| contexts.tenant_id.set(user.current_tenant_id) | contexts.tenant_id.set(user.current_tenant_id) | ||||
| import enum | import enum | ||||
| import json | import json | ||||
| from typing import cast | |||||
| from typing import Optional, cast | |||||
| from flask_login import UserMixin # type: ignore | from flask_login import UserMixin # type: ignore | ||||
| from sqlalchemy import func | from sqlalchemy import func | ||||
| from sqlalchemy.orm import Mapped, mapped_column | |||||
| from sqlalchemy.orm import Mapped, mapped_column, reconstructor | |||||
| from models.base import Base | from models.base import Base | ||||
| from .types import StringUUID | from .types import StringUUID | ||||
| class TenantAccountRole(enum.StrEnum): | |||||
| OWNER = "owner" | |||||
| ADMIN = "admin" | |||||
| EDITOR = "editor" | |||||
| NORMAL = "normal" | |||||
| DATASET_OPERATOR = "dataset_operator" | |||||
| @staticmethod | |||||
| def is_valid_role(role: str) -> bool: | |||||
| if not role: | |||||
| return False | |||||
| return role in { | |||||
| TenantAccountRole.OWNER, | |||||
| TenantAccountRole.ADMIN, | |||||
| TenantAccountRole.EDITOR, | |||||
| TenantAccountRole.NORMAL, | |||||
| TenantAccountRole.DATASET_OPERATOR, | |||||
| } | |||||
| @staticmethod | |||||
| def is_privileged_role(role: Optional["TenantAccountRole"]) -> bool: | |||||
| if not role: | |||||
| return False | |||||
| return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} | |||||
| @staticmethod | |||||
| def is_admin_role(role: Optional["TenantAccountRole"]) -> bool: | |||||
| if not role: | |||||
| return False | |||||
| return role == TenantAccountRole.ADMIN | |||||
| @staticmethod | |||||
| def is_non_owner_role(role: Optional["TenantAccountRole"]) -> bool: | |||||
| if not role: | |||||
| return False | |||||
| return role in { | |||||
| TenantAccountRole.ADMIN, | |||||
| TenantAccountRole.EDITOR, | |||||
| TenantAccountRole.NORMAL, | |||||
| TenantAccountRole.DATASET_OPERATOR, | |||||
| } | |||||
| @staticmethod | |||||
| def is_editing_role(role: Optional["TenantAccountRole"]) -> bool: | |||||
| if not role: | |||||
| return False | |||||
| return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} | |||||
| @staticmethod | |||||
| def is_dataset_edit_role(role: Optional["TenantAccountRole"]) -> bool: | |||||
| if not role: | |||||
| return False | |||||
| return role in { | |||||
| TenantAccountRole.OWNER, | |||||
| TenantAccountRole.ADMIN, | |||||
| TenantAccountRole.EDITOR, | |||||
| TenantAccountRole.DATASET_OPERATOR, | |||||
| } | |||||
| class AccountStatus(enum.StrEnum): | class AccountStatus(enum.StrEnum): | ||||
| PENDING = "pending" | PENDING = "pending" | ||||
| UNINITIALIZED = "uninitialized" | UNINITIALIZED = "uninitialized" | ||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| @reconstructor | |||||
| def init_on_load(self): | |||||
| self.role: Optional[TenantAccountRole] = None | |||||
| self._current_tenant: Optional[Tenant] = None | |||||
| @property | @property | ||||
| def is_password_set(self): | def is_password_set(self): | ||||
| return self.password is not None | return self.password is not None | ||||
| @property | @property | ||||
| def current_tenant(self): | def current_tenant(self): | ||||
| return self._current_tenant # type: ignore | |||||
| return self._current_tenant | |||||
| @current_tenant.setter | @current_tenant.setter | ||||
| def current_tenant(self, value: "Tenant"): | |||||
| tenant = value | |||||
| def current_tenant(self, tenant: "Tenant"): | |||||
| ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first() | ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first() | ||||
| if ta: | if ta: | ||||
| tenant.current_role = ta.role | |||||
| else: | |||||
| tenant = None # type: ignore | |||||
| self._current_tenant = tenant | |||||
| self.role = TenantAccountRole(ta.role) | |||||
| self._current_tenant = tenant | |||||
| return | |||||
| self._current_tenant = None | |||||
| @property | @property | ||||
| def current_tenant_id(self) -> str | None: | def current_tenant_id(self) -> str | None: | ||||
| return | return | ||||
| tenant, join = tenant_account_join | tenant, join = tenant_account_join | ||||
| tenant.current_role = join.role | |||||
| self.role = join.role | |||||
| self._current_tenant = tenant | self._current_tenant = tenant | ||||
| @property | @property | ||||
| def current_role(self): | def current_role(self): | ||||
| return self._current_tenant.current_role | |||||
| return self.role | |||||
| def get_status(self) -> AccountStatus: | def get_status(self) -> AccountStatus: | ||||
| status_str = self.status | status_str = self.status | ||||
| # check current_user.current_tenant.current_role in ['admin', 'owner'] | # check current_user.current_tenant.current_role in ['admin', 'owner'] | ||||
| @property | @property | ||||
| def is_admin_or_owner(self): | def is_admin_or_owner(self): | ||||
| return TenantAccountRole.is_privileged_role(self._current_tenant.current_role) | |||||
| return TenantAccountRole.is_privileged_role(self.role) | |||||
| @property | @property | ||||
| def is_admin(self): | def is_admin(self): | ||||
| return TenantAccountRole.is_admin_role(self._current_tenant.current_role) | |||||
| return TenantAccountRole.is_admin_role(self.role) | |||||
| @property | @property | ||||
| def is_editor(self): | def is_editor(self): | ||||
| return TenantAccountRole.is_editing_role(self._current_tenant.current_role) | |||||
| return TenantAccountRole.is_editing_role(self.role) | |||||
| @property | @property | ||||
| def is_dataset_editor(self): | def is_dataset_editor(self): | ||||
| return TenantAccountRole.is_dataset_edit_role(self._current_tenant.current_role) | |||||
| return TenantAccountRole.is_dataset_edit_role(self.role) | |||||
| @property | @property | ||||
| def is_dataset_operator(self): | def is_dataset_operator(self): | ||||
| return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR | |||||
| return self.role == TenantAccountRole.DATASET_OPERATOR | |||||
| class TenantStatus(enum.StrEnum): | class TenantStatus(enum.StrEnum): | ||||
| ARCHIVE = "archive" | ARCHIVE = "archive" | ||||
| class TenantAccountRole(enum.StrEnum): | |||||
| OWNER = "owner" | |||||
| ADMIN = "admin" | |||||
| EDITOR = "editor" | |||||
| NORMAL = "normal" | |||||
| DATASET_OPERATOR = "dataset_operator" | |||||
| @staticmethod | |||||
| def is_valid_role(role: str) -> bool: | |||||
| if not role: | |||||
| return False | |||||
| return role in { | |||||
| TenantAccountRole.OWNER, | |||||
| TenantAccountRole.ADMIN, | |||||
| TenantAccountRole.EDITOR, | |||||
| TenantAccountRole.NORMAL, | |||||
| TenantAccountRole.DATASET_OPERATOR, | |||||
| } | |||||
| @staticmethod | |||||
| def is_privileged_role(role: str) -> bool: | |||||
| if not role: | |||||
| return False | |||||
| return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} | |||||
| @staticmethod | |||||
| def is_admin_role(role: str) -> bool: | |||||
| if not role: | |||||
| return False | |||||
| return role == TenantAccountRole.ADMIN | |||||
| @staticmethod | |||||
| def is_non_owner_role(role: str) -> bool: | |||||
| if not role: | |||||
| return False | |||||
| return role in { | |||||
| TenantAccountRole.ADMIN, | |||||
| TenantAccountRole.EDITOR, | |||||
| TenantAccountRole.NORMAL, | |||||
| TenantAccountRole.DATASET_OPERATOR, | |||||
| } | |||||
| @staticmethod | |||||
| def is_editing_role(role: str) -> bool: | |||||
| if not role: | |||||
| return False | |||||
| return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} | |||||
| @staticmethod | |||||
| def is_dataset_edit_role(role: str) -> bool: | |||||
| if not role: | |||||
| return False | |||||
| return role in { | |||||
| TenantAccountRole.OWNER, | |||||
| TenantAccountRole.ADMIN, | |||||
| TenantAccountRole.EDITOR, | |||||
| TenantAccountRole.DATASET_OPERATOR, | |||||
| } | |||||
| class Tenant(Base): | class Tenant(Base): | ||||
| __tablename__ = "tenants" | __tablename__ = "tenants" | ||||
| __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) | __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) |
| from sqlalchemy.orm import declarative_base | |||||
| from sqlalchemy.orm import DeclarativeBase | |||||
| from models.engine import metadata | from models.engine import metadata | ||||
| Base = declarative_base(metadata=metadata) | |||||
| class Base(DeclarativeBase): | |||||
| metadata = metadata |
| db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") | db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") | ||||
| ) | ) | ||||
| @property | |||||
| def schema_type(self) -> ApiProviderSchemaType: | |||||
| return ApiProviderSchemaType.value_of(self.schema_type_str) | |||||
| @property | @property | ||||
| def user(self) -> Account | None: | def user(self) -> Account | None: | ||||
| return db.session.query(Account).filter(Account.id == self.user_id).first() | return db.session.query(Account).filter(Account.id == self.user_id).first() |
| from collections.abc import Mapping, Sequence | from collections.abc import Mapping, Sequence | ||||
| from datetime import UTC, datetime | from datetime import UTC, datetime | ||||
| from enum import Enum, StrEnum | from enum import Enum, StrEnum | ||||
| from typing import TYPE_CHECKING, Any, Optional, Self, Union | |||||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| from core.variables import utils as variable_utils | from core.variables import utils as variable_utils | ||||
| conversation_variables: Sequence[Variable], | conversation_variables: Sequence[Variable], | ||||
| marked_name: str = "", | marked_name: str = "", | ||||
| marked_comment: str = "", | marked_comment: str = "", | ||||
| ) -> Self: | |||||
| ) -> "Workflow": | |||||
| workflow = Workflow() | workflow = Workflow() | ||||
| workflow.id = str(uuid4()) | workflow.id = str(uuid4()) | ||||
| workflow.tenant_id = tenant_id | workflow.tenant_id = tenant_id |
| ): | ): | ||||
| documents: list[Document] = [] | documents: list[Document] = [] | ||||
| document: Document | None = None | |||||
| for segment in segments: | for segment in segments: | ||||
| if doc_form == IndexType.PARENT_CHILD_INDEX: | if doc_form == IndexType.PARENT_CHILD_INDEX: | ||||
| document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() | |||||
| if not document: | |||||
| dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() | |||||
| if not dataset_document: | |||||
| _logger.warning( | _logger.warning( | ||||
| "Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s", | "Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s", | ||||
| segment.document_id, | segment.document_id, | ||||
| # get the process rule | # get the process rule | ||||
| processing_rule = ( | processing_rule = ( | ||||
| db.session.query(DatasetProcessRule) | db.session.query(DatasetProcessRule) | ||||
| .filter(DatasetProcessRule.id == document.dataset_process_rule_id) | |||||
| .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||||
| .first() | .first() | ||||
| ) | ) | ||||
| if not processing_rule: | if not processing_rule: | ||||
| ) | ) | ||||
| else: | else: | ||||
| raise ValueError("The knowledge base index technique is not high quality!") | raise ValueError("The knowledge base index technique is not high quality!") | ||||
| cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False) | |||||
| cls.generate_child_chunks( | |||||
| segment, dataset_document, dataset, embedding_model_instance, processing_rule, False | |||||
| ) | |||||
| else: | else: | ||||
| document = Document( | |||||
| rag_document = Document( | |||||
| page_content=segment.content, | page_content=segment.content, | ||||
| metadata={ | metadata={ | ||||
| "doc_id": segment.index_node_id, | "doc_id": segment.index_node_id, | ||||
| "dataset_id": segment.dataset_id, | "dataset_id": segment.dataset_id, | ||||
| }, | }, | ||||
| ) | ) | ||||
| documents.append(document) | |||||
| documents.append(rag_document) | |||||
| if len(documents) > 0: | if len(documents) > 0: | ||||
| index_processor = IndexProcessorFactory(doc_form).init_index_processor() | index_processor = IndexProcessorFactory(doc_form).init_index_processor() | ||||
| index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) | index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) |
| raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") | raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") | ||||
| # Check if this workflow is currently referenced by an app | # Check if this workflow is currently referenced by an app | ||||
| stmt = select(App).where(App.workflow_id == workflow_id) | |||||
| app = session.scalar(stmt) | |||||
| app_stmt = select(App).where(App.workflow_id == workflow_id) | |||||
| app = session.scalar(app_stmt) | |||||
| if app: | if app: | ||||
| # Cannot delete a workflow that's currently in use by an app | # Cannot delete a workflow that's currently in use by an app | ||||
| raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'") | |||||
| raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'") | |||||
| # Don't use workflow.tool_published as it's not accurate for specific workflow versions | # Don't use workflow.tool_published as it's not accurate for specific workflow versions | ||||
| # Check if there's a tool provider using this specific workflow version | # Check if there's a tool provider using this specific workflow version |
| logging.exception("add document to index failed") | logging.exception("add document to index failed") | ||||
| dataset_document.enabled = False | dataset_document.enabled = False | ||||
| dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | ||||
| dataset_document.status = "error" | |||||
| dataset_document.indexing_status = "error" | |||||
| dataset_document.error = str(e) | dataset_document.error = str(e) | ||||
| db.session.commit() | db.session.commit() | ||||
| finally: | finally: |
| def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): | def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): | ||||
| # Get app's owner | # Get app's owner | ||||
| with Session(db.engine, expire_on_commit=False) as session: | with Session(db.engine, expire_on_commit=False) as session: | ||||
| stmt = select(Account).where(Account.id == App.owner_id).where(App.id == app_id) | |||||
| stmt = select(Account).where(Account.id == App.created_by).where(App.id == app_id) | |||||
| user = session.scalar(stmt) | user = session.scalar(stmt) | ||||
| if user is None: | if user is None: |
| # needs to patch those methods to avoid database access. | # needs to patch those methods to avoid database access. | ||||
| monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) | monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) | ||||
| monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) | monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) | ||||
| monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None) | |||||
| # replace `WorkflowAppGenerator.generate` 's return value. | # replace `WorkflowAppGenerator.generate` 's return value. | ||||
| monkeypatch.setattr( | monkeypatch.setattr( | ||||
| "core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", | "core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", | ||||
| lambda *args, **kwargs: {"data": {"error": "oops"}}, | lambda *args, **kwargs: {"data": {"error": "oops"}}, | ||||
| ) | ) | ||||
| monkeypatch.setattr("flask_login.current_user", lambda *args, **kwargs: None) | |||||
| with pytest.raises(ToolInvokeError) as exc_info: | with pytest.raises(ToolInvokeError) as exc_info: | ||||
| # WorkflowTool always returns a generator, so we need to iterate to | # WorkflowTool always returns a generator, so we need to iterate to |