Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>tags/1.9.0
| import services | import services | ||||
| from configs import dify_config | from configs import dify_config | ||||
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.app.error import ( | |||||
| ConversationCompletedError, | |||||
| DraftWorkflowNotExist, | |||||
| DraftWorkflowNotSync, | |||||
| ) | |||||
| from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync | |||||
| from controllers.console.app.wraps import get_app_model | from controllers.console.app.wraps import get_app_model | ||||
| from controllers.console.wraps import account_initialization_required, setup_required | from controllers.console.wraps import account_initialization_required, setup_required | ||||
| from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError | from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError |
| ProviderType, | ProviderType, | ||||
| TenantPreferredModelProvider, | TenantPreferredModelProvider, | ||||
| ) | ) | ||||
| from services.enterprise.plugin_manager_service import PluginCredentialType | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| return copy_credentials | return copy_credentials | ||||
| else: | else: | ||||
| credentials = None | credentials = None | ||||
| current_credential_id = None | |||||
| if self.custom_configuration.models: | if self.custom_configuration.models: | ||||
| for model_configuration in self.custom_configuration.models: | for model_configuration in self.custom_configuration.models: | ||||
| if model_configuration.model_type == model_type and model_configuration.model == model: | if model_configuration.model_type == model_type and model_configuration.model == model: | ||||
| credentials = model_configuration.credentials | credentials = model_configuration.credentials | ||||
| current_credential_id = model_configuration.current_credential_id | |||||
| break | break | ||||
| if not credentials and self.custom_configuration.provider: | if not credentials and self.custom_configuration.provider: | ||||
| credentials = self.custom_configuration.provider.credentials | credentials = self.custom_configuration.provider.credentials | ||||
| current_credential_id = self.custom_configuration.provider.current_credential_id | |||||
| if current_credential_id: | |||||
| from core.helper.credential_utils import check_credential_policy_compliance | |||||
| check_credential_policy_compliance( | |||||
| credential_id=current_credential_id, | |||||
| provider=self.provider.provider, | |||||
| credential_type=PluginCredentialType.MODEL, | |||||
| ) | |||||
| else: | |||||
| # no current credential id, check all available credentials | |||||
| if self.custom_configuration.provider: | |||||
| for credential_configuration in self.custom_configuration.provider.available_credentials: | |||||
| from core.helper.credential_utils import check_credential_policy_compliance | |||||
| check_credential_policy_compliance( | |||||
| credential_id=credential_configuration.credential_id, | |||||
| provider=self.provider.provider, | |||||
| credential_type=PluginCredentialType.MODEL, | |||||
| ) | |||||
| return credentials | return credentials | ||||
| :param credential_id: if provided, return the specified credential | :param credential_id: if provided, return the specified credential | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| if credential_id: | if credential_id: | ||||
| return self._get_specific_provider_credential(credential_id) | return self._get_specific_provider_credential(credential_id) | ||||
| current_credential_id = credential_record.id | current_credential_id = credential_record.id | ||||
| current_credential_name = credential_record.credential_name | current_credential_name = credential_record.credential_name | ||||
| credentials = self.obfuscated_credentials( | credentials = self.obfuscated_credentials( | ||||
| credentials=credentials, | credentials=credentials, | ||||
| credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas | credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas | ||||
| ): | ): | ||||
| current_credential_id = model_configuration.current_credential_id | current_credential_id = model_configuration.current_credential_id | ||||
| current_credential_name = model_configuration.current_credential_name | current_credential_name = model_configuration.current_credential_name | ||||
| credentials = self.obfuscated_credentials( | credentials = self.obfuscated_credentials( | ||||
| credentials=model_configuration.credentials, | credentials=model_configuration.credentials, | ||||
| credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas | credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas |
| name: str | name: str | ||||
| credentials: dict | credentials: dict | ||||
| credential_source_type: str | None = None | credential_source_type: str | None = None | ||||
| credential_id: str | None = None | |||||
| class ModelSettings(BaseModel): | class ModelSettings(BaseModel): |
| """ | |||||
| Credential utility functions for checking credential existence and policy compliance. | |||||
| """ | |||||
| from services.enterprise.plugin_manager_service import PluginCredentialType | |||||
| def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool: | |||||
| """ | |||||
| Check if the credential still exists in the database. | |||||
| :param credential_id: The credential ID to check | |||||
| :param credential_type: The type of credential (MODEL or TOOL) | |||||
| :return: True if credential exists, False otherwise | |||||
| """ | |||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session | |||||
| from extensions.ext_database import db | |||||
| from models.provider import ProviderCredential, ProviderModelCredential | |||||
| from models.tools import BuiltinToolProvider | |||||
| with Session(db.engine) as session: | |||||
| if credential_type == PluginCredentialType.MODEL: | |||||
| # Check both pre-defined and custom model credentials using a single UNION query | |||||
| stmt = ( | |||||
| select(ProviderCredential.id) | |||||
| .where(ProviderCredential.id == credential_id) | |||||
| .union(select(ProviderModelCredential.id).where(ProviderModelCredential.id == credential_id)) | |||||
| ) | |||||
| return session.scalar(stmt) is not None | |||||
| if credential_type == PluginCredentialType.TOOL: | |||||
| return ( | |||||
| session.scalar(select(BuiltinToolProvider.id).where(BuiltinToolProvider.id == credential_id)) | |||||
| is not None | |||||
| ) | |||||
| return False | |||||
| def check_credential_policy_compliance( | |||||
| credential_id: str, provider: str, credential_type: "PluginCredentialType", check_existence: bool = True | |||||
| ) -> None: | |||||
| """ | |||||
| Check credential policy compliance for the given credential ID. | |||||
| :param credential_id: The credential ID to check | |||||
| :param provider: The provider name | |||||
| :param credential_type: The type of credential (MODEL or TOOL) | |||||
| :param check_existence: Whether to check if credential exists in database first | |||||
| :raises ValueError: If credential policy compliance check fails | |||||
| """ | |||||
| from services.enterprise.plugin_manager_service import ( | |||||
| CheckCredentialPolicyComplianceRequest, | |||||
| PluginManagerService, | |||||
| ) | |||||
| from services.feature_service import FeatureService | |||||
| if not FeatureService.get_system_features().plugin_manager.enabled or not credential_id: | |||||
| return | |||||
| # Check if credential exists in database first (if requested) | |||||
| if check_existence: | |||||
| if not is_credential_exists(credential_id, credential_type): | |||||
| raise ValueError(f"Credential with id {credential_id} for provider {provider} not found.") | |||||
| # Check policy compliance | |||||
| PluginManagerService.check_credential_policy_compliance( | |||||
| CheckCredentialPolicyComplianceRequest( | |||||
| dify_credential_id=credential_id, | |||||
| provider=provider, | |||||
| credential_type=credential_type, | |||||
| ) | |||||
| ) |
| from core.provider_manager import ProviderManager | from core.provider_manager import ProviderManager | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.provider import ProviderType | from models.provider import ProviderType | ||||
| from services.enterprise.plugin_manager_service import PluginCredentialType | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| else: | else: | ||||
| raise last_exception | raise last_exception | ||||
| # Additional policy compliance check as fallback (in case fetch_next didn't catch it) | |||||
| try: | |||||
| from core.helper.credential_utils import check_credential_policy_compliance | |||||
| if lb_config.credential_id: | |||||
| check_credential_policy_compliance( | |||||
| credential_id=lb_config.credential_id, | |||||
| provider=self.provider, | |||||
| credential_type=PluginCredentialType.MODEL, | |||||
| ) | |||||
| except Exception as e: | |||||
| logger.warning( | |||||
| "Load balancing config %s failed policy compliance check in round-robin: %s", lb_config.id, str(e) | |||||
| ) | |||||
| self.load_balancing_manager.cooldown(lb_config, expire=60) | |||||
| continue | |||||
| try: | try: | ||||
| if "credentials" in kwargs: | if "credentials" in kwargs: | ||||
| del kwargs["credentials"] | del kwargs["credentials"] | ||||
| continue | continue | ||||
| # Check policy compliance for the selected configuration | |||||
| try: | |||||
| from core.helper.credential_utils import check_credential_policy_compliance | |||||
| if config.credential_id: | |||||
| check_credential_policy_compliance( | |||||
| credential_id=config.credential_id, | |||||
| provider=self._provider, | |||||
| credential_type=PluginCredentialType.MODEL, | |||||
| ) | |||||
| except Exception as e: | |||||
| logger.warning("Load balancing config %s failed policy compliance check: %s", config.id, str(e)) | |||||
| cooldown_load_balancing_configs.append(config) | |||||
| if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs): | |||||
| # all configs are in cooldown or failed policy compliance | |||||
| return None | |||||
| continue | |||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| logger.info( | logger.info( | ||||
| """Model LB | """Model LB |
| name=load_balancing_model_config.name, | name=load_balancing_model_config.name, | ||||
| credentials=provider_model_credentials, | credentials=provider_model_credentials, | ||||
| credential_source_type=load_balancing_model_config.credential_source_type, | credential_source_type=load_balancing_model_config.credential_source_type, | ||||
| credential_id=load_balancing_model_config.credential_id, | |||||
| ) | ) | ||||
| ) | ) | ||||
| pass | pass | ||||
| class ToolCredentialPolicyViolationError(ValueError): | |||||
| pass | |||||
| class ToolEngineInvokeError(Exception): | class ToolEngineInvokeError(Exception): | ||||
| meta: ToolInvokeMeta | meta: ToolInvokeMeta | ||||
| from core.tools.utils.uuid_utils import is_valid_uuid | from core.tools.utils.uuid_utils import is_valid_uuid | ||||
| from core.tools.workflow_as_tool.provider import WorkflowToolProviderController | from core.tools.workflow_as_tool.provider import WorkflowToolProviderController | ||||
| from core.workflow.entities.variable_pool import VariablePool | from core.workflow.entities.variable_pool import VariablePool | ||||
| from services.enterprise.plugin_manager_service import PluginCredentialType | |||||
| from services.tools.mcp_tools_manage_service import MCPToolManageService | from services.tools.mcp_tools_manage_service import MCPToolManageService | ||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| ) | ) | ||||
| from core.tools.errors import ToolProviderNotFoundError | from core.tools.errors import ToolProviderNotFoundError | ||||
| from core.tools.tool_label_manager import ToolLabelManager | from core.tools.tool_label_manager import ToolLabelManager | ||||
| from core.tools.utils.configuration import ( | |||||
| ToolParameterConfigurationManager, | |||||
| ) | |||||
| from core.tools.utils.configuration import ToolParameterConfigurationManager | |||||
| from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter | from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter | ||||
| from core.tools.workflow_as_tool.tool import WorkflowTool | from core.tools.workflow_as_tool.tool import WorkflowTool | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| if builtin_provider is None: | if builtin_provider is None: | ||||
| raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") | raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") | ||||
| # check if the credential is allowed to be used | |||||
| from core.helper.credential_utils import check_credential_policy_compliance | |||||
| check_credential_policy_compliance( | |||||
| credential_id=builtin_provider.id, | |||||
| provider=provider_id, | |||||
| credential_type=PluginCredentialType.TOOL, | |||||
| check_existence=False, | |||||
| ) | |||||
| encrypter, cache = create_provider_encrypter( | encrypter, cache = create_provider_encrypter( | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| config=[ | config=[ |
| import requests | import requests | ||||
| class EnterpriseRequest: | |||||
| base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") | |||||
| secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") | |||||
| class BaseRequest: | |||||
| proxies = { | proxies = { | ||||
| "http": "", | "http": "", | ||||
| "https": "", | "https": "", | ||||
| } | } | ||||
| base_url = "" | |||||
| secret_key = "" | |||||
| secret_key_header = "" | |||||
| @classmethod | @classmethod | ||||
| def send_request(cls, method, endpoint, json=None, params=None): | def send_request(cls, method, endpoint, json=None, params=None): | ||||
| headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key} | |||||
| headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} | |||||
| url = f"{cls.base_url}{endpoint}" | url = f"{cls.base_url}{endpoint}" | ||||
| response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies) | response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies) | ||||
| return response.json() | return response.json() | ||||
| class EnterpriseRequest(BaseRequest): | |||||
| base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") | |||||
| secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") | |||||
| secret_key_header = "Enterprise-Api-Secret-Key" | |||||
| class EnterprisePluginManagerRequest(BaseRequest): | |||||
| base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL") | |||||
| secret_key = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY", "ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY") | |||||
| secret_key_header = "Plugin-Manager-Inner-Api-Secret-Key" |
| import enum | |||||
| import logging | |||||
| from pydantic import BaseModel | |||||
| from services.enterprise.base import EnterprisePluginManagerRequest | |||||
| from services.errors.base import BaseServiceError | |||||
| class PluginCredentialType(enum.Enum): | |||||
| MODEL = 0 | |||||
| TOOL = 1 | |||||
| def to_number(self): | |||||
| return self.value | |||||
| class CheckCredentialPolicyComplianceRequest(BaseModel): | |||||
| dify_credential_id: str | |||||
| provider: str | |||||
| credential_type: PluginCredentialType | |||||
| def model_dump(self, **kwargs): | |||||
| data = super().model_dump(**kwargs) | |||||
| data["credential_type"] = self.credential_type.to_number() | |||||
| return data | |||||
| class CredentialPolicyViolationError(BaseServiceError): | |||||
| pass | |||||
| class PluginManagerService: | |||||
| @classmethod | |||||
| def check_credential_policy_compliance(cls, body: CheckCredentialPolicyComplianceRequest): | |||||
| try: | |||||
| ret = EnterprisePluginManagerRequest.send_request( | |||||
| "POST", "/check-credential-policy-compliance", json=body.model_dump() | |||||
| ) | |||||
| if not isinstance(ret, dict) or "result" not in ret: | |||||
| raise ValueError("Invalid response format from plugin manager API") | |||||
| except Exception as e: | |||||
| raise CredentialPolicyViolationError( | |||||
| f"error occurred while checking credential policy compliance: {e}" | |||||
| ) from e | |||||
| if not ret.get("result", False): | |||||
| raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials") | |||||
| logging.debug( | |||||
| f"Credential policy compliance checked for {body.provider} with credential {body.dify_credential_id}, result: {ret.get('result', False)}" | |||||
| ) |
| subscription_plan: str = "" | subscription_plan: str = "" | ||||
| class PluginManagerModel(BaseModel): | |||||
| enabled: bool = False | |||||
| class SystemFeatureModel(BaseModel): | class SystemFeatureModel(BaseModel): | ||||
| sso_enforced_for_signin: bool = False | sso_enforced_for_signin: bool = False | ||||
| sso_enforced_for_signin_protocol: str = "" | sso_enforced_for_signin_protocol: str = "" | ||||
| webapp_auth: WebAppAuthModel = WebAppAuthModel() | webapp_auth: WebAppAuthModel = WebAppAuthModel() | ||||
| plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() | plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() | ||||
| enable_change_email: bool = True | enable_change_email: bool = True | ||||
| plugin_manager: PluginManagerModel = PluginManagerModel() | |||||
| class FeatureService: | class FeatureService: | ||||
| system_features.branding.enabled = True | system_features.branding.enabled = True | ||||
| system_features.webapp_auth.enabled = True | system_features.webapp_auth.enabled = True | ||||
| system_features.enable_change_email = False | system_features.enable_change_email = False | ||||
| system_features.plugin_manager.enabled = True | |||||
| cls._fulfill_params_from_enterprise(system_features) | cls._fulfill_params_from_enterprise(system_features) | ||||
| if dify_config.MARKETPLACE_ENABLED: | if dify_config.MARKETPLACE_ENABLED: |
| from models.account import Account | from models.account import Account | ||||
| from models.model import App, AppMode | from models.model import App, AppMode | ||||
| from models.tools import WorkflowToolProvider | from models.tools import WorkflowToolProvider | ||||
| from models.workflow import ( | |||||
| Workflow, | |||||
| WorkflowNodeExecutionModel, | |||||
| WorkflowNodeExecutionTriggeredFrom, | |||||
| WorkflowType, | |||||
| ) | |||||
| from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType | |||||
| from repositories.factory import DifyAPIRepositoryFactory | from repositories.factory import DifyAPIRepositoryFactory | ||||
| from services.enterprise.plugin_manager_service import PluginCredentialType | |||||
| from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError | from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError | ||||
| from services.workflow.workflow_converter import WorkflowConverter | from services.workflow.workflow_converter import WorkflowConverter | ||||
| from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError | from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError | ||||
| from .workflow_draft_variable_service import ( | |||||
| DraftVariableSaver, | |||||
| DraftVarLoader, | |||||
| WorkflowDraftVariableService, | |||||
| ) | |||||
| from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService | |||||
| class WorkflowService: | class WorkflowService: | ||||
| if not draft_workflow: | if not draft_workflow: | ||||
| raise ValueError("No valid workflow found.") | raise ValueError("No valid workflow found.") | ||||
| # Validate credentials before publishing, for credential policy check | |||||
| from services.feature_service import FeatureService | |||||
| if FeatureService.get_system_features().plugin_manager.enabled: | |||||
| self._validate_workflow_credentials(draft_workflow) | |||||
| # create new workflow | # create new workflow | ||||
| workflow = Workflow.new( | workflow = Workflow.new( | ||||
| tenant_id=app_model.tenant_id, | tenant_id=app_model.tenant_id, | ||||
| # return new workflow | # return new workflow | ||||
| return workflow | return workflow | ||||
| def _validate_workflow_credentials(self, workflow: Workflow) -> None: | |||||
| """ | |||||
| Validate all credentials in workflow nodes before publishing. | |||||
| :param workflow: The workflow to validate | |||||
| :raises ValueError: If any credentials violate policy compliance | |||||
| """ | |||||
| graph_dict = workflow.graph_dict | |||||
| nodes = graph_dict.get("nodes", []) | |||||
| for node in nodes: | |||||
| node_data = node.get("data", {}) | |||||
| node_type = node_data.get("type") | |||||
| node_id = node.get("id", "unknown") | |||||
| try: | |||||
| # Extract and validate credentials based on node type | |||||
| if node_type == "tool": | |||||
| credential_id = node_data.get("credential_id") | |||||
| provider = node_data.get("provider_id") | |||||
| if provider: | |||||
| if credential_id: | |||||
| # Check specific credential | |||||
| from core.helper.credential_utils import check_credential_policy_compliance | |||||
| check_credential_policy_compliance( | |||||
| credential_id=credential_id, | |||||
| provider=provider, | |||||
| credential_type=PluginCredentialType.TOOL, | |||||
| ) | |||||
| else: | |||||
| # Check default workspace credential for this provider | |||||
| self._check_default_tool_credential(workflow.tenant_id, provider) | |||||
| elif node_type == "agent": | |||||
| agent_params = node_data.get("agent_parameters", {}) | |||||
| model_config = agent_params.get("model", {}).get("value", {}) | |||||
| if model_config.get("provider") and model_config.get("model"): | |||||
| self._validate_llm_model_config( | |||||
| workflow.tenant_id, model_config["provider"], model_config["model"] | |||||
| ) | |||||
| # Validate load balancing credentials for agent model if load balancing is enabled | |||||
| agent_model_node_data = {"model": model_config} | |||||
| self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id) | |||||
| # Validate agent tools | |||||
| tools = agent_params.get("tools", {}).get("value", []) | |||||
| for tool in tools: | |||||
| # Agent tools store provider in provider_name field | |||||
| provider = tool.get("provider_name") | |||||
| credential_id = tool.get("credential_id") | |||||
| if provider: | |||||
| if credential_id: | |||||
| from core.helper.credential_utils import check_credential_policy_compliance | |||||
| check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL) | |||||
| else: | |||||
| self._check_default_tool_credential(workflow.tenant_id, provider) | |||||
| elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]: | |||||
| model_config = node_data.get("model", {}) | |||||
| provider = model_config.get("provider") | |||||
| model_name = model_config.get("name") | |||||
| if provider and model_name: | |||||
| # Validate that the provider+model combination can fetch valid credentials | |||||
| self._validate_llm_model_config(workflow.tenant_id, provider, model_name) | |||||
| # Validate load balancing credentials if load balancing is enabled | |||||
| self._validate_load_balancing_credentials(workflow, node_data, node_id) | |||||
| else: | |||||
| raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration") | |||||
| except Exception as e: | |||||
| if isinstance(e, ValueError): | |||||
| raise e | |||||
| else: | |||||
| raise ValueError(f"Node {node_id} ({node_type}): {str(e)}") | |||||
| def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None: | |||||
| """ | |||||
| Validate that an LLM model configuration can fetch valid credentials. | |||||
| This method attempts to get the model instance and validates that: | |||||
| 1. The provider exists and is configured | |||||
| 2. The model exists in the provider | |||||
| 3. Credentials can be fetched for the model | |||||
| 4. The credentials pass policy compliance checks | |||||
| :param tenant_id: The tenant ID | |||||
| :param provider: The provider name | |||||
| :param model_name: The model name | |||||
| :raises ValueError: If the model configuration is invalid or credentials fail policy checks | |||||
| """ | |||||
| try: | |||||
| from core.model_manager import ModelManager | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| # Get model instance to validate provider+model combination | |||||
| model_manager = ModelManager() | |||||
| model_manager.get_model_instance( | |||||
| tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name | |||||
| ) | |||||
| # The ModelInstance constructor will automatically check credential policy compliance | |||||
| # via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance() | |||||
| # If it fails, an exception will be raised | |||||
| except Exception as e: | |||||
| raise ValueError( | |||||
| f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}" | |||||
| ) | |||||
| def _check_default_tool_credential(self, tenant_id: str, provider: str) -> None: | |||||
| """ | |||||
| Check credential policy compliance for the default workspace credential of a tool provider. | |||||
| This method finds the default credential for the given provider and validates it. | |||||
| Uses the same fallback logic as runtime to handle deauthorized credentials. | |||||
| :param tenant_id: The tenant ID | |||||
| :param provider: The tool provider name | |||||
| :raises ValueError: If no default credential exists or if it fails policy compliance | |||||
| """ | |||||
| try: | |||||
| from models.tools import BuiltinToolProvider | |||||
| # Use the same fallback logic as runtime: get the first available credential | |||||
| # ordered by is_default DESC, created_at ASC (same as tool_manager.py) | |||||
| default_provider = ( | |||||
| db.session.query(BuiltinToolProvider) | |||||
| .where( | |||||
| BuiltinToolProvider.tenant_id == tenant_id, | |||||
| BuiltinToolProvider.provider == provider, | |||||
| ) | |||||
| .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) | |||||
| .first() | |||||
| ) | |||||
| if not default_provider: | |||||
| raise ValueError("No default credential found") | |||||
| # Check credential policy compliance using the default credential ID | |||||
| from core.helper.credential_utils import check_credential_policy_compliance | |||||
| check_credential_policy_compliance( | |||||
| credential_id=default_provider.id, | |||||
| provider=provider, | |||||
| credential_type=PluginCredentialType.TOOL, | |||||
| check_existence=False, | |||||
| ) | |||||
| except Exception as e: | |||||
| raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}") | |||||
| def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None: | |||||
| """ | |||||
| Validate load balancing credentials for a workflow node. | |||||
| :param workflow: The workflow being validated | |||||
| :param node_data: The node data containing model configuration | |||||
| :param node_id: The node ID for error reporting | |||||
| :raises ValueError: If load balancing credentials violate policy compliance | |||||
| """ | |||||
| # Extract model configuration | |||||
| model_config = node_data.get("model", {}) | |||||
| provider = model_config.get("provider") | |||||
| model_name = model_config.get("name") | |||||
| if not provider or not model_name: | |||||
| return # No model config to validate | |||||
| # Check if this model has load balancing enabled | |||||
| if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name): | |||||
| # Get all load balancing configurations for this model | |||||
| load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name) | |||||
| # Validate each load balancing configuration | |||||
| try: | |||||
| for config in load_balancing_configs: | |||||
| if config.get("credential_id"): | |||||
| from core.helper.credential_utils import check_credential_policy_compliance | |||||
| check_credential_policy_compliance( | |||||
| config["credential_id"], provider, PluginCredentialType.MODEL | |||||
| ) | |||||
| except Exception as e: | |||||
| raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}") | |||||
| def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool: | |||||
| """ | |||||
| Check if load balancing is enabled for a specific model. | |||||
| :param tenant_id: The tenant ID | |||||
| :param provider: The provider name | |||||
| :param model_name: The model name | |||||
| :return: True if load balancing is enabled, False otherwise | |||||
| """ | |||||
| try: | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.provider_manager import ProviderManager | |||||
| # Get provider configurations | |||||
| provider_manager = ProviderManager() | |||||
| provider_configurations = provider_manager.get_configurations(tenant_id) | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| return False | |||||
| # Get provider model setting | |||||
| provider_model_setting = provider_configuration.get_provider_model_setting( | |||||
| model_type=ModelType.LLM, | |||||
| model=model_name, | |||||
| ) | |||||
| return provider_model_setting is not None and provider_model_setting.load_balancing_enabled | |||||
| except Exception: | |||||
| # If we can't determine the status, assume load balancing is not enabled | |||||
| return False | |||||
| def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]: | |||||
| """ | |||||
| Get all load balancing configurations for a model. | |||||
| :param tenant_id: The tenant ID | |||||
| :param provider: The provider name | |||||
| :param model_name: The model name | |||||
| :return: List of load balancing configuration dictionaries | |||||
| """ | |||||
| try: | |||||
| from services.model_load_balancing_service import ModelLoadBalancingService | |||||
| model_load_balancing_service = ModelLoadBalancingService() | |||||
| _, configs = model_load_balancing_service.get_load_balancing_configs( | |||||
| tenant_id=tenant_id, | |||||
| provider=provider, | |||||
| model=model_name, | |||||
| model_type="llm", # Load balancing is primarily used for LLM models | |||||
| config_from="predefined-model", # Check both predefined and custom models | |||||
| ) | |||||
| _, custom_configs = model_load_balancing_service.get_load_balancing_configs( | |||||
| tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model" | |||||
| ) | |||||
| all_configs = configs + custom_configs | |||||
| return [config for config in all_configs if config.get("credential_id")] | |||||
| except Exception: | |||||
| # If we can't get the configurations, return empty list | |||||
| # This will prevent validation errors from breaking the workflow | |||||
| return [] | |||||
| def get_default_block_configs(self) -> list[dict]: | def get_default_block_configs(self) -> list[dict]: | ||||
| """ | """ | ||||
| Get default block configs | Get default block configs |