Sfoglia il codice sorgente

Feat/credential policy (#25151)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
tags/1.9.0
Xiyuan Chen 1 mese fa
parent
commit
64c9a2f678
Nessun account collegato all'indirizzo email del committer

+ 1
- 5
api/controllers/console/app/workflow.py Vedi File

import services import services
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import api
from controllers.console.app.error import (
ConversationCompletedError,
DraftWorkflowNotExist,
DraftWorkflowNotSync,
)
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError

+ 27
- 1
api/core/entities/provider_configuration.py Vedi File

ProviderType, ProviderType,
TenantPreferredModelProvider, TenantPreferredModelProvider,
) )
from services.enterprise.plugin_manager_service import PluginCredentialType


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)


return copy_credentials return copy_credentials
else: else:
credentials = None credentials = None
current_credential_id = None

if self.custom_configuration.models: if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models: for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model: if model_configuration.model_type == model_type and model_configuration.model == model:
credentials = model_configuration.credentials credentials = model_configuration.credentials
current_credential_id = model_configuration.current_credential_id
break break


if not credentials and self.custom_configuration.provider: if not credentials and self.custom_configuration.provider:
credentials = self.custom_configuration.provider.credentials credentials = self.custom_configuration.provider.credentials
current_credential_id = self.custom_configuration.provider.current_credential_id

if current_credential_id:
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(
credential_id=current_credential_id,
provider=self.provider.provider,
credential_type=PluginCredentialType.MODEL,
)
else:
# no current credential id, check all available credentials
if self.custom_configuration.provider:
for credential_configuration in self.custom_configuration.provider.available_credentials:
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(
credential_id=credential_configuration.credential_id,
provider=self.provider.provider,
credential_type=PluginCredentialType.MODEL,
)


return credentials return credentials


:param credential_id: if provided, return the specified credential :param credential_id: if provided, return the specified credential
:return: :return:
""" """

if credential_id: if credential_id:
return self._get_specific_provider_credential(credential_id) return self._get_specific_provider_credential(credential_id)




current_credential_id = credential_record.id current_credential_id = credential_record.id
current_credential_name = credential_record.credential_name current_credential_name = credential_record.credential_name

credentials = self.obfuscated_credentials( credentials = self.obfuscated_credentials(
credentials=credentials, credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
): ):
current_credential_id = model_configuration.current_credential_id current_credential_id = model_configuration.current_credential_id
current_credential_name = model_configuration.current_credential_name current_credential_name = model_configuration.current_credential_name

credentials = self.obfuscated_credentials( credentials = self.obfuscated_credentials(
credentials=model_configuration.credentials, credentials=model_configuration.credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas

+ 1
- 0
api/core/entities/provider_entities.py Vedi File

name: str name: str
credentials: dict credentials: dict
credential_source_type: str | None = None credential_source_type: str | None = None
credential_id: str | None = None




class ModelSettings(BaseModel): class ModelSettings(BaseModel):

+ 75
- 0
api/core/helper/credential_utils.py Vedi File

"""
Credential utility functions for checking credential existence and policy compliance.
"""

from services.enterprise.plugin_manager_service import PluginCredentialType


def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:
"""
Check if the credential still exists in the database.

:param credential_id: The credential ID to check
:param credential_type: The type of credential (MODEL or TOOL)
:return: True if credential exists, False otherwise
"""
from sqlalchemy import select
from sqlalchemy.orm import Session

from extensions.ext_database import db
from models.provider import ProviderCredential, ProviderModelCredential
from models.tools import BuiltinToolProvider

with Session(db.engine) as session:
if credential_type == PluginCredentialType.MODEL:
# Check both pre-defined and custom model credentials using a single UNION query
stmt = (
select(ProviderCredential.id)
.where(ProviderCredential.id == credential_id)
.union(select(ProviderModelCredential.id).where(ProviderModelCredential.id == credential_id))
)
return session.scalar(stmt) is not None

if credential_type == PluginCredentialType.TOOL:
return (
session.scalar(select(BuiltinToolProvider.id).where(BuiltinToolProvider.id == credential_id))
is not None
)

return False


def check_credential_policy_compliance(
credential_id: str, provider: str, credential_type: "PluginCredentialType", check_existence: bool = True
) -> None:
"""
Check credential policy compliance for the given credential ID.

:param credential_id: The credential ID to check
:param provider: The provider name
:param credential_type: The type of credential (MODEL or TOOL)
:param check_existence: Whether to check if credential exists in database first
:raises ValueError: If credential policy compliance check fails
"""
from services.enterprise.plugin_manager_service import (
CheckCredentialPolicyComplianceRequest,
PluginManagerService,
)
from services.feature_service import FeatureService

if not FeatureService.get_system_features().plugin_manager.enabled or not credential_id:
return

# Check if credential exists in database first (if requested)
if check_existence:
if not is_credential_exists(credential_id, credential_type):
raise ValueError(f"Credential with id {credential_id} for provider {provider} not found.")

# Check policy compliance
PluginManagerService.check_credential_policy_compliance(
CheckCredentialPolicyComplianceRequest(
dify_credential_id=credential_id,
provider=provider,
credential_type=credential_type,
)
)

+ 36
- 0
api/core/model_manager.py Vedi File

from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.provider import ProviderType from models.provider import ProviderType
from services.enterprise.plugin_manager_service import PluginCredentialType


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)


else: else:
raise last_exception raise last_exception


# Additional policy compliance check as fallback (in case fetch_next didn't catch it)
try:
from core.helper.credential_utils import check_credential_policy_compliance

if lb_config.credential_id:
check_credential_policy_compliance(
credential_id=lb_config.credential_id,
provider=self.provider,
credential_type=PluginCredentialType.MODEL,
)
except Exception as e:
logger.warning(
"Load balancing config %s failed policy compliance check in round-robin: %s", lb_config.id, str(e)
)
self.load_balancing_manager.cooldown(lb_config, expire=60)
continue

try: try:
if "credentials" in kwargs: if "credentials" in kwargs:
del kwargs["credentials"] del kwargs["credentials"]


continue continue


# Check policy compliance for the selected configuration
try:
from core.helper.credential_utils import check_credential_policy_compliance

if config.credential_id:
check_credential_policy_compliance(
credential_id=config.credential_id,
provider=self._provider,
credential_type=PluginCredentialType.MODEL,
)
except Exception as e:
logger.warning("Load balancing config %s failed policy compliance check: %s", config.id, str(e))
cooldown_load_balancing_configs.append(config)
if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs):
# all configs are in cooldown or failed policy compliance
return None
continue

if dify_config.DEBUG: if dify_config.DEBUG:
logger.info( logger.info(
"""Model LB """Model LB

+ 1
- 0
api/core/provider_manager.py Vedi File

name=load_balancing_model_config.name, name=load_balancing_model_config.name,
credentials=provider_model_credentials, credentials=provider_model_credentials,
credential_source_type=load_balancing_model_config.credential_source_type, credential_source_type=load_balancing_model_config.credential_source_type,
credential_id=load_balancing_model_config.credential_id,
) )
) )



+ 4
- 0
api/core/tools/errors.py Vedi File

pass pass




class ToolCredentialPolicyViolationError(ValueError):
pass


class ToolEngineInvokeError(Exception): class ToolEngineInvokeError(Exception):
meta: ToolInvokeMeta meta: ToolInvokeMeta



+ 12
- 3
api/core/tools/tool_manager.py Vedi File

from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.tools.mcp_tools_manage_service import MCPToolManageService from services.tools.mcp_tools_manage_service import MCPToolManageService


if TYPE_CHECKING: if TYPE_CHECKING:
) )
from core.tools.errors import ToolProviderNotFoundError from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ToolParameterConfigurationManager,
)
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db from extensions.ext_database import db
if builtin_provider is None: if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")


# check if the credential is allowed to be used
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(
credential_id=builtin_provider.id,
provider=provider_id,
credential_type=PluginCredentialType.TOOL,
check_existence=False,
)

encrypter, cache = create_provider_encrypter( encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[ config=[

+ 17
- 5
api/services/enterprise/base.py Vedi File

import requests import requests




class EnterpriseRequest:
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")

class BaseRequest:
proxies = { proxies = {
"http": "", "http": "",
"https": "", "https": "",
} }
base_url = ""
secret_key = ""
secret_key_header = ""


@classmethod @classmethod
def send_request(cls, method, endpoint, json=None, params=None): def send_request(cls, method, endpoint, json=None, params=None):
headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key}
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}" url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies) response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies)
return response.json() return response.json()


class EnterpriseRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
secret_key_header = "Enterprise-Api-Secret-Key"


class EnterprisePluginManagerRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL")
secret_key = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY", "ENTERPRISE_PLUGIN_MANAGER_API_SECRET_KEY")
secret_key_header = "Plugin-Manager-Inner-Api-Secret-Key"

+ 52
- 0
api/services/enterprise/plugin_manager_service.py Vedi File

import enum
import logging

from pydantic import BaseModel

from services.enterprise.base import EnterprisePluginManagerRequest
from services.errors.base import BaseServiceError


class PluginCredentialType(enum.Enum):
MODEL = 0
TOOL = 1

def to_number(self):
return self.value


class CheckCredentialPolicyComplianceRequest(BaseModel):
dify_credential_id: str
provider: str
credential_type: PluginCredentialType

def model_dump(self, **kwargs):
data = super().model_dump(**kwargs)
data["credential_type"] = self.credential_type.to_number()
return data


class CredentialPolicyViolationError(BaseServiceError):
pass


class PluginManagerService:
@classmethod
def check_credential_policy_compliance(cls, body: CheckCredentialPolicyComplianceRequest):
try:
ret = EnterprisePluginManagerRequest.send_request(
"POST", "/check-credential-policy-compliance", json=body.model_dump()
)
if not isinstance(ret, dict) or "result" not in ret:
raise ValueError("Invalid response format from plugin manager API")
except Exception as e:
raise CredentialPolicyViolationError(
f"error occurred while checking credential policy compliance: {e}"
) from e

if not ret.get("result", False):
raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials")

logging.debug(
f"Credential policy compliance checked for {body.provider} with credential {body.dify_credential_id}, result: {ret.get('result', False)}"
)

+ 6
- 0
api/services/feature_service.py Vedi File

subscription_plan: str = "" subscription_plan: str = ""




class PluginManagerModel(BaseModel):
enabled: bool = False


class SystemFeatureModel(BaseModel): class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = "" sso_enforced_for_signin_protocol: str = ""
webapp_auth: WebAppAuthModel = WebAppAuthModel() webapp_auth: WebAppAuthModel = WebAppAuthModel()
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
enable_change_email: bool = True enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()




class FeatureService: class FeatureService:
system_features.branding.enabled = True system_features.branding.enabled = True
system_features.webapp_auth.enabled = True system_features.webapp_auth.enabled = True
system_features.enable_change_email = False system_features.enable_change_email = False
system_features.plugin_manager.enabled = True
cls._fulfill_params_from_enterprise(system_features) cls._fulfill_params_from_enterprise(system_features)


if dify_config.MARKETPLACE_ENABLED: if dify_config.MARKETPLACE_ENABLED:

+ 263
- 11
api/services/workflow_service.py Vedi File

from models.account import Account from models.account import Account
from models.model import App, AppMode from models.model import App, AppMode
from models.tools import WorkflowToolProvider from models.tools import WorkflowToolProvider
from models.workflow import (
Workflow,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom,
WorkflowType,
)
from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType
from repositories.factory import DifyAPIRepositoryFactory from repositories.factory import DifyAPIRepositoryFactory
from services.enterprise.plugin_manager_service import PluginCredentialType
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter from services.workflow.workflow_converter import WorkflowConverter


from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
from .workflow_draft_variable_service import (
DraftVariableSaver,
DraftVarLoader,
WorkflowDraftVariableService,
)
from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService




class WorkflowService: class WorkflowService:
if not draft_workflow: if not draft_workflow:
raise ValueError("No valid workflow found.") raise ValueError("No valid workflow found.")


# Validate credentials before publishing, for credential policy check
from services.feature_service import FeatureService

if FeatureService.get_system_features().plugin_manager.enabled:
self._validate_workflow_credentials(draft_workflow)

# create new workflow # create new workflow
workflow = Workflow.new( workflow = Workflow.new(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
# return new workflow # return new workflow
return workflow return workflow


def _validate_workflow_credentials(self, workflow: Workflow) -> None:
"""
Validate all credentials in workflow nodes before publishing.

:param workflow: The workflow to validate
:raises ValueError: If any credentials violate policy compliance
"""
graph_dict = workflow.graph_dict
nodes = graph_dict.get("nodes", [])

for node in nodes:
node_data = node.get("data", {})
node_type = node_data.get("type")
node_id = node.get("id", "unknown")

try:
# Extract and validate credentials based on node type
if node_type == "tool":
credential_id = node_data.get("credential_id")
provider = node_data.get("provider_id")
if provider:
if credential_id:
# Check specific credential
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(
credential_id=credential_id,
provider=provider,
credential_type=PluginCredentialType.TOOL,
)
else:
# Check default workspace credential for this provider
self._check_default_tool_credential(workflow.tenant_id, provider)

elif node_type == "agent":
agent_params = node_data.get("agent_parameters", {})

model_config = agent_params.get("model", {}).get("value", {})
if model_config.get("provider") and model_config.get("model"):
self._validate_llm_model_config(
workflow.tenant_id, model_config["provider"], model_config["model"]
)

# Validate load balancing credentials for agent model if load balancing is enabled
agent_model_node_data = {"model": model_config}
self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id)

# Validate agent tools
tools = agent_params.get("tools", {}).get("value", [])
for tool in tools:
# Agent tools store provider in provider_name field
provider = tool.get("provider_name")
credential_id = tool.get("credential_id")
if provider:
if credential_id:
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL)
else:
self._check_default_tool_credential(workflow.tenant_id, provider)

elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]:
model_config = node_data.get("model", {})
provider = model_config.get("provider")
model_name = model_config.get("name")

if provider and model_name:
# Validate that the provider+model combination can fetch valid credentials
self._validate_llm_model_config(workflow.tenant_id, provider, model_name)
# Validate load balancing credentials if load balancing is enabled
self._validate_load_balancing_credentials(workflow, node_data, node_id)
else:
raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration")

except Exception as e:
if isinstance(e, ValueError):
raise e
else:
raise ValueError(f"Node {node_id} ({node_type}): {str(e)}")

def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None:
"""
Validate that an LLM model configuration can fetch valid credentials.

This method attempts to get the model instance and validates that:
1. The provider exists and is configured
2. The model exists in the provider
3. Credentials can be fetched for the model
4. The credentials pass policy compliance checks

:param tenant_id: The tenant ID
:param provider: The provider name
:param model_name: The model name
:raises ValueError: If the model configuration is invalid or credentials fail policy checks
"""
try:
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType

# Get model instance to validate provider+model combination
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name
)

# The ModelInstance constructor will automatically check credential policy compliance
# via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance()
# If it fails, an exception will be raised

except Exception as e:
raise ValueError(
f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}"
)

def _check_default_tool_credential(self, tenant_id: str, provider: str) -> None:
"""
Check credential policy compliance for the default workspace credential of a tool provider.

This method finds the default credential for the given provider and validates it.
Uses the same fallback logic as runtime to handle deauthorized credentials.

:param tenant_id: The tenant ID
:param provider: The tool provider name
:raises ValueError: If no default credential exists or if it fails policy compliance
"""
try:
from models.tools import BuiltinToolProvider

# Use the same fallback logic as runtime: get the first available credential
# ordered by is_default DESC, created_at ASC (same as tool_manager.py)
default_provider = (
db.session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)

if not default_provider:
raise ValueError("No default credential found")

# Check credential policy compliance using the default credential ID
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(
credential_id=default_provider.id,
provider=provider,
credential_type=PluginCredentialType.TOOL,
check_existence=False,
)

except Exception as e:
raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}")

def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None:
"""
Validate load balancing credentials for a workflow node.

:param workflow: The workflow being validated
:param node_data: The node data containing model configuration
:param node_id: The node ID for error reporting
:raises ValueError: If load balancing credentials violate policy compliance
"""
# Extract model configuration
model_config = node_data.get("model", {})
provider = model_config.get("provider")
model_name = model_config.get("name")

if not provider or not model_name:
return # No model config to validate

# Check if this model has load balancing enabled
if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name):
# Get all load balancing configurations for this model
load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name)
# Validate each load balancing configuration
try:
for config in load_balancing_configs:
if config.get("credential_id"):
from core.helper.credential_utils import check_credential_policy_compliance

check_credential_policy_compliance(
config["credential_id"], provider, PluginCredentialType.MODEL
)
except Exception as e:
raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}")

def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool:
"""
Check if load balancing is enabled for a specific model.

:param tenant_id: The tenant ID
:param provider: The provider name
:param model_name: The model name
:return: True if load balancing is enabled, False otherwise
"""
try:
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager

# Get provider configurations
provider_manager = ProviderManager()
provider_configurations = provider_manager.get_configurations(tenant_id)
provider_configuration = provider_configurations.get(provider)

if not provider_configuration:
return False

# Get provider model setting
provider_model_setting = provider_configuration.get_provider_model_setting(
model_type=ModelType.LLM,
model=model_name,
)
return provider_model_setting is not None and provider_model_setting.load_balancing_enabled

except Exception:
# If we can't determine the status, assume load balancing is not enabled
return False

def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]:
"""
Get all load balancing configurations for a model.

:param tenant_id: The tenant ID
:param provider: The provider name
:param model_name: The model name
:return: List of load balancing configuration dictionaries
"""
try:
from services.model_load_balancing_service import ModelLoadBalancingService

model_load_balancing_service = ModelLoadBalancingService()
_, configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=model_name,
model_type="llm", # Load balancing is primarily used for LLM models
config_from="predefined-model", # Check both predefined and custom models
)

_, custom_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model"
)
all_configs = configs + custom_configs

return [config for config in all_configs if config.get("credential_id")]

except Exception:
# If we can't get the configurations, return empty list
# This will prevent validation errors from breaking the workflow
return []

def get_default_block_configs(self) -> list[dict]: def get_default_block_configs(self) -> list[dict]:
""" """
Get default block configs Get default block configs

Loading…
Annulla
Salva