Kaynağa Gözat

Feat/credential policy (#25151)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
tags/1.9.0
Xiyuan Chen 1 ay önce
ebeveyn
işleme
64c9a2f678
No account linked to committer's email address

+ 1
- 5
api/controllers/console/app/workflow.py Dosyayı Görüntüle

@@ -11,11 +11,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from configs import dify_config
from controllers.console import api
from controllers.console.app.error import (
ConversationCompletedError,
DraftWorkflowNotExist,
DraftWorkflowNotSync,
)
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError

+ 27
- 1
api/core/entities/provider_configuration.py Dosyayı Görüntüle

@@ -42,6 +42,7 @@ from models.provider import (
ProviderType,
TenantPreferredModelProvider,
)
from services.enterprise.plugin_manager_service import PluginCredentialType

logger = logging.getLogger(__name__)

@@ -129,14 +130,38 @@ class ProviderConfiguration(BaseModel):
return copy_credentials
else:
credentials = None
current_credential_id = None

if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
credentials = model_configuration.credentials
current_credential_id = model_configuration.current_credential_id
break

if not credentials and self.custom_configuration.provider:
credentials = self.custom_configuration.provider.credentials
current_credential_id = self.custom_configuration.provider.current_credential_id

if current_credential_id:
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(
credential_id=current_credential_id,
provider=self.provider.provider,
credential_type=PluginCredentialType.MODEL,
)
else:
# no current credential id, check all available credentials
if self.custom_configuration.provider:
for credential_configuration in self.custom_configuration.provider.available_credentials:
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(
credential_id=credential_configuration.credential_id,
provider=self.provider.provider,
credential_type=PluginCredentialType.MODEL,
)

return credentials

@@ -266,7 +291,6 @@ class ProviderConfiguration(BaseModel):
:param credential_id: if provided, return the specified credential
:return:
"""

if credential_id:
return self._get_specific_provider_credential(credential_id)

@@ -738,6 +762,7 @@ class ProviderConfiguration(BaseModel):

current_credential_id = credential_record.id
current_credential_name = credential_record.credential_name

credentials = self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
@@ -792,6 +817,7 @@ class ProviderConfiguration(BaseModel):
):
current_credential_id = model_configuration.current_credential_id
current_credential_name = model_configuration.current_credential_name

credentials = self.obfuscated_credentials(
credentials=model_configuration.credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas

+ 1
- 0
api/core/entities/provider_entities.py Dosyayı Görüntüle

@@ -145,6 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
name: str
credentials: dict
credential_source_type: str | None = None
credential_id: str | None = None


class ModelSettings(BaseModel):

+ 75
- 0
api/core/helper/credential_utils.py Dosyayı Görüntüle

@@ -0,0 +1,75 @@
"""
Credential utility functions for checking credential existence and policy compliance.
"""

from services.enterprise.plugin_manager_service import PluginCredentialType


def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:
"""
Check if the credential still exists in the database.

:param credential_id: The credential ID to check
:param credential_type: The type of credential (MODEL or TOOL)
:return: True if credential exists, False otherwise
"""
from sqlalchemy import select
from sqlalchemy.orm import Session

from extensions.ext_database import db
from models.provider import ProviderCredential, ProviderModelCredential
from models.tools import BuiltinToolProvider

with Session(db.engine) as session:
if credential_type == PluginCredentialType.MODEL:
# Check both pre-defined and custom model credentials using a single UNION query
stmt = (
select(ProviderCredential.id)
.where(ProviderCredential.id == credential_id)
.union(select(ProviderModelCredential.id).where(ProviderModelCredential.id == credential_id))
)
return session.scalar(stmt) is not None

if credential_type == PluginCredentialType.TOOL:
return (
session.scalar(select(BuiltinToolProvider.id).where(BuiltinToolProvider.id == credential_id))
is not None
)

return False


def check_credential_policy_compliance(
credential_id: str, provider: str, credential_type: "PluginCredentialType", check_existence: bool = True
) -> None:
"""
Check credential policy compliance for the given credential ID.

:param credential_id: The credential ID to check
:param provider: The provider name
:param credential_type: The type of credential (MODEL or TOOL)
:param check_existence: Whether to check if credential exists in database first
:raises ValueError: If credential policy compliance check fails
"""
from services.enterprise.plugin_manager_service import (
CheckCredentialPolicyComplianceRequest,
PluginManagerService,
)
from services.feature_service import FeatureService

if not FeatureService.get_system_features().plugin_manager.enabled or not credential_id:
return

# Check if credential exists in database first (if requested)
if check_existence:
if not is_credential_exists(credential_id, credential_type):
raise ValueError(f"Credential with id {credential_id} for provider {provider} not found.")

# Check policy compliance
PluginManagerService.check_credential_policy_compliance(
CheckCredentialPolicyComplianceRequest(
dify_credential_id=credential_id,
provider=provider,
credential_type=credential_type,
)
)

+ 36
- 0
api/core/model_manager.py Dosyayı Görüntüle

@@ -23,6 +23,7 @@ from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.provider_manager import ProviderManager
from extensions.ext_redis import redis_client
from models.provider import ProviderType
from services.enterprise.plugin_manager_service import PluginCredentialType

logger = logging.getLogger(__name__)

@@ -362,6 +363,23 @@ class ModelInstance:
else:
raise last_exception

# Additional policy compliance check as fallback (in case fetch_next didn't catch it)
try:
from core.helper.credential_utils import check_credential_policy_compliance

if lb_config.credential_id:
check_credential_policy_compliance(
credential_id=lb_config.credential_id,
provider=self.provider,
credential_type=PluginCredentialType.MODEL,
)
except Exception as e:
logger.warning(
"Load balancing config %s failed policy compliance check in round-robin: %s", lb_config.id, str(e)
)
self.load_balancing_manager.cooldown(lb_config, expire=60)
continue

try:
if "credentials" in kwargs:
del kwargs["credentials"]
@@ -515,6 +533,24 @@ class LBModelManager:

continue

# Check policy compliance for the selected configuration
try:
from core.helper.credential_utils import check_credential_policy_compliance

if config.credential_id:
check_credential_policy_compliance(
credential_id=config.credential_id,
provider=self._provider,
credential_type=PluginCredentialType.MODEL,
)
except Exception as e:
logger.warning("Load balancing config %s failed policy compliance check: %s", config.id, str(e))
cooldown_load_balancing_configs.append(config)
if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs):
# all configs are in cooldown or failed policy compliance
return None
continue

if dify_config.DEBUG:
logger.info(
"""Model LB

+ 1
- 0
api/core/provider_manager.py Dosyayı Görüntüle

@@ -1129,6 +1129,7 @@ class ProviderManager:
name=load_balancing_model_config.name,
credentials=provider_model_credentials,
credential_source_type=load_balancing_model_config.credential_source_type,
credential_id=load_balancing_model_config.credential_id,
)
)


+ 4
- 0
api/core/tools/errors.py Dosyayı Görüntüle

@@ -29,6 +29,10 @@ class ToolApiSchemaError(ValueError):
pass


class ToolCredentialPolicyViolationError(ValueError):
pass


class ToolEngineInvokeError(Exception):
meta: ToolInvokeMeta


+ 12
- 3
api/core/tools/tool_manager.py Dosyayı Görüntüle

@@ -27,6 +27,7 @@ from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService

if TYPE_CHECKING:
@@ -55,9 +56,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ToolParameterConfigurationManager,
)
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
@@ -237,6 +236,16 @@ class ToolManager:
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")

# check if the credential is allowed to be used
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(
credential_id=builtin_provider.id,
provider=provider_id,
credential_type=PluginCredentialType.TOOL,
check_existence=False,
)

encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[

+ 17
- 5
api/services/enterprise/base.py Dosyayı Görüntüle

@@ -3,18 +3,30 @@ import os
import requests


class EnterpriseRequest:
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")

class BaseRequest:
proxies = {
"http": "",
"https": "",
}
base_url = ""
secret_key = ""
secret_key_header = ""

@classmethod
def send_request(cls, method, endpoint, json=None, params=None):
headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key}
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies)
return response.json()


class EnterpriseRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
secret_key_header = "Enterprise-Api-Secret-Key"


class EnterprisePluginManagerRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL")
secret_key = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY", "ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY")
secret_key_header = "Plugin-Manager-Inner-Api-Secret-Key"

+ 52
- 0
api/services/enterprise/plugin_manager_service.py Dosyayı Görüntüle

@@ -0,0 +1,52 @@
import enum
import logging

from pydantic import BaseModel

from services.enterprise.base import EnterprisePluginManagerRequest
from services.errors.base import BaseServiceError


class PluginCredentialType(enum.Enum):
MODEL = 0
TOOL = 1

def to_number(self):
return self.value


class CheckCredentialPolicyComplianceRequest(BaseModel):
dify_credential_id: str
provider: str
credential_type: PluginCredentialType

def model_dump(self, **kwargs):
data = super().model_dump(**kwargs)
data["credential_type"] = self.credential_type.to_number()
return data


class CredentialPolicyViolationError(BaseServiceError):
pass


class PluginManagerService:
@classmethod
def check_credential_policy_compliance(cls, body: CheckCredentialPolicyComplianceRequest):
try:
ret = EnterprisePluginManagerRequest.send_request(
"POST", "/check-credential-policy-compliance", json=body.model_dump()
)
if not isinstance(ret, dict) or "result" not in ret:
raise ValueError("Invalid response format from plugin manager API")
except Exception as e:
raise CredentialPolicyViolationError(
f"error occurred while checking credential policy compliance: {e}"
) from e

if not ret.get("result", False):
raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials")

logging.debug(
f"Credential policy compliance checked for {body.provider} with credential {body.dify_credential_id}, result: {ret.get('result', False)}"
)

+ 6
- 0
api/services/feature_service.py Dosyayı Görüntüle

@@ -134,6 +134,10 @@ class KnowledgeRateLimitModel(BaseModel):
subscription_plan: str = ""


class PluginManagerModel(BaseModel):
enabled: bool = False


class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ""
@@ -150,6 +154,7 @@ class SystemFeatureModel(BaseModel):
webapp_auth: WebAppAuthModel = WebAppAuthModel()
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()


class FeatureService:
@@ -188,6 +193,7 @@ class FeatureService:
system_features.branding.enabled = True
system_features.webapp_auth.enabled = True
system_features.enable_change_email = False
system_features.plugin_manager.enabled = True
cls._fulfill_params_from_enterprise(system_features)

if dify_config.MARKETPLACE_ENABLED:

+ 263
- 11
api/services/workflow_service.py Dosyayı Görüntüle

@@ -36,22 +36,14 @@ from libs.datetime_utils import naive_utc_now
from models.account import Account
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import (
Workflow,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom,
WorkflowType,
)
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
from repositories.factory import DifyAPIRepositoryFactory
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter

from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
from .workflow_draft_variable_service import (
DraftVariableSaver,
DraftVarLoader,
WorkflowDraftVariableService,
)
from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService


class WorkflowService:
@@ -271,6 +263,12 @@ class WorkflowService:
if not draft_workflow:
raise ValueError("No valid workflow found.")

# Validate credentials before publishing, for credential policy check
from services.feature_service import FeatureService

if FeatureService.get_system_features().plugin_manager.enabled:
self._validate_workflow_credentials(draft_workflow)

# create new workflow
workflow = Workflow.new(
tenant_id=app_model.tenant_id,
@@ -295,6 +293,260 @@ class WorkflowService:
# return new workflow
return workflow

def _validate_workflow_credentials(self, workflow: Workflow) -> None:
"""
Validate all credentials in workflow nodes before publishing.

:param workflow: The workflow to validate
:raises ValueError: If any credentials violate policy compliance
"""
graph_dict = workflow.graph_dict
nodes = graph_dict.get("nodes", [])

for node in nodes:
node_data = node.get("data", {})
node_type = node_data.get("type")
node_id = node.get("id", "unknown")

try:
# Extract and validate credentials based on node type
if node_type == "tool":
credential_id = node_data.get("credential_id")
provider = node_data.get("provider_id")
if provider:
if credential_id:
# Check specific credential
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(
credential_id=credential_id,
provider=provider,
credential_type=PluginCredentialType.TOOL,
)
else:
# Check default workspace credential for this provider
self._check_default_tool_credential(workflow.tenant_id, provider)

elif node_type == "agent":
agent_params = node_data.get("agent_parameters", {})

model_config = agent_params.get("model", {}).get("value", {})
if model_config.get("provider") and model_config.get("model"):
self._validate_llm_model_config(
workflow.tenant_id, model_config["provider"], model_config["model"]
)

# Validate load balancing credentials for agent model if load balancing is enabled
agent_model_node_data = {"model": model_config}
self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id)

# Validate agent tools
tools = agent_params.get("tools", {}).get("value", [])
for tool in tools:
# Agent tools store provider in provider_name field
provider = tool.get("provider_name")
credential_id = tool.get("credential_id")
if provider:
if credential_id:
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL)
else:
self._check_default_tool_credential(workflow.tenant_id, provider)

elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]:
model_config = node_data.get("model", {})
provider = model_config.get("provider")
model_name = model_config.get("name")

if provider and model_name:
# Validate that the provider+model combination can fetch valid credentials
self._validate_llm_model_config(workflow.tenant_id, provider, model_name)
# Validate load balancing credentials if load balancing is enabled
self._validate_load_balancing_credentials(workflow, node_data, node_id)
else:
raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration")

except Exception as e:
if isinstance(e, ValueError):
raise e
else:
raise ValueError(f"Node {node_id} ({node_type}): {str(e)}")

def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None:
"""
Validate that an LLM model configuration can fetch valid credentials.

This method attempts to get the model instance and validates that:
1. The provider exists and is configured
2. The model exists in the provider
3. Credentials can be fetched for the model
4. The credentials pass policy compliance checks

:param tenant_id: The tenant ID
:param provider: The provider name
:param model_name: The model name
:raises ValueError: If the model configuration is invalid or credentials fail policy checks
"""
try:
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType

# Get model instance to validate provider+model combination
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name
)

# The ModelInstance constructor will automatically check credential policy compliance
# via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance()
# If it fails, an exception will be raised

except Exception as e:
raise ValueError(
f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}"
)

def _check_default_tool_credential(self, tenant_id: str, provider: str) -> None:
"""
Check credential policy compliance for the default workspace credential of a tool provider.

This method finds the default credential for the given provider and validates it.
Uses the same fallback logic as runtime to handle deauthorized credentials.

:param tenant_id: The tenant ID
:param provider: The tool provider name
:raises ValueError: If no default credential exists or if it fails policy compliance
"""
try:
from models.tools import BuiltinToolProvider

# Use the same fallback logic as runtime: get the first available credential
# ordered by is_default DESC, created_at ASC (same as tool_manager.py)
default_provider = (
db.session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)

if not default_provider:
raise ValueError("No default credential found")

# Check credential policy compliance using the default credential ID
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(
credential_id=default_provider.id,
provider=provider,
credential_type=PluginCredentialType.TOOL,
check_existence=False,
)

except Exception as e:
raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}")

def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None:
"""
Validate load balancing credentials for a workflow node.

:param workflow: The workflow being validated
:param node_data: The node data containing model configuration
:param node_id: The node ID for error reporting
:raises ValueError: If load balancing credentials violate policy compliance
"""
# Extract model configuration
model_config = node_data.get("model", {})
provider = model_config.get("provider")
model_name = model_config.get("name")

if not provider or not model_name:
return # No model config to validate

# Check if this model has load balancing enabled
if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name):
# Get all load balancing configurations for this model
load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name)
# Validate each load balancing configuration
try:
for config in load_balancing_configs:
if config.get("credential_id"):
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(
config["credential_id"], provider, PluginCredentialType.MODEL
)
except Exception as e:
raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}")

def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool:
"""
Check if load balancing is enabled for a specific model.

:param tenant_id: The tenant ID
:param provider: The provider name
:param model_name: The model name
:return: True if load balancing is enabled, False otherwise
"""
try:
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager

# Get provider configurations
provider_manager = ProviderManager()
provider_configurations = provider_manager.get_configurations(tenant_id)
provider_configuration = provider_configurations.get(provider)

if not provider_configuration:
return False

# Get provider model setting
provider_model_setting = provider_configuration.get_provider_model_setting(
model_type=ModelType.LLM,
model=model_name,
)
return provider_model_setting is not None and provider_model_setting.load_balancing_enabled

except Exception:
# If we can't determine the status, assume load balancing is not enabled
return False

def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]:
"""
Get all load balancing configurations for a model.

:param tenant_id: The tenant ID
:param provider: The provider name
:param model_name: The model name
:return: List of load balancing configuration dictionaries
"""
try:
from services.model_load_balancing_service import ModelLoadBalancingService

model_load_balancing_service = ModelLoadBalancingService()
_, configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=model_name,
model_type="llm", # Load balancing is primarily used for LLM models
config_from="predefined-model", # Check both predefined and custom models
)

_, custom_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model"
)
all_configs = configs + custom_configs

return [config for config in all_configs if config.get("credential_id")]

except Exception:
# If we can't get the configurations, return empty list
# This will prevent validation errors from breaking the workflow
return []

def get_default_block_configs(self) -> list[dict]:
"""
Get default block configs

Loading…
İptal
Kaydet