Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>tags/1.9.0
| @@ -11,11 +11,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | |||
| import services | |||
| from configs import dify_config | |||
| from controllers.console import api | |||
| from controllers.console.app.error import ( | |||
| ConversationCompletedError, | |||
| DraftWorkflowNotExist, | |||
| DraftWorkflowNotSync, | |||
| ) | |||
| from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync | |||
| from controllers.console.app.wraps import get_app_model | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError | |||
| @@ -42,6 +42,7 @@ from models.provider import ( | |||
| ProviderType, | |||
| TenantPreferredModelProvider, | |||
| ) | |||
| from services.enterprise.plugin_manager_service import PluginCredentialType | |||
| logger = logging.getLogger(__name__) | |||
| @@ -129,14 +130,38 @@ class ProviderConfiguration(BaseModel): | |||
| return copy_credentials | |||
| else: | |||
| credentials = None | |||
| current_credential_id = None | |||
| if self.custom_configuration.models: | |||
| for model_configuration in self.custom_configuration.models: | |||
| if model_configuration.model_type == model_type and model_configuration.model == model: | |||
| credentials = model_configuration.credentials | |||
| current_credential_id = model_configuration.current_credential_id | |||
| break | |||
| if not credentials and self.custom_configuration.provider: | |||
| credentials = self.custom_configuration.provider.credentials | |||
| current_credential_id = self.custom_configuration.provider.current_credential_id | |||
| if current_credential_id: | |||
| from core.helper.credential_utils import check_credential_policy_compliance | |||
| check_credential_policy_compliance( | |||
| credential_id=current_credential_id, | |||
| provider=self.provider.provider, | |||
| credential_type=PluginCredentialType.MODEL, | |||
| ) | |||
| else: | |||
| # no current credential id, check all available credentials | |||
| if self.custom_configuration.provider: | |||
| for credential_configuration in self.custom_configuration.provider.available_credentials: | |||
| from core.helper.credential_utils import check_credential_policy_compliance | |||
| check_credential_policy_compliance( | |||
| credential_id=credential_configuration.credential_id, | |||
| provider=self.provider.provider, | |||
| credential_type=PluginCredentialType.MODEL, | |||
| ) | |||
| return credentials | |||
| @@ -266,7 +291,6 @@ class ProviderConfiguration(BaseModel): | |||
| :param credential_id: if provided, return the specified credential | |||
| :return: | |||
| """ | |||
| if credential_id: | |||
| return self._get_specific_provider_credential(credential_id) | |||
| @@ -738,6 +762,7 @@ class ProviderConfiguration(BaseModel): | |||
| current_credential_id = credential_record.id | |||
| current_credential_name = credential_record.credential_name | |||
| credentials = self.obfuscated_credentials( | |||
| credentials=credentials, | |||
| credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas | |||
| @@ -792,6 +817,7 @@ class ProviderConfiguration(BaseModel): | |||
| ): | |||
| current_credential_id = model_configuration.current_credential_id | |||
| current_credential_name = model_configuration.current_credential_name | |||
| credentials = self.obfuscated_credentials( | |||
| credentials=model_configuration.credentials, | |||
| credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas | |||
| @@ -145,6 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel): | |||
| name: str | |||
| credentials: dict | |||
| credential_source_type: str | None = None | |||
| credential_id: str | None = None | |||
| class ModelSettings(BaseModel): | |||
| @@ -0,0 +1,75 @@ | |||
| """ | |||
| Credential utility functions for checking credential existence and policy compliance. | |||
| """ | |||
| from services.enterprise.plugin_manager_service import PluginCredentialType | |||
| def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool: | |||
| """ | |||
| Check if the credential still exists in the database. | |||
| :param credential_id: The credential ID to check | |||
| :param credential_type: The type of credential (MODEL or TOOL) | |||
| :return: True if credential exists, False otherwise | |||
| """ | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from extensions.ext_database import db | |||
| from models.provider import ProviderCredential, ProviderModelCredential | |||
| from models.tools import BuiltinToolProvider | |||
| with Session(db.engine) as session: | |||
| if credential_type == PluginCredentialType.MODEL: | |||
| # Check both pre-defined and custom model credentials using a single UNION query | |||
| stmt = ( | |||
| select(ProviderCredential.id) | |||
| .where(ProviderCredential.id == credential_id) | |||
| .union(select(ProviderModelCredential.id).where(ProviderModelCredential.id == credential_id)) | |||
| ) | |||
| return session.scalar(stmt) is not None | |||
| if credential_type == PluginCredentialType.TOOL: | |||
| return ( | |||
| session.scalar(select(BuiltinToolProvider.id).where(BuiltinToolProvider.id == credential_id)) | |||
| is not None | |||
| ) | |||
| return False | |||
| def check_credential_policy_compliance( | |||
| credential_id: str, provider: str, credential_type: "PluginCredentialType", check_existence: bool = True | |||
| ) -> None: | |||
| """ | |||
| Check credential policy compliance for the given credential ID. | |||
| :param credential_id: The credential ID to check | |||
| :param provider: The provider name | |||
| :param credential_type: The type of credential (MODEL or TOOL) | |||
| :param check_existence: Whether to check if credential exists in database first | |||
| :raises ValueError: If credential policy compliance check fails | |||
| """ | |||
| from services.enterprise.plugin_manager_service import ( | |||
| CheckCredentialPolicyComplianceRequest, | |||
| PluginManagerService, | |||
| ) | |||
| from services.feature_service import FeatureService | |||
| if not FeatureService.get_system_features().plugin_manager.enabled or not credential_id: | |||
| return | |||
| # Check if credential exists in database first (if requested) | |||
| if check_existence: | |||
| if not is_credential_exists(credential_id, credential_type): | |||
| raise ValueError(f"Credential with id {credential_id} for provider {provider} not found.") | |||
| # Check policy compliance | |||
| PluginManagerService.check_credential_policy_compliance( | |||
| CheckCredentialPolicyComplianceRequest( | |||
| dify_credential_id=credential_id, | |||
| provider=provider, | |||
| credential_type=credential_type, | |||
| ) | |||
| ) | |||
| @@ -23,6 +23,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel | |||
| from core.provider_manager import ProviderManager | |||
| from extensions.ext_redis import redis_client | |||
| from models.provider import ProviderType | |||
| from services.enterprise.plugin_manager_service import PluginCredentialType | |||
| logger = logging.getLogger(__name__) | |||
| @@ -362,6 +363,23 @@ class ModelInstance: | |||
| else: | |||
| raise last_exception | |||
| # Additional policy compliance check as fallback (in case fetch_next didn't catch it) | |||
| try: | |||
| from core.helper.credential_utils import check_credential_policy_compliance | |||
| if lb_config.credential_id: | |||
| check_credential_policy_compliance( | |||
| credential_id=lb_config.credential_id, | |||
| provider=self.provider, | |||
| credential_type=PluginCredentialType.MODEL, | |||
| ) | |||
| except Exception as e: | |||
| logger.warning( | |||
| "Load balancing config %s failed policy compliance check in round-robin: %s", lb_config.id, str(e) | |||
| ) | |||
| self.load_balancing_manager.cooldown(lb_config, expire=60) | |||
| continue | |||
| try: | |||
| if "credentials" in kwargs: | |||
| del kwargs["credentials"] | |||
| @@ -515,6 +533,24 @@ class LBModelManager: | |||
| continue | |||
| # Check policy compliance for the selected configuration | |||
| try: | |||
| from core.helper.credential_utils import check_credential_policy_compliance | |||
| if config.credential_id: | |||
| check_credential_policy_compliance( | |||
| credential_id=config.credential_id, | |||
| provider=self._provider, | |||
| credential_type=PluginCredentialType.MODEL, | |||
| ) | |||
| except Exception as e: | |||
| logger.warning("Load balancing config %s failed policy compliance check: %s", config.id, str(e)) | |||
| cooldown_load_balancing_configs.append(config) | |||
| if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs): | |||
| # all configs are in cooldown or failed policy compliance | |||
| return None | |||
| continue | |||
| if dify_config.DEBUG: | |||
| logger.info( | |||
| """Model LB | |||
| @@ -1129,6 +1129,7 @@ class ProviderManager: | |||
| name=load_balancing_model_config.name, | |||
| credentials=provider_model_credentials, | |||
| credential_source_type=load_balancing_model_config.credential_source_type, | |||
| credential_id=load_balancing_model_config.credential_id, | |||
| ) | |||
| ) | |||
| @@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError): | |||
| pass | |||
| class ToolCredentialPolicyViolationError(ValueError): | |||
| pass | |||
| class ToolEngineInvokeError(Exception): | |||
| meta: ToolInvokeMeta | |||
| @@ -27,6 +27,7 @@ from core.tools.plugin_tool.tool import PluginTool | |||
| from core.tools.utils.uuid_utils import is_valid_uuid | |||
| from core.tools.workflow_as_tool.provider import WorkflowToolProviderController | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from services.enterprise.plugin_manager_service import PluginCredentialType | |||
| from services.tools.mcp_tools_manage_service import MCPToolManageService | |||
| if TYPE_CHECKING: | |||
| @@ -55,9 +56,7 @@ from core.tools.entities.tool_entities import ( | |||
| ) | |||
| from core.tools.errors import ToolProviderNotFoundError | |||
| from core.tools.tool_label_manager import ToolLabelManager | |||
| from core.tools.utils.configuration import ( | |||
| ToolParameterConfigurationManager, | |||
| ) | |||
| from core.tools.utils.configuration import ToolParameterConfigurationManager | |||
| from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter | |||
| from core.tools.workflow_as_tool.tool import WorkflowTool | |||
| from extensions.ext_database import db | |||
| @@ -237,6 +236,16 @@ class ToolManager: | |||
| if builtin_provider is None: | |||
| raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") | |||
| # check if the credential is allowed to be used | |||
| from core.helper.credential_utils import check_credential_policy_compliance | |||
| check_credential_policy_compliance( | |||
| credential_id=builtin_provider.id, | |||
| provider=provider_id, | |||
| credential_type=PluginCredentialType.TOOL, | |||
| check_existence=False, | |||
| ) | |||
| encrypter, cache = create_provider_encrypter( | |||
| tenant_id=tenant_id, | |||
| config=[ | |||
| @@ -3,18 +3,30 @@ import os | |||
| import requests | |||
| class EnterpriseRequest: | |||
| base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") | |||
| secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") | |||
| class BaseRequest: | |||
| proxies = { | |||
| "http": "", | |||
| "https": "", | |||
| } | |||
| base_url = "" | |||
| secret_key = "" | |||
| secret_key_header = "" | |||
| @classmethod | |||
| def send_request(cls, method, endpoint, json=None, params=None): | |||
| headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key} | |||
| headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} | |||
| url = f"{cls.base_url}{endpoint}" | |||
| response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies) | |||
| return response.json() | |||
| class EnterpriseRequest(BaseRequest): | |||
| base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") | |||
| secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") | |||
| secret_key_header = "Enterprise-Api-Secret-Key" | |||
| class EnterprisePluginManagerRequest(BaseRequest): | |||
| base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL") | |||
| secret_key = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY", "ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY") | |||
| secret_key_header = "Plugin-Manager-Inner-Api-Secret-Key" | |||
| @@ -0,0 +1,52 @@ | |||
| import enum | |||
| import logging | |||
| from pydantic import BaseModel | |||
| from services.enterprise.base import EnterprisePluginManagerRequest | |||
| from services.errors.base import BaseServiceError | |||
| class PluginCredentialType(enum.Enum): | |||
| MODEL = 0 | |||
| TOOL = 1 | |||
| def to_number(self): | |||
| return self.value | |||
| class CheckCredentialPolicyComplianceRequest(BaseModel): | |||
| dify_credential_id: str | |||
| provider: str | |||
| credential_type: PluginCredentialType | |||
| def model_dump(self, **kwargs): | |||
| data = super().model_dump(**kwargs) | |||
| data["credential_type"] = self.credential_type.to_number() | |||
| return data | |||
| class CredentialPolicyViolationError(BaseServiceError): | |||
| pass | |||
| class PluginManagerService: | |||
| @classmethod | |||
| def check_credential_policy_compliance(cls, body: CheckCredentialPolicyComplianceRequest): | |||
| try: | |||
| ret = EnterprisePluginManagerRequest.send_request( | |||
| "POST", "/check-credential-policy-compliance", json=body.model_dump() | |||
| ) | |||
| if not isinstance(ret, dict) or "result" not in ret: | |||
| raise ValueError("Invalid response format from plugin manager API") | |||
| except Exception as e: | |||
| raise CredentialPolicyViolationError( | |||
| f"error occurred while checking credential policy compliance: {e}" | |||
| ) from e | |||
| if not ret.get("result", False): | |||
| raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials") | |||
| logging.debug( | |||
| f"Credential policy compliance checked for {body.provider} with credential {body.dify_credential_id}, result: {ret.get('result', False)}" | |||
| ) | |||
| @@ -134,6 +134,10 @@ class KnowledgeRateLimitModel(BaseModel): | |||
| subscription_plan: str = "" | |||
| class PluginManagerModel(BaseModel): | |||
| enabled: bool = False | |||
| class SystemFeatureModel(BaseModel): | |||
| sso_enforced_for_signin: bool = False | |||
| sso_enforced_for_signin_protocol: str = "" | |||
| @@ -150,6 +154,7 @@ class SystemFeatureModel(BaseModel): | |||
| webapp_auth: WebAppAuthModel = WebAppAuthModel() | |||
| plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() | |||
| enable_change_email: bool = True | |||
| plugin_manager: PluginManagerModel = PluginManagerModel() | |||
| class FeatureService: | |||
| @@ -188,6 +193,7 @@ class FeatureService: | |||
| system_features.branding.enabled = True | |||
| system_features.webapp_auth.enabled = True | |||
| system_features.enable_change_email = False | |||
| system_features.plugin_manager.enabled = True | |||
| cls._fulfill_params_from_enterprise(system_features) | |||
| if dify_config.MARKETPLACE_ENABLED: | |||
| @@ -36,22 +36,14 @@ from libs.datetime_utils import naive_utc_now | |||
| from models.account import Account | |||
| from models.model import App, AppMode | |||
| from models.tools import WorkflowToolProvider | |||
| from models.workflow import ( | |||
| Workflow, | |||
| WorkflowNodeExecutionModel, | |||
| WorkflowNodeExecutionTriggeredFrom, | |||
| WorkflowType, | |||
| ) | |||
| from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType | |||
| from repositories.factory import DifyAPIRepositoryFactory | |||
| from services.enterprise.plugin_manager_service import PluginCredentialType | |||
| from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError | |||
| from services.workflow.workflow_converter import WorkflowConverter | |||
| from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError | |||
| from .workflow_draft_variable_service import ( | |||
| DraftVariableSaver, | |||
| DraftVarLoader, | |||
| WorkflowDraftVariableService, | |||
| ) | |||
| from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService | |||
| class WorkflowService: | |||
| @@ -271,6 +263,12 @@ class WorkflowService: | |||
| if not draft_workflow: | |||
| raise ValueError("No valid workflow found.") | |||
| # Validate credentials before publishing, for credential policy check | |||
| from services.feature_service import FeatureService | |||
| if FeatureService.get_system_features().plugin_manager.enabled: | |||
| self._validate_workflow_credentials(draft_workflow) | |||
| # create new workflow | |||
| workflow = Workflow.new( | |||
| tenant_id=app_model.tenant_id, | |||
| @@ -295,6 +293,260 @@ class WorkflowService: | |||
| # return new workflow | |||
| return workflow | |||
| def _validate_workflow_credentials(self, workflow: Workflow) -> None: | |||
| """ | |||
| Validate all credentials in workflow nodes before publishing. | |||
| :param workflow: The workflow to validate | |||
| :raises ValueError: If any credentials violate policy compliance | |||
| """ | |||
| graph_dict = workflow.graph_dict | |||
| nodes = graph_dict.get("nodes", []) | |||
| for node in nodes: | |||
| node_data = node.get("data", {}) | |||
| node_type = node_data.get("type") | |||
| node_id = node.get("id", "unknown") | |||
| try: | |||
| # Extract and validate credentials based on node type | |||
| if node_type == "tool": | |||
| credential_id = node_data.get("credential_id") | |||
| provider = node_data.get("provider_id") | |||
| if provider: | |||
| if credential_id: | |||
| # Check specific credential | |||
| from core.helper.credential_utils import check_credential_policy_compliance | |||
| check_credential_policy_compliance( | |||
| credential_id=credential_id, | |||
| provider=provider, | |||
| credential_type=PluginCredentialType.TOOL, | |||
| ) | |||
| else: | |||
| # Check default workspace credential for this provider | |||
| self._check_default_tool_credential(workflow.tenant_id, provider) | |||
| elif node_type == "agent": | |||
| agent_params = node_data.get("agent_parameters", {}) | |||
| model_config = agent_params.get("model", {}).get("value", {}) | |||
| if model_config.get("provider") and model_config.get("model"): | |||
| self._validate_llm_model_config( | |||
| workflow.tenant_id, model_config["provider"], model_config["model"] | |||
| ) | |||
| # Validate load balancing credentials for agent model if load balancing is enabled | |||
| agent_model_node_data = {"model": model_config} | |||
| self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id) | |||
| # Validate agent tools | |||
| tools = agent_params.get("tools", {}).get("value", []) | |||
| for tool in tools: | |||
| # Agent tools store provider in provider_name field | |||
| provider = tool.get("provider_name") | |||
| credential_id = tool.get("credential_id") | |||
| if provider: | |||
| if credential_id: | |||
| from core.helper.credential_utils import check_credential_policy_compliance | |||
| check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL) | |||
| else: | |||
| self._check_default_tool_credential(workflow.tenant_id, provider) | |||
| elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]: | |||
| model_config = node_data.get("model", {}) | |||
| provider = model_config.get("provider") | |||
| model_name = model_config.get("name") | |||
| if provider and model_name: | |||
| # Validate that the provider+model combination can fetch valid credentials | |||
| self._validate_llm_model_config(workflow.tenant_id, provider, model_name) | |||
| # Validate load balancing credentials if load balancing is enabled | |||
| self._validate_load_balancing_credentials(workflow, node_data, node_id) | |||
| else: | |||
| raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration") | |||
| except Exception as e: | |||
| if isinstance(e, ValueError): | |||
| raise e | |||
| else: | |||
| raise ValueError(f"Node {node_id} ({node_type}): {str(e)}") | |||
| def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None: | |||
| """ | |||
| Validate that an LLM model configuration can fetch valid credentials. | |||
| This method attempts to get the model instance and validates that: | |||
| 1. The provider exists and is configured | |||
| 2. The model exists in the provider | |||
| 3. Credentials can be fetched for the model | |||
| 4. The credentials pass policy compliance checks | |||
| :param tenant_id: The tenant ID | |||
| :param provider: The provider name | |||
| :param model_name: The model name | |||
| :raises ValueError: If the model configuration is invalid or credentials fail policy checks | |||
| """ | |||
| try: | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| # Get model instance to validate provider+model combination | |||
| model_manager = ModelManager() | |||
| model_manager.get_model_instance( | |||
| tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name | |||
| ) | |||
| # The ModelInstance constructor will automatically check credential policy compliance | |||
| # via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance() | |||
| # If it fails, an exception will be raised | |||
| except Exception as e: | |||
| raise ValueError( | |||
| f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}" | |||
| ) | |||
| def _check_default_tool_credential(self, tenant_id: str, provider: str) -> None: | |||
| """ | |||
| Check credential policy compliance for the default workspace credential of a tool provider. | |||
| This method finds the default credential for the given provider and validates it. | |||
| Uses the same fallback logic as runtime to handle deauthorized credentials. | |||
| :param tenant_id: The tenant ID | |||
| :param provider: The tool provider name | |||
| :raises ValueError: If no default credential exists or if it fails policy compliance | |||
| """ | |||
| try: | |||
| from models.tools import BuiltinToolProvider | |||
| # Use the same fallback logic as runtime: get the first available credential | |||
| # ordered by is_default DESC, created_at ASC (same as tool_manager.py) | |||
| default_provider = ( | |||
| db.session.query(BuiltinToolProvider) | |||
| .where( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.provider == provider, | |||
| ) | |||
| .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) | |||
| .first() | |||
| ) | |||
| if not default_provider: | |||
| raise ValueError("No default credential found") | |||
| # Check credential policy compliance using the default credential ID | |||
| from core.helper.credential_utils import check_credential_policy_compliance | |||
| check_credential_policy_compliance( | |||
| credential_id=default_provider.id, | |||
| provider=provider, | |||
| credential_type=PluginCredentialType.TOOL, | |||
| check_existence=False, | |||
| ) | |||
| except Exception as e: | |||
| raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}") | |||
| def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None: | |||
| """ | |||
| Validate load balancing credentials for a workflow node. | |||
| :param workflow: The workflow being validated | |||
| :param node_data: The node data containing model configuration | |||
| :param node_id: The node ID for error reporting | |||
| :raises ValueError: If load balancing credentials violate policy compliance | |||
| """ | |||
| # Extract model configuration | |||
| model_config = node_data.get("model", {}) | |||
| provider = model_config.get("provider") | |||
| model_name = model_config.get("name") | |||
| if not provider or not model_name: | |||
| return # No model config to validate | |||
| # Check if this model has load balancing enabled | |||
| if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name): | |||
| # Get all load balancing configurations for this model | |||
| load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name) | |||
| # Validate each load balancing configuration | |||
| try: | |||
| for config in load_balancing_configs: | |||
| if config.get("credential_id"): | |||
| from core.helper.credential_utils import check_credential_policy_compliance | |||
| check_credential_policy_compliance( | |||
| config["credential_id"], provider, PluginCredentialType.MODEL | |||
| ) | |||
| except Exception as e: | |||
| raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}") | |||
| def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool: | |||
| """ | |||
| Check if load balancing is enabled for a specific model. | |||
| :param tenant_id: The tenant ID | |||
| :param provider: The provider name | |||
| :param model_name: The model name | |||
| :return: True if load balancing is enabled, False otherwise | |||
| """ | |||
| try: | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.provider_manager import ProviderManager | |||
| # Get provider configurations | |||
| provider_manager = ProviderManager() | |||
| provider_configurations = provider_manager.get_configurations(tenant_id) | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| return False | |||
| # Get provider model setting | |||
| provider_model_setting = provider_configuration.get_provider_model_setting( | |||
| model_type=ModelType.LLM, | |||
| model=model_name, | |||
| ) | |||
| return provider_model_setting is not None and provider_model_setting.load_balancing_enabled | |||
| except Exception: | |||
| # If we can't determine the status, assume load balancing is not enabled | |||
| return False | |||
| def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]: | |||
| """ | |||
| Get all load balancing configurations for a model. | |||
| :param tenant_id: The tenant ID | |||
| :param provider: The provider name | |||
| :param model_name: The model name | |||
| :return: List of load balancing configuration dictionaries | |||
| """ | |||
| try: | |||
| from services.model_load_balancing_service import ModelLoadBalancingService | |||
| model_load_balancing_service = ModelLoadBalancingService() | |||
| _, configs = model_load_balancing_service.get_load_balancing_configs( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=model_name, | |||
| model_type="llm", # Load balancing is primarily used for LLM models | |||
| config_from="predefined-model", # Check both predefined and custom models | |||
| ) | |||
| _, custom_configs = model_load_balancing_service.get_load_balancing_configs( | |||
| tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model" | |||
| ) | |||
| all_configs = configs + custom_configs | |||
| return [config for config in all_configs if config.get("credential_id")] | |||
| except Exception: | |||
| # If we can't get the configurations, return empty list | |||
| # This will prevent validation errors from breaking the workflow | |||
| return [] | |||
| def get_default_block_configs(self) -> list[dict]: | |||
| """ | |||
| Get default block configs | |||