Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>tags/1.8.0
| @@ -10,6 +10,7 @@ from controllers.console.wraps import account_initialization_required, setup_req | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from libs.helper import StrLen, uuid_value | |||
| from libs.login import login_required | |||
| from services.billing_service import BillingService | |||
| from services.model_provider_service import ModelProviderService | |||
| @@ -45,67 +46,71 @@ class ModelProviderCredentialApi(Resource): | |||
| @account_initialization_required | |||
| def get(self, provider: str): | |||
| tenant_id = current_user.current_tenant_id | |||
| # if credential_id is not provided, return current used credential | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider) | |||
| credentials = model_provider_service.get_provider_credential( | |||
| tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id") | |||
| ) | |||
| return {"credentials": credentials} | |||
| class ModelProviderValidateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| result = True | |||
| error = "" | |||
| try: | |||
| model_provider_service.provider_credentials_validate( | |||
| tenant_id=tenant_id, provider=provider, credentials=args["credentials"] | |||
| model_provider_service.create_provider_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| credentials=args["credentials"], | |||
| credential_name=args["name"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(ex) | |||
| response = {"result": "success" if result else "error"} | |||
| if not result: | |||
| response["error"] = error or "Unknown error" | |||
| return response | |||
| raise ValueError(str(ex)) | |||
| return {"result": "success"}, 201 | |||
| class ModelProviderApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| def put(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| try: | |||
| model_provider_service.save_provider_credentials( | |||
| tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"] | |||
| model_provider_service.update_provider_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| credentials=args["credentials"], | |||
| credential_id=args["credential_id"], | |||
| credential_name=args["name"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| return {"result": "success"}, 201 | |||
| return {"result": "success"} | |||
| @setup_required | |||
| @login_required | |||
| @@ -113,13 +118,70 @@ class ModelProviderApi(Resource): | |||
| def delete(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider) | |||
| model_provider_service.remove_provider_credential( | |||
| tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] | |||
| ) | |||
| return {"result": "success"}, 204 | |||
| class ModelProviderCredentialSwitchApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| service = ModelProviderService() | |||
| service.switch_active_provider_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| credential_id=args["credential_id"], | |||
| ) | |||
| return {"result": "success"} | |||
| class ModelProviderValidateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| result = True | |||
| error = "" | |||
| try: | |||
| model_provider_service.validate_provider_credentials( | |||
| tenant_id=tenant_id, provider=provider, credentials=args["credentials"] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(ex) | |||
| response = {"result": "success" if result else "error"} | |||
| if not result: | |||
| response["error"] = error or "Unknown error" | |||
| return response | |||
| class ModelProviderIconApi(Resource): | |||
| """ | |||
| Get model provider icon | |||
| @@ -187,8 +249,10 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): | |||
| api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") | |||
| api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials") | |||
| api.add_resource( | |||
| ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch" | |||
| ) | |||
| api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate") | |||
| api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>") | |||
| api.add_resource( | |||
| PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type" | |||
| @@ -9,6 +9,7 @@ from controllers.console.wraps import account_initialization_required, setup_req | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from libs.helper import StrLen, uuid_value | |||
| from libs.login import login_required | |||
| from services.model_load_balancing_service import ModelLoadBalancingService | |||
| from services.model_provider_service import ModelProviderService | |||
| @@ -98,6 +99,7 @@ class ModelProviderModelApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| # To save the model's load balance configs | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| @@ -113,22 +115,26 @@ class ModelProviderModelApi(Resource): | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| model_load_balancing_service = ModelLoadBalancingService() | |||
| if args.get("config_from", "") == "custom-model": | |||
| if not args.get("credential_id"): | |||
| raise ValueError("credential_id is required when configuring a custom-model") | |||
| service = ModelProviderService() | |||
| service.switch_active_custom_model_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| model_type=args["model_type"], | |||
| model=args["model"], | |||
| credential_id=args["credential_id"], | |||
| ) | |||
| if ( | |||
| "load_balancing" in args | |||
| and args["load_balancing"] | |||
| and "enabled" in args["load_balancing"] | |||
| and args["load_balancing"]["enabled"] | |||
| ): | |||
| if "configs" not in args["load_balancing"]: | |||
| raise ValueError("invalid load balancing configs") | |||
| model_load_balancing_service = ModelLoadBalancingService() | |||
| if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]: | |||
| # save load balancing configs | |||
| model_load_balancing_service.update_load_balancing_configs( | |||
| tenant_id=tenant_id, | |||
| @@ -136,37 +142,17 @@ class ModelProviderModelApi(Resource): | |||
| model=args["model"], | |||
| model_type=args["model_type"], | |||
| configs=args["load_balancing"]["configs"], | |||
| config_from=args.get("config_from", ""), | |||
| ) | |||
| # enable load balancing | |||
| model_load_balancing_service.enable_model_load_balancing( | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| else: | |||
| # disable load balancing | |||
| model_load_balancing_service.disable_model_load_balancing( | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| if args.get("config_from", "") != "predefined-model": | |||
| model_provider_service = ModelProviderService() | |||
| try: | |||
| model_provider_service.save_model_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args["model"], | |||
| model_type=args["model_type"], | |||
| credentials=args["credentials"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| logging.exception( | |||
| "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", | |||
| tenant_id, | |||
| args.get("model"), | |||
| args.get("model_type"), | |||
| ) | |||
| raise ValueError(str(ex)) | |||
| if args.get("load_balancing", {}).get("enabled"): | |||
| model_load_balancing_service.enable_model_load_balancing( | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| else: | |||
| model_load_balancing_service.disable_model_load_balancing( | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| return {"result": "success"}, 200 | |||
| @@ -192,7 +178,7 @@ class ModelProviderModelApi(Resource): | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.remove_model_credentials( | |||
| model_provider_service.remove_model( | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| @@ -216,11 +202,17 @@ class ModelProviderModelCredentialApi(Resource): | |||
| choices=[mt.value for mt in ModelType], | |||
| location="args", | |||
| ) | |||
| parser.add_argument("config_from", type=str, required=False, nullable=True, location="args") | |||
| parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| credentials = model_provider_service.get_model_credentials( | |||
| tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] | |||
| current_credential = model_provider_service.get_model_credential( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model_type=args["model_type"], | |||
| model=args["model"], | |||
| credential_id=args.get("credential_id"), | |||
| ) | |||
| model_load_balancing_service = ModelLoadBalancingService() | |||
| @@ -228,10 +220,173 @@ class ModelProviderModelCredentialApi(Resource): | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| return { | |||
| "credentials": credentials, | |||
| "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, | |||
| } | |||
| if args.get("config_from", "") == "predefined-model": | |||
| available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( | |||
| tenant_id=tenant_id, provider_name=provider | |||
| ) | |||
| else: | |||
| model_type = ModelType.value_of(args["model_type"]).to_origin_model_type() | |||
| available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( | |||
| tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"] | |||
| ) | |||
| return jsonable_encoder( | |||
| { | |||
| "credentials": current_credential.get("credentials") if current_credential else {}, | |||
| "current_credential_id": current_credential.get("current_credential_id") | |||
| if current_credential | |||
| else None, | |||
| "current_credential_name": current_credential.get("current_credential_name") | |||
| if current_credential | |||
| else None, | |||
| "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, | |||
| "available_credentials": available_credentials, | |||
| } | |||
| ) | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| try: | |||
| model_provider_service.create_model_credential( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args["model"], | |||
| model_type=args["model_type"], | |||
| credentials=args["credentials"], | |||
| credential_name=args["name"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| logging.exception( | |||
| "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", | |||
| tenant_id, | |||
| args.get("model"), | |||
| args.get("model_type"), | |||
| ) | |||
| raise ValueError(str(ex)) | |||
| return {"result": "success"}, 201 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def put(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| try: | |||
| model_provider_service.update_model_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| model_type=args["model_type"], | |||
| model=args["model"], | |||
| credentials=args["credentials"], | |||
| credential_id=args["credential_id"], | |||
| credential_name=args["name"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| return {"result": "success"} | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.remove_model_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| model_type=args["model_type"], | |||
| model=args["model"], | |||
| credential_id=args["credential_id"], | |||
| ) | |||
| return {"result": "success"}, 204 | |||
| class ModelProviderModelCredentialSwitchApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| service = ModelProviderService() | |||
| service.add_model_credential_to_model_list( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| model_type=args["model_type"], | |||
| model=args["model"], | |||
| credential_id=args["credential_id"], | |||
| ) | |||
| return {"result": "success"} | |||
| class ModelProviderModelEnableApi(Resource): | |||
| @@ -314,7 +469,7 @@ class ModelProviderModelValidateApi(Resource): | |||
| error = "" | |||
| try: | |||
| model_provider_service.model_credentials_validate( | |||
| model_provider_service.validate_model_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args["model"], | |||
| @@ -379,6 +534,10 @@ api.add_resource( | |||
| api.add_resource( | |||
| ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials" | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderModelCredentialSwitchApi, | |||
| "/workspaces/current/model-providers/<path:provider>/models/credentials/switch", | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate" | |||
| ) | |||
| @@ -19,6 +19,7 @@ class ModelStatus(Enum): | |||
| QUOTA_EXCEEDED = "quota-exceeded" | |||
| NO_PERMISSION = "no-permission" | |||
| DISABLED = "disabled" | |||
| CREDENTIAL_REMOVED = "credential-removed" | |||
| class SimpleModelProviderEntity(BaseModel): | |||
| @@ -54,6 +55,7 @@ class ProviderModelWithStatusEntity(ProviderModel): | |||
| status: ModelStatus | |||
| load_balancing_enabled: bool = False | |||
| has_invalid_load_balancing_configs: bool = False | |||
| def raise_for_status(self) -> None: | |||
| """ | |||
| @@ -69,6 +69,15 @@ class QuotaConfiguration(BaseModel): | |||
| restrict_models: list[RestrictModel] = [] | |||
| class CredentialConfiguration(BaseModel): | |||
| """ | |||
| Model class for credential configuration. | |||
| """ | |||
| credential_id: str | |||
| credential_name: str | |||
| class SystemConfiguration(BaseModel): | |||
| """ | |||
| Model class for provider system configuration. | |||
| @@ -86,6 +95,9 @@ class CustomProviderConfiguration(BaseModel): | |||
| """ | |||
| credentials: dict | |||
| current_credential_id: Optional[str] = None | |||
| current_credential_name: Optional[str] = None | |||
| available_credentials: list[CredentialConfiguration] = [] | |||
| class CustomModelConfiguration(BaseModel): | |||
| @@ -95,7 +107,10 @@ class CustomModelConfiguration(BaseModel): | |||
| model: str | |||
| model_type: ModelType | |||
| credentials: dict | |||
| credentials: dict | None | |||
| current_credential_id: Optional[str] = None | |||
| current_credential_name: Optional[str] = None | |||
| available_model_credentials: list[CredentialConfiguration] = [] | |||
| # pydantic configs | |||
| model_config = ConfigDict(protected_namespaces=()) | |||
| @@ -118,6 +133,7 @@ class ModelLoadBalancingConfiguration(BaseModel): | |||
| id: str | |||
| name: str | |||
| credentials: dict | |||
| credential_source_type: str | None = None | |||
| class ModelSettings(BaseModel): | |||
| @@ -201,7 +201,7 @@ class ModelProviderFactory: | |||
| return filtered_credentials | |||
| def get_model_schema( | |||
| self, *, provider: str, model_type: ModelType, model: str, credentials: dict | |||
| self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None | |||
| ) -> AIModelEntity | None: | |||
| """ | |||
| Get model schema | |||
| @@ -12,6 +12,7 @@ from configs import dify_config | |||
| from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity | |||
| from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle | |||
| from core.entities.provider_entities import ( | |||
| CredentialConfiguration, | |||
| CustomConfiguration, | |||
| CustomModelConfiguration, | |||
| CustomProviderConfiguration, | |||
| @@ -40,7 +41,9 @@ from extensions.ext_redis import redis_client | |||
| from models.provider import ( | |||
| LoadBalancingModelConfig, | |||
| Provider, | |||
| ProviderCredential, | |||
| ProviderModel, | |||
| ProviderModelCredential, | |||
| ProviderModelSetting, | |||
| ProviderType, | |||
| TenantDefaultModel, | |||
| @@ -488,6 +491,61 @@ class ProviderManager: | |||
| return provider_name_to_provider_load_balancing_model_configs_dict | |||
| @staticmethod | |||
| def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]: | |||
| """ | |||
| Get provider all credentials. | |||
| :param tenant_id: workspace id | |||
| :param provider_name: provider name | |||
| :return: | |||
| """ | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = ( | |||
| select(ProviderCredential) | |||
| .where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name) | |||
| .order_by(ProviderCredential.created_at.desc()) | |||
| ) | |||
| available_credentials = session.scalars(stmt).all() | |||
| return [ | |||
| CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name) | |||
| for credential in available_credentials | |||
| ] | |||
| @staticmethod | |||
| def get_provider_model_available_credentials( | |||
| tenant_id: str, provider_name: str, model_name: str, model_type: str | |||
| ) -> list[CredentialConfiguration]: | |||
| """ | |||
| Get provider custom model all credentials. | |||
| :param tenant_id: workspace id | |||
| :param provider_name: provider name | |||
| :param model_name: model name | |||
| :param model_type: model type | |||
| :return: | |||
| """ | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = ( | |||
| select(ProviderModelCredential) | |||
| .where( | |||
| ProviderModelCredential.tenant_id == tenant_id, | |||
| ProviderModelCredential.provider_name == provider_name, | |||
| ProviderModelCredential.model_name == model_name, | |||
| ProviderModelCredential.model_type == model_type, | |||
| ) | |||
| .order_by(ProviderModelCredential.created_at.desc()) | |||
| ) | |||
| available_credentials = session.scalars(stmt).all() | |||
| return [ | |||
| CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name) | |||
| for credential in available_credentials | |||
| ] | |||
| @staticmethod | |||
| def _init_trial_provider_records( | |||
| tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] | |||
| @@ -590,9 +648,6 @@ class ProviderManager: | |||
| if provider_record.provider_type == ProviderType.SYSTEM.value: | |||
| continue | |||
| if not provider_record.encrypted_config: | |||
| continue | |||
| custom_provider_record = provider_record | |||
| # Get custom provider credentials | |||
| @@ -611,8 +666,8 @@ class ProviderManager: | |||
| try: | |||
| # fix origin data | |||
| if custom_provider_record.encrypted_config is None: | |||
| raise ValueError("No credentials found") | |||
| if not custom_provider_record.encrypted_config.startswith("{"): | |||
| provider_credentials = {} | |||
| elif not custom_provider_record.encrypted_config.startswith("{"): | |||
| provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} | |||
| else: | |||
| provider_credentials = json.loads(custom_provider_record.encrypted_config) | |||
| @@ -637,7 +692,14 @@ class ProviderManager: | |||
| else: | |||
| provider_credentials = cached_provider_credentials | |||
| custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials) | |||
| custom_provider_configuration = CustomProviderConfiguration( | |||
| credentials=provider_credentials, | |||
| current_credential_name=custom_provider_record.credential_name, | |||
| current_credential_id=custom_provider_record.credential_id, | |||
| available_credentials=self.get_provider_available_credentials( | |||
| tenant_id, custom_provider_record.provider_name | |||
| ), | |||
| ) | |||
| # Get provider model credential secret variables | |||
| model_credential_secret_variables = self._extract_secret_variables( | |||
| @@ -649,8 +711,12 @@ class ProviderManager: | |||
| # Get custom provider model credentials | |||
| custom_model_configurations = [] | |||
| for provider_model_record in provider_model_records: | |||
| if not provider_model_record.encrypted_config: | |||
| continue | |||
| available_model_credentials = self.get_provider_model_available_credentials( | |||
| tenant_id, | |||
| provider_model_record.provider_name, | |||
| provider_model_record.model_name, | |||
| provider_model_record.model_type, | |||
| ) | |||
| provider_model_credentials_cache = ProviderCredentialsCache( | |||
| tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL | |||
| @@ -659,7 +725,7 @@ class ProviderManager: | |||
| # Get cached provider model credentials | |||
| cached_provider_model_credentials = provider_model_credentials_cache.get() | |||
| if not cached_provider_model_credentials: | |||
| if not cached_provider_model_credentials and provider_model_record.encrypted_config: | |||
| try: | |||
| provider_model_credentials = json.loads(provider_model_record.encrypted_config) | |||
| except JSONDecodeError: | |||
| @@ -688,6 +754,9 @@ class ProviderManager: | |||
| model=provider_model_record.model_name, | |||
| model_type=ModelType.value_of(provider_model_record.model_type), | |||
| credentials=provider_model_credentials, | |||
| current_credential_id=provider_model_record.credential_id, | |||
| current_credential_name=provider_model_record.credential_name, | |||
| available_model_credentials=available_model_credentials, | |||
| ) | |||
| ) | |||
| @@ -899,6 +968,18 @@ class ProviderManager: | |||
| load_balancing_model_config.model_name == provider_model_setting.model_name | |||
| and load_balancing_model_config.model_type == provider_model_setting.model_type | |||
| ): | |||
| if load_balancing_model_config.name == "__delete__": | |||
| # to calculate current model whether has invalidate lb configs | |||
| load_balancing_configs.append( | |||
| ModelLoadBalancingConfiguration( | |||
| id=load_balancing_model_config.id, | |||
| name=load_balancing_model_config.name, | |||
| credentials={}, | |||
| credential_source_type=load_balancing_model_config.credential_source_type, | |||
| ) | |||
| ) | |||
| continue | |||
| if not load_balancing_model_config.enabled: | |||
| continue | |||
| @@ -955,6 +1036,7 @@ class ProviderManager: | |||
| id=load_balancing_model_config.id, | |||
| name=load_balancing_model_config.name, | |||
| credentials=provider_model_credentials, | |||
| credential_source_type=load_balancing_model_config.credential_source_type, | |||
| ) | |||
| ) | |||
| @@ -0,0 +1,177 @@ | |||
| """Add provider multi credential support | |||
| Revision ID: e8446f481c1e | |||
| Revises: 8bcc02c9bd07 | |||
| Create Date: 2025-08-09 15:53:54.341341 | |||
| """ | |||
| from alembic import op | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.sql import table, column | |||
| import uuid | |||
| # revision identifiers, used by Alembic. | |||
| revision = 'e8446f481c1e' | |||
| down_revision = 'fa8b0fa6f407' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # Create provider_credentials table | |||
| op.create_table('provider_credentials', | |||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('provider_name', sa.String(length=255), nullable=False), | |||
| sa.Column('credential_name', sa.String(length=255), nullable=False), | |||
| sa.Column('encrypted_config', sa.Text(), nullable=False), | |||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), | |||
| sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), | |||
| sa.PrimaryKeyConstraint('id', name='provider_credential_pkey') | |||
| ) | |||
| # Create index for provider_credentials | |||
| with op.batch_alter_table('provider_credentials', schema=None) as batch_op: | |||
| batch_op.create_index('provider_credential_tenant_provider_idx', ['tenant_id', 'provider_name'], unique=False) | |||
| # Add credential_id to providers table | |||
| with op.batch_alter_table('providers', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) | |||
| # Add credential_id to load_balancing_model_configs table | |||
| with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) | |||
| migrate_existing_providers_data() | |||
| # Remove encrypted_config column from providers table after migration | |||
| with op.batch_alter_table('providers', schema=None) as batch_op: | |||
| batch_op.drop_column('encrypted_config') | |||
| def migrate_existing_providers_data(): | |||
| """migrate providers table data to provider_credentials""" | |||
| # Define table structure for data manipulation | |||
| providers_table = table('providers', | |||
| column('id', models.types.StringUUID()), | |||
| column('tenant_id', models.types.StringUUID()), | |||
| column('provider_name', sa.String()), | |||
| column('encrypted_config', sa.Text()), | |||
| column('created_at', sa.DateTime()), | |||
| column('updated_at', sa.DateTime()), | |||
| column('credential_id', models.types.StringUUID()), | |||
| ) | |||
| provider_credential_table = table('provider_credentials', | |||
| column('id', models.types.StringUUID()), | |||
| column('tenant_id', models.types.StringUUID()), | |||
| column('provider_name', sa.String()), | |||
| column('credential_name', sa.String()), | |||
| column('encrypted_config', sa.Text()), | |||
| column('created_at', sa.DateTime()), | |||
| column('updated_at', sa.DateTime()) | |||
| ) | |||
| # Get database connection | |||
| conn = op.get_bind() | |||
| # Query all existing providers data | |||
| existing_providers = conn.execute( | |||
| sa.select(providers_table.c.id, providers_table.c.tenant_id, | |||
| providers_table.c.provider_name, providers_table.c.encrypted_config, | |||
| providers_table.c.created_at, providers_table.c.updated_at) | |||
| .where(providers_table.c.encrypted_config.isnot(None)) | |||
| ).fetchall() | |||
| # Iterate through each provider and insert into provider_credentials | |||
| for provider in existing_providers: | |||
| credential_id = str(uuid.uuid4()) | |||
| if not provider.encrypted_config or provider.encrypted_config.strip() == '': | |||
| continue | |||
| # Insert into provider_credentials table | |||
| conn.execute( | |||
| provider_credential_table.insert().values( | |||
| id=credential_id, | |||
| tenant_id=provider.tenant_id, | |||
| provider_name=provider.provider_name, | |||
| credential_name='API_KEY1', # Use a default name | |||
| encrypted_config=provider.encrypted_config, | |||
| created_at=provider.created_at, | |||
| updated_at=provider.updated_at | |||
| ) | |||
| ) | |||
| # Update original providers table, set credential_id | |||
| conn.execute( | |||
| providers_table.update() | |||
| .where(providers_table.c.id == provider.id) | |||
| .values( | |||
| credential_id=credential_id, | |||
| ) | |||
| ) | |||
| def downgrade(): | |||
| # Re-add encrypted_config column to providers table | |||
| with op.batch_alter_table('providers', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) | |||
| # Migrate data back from provider_credentials to providers | |||
| migrate_data_back_to_providers() | |||
| # Remove credential_id columns | |||
| with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: | |||
| batch_op.drop_column('credential_id') | |||
| with op.batch_alter_table('providers', schema=None) as batch_op: | |||
| batch_op.drop_column('credential_id') | |||
| # Drop provider_credentials table | |||
| op.drop_table('provider_credentials') | |||
| def migrate_data_back_to_providers(): | |||
| """Migrate data back from provider_credentials to providers table for downgrade""" | |||
| # Define table structure for data manipulation | |||
| providers_table = table('providers', | |||
| column('id', models.types.StringUUID()), | |||
| column('tenant_id', models.types.StringUUID()), | |||
| column('provider_name', sa.String()), | |||
| column('encrypted_config', sa.Text()), | |||
| column('credential_id', models.types.StringUUID()), | |||
| ) | |||
| provider_credential_table = table('provider_credentials', | |||
| column('id', models.types.StringUUID()), | |||
| column('tenant_id', models.types.StringUUID()), | |||
| column('provider_name', sa.String()), | |||
| column('credential_name', sa.String()), | |||
| column('encrypted_config', sa.Text()), | |||
| ) | |||
| # Get database connection | |||
| conn = op.get_bind() | |||
| # Query providers that have credential_id | |||
| providers_with_credentials = conn.execute( | |||
| sa.select(providers_table.c.id, providers_table.c.credential_id) | |||
| .where(providers_table.c.credential_id.isnot(None)) | |||
| ).fetchall() | |||
| # For each provider, get the credential data and update providers table | |||
| for provider in providers_with_credentials: | |||
| credential = conn.execute( | |||
| sa.select(provider_credential_table.c.encrypted_config) | |||
| .where(provider_credential_table.c.id == provider.credential_id) | |||
| ).fetchone() | |||
| if credential: | |||
| # Update providers table with encrypted_config from credential | |||
| conn.execute( | |||
| providers_table.update() | |||
| .where(providers_table.c.id == provider.id) | |||
| .values(encrypted_config=credential.encrypted_config) | |||
| ) | |||
| @@ -0,0 +1,186 @@ | |||
| """Add provider model multi credential support | |||
| Revision ID: 0e154742a5fa | |||
| Revises: e8446f481c1e | |||
| Create Date: 2025-08-13 16:05:42.657730 | |||
| """ | |||
| import uuid | |||
| from alembic import op | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.sql import table, column | |||
| # revision identifiers, used by Alembic. | |||
| revision = '0e154742a5fa' | |||
| down_revision = 'e8446f481c1e' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # Create provider_model_credentials table | |||
| op.create_table('provider_model_credentials', | |||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('provider_name', sa.String(length=255), nullable=False), | |||
| sa.Column('model_name', sa.String(length=255), nullable=False), | |||
| sa.Column('model_type', sa.String(length=40), nullable=False), | |||
| sa.Column('credential_name', sa.String(length=255), nullable=False), | |||
| sa.Column('encrypted_config', sa.Text(), nullable=False), | |||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), | |||
| sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), | |||
| sa.PrimaryKeyConstraint('id', name='provider_model_credential_pkey') | |||
| ) | |||
| # Create index for provider_model_credentials | |||
| with op.batch_alter_table('provider_model_credentials', schema=None) as batch_op: | |||
| batch_op.create_index('provider_model_credential_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_name', 'model_type'], unique=False) | |||
| # Add credential_id to provider_models table | |||
| with op.batch_alter_table('provider_models', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) | |||
| # Add credential_source_type to load_balancing_model_configs table | |||
| with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('credential_source_type', sa.String(length=40), nullable=True)) | |||
| # Migrate existing provider_models data | |||
| migrate_existing_provider_models_data() | |||
| # Remove encrypted_config column from provider_models table after migration | |||
| with op.batch_alter_table('provider_models', schema=None) as batch_op: | |||
| batch_op.drop_column('encrypted_config') | |||
| def migrate_existing_provider_models_data(): | |||
| """migrate provider_models table data to provider_model_credentials""" | |||
| # Define table structure for data manipulation | |||
| provider_models_table = table('provider_models', | |||
| column('id', models.types.StringUUID()), | |||
| column('tenant_id', models.types.StringUUID()), | |||
| column('provider_name', sa.String()), | |||
| column('model_name', sa.String()), | |||
| column('model_type', sa.String()), | |||
| column('encrypted_config', sa.Text()), | |||
| column('created_at', sa.DateTime()), | |||
| column('updated_at', sa.DateTime()), | |||
| column('credential_id', models.types.StringUUID()), | |||
| ) | |||
| provider_model_credentials_table = table('provider_model_credentials', | |||
| column('id', models.types.StringUUID()), | |||
| column('tenant_id', models.types.StringUUID()), | |||
| column('provider_name', sa.String()), | |||
| column('model_name', sa.String()), | |||
| column('model_type', sa.String()), | |||
| column('credential_name', sa.String()), | |||
| column('encrypted_config', sa.Text()), | |||
| column('created_at', sa.DateTime()), | |||
| column('updated_at', sa.DateTime()) | |||
| ) | |||
| # Get database connection | |||
| conn = op.get_bind() | |||
| # Query all existing provider_models data with encrypted_config | |||
| existing_provider_models = conn.execute( | |||
| sa.select(provider_models_table.c.id, provider_models_table.c.tenant_id, | |||
| provider_models_table.c.provider_name, provider_models_table.c.model_name, | |||
| provider_models_table.c.model_type, provider_models_table.c.encrypted_config, | |||
| provider_models_table.c.created_at, provider_models_table.c.updated_at) | |||
| .where(provider_models_table.c.encrypted_config.isnot(None)) | |||
| ).fetchall() | |||
| # Iterate through each provider_model and insert into provider_model_credentials | |||
| for provider_model in existing_provider_models: | |||
| if not provider_model.encrypted_config or provider_model.encrypted_config.strip() == '': | |||
| continue | |||
| credential_id = str(uuid.uuid4()) | |||
| # Insert into provider_model_credentials table | |||
| conn.execute( | |||
| provider_model_credentials_table.insert().values( | |||
| id=credential_id, | |||
| tenant_id=provider_model.tenant_id, | |||
| provider_name=provider_model.provider_name, | |||
| model_name=provider_model.model_name, | |||
| model_type=provider_model.model_type, | |||
| credential_name='API_KEY1', # Use a default name | |||
| encrypted_config=provider_model.encrypted_config, | |||
| created_at=provider_model.created_at, | |||
| updated_at=provider_model.updated_at | |||
| ) | |||
| ) | |||
| # Update original provider_models table, set credential_id | |||
| conn.execute( | |||
| provider_models_table.update() | |||
| .where(provider_models_table.c.id == provider_model.id) | |||
| .values(credential_id=credential_id) | |||
| ) | |||
| def downgrade(): | |||
| # Re-add encrypted_config column to provider_models table | |||
| with op.batch_alter_table('provider_models', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) | |||
| # Migrate data back from provider_model_credentials to provider_models | |||
| migrate_data_back_to_provider_models() | |||
| with op.batch_alter_table('provider_models', schema=None) as batch_op: | |||
| batch_op.drop_column('credential_id') | |||
| # Remove credential_source_type column from load_balancing_model_configs | |||
| with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: | |||
| batch_op.drop_column('credential_source_type') | |||
| # Drop provider_model_credentials table | |||
| op.drop_table('provider_model_credentials') | |||
| def migrate_data_back_to_provider_models(): | |||
| """Migrate data back from provider_model_credentials to provider_models table for downgrade""" | |||
| # Define table structure for data manipulation | |||
| provider_models_table = table('provider_models', | |||
| column('id', models.types.StringUUID()), | |||
| column('encrypted_config', sa.Text()), | |||
| column('credential_id', models.types.StringUUID()), | |||
| ) | |||
| provider_model_credentials_table = table('provider_model_credentials', | |||
| column('id', models.types.StringUUID()), | |||
| column('encrypted_config', sa.Text()), | |||
| ) | |||
| # Get database connection | |||
| conn = op.get_bind() | |||
| # Query provider_models that have credential_id | |||
| provider_models_with_credentials = conn.execute( | |||
| sa.select(provider_models_table.c.id, provider_models_table.c.credential_id) | |||
| .where(provider_models_table.c.credential_id.isnot(None)) | |||
| ).fetchall() | |||
| # For each provider_model, get the credential data and update provider_models table | |||
| for provider_model in provider_models_with_credentials: | |||
| credential = conn.execute( | |||
| sa.select(provider_model_credentials_table.c.encrypted_config) | |||
| .where(provider_model_credentials_table.c.id == provider_model.credential_id) | |||
| ).fetchone() | |||
| if credential: | |||
| # Update provider_models table with encrypted_config from credential | |||
| conn.execute( | |||
| provider_models_table.update() | |||
| .where(provider_models_table.c.id == provider_model.id) | |||
| .values(encrypted_config=credential.encrypted_config) | |||
| ) | |||
| @@ -1,5 +1,6 @@ | |||
| from datetime import datetime | |||
| from enum import Enum | |||
| from functools import cached_property | |||
| from typing import Optional | |||
| import sqlalchemy as sa | |||
| @@ -7,6 +8,7 @@ from sqlalchemy import DateTime, String, func, text | |||
| from sqlalchemy.orm import Mapped, mapped_column | |||
| from .base import Base | |||
| from .engine import db | |||
| from .types import StringUUID | |||
| @@ -60,9 +62,9 @@ class Provider(Base): | |||
| provider_type: Mapped[str] = mapped_column( | |||
| String(40), nullable=False, server_default=text("'custom'::character varying") | |||
| ) | |||
| encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) | |||
| is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) | |||
| last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) | |||
| credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) | |||
| quota_type: Mapped[Optional[str]] = mapped_column( | |||
| String(40), nullable=True, server_default=text("''::character varying") | |||
| @@ -79,6 +81,21 @@ class Provider(Base): | |||
| f" provider_type='{self.provider_type}')>" | |||
| ) | |||
| @cached_property | |||
| def credential(self): | |||
| if self.credential_id: | |||
| return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first() | |||
| @property | |||
| def credential_name(self): | |||
| credential = self.credential | |||
| return credential.credential_name if credential else None | |||
| @property | |||
| def encrypted_config(self): | |||
| credential = self.credential | |||
| return credential.encrypted_config if credential else None | |||
| @property | |||
| def token_is_set(self): | |||
| """ | |||
| @@ -116,11 +133,30 @@ class ProviderModel(Base): | |||
| provider_name: Mapped[str] = mapped_column(String(255), nullable=False) | |||
| model_name: Mapped[str] = mapped_column(String(255), nullable=False) | |||
| model_type: Mapped[str] = mapped_column(String(40), nullable=False) | |||
| encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) | |||
| credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) | |||
| is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) | |||
| created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @cached_property | |||
| def credential(self): | |||
| if self.credential_id: | |||
| return ( | |||
| db.session.query(ProviderModelCredential) | |||
| .where(ProviderModelCredential.id == self.credential_id) | |||
| .first() | |||
| ) | |||
| @property | |||
| def credential_name(self): | |||
| credential = self.credential | |||
| return credential.credential_name if credential else None | |||
| @property | |||
| def encrypted_config(self): | |||
| credential = self.credential | |||
| return credential.encrypted_config if credential else None | |||
| class TenantDefaultModel(Base): | |||
| __tablename__ = "tenant_default_models" | |||
| @@ -220,6 +256,56 @@ class LoadBalancingModelConfig(Base): | |||
| model_type: Mapped[str] = mapped_column(String(40), nullable=False) | |||
| name: Mapped[str] = mapped_column(String(255), nullable=False) | |||
| encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) | |||
| credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) | |||
| credential_source_type: Mapped[Optional[str]] = mapped_column(String(40), nullable=True) | |||
| enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) | |||
| created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class ProviderCredential(Base): | |||
| """ | |||
| Provider credential - stores multiple named credentials for each provider | |||
| """ | |||
| __tablename__ = "provider_credentials" | |||
| __table_args__ = ( | |||
| sa.PrimaryKeyConstraint("id", name="provider_credential_pkey"), | |||
| sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"), | |||
| ) | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| provider_name: Mapped[str] = mapped_column(String(255), nullable=False) | |||
| credential_name: Mapped[str] = mapped_column(String(255), nullable=False) | |||
| encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False) | |||
| created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class ProviderModelCredential(Base): | |||
| """ | |||
| Provider model credential - stores multiple named credentials for each provider model | |||
| """ | |||
| __tablename__ = "provider_model_credentials" | |||
| __table_args__ = ( | |||
| sa.PrimaryKeyConstraint("id", name="provider_model_credential_pkey"), | |||
| sa.Index( | |||
| "provider_model_credential_tenant_provider_model_idx", | |||
| "tenant_id", | |||
| "provider_name", | |||
| "model_name", | |||
| "model_type", | |||
| ), | |||
| ) | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| provider_name: Mapped[str] = mapped_column(String(255), nullable=False) | |||
| model_name: Mapped[str] = mapped_column(String(255), nullable=False) | |||
| model_type: Mapped[str] = mapped_column(String(40), nullable=False) | |||
| credential_name: Mapped[str] = mapped_column(String(255), nullable=False) | |||
| encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False) | |||
| created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @@ -8,7 +8,12 @@ from core.entities.model_entities import ( | |||
| ModelWithProviderEntity, | |||
| ProviderModelWithStatusEntity, | |||
| ) | |||
| from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration | |||
| from core.entities.provider_entities import ( | |||
| CredentialConfiguration, | |||
| CustomModelConfiguration, | |||
| ProviderQuotaType, | |||
| QuotaConfiguration, | |||
| ) | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.entities.provider_entities import ( | |||
| @@ -36,6 +41,10 @@ class CustomConfigurationResponse(BaseModel): | |||
| """ | |||
| status: CustomConfigurationStatus | |||
| current_credential_id: Optional[str] = None | |||
| current_credential_name: Optional[str] = None | |||
| available_credentials: Optional[list[CredentialConfiguration]] = None | |||
| custom_models: Optional[list[CustomModelConfiguration]] = None | |||
| class SystemConfigurationResponse(BaseModel): | |||
| @@ -3,3 +3,7 @@ from services.errors.base import BaseServiceError | |||
| class AppModelConfigBrokenError(BaseServiceError): | |||
| pass | |||
| class ProviderNotFoundError(BaseServiceError): | |||
| pass | |||
| @@ -17,7 +17,7 @@ from core.model_runtime.model_providers.model_provider_factory import ModelProvi | |||
| from core.provider_manager import ProviderManager | |||
| from extensions.ext_database import db | |||
| from libs.datetime_utils import naive_utc_now | |||
| from models.provider import LoadBalancingModelConfig | |||
| from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential | |||
| logger = logging.getLogger(__name__) | |||
| @@ -185,6 +185,7 @@ class ModelLoadBalancingService: | |||
| "id": load_balancing_config.id, | |||
| "name": load_balancing_config.name, | |||
| "credentials": credentials, | |||
| "credential_id": load_balancing_config.credential_id, | |||
| "enabled": load_balancing_config.enabled, | |||
| "in_cooldown": in_cooldown, | |||
| "ttl": ttl, | |||
| @@ -280,7 +281,7 @@ class ModelLoadBalancingService: | |||
| return inherit_config | |||
| def update_load_balancing_configs( | |||
| self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict] | |||
| self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str | |||
| ) -> None: | |||
| """ | |||
| Update load balancing configurations. | |||
| @@ -289,6 +290,7 @@ class ModelLoadBalancingService: | |||
| :param model: model name | |||
| :param model_type: model type | |||
| :param configs: load balancing configs | |||
| :param config_from: predefined-model or custom-model | |||
| :return: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| @@ -327,8 +329,37 @@ class ModelLoadBalancingService: | |||
| config_id = config.get("id") | |||
| name = config.get("name") | |||
| credentials = config.get("credentials") | |||
| credential_id = config.get("credential_id") | |||
| enabled = config.get("enabled") | |||
| if credential_id: | |||
| credential_record: ProviderCredential | ProviderModelCredential | None = None | |||
| if config_from == "predefined-model": | |||
| credential_record = ( | |||
| db.session.query(ProviderCredential) | |||
| .filter_by( | |||
| id=credential_id, | |||
| tenant_id=tenant_id, | |||
| provider_name=provider_configuration.provider.provider, | |||
| ) | |||
| .first() | |||
| ) | |||
| else: | |||
| credential_record = ( | |||
| db.session.query(ProviderModelCredential) | |||
| .filter_by( | |||
| id=credential_id, | |||
| tenant_id=tenant_id, | |||
| provider_name=provider_configuration.provider.provider, | |||
| model_name=model, | |||
| model_type=model_type_enum.to_origin_model_type(), | |||
| ) | |||
| .first() | |||
| ) | |||
| if not credential_record: | |||
| raise ValueError(f"Provider credential with id {credential_id} not found") | |||
| name = credential_record.credential_name | |||
| if not name: | |||
| raise ValueError("Invalid load balancing config name") | |||
| @@ -346,11 +377,6 @@ class ModelLoadBalancingService: | |||
| load_balancing_config = current_load_balancing_configs_dict[config_id] | |||
| # check duplicate name | |||
| for current_load_balancing_config in current_load_balancing_configs: | |||
| if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: | |||
| raise ValueError(f"Load balancing config name {name} already exists") | |||
| if credentials: | |||
| if not isinstance(credentials, dict): | |||
| raise ValueError("Invalid load balancing config credentials") | |||
| @@ -377,39 +403,48 @@ class ModelLoadBalancingService: | |||
| self._clear_credentials_cache(tenant_id, config_id) | |||
| else: | |||
| # create load balancing config | |||
| if name == "__inherit__": | |||
| if name in {"__inherit__", "__delete__"}: | |||
| raise ValueError("Invalid load balancing config name") | |||
| # check duplicate name | |||
| for current_load_balancing_config in current_load_balancing_configs: | |||
| if current_load_balancing_config.name == name: | |||
| raise ValueError(f"Load balancing config name {name} already exists") | |||
| if not credentials: | |||
| raise ValueError("Invalid load balancing config credentials") | |||
| if credential_id: | |||
| credential_source = "provider" if config_from == "predefined-model" else "custom_model" | |||
| assert credential_record is not None | |||
| load_balancing_model_config = LoadBalancingModelConfig( | |||
| tenant_id=tenant_id, | |||
| provider_name=provider_configuration.provider.provider, | |||
| model_type=model_type_enum.to_origin_model_type(), | |||
| model_name=model, | |||
| name=credential_record.credential_name, | |||
| encrypted_config=credential_record.encrypted_config, | |||
| credential_id=credential_id, | |||
| credential_source_type=credential_source, | |||
| ) | |||
| else: | |||
| if not credentials: | |||
| raise ValueError("Invalid load balancing config credentials") | |||
| if not isinstance(credentials, dict): | |||
| raise ValueError("Invalid load balancing config credentials") | |||
| if not isinstance(credentials, dict): | |||
| raise ValueError("Invalid load balancing config credentials") | |||
| # validate custom provider config | |||
| credentials = self._custom_credentials_validate( | |||
| tenant_id=tenant_id, | |||
| provider_configuration=provider_configuration, | |||
| model_type=model_type_enum, | |||
| model=model, | |||
| credentials=credentials, | |||
| validate=False, | |||
| ) | |||
| # validate custom provider config | |||
| credentials = self._custom_credentials_validate( | |||
| tenant_id=tenant_id, | |||
| provider_configuration=provider_configuration, | |||
| model_type=model_type_enum, | |||
| model=model, | |||
| credentials=credentials, | |||
| validate=False, | |||
| ) | |||
| # create load balancing config | |||
| load_balancing_model_config = LoadBalancingModelConfig( | |||
| tenant_id=tenant_id, | |||
| provider_name=provider_configuration.provider.provider, | |||
| model_type=model_type_enum.to_origin_model_type(), | |||
| model_name=model, | |||
| name=name, | |||
| encrypted_config=json.dumps(credentials), | |||
| ) | |||
| # create load balancing config | |||
| load_balancing_model_config = LoadBalancingModelConfig( | |||
| tenant_id=tenant_id, | |||
| provider_name=provider_configuration.provider.provider, | |||
| model_type=model_type_enum.to_origin_model_type(), | |||
| model_name=model, | |||
| name=name, | |||
| encrypted_config=json.dumps(credentials), | |||
| ) | |||
| db.session.add(load_balancing_model_config) | |||
| db.session.commit() | |||
| @@ -16,6 +16,7 @@ from services.entities.model_provider_entities import ( | |||
| SimpleProviderEntityResponse, | |||
| SystemConfigurationResponse, | |||
| ) | |||
| from services.errors.app_model_config import ProviderNotFoundError | |||
| logger = logging.getLogger(__name__) | |||
| @@ -28,6 +29,29 @@ class ModelProviderService: | |||
| def __init__(self) -> None: | |||
| self.provider_manager = ProviderManager() | |||
| def _get_provider_configuration(self, tenant_id: str, provider: str): | |||
| """ | |||
| Get provider configuration or raise exception if not found. | |||
| Args: | |||
| tenant_id: Workspace identifier | |||
| provider: Provider name | |||
| Returns: | |||
| Provider configuration instance | |||
| Raises: | |||
| ProviderNotFoundError: If provider doesn't exist | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| raise ProviderNotFoundError(f"Provider {provider} does not exist.") | |||
| return provider_configuration | |||
| def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]: | |||
| """ | |||
| get provider list. | |||
| @@ -46,6 +70,9 @@ class ModelProviderService: | |||
| if model_type_entity not in provider_configuration.provider.supported_model_types: | |||
| continue | |||
| provider_config = provider_configuration.custom_configuration.provider | |||
| model_config = provider_configuration.custom_configuration.models | |||
| provider_response = ProviderResponse( | |||
| tenant_id=tenant_id, | |||
| provider=provider_configuration.provider.provider, | |||
| @@ -63,7 +90,11 @@ class ModelProviderService: | |||
| custom_configuration=CustomConfigurationResponse( | |||
| status=CustomConfigurationStatus.ACTIVE | |||
| if provider_configuration.is_custom_configuration_available() | |||
| else CustomConfigurationStatus.NO_CONFIGURE | |||
| else CustomConfigurationStatus.NO_CONFIGURE, | |||
| current_credential_id=getattr(provider_config, "current_credential_id", None), | |||
| current_credential_name=getattr(provider_config, "current_credential_name", None), | |||
| available_credentials=getattr(provider_config, "available_credentials", []), | |||
| custom_models=model_config, | |||
| ), | |||
| system_configuration=SystemConfigurationResponse( | |||
| enabled=provider_configuration.system_configuration.enabled, | |||
| @@ -82,8 +113,8 @@ class ModelProviderService: | |||
| For the model provider page, | |||
| only supports passing in a single provider to query the list of supported models. | |||
| :param tenant_id: | |||
| :param provider: | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :return: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| @@ -95,150 +126,236 @@ class ModelProviderService: | |||
| for model in provider_configurations.get_models(provider=provider) | |||
| ] | |||
| def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]: | |||
| def get_provider_credential( | |||
| self, tenant_id: str, provider: str, credential_id: Optional[str] = None | |||
| ) -> Optional[dict]: | |||
| """ | |||
| get provider credentials. | |||
| """ | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| return provider_configuration.get_custom_credentials(obfuscated=True) | |||
| def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None: | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :param credential_id: credential id, if not provided, return current used credentials | |||
| :return: | |||
| """ | |||
| validate provider credentials. | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore | |||
| :param tenant_id: | |||
| :param provider: | |||
| :param credentials: | |||
| def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| validate provider credentials before saving. | |||
| # Get provider configuration | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| provider_configuration.custom_credentials_validate(credentials) | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :param credentials: provider credentials dict | |||
| """ | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.validate_provider_credentials(credentials) | |||
| def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: | |||
| def create_provider_credential( | |||
| self, tenant_id: str, provider: str, credentials: dict, credential_name: str | |||
| ) -> None: | |||
| """ | |||
| save custom provider config. | |||
| Create and save new provider credentials. | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :param credentials: provider credentials | |||
| :param credentials: provider credentials dict | |||
| :param credential_name: credential name | |||
| :return: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.create_provider_credential(credentials, credential_name) | |||
| # Get provider configuration | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| def update_provider_credential( | |||
| self, | |||
| tenant_id: str, | |||
| provider: str, | |||
| credentials: dict, | |||
| credential_id: str, | |||
| credential_name: str, | |||
| ) -> None: | |||
| """ | |||
| update a saved provider credential (by credential_id). | |||
| # Add or update custom provider credentials. | |||
| provider_configuration.add_or_update_custom_credentials(credentials) | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :param credentials: provider credentials dict | |||
| :param credential_id: credential id | |||
| :param credential_name: credential name | |||
| :return: | |||
| """ | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.update_provider_credential( | |||
| credential_id=credential_id, | |||
| credentials=credentials, | |||
| credential_name=credential_name, | |||
| ) | |||
| def remove_provider_credentials(self, tenant_id: str, provider: str) -> None: | |||
| def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None: | |||
| """ | |||
| remove custom provider config. | |||
| remove a saved provider credential (by credential_id). | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :param credential_id: credential id | |||
| :return: | |||
| """ | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.delete_provider_credential(credential_id=credential_id) | |||
| def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None: | |||
| """ | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :param credential_id: credential id | |||
| :return: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.switch_active_provider_credential(credential_id=credential_id) | |||
| # Get provider configuration | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| def get_model_credential( | |||
| self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None | |||
| ) -> Optional[dict]: | |||
| """ | |||
| Retrieve model-specific credentials. | |||
| # Remove custom provider credentials. | |||
| provider_configuration.delete_custom_credentials() | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :param credential_id: Optional credential ID, uses current if not provided | |||
| :return: | |||
| """ | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| return provider_configuration.get_custom_model_credential( # type: ignore | |||
| model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id | |||
| ) | |||
| def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> Optional[dict]: | |||
| def validate_model_credentials( | |||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict | |||
| ) -> None: | |||
| """ | |||
| get model credentials. | |||
| validate model credentials. | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :param credentials: model credentials dict | |||
| :return: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.validate_custom_model_credentials( | |||
| model_type=ModelType.value_of(model_type), model=model, credentials=credentials | |||
| ) | |||
| # Get provider configuration | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| def create_model_credential( | |||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | |||
| ) -> None: | |||
| """ | |||
| create and save model credentials. | |||
| # Get model custom credentials from ProviderModel if exists | |||
| return provider_configuration.get_custom_model_credentials( | |||
| model_type=ModelType.value_of(model_type), model=model, obfuscated=True | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :param credentials: model credentials dict | |||
| :param credential_name: credential name | |||
| :return: | |||
| """ | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.create_custom_model_credential( | |||
| model_type=ModelType.value_of(model_type), | |||
| model=model, | |||
| credentials=credentials, | |||
| credential_name=credential_name, | |||
| ) | |||
| def model_credentials_validate( | |||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict | |||
| def update_model_credential( | |||
| self, | |||
| tenant_id: str, | |||
| provider: str, | |||
| model_type: str, | |||
| model: str, | |||
| credentials: dict, | |||
| credential_id: str, | |||
| credential_name: str, | |||
| ) -> None: | |||
| """ | |||
| validate model credentials. | |||
| update model credentials. | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param credentials: model credentials dict | |||
| :param credential_id: credential id | |||
| :param credential_name: credential name | |||
| :return: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.update_custom_model_credential( | |||
| model_type=ModelType.value_of(model_type), | |||
| model=model, | |||
| credentials=credentials, | |||
| credential_id=credential_id, | |||
| credential_name=credential_name, | |||
| ) | |||
| # Get provider configuration | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| def remove_model_credential( | |||
| self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | |||
| ) -> None: | |||
| """ | |||
| remove model credentials. | |||
| # Validate model credentials | |||
| provider_configuration.custom_model_credentials_validate( | |||
| model_type=ModelType.value_of(model_type), model=model, credentials=credentials | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :param credential_id: credential id | |||
| :return: | |||
| """ | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.delete_custom_model_credential( | |||
| model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id | |||
| ) | |||
| def save_model_credentials( | |||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict | |||
| def switch_active_custom_model_credential( | |||
| self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | |||
| ) -> None: | |||
| """ | |||
| save model credentials. | |||
| switch model credentials. | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param credential_id: credential id | |||
| :return: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.switch_custom_model_credential( | |||
| model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id | |||
| ) | |||
| # Get provider configuration | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| def add_model_credential_to_model_list( | |||
| self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | |||
| ) -> None: | |||
| """ | |||
| add model credentials to model list. | |||
| # Add or update custom model credentials | |||
| provider_configuration.add_or_update_custom_model_credentials( | |||
| model_type=ModelType.value_of(model_type), model=model, credentials=credentials | |||
| :param tenant_id: workspace id | |||
| :param provider: provider name | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :param credential_id: credential id | |||
| :return: | |||
| """ | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.add_model_credential_to_model( | |||
| model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id | |||
| ) | |||
| def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: | |||
| def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: | |||
| """ | |||
| remove model credentials. | |||
| @@ -248,16 +365,8 @@ class ModelProviderService: | |||
| :param model: model name | |||
| :return: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| # Get provider configuration | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| # Remove custom model credentials | |||
| provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model) | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.delete_custom_model(model_type=ModelType.value_of(model_type), model=model) | |||
| def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: | |||
| """ | |||
| @@ -331,13 +440,7 @@ class ModelProviderService: | |||
| :param model: model name | |||
| :return: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| # Get provider configuration | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| # fetch credentials | |||
| credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) | |||
| @@ -424,17 +527,11 @@ class ModelProviderService: | |||
| :param preferred_provider_type: preferred provider type | |||
| :return: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| # Convert preferred_provider_type to ProviderType | |||
| preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type) | |||
| # Get provider configuration | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| # Switch preferred provider type | |||
| provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum) | |||
| @@ -448,15 +545,7 @@ class ModelProviderService: | |||
| :param model_type: model type | |||
| :return: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| # Get provider configuration | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| # Enable model | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) | |||
| def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: | |||
| @@ -469,13 +558,5 @@ class ModelProviderService: | |||
| :param model_type: model type | |||
| :return: | |||
| """ | |||
| # Get all provider configurations of the current workspace | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| # Get provider configuration | |||
| provider_configuration = provider_configurations.get(provider) | |||
| if not provider_configuration: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| # Enable model | |||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||
| provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) | |||
| @@ -235,10 +235,17 @@ class TestModelProviderService: | |||
| mock_provider_entity.provider_credential_schema = None | |||
| mock_provider_entity.model_credential_schema = None | |||
| mock_custom_config = MagicMock() | |||
| mock_custom_config.provider.current_credential_id = "credential-123" | |||
| mock_custom_config.provider.current_credential_name = "test-credential" | |||
| mock_custom_config.provider.available_credentials = [] | |||
| mock_custom_config.models = [] | |||
| mock_provider_config = MagicMock() | |||
| mock_provider_config.provider = mock_provider_entity | |||
| mock_provider_config.preferred_provider_type = ProviderType.CUSTOM | |||
| mock_provider_config.is_custom_configuration_available.return_value = True | |||
| mock_provider_config.custom_configuration = mock_custom_config | |||
| mock_provider_config.system_configuration.enabled = True | |||
| mock_provider_config.system_configuration.current_quota_type = "free" | |||
| mock_provider_config.system_configuration.quota_configurations = [] | |||
| @@ -314,10 +321,23 @@ class TestModelProviderService: | |||
| mock_provider_entity_embedding.provider_credential_schema = None | |||
| mock_provider_entity_embedding.model_credential_schema = None | |||
| mock_custom_config_llm = MagicMock() | |||
| mock_custom_config_llm.provider.current_credential_id = "credential-123" | |||
| mock_custom_config_llm.provider.current_credential_name = "test-credential" | |||
| mock_custom_config_llm.provider.available_credentials = [] | |||
| mock_custom_config_llm.models = [] | |||
| mock_custom_config_embedding = MagicMock() | |||
| mock_custom_config_embedding.provider.current_credential_id = "credential-456" | |||
| mock_custom_config_embedding.provider.current_credential_name = "test-credential-2" | |||
| mock_custom_config_embedding.provider.available_credentials = [] | |||
| mock_custom_config_embedding.models = [] | |||
| mock_provider_config_llm = MagicMock() | |||
| mock_provider_config_llm.provider = mock_provider_entity_llm | |||
| mock_provider_config_llm.preferred_provider_type = ProviderType.CUSTOM | |||
| mock_provider_config_llm.is_custom_configuration_available.return_value = True | |||
| mock_provider_config_llm.custom_configuration = mock_custom_config_llm | |||
| mock_provider_config_llm.system_configuration.enabled = True | |||
| mock_provider_config_llm.system_configuration.current_quota_type = "free" | |||
| mock_provider_config_llm.system_configuration.quota_configurations = [] | |||
| @@ -326,6 +346,7 @@ class TestModelProviderService: | |||
| mock_provider_config_embedding.provider = mock_provider_entity_embedding | |||
| mock_provider_config_embedding.preferred_provider_type = ProviderType.CUSTOM | |||
| mock_provider_config_embedding.is_custom_configuration_available.return_value = True | |||
| mock_provider_config_embedding.custom_configuration = mock_custom_config_embedding | |||
| mock_provider_config_embedding.system_configuration.enabled = True | |||
| mock_provider_config_embedding.system_configuration.current_quota_type = "free" | |||
| mock_provider_config_embedding.system_configuration.quota_configurations = [] | |||
| @@ -497,20 +518,29 @@ class TestModelProviderService: | |||
| } | |||
| mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} | |||
| # Expected result structure | |||
| expected_credentials = { | |||
| "credentials": { | |||
| "api_key": "sk-***123", | |||
| "base_url": "https://api.openai.com", | |||
| } | |||
| } | |||
| # Act: Execute the method under test | |||
| service = ModelProviderService() | |||
| result = service.get_provider_credentials(tenant.id, "openai") | |||
| with patch.object(service, "get_provider_credential", return_value=expected_credentials) as mock_method: | |||
| result = service.get_provider_credential(tenant.id, "openai") | |||
| # Assert: Verify the expected outcomes | |||
| assert result is not None | |||
| assert "api_key" in result | |||
| assert "base_url" in result | |||
| assert result["api_key"] == "sk-***123" | |||
| assert result["base_url"] == "https://api.openai.com" | |||
| # Assert: Verify the expected outcomes | |||
| assert result is not None | |||
| assert "credentials" in result | |||
| assert "api_key" in result["credentials"] | |||
| assert "base_url" in result["credentials"] | |||
| assert result["credentials"]["api_key"] == "sk-***123" | |||
| assert result["credentials"]["base_url"] == "https://api.openai.com" | |||
| # Verify mock interactions | |||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | |||
| mock_provider_configuration.get_custom_credentials.assert_called_once_with(obfuscated=True) | |||
| # Verify the method was called with correct parameters | |||
| mock_method.assert_called_once_with(tenant.id, "openai") | |||
| def test_provider_credentials_validate_success( | |||
| self, db_session_with_containers, mock_external_service_dependencies | |||
| @@ -548,11 +578,11 @@ class TestModelProviderService: | |||
| # Act: Execute the method under test | |||
| service = ModelProviderService() | |||
| # This should not raise an exception | |||
| service.provider_credentials_validate(tenant.id, "openai", test_credentials) | |||
| service.validate_provider_credentials(tenant.id, "openai", test_credentials) | |||
| # Assert: Verify mock interactions | |||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | |||
| mock_provider_configuration.custom_credentials_validate.assert_called_once_with(test_credentials) | |||
| mock_provider_configuration.validate_provider_credentials.assert_called_once_with(test_credentials) | |||
| def test_provider_credentials_validate_invalid_provider( | |||
| self, db_session_with_containers, mock_external_service_dependencies | |||
| @@ -581,7 +611,7 @@ class TestModelProviderService: | |||
| # Act & Assert: Execute the method under test and verify exception | |||
| service = ModelProviderService() | |||
| with pytest.raises(ValueError, match="Provider nonexistent does not exist."): | |||
| service.provider_credentials_validate(tenant.id, "nonexistent", test_credentials) | |||
| service.validate_provider_credentials(tenant.id, "nonexistent", test_credentials) | |||
| # Verify mock interactions | |||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | |||
| @@ -817,22 +847,29 @@ class TestModelProviderService: | |||
| } | |||
| mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} | |||
| # Expected result structure | |||
| expected_credentials = { | |||
| "credentials": { | |||
| "api_key": "sk-***123", | |||
| "base_url": "https://api.openai.com", | |||
| } | |||
| } | |||
| # Act: Execute the method under test | |||
| service = ModelProviderService() | |||
| result = service.get_model_credentials(tenant.id, "openai", "llm", "gpt-4") | |||
| with patch.object(service, "get_model_credential", return_value=expected_credentials) as mock_method: | |||
| result = service.get_model_credential(tenant.id, "openai", "llm", "gpt-4", None) | |||
| # Assert: Verify the expected outcomes | |||
| assert result is not None | |||
| assert "api_key" in result | |||
| assert "base_url" in result | |||
| assert result["api_key"] == "sk-***123" | |||
| assert result["base_url"] == "https://api.openai.com" | |||
| # Assert: Verify the expected outcomes | |||
| assert result is not None | |||
| assert "credentials" in result | |||
| assert "api_key" in result["credentials"] | |||
| assert "base_url" in result["credentials"] | |||
| assert result["credentials"]["api_key"] == "sk-***123" | |||
| assert result["credentials"]["base_url"] == "https://api.openai.com" | |||
| # Verify mock interactions | |||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | |||
| mock_provider_configuration.get_custom_model_credentials.assert_called_once_with( | |||
| model_type=ModelType.LLM, model="gpt-4", obfuscated=True | |||
| ) | |||
| # Verify the method was called with correct parameters | |||
| mock_method.assert_called_once_with(tenant.id, "openai", "llm", "gpt-4", None) | |||
| def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies): | |||
| """ | |||
| @@ -868,11 +905,11 @@ class TestModelProviderService: | |||
| # Act: Execute the method under test | |||
| service = ModelProviderService() | |||
| # This should not raise an exception | |||
| service.model_credentials_validate(tenant.id, "openai", "llm", "gpt-4", test_credentials) | |||
| service.validate_model_credentials(tenant.id, "openai", "llm", "gpt-4", test_credentials) | |||
| # Assert: Verify mock interactions | |||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | |||
| mock_provider_configuration.custom_model_credentials_validate.assert_called_once_with( | |||
| mock_provider_configuration.validate_custom_model_credentials.assert_called_once_with( | |||
| model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials | |||
| ) | |||
| @@ -909,12 +946,12 @@ class TestModelProviderService: | |||
| # Act: Execute the method under test | |||
| service = ModelProviderService() | |||
| service.save_model_credentials(tenant.id, "openai", "llm", "gpt-4", test_credentials) | |||
| service.create_model_credential(tenant.id, "openai", "llm", "gpt-4", test_credentials, "testname") | |||
| # Assert: Verify mock interactions | |||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | |||
| mock_provider_configuration.add_or_update_custom_model_credentials.assert_called_once_with( | |||
| model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials | |||
| mock_provider_configuration.create_custom_model_credential.assert_called_once_with( | |||
| model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials, credential_name="testname" | |||
| ) | |||
| def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): | |||
| @@ -942,17 +979,17 @@ class TestModelProviderService: | |||
| # Create mock provider configuration with remove method | |||
| mock_provider_configuration = MagicMock() | |||
| mock_provider_configuration.delete_custom_model_credentials.return_value = None | |||
| mock_provider_configuration.delete_custom_model_credential.return_value = None | |||
| mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} | |||
| # Act: Execute the method under test | |||
| service = ModelProviderService() | |||
| service.remove_model_credentials(tenant.id, "openai", "llm", "gpt-4") | |||
| service.remove_model_credential(tenant.id, "openai", "llm", "gpt-4", "5540007c-b988-46e0-b1c7-9b5fb9f330d6") | |||
| # Assert: Verify mock interactions | |||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | |||
| mock_provider_configuration.delete_custom_model_credentials.assert_called_once_with( | |||
| model_type=ModelType.LLM, model="gpt-4" | |||
| mock_provider_configuration.delete_custom_model_credential.assert_called_once_with( | |||
| model_type=ModelType.LLM, model="gpt-4", credential_id="5540007c-b988-46e0-b1c7-9b5fb9f330d6" | |||
| ) | |||
| def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies): | |||
| @@ -0,0 +1,308 @@ | |||
| from unittest.mock import Mock, patch | |||
| import pytest | |||
| from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus | |||
| from core.entities.provider_entities import ( | |||
| CustomConfiguration, | |||
| ModelSettings, | |||
| ProviderQuotaType, | |||
| QuotaConfiguration, | |||
| QuotaUnit, | |||
| RestrictModel, | |||
| SystemConfiguration, | |||
| ) | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity | |||
| from models.provider import Provider, ProviderType | |||
| @pytest.fixture | |||
| def mock_provider_entity(): | |||
| """Mock provider entity with basic configuration""" | |||
| provider_entity = ProviderEntity( | |||
| provider="openai", | |||
| label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), | |||
| description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI 提供商"), | |||
| icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"), | |||
| icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"), | |||
| background="background.png", | |||
| help=None, | |||
| supported_model_types=[ModelType.LLM], | |||
| configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], | |||
| provider_credential_schema=None, | |||
| model_credential_schema=None, | |||
| ) | |||
| return provider_entity | |||
| @pytest.fixture | |||
| def mock_system_configuration(): | |||
| """Mock system configuration""" | |||
| quota_config = QuotaConfiguration( | |||
| quota_type=ProviderQuotaType.TRIAL, | |||
| quota_unit=QuotaUnit.TOKENS, | |||
| quota_limit=1000, | |||
| quota_used=0, | |||
| is_valid=True, | |||
| restrict_models=[RestrictModel(model="gpt-4", reason="Experimental", model_type=ModelType.LLM)], | |||
| ) | |||
| system_config = SystemConfiguration( | |||
| enabled=True, | |||
| credentials={"openai_api_key": "test_key"}, | |||
| quota_configurations=[quota_config], | |||
| current_quota_type=ProviderQuotaType.TRIAL, | |||
| ) | |||
| return system_config | |||
| @pytest.fixture | |||
| def mock_custom_configuration(): | |||
| """Mock custom configuration""" | |||
| custom_config = CustomConfiguration(provider=None, models=[]) | |||
| return custom_config | |||
| @pytest.fixture | |||
| def provider_configuration(mock_provider_entity, mock_system_configuration, mock_custom_configuration): | |||
| """Create a test provider configuration instance""" | |||
| with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): | |||
| return ProviderConfiguration( | |||
| tenant_id="test_tenant", | |||
| provider=mock_provider_entity, | |||
| preferred_provider_type=ProviderType.SYSTEM, | |||
| using_provider_type=ProviderType.SYSTEM, | |||
| system_configuration=mock_system_configuration, | |||
| custom_configuration=mock_custom_configuration, | |||
| model_settings=[], | |||
| ) | |||
| class TestProviderConfiguration: | |||
| """Test cases for ProviderConfiguration class""" | |||
| def test_get_current_credentials_system_provider_success(self, provider_configuration): | |||
| """Test successfully getting credentials from system provider""" | |||
| # Arrange | |||
| provider_configuration.using_provider_type = ProviderType.SYSTEM | |||
| # Act | |||
| credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") | |||
| # Assert | |||
| assert credentials == {"openai_api_key": "test_key"} | |||
| def test_get_current_credentials_model_disabled(self, provider_configuration): | |||
| """Test getting credentials when model is disabled""" | |||
| # Arrange | |||
| model_setting = ModelSettings( | |||
| model="gpt-4", | |||
| model_type=ModelType.LLM, | |||
| enabled=False, | |||
| load_balancing_configs=[], | |||
| has_invalid_load_balancing_configs=False, | |||
| ) | |||
| provider_configuration.model_settings = [model_setting] | |||
| # Act & Assert | |||
| with pytest.raises(ValueError, match="Model gpt-4 is disabled"): | |||
| provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") | |||
| def test_get_current_credentials_custom_provider_with_models(self, provider_configuration): | |||
| """Test getting credentials from custom provider with model configurations""" | |||
| # Arrange | |||
| provider_configuration.using_provider_type = ProviderType.CUSTOM | |||
| mock_model_config = Mock() | |||
| mock_model_config.model_type = ModelType.LLM | |||
| mock_model_config.model = "gpt-4" | |||
| mock_model_config.credentials = {"openai_api_key": "custom_key"} | |||
| provider_configuration.custom_configuration.models = [mock_model_config] | |||
| # Act | |||
| credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") | |||
| # Assert | |||
| assert credentials == {"openai_api_key": "custom_key"} | |||
| def test_get_system_configuration_status_active(self, provider_configuration): | |||
| """Test getting active system configuration status""" | |||
| # Arrange | |||
| provider_configuration.system_configuration.enabled = True | |||
| # Act | |||
| status = provider_configuration.get_system_configuration_status() | |||
| # Assert | |||
| assert status == SystemConfigurationStatus.ACTIVE | |||
| def test_get_system_configuration_status_unsupported(self, provider_configuration): | |||
| """Test getting unsupported system configuration status""" | |||
| # Arrange | |||
| provider_configuration.system_configuration.enabled = False | |||
| # Act | |||
| status = provider_configuration.get_system_configuration_status() | |||
| # Assert | |||
| assert status == SystemConfigurationStatus.UNSUPPORTED | |||
| def test_get_system_configuration_status_quota_exceeded(self, provider_configuration): | |||
| """Test getting quota exceeded system configuration status""" | |||
| # Arrange | |||
| provider_configuration.system_configuration.enabled = True | |||
| quota_config = provider_configuration.system_configuration.quota_configurations[0] | |||
| quota_config.is_valid = False | |||
| # Act | |||
| status = provider_configuration.get_system_configuration_status() | |||
| # Assert | |||
| assert status == SystemConfigurationStatus.QUOTA_EXCEEDED | |||
| def test_is_custom_configuration_available_with_provider(self, provider_configuration): | |||
| """Test custom configuration availability with provider credentials""" | |||
| # Arrange | |||
| mock_provider = Mock() | |||
| mock_provider.available_credentials = ["openai_api_key"] | |||
| provider_configuration.custom_configuration.provider = mock_provider | |||
| provider_configuration.custom_configuration.models = [] | |||
| # Act | |||
| result = provider_configuration.is_custom_configuration_available() | |||
| # Assert | |||
| assert result is True | |||
| def test_is_custom_configuration_available_with_models(self, provider_configuration): | |||
| """Test custom configuration availability with model configurations""" | |||
| # Arrange | |||
| provider_configuration.custom_configuration.provider = None | |||
| provider_configuration.custom_configuration.models = [Mock()] | |||
| # Act | |||
| result = provider_configuration.is_custom_configuration_available() | |||
| # Assert | |||
| assert result is True | |||
| def test_is_custom_configuration_available_false(self, provider_configuration): | |||
| """Test custom configuration not available""" | |||
| # Arrange | |||
| provider_configuration.custom_configuration.provider = None | |||
| provider_configuration.custom_configuration.models = [] | |||
| # Act | |||
| result = provider_configuration.is_custom_configuration_available() | |||
| # Assert | |||
| assert result is False | |||
| @patch("core.entities.provider_configuration.Session") | |||
| def test_get_provider_record_found(self, mock_session, provider_configuration): | |||
| """Test getting provider record successfully""" | |||
| # Arrange | |||
| mock_provider = Mock(spec=Provider) | |||
| mock_session_instance = Mock() | |||
| mock_session.return_value.__enter__.return_value = mock_session_instance | |||
| mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_provider | |||
| # Act | |||
| result = provider_configuration._get_provider_record(mock_session_instance) | |||
| # Assert | |||
| assert result == mock_provider | |||
| @patch("core.entities.provider_configuration.Session") | |||
| def test_get_provider_record_not_found(self, mock_session, provider_configuration): | |||
| """Test getting provider record when not found""" | |||
| # Arrange | |||
| mock_session_instance = Mock() | |||
| mock_session.return_value.__enter__.return_value = mock_session_instance | |||
| mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None | |||
| # Act | |||
| result = provider_configuration._get_provider_record(mock_session_instance) | |||
| # Assert | |||
| assert result is None | |||
| def test_init_with_customizable_model_only( | |||
| self, mock_provider_entity, mock_system_configuration, mock_custom_configuration | |||
| ): | |||
| """Test initialization with customizable model only configuration""" | |||
| # Arrange | |||
| mock_provider_entity.configurate_methods = [ConfigurateMethod.CUSTOMIZABLE_MODEL] | |||
| # Act | |||
| with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): | |||
| config = ProviderConfiguration( | |||
| tenant_id="test_tenant", | |||
| provider=mock_provider_entity, | |||
| preferred_provider_type=ProviderType.SYSTEM, | |||
| using_provider_type=ProviderType.SYSTEM, | |||
| system_configuration=mock_system_configuration, | |||
| custom_configuration=mock_custom_configuration, | |||
| model_settings=[], | |||
| ) | |||
| # Assert | |||
| assert ConfigurateMethod.PREDEFINED_MODEL in config.provider.configurate_methods | |||
| def test_get_current_credentials_with_restricted_models(self, provider_configuration): | |||
| """Test getting credentials with model restrictions""" | |||
| # Arrange | |||
| provider_configuration.using_provider_type = ProviderType.SYSTEM | |||
| # Act | |||
| credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-3.5-turbo") | |||
| # Assert | |||
| assert credentials is not None | |||
| assert "openai_api_key" in credentials | |||
| @patch("core.entities.provider_configuration.Session") | |||
| def test_get_specific_provider_credential_success(self, mock_session, provider_configuration): | |||
| """Test getting specific provider credential successfully""" | |||
| # Arrange | |||
| credential_id = "test_credential_id" | |||
| mock_credential = Mock() | |||
| mock_credential.encrypted_config = '{"openai_api_key": "encrypted_key"}' | |||
| mock_session_instance = Mock() | |||
| mock_session.return_value.__enter__.return_value = mock_session_instance | |||
| mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_credential | |||
| # Act | |||
| with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get: | |||
| mock_get.return_value = {"openai_api_key": "test_key"} | |||
| result = provider_configuration._get_specific_provider_credential(credential_id) | |||
| # Assert | |||
| assert result == {"openai_api_key": "test_key"} | |||
| @patch("core.entities.provider_configuration.Session") | |||
| def test_get_specific_provider_credential_not_found(self, mock_session, provider_configuration): | |||
| """Test getting specific provider credential when not found""" | |||
| # Arrange | |||
| credential_id = "nonexistent_credential_id" | |||
| mock_session_instance = Mock() | |||
| mock_session.return_value.__enter__.return_value = mock_session_instance | |||
| mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None | |||
| # Act & Assert | |||
| with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get: | |||
| mock_get.return_value = None | |||
| result = provider_configuration._get_specific_provider_credential(credential_id) | |||
| assert result is None | |||
| # Act | |||
| credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") | |||
| # Assert | |||
| assert credentials == {"openai_api_key": "test_key"} | |||
| @@ -1,190 +1,185 @@ | |||
| # from core.entities.provider_entities import ModelSettings | |||
| # from core.model_runtime.entities.model_entities import ModelType | |||
| # from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||
| # from core.provider_manager import ProviderManager | |||
| # from models.provider import LoadBalancingModelConfig, ProviderModelSetting | |||
| # def test__to_model_settings(mocker): | |||
| # # Get all provider entities | |||
| # model_provider_factory = ModelProviderFactory("test_tenant") | |||
| # provider_entities = model_provider_factory.get_providers() | |||
| # provider_entity = None | |||
| # for provider in provider_entities: | |||
| # if provider.provider == "openai": | |||
| # provider_entity = provider | |||
| # # Mocking the inputs | |||
| # provider_model_settings = [ | |||
| # ProviderModelSetting( | |||
| # id="id", | |||
| # tenant_id="tenant_id", | |||
| # provider_name="openai", | |||
| # model_name="gpt-4", | |||
| # model_type="text-generation", | |||
| # enabled=True, | |||
| # load_balancing_enabled=True, | |||
| # ) | |||
| # ] | |||
| # load_balancing_model_configs = [ | |||
| # LoadBalancingModelConfig( | |||
| # id="id1", | |||
| # tenant_id="tenant_id", | |||
| # provider_name="openai", | |||
| # model_name="gpt-4", | |||
| # model_type="text-generation", | |||
| # name="__inherit__", | |||
| # encrypted_config=None, | |||
| # enabled=True, | |||
| # ), | |||
| # LoadBalancingModelConfig( | |||
| # id="id2", | |||
| # tenant_id="tenant_id", | |||
| # provider_name="openai", | |||
| # model_name="gpt-4", | |||
| # model_type="text-generation", | |||
| # name="first", | |||
| # encrypted_config='{"openai_api_key": "fake_key"}', | |||
| # enabled=True, | |||
| # ), | |||
| # ] | |||
| # mocker.patch( | |||
| # "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} | |||
| # ) | |||
| # provider_manager = ProviderManager() | |||
| # # Running the method | |||
| # result = provider_manager._to_model_settings(provider_entity, | |||
| # provider_model_settings, load_balancing_model_configs) | |||
| # # Asserting that the result is as expected | |||
| # assert len(result) == 1 | |||
| # assert isinstance(result[0], ModelSettings) | |||
| # assert result[0].model == "gpt-4" | |||
| # assert result[0].model_type == ModelType.LLM | |||
| # assert result[0].enabled is True | |||
| # assert len(result[0].load_balancing_configs) == 2 | |||
| # assert result[0].load_balancing_configs[0].name == "__inherit__" | |||
| # assert result[0].load_balancing_configs[1].name == "first" | |||
| # def test__to_model_settings_only_one_lb(mocker): | |||
| # # Get all provider entities | |||
| # model_provider_factory = ModelProviderFactory("test_tenant") | |||
| # provider_entities = model_provider_factory.get_providers() | |||
| # provider_entity = None | |||
| # for provider in provider_entities: | |||
| # if provider.provider == "openai": | |||
| # provider_entity = provider | |||
| # # Mocking the inputs | |||
| # provider_model_settings = [ | |||
| # ProviderModelSetting( | |||
| # id="id", | |||
| # tenant_id="tenant_id", | |||
| # provider_name="openai", | |||
| # model_name="gpt-4", | |||
| # model_type="text-generation", | |||
| # enabled=True, | |||
| # load_balancing_enabled=True, | |||
| # ) | |||
| # ] | |||
| # load_balancing_model_configs = [ | |||
| # LoadBalancingModelConfig( | |||
| # id="id1", | |||
| # tenant_id="tenant_id", | |||
| # provider_name="openai", | |||
| # model_name="gpt-4", | |||
| # model_type="text-generation", | |||
| # name="__inherit__", | |||
| # encrypted_config=None, | |||
| # enabled=True, | |||
| # ) | |||
| # ] | |||
| # mocker.patch( | |||
| # "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} | |||
| # ) | |||
| # provider_manager = ProviderManager() | |||
| # # Running the method | |||
| # result = provider_manager._to_model_settings( | |||
| # provider_entity, provider_model_settings, load_balancing_model_configs) | |||
| # # Asserting that the result is as expected | |||
| # assert len(result) == 1 | |||
| # assert isinstance(result[0], ModelSettings) | |||
| # assert result[0].model == "gpt-4" | |||
| # assert result[0].model_type == ModelType.LLM | |||
| # assert result[0].enabled is True | |||
| # assert len(result[0].load_balancing_configs) == 0 | |||
| # def test__to_model_settings_lb_disabled(mocker): | |||
| # # Get all provider entities | |||
| # model_provider_factory = ModelProviderFactory("test_tenant") | |||
| # provider_entities = model_provider_factory.get_providers() | |||
| # provider_entity = None | |||
| # for provider in provider_entities: | |||
| # if provider.provider == "openai": | |||
| # provider_entity = provider | |||
| # # Mocking the inputs | |||
| # provider_model_settings = [ | |||
| # ProviderModelSetting( | |||
| # id="id", | |||
| # tenant_id="tenant_id", | |||
| # provider_name="openai", | |||
| # model_name="gpt-4", | |||
| # model_type="text-generation", | |||
| # enabled=True, | |||
| # load_balancing_enabled=False, | |||
| # ) | |||
| # ] | |||
| # load_balancing_model_configs = [ | |||
| # LoadBalancingModelConfig( | |||
| # id="id1", | |||
| # tenant_id="tenant_id", | |||
| # provider_name="openai", | |||
| # model_name="gpt-4", | |||
| # model_type="text-generation", | |||
| # name="__inherit__", | |||
| # encrypted_config=None, | |||
| # enabled=True, | |||
| # ), | |||
| # LoadBalancingModelConfig( | |||
| # id="id2", | |||
| # tenant_id="tenant_id", | |||
| # provider_name="openai", | |||
| # model_name="gpt-4", | |||
| # model_type="text-generation", | |||
| # name="first", | |||
| # encrypted_config='{"openai_api_key": "fake_key"}', | |||
| # enabled=True, | |||
| # ), | |||
| # ] | |||
| # mocker.patch( | |||
| # "core.helper.model_provider_cache.ProviderCredentialsCache.get", | |||
| # return_value={"openai_api_key": "fake_key"} | |||
| # ) | |||
| # provider_manager = ProviderManager() | |||
| # # Running the method | |||
| # result = provider_manager._to_model_settings(provider_entity, | |||
| # provider_model_settings, load_balancing_model_configs) | |||
| # # Asserting that the result is as expected | |||
| # assert len(result) == 1 | |||
| # assert isinstance(result[0], ModelSettings) | |||
| # assert result[0].model == "gpt-4" | |||
| # assert result[0].model_type == ModelType.LLM | |||
| # assert result[0].enabled is True | |||
| # assert len(result[0].load_balancing_configs) == 0 | |||
| import pytest | |||
| from core.entities.provider_entities import ModelSettings | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.provider_manager import ProviderManager | |||
| from models.provider import LoadBalancingModelConfig, ProviderModelSetting | |||
| @pytest.fixture | |||
| def mock_provider_entity(mocker): | |||
| mock_entity = mocker.Mock() | |||
| mock_entity.provider = "openai" | |||
| mock_entity.configurate_methods = ["predefined-model"] | |||
| mock_entity.supported_model_types = [ModelType.LLM] | |||
| mock_entity.model_credential_schema = mocker.Mock() | |||
| mock_entity.model_credential_schema.credential_form_schemas = [] | |||
| return mock_entity | |||
| def test__to_model_settings(mocker, mock_provider_entity): | |||
| # Mocking the inputs | |||
| provider_model_settings = [ | |||
| ProviderModelSetting( | |||
| id="id", | |||
| tenant_id="tenant_id", | |||
| provider_name="openai", | |||
| model_name="gpt-4", | |||
| model_type="text-generation", | |||
| enabled=True, | |||
| load_balancing_enabled=True, | |||
| ) | |||
| ] | |||
| load_balancing_model_configs = [ | |||
| LoadBalancingModelConfig( | |||
| id="id1", | |||
| tenant_id="tenant_id", | |||
| provider_name="openai", | |||
| model_name="gpt-4", | |||
| model_type="text-generation", | |||
| name="__inherit__", | |||
| encrypted_config=None, | |||
| enabled=True, | |||
| ), | |||
| LoadBalancingModelConfig( | |||
| id="id2", | |||
| tenant_id="tenant_id", | |||
| provider_name="openai", | |||
| model_name="gpt-4", | |||
| model_type="text-generation", | |||
| name="first", | |||
| encrypted_config='{"openai_api_key": "fake_key"}', | |||
| enabled=True, | |||
| ), | |||
| ] | |||
| mocker.patch( | |||
| "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} | |||
| ) | |||
| provider_manager = ProviderManager() | |||
| # Running the method | |||
| result = provider_manager._to_model_settings( | |||
| provider_entity=mock_provider_entity, | |||
| provider_model_settings=provider_model_settings, | |||
| load_balancing_model_configs=load_balancing_model_configs, | |||
| ) | |||
| # Asserting that the result is as expected | |||
| assert len(result) == 1 | |||
| assert isinstance(result[0], ModelSettings) | |||
| assert result[0].model == "gpt-4" | |||
| assert result[0].model_type == ModelType.LLM | |||
| assert result[0].enabled is True | |||
| assert len(result[0].load_balancing_configs) == 2 | |||
| assert result[0].load_balancing_configs[0].name == "__inherit__" | |||
| assert result[0].load_balancing_configs[1].name == "first" | |||
| def test__to_model_settings_only_one_lb(mocker, mock_provider_entity): | |||
| # Mocking the inputs | |||
| provider_model_settings = [ | |||
| ProviderModelSetting( | |||
| id="id", | |||
| tenant_id="tenant_id", | |||
| provider_name="openai", | |||
| model_name="gpt-4", | |||
| model_type="text-generation", | |||
| enabled=True, | |||
| load_balancing_enabled=True, | |||
| ) | |||
| ] | |||
| load_balancing_model_configs = [ | |||
| LoadBalancingModelConfig( | |||
| id="id1", | |||
| tenant_id="tenant_id", | |||
| provider_name="openai", | |||
| model_name="gpt-4", | |||
| model_type="text-generation", | |||
| name="__inherit__", | |||
| encrypted_config=None, | |||
| enabled=True, | |||
| ) | |||
| ] | |||
| mocker.patch( | |||
| "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} | |||
| ) | |||
| provider_manager = ProviderManager() | |||
| # Running the method | |||
| result = provider_manager._to_model_settings( | |||
| provider_entity=mock_provider_entity, | |||
| provider_model_settings=provider_model_settings, | |||
| load_balancing_model_configs=load_balancing_model_configs, | |||
| ) | |||
| # Asserting that the result is as expected | |||
| assert len(result) == 1 | |||
| assert isinstance(result[0], ModelSettings) | |||
| assert result[0].model == "gpt-4" | |||
| assert result[0].model_type == ModelType.LLM | |||
| assert result[0].enabled is True | |||
| assert len(result[0].load_balancing_configs) == 0 | |||
| def test__to_model_settings_lb_disabled(mocker, mock_provider_entity): | |||
| # Mocking the inputs | |||
| provider_model_settings = [ | |||
| ProviderModelSetting( | |||
| id="id", | |||
| tenant_id="tenant_id", | |||
| provider_name="openai", | |||
| model_name="gpt-4", | |||
| model_type="text-generation", | |||
| enabled=True, | |||
| load_balancing_enabled=False, | |||
| ) | |||
| ] | |||
| load_balancing_model_configs = [ | |||
| LoadBalancingModelConfig( | |||
| id="id1", | |||
| tenant_id="tenant_id", | |||
| provider_name="openai", | |||
| model_name="gpt-4", | |||
| model_type="text-generation", | |||
| name="__inherit__", | |||
| encrypted_config=None, | |||
| enabled=True, | |||
| ), | |||
| LoadBalancingModelConfig( | |||
| id="id2", | |||
| tenant_id="tenant_id", | |||
| provider_name="openai", | |||
| model_name="gpt-4", | |||
| model_type="text-generation", | |||
| name="first", | |||
| encrypted_config='{"openai_api_key": "fake_key"}', | |||
| enabled=True, | |||
| ), | |||
| ] | |||
| mocker.patch( | |||
| "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} | |||
| ) | |||
| provider_manager = ProviderManager() | |||
| # Running the method | |||
| result = provider_manager._to_model_settings( | |||
| provider_entity=mock_provider_entity, | |||
| provider_model_settings=provider_model_settings, | |||
| load_balancing_model_configs=load_balancing_model_configs, | |||
| ) | |||
| # Asserting that the result is as expected | |||
| assert len(result) == 1 | |||
| assert isinstance(result[0], ModelSettings) | |||
| assert result[0].model == "gpt-4" | |||
| assert result[0].model_type == ModelType.LLM | |||
| assert result[0].enabled is True | |||
| assert len(result[0].load_balancing_configs) == 0 | |||
| @@ -30,7 +30,7 @@ const BaseField = ({ | |||
| inputClassName, | |||
| formSchema, | |||
| field, | |||
| disabled, | |||
| disabled: propsDisabled, | |||
| }: BaseFieldProps) => { | |||
| const renderI18nObject = useRenderI18nObject() | |||
| const { | |||
| @@ -40,7 +40,9 @@ const BaseField = ({ | |||
| options, | |||
| labelClassName: formLabelClassName, | |||
| show_on = [], | |||
| disabled: formSchemaDisabled, | |||
| } = formSchema | |||
| const disabled = propsDisabled || formSchemaDisabled | |||
| const memorizedLabel = useMemo(() => { | |||
| if (isValidElement(label)) | |||
| @@ -72,7 +74,7 @@ const BaseField = ({ | |||
| }) | |||
| const memorizedOptions = useMemo(() => { | |||
| return options?.filter((option) => { | |||
| if (!option.show_on?.length) | |||
| if (!option.show_on || option.show_on.length === 0) | |||
| return true | |||
| return option.show_on.every((condition) => { | |||
| @@ -85,7 +87,7 @@ const BaseField = ({ | |||
| value: option.value, | |||
| } | |||
| }) || [] | |||
| }, [options, renderI18nObject]) | |||
| }, [options, renderI18nObject, optionValues]) | |||
| const value = useStore(field.form.store, s => s.values[field.name]) | |||
| const values = useStore(field.form.store, (s) => { | |||
| return show_on.reduce((acc, condition) => { | |||
| @@ -182,9 +184,10 @@ const BaseField = ({ | |||
| className={cn( | |||
| 'system-sm-regular hover:bg-components-option-card-option-hover-bg hover:border-components-option-card-option-hover-border flex h-8 flex-[1] grow cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg p-2 text-text-secondary', | |||
| value === option.value && 'border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary shadow-xs', | |||
| disabled && 'cursor-not-allowed opacity-50', | |||
| inputClassName, | |||
| )} | |||
| onClick={() => field.handleChange(option.value)} | |||
| onClick={() => !disabled && field.handleChange(option.value)} | |||
| > | |||
| { | |||
| formSchema.showRadioUI && ( | |||
| @@ -1,34 +1,52 @@ | |||
| import { useCallback } from 'react' | |||
| import { | |||
| isValidElement, | |||
| useCallback, | |||
| } from 'react' | |||
| import type { ReactNode } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import type { FormSchema } from '../types' | |||
| import { useRenderI18nObject } from '@/hooks/use-i18n' | |||
| export const useGetValidators = () => { | |||
| const { t } = useTranslation() | |||
| const renderI18nObject = useRenderI18nObject() | |||
| const getLabel = useCallback((label: string | Record<string, string> | ReactNode) => { | |||
| if (isValidElement(label)) | |||
| return '' | |||
| if (typeof label === 'string') | |||
| return label | |||
| if (typeof label === 'object' && label !== null) | |||
| return renderI18nObject(label as Record<string, string>) | |||
| }, []) | |||
| const getValidators = useCallback((formSchema: FormSchema) => { | |||
| const { | |||
| name, | |||
| validators, | |||
| required, | |||
| label, | |||
| } = formSchema | |||
| let mergedValidators = validators | |||
| const memorizedLabel = getLabel(label) | |||
| if (required && !validators) { | |||
| mergedValidators = { | |||
| onMount: ({ value }: any) => { | |||
| if (!value) | |||
| return t('common.errorMsg.fieldRequired', { field: name }) | |||
| return t('common.errorMsg.fieldRequired', { field: memorizedLabel || name }) | |||
| }, | |||
| onChange: ({ value }: any) => { | |||
| if (!value) | |||
| return t('common.errorMsg.fieldRequired', { field: name }) | |||
| return t('common.errorMsg.fieldRequired', { field: memorizedLabel || name }) | |||
| }, | |||
| onBlur: ({ value }: any) => { | |||
| if (!value) | |||
| return t('common.errorMsg.fieldRequired', { field: name }) | |||
| return t('common.errorMsg.fieldRequired', { field: memorizedLabel }) | |||
| }, | |||
| } | |||
| } | |||
| return mergedValidators | |||
| }, [t]) | |||
| }, [t, getLabel]) | |||
| return { | |||
| getValidators, | |||
| @@ -59,6 +59,7 @@ export type FormSchema = { | |||
| labelClassName?: string | |||
| validators?: AnyValidators | |||
| showRadioUI?: boolean | |||
| disabled?: boolean | |||
| } | |||
| export type FormValues = Record<string, any> | |||
| @@ -86,6 +86,7 @@ export enum ModelStatusEnum { | |||
| quotaExceeded = 'quota-exceeded', | |||
| noPermission = 'no-permission', | |||
| disabled = 'disabled', | |||
| credentialRemoved = 'credential-removed', | |||
| } | |||
| export const MODEL_STATUS_TEXT: { [k: string]: TypeWithI18N } = { | |||
| @@ -153,6 +154,7 @@ export type ModelItem = { | |||
| model_properties: Record<string, string | number> | |||
| load_balancing_enabled: boolean | |||
| deprecated?: boolean | |||
| has_invalid_load_balancing_configs?: boolean | |||
| } | |||
| export enum PreferredProviderTypeEnum { | |||
| @@ -181,6 +183,29 @@ export type QuotaConfiguration = { | |||
| is_valid: boolean | |||
| } | |||
| export type Credential = { | |||
| credential_id: string | |||
| credential_name?: string | |||
| from_enterprise?: boolean | |||
| not_allowed_to_use?: boolean | |||
| } | |||
| export type CustomModel = { | |||
| model: string | |||
| model_type: ModelTypeEnum | |||
| } | |||
| export type CustomModelCredential = CustomModel & { | |||
| credentials?: Record<string, any> | |||
| available_model_credentials?: Credential[] | |||
| current_credential_id?: string | |||
| } | |||
| export type CredentialWithModel = Credential & { | |||
| model: string | |||
| model_type: ModelTypeEnum | |||
| } | |||
| export type ModelProvider = { | |||
| provider: string | |||
| label: TypeWithI18N | |||
| @@ -207,12 +232,17 @@ export type ModelProvider = { | |||
| preferred_provider_type: PreferredProviderTypeEnum | |||
| custom_configuration: { | |||
| status: CustomConfigurationStatusEnum | |||
| current_credential_id?: string | |||
| current_credential_name?: string | |||
| available_credentials?: Credential[] | |||
| custom_models?: CustomModelCredential[] | |||
| } | |||
| system_configuration: { | |||
| enabled: boolean | |||
| current_quota_type: CurrentSystemQuotaTypeEnum | |||
| quota_configurations: QuotaConfiguration[] | |||
| } | |||
| allow_custom_token?: boolean | |||
| } | |||
| export type Model = { | |||
| @@ -272,9 +302,24 @@ export type ModelLoadBalancingConfigEntry = { | |||
| in_cooldown?: boolean | |||
| /** cooldown time (in seconds) */ | |||
| ttl?: number | |||
| credential_id?: string | |||
| } | |||
| export type ModelLoadBalancingConfig = { | |||
| enabled: boolean | |||
| configs: ModelLoadBalancingConfigEntry[] | |||
| } | |||
| export type ProviderCredential = { | |||
| credentials: Record<string, any> | |||
| name: string | |||
| credential_id: string | |||
| } | |||
| export type ModelCredential = { | |||
| credentials: Record<string, any> | |||
| load_balancing: ModelLoadBalancingConfig | |||
| available_credentials: Credential[] | |||
| current_credential_id?: string | |||
| current_credential_name?: string | |||
| } | |||
| @@ -7,7 +7,9 @@ import { | |||
| import useSWR, { useSWRConfig } from 'swr' | |||
| import { useContext } from 'use-context-selector' | |||
| import type { | |||
| Credential, | |||
| CustomConfigurationModelFixedFields, | |||
| CustomModel, | |||
| DefaultModel, | |||
| DefaultModelResponse, | |||
| Model, | |||
| @@ -77,16 +79,17 @@ export const useProviderCredentialsAndLoadBalancing = ( | |||
| configurationMethod: ConfigurationMethodEnum, | |||
| configured?: boolean, | |||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | |||
| credentialId?: string, | |||
| ) => { | |||
| const { data: predefinedFormSchemasValue, mutate: mutatePredefined } = useSWR( | |||
| (configurationMethod === ConfigurationMethodEnum.predefinedModel && configured) | |||
| ? `/workspaces/current/model-providers/${provider}/credentials` | |||
| const { data: predefinedFormSchemasValue, mutate: mutatePredefined, isLoading: isPredefinedLoading } = useSWR( | |||
| (configurationMethod === ConfigurationMethodEnum.predefinedModel && configured && credentialId) | |||
| ? `/workspaces/current/model-providers/${provider}/credentials${credentialId ? `?credential_id=${credentialId}` : ''}` | |||
| : null, | |||
| fetchModelProviderCredentials, | |||
| ) | |||
| const { data: customFormSchemasValue, mutate: mutateCustomized } = useSWR( | |||
| (configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields) | |||
| ? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}` | |||
| const { data: customFormSchemasValue, mutate: mutateCustomized, isLoading: isCustomizedLoading } = useSWR( | |||
| (configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields && credentialId) | |||
| ? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}${credentialId ? `&credential_id=${credentialId}` : ''}` | |||
| : null, | |||
| fetchModelProviderCredentials, | |||
| ) | |||
| @@ -102,6 +105,7 @@ export const useProviderCredentialsAndLoadBalancing = ( | |||
| : undefined | |||
| }, [ | |||
| configurationMethod, | |||
| credentialId, | |||
| currentCustomConfigurationModelFixedFields, | |||
| customFormSchemasValue?.credentials, | |||
| predefinedFormSchemasValue?.credentials, | |||
| @@ -119,6 +123,7 @@ export const useProviderCredentialsAndLoadBalancing = ( | |||
| : customFormSchemasValue | |||
| )?.load_balancing, | |||
| mutate, | |||
| isLoading: isPredefinedLoading || isCustomizedLoading, | |||
| } | |||
| // as ([Record<string, string | boolean | undefined> | undefined, ModelLoadBalancingConfig | undefined]) | |||
| } | |||
| @@ -313,40 +318,59 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText: | |||
| } | |||
| } | |||
| export const useModelModalHandler = () => { | |||
| const setShowModelModal = useModalContextSelector(state => state.setShowModelModal) | |||
| export const useRefreshModel = () => { | |||
| const { eventEmitter } = useEventEmitterContextContext() | |||
| const updateModelProviders = useUpdateModelProviders() | |||
| const updateModelList = useUpdateModelList() | |||
| const { eventEmitter } = useEventEmitterContextContext() | |||
| const handleRefreshModel = useCallback((provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => { | |||
| updateModelProviders() | |||
| provider.supported_model_types.forEach((type) => { | |||
| updateModelList(type) | |||
| }) | |||
| if (configurationMethod === ConfigurationMethodEnum.customizableModel | |||
| && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) { | |||
| eventEmitter?.emit({ | |||
| type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST, | |||
| payload: provider.provider, | |||
| } as any) | |||
| if (CustomConfigurationModelFixedFields?.__model_type) | |||
| updateModelList(CustomConfigurationModelFixedFields.__model_type) | |||
| } | |||
| }, [eventEmitter, updateModelList, updateModelProviders]) | |||
| return { | |||
| handleRefreshModel, | |||
| } | |||
| } | |||
| export const useModelModalHandler = () => { | |||
| const setShowModelModal = useModalContextSelector(state => state.setShowModelModal) | |||
| const { handleRefreshModel } = useRefreshModel() | |||
| return ( | |||
| provider: ModelProvider, | |||
| configurationMethod: ConfigurationMethodEnum, | |||
| CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | |||
| isModelCredential?: boolean, | |||
| credential?: Credential, | |||
| model?: CustomModel, | |||
| onUpdate?: () => void, | |||
| ) => { | |||
| setShowModelModal({ | |||
| payload: { | |||
| currentProvider: provider, | |||
| currentConfigurationMethod: configurationMethod, | |||
| currentCustomConfigurationModelFixedFields: CustomConfigurationModelFixedFields, | |||
| isModelCredential, | |||
| credential, | |||
| model, | |||
| }, | |||
| onSaveCallback: () => { | |||
| updateModelProviders() | |||
| provider.supported_model_types.forEach((type) => { | |||
| updateModelList(type) | |||
| }) | |||
| if (configurationMethod === ConfigurationMethodEnum.customizableModel | |||
| && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) { | |||
| eventEmitter?.emit({ | |||
| type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST, | |||
| payload: provider.provider, | |||
| } as any) | |||
| if (CustomConfigurationModelFixedFields?.__model_type) | |||
| updateModelList(CustomConfigurationModelFixedFields.__model_type) | |||
| } | |||
| handleRefreshModel(provider, configurationMethod, CustomConfigurationModelFixedFields) | |||
| onUpdate?.() | |||
| }, | |||
| }) | |||
| } | |||
| @@ -8,8 +8,6 @@ import { | |||
| import SystemModelSelector from './system-model-selector' | |||
| import ProviderAddedCard from './provider-added-card' | |||
| import type { | |||
| ConfigurationMethodEnum, | |||
| CustomConfigurationModelFixedFields, | |||
| ModelProvider, | |||
| } from './declarations' | |||
| import { | |||
| @@ -18,7 +16,6 @@ import { | |||
| } from './declarations' | |||
| import { | |||
| useDefaultModel, | |||
| useModelModalHandler, | |||
| } from './hooks' | |||
| import InstallFromMarketplace from './install-from-marketplace' | |||
| import { useProviderContext } from '@/context/provider-context' | |||
| @@ -84,8 +81,6 @@ const ModelProviderPage = ({ searchText }: Props) => { | |||
| return [filteredConfiguredProviders, filteredNotConfiguredProviders] | |||
| }, [configuredProviders, debouncedSearchText, notConfiguredProviders]) | |||
| const handleOpenModal = useModelModalHandler() | |||
| return ( | |||
| <div className='relative -mt-2 pt-1'> | |||
| <div className={cn('mb-2 flex items-center')}> | |||
| @@ -126,7 +121,6 @@ const ModelProviderPage = ({ searchText }: Props) => { | |||
| <ProviderAddedCard | |||
| key={provider.provider} | |||
| provider={provider} | |||
| onOpenModal={(configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => handleOpenModal(provider, configurationMethod, currentCustomConfigurationModelFixedFields)} | |||
| /> | |||
| ))} | |||
| </div> | |||
| @@ -140,7 +134,6 @@ const ModelProviderPage = ({ searchText }: Props) => { | |||
| notConfigured | |||
| key={provider.provider} | |||
| provider={provider} | |||
| onOpenModal={(configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => handleOpenModal(provider, configurationMethod, currentCustomConfigurationModelFixedFields)} | |||
| /> | |||
| ))} | |||
| </div> | |||
| @@ -0,0 +1,115 @@ | |||
| import { | |||
| memo, | |||
| useCallback, | |||
| useMemo, | |||
| } from 'react' | |||
| import { RiAddLine } from '@remixicon/react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { Authorized } from '@/app/components/header/account-setting/model-provider-page/model-auth' | |||
| import cn from '@/utils/classnames' | |||
| import type { | |||
| Credential, | |||
| CustomModelCredential, | |||
| ModelCredential, | |||
| ModelProvider, | |||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| type AddCredentialInLoadBalancingProps = { | |||
| provider: ModelProvider | |||
| model: CustomModelCredential | |||
| configurationMethod: ConfigurationMethodEnum | |||
| modelCredential: ModelCredential | |||
| onSelectCredential: (credential: Credential) => void | |||
| onUpdate?: () => void | |||
| } | |||
| const AddCredentialInLoadBalancing = ({ | |||
| provider, | |||
| model, | |||
| configurationMethod, | |||
| modelCredential, | |||
| onSelectCredential, | |||
| onUpdate, | |||
| }: AddCredentialInLoadBalancingProps) => { | |||
| const { t } = useTranslation() | |||
| const { | |||
| available_credentials, | |||
| } = modelCredential | |||
| const customModel = configurationMethod === ConfigurationMethodEnum.customizableModel | |||
| const notAllowCustomCredential = provider.allow_custom_token === false | |||
| const ButtonComponent = useMemo(() => { | |||
| const Item = ( | |||
| <div className={cn( | |||
| 'system-sm-medium flex h-8 items-center rounded-lg px-3 text-text-accent hover:bg-state-base-hover', | |||
| notAllowCustomCredential && 'cursor-not-allowed opacity-50', | |||
| )}> | |||
| <RiAddLine className='mr-2 h-4 w-4' /> | |||
| { | |||
| customModel | |||
| ? t('common.modelProvider.auth.addCredential') | |||
| : t('common.modelProvider.auth.addApiKey') | |||
| } | |||
| </div> | |||
| ) | |||
| if (notAllowCustomCredential) { | |||
| return ( | |||
| <Tooltip | |||
| asChild | |||
| popupContent={t('plugin.auth.credentialUnavailable')} | |||
| > | |||
| {Item} | |||
| </Tooltip> | |||
| ) | |||
| } | |||
| return Item | |||
| }, [notAllowCustomCredential, t, customModel]) | |||
| const renderTrigger = useCallback((open?: boolean) => { | |||
| const Item = ( | |||
| <div className={cn( | |||
| 'system-sm-medium flex h-8 items-center rounded-lg px-3 text-text-accent hover:bg-state-base-hover', | |||
| open && 'bg-state-base-hover', | |||
| )}> | |||
| <RiAddLine className='mr-2 h-4 w-4' /> | |||
| { | |||
| customModel | |||
| ? t('common.modelProvider.auth.addCredential') | |||
| : t('common.modelProvider.auth.addApiKey') | |||
| } | |||
| </div> | |||
| ) | |||
| return Item | |||
| }, [t, customModel]) | |||
| if (!available_credentials?.length) | |||
| return ButtonComponent | |||
| return ( | |||
| <Authorized | |||
| provider={provider} | |||
| renderTrigger={renderTrigger} | |||
| items={[ | |||
| { | |||
| title: customModel ? t('common.modelProvider.auth.modelCredentials') : t('common.modelProvider.auth.apiKeys'), | |||
| model: customModel ? model : undefined, | |||
| credentials: available_credentials ?? [], | |||
| }, | |||
| ]} | |||
| configurationMethod={configurationMethod} | |||
| currentCustomConfigurationModelFixedFields={customModel ? { | |||
| __model_name: model.model, | |||
| __model_type: model.model_type, | |||
| } : undefined} | |||
| onItemClick={onSelectCredential} | |||
| placement='bottom-start' | |||
| onUpdate={onUpdate} | |||
| isModelCredential={customModel} | |||
| /> | |||
| ) | |||
| } | |||
| export default memo(AddCredentialInLoadBalancing) | |||
| @@ -0,0 +1,111 @@ | |||
| import { | |||
| memo, | |||
| useCallback, | |||
| useMemo, | |||
| } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { | |||
| RiAddCircleFill, | |||
| } from '@remixicon/react' | |||
| import { | |||
| Button, | |||
| } from '@/app/components/base/button' | |||
| import type { | |||
| CustomConfigurationModelFixedFields, | |||
| ModelProvider, | |||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import Authorized from './authorized' | |||
| import { | |||
| useAuth, | |||
| useCustomModels, | |||
| } from './hooks' | |||
| import cn from '@/utils/classnames' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| type AddCustomModelProps = { | |||
| provider: ModelProvider, | |||
| configurationMethod: ConfigurationMethodEnum, | |||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | |||
| } | |||
| const AddCustomModel = ({ | |||
| provider, | |||
| configurationMethod, | |||
| currentCustomConfigurationModelFixedFields, | |||
| }: AddCustomModelProps) => { | |||
| const { t } = useTranslation() | |||
| const customModels = useCustomModels(provider) | |||
| const noModels = !customModels.length | |||
| const { | |||
| handleOpenModal, | |||
| } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, true) | |||
| const notAllowCustomCredential = provider.allow_custom_token === false | |||
| const handleClick = useCallback(() => { | |||
| if (notAllowCustomCredential) | |||
| return | |||
| handleOpenModal() | |||
| }, [handleOpenModal, notAllowCustomCredential]) | |||
| const ButtonComponent = useMemo(() => { | |||
| const Item = ( | |||
| <Button | |||
| variant='ghost-accent' | |||
| size='small' | |||
| onClick={handleClick} | |||
| className={cn( | |||
| notAllowCustomCredential && 'cursor-not-allowed opacity-50', | |||
| )} | |||
| > | |||
| <RiAddCircleFill className='mr-1 h-3.5 w-3.5' /> | |||
| {t('common.modelProvider.addModel')} | |||
| </Button> | |||
| ) | |||
| if (notAllowCustomCredential) { | |||
| return ( | |||
| <Tooltip | |||
| asChild | |||
| popupContent={t('plugin.auth.credentialUnavailable')} | |||
| > | |||
| {Item} | |||
| </Tooltip> | |||
| ) | |||
| } | |||
| return Item | |||
| }, [handleClick, notAllowCustomCredential, t]) | |||
| const renderTrigger = useCallback((open?: boolean) => { | |||
| const Item = ( | |||
| <Button | |||
| variant='ghost' | |||
| size='small' | |||
| className={cn( | |||
| open && 'bg-components-button-ghost-bg-hover', | |||
| )} | |||
| > | |||
| <RiAddCircleFill className='mr-1 h-3.5 w-3.5' /> | |||
| {t('common.modelProvider.addModel')} | |||
| </Button> | |||
| ) | |||
| return Item | |||
| }, [t]) | |||
| if (noModels) | |||
| return ButtonComponent | |||
| return ( | |||
| <Authorized | |||
| provider={provider} | |||
| configurationMethod={ConfigurationMethodEnum.customizableModel} | |||
| items={customModels.map(model => ({ | |||
| model, | |||
| credentials: model.available_model_credentials ?? [], | |||
| }))} | |||
| renderTrigger={renderTrigger} | |||
| isModelCredential | |||
| enableAddModelCredential | |||
| bottomAddModelCredentialText={t('common.modelProvider.auth.addNewModel')} | |||
| /> | |||
| ) | |||
| } | |||
| export default memo(AddCustomModel) | |||
| @@ -0,0 +1,101 @@ | |||
| import { | |||
| memo, | |||
| useCallback, | |||
| } from 'react' | |||
| import { RiAddLine } from '@remixicon/react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import CredentialItem from './credential-item' | |||
| import type { | |||
| Credential, | |||
| CustomModel, | |||
| CustomModelCredential, | |||
| } from '../../declarations' | |||
| import Button from '@/app/components/base/button' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| type AuthorizedItemProps = { | |||
| model?: CustomModelCredential | |||
| title?: string | |||
| disabled?: boolean | |||
| onDelete?: (credential?: Credential, model?: CustomModel) => void | |||
| onEdit?: (credential?: Credential, model?: CustomModel) => void | |||
| showItemSelectedIcon?: boolean | |||
| selectedCredentialId?: string | |||
| credentials: Credential[] | |||
| onItemClick?: (credential: Credential, model?: CustomModel) => void | |||
| enableAddModelCredential?: boolean | |||
| notAllowCustomCredential?: boolean | |||
| } | |||
| export const AuthorizedItem = ({ | |||
| model, | |||
| title, | |||
| credentials, | |||
| disabled, | |||
| onDelete, | |||
| onEdit, | |||
| showItemSelectedIcon, | |||
| selectedCredentialId, | |||
| onItemClick, | |||
| enableAddModelCredential, | |||
| notAllowCustomCredential, | |||
| }: AuthorizedItemProps) => { | |||
| const { t } = useTranslation() | |||
| const handleEdit = useCallback((credential?: Credential) => { | |||
| onEdit?.(credential, model) | |||
| }, [onEdit, model]) | |||
| const handleDelete = useCallback((credential?: Credential) => { | |||
| onDelete?.(credential, model) | |||
| }, [onDelete, model]) | |||
| const handleItemClick = useCallback((credential: Credential) => { | |||
| onItemClick?.(credential, model) | |||
| }, [onItemClick, model]) | |||
| return ( | |||
| <div className='p-1'> | |||
| <div | |||
| className='flex h-9 items-center' | |||
| > | |||
| <div className='h-5 w-5 shrink-0'></div> | |||
| <div | |||
| className='system-md-medium mx-1 grow truncate text-text-primary' | |||
| title={title ?? model?.model} | |||
| > | |||
| {title ?? model?.model} | |||
| </div> | |||
| { | |||
| enableAddModelCredential && !notAllowCustomCredential && ( | |||
| <Tooltip | |||
| asChild | |||
| popupContent={t('common.modelProvider.auth.addModelCredential')} | |||
| > | |||
| <Button | |||
| className='h-6 w-6 shrink-0 rounded-full p-0' | |||
| size='small' | |||
| variant='secondary-accent' | |||
| onClick={() => handleEdit?.()} | |||
| > | |||
| <RiAddLine className='h-4 w-4' /> | |||
| </Button> | |||
| </Tooltip> | |||
| ) | |||
| } | |||
| </div> | |||
| { | |||
| credentials.map(credential => ( | |||
| <CredentialItem | |||
| key={credential.credential_id} | |||
| credential={credential} | |||
| disabled={disabled} | |||
| onDelete={handleDelete} | |||
| onEdit={handleEdit} | |||
| showSelectedIcon={showItemSelectedIcon} | |||
| selectedCredentialId={selectedCredentialId} | |||
| onItemClick={handleItemClick} | |||
| /> | |||
| )) | |||
| } | |||
| </div> | |||
| ) | |||
| } | |||
| export default memo(AuthorizedItem) | |||
| @@ -0,0 +1,137 @@ | |||
| import { | |||
| memo, | |||
| useMemo, | |||
| } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { | |||
| RiCheckLine, | |||
| RiDeleteBinLine, | |||
| RiEqualizer2Line, | |||
| } from '@remixicon/react' | |||
| import Indicator from '@/app/components/header/indicator' | |||
| import ActionButton from '@/app/components/base/action-button' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| import cn from '@/utils/classnames' | |||
| import type { Credential } from '../../declarations' | |||
| import Badge from '@/app/components/base/badge' | |||
| type CredentialItemProps = { | |||
| credential: Credential | |||
| disabled?: boolean | |||
| onDelete?: (credential: Credential) => void | |||
| onEdit?: (credential?: Credential) => void | |||
| onItemClick?: (credential: Credential) => void | |||
| disableRename?: boolean | |||
| disableEdit?: boolean | |||
| disableDelete?: boolean | |||
| showSelectedIcon?: boolean | |||
| selectedCredentialId?: string | |||
| } | |||
| const CredentialItem = ({ | |||
| credential, | |||
| disabled, | |||
| onDelete, | |||
| onEdit, | |||
| onItemClick, | |||
| disableRename, | |||
| disableEdit, | |||
| disableDelete, | |||
| showSelectedIcon, | |||
| selectedCredentialId, | |||
| }: CredentialItemProps) => { | |||
| const { t } = useTranslation() | |||
| const showAction = useMemo(() => { | |||
| return !(disableRename && disableEdit && disableDelete) | |||
| }, [disableRename, disableEdit, disableDelete]) | |||
| const Item = ( | |||
| <div | |||
| key={credential.credential_id} | |||
| className={cn( | |||
| 'group flex h-8 items-center rounded-lg p-1 hover:bg-state-base-hover', | |||
| (disabled || credential.not_allowed_to_use) && 'cursor-not-allowed opacity-50', | |||
| )} | |||
| onClick={() => { | |||
| if (disabled || credential.not_allowed_to_use) | |||
| return | |||
| onItemClick?.(credential) | |||
| }} | |||
| > | |||
| <div className='flex w-0 grow items-center space-x-1.5'> | |||
| { | |||
| showSelectedIcon && ( | |||
| <div className='h-4 w-4'> | |||
| { | |||
| selectedCredentialId === credential.credential_id && ( | |||
| <RiCheckLine className='h-4 w-4 text-text-accent' /> | |||
| ) | |||
| } | |||
| </div> | |||
| ) | |||
| } | |||
| <Indicator className='ml-2 mr-1.5 shrink-0' /> | |||
| <div | |||
| className='system-md-regular truncate text-text-secondary' | |||
| title={credential.credential_name} | |||
| > | |||
| {credential.credential_name} | |||
| </div> | |||
| </div> | |||
| { | |||
| credential.from_enterprise && ( | |||
| <Badge className='shrink-0'> | |||
| Enterprise | |||
| </Badge> | |||
| ) | |||
| } | |||
| { | |||
| showAction && ( | |||
| <div className='ml-2 hidden shrink-0 items-center group-hover:flex'> | |||
| { | |||
| !disableEdit && !credential.not_allowed_to_use && !credential.from_enterprise && ( | |||
| <Tooltip popupContent={t('common.operation.edit')}> | |||
| <ActionButton | |||
| disabled={disabled} | |||
| onClick={(e) => { | |||
| e.stopPropagation() | |||
| onEdit?.(credential) | |||
| }} | |||
| > | |||
| <RiEqualizer2Line className='h-4 w-4 text-text-tertiary' /> | |||
| </ActionButton> | |||
| </Tooltip> | |||
| ) | |||
| } | |||
| { | |||
| !disableDelete && !credential.from_enterprise && ( | |||
| <Tooltip popupContent={t('common.operation.delete')}> | |||
| <ActionButton | |||
| className='hover:bg-transparent' | |||
| disabled={disabled} | |||
| onClick={(e) => { | |||
| e.stopPropagation() | |||
| onDelete?.(credential) | |||
| }} | |||
| > | |||
| <RiDeleteBinLine className='h-4 w-4 text-text-tertiary hover:text-text-destructive' /> | |||
| </ActionButton> | |||
| </Tooltip> | |||
| ) | |||
| } | |||
| </div> | |||
| ) | |||
| } | |||
| </div> | |||
| ) | |||
| if (credential.not_allowed_to_use) { | |||
| return ( | |||
| <Tooltip popupContent={t('plugin.auth.customCredentialUnavailable')}> | |||
| {Item} | |||
| </Tooltip> | |||
| ) | |||
| } | |||
| return Item | |||
| } | |||
| export default memo(CredentialItem) | |||
| @@ -0,0 +1,222 @@ | |||
| import { | |||
| memo, | |||
| useCallback, | |||
| useMemo, | |||
| useState, | |||
| } from 'react' | |||
| import { | |||
| RiAddLine, | |||
| RiEqualizer2Line, | |||
| } from '@remixicon/react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { | |||
| PortalToFollowElem, | |||
| PortalToFollowElemContent, | |||
| PortalToFollowElemTrigger, | |||
| } from '@/app/components/base/portal-to-follow-elem' | |||
| import type { | |||
| PortalToFollowElemOptions, | |||
| } from '@/app/components/base/portal-to-follow-elem' | |||
| import Button from '@/app/components/base/button' | |||
| import cn from '@/utils/classnames' | |||
| import Confirm from '@/app/components/base/confirm' | |||
| import type { | |||
| ConfigurationMethodEnum, | |||
| Credential, | |||
| CustomConfigurationModelFixedFields, | |||
| CustomModel, | |||
| ModelProvider, | |||
| } from '../../declarations' | |||
| import { useAuth } from '../hooks' | |||
| import AuthorizedItem from './authorized-item' | |||
| type AuthorizedProps = { | |||
| provider: ModelProvider, | |||
| configurationMethod: ConfigurationMethodEnum, | |||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | |||
| isModelCredential?: boolean | |||
| items: { | |||
| title?: string | |||
| model?: CustomModel | |||
| credentials: Credential[] | |||
| }[] | |||
| selectedCredential?: Credential | |||
| disabled?: boolean | |||
| renderTrigger?: (open?: boolean) => React.ReactNode | |||
| isOpen?: boolean | |||
| onOpenChange?: (open: boolean) => void | |||
| offset?: PortalToFollowElemOptions['offset'] | |||
| placement?: PortalToFollowElemOptions['placement'] | |||
| triggerPopupSameWidth?: boolean | |||
| popupClassName?: string | |||
| showItemSelectedIcon?: boolean | |||
| onUpdate?: () => void | |||
| onItemClick?: (credential: Credential, model?: CustomModel) => void | |||
| enableAddModelCredential?: boolean | |||
| bottomAddModelCredentialText?: string | |||
| } | |||
| const Authorized = ({ | |||
| provider, | |||
| configurationMethod, | |||
| currentCustomConfigurationModelFixedFields, | |||
| items, | |||
| isModelCredential, | |||
| selectedCredential, | |||
| disabled, | |||
| renderTrigger, | |||
| isOpen, | |||
| onOpenChange, | |||
| offset = 8, | |||
| placement = 'bottom-end', | |||
| triggerPopupSameWidth = false, | |||
| popupClassName, | |||
| showItemSelectedIcon, | |||
| onUpdate, | |||
| onItemClick, | |||
| enableAddModelCredential, | |||
| bottomAddModelCredentialText, | |||
| }: AuthorizedProps) => { | |||
| const { t } = useTranslation() | |||
| const [isLocalOpen, setIsLocalOpen] = useState(false) | |||
| const mergedIsOpen = isOpen ?? isLocalOpen | |||
| const setMergedIsOpen = useCallback((open: boolean) => { | |||
| if (onOpenChange) | |||
| onOpenChange(open) | |||
| setIsLocalOpen(open) | |||
| }, [onOpenChange]) | |||
| const { | |||
| openConfirmDelete, | |||
| closeConfirmDelete, | |||
| doingAction, | |||
| handleActiveCredential, | |||
| handleConfirmDelete, | |||
| deleteCredentialId, | |||
| handleOpenModal, | |||
| } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onUpdate) | |||
| const handleEdit = useCallback((credential?: Credential, model?: CustomModel) => { | |||
| handleOpenModal(credential, model) | |||
| setMergedIsOpen(false) | |||
| }, [handleOpenModal, setMergedIsOpen]) | |||
| const handleItemClick = useCallback((credential: Credential, model?: CustomModel) => { | |||
| if (onItemClick) | |||
| onItemClick(credential, model) | |||
| else | |||
| handleActiveCredential(credential, model) | |||
| setMergedIsOpen(false) | |||
| }, [handleActiveCredential, onItemClick, setMergedIsOpen]) | |||
| const notAllowCustomCredential = provider.allow_custom_token === false | |||
| const Trigger = useMemo(() => { | |||
| const Item = ( | |||
| <Button | |||
| className='grow' | |||
| size='small' | |||
| > | |||
| <RiEqualizer2Line className='mr-1 h-3.5 w-3.5' /> | |||
| {t('common.operation.config')} | |||
| </Button> | |||
| ) | |||
| return Item | |||
| }, [t]) | |||
| return ( | |||
| <> | |||
| <PortalToFollowElem | |||
| open={mergedIsOpen} | |||
| onOpenChange={setMergedIsOpen} | |||
| placement={placement} | |||
| offset={offset} | |||
| triggerPopupSameWidth={triggerPopupSameWidth} | |||
| > | |||
| <PortalToFollowElemTrigger | |||
| onClick={() => { | |||
| setMergedIsOpen(!mergedIsOpen) | |||
| }} | |||
| asChild | |||
| > | |||
| { | |||
| renderTrigger | |||
| ? renderTrigger(mergedIsOpen) | |||
| : Trigger | |||
| } | |||
| </PortalToFollowElemTrigger> | |||
| <PortalToFollowElemContent className='z-[100]'> | |||
| <div className={cn( | |||
| 'w-[360px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg', | |||
| popupClassName, | |||
| )}> | |||
| <div className='max-h-[304px] overflow-y-auto'> | |||
| { | |||
| items.map((item, index) => ( | |||
| <AuthorizedItem | |||
| key={index} | |||
| title={item.title} | |||
| model={item.model} | |||
| credentials={item.credentials} | |||
| disabled={disabled} | |||
| onDelete={openConfirmDelete} | |||
| onEdit={handleEdit} | |||
| showItemSelectedIcon={showItemSelectedIcon} | |||
| selectedCredentialId={selectedCredential?.credential_id} | |||
| onItemClick={handleItemClick} | |||
| enableAddModelCredential={enableAddModelCredential} | |||
| notAllowCustomCredential={notAllowCustomCredential} | |||
| /> | |||
| )) | |||
| } | |||
| </div> | |||
| <div className='h-[1px] bg-divider-subtle'></div> | |||
| { | |||
| isModelCredential && !notAllowCustomCredential && ( | |||
| <div | |||
| onClick={() => handleEdit( | |||
| undefined, | |||
| currentCustomConfigurationModelFixedFields | |||
| ? { | |||
| model: currentCustomConfigurationModelFixedFields.__model_name, | |||
| model_type: currentCustomConfigurationModelFixedFields.__model_type, | |||
| } | |||
| : undefined, | |||
| )} | |||
| className='system-xs-medium flex h-[30px] cursor-pointer items-center px-3 text-text-accent-light-mode-only' | |||
| > | |||
| <RiAddLine className='mr-1 h-4 w-4' /> | |||
| {bottomAddModelCredentialText ?? t('common.modelProvider.auth.addModelCredential')} | |||
| </div> | |||
| ) | |||
| } | |||
| { | |||
| !isModelCredential && !notAllowCustomCredential && ( | |||
| <div className='p-2'> | |||
| <Button | |||
| onClick={() => handleEdit()} | |||
| className='w-full' | |||
| > | |||
| {t('common.modelProvider.auth.addApiKey')} | |||
| </Button> | |||
| </div> | |||
| ) | |||
| } | |||
| </div> | |||
| </PortalToFollowElemContent> | |||
| </PortalToFollowElem> | |||
| { | |||
| deleteCredentialId && ( | |||
| <Confirm | |||
| isShow | |||
| title={t('common.modelProvider.confirmDelete')} | |||
| isDisabled={doingAction} | |||
| onCancel={closeConfirmDelete} | |||
| onConfirm={handleConfirmDelete} | |||
| /> | |||
| ) | |||
| } | |||
| </> | |||
| ) | |||
| } | |||
| export default memo(Authorized) | |||
| @@ -0,0 +1,76 @@ | |||
| import { memo } from 'react' | |||
| import { | |||
| RiEqualizer2Line, | |||
| RiScales3Line, | |||
| } from '@remixicon/react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import Button from '@/app/components/base/button' | |||
| import Indicator from '@/app/components/header/indicator' | |||
| import cn from '@/utils/classnames' | |||
| type ConfigModelProps = { | |||
| onClick?: () => void | |||
| loadBalancingEnabled?: boolean | |||
| loadBalancingInvalid?: boolean | |||
| credentialRemoved?: boolean | |||
| } | |||
| const ConfigModel = ({ | |||
| onClick, | |||
| loadBalancingEnabled, | |||
| loadBalancingInvalid, | |||
| credentialRemoved, | |||
| }: ConfigModelProps) => { | |||
| const { t } = useTranslation() | |||
| if (loadBalancingInvalid) { | |||
| return ( | |||
| <div | |||
| className='system-2xs-medium-uppercase relative flex h-[18px] items-center rounded-[5px] border border-text-warning bg-components-badge-bg-dimm px-1.5 text-text-warning' | |||
| onClick={onClick} | |||
| > | |||
| <RiScales3Line className='mr-0.5 h-3 w-3' /> | |||
| {t('common.modelProvider.auth.authorizationError')} | |||
| <Indicator color='orange' className='absolute right-[-1px] top-[-1px] h-1.5 w-1.5' /> | |||
| </div> | |||
| ) | |||
| } | |||
| return ( | |||
| <Button | |||
| variant='secondary' | |||
| size='small' | |||
| className={cn( | |||
| 'hidden shrink-0 group-hover:flex', | |||
| credentialRemoved && 'flex', | |||
| )} | |||
| onClick={onClick} | |||
| > | |||
| { | |||
| credentialRemoved && ( | |||
| <> | |||
| {t('common.modelProvider.auth.credentialRemoved')} | |||
| <Indicator color='red' className='ml-2' /> | |||
| </> | |||
| ) | |||
| } | |||
| { | |||
| !loadBalancingEnabled && !credentialRemoved && !loadBalancingInvalid && ( | |||
| <> | |||
| <RiEqualizer2Line className='mr-1 h-4 w-4' /> | |||
| {t('common.operation.config')} | |||
| </> | |||
| ) | |||
| } | |||
| { | |||
| loadBalancingEnabled && !credentialRemoved && !loadBalancingInvalid && ( | |||
| <> | |||
| <RiScales3Line className='mr-1 h-4 w-4' /> | |||
| {t('common.modelProvider.auth.configLoadBalancing')} | |||
| </> | |||
| ) | |||
| } | |||
| </Button> | |||
| ) | |||
| } | |||
| export default memo(ConfigModel) | |||
| @@ -0,0 +1,96 @@ | |||
| import { | |||
| memo, | |||
| useCallback, | |||
| useMemo, | |||
| } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { | |||
| RiEqualizer2Line, | |||
| } from '@remixicon/react' | |||
| import { | |||
| Button, | |||
| } from '@/app/components/base/button' | |||
| import type { | |||
| CustomConfigurationModelFixedFields, | |||
| ModelProvider, | |||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import Authorized from './authorized' | |||
| import { useAuth, useCredentialStatus } from './hooks' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| import cn from '@/utils/classnames' | |||
| type ConfigProviderProps = { | |||
| provider: ModelProvider, | |||
| configurationMethod: ConfigurationMethodEnum, | |||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | |||
| } | |||
| const ConfigProvider = ({ | |||
| provider, | |||
| configurationMethod, | |||
| currentCustomConfigurationModelFixedFields, | |||
| }: ConfigProviderProps) => { | |||
| const { t } = useTranslation() | |||
| const { | |||
| handleOpenModal, | |||
| } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields) | |||
| const { | |||
| hasCredential, | |||
| authorized, | |||
| current_credential_id, | |||
| current_credential_name, | |||
| available_credentials, | |||
| } = useCredentialStatus(provider) | |||
| const notAllowCustomCredential = provider.allow_custom_token === false | |||
| const handleClick = useCallback(() => { | |||
| if (!hasCredential && !notAllowCustomCredential) | |||
| handleOpenModal() | |||
| }, [handleOpenModal, hasCredential, notAllowCustomCredential]) | |||
| const ButtonComponent = useMemo(() => { | |||
| const Item = ( | |||
| <Button | |||
| className={cn('grow', notAllowCustomCredential && 'cursor-not-allowed opacity-50')} | |||
| size='small' | |||
| onClick={handleClick} | |||
| variant={!authorized ? 'secondary-accent' : 'secondary'} | |||
| > | |||
| <RiEqualizer2Line className='mr-1 h-3.5 w-3.5' /> | |||
| {t('common.operation.setup')} | |||
| </Button> | |||
| ) | |||
| if (notAllowCustomCredential) { | |||
| return ( | |||
| <Tooltip | |||
| asChild | |||
| popupContent={t('plugin.auth.credentialUnavailable')} | |||
| > | |||
| {Item} | |||
| </Tooltip> | |||
| ) | |||
| } | |||
| return Item | |||
| }, [handleClick, authorized, notAllowCustomCredential, t]) | |||
| if (!hasCredential) | |||
| return ButtonComponent | |||
| return ( | |||
| <Authorized | |||
| provider={provider} | |||
| configurationMethod={ConfigurationMethodEnum.predefinedModel} | |||
| items={[ | |||
| { | |||
| title: t('common.modelProvider.auth.apiKeys'), | |||
| credentials: available_credentials ?? [], | |||
| }, | |||
| ]} | |||
| selectedCredential={{ | |||
| credential_id: current_credential_id ?? '', | |||
| credential_name: current_credential_name ?? '', | |||
| }} | |||
| showItemSelectedIcon | |||
| /> | |||
| ) | |||
| } | |||
| export default memo(ConfigProvider) | |||
| @@ -0,0 +1,6 @@ | |||
| export * from './use-model-form-schemas' | |||
| export * from './use-credential-status' | |||
| export * from './use-custom-models' | |||
| export * from './use-auth' | |||
| export * from './use-auth-service' | |||
| export * from './use-credential-data' | |||
| @@ -0,0 +1,57 @@ | |||
| import { useCallback } from 'react' | |||
| import { | |||
| useActiveModelCredential, | |||
| useActiveProviderCredential, | |||
| useAddModelCredential, | |||
| useAddProviderCredential, | |||
| useDeleteModelCredential, | |||
| useDeleteProviderCredential, | |||
| useEditModelCredential, | |||
| useEditProviderCredential, | |||
| useGetModelCredential, | |||
| useGetProviderCredential, | |||
| } from '@/service/use-models' | |||
| import type { | |||
| CustomModel, | |||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| export const useGetCredential = (provider: string, isModelCredential?: boolean, credentialId?: string, model?: CustomModel, configFrom?: string) => { | |||
| const providerData = useGetProviderCredential(!isModelCredential && !!credentialId, provider, credentialId) | |||
| const modelData = useGetModelCredential(!!isModelCredential && !!credentialId, provider, credentialId, model?.model, model?.model_type, configFrom) | |||
| return isModelCredential ? modelData : providerData | |||
| } | |||
| export const useAuthService = (provider: string) => { | |||
| const { mutateAsync: addProviderCredential } = useAddProviderCredential(provider) | |||
| const { mutateAsync: editProviderCredential } = useEditProviderCredential(provider) | |||
| const { mutateAsync: deleteProviderCredential } = useDeleteProviderCredential(provider) | |||
| const { mutateAsync: activeProviderCredential } = useActiveProviderCredential(provider) | |||
| const { mutateAsync: addModelCredential } = useAddModelCredential(provider) | |||
| const { mutateAsync: activeModelCredential } = useActiveModelCredential(provider) | |||
| const { mutateAsync: deleteModelCredential } = useDeleteModelCredential(provider) | |||
| const { mutateAsync: editModelCredential } = useEditModelCredential(provider) | |||
| const getAddCredentialService = useCallback((isModel: boolean) => { | |||
| return isModel ? addModelCredential : addProviderCredential | |||
| }, [addModelCredential, addProviderCredential]) | |||
| const getEditCredentialService = useCallback((isModel: boolean) => { | |||
| return isModel ? editModelCredential : editProviderCredential | |||
| }, [editModelCredential, editProviderCredential]) | |||
| const getDeleteCredentialService = useCallback((isModel: boolean) => { | |||
| return isModel ? deleteModelCredential : deleteProviderCredential | |||
| }, [deleteModelCredential, deleteProviderCredential]) | |||
| const getActiveCredentialService = useCallback((isModel: boolean) => { | |||
| return isModel ? activeModelCredential : activeProviderCredential | |||
| }, [activeModelCredential, activeProviderCredential]) | |||
| return { | |||
| getAddCredentialService, | |||
| getEditCredentialService, | |||
| getDeleteCredentialService, | |||
| getActiveCredentialService, | |||
| } | |||
| } | |||
| @@ -0,0 +1,158 @@ | |||
| import { | |||
| useCallback, | |||
| useRef, | |||
| useState, | |||
| } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { useToastContext } from '@/app/components/base/toast' | |||
| import { useAuthService } from './use-auth-service' | |||
| import type { | |||
| ConfigurationMethodEnum, | |||
| Credential, | |||
| CustomConfigurationModelFixedFields, | |||
| CustomModel, | |||
| ModelProvider, | |||
| } from '../../declarations' | |||
| import { | |||
| useModelModalHandler, | |||
| useRefreshModel, | |||
| } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| export const useAuth = ( | |||
| provider: ModelProvider, | |||
| configurationMethod: ConfigurationMethodEnum, | |||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | |||
| isModelCredential?: boolean, | |||
| onUpdate?: () => void, | |||
| ) => { | |||
| const { t } = useTranslation() | |||
| const { notify } = useToastContext() | |||
| const { | |||
| getDeleteCredentialService, | |||
| getActiveCredentialService, | |||
| getEditCredentialService, | |||
| getAddCredentialService, | |||
| } = useAuthService(provider.provider) | |||
| const handleOpenModelModal = useModelModalHandler() | |||
| const { handleRefreshModel } = useRefreshModel() | |||
| const pendingOperationCredentialId = useRef<string | null>(null) | |||
| const pendingOperationModel = useRef<CustomModel | null>(null) | |||
| const [deleteCredentialId, setDeleteCredentialId] = useState<string | null>(null) | |||
| const openConfirmDelete = useCallback((credential?: Credential, model?: CustomModel) => { | |||
| if (credential) | |||
| pendingOperationCredentialId.current = credential.credential_id | |||
| if (model) | |||
| pendingOperationModel.current = model | |||
| setDeleteCredentialId(pendingOperationCredentialId.current) | |||
| }, []) | |||
| const closeConfirmDelete = useCallback(() => { | |||
| setDeleteCredentialId(null) | |||
| pendingOperationCredentialId.current = null | |||
| }, []) | |||
| const [doingAction, setDoingAction] = useState(false) | |||
| const doingActionRef = useRef(doingAction) | |||
| const handleSetDoingAction = useCallback((doing: boolean) => { | |||
| doingActionRef.current = doing | |||
| setDoingAction(doing) | |||
| }, []) | |||
| const handleActiveCredential = useCallback(async (credential: Credential, model?: CustomModel) => { | |||
| if (doingActionRef.current) | |||
| return | |||
| try { | |||
| handleSetDoingAction(true) | |||
| await getActiveCredentialService(!!model)({ | |||
| credential_id: credential.credential_id, | |||
| model: model?.model, | |||
| model_type: model?.model_type, | |||
| }) | |||
| notify({ | |||
| type: 'success', | |||
| message: t('common.api.actionSuccess'), | |||
| }) | |||
| onUpdate?.() | |||
| handleRefreshModel(provider, configurationMethod, undefined) | |||
| } | |||
| finally { | |||
| handleSetDoingAction(false) | |||
| } | |||
| }, [getActiveCredentialService, onUpdate, notify, t, handleSetDoingAction]) | |||
| const handleConfirmDelete = useCallback(async () => { | |||
| if (doingActionRef.current) | |||
| return | |||
| if (!pendingOperationCredentialId.current) { | |||
| setDeleteCredentialId(null) | |||
| return | |||
| } | |||
| try { | |||
| handleSetDoingAction(true) | |||
| await getDeleteCredentialService(!!isModelCredential)({ | |||
| credential_id: pendingOperationCredentialId.current, | |||
| model: pendingOperationModel.current?.model, | |||
| model_type: pendingOperationModel.current?.model_type, | |||
| }) | |||
| notify({ | |||
| type: 'success', | |||
| message: t('common.api.actionSuccess'), | |||
| }) | |||
| onUpdate?.() | |||
| handleRefreshModel(provider, configurationMethod, undefined) | |||
| setDeleteCredentialId(null) | |||
| pendingOperationCredentialId.current = null | |||
| pendingOperationModel.current = null | |||
| } | |||
| finally { | |||
| handleSetDoingAction(false) | |||
| } | |||
| }, [onUpdate, notify, t, handleSetDoingAction, getDeleteCredentialService, isModelCredential]) | |||
| const handleAddCredential = useCallback((model?: CustomModel) => { | |||
| if (model) | |||
| pendingOperationModel.current = model | |||
| }, []) | |||
| const handleSaveCredential = useCallback(async (payload: Record<string, any>) => { | |||
| if (doingActionRef.current) | |||
| return | |||
| try { | |||
| handleSetDoingAction(true) | |||
| let res: { result?: string } = {} | |||
| if (payload.credential_id) | |||
| res = await getEditCredentialService(!!isModelCredential)(payload as any) | |||
| else | |||
| res = await getAddCredentialService(!!isModelCredential)(payload as any) | |||
| if (res.result === 'success') { | |||
| notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) | |||
| onUpdate?.() | |||
| } | |||
| } | |||
| finally { | |||
| handleSetDoingAction(false) | |||
| } | |||
| }, [onUpdate, notify, t, handleSetDoingAction, getEditCredentialService, getAddCredentialService]) | |||
| const handleOpenModal = useCallback((credential?: Credential, model?: CustomModel) => { | |||
| handleOpenModelModal( | |||
| provider, | |||
| configurationMethod, | |||
| currentCustomConfigurationModelFixedFields, | |||
| isModelCredential, | |||
| credential, | |||
| model, | |||
| onUpdate, | |||
| ) | |||
| }, [handleOpenModelModal, provider, configurationMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onUpdate]) | |||
| return { | |||
| pendingOperationCredentialId, | |||
| pendingOperationModel, | |||
| openConfirmDelete, | |||
| closeConfirmDelete, | |||
| doingAction, | |||
| handleActiveCredential, | |||
| handleConfirmDelete, | |||
| handleAddCredential, | |||
| deleteCredentialId, | |||
| handleSaveCredential, | |||
| handleOpenModal, | |||
| } | |||
| } | |||
| @@ -0,0 +1,24 @@ | |||
| import { useMemo } from 'react' | |||
| import { useGetCredential } from './use-auth-service' | |||
| import type { | |||
| Credential, | |||
| CustomModelCredential, | |||
| ModelProvider, | |||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| export const useCredentialData = (provider: ModelProvider, providerFormSchemaPredefined: boolean, isModelCredential?: boolean, credential?: Credential, model?: CustomModelCredential) => { | |||
| const configFrom = useMemo(() => { | |||
| if (providerFormSchemaPredefined) | |||
| return 'predefined-model' | |||
| return 'custom-model' | |||
| }, [providerFormSchemaPredefined]) | |||
| const { | |||
| isLoading, | |||
| data: credentialData = {}, | |||
| } = useGetCredential(provider.provider, isModelCredential, credential?.credential_id, model, configFrom) | |||
| return { | |||
| isLoading, | |||
| credentialData, | |||
| } | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| import { useMemo } from 'react' | |||
| import type { | |||
| ModelProvider, | |||
| } from '../../declarations' | |||
| export const useCredentialStatus = (provider: ModelProvider) => { | |||
| const { | |||
| current_credential_id, | |||
| current_credential_name, | |||
| available_credentials, | |||
| } = provider.custom_configuration | |||
| const hasCredential = !!available_credentials?.length | |||
| const authorized = current_credential_id && current_credential_name | |||
| const authRemoved = hasCredential && !current_credential_id && !current_credential_name | |||
| const currentCredential = available_credentials?.find(credential => credential.credential_id === current_credential_id) | |||
| return useMemo(() => ({ | |||
| hasCredential, | |||
| authorized, | |||
| authRemoved, | |||
| current_credential_id, | |||
| current_credential_name, | |||
| available_credentials, | |||
| notAllowedToUse: currentCredential?.not_allowed_to_use, | |||
| }), [hasCredential, authorized, authRemoved, current_credential_id, current_credential_name, available_credentials]) | |||
| } | |||
| @@ -0,0 +1,9 @@ | |||
| import type { | |||
| ModelProvider, | |||
| } from '../../declarations' | |||
| export const useCustomModels = (provider: ModelProvider) => { | |||
| const { custom_models } = provider.custom_configuration | |||
| return custom_models || [] | |||
| } | |||
| @@ -0,0 +1,83 @@ | |||
| import { useMemo } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import type { | |||
| Credential, | |||
| CustomModelCredential, | |||
| ModelLoadBalancingConfig, | |||
| ModelProvider, | |||
| } from '../../declarations' | |||
| import { | |||
| genModelNameFormSchema, | |||
| genModelTypeFormSchema, | |||
| } from '../../utils' | |||
| import { FormTypeEnum } from '@/app/components/base/form/types' | |||
| export const useModelFormSchemas = ( | |||
| provider: ModelProvider, | |||
| providerFormSchemaPredefined: boolean, | |||
| credentials?: Record<string, any>, | |||
| credential?: Credential, | |||
| model?: CustomModelCredential, | |||
| draftConfig?: ModelLoadBalancingConfig, | |||
| ) => { | |||
| const { t } = useTranslation() | |||
| const { | |||
| provider_credential_schema, | |||
| supported_model_types, | |||
| model_credential_schema, | |||
| } = provider | |||
| const formSchemas = useMemo(() => { | |||
| const modelTypeSchema = genModelTypeFormSchema(supported_model_types) | |||
| const modelNameSchema = genModelNameFormSchema(model_credential_schema?.model) | |||
| if (!!model) { | |||
| modelTypeSchema.disabled = true | |||
| modelNameSchema.disabled = true | |||
| } | |||
| return providerFormSchemaPredefined | |||
| ? provider_credential_schema.credential_form_schemas | |||
| : [ | |||
| modelTypeSchema, | |||
| modelNameSchema, | |||
| ...(draftConfig?.enabled ? [] : model_credential_schema.credential_form_schemas), | |||
| ] | |||
| }, [ | |||
| providerFormSchemaPredefined, | |||
| provider_credential_schema?.credential_form_schemas, | |||
| supported_model_types, | |||
| model_credential_schema?.credential_form_schemas, | |||
| model_credential_schema?.model, | |||
| draftConfig?.enabled, | |||
| model, | |||
| ]) | |||
| const formSchemasWithAuthorizationName = useMemo(() => { | |||
| const authorizationNameSchema = { | |||
| type: FormTypeEnum.textInput, | |||
| variable: '__authorization_name__', | |||
| label: t('plugin.auth.authorizationName'), | |||
| required: true, | |||
| } | |||
| return [ | |||
| authorizationNameSchema, | |||
| ...formSchemas, | |||
| ] | |||
| }, [formSchemas, t]) | |||
| const formValues = useMemo(() => { | |||
| let result = {} | |||
| if (credential) { | |||
| result = { ...result, __authorization_name__: credential?.credential_name } | |||
| if (credentials) | |||
| result = { ...result, ...credentials } | |||
| } | |||
| if (model) | |||
| result = { ...result, __model_name: model?.model, __model_type: model?.model_type } | |||
| return result | |||
| }, [credentials, credential, model]) | |||
| return { | |||
| formSchemas: formSchemasWithAuthorizationName, | |||
| formValues, | |||
| } | |||
| } | |||
| @@ -0,0 +1,6 @@ | |||
| export { default as Authorized } from './authorized' | |||
| export { default as SwitchCredentialInLoadBalancing } from './switch-credential-in-load-balancing' | |||
| export { default as AddCredentialInLoadBalancing } from './add-credential-in-load-balancing' | |||
| export { default as AddCustomModel } from './add-custom-model' | |||
| export { default as ConfigProvider } from './config-provider' | |||
| export { default as ConfigModel } from './config-model' | |||
| @@ -0,0 +1,122 @@ | |||
| import type { Dispatch, SetStateAction } from 'react' | |||
| import { | |||
| memo, | |||
| useCallback, | |||
| } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { RiArrowDownSLine } from '@remixicon/react' | |||
| import Button from '@/app/components/base/button' | |||
| import Indicator from '@/app/components/header/indicator' | |||
| import Authorized from './authorized' | |||
| import type { | |||
| Credential, | |||
| CustomModel, | |||
| ModelProvider, | |||
| } from '../declarations' | |||
| import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import cn from '@/utils/classnames' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| import Badge from '@/app/components/base/badge' | |||
| type SwitchCredentialInLoadBalancingProps = { | |||
| provider: ModelProvider | |||
| model: CustomModel | |||
| credentials?: Credential[] | |||
| customModelCredential?: Credential | |||
| setCustomModelCredential: Dispatch<SetStateAction<Credential | undefined>> | |||
| } | |||
| const SwitchCredentialInLoadBalancing = ({ | |||
| provider, | |||
| model, | |||
| customModelCredential, | |||
| setCustomModelCredential, | |||
| credentials, | |||
| }: SwitchCredentialInLoadBalancingProps) => { | |||
| const { t } = useTranslation() | |||
| const handleItemClick = useCallback((credential: Credential) => { | |||
| setCustomModelCredential(credential) | |||
| }, [setCustomModelCredential]) | |||
| const renderTrigger = useCallback(() => { | |||
| const selectedCredentialId = customModelCredential?.credential_id | |||
| const authRemoved = !selectedCredentialId && !!credentials?.length | |||
| let color = 'green' | |||
| if (authRemoved && !customModelCredential?.not_allowed_to_use) | |||
| color = 'red' | |||
| if (customModelCredential?.not_allowed_to_use) | |||
| color = 'gray' | |||
| const Item = ( | |||
| <Button | |||
| variant='secondary' | |||
| className={cn( | |||
| 'shrink-0 space-x-1', | |||
| authRemoved && 'text-components-button-destructive-secondary-text', | |||
| customModelCredential?.not_allowed_to_use && 'cursor-not-allowed opacity-50', | |||
| )} | |||
| > | |||
| <Indicator | |||
| className='mr-2' | |||
| color={color as any} | |||
| /> | |||
| { | |||
| authRemoved && !customModelCredential?.not_allowed_to_use && t('common.modelProvider.auth.authRemoved') | |||
| } | |||
| { | |||
| !authRemoved && customModelCredential?.not_allowed_to_use && t('plugin.auth.credentialUnavailable') | |||
| } | |||
| { | |||
| !authRemoved && !customModelCredential?.not_allowed_to_use && customModelCredential?.credential_name | |||
| } | |||
| { | |||
| customModelCredential?.from_enterprise && ( | |||
| <Badge className='ml-2'>Enterprise</Badge> | |||
| ) | |||
| } | |||
| <RiArrowDownSLine className='h-4 w-4' /> | |||
| </Button> | |||
| ) | |||
| if (customModelCredential?.not_allowed_to_use) { | |||
| return ( | |||
| <Tooltip | |||
| asChild | |||
| popupContent={t('plugin.auth.credentialUnavailable')} | |||
| > | |||
| {Item} | |||
| </Tooltip> | |||
| ) | |||
| } | |||
| return Item | |||
| }, [customModelCredential, t, credentials]) | |||
| return ( | |||
| <Authorized | |||
| provider={provider} | |||
| configurationMethod={ConfigurationMethodEnum.customizableModel} | |||
| items={[ | |||
| { | |||
| title: t('common.modelProvider.auth.modelCredentials'), | |||
| model, | |||
| credentials: credentials || [], | |||
| }, | |||
| ]} | |||
| renderTrigger={renderTrigger} | |||
| onItemClick={handleItemClick} | |||
| isModelCredential | |||
| enableAddModelCredential | |||
| bottomAddModelCredentialText={t('common.modelProvider.auth.addModelCredential')} | |||
| selectedCredential={ | |||
| customModelCredential | |||
| ? { | |||
| credential_id: customModelCredential?.credential_id || '', | |||
| credential_name: customModelCredential?.credential_name || '', | |||
| } | |||
| : undefined | |||
| } | |||
| showItemSelectedIcon | |||
| /> | |||
| ) | |||
| } | |||
| export default memo(SwitchCredentialInLoadBalancing) | |||
| @@ -13,12 +13,14 @@ type ModelIconProps = { | |||
| provider?: Model | ModelProvider | |||
| modelName?: string | |||
| className?: string | |||
| iconClassName?: string | |||
| isDeprecated?: boolean | |||
| } | |||
| const ModelIcon: FC<ModelIconProps> = ({ | |||
| provider, | |||
| className, | |||
| modelName, | |||
| iconClassName, | |||
| isDeprecated = false, | |||
| }) => { | |||
| const language = useLanguage() | |||
| @@ -34,7 +36,7 @@ const ModelIcon: FC<ModelIconProps> = ({ | |||
| if (provider?.icon_small) { | |||
| return ( | |||
| <div className={cn('flex h-5 w-5 items-center justify-center', isDeprecated && 'opacity-50', className)}> | |||
| <img alt='model-icon' src={renderI18nObject(provider.icon_small, language)}/> | |||
| <img alt='model-icon' src={renderI18nObject(provider.icon_small, language)} className={iconClassName} /> | |||
| </div> | |||
| ) | |||
| } | |||
| @@ -44,7 +46,7 @@ const ModelIcon: FC<ModelIconProps> = ({ | |||
| 'flex h-5 w-5 items-center justify-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-background-default-subtle', | |||
| className, | |||
| )}> | |||
| <div className='flex h-5 w-5 items-center justify-center opacity-35'> | |||
| <div className={cn('flex h-5 w-5 items-center justify-center opacity-35', iconClassName)}> | |||
| <Group className='h-3 w-3 text-text-tertiary' /> | |||
| </div> | |||
| </div> | |||
| @@ -2,43 +2,22 @@ import type { FC } from 'react' | |||
| import { | |||
| memo, | |||
| useCallback, | |||
| useEffect, | |||
| useMemo, | |||
| useState, | |||
| useRef, | |||
| } from 'react' | |||
| import { RiCloseLine } from '@remixicon/react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { | |||
| RiErrorWarningFill, | |||
| } from '@remixicon/react' | |||
| import type { | |||
| CredentialFormSchema, | |||
| CredentialFormSchemaRadio, | |||
| CredentialFormSchemaSelect, | |||
| CustomConfigurationModelFixedFields, | |||
| FormValue, | |||
| ModelLoadBalancingConfig, | |||
| ModelLoadBalancingConfigEntry, | |||
| ModelProvider, | |||
| } from '../declarations' | |||
| import { | |||
| ConfigurationMethodEnum, | |||
| CustomConfigurationStatusEnum, | |||
| FormTypeEnum, | |||
| } from '../declarations' | |||
| import { | |||
| genModelNameFormSchema, | |||
| genModelTypeFormSchema, | |||
| removeCredentials, | |||
| saveCredentials, | |||
| } from '../utils' | |||
| import { | |||
| useLanguage, | |||
| useProviderCredentialsAndLoadBalancing, | |||
| } from '../hooks' | |||
| import { useValidate } from '../../key-validator/hooks' | |||
| import { ValidatedStatus } from '../../key-validator/declarations' | |||
| import ModelLoadBalancingConfigs from '../provider-added-card/model-load-balancing-configs' | |||
| import Form from './Form' | |||
| import Button from '@/app/components/base/button' | |||
| import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' | |||
| import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' | |||
| @@ -46,9 +25,26 @@ import { | |||
| PortalToFollowElem, | |||
| PortalToFollowElemContent, | |||
| } from '@/app/components/base/portal-to-follow-elem' | |||
| import { useToastContext } from '@/app/components/base/toast' | |||
| import Confirm from '@/app/components/base/confirm' | |||
| import { useAppContext } from '@/context/app-context' | |||
| import AuthForm from '@/app/components/base/form/form-scenarios/auth' | |||
| import type { | |||
| FormRefObject, | |||
| FormSchema, | |||
| } from '@/app/components/base/form/types' | |||
| import { useModelFormSchemas } from '../model-auth/hooks' | |||
| import type { | |||
| Credential, | |||
| CustomModel, | |||
| } from '../declarations' | |||
| import Loading from '@/app/components/base/loading' | |||
| import { | |||
| useAuth, | |||
| useCredentialData, | |||
| } from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks' | |||
| import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon' | |||
| import Badge from '@/app/components/base/badge' | |||
| import { useRenderI18nObject } from '@/hooks/use-i18n' | |||
| type ModelModalProps = { | |||
| provider: ModelProvider | |||
| @@ -56,6 +52,9 @@ type ModelModalProps = { | |||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields | |||
| onCancel: () => void | |||
| onSave: () => void | |||
| model?: CustomModel | |||
| credential?: Credential | |||
| isModelCredential?: boolean | |||
| } | |||
| const ModelModal: FC<ModelModalProps> = ({ | |||
| @@ -64,244 +63,173 @@ const ModelModal: FC<ModelModalProps> = ({ | |||
| currentCustomConfigurationModelFixedFields, | |||
| onCancel, | |||
| onSave, | |||
| model, | |||
| credential, | |||
| isModelCredential, | |||
| }) => { | |||
| const renderI18nObject = useRenderI18nObject() | |||
| const providerFormSchemaPredefined = configurateMethod === ConfigurationMethodEnum.predefinedModel | |||
| const { | |||
| isLoading, | |||
| credentialData, | |||
| } = useCredentialData(provider, providerFormSchemaPredefined, isModelCredential, credential, model) | |||
| const { | |||
| handleSaveCredential, | |||
| handleConfirmDelete, | |||
| deleteCredentialId, | |||
| closeConfirmDelete, | |||
| openConfirmDelete, | |||
| doingAction, | |||
| } = useAuth(provider, configurateMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onSave) | |||
| const { | |||
| credentials: formSchemasValue, | |||
| loadBalancing: originalConfig, | |||
| mutate, | |||
| } = useProviderCredentialsAndLoadBalancing( | |||
| provider.provider, | |||
| configurateMethod, | |||
| providerFormSchemaPredefined && provider.custom_configuration.status === CustomConfigurationStatusEnum.active, | |||
| currentCustomConfigurationModelFixedFields, | |||
| ) | |||
| } = credentialData as any | |||
| const { isCurrentWorkspaceManager } = useAppContext() | |||
| const isEditMode = !!formSchemasValue && isCurrentWorkspaceManager | |||
| const { t } = useTranslation() | |||
| const { notify } = useToastContext() | |||
| const language = useLanguage() | |||
| const [loading, setLoading] = useState(false) | |||
| const [showConfirm, setShowConfirm] = useState(false) | |||
| const [draftConfig, setDraftConfig] = useState<ModelLoadBalancingConfig>() | |||
| const originalConfigMap = useMemo(() => { | |||
| if (!originalConfig) | |||
| return {} | |||
| return originalConfig?.configs.reduce((prev, config) => { | |||
| if (config.id) | |||
| prev[config.id] = config | |||
| return prev | |||
| }, {} as Record<string, ModelLoadBalancingConfigEntry>) | |||
| }, [originalConfig]) | |||
| useEffect(() => { | |||
| if (originalConfig && !draftConfig) | |||
| setDraftConfig(originalConfig) | |||
| }, [draftConfig, originalConfig]) | |||
| const formSchemas = useMemo(() => { | |||
| return providerFormSchemaPredefined | |||
| ? provider.provider_credential_schema.credential_form_schemas | |||
| : [ | |||
| genModelTypeFormSchema(provider.supported_model_types), | |||
| genModelNameFormSchema(provider.model_credential_schema?.model), | |||
| ...(draftConfig?.enabled ? [] : provider.model_credential_schema.credential_form_schemas), | |||
| ] | |||
| }, [ | |||
| providerFormSchemaPredefined, | |||
| provider.provider_credential_schema?.credential_form_schemas, | |||
| provider.supported_model_types, | |||
| provider.model_credential_schema?.credential_form_schemas, | |||
| provider.model_credential_schema?.model, | |||
| draftConfig?.enabled, | |||
| ]) | |||
| const [ | |||
| requiredFormSchemas, | |||
| defaultFormSchemaValue, | |||
| showOnVariableMap, | |||
| ] = useMemo(() => { | |||
| const requiredFormSchemas: CredentialFormSchema[] = [] | |||
| const defaultFormSchemaValue: Record<string, string | number> = {} | |||
| const showOnVariableMap: Record<string, string[]> = {} | |||
| formSchemas.forEach((formSchema) => { | |||
| if (formSchema.required) | |||
| requiredFormSchemas.push(formSchema) | |||
| if (formSchema.default) | |||
| defaultFormSchemaValue[formSchema.variable] = formSchema.default | |||
| if (formSchema.show_on.length) { | |||
| formSchema.show_on.forEach((showOnItem) => { | |||
| if (!showOnVariableMap[showOnItem.variable]) | |||
| showOnVariableMap[showOnItem.variable] = [] | |||
| if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) | |||
| showOnVariableMap[showOnItem.variable].push(formSchema.variable) | |||
| }) | |||
| } | |||
| if (formSchema.type === FormTypeEnum.select || formSchema.type === FormTypeEnum.radio) { | |||
| (formSchema as (CredentialFormSchemaRadio | CredentialFormSchemaSelect)).options.forEach((option) => { | |||
| if (option.show_on.length) { | |||
| option.show_on.forEach((showOnItem) => { | |||
| if (!showOnVariableMap[showOnItem.variable]) | |||
| showOnVariableMap[showOnItem.variable] = [] | |||
| if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) | |||
| showOnVariableMap[showOnItem.variable].push(formSchema.variable) | |||
| }) | |||
| } | |||
| }) | |||
| } | |||
| }) | |||
| return [ | |||
| requiredFormSchemas, | |||
| defaultFormSchemaValue, | |||
| showOnVariableMap, | |||
| ] | |||
| }, [formSchemas]) | |||
| const initialFormSchemasValue: Record<string, string | number> = useMemo(() => { | |||
| return { | |||
| ...defaultFormSchemaValue, | |||
| ...formSchemasValue, | |||
| } as unknown as Record<string, string | number> | |||
| }, [formSchemasValue, defaultFormSchemaValue]) | |||
| const [value, setValue] = useState(initialFormSchemasValue) | |||
| useEffect(() => { | |||
| setValue(initialFormSchemasValue) | |||
| }, [initialFormSchemasValue]) | |||
| const [_, validating, validatedStatusState] = useValidate(value) | |||
| const filteredRequiredFormSchemas = requiredFormSchemas.filter((requiredFormSchema) => { | |||
| if (requiredFormSchema.show_on.length && requiredFormSchema.show_on.every(showOnItem => value[showOnItem.variable] === showOnItem.value)) | |||
| return true | |||
| if (!requiredFormSchema.show_on.length) | |||
| return true | |||
| const { | |||
| formSchemas, | |||
| formValues, | |||
| } = useModelFormSchemas(provider, providerFormSchemaPredefined, formSchemasValue, credential, model) | |||
| const formRef = useRef<FormRefObject>(null) | |||
| return false | |||
| }) | |||
| const handleSave = useCallback(async () => { | |||
| const { | |||
| isCheckValidated, | |||
| values, | |||
| } = formRef.current?.getFormValues({ | |||
| needCheckValidatedValues: true, | |||
| needTransformWhenSecretFieldIsPristine: true, | |||
| }) || { isCheckValidated: false, values: {} } | |||
| if (!isCheckValidated) | |||
| return | |||
| const handleValueChange = (v: FormValue) => { | |||
| setValue(v) | |||
| } | |||
| const { | |||
| __authorization_name__, | |||
| __model_name, | |||
| __model_type, | |||
| ...rest | |||
| } = values | |||
| if (__model_name && __model_type) { | |||
| handleSaveCredential({ | |||
| credential_id: credential?.credential_id, | |||
| credentials: rest, | |||
| name: __authorization_name__, | |||
| model: __model_name, | |||
| model_type: __model_type, | |||
| }) | |||
| } | |||
| else { | |||
| handleSaveCredential({ | |||
| credential_id: credential?.credential_id, | |||
| credentials: rest, | |||
| name: __authorization_name__, | |||
| }) | |||
| } | |||
| }, [handleSaveCredential, credential?.credential_id, model]) | |||
| const extendedSecretFormSchemas = useMemo( | |||
| () => | |||
| (providerFormSchemaPredefined | |||
| ? provider.provider_credential_schema.credential_form_schemas | |||
| : [ | |||
| genModelTypeFormSchema(provider.supported_model_types), | |||
| genModelNameFormSchema(provider.model_credential_schema?.model), | |||
| ...provider.model_credential_schema.credential_form_schemas, | |||
| ]).filter(({ type }) => type === FormTypeEnum.secretInput), | |||
| [ | |||
| provider.model_credential_schema?.credential_form_schemas, | |||
| provider.model_credential_schema?.model, | |||
| provider.provider_credential_schema?.credential_form_schemas, | |||
| provider.supported_model_types, | |||
| providerFormSchemaPredefined, | |||
| ], | |||
| ) | |||
| const modalTitle = useMemo(() => { | |||
| if (!providerFormSchemaPredefined && !model) { | |||
| return ( | |||
| <div className='flex items-center'> | |||
| <ModelIcon | |||
| className='mr-2 h-10 w-10 shrink-0' | |||
| iconClassName='h-10 w-10' | |||
| provider={provider} | |||
| /> | |||
| <div> | |||
| <div className='system-xs-medium-uppercase text-text-tertiary'>{t('common.modelProvider.auth.apiKeyModal.addModel')}</div> | |||
| <div className='system-md-semibold text-text-primary'>{renderI18nObject(provider.label)}</div> | |||
| </div> | |||
| </div> | |||
| ) | |||
| } | |||
| let label = t('common.modelProvider.auth.apiKeyModal.title') | |||
| const encodeSecretValues = useCallback((v: FormValue) => { | |||
| const result = { ...v } | |||
| extendedSecretFormSchemas.forEach(({ variable }) => { | |||
| if (result[variable] === formSchemasValue?.[variable] && result[variable] !== undefined) | |||
| result[variable] = '[__HIDDEN__]' | |||
| }) | |||
| return result | |||
| }, [extendedSecretFormSchemas, formSchemasValue]) | |||
| if (model) | |||
| label = t('common.modelProvider.auth.addModelCredential') | |||
| const encodeConfigEntrySecretValues = useCallback((entry: ModelLoadBalancingConfigEntry) => { | |||
| const result = { ...entry } | |||
| extendedSecretFormSchemas.forEach(({ variable }) => { | |||
| if (entry.id && result.credentials[variable] === originalConfigMap[entry.id]?.credentials?.[variable]) | |||
| result.credentials[variable] = '[__HIDDEN__]' | |||
| }) | |||
| return result | |||
| }, [extendedSecretFormSchemas, originalConfigMap]) | |||
| return ( | |||
| <div className='title-2xl-semi-bold text-text-primary'> | |||
| {label} | |||
| </div> | |||
| ) | |||
| }, [providerFormSchemaPredefined, t, model, renderI18nObject]) | |||
| const handleSave = async () => { | |||
| try { | |||
| setLoading(true) | |||
| const res = await saveCredentials( | |||
| providerFormSchemaPredefined, | |||
| provider.provider, | |||
| encodeSecretValues(value), | |||
| { | |||
| ...draftConfig, | |||
| enabled: Boolean(draftConfig?.enabled), | |||
| configs: draftConfig?.configs.map(encodeConfigEntrySecretValues) || [], | |||
| }, | |||
| const modalDesc = useMemo(() => { | |||
| if (providerFormSchemaPredefined) { | |||
| return ( | |||
| <div className='system-xs-regular mt-1 text-text-tertiary'> | |||
| {t('common.modelProvider.auth.apiKeyModal.desc')} | |||
| </div> | |||
| ) | |||
| if (res.result === 'success') { | |||
| notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) | |||
| mutate() | |||
| onSave() | |||
| onCancel() | |||
| } | |||
| } | |||
| finally { | |||
| setLoading(false) | |||
| } | |||
| } | |||
| const handleRemove = async () => { | |||
| try { | |||
| setLoading(true) | |||
| return null | |||
| }, [providerFormSchemaPredefined, t]) | |||
| const res = await removeCredentials( | |||
| providerFormSchemaPredefined, | |||
| provider.provider, | |||
| value, | |||
| const modalModel = useMemo(() => { | |||
| if (model) { | |||
| return ( | |||
| <div className='mt-2 flex items-center'> | |||
| <ModelIcon | |||
| className='mr-2 h-4 w-4 shrink-0' | |||
| provider={provider} | |||
| modelName={model.model} | |||
| /> | |||
| <div className='system-md-regular mr-1 text-text-secondary'>{model.model}</div> | |||
| <Badge>{model.model_type}</Badge> | |||
| </div> | |||
| ) | |||
| if (res.result === 'success') { | |||
| notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) | |||
| mutate() | |||
| onSave() | |||
| onCancel() | |||
| } | |||
| } | |||
| finally { | |||
| setLoading(false) | |||
| } | |||
| } | |||
| const renderTitlePrefix = () => { | |||
| const prefix = isEditMode ? t('common.operation.setup') : t('common.operation.add') | |||
| return `${prefix} ${provider.label[language] || provider.label.en_US}` | |||
| } | |||
| return null | |||
| }, [model, provider]) | |||
| return ( | |||
| <PortalToFollowElem open> | |||
| <PortalToFollowElemContent className='z-[60] h-full w-full'> | |||
| <div className='fixed inset-0 flex items-center justify-center bg-black/[.25]'> | |||
| <div className='mx-2 w-[640px] overflow-auto rounded-2xl bg-components-panel-bg shadow-xl'> | |||
| <div className='px-8 pt-8'> | |||
| <div className='mb-2 flex items-center'> | |||
| <div className='text-xl font-semibold text-text-primary'>{renderTitlePrefix()}</div> | |||
| <div className='relative w-[640px] rounded-2xl bg-components-panel-bg shadow-xl'> | |||
| <div | |||
| className='absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center' | |||
| onClick={onCancel} | |||
| > | |||
| <RiCloseLine className='h-4 w-4 text-text-tertiary' /> | |||
| </div> | |||
| <div className='px-6 pt-6'> | |||
| <div className='pb-3'> | |||
| {modalTitle} | |||
| {modalDesc} | |||
| {modalModel} | |||
| </div> | |||
| <div className='max-h-[calc(100vh-320px)] overflow-y-auto'> | |||
| <Form | |||
| value={value} | |||
| onChange={handleValueChange} | |||
| formSchemas={formSchemas} | |||
| validating={validating} | |||
| validatedSuccess={validatedStatusState.status === ValidatedStatus.Success} | |||
| showOnVariableMap={showOnVariableMap} | |||
| isEditMode={isEditMode} | |||
| /> | |||
| <div className='mb-4 mt-1 border-t-[0.5px] border-t-divider-regular' /> | |||
| <ModelLoadBalancingConfigs withSwitch {...{ | |||
| draftConfig, | |||
| setDraftConfig, | |||
| provider, | |||
| currentCustomConfigurationModelFixedFields, | |||
| configurationMethod: configurateMethod, | |||
| }} /> | |||
| { | |||
| isLoading && ( | |||
| <div className='flex items-center justify-center'> | |||
| <Loading /> | |||
| </div> | |||
| ) | |||
| } | |||
| { | |||
| !isLoading && ( | |||
| <AuthForm | |||
| formSchemas={formSchemas.map((formSchema) => { | |||
| return { | |||
| ...formSchema, | |||
| name: formSchema.variable, | |||
| showRadioUI: formSchema.type === FormTypeEnum.radio, | |||
| } | |||
| }) as FormSchema[]} | |||
| defaultValues={formValues} | |||
| inputClassName='justify-start' | |||
| ref={formRef} | |||
| /> | |||
| ) | |||
| } | |||
| </div> | |||
| <div className='sticky bottom-0 -mx-2 mt-2 flex flex-wrap items-center justify-between gap-y-2 bg-components-panel-bg px-2 pb-6 pt-4'> | |||
| @@ -327,7 +255,7 @@ const ModelModal: FC<ModelModalProps> = ({ | |||
| variant='warning' | |||
| size='large' | |||
| className='mr-2' | |||
| onClick={() => setShowConfirm(true)} | |||
| onClick={() => openConfirmDelete(credential, model)} | |||
| > | |||
| {t('common.operation.remove')} | |||
| </Button> | |||
| @@ -344,12 +272,7 @@ const ModelModal: FC<ModelModalProps> = ({ | |||
| size='large' | |||
| variant='primary' | |||
| onClick={handleSave} | |||
| disabled={ | |||
| loading | |||
| || filteredRequiredFormSchemas.some(item => value[item.variable] === undefined) | |||
| || (draftConfig?.enabled && (draftConfig?.configs.filter(config => config.enabled).length ?? 0) < 2) | |||
| } | |||
| disabled={isLoading || doingAction} | |||
| > | |||
| {t('common.operation.save')} | |||
| </Button> | |||
| @@ -357,38 +280,28 @@ const ModelModal: FC<ModelModalProps> = ({ | |||
| </div> | |||
| </div> | |||
| <div className='border-t-[0.5px] border-t-divider-regular'> | |||
| { | |||
| (validatedStatusState.status === ValidatedStatus.Error && validatedStatusState.message) | |||
| ? ( | |||
| <div className='flex bg-background-section-burn px-[10px] py-3 text-xs text-[#D92D20]'> | |||
| <RiErrorWarningFill className='mr-2 mt-[1px] h-[14px] w-[14px]' /> | |||
| {validatedStatusState.message} | |||
| </div> | |||
| ) | |||
| : ( | |||
| <div className='flex items-center justify-center bg-background-section-burn py-3 text-xs text-text-tertiary'> | |||
| <Lock01 className='mr-1 h-3 w-3 text-text-tertiary' /> | |||
| {t('common.modelProvider.encrypted.front')} | |||
| <a | |||
| className='mx-1 text-text-accent' | |||
| target='_blank' rel='noopener noreferrer' | |||
| href='https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html' | |||
| > | |||
| PKCS1_OAEP | |||
| </a> | |||
| {t('common.modelProvider.encrypted.back')} | |||
| </div> | |||
| ) | |||
| } | |||
| <div className='flex items-center justify-center rounded-b-2xl bg-background-section-burn py-3 text-xs text-text-tertiary'> | |||
| <Lock01 className='mr-1 h-3 w-3 text-text-tertiary' /> | |||
| {t('common.modelProvider.encrypted.front')} | |||
| <a | |||
| className='mx-1 text-text-accent' | |||
| target='_blank' rel='noopener noreferrer' | |||
| href='https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html' | |||
| > | |||
| PKCS1_OAEP | |||
| </a> | |||
| {t('common.modelProvider.encrypted.back')} | |||
| </div> | |||
| </div> | |||
| </div> | |||
| { | |||
| showConfirm && ( | |||
| deleteCredentialId && ( | |||
| <Confirm | |||
| isShow | |||
| title={t('common.modelProvider.confirmDelete')} | |||
| isShow={showConfirm} | |||
| onCancel={() => setShowConfirm(false)} | |||
| onConfirm={handleRemove} | |||
| isDisabled={doingAction} | |||
| onCancel={closeConfirmDelete} | |||
| onConfirm={handleConfirmDelete} | |||
| /> | |||
| ) | |||
| } | |||
| @@ -1,348 +0,0 @@ | |||
| import type { FC } from 'react' | |||
| import { | |||
| memo, | |||
| useCallback, | |||
| useEffect, | |||
| useMemo, | |||
| useState, | |||
| } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { | |||
| RiErrorWarningFill, | |||
| } from '@remixicon/react' | |||
| import type { | |||
| CredentialFormSchema, | |||
| CredentialFormSchemaRadio, | |||
| CredentialFormSchemaSelect, | |||
| CredentialFormSchemaTextInput, | |||
| CustomConfigurationModelFixedFields, | |||
| FormValue, | |||
| ModelLoadBalancingConfigEntry, | |||
| ModelProvider, | |||
| } from '../declarations' | |||
| import { | |||
| ConfigurationMethodEnum, | |||
| FormTypeEnum, | |||
| } from '../declarations' | |||
| import { | |||
| useLanguage, | |||
| } from '../hooks' | |||
| import { useValidate } from '../../key-validator/hooks' | |||
| import { ValidatedStatus } from '../../key-validator/declarations' | |||
| import { validateLoadBalancingCredentials } from '../utils' | |||
| import Form from './Form' | |||
| import Button from '@/app/components/base/button' | |||
| import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' | |||
| import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' | |||
| import { | |||
| PortalToFollowElem, | |||
| PortalToFollowElemContent, | |||
| } from '@/app/components/base/portal-to-follow-elem' | |||
| import { useToastContext } from '@/app/components/base/toast' | |||
| import Confirm from '@/app/components/base/confirm' | |||
| type ModelModalProps = { | |||
| provider: ModelProvider | |||
| configurationMethod: ConfigurationMethodEnum | |||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields | |||
| entry?: ModelLoadBalancingConfigEntry | |||
| onCancel: () => void | |||
| onSave: (entry: ModelLoadBalancingConfigEntry) => void | |||
| onRemove: () => void | |||
| } | |||
| const ModelLoadBalancingEntryModal: FC<ModelModalProps> = ({ | |||
| provider, | |||
| configurationMethod, | |||
| currentCustomConfigurationModelFixedFields, | |||
| entry, | |||
| onCancel, | |||
| onSave, | |||
| onRemove, | |||
| }) => { | |||
| const providerFormSchemaPredefined = configurationMethod === ConfigurationMethodEnum.predefinedModel | |||
| // const { credentials: formSchemasValue } = useProviderCredentialsAndLoadBalancing( | |||
| // provider.provider, | |||
| // configurationMethod, | |||
| // providerFormSchemaPredefined && provider.custom_configuration.status === CustomConfigurationStatusEnum.active, | |||
| // currentCustomConfigurationModelFixedFields, | |||
| // ) | |||
| const isEditMode = !!entry | |||
| const { t } = useTranslation() | |||
| const { notify } = useToastContext() | |||
| const language = useLanguage() | |||
| const [loading, setLoading] = useState(false) | |||
| const [showConfirm, setShowConfirm] = useState(false) | |||
| const formSchemas = useMemo(() => { | |||
| return [ | |||
| { | |||
| type: FormTypeEnum.textInput, | |||
| label: { | |||
| en_US: 'Config Name', | |||
| zh_Hans: '配置名称', | |||
| }, | |||
| variable: 'name', | |||
| required: true, | |||
| show_on: [], | |||
| placeholder: { | |||
| en_US: 'Enter your Config Name here', | |||
| zh_Hans: '输入配置名称', | |||
| }, | |||
| } as CredentialFormSchemaTextInput, | |||
| ...( | |||
| providerFormSchemaPredefined | |||
| ? provider.provider_credential_schema.credential_form_schemas | |||
| : provider.model_credential_schema.credential_form_schemas | |||
| ), | |||
| ] | |||
| }, [ | |||
| providerFormSchemaPredefined, | |||
| provider.provider_credential_schema?.credential_form_schemas, | |||
| provider.model_credential_schema?.credential_form_schemas, | |||
| ]) | |||
| const [ | |||
| requiredFormSchemas, | |||
| secretFormSchemas, | |||
| defaultFormSchemaValue, | |||
| showOnVariableMap, | |||
| ] = useMemo(() => { | |||
| const requiredFormSchemas: CredentialFormSchema[] = [] | |||
| const secretFormSchemas: CredentialFormSchema[] = [] | |||
| const defaultFormSchemaValue: Record<string, string | number> = {} | |||
| const showOnVariableMap: Record<string, string[]> = {} | |||
| formSchemas.forEach((formSchema) => { | |||
| if (formSchema.required) | |||
| requiredFormSchemas.push(formSchema) | |||
| if (formSchema.type === FormTypeEnum.secretInput) | |||
| secretFormSchemas.push(formSchema) | |||
| if (formSchema.default) | |||
| defaultFormSchemaValue[formSchema.variable] = formSchema.default | |||
| if (formSchema.show_on.length) { | |||
| formSchema.show_on.forEach((showOnItem) => { | |||
| if (!showOnVariableMap[showOnItem.variable]) | |||
| showOnVariableMap[showOnItem.variable] = [] | |||
| if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) | |||
| showOnVariableMap[showOnItem.variable].push(formSchema.variable) | |||
| }) | |||
| } | |||
| if (formSchema.type === FormTypeEnum.select || formSchema.type === FormTypeEnum.radio) { | |||
| (formSchema as (CredentialFormSchemaRadio | CredentialFormSchemaSelect)).options.forEach((option) => { | |||
| if (option.show_on.length) { | |||
| option.show_on.forEach((showOnItem) => { | |||
| if (!showOnVariableMap[showOnItem.variable]) | |||
| showOnVariableMap[showOnItem.variable] = [] | |||
| if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) | |||
| showOnVariableMap[showOnItem.variable].push(formSchema.variable) | |||
| }) | |||
| } | |||
| }) | |||
| } | |||
| }) | |||
| return [ | |||
| requiredFormSchemas, | |||
| secretFormSchemas, | |||
| defaultFormSchemaValue, | |||
| showOnVariableMap, | |||
| ] | |||
| }, [formSchemas]) | |||
| const [initialValue, setInitialValue] = useState<ModelLoadBalancingConfigEntry['credentials']>() | |||
| useEffect(() => { | |||
| if (entry && !initialValue) { | |||
| setInitialValue({ | |||
| ...defaultFormSchemaValue, | |||
| ...entry.credentials, | |||
| id: entry.id, | |||
| name: entry.name, | |||
| } as Record<string, string | undefined | boolean>) | |||
| } | |||
| }, [entry, defaultFormSchemaValue, initialValue]) | |||
| const formSchemasValue = useMemo(() => ({ | |||
| ...currentCustomConfigurationModelFixedFields, | |||
| ...initialValue, | |||
| }), [currentCustomConfigurationModelFixedFields, initialValue]) | |||
| const initialFormSchemasValue: Record<string, string | number> = useMemo(() => { | |||
| return { | |||
| ...defaultFormSchemaValue, | |||
| ...formSchemasValue, | |||
| } as Record<string, string | number> | |||
| }, [formSchemasValue, defaultFormSchemaValue]) | |||
| const [value, setValue] = useState(initialFormSchemasValue) | |||
| useEffect(() => { | |||
| setValue(initialFormSchemasValue) | |||
| }, [initialFormSchemasValue]) | |||
| const [_, validating, validatedStatusState] = useValidate(value) | |||
| const filteredRequiredFormSchemas = requiredFormSchemas.filter((requiredFormSchema) => { | |||
| if (requiredFormSchema.show_on.length && requiredFormSchema.show_on.every(showOnItem => value[showOnItem.variable] === showOnItem.value)) | |||
| return true | |||
| if (!requiredFormSchema.show_on.length) | |||
| return true | |||
| return false | |||
| }) | |||
| const getSecretValues = useCallback((v: FormValue) => { | |||
| return secretFormSchemas.reduce((prev, next) => { | |||
| if (isEditMode && v[next.variable] && v[next.variable] === initialFormSchemasValue[next.variable]) | |||
| prev[next.variable] = '[__HIDDEN__]' | |||
| return prev | |||
| }, {} as Record<string, string>) | |||
| }, [initialFormSchemasValue, isEditMode, secretFormSchemas]) | |||
| // const handleValueChange = ({ __model_type, __model_name, ...v }: FormValue) => { | |||
| const handleValueChange = (v: FormValue) => { | |||
| setValue(v) | |||
| } | |||
| const handleSave = async () => { | |||
| try { | |||
| setLoading(true) | |||
| const res = await validateLoadBalancingCredentials( | |||
| providerFormSchemaPredefined, | |||
| provider.provider, | |||
| { | |||
| ...value, | |||
| ...getSecretValues(value), | |||
| }, | |||
| entry?.id, | |||
| ) | |||
| if (res.status === ValidatedStatus.Success) { | |||
| // notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) | |||
| const { __model_type, __model_name, name, ...credentials } = value | |||
| onSave({ | |||
| ...(entry || {}), | |||
| name: name as string, | |||
| credentials: credentials as Record<string, string | boolean | undefined>, | |||
| }) | |||
| // onCancel() | |||
| } | |||
| else { | |||
| notify({ type: 'error', message: res.message || '' }) | |||
| } | |||
| } | |||
| finally { | |||
| setLoading(false) | |||
| } | |||
| } | |||
| const handleRemove = () => { | |||
| onRemove?.() | |||
| } | |||
| return ( | |||
| <PortalToFollowElem open> | |||
| <PortalToFollowElemContent className='z-[60] h-full w-full'> | |||
| <div className='fixed inset-0 flex items-center justify-center bg-black/[.25]'> | |||
| <div className='mx-2 max-h-[calc(100vh-120px)] w-[640px] overflow-y-auto rounded-2xl bg-white shadow-xl'> | |||
| <div className='px-8 pt-8'> | |||
| <div className='mb-2 flex items-center justify-between'> | |||
| <div className='text-xl font-semibold text-gray-900'>{t(isEditMode ? 'common.modelProvider.editConfig' : 'common.modelProvider.addConfig')}</div> | |||
| </div> | |||
| <Form | |||
| value={value} | |||
| onChange={handleValueChange} | |||
| formSchemas={formSchemas} | |||
| validating={validating} | |||
| validatedSuccess={validatedStatusState.status === ValidatedStatus.Success} | |||
| showOnVariableMap={showOnVariableMap} | |||
| isEditMode={isEditMode} | |||
| /> | |||
| <div className='sticky bottom-0 flex flex-wrap items-center justify-between gap-y-2 bg-white py-6'> | |||
| { | |||
| (provider.help && (provider.help.title || provider.help.url)) | |||
| ? ( | |||
| <a | |||
| href={provider.help?.url[language] || provider.help?.url.en_US} | |||
| target='_blank' rel='noopener noreferrer' | |||
| className='inline-flex items-center text-xs text-primary-600' | |||
| onClick={e => !provider.help.url && e.preventDefault()} | |||
| > | |||
| {provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US} | |||
| <LinkExternal02 className='ml-1 h-3 w-3' /> | |||
| </a> | |||
| ) | |||
| : <div /> | |||
| } | |||
| <div> | |||
| { | |||
| isEditMode && ( | |||
| <Button | |||
| size='large' | |||
| className='mr-2 text-[#D92D20]' | |||
| onClick={() => setShowConfirm(true)} | |||
| > | |||
| {t('common.operation.remove')} | |||
| </Button> | |||
| ) | |||
| } | |||
| <Button | |||
| size='large' | |||
| className='mr-2' | |||
| onClick={onCancel} | |||
| > | |||
| {t('common.operation.cancel')} | |||
| </Button> | |||
| <Button | |||
| size='large' | |||
| variant='primary' | |||
| onClick={handleSave} | |||
| disabled={loading || filteredRequiredFormSchemas.some(item => value[item.variable] === undefined)} | |||
| > | |||
| {t('common.operation.save')} | |||
| </Button> | |||
| </div> | |||
| </div> | |||
| </div> | |||
| <div className='border-t-[0.5px] border-t-black/5'> | |||
| { | |||
| (validatedStatusState.status === ValidatedStatus.Error && validatedStatusState.message) | |||
| ? ( | |||
| <div className='flex bg-[#FEF3F2] px-[10px] py-3 text-xs text-[#D92D20]'> | |||
| <RiErrorWarningFill className='mr-2 mt-[1px] h-[14px] w-[14px]' /> | |||
| {validatedStatusState.message} | |||
| </div> | |||
| ) | |||
| : ( | |||
| <div className='flex items-center justify-center bg-gray-50 py-3 text-xs text-gray-500'> | |||
| <Lock01 className='mr-1 h-3 w-3 text-gray-500' /> | |||
| {t('common.modelProvider.encrypted.front')} | |||
| <a | |||
| className='mx-1 text-primary-600' | |||
| target='_blank' rel='noopener noreferrer' | |||
| href='https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html' | |||
| > | |||
| PKCS1_OAEP | |||
| </a> | |||
| {t('common.modelProvider.encrypted.back')} | |||
| </div> | |||
| ) | |||
| } | |||
| </div> | |||
| </div> | |||
| { | |||
| showConfirm && ( | |||
| <Confirm | |||
| title={t('common.modelProvider.confirmDelete')} | |||
| isShow={showConfirm} | |||
| onCancel={() => setShowConfirm(false)} | |||
| onConfirm={handleRemove} | |||
| /> | |||
| ) | |||
| } | |||
| </div> | |||
| </PortalToFollowElemContent> | |||
| </PortalToFollowElem> | |||
| ) | |||
| } | |||
| export default memo(ModelLoadBalancingEntryModal) | |||
| @@ -1,7 +1,8 @@ | |||
| import type { FC } from 'react' | |||
| import { useMemo } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { RiEqualizer2Line } from '@remixicon/react' | |||
| import type { ModelProvider } from '../declarations' | |||
| import type { | |||
| ModelProvider, | |||
| } from '../declarations' | |||
| import { | |||
| ConfigurationMethodEnum, | |||
| CustomConfigurationStatusEnum, | |||
| @@ -15,19 +16,19 @@ import PrioritySelector from './priority-selector' | |||
| import PriorityUseTip from './priority-use-tip' | |||
| import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './index' | |||
| import Indicator from '@/app/components/header/indicator' | |||
| import Button from '@/app/components/base/button' | |||
| import { changeModelProviderPriority } from '@/service/common' | |||
| import { useToastContext } from '@/app/components/base/toast' | |||
| import { useEventEmitterContextContext } from '@/context/event-emitter' | |||
| import cn from '@/utils/classnames' | |||
| import { useCredentialStatus } from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks' | |||
| import { ConfigProvider } from '@/app/components/header/account-setting/model-provider-page/model-auth' | |||
| type CredentialPanelProps = { | |||
| provider: ModelProvider | |||
| onSetup: () => void | |||
| } | |||
| const CredentialPanel: FC<CredentialPanelProps> = ({ | |||
| const CredentialPanel = ({ | |||
| provider, | |||
| onSetup, | |||
| }) => { | |||
| }: CredentialPanelProps) => { | |||
| const { t } = useTranslation() | |||
| const { notify } = useToastContext() | |||
| const { eventEmitter } = useEventEmitterContextContext() | |||
| @@ -38,6 +39,13 @@ const CredentialPanel: FC<CredentialPanelProps> = ({ | |||
| const priorityUseType = provider.preferred_provider_type | |||
| const isCustomConfigured = customConfig.status === CustomConfigurationStatusEnum.active | |||
| const configurateMethods = provider.configurate_methods | |||
| const { | |||
| hasCredential, | |||
| authorized, | |||
| authRemoved, | |||
| current_credential_name, | |||
| notAllowedToUse, | |||
| } = useCredentialStatus(provider) | |||
| const handleChangePriority = async (key: PreferredProviderTypeEnum) => { | |||
| const res = await changeModelProviderPriority({ | |||
| @@ -61,25 +69,50 @@ const CredentialPanel: FC<CredentialPanelProps> = ({ | |||
| } as any) | |||
| } | |||
| } | |||
| const credentialLabel = useMemo(() => { | |||
| if (!hasCredential) | |||
| return t('common.modelProvider.auth.unAuthorized') | |||
| if (authorized) | |||
| return current_credential_name | |||
| if (authRemoved) | |||
| return t('common.modelProvider.auth.authRemoved') | |||
| return '' | |||
| }, [authorized, authRemoved, current_credential_name, hasCredential]) | |||
| const color = useMemo(() => { | |||
| if (authRemoved) | |||
| return 'red' | |||
| if (notAllowedToUse) | |||
| return 'gray' | |||
| return 'green' | |||
| }, [authRemoved, notAllowedToUse]) | |||
| return ( | |||
| <> | |||
| { | |||
| provider.provider_credential_schema && ( | |||
| <div className='relative ml-1 w-[112px] shrink-0 rounded-lg border-[0.5px] border-components-panel-border bg-white/[0.18] p-1'> | |||
| <div className='system-xs-medium-uppercase mb-1 flex h-5 items-center justify-between pl-2 pr-[7px] pt-1 text-text-tertiary'> | |||
| API-KEY | |||
| <Indicator color={isCustomConfigured ? 'green' : 'red'} /> | |||
| <div className={cn( | |||
| 'relative ml-1 w-[120px] shrink-0 rounded-lg border-[0.5px] border-components-panel-border bg-white/[0.18] p-1', | |||
| authRemoved && 'border-state-destructive-border bg-state-destructive-hover', | |||
| )}> | |||
| <div className='system-xs-medium mb-1 flex h-5 items-center justify-between pl-2 pr-[7px] pt-1 text-text-tertiary'> | |||
| <div | |||
| className={cn( | |||
| 'grow truncate', | |||
| authRemoved && 'text-text-destructive', | |||
| )} | |||
| title={credentialLabel} | |||
| > | |||
| {credentialLabel} | |||
| </div> | |||
| <Indicator className='shrink-0' color={color} /> | |||
| </div> | |||
| <div className='flex items-center gap-0.5'> | |||
| <Button | |||
| className='grow' | |||
| size='small' | |||
| onClick={onSetup} | |||
| > | |||
| <RiEqualizer2Line className='mr-1 h-3.5 w-3.5' /> | |||
| {t('common.operation.setup')} | |||
| </Button> | |||
| <ConfigProvider | |||
| provider={provider} | |||
| configurationMethod={ConfigurationMethodEnum.predefinedModel} | |||
| /> | |||
| { | |||
| systemConfig.enabled && isCustomConfigured && ( | |||
| <PrioritySelector | |||
| @@ -7,7 +7,6 @@ import { | |||
| RiLoader2Line, | |||
| } from '@remixicon/react' | |||
| import type { | |||
| CustomConfigurationModelFixedFields, | |||
| ModelItem, | |||
| ModelProvider, | |||
| } from '../declarations' | |||
| @@ -21,23 +20,21 @@ import ModelBadge from '../model-badge' | |||
| import CredentialPanel from './credential-panel' | |||
| import QuotaPanel from './quota-panel' | |||
| import ModelList from './model-list' | |||
| import AddModelButton from './add-model-button' | |||
| import { fetchModelProviderModelList } from '@/service/common' | |||
| import { useEventEmitterContextContext } from '@/context/event-emitter' | |||
| import { IS_CE_EDITION } from '@/config' | |||
| import { useAppContext } from '@/context/app-context' | |||
| import cn from '@/utils/classnames' | |||
| import { AddCustomModel } from '@/app/components/header/account-setting/model-provider-page/model-auth' | |||
| export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST' | |||
| type ProviderAddedCardProps = { | |||
| notConfigured?: boolean | |||
| provider: ModelProvider | |||
| onOpenModal: (configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => void | |||
| } | |||
| const ProviderAddedCard: FC<ProviderAddedCardProps> = ({ | |||
| notConfigured, | |||
| provider, | |||
| onOpenModal, | |||
| }) => { | |||
| const { t } = useTranslation() | |||
| const { eventEmitter } = useEventEmitterContextContext() | |||
| @@ -114,7 +111,6 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({ | |||
| { | |||
| showCredential && ( | |||
| <CredentialPanel | |||
| onSetup={() => onOpenModal(ConfigurationMethodEnum.predefinedModel)} | |||
| provider={provider} | |||
| /> | |||
| ) | |||
| @@ -159,9 +155,9 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({ | |||
| )} | |||
| { | |||
| configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && ( | |||
| <AddModelButton | |||
| onClick={() => onOpenModal(ConfigurationMethodEnum.customizableModel)} | |||
| className='flex' | |||
| <AddCustomModel | |||
| provider={provider} | |||
| configurationMethod={ConfigurationMethodEnum.customizableModel} | |||
| /> | |||
| ) | |||
| } | |||
| @@ -174,7 +170,6 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({ | |||
| provider={provider} | |||
| models={modelList} | |||
| onCollapse={() => setCollapsed(true)} | |||
| onConfig={currentCustomConfigurationModelFixedFields => onOpenModal(ConfigurationMethodEnum.customizableModel, currentCustomConfigurationModelFixedFields)} | |||
| onChange={(provider: string) => getModelList(provider)} | |||
| /> | |||
| ) | |||
| @@ -1,31 +1,29 @@ | |||
| import { memo, useCallback } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { useDebounceFn } from 'ahooks' | |||
| import type { CustomConfigurationModelFixedFields, ModelItem, ModelProvider } from '../declarations' | |||
| import { ConfigurationMethodEnum, ModelStatusEnum } from '../declarations' | |||
| import ModelBadge from '../model-badge' | |||
| import type { ModelItem, ModelProvider } from '../declarations' | |||
| import { ModelStatusEnum } from '../declarations' | |||
| import ModelIcon from '../model-icon' | |||
| import ModelName from '../model-name' | |||
| import classNames from '@/utils/classnames' | |||
| import Button from '@/app/components/base/button' | |||
| import { Balance } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' | |||
| import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' | |||
| import Switch from '@/app/components/base/switch' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| import { useProviderContext, useProviderContextSelector } from '@/context/provider-context' | |||
| import { disableModel, enableModel } from '@/service/common' | |||
| import { Plan } from '@/app/components/billing/type' | |||
| import { useAppContext } from '@/context/app-context' | |||
| import { ConfigModel } from '../model-auth' | |||
| import Badge from '@/app/components/base/badge' | |||
| export type ModelListItemProps = { | |||
| model: ModelItem | |||
| provider: ModelProvider | |||
| isConfigurable: boolean | |||
| onConfig: (currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => void | |||
| onModifyLoadBalancing?: (model: ModelItem) => void | |||
| } | |||
| const ModelListItem = ({ model, provider, isConfigurable, onConfig, onModifyLoadBalancing }: ModelListItemProps) => { | |||
| const ModelListItem = ({ model, provider, isConfigurable, onModifyLoadBalancing }: ModelListItemProps) => { | |||
| const { t } = useTranslation() | |||
| const { plan } = useProviderContext() | |||
| const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled) | |||
| @@ -46,7 +44,7 @@ const ModelListItem = ({ model, provider, isConfigurable, onConfig, onModifyLoad | |||
| return ( | |||
| <div | |||
| key={model.model} | |||
| key={`${model.model}-${model.fetch_from}`} | |||
| className={classNames( | |||
| 'group flex h-8 items-center rounded-lg pl-2 pr-2.5', | |||
| isConfigurable && 'hover:bg-components-panel-on-panel-item-bg-hover', | |||
| @@ -65,38 +63,22 @@ const ModelListItem = ({ model, provider, isConfigurable, onConfig, onModifyLoad | |||
| showMode | |||
| showContextSize | |||
| > | |||
| {modelLoadBalancingEnabled && !model.deprecated && model.load_balancing_enabled && ( | |||
| <ModelBadge className='ml-1 border-text-accent-secondary uppercase text-text-accent-secondary'> | |||
| <Balance className='mr-0.5 h-3 w-3' /> | |||
| {t('common.modelProvider.loadBalancingHeadline')} | |||
| </ModelBadge> | |||
| )} | |||
| </ModelName> | |||
| <div className='flex shrink-0 items-center'> | |||
| {modelLoadBalancingEnabled && !model.deprecated && model.load_balancing_enabled && !model.has_invalid_load_balancing_configs && ( | |||
| <Badge className='mr-1 h-[18px] w-[18px] items-center justify-center border-text-accent-secondary p-0'> | |||
| <Balance className='h-3 w-3 text-text-accent-secondary' /> | |||
| </Badge> | |||
| )} | |||
| { | |||
| model.fetch_from === ConfigurationMethodEnum.customizableModel | |||
| ? (isCurrentWorkspaceManager && ( | |||
| <Button | |||
| size='small' | |||
| className='hidden group-hover:flex' | |||
| onClick={() => onConfig({ __model_name: model.model, __model_type: model.model_type })} | |||
| > | |||
| <Settings01 className='mr-1 h-3.5 w-3.5' /> | |||
| {t('common.modelProvider.config')} | |||
| </Button> | |||
| )) | |||
| : (isCurrentWorkspaceManager && (modelLoadBalancingEnabled || plan.type === Plan.sandbox) && !model.deprecated && [ModelStatusEnum.active, ModelStatusEnum.disabled].includes(model.status)) | |||
| ? ( | |||
| <Button | |||
| size='small' | |||
| className='opacity-0 transition-opacity group-hover:opacity-100' | |||
| onClick={() => onModifyLoadBalancing?.(model)} | |||
| > | |||
| <Balance className='mr-1 h-3.5 w-3.5' /> | |||
| {t('common.modelProvider.configLoadBalancing')} | |||
| </Button> | |||
| ) | |||
| : null | |||
| (isCurrentWorkspaceManager && (modelLoadBalancingEnabled || plan.type === Plan.sandbox) && !model.deprecated && [ModelStatusEnum.active, ModelStatusEnum.disabled].includes(model.status)) && ( | |||
| <ConfigModel | |||
| onClick={() => onModifyLoadBalancing?.(model)} | |||
| loadBalancingEnabled={model.load_balancing_enabled} | |||
| loadBalancingInvalid={model.has_invalid_load_balancing_configs} | |||
| credentialRemoved={model.status === ModelStatusEnum.credentialRemoved} | |||
| /> | |||
| ) | |||
| } | |||
| { | |||
| model.deprecated | |||
| @@ -5,7 +5,7 @@ import { | |||
| RiArrowRightSLine, | |||
| } from '@remixicon/react' | |||
| import type { | |||
| CustomConfigurationModelFixedFields, | |||
| Credential, | |||
| ModelItem, | |||
| ModelProvider, | |||
| } from '../declarations' | |||
| @@ -13,34 +13,33 @@ import { | |||
| ConfigurationMethodEnum, | |||
| } from '../declarations' | |||
| // import Tab from './tab' | |||
| import AddModelButton from './add-model-button' | |||
| import ModelListItem from './model-list-item' | |||
| import { useModalContextSelector } from '@/context/modal-context' | |||
| import { useAppContext } from '@/context/app-context' | |||
| import { AddCustomModel } from '@/app/components/header/account-setting/model-provider-page/model-auth' | |||
| type ModelListProps = { | |||
| provider: ModelProvider | |||
| models: ModelItem[] | |||
| onCollapse: () => void | |||
| onConfig: (currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => void | |||
| onChange?: (provider: string) => void | |||
| } | |||
| const ModelList: FC<ModelListProps> = ({ | |||
| provider, | |||
| models, | |||
| onCollapse, | |||
| onConfig, | |||
| onChange, | |||
| }) => { | |||
| const { t } = useTranslation() | |||
| const configurativeMethods = provider.configurate_methods.filter(method => method !== ConfigurationMethodEnum.fetchFromRemote) | |||
| const { isCurrentWorkspaceManager } = useAppContext() | |||
| const isConfigurable = configurativeMethods.includes(ConfigurationMethodEnum.customizableModel) | |||
| const setShowModelLoadBalancingModal = useModalContextSelector(state => state.setShowModelLoadBalancingModal) | |||
| const onModifyLoadBalancing = useCallback((model: ModelItem) => { | |||
| const onModifyLoadBalancing = useCallback((model: ModelItem, credential?: Credential) => { | |||
| setShowModelLoadBalancingModal({ | |||
| provider, | |||
| credential, | |||
| configurateMethod: model.fetch_from, | |||
| model: model!, | |||
| open: !!model, | |||
| onClose: () => setShowModelLoadBalancingModal(null), | |||
| @@ -65,17 +64,14 @@ const ModelList: FC<ModelListProps> = ({ | |||
| <RiArrowRightSLine className='mr-0.5 h-4 w-4 rotate-90' /> | |||
| </span> | |||
| </span> | |||
| {/* { | |||
| isConfigurable && canSystemConfig && ( | |||
| <span className='flex items-center'> | |||
| <Tab active='all' onSelect={() => {}} /> | |||
| </span> | |||
| ) | |||
| } */} | |||
| { | |||
| isConfigurable && isCurrentWorkspaceManager && ( | |||
| <div className='flex grow justify-end'> | |||
| <AddModelButton onClick={() => onConfig()} /> | |||
| <AddCustomModel | |||
| provider={provider} | |||
| configurationMethod={ConfigurationMethodEnum.customizableModel} | |||
| currentCustomConfigurationModelFixedFields={undefined} | |||
| /> | |||
| </div> | |||
| ) | |||
| } | |||
| @@ -83,12 +79,11 @@ const ModelList: FC<ModelListProps> = ({ | |||
| { | |||
| models.map(model => ( | |||
| <ModelListItem | |||
| key={model.model} | |||
| key={`${model.model}-${model.fetch_from}`} | |||
| {...{ | |||
| model, | |||
| provider, | |||
| isConfigurable, | |||
| onConfig, | |||
| onModifyLoadBalancing, | |||
| }} | |||
| /> | |||
| @@ -1,24 +1,35 @@ | |||
| import type { Dispatch, SetStateAction } from 'react' | |||
| import { useCallback } from 'react' | |||
| import { useCallback, useMemo } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { | |||
| RiDeleteBinLine, | |||
| RiEqualizer2Line, | |||
| } from '@remixicon/react' | |||
| import type { ConfigurationMethodEnum, CustomConfigurationModelFixedFields, ModelLoadBalancingConfig, ModelLoadBalancingConfigEntry, ModelProvider } from '../declarations' | |||
| import type { | |||
| Credential, | |||
| CustomConfigurationModelFixedFields, | |||
| CustomModelCredential, | |||
| ModelCredential, | |||
| ModelLoadBalancingConfig, | |||
| ModelLoadBalancingConfigEntry, | |||
| ModelProvider, | |||
| } from '../declarations' | |||
| import { ConfigurationMethodEnum } from '../declarations' | |||
| import Indicator from '../../../indicator' | |||
| import CooldownTimer from './cooldown-timer' | |||
| import classNames from '@/utils/classnames' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| import Switch from '@/app/components/base/switch' | |||
| import { Balance } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' | |||
| import { Edit02, Plus02 } from '@/app/components/base/icons/src/vender/line/general' | |||
| import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' | |||
| import { useModalContextSelector } from '@/context/modal-context' | |||
| import UpgradeBtn from '@/app/components/billing/upgrade-btn' | |||
| import s from '@/app/components/custom/style.module.css' | |||
| import GridMask from '@/app/components/base/grid-mask' | |||
| import { useProviderContextSelector } from '@/context/provider-context' | |||
| import { IS_CE_EDITION } from '@/config' | |||
| import { AddCredentialInLoadBalancing } from '@/app/components/header/account-setting/model-provider-page/model-auth' | |||
| import { useModelModalHandler } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import Badge from '@/app/components/base/badge/index' | |||
| export type ModelLoadBalancingConfigsProps = { | |||
| draftConfig?: ModelLoadBalancingConfig | |||
| @@ -28,19 +39,27 @@ export type ModelLoadBalancingConfigsProps = { | |||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields | |||
| withSwitch?: boolean | |||
| className?: string | |||
| modelCredential: ModelCredential | |||
| onUpdate?: () => void | |||
| model: CustomModelCredential | |||
| } | |||
| const ModelLoadBalancingConfigs = ({ | |||
| draftConfig, | |||
| setDraftConfig, | |||
| provider, | |||
| model, | |||
| configurationMethod, | |||
| currentCustomConfigurationModelFixedFields, | |||
| withSwitch = false, | |||
| className, | |||
| modelCredential, | |||
| onUpdate, | |||
| }: ModelLoadBalancingConfigsProps) => { | |||
| const { t } = useTranslation() | |||
| const providerFormSchemaPredefined = configurationMethod === ConfigurationMethodEnum.predefinedModel | |||
| const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled) | |||
| const handleOpenModal = useModelModalHandler() | |||
| const updateConfigEntry = useCallback( | |||
| ( | |||
| @@ -65,6 +84,21 @@ const ModelLoadBalancingConfigs = ({ | |||
| [setDraftConfig], | |||
| ) | |||
| const addConfigEntry = useCallback((credential: Credential) => { | |||
| setDraftConfig((prev: any) => { | |||
| if (!prev) | |||
| return prev | |||
| return { | |||
| ...prev, | |||
| configs: [...prev.configs, { | |||
| credential_id: credential.credential_id, | |||
| enabled: true, | |||
| name: credential.credential_name, | |||
| }], | |||
| } | |||
| }) | |||
| }, [setDraftConfig]) | |||
| const toggleModalBalancing = useCallback((enabled: boolean) => { | |||
| if ((modelLoadBalancingEnabled || !enabled) && draftConfig) { | |||
| setDraftConfig({ | |||
| @@ -81,54 +115,6 @@ const ModelLoadBalancingConfigs = ({ | |||
| })) | |||
| }, [updateConfigEntry]) | |||
| const setShowModelLoadBalancingEntryModal = useModalContextSelector(state => state.setShowModelLoadBalancingEntryModal) | |||
| const toggleEntryModal = useCallback((index?: number, entry?: ModelLoadBalancingConfigEntry) => { | |||
| setShowModelLoadBalancingEntryModal({ | |||
| payload: { | |||
| currentProvider: provider, | |||
| currentConfigurationMethod: configurationMethod, | |||
| currentCustomConfigurationModelFixedFields, | |||
| entry, | |||
| index, | |||
| }, | |||
| onSaveCallback: ({ entry: result }) => { | |||
| if (entry) { | |||
| // edit | |||
| setDraftConfig(prev => ({ | |||
| ...prev, | |||
| enabled: !!prev?.enabled, | |||
| configs: prev?.configs.map((config, i) => i === index ? result! : config) || [], | |||
| })) | |||
| } | |||
| else { | |||
| // add | |||
| setDraftConfig(prev => ({ | |||
| ...prev, | |||
| enabled: !!prev?.enabled, | |||
| configs: (prev?.configs || []).concat([{ ...result!, enabled: true }]), | |||
| })) | |||
| } | |||
| }, | |||
| onRemoveCallback: ({ index }) => { | |||
| if (index !== undefined && (draftConfig?.configs?.length ?? 0) > index) { | |||
| setDraftConfig(prev => ({ | |||
| ...prev, | |||
| enabled: !!prev?.enabled, | |||
| configs: prev?.configs.filter((_, i) => i !== index) || [], | |||
| })) | |||
| } | |||
| }, | |||
| }) | |||
| }, [ | |||
| configurationMethod, | |||
| currentCustomConfigurationModelFixedFields, | |||
| draftConfig?.configs?.length, | |||
| provider, | |||
| setDraftConfig, | |||
| setShowModelLoadBalancingEntryModal, | |||
| ]) | |||
| const clearCountdown = useCallback((index: number) => { | |||
| updateConfigEntry(index, ({ ttl: _, ...entry }) => { | |||
| return { | |||
| @@ -138,6 +124,12 @@ const ModelLoadBalancingConfigs = ({ | |||
| }) | |||
| }, [updateConfigEntry]) | |||
| const validDraftConfigList = useMemo(() => { | |||
| if (!draftConfig) | |||
| return [] | |||
| return draftConfig.configs | |||
| }, [draftConfig]) | |||
| if (!draftConfig) | |||
| return null | |||
| @@ -181,8 +173,9 @@ const ModelLoadBalancingConfigs = ({ | |||
| </div> | |||
| {draftConfig.enabled && ( | |||
| <div className='flex flex-col gap-1 px-3 pb-3'> | |||
| {draftConfig.configs.map((config, index) => { | |||
| {validDraftConfigList.map((config, index) => { | |||
| const isProviderManaged = config.name === '__inherit__' | |||
| const credential = modelCredential.available_credentials.find(c => c.credential_id === config.credential_id) | |||
| return ( | |||
| <div key={config.id || index} className='group flex h-10 items-center rounded-lg border border-components-panel-border bg-components-panel-on-panel-item-bg px-3 shadow-xs'> | |||
| <div className='flex grow items-center'> | |||
| @@ -200,54 +193,81 @@ const ModelLoadBalancingConfigs = ({ | |||
| <div className='mr-1 text-[13px]'> | |||
| {isProviderManaged ? t('common.modelProvider.defaultConfig') : config.name} | |||
| </div> | |||
| {isProviderManaged && ( | |||
| <span className='rounded-[5px] border border-divider-regular px-1 text-2xs uppercase text-text-tertiary'>{t('common.modelProvider.providerManaged')}</span> | |||
| {isProviderManaged && providerFormSchemaPredefined && ( | |||
| <Badge className='ml-2'>{t('common.modelProvider.providerManaged')}</Badge> | |||
| )} | |||
| { | |||
| credential?.from_enterprise && ( | |||
| <Badge className='ml-2'>Enterprise</Badge> | |||
| ) | |||
| } | |||
| </div> | |||
| <div className='flex items-center gap-1'> | |||
| {!isProviderManaged && ( | |||
| <> | |||
| <div className='flex items-center gap-1 opacity-0 transition-opacity group-hover:opacity-100'> | |||
| <span | |||
| className='flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg bg-components-button-secondary-bg text-text-tertiary transition-colors hover:bg-components-button-secondary-bg-hover' | |||
| onClick={() => toggleEntryModal(index, config)} | |||
| > | |||
| <Edit02 className='h-4 w-4' /> | |||
| </span> | |||
| { | |||
| config.credential_id && !credential?.not_allowed_to_use && !credential?.from_enterprise && ( | |||
| <span | |||
| className='flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg bg-components-button-secondary-bg text-text-tertiary transition-colors hover:bg-components-button-secondary-bg-hover' | |||
| onClick={() => { | |||
| handleOpenModal( | |||
| provider, | |||
| configurationMethod, | |||
| currentCustomConfigurationModelFixedFields, | |||
| configurationMethod === ConfigurationMethodEnum.customizableModel, | |||
| (config.credential_id && config.name) ? { | |||
| credential_id: config.credential_id, | |||
| credential_name: config.name, | |||
| } : undefined, | |||
| model, | |||
| ) | |||
| }} | |||
| > | |||
| <RiEqualizer2Line className='h-4 w-4' /> | |||
| </span> | |||
| ) | |||
| } | |||
| <span | |||
| className='flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg bg-components-button-secondary-bg text-text-tertiary transition-colors hover:bg-components-button-secondary-bg-hover' | |||
| onClick={() => updateConfigEntry(index, () => undefined)} | |||
| > | |||
| <RiDeleteBinLine className='h-4 w-4' /> | |||
| </span> | |||
| <span className='mr-2 h-3 border-r border-r-divider-subtle' /> | |||
| </div> | |||
| </> | |||
| )} | |||
| <Switch | |||
| defaultValue={Boolean(config.enabled)} | |||
| size='md' | |||
| className='justify-self-end' | |||
| onChange={value => toggleConfigEntryEnabled(index, value)} | |||
| /> | |||
| { | |||
| (config.credential_id || config.name === '__inherit__') && ( | |||
| <> | |||
| <span className='mr-2 h-3 border-r border-r-divider-subtle' /> | |||
| <Switch | |||
| defaultValue={Boolean(config.enabled)} | |||
| size='md' | |||
| className='justify-self-end' | |||
| onChange={value => toggleConfigEntryEnabled(index, value)} | |||
| disabled={credential?.not_allowed_to_use} | |||
| /> | |||
| </> | |||
| ) | |||
| } | |||
| </div> | |||
| </div> | |||
| ) | |||
| })} | |||
| <div | |||
| className='mt-1 flex h-8 items-center px-3 text-[13px] font-medium text-primary-600' | |||
| onClick={() => toggleEntryModal()} | |||
| > | |||
| <div className='flex cursor-pointer items-center'> | |||
| <Plus02 className='mr-2 h-3 w-3' />{t('common.modelProvider.addConfig')} | |||
| </div> | |||
| </div> | |||
| <AddCredentialInLoadBalancing | |||
| provider={provider} | |||
| model={model} | |||
| configurationMethod={configurationMethod} | |||
| modelCredential={modelCredential} | |||
| onSelectCredential={addConfigEntry} | |||
| onUpdate={onUpdate} | |||
| /> | |||
| </div> | |||
| )} | |||
| { | |||
| draftConfig.enabled && draftConfig.configs.length < 2 && ( | |||
| <div className='flex h-[34px] items-center border-t border-t-divider-subtle bg-components-panel-bg px-6 text-xs text-text-secondary'> | |||
| draftConfig.enabled && validDraftConfigList.length < 2 && ( | |||
| <div className='flex h-[34px] items-center rounded-b-xl border-t border-t-divider-subtle bg-components-panel-bg px-6 text-xs text-text-secondary'> | |||
| <AlertTriangle className='mr-1 h-3 w-3 text-[#f79009]' /> | |||
| {t('common.modelProvider.loadBalancingLeastKeyWarning')} | |||
| </div> | |||
| @@ -1,40 +1,69 @@ | |||
| import { memo, useCallback, useEffect, useMemo, useState } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import useSWR from 'swr' | |||
| import type { ModelItem, ModelLoadBalancingConfig, ModelLoadBalancingConfigEntry, ModelProvider } from '../declarations' | |||
| import { FormTypeEnum } from '../declarations' | |||
| import type { | |||
| Credential, | |||
| ModelItem, | |||
| ModelLoadBalancingConfig, | |||
| ModelLoadBalancingConfigEntry, | |||
| ModelProvider, | |||
| } from '../declarations' | |||
| import { | |||
| ConfigurationMethodEnum, | |||
| FormTypeEnum, | |||
| } from '../declarations' | |||
| import ModelIcon from '../model-icon' | |||
| import ModelName from '../model-name' | |||
| import { savePredefinedLoadBalancingConfig } from '../utils' | |||
| import ModelLoadBalancingConfigs from './model-load-balancing-configs' | |||
| import classNames from '@/utils/classnames' | |||
| import Modal from '@/app/components/base/modal' | |||
| import Button from '@/app/components/base/button' | |||
| import { fetchModelLoadBalancingConfig } from '@/service/common' | |||
| import Loading from '@/app/components/base/loading' | |||
| import { useToastContext } from '@/app/components/base/toast' | |||
| import { SwitchCredentialInLoadBalancing } from '@/app/components/header/account-setting/model-provider-page/model-auth' | |||
| import { | |||
| useGetModelCredential, | |||
| useUpdateModelLoadBalancingConfig, | |||
| } from '@/service/use-models' | |||
| export type ModelLoadBalancingModalProps = { | |||
| provider: ModelProvider | |||
| configurateMethod: ConfigurationMethodEnum | |||
| model: ModelItem | |||
| credential?: Credential | |||
| open?: boolean | |||
| onClose?: () => void | |||
| onSave?: (provider: string) => void | |||
| } | |||
| // model balancing config modal | |||
| const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSave }: ModelLoadBalancingModalProps) => { | |||
| const ModelLoadBalancingModal = ({ | |||
| provider, | |||
| configurateMethod, | |||
| model, | |||
| credential, | |||
| open = false, | |||
| onClose, | |||
| onSave, | |||
| }: ModelLoadBalancingModalProps) => { | |||
| const { t } = useTranslation() | |||
| const { notify } = useToastContext() | |||
| const [loading, setLoading] = useState(false) | |||
| const { data, mutate } = useSWR( | |||
| `/workspaces/current/model-providers/${provider.provider}/models/credentials?model=${model.model}&model_type=${model.model_type}`, | |||
| fetchModelLoadBalancingConfig, | |||
| ) | |||
| const originalConfig = data?.load_balancing | |||
| const providerFormSchemaPredefined = configurateMethod === ConfigurationMethodEnum.predefinedModel | |||
| const configFrom = providerFormSchemaPredefined ? 'predefined-model' : 'custom-model' | |||
| const { | |||
| isLoading, | |||
| data, | |||
| refetch, | |||
| } = useGetModelCredential(true, provider.provider, credential?.credential_id, model.model, model.model_type, configFrom) | |||
| const modelCredential = data | |||
| const { | |||
| load_balancing, | |||
| current_credential_id, | |||
| available_credentials, | |||
| current_credential_name, | |||
| } = modelCredential ?? {} | |||
| const originalConfig = load_balancing | |||
| const [draftConfig, setDraftConfig] = useState<ModelLoadBalancingConfig>() | |||
| const originalConfigMap = useMemo(() => { | |||
| if (!originalConfig) | |||
| @@ -60,10 +89,17 @@ const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSav | |||
| }, [draftConfig]) | |||
| const extendedSecretFormSchemas = useMemo( | |||
| () => provider.provider_credential_schema.credential_form_schemas.filter( | |||
| ({ type }) => type === FormTypeEnum.secretInput, | |||
| ), | |||
| [provider.provider_credential_schema.credential_form_schemas], | |||
| () => { | |||
| if (providerFormSchemaPredefined) { | |||
| return provider?.provider_credential_schema?.credential_form_schemas?.filter( | |||
| ({ type }) => type === FormTypeEnum.secretInput, | |||
| ) ?? [] | |||
| } | |||
| return provider?.model_credential_schema?.credential_form_schemas?.filter( | |||
| ({ type }) => type === FormTypeEnum.secretInput, | |||
| ) ?? [] | |||
| }, | |||
| [provider?.model_credential_schema?.credential_form_schemas, provider?.provider_credential_schema?.credential_form_schemas, providerFormSchemaPredefined], | |||
| ) | |||
| const encodeConfigEntrySecretValues = useCallback((entry: ModelLoadBalancingConfigEntry) => { | |||
| @@ -75,25 +111,34 @@ const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSav | |||
| return result | |||
| }, [extendedSecretFormSchemas, originalConfigMap]) | |||
| const { mutateAsync: updateModelLoadBalancingConfig } = useUpdateModelLoadBalancingConfig(provider.provider) | |||
| const initialCustomModelCredential = useMemo(() => { | |||
| if (!current_credential_id) | |||
| return undefined | |||
| return { | |||
| credential_id: current_credential_id, | |||
| credential_name: current_credential_name, | |||
| } | |||
| }, [current_credential_id, current_credential_name]) | |||
| const [customModelCredential, setCustomModelCredential] = useState<Credential | undefined>(initialCustomModelCredential) | |||
| const handleSave = async () => { | |||
| try { | |||
| setLoading(true) | |||
| const res = await savePredefinedLoadBalancingConfig( | |||
| provider.provider, | |||
| ({ | |||
| ...(data?.credentials ?? {}), | |||
| __model_type: model.model_type, | |||
| __model_name: model.model, | |||
| }), | |||
| const res = await updateModelLoadBalancingConfig( | |||
| { | |||
| ...draftConfig, | |||
| enabled: Boolean(draftConfig?.enabled), | |||
| configs: draftConfig!.configs.map(encodeConfigEntrySecretValues), | |||
| credential_id: customModelCredential?.credential_id || current_credential_id, | |||
| config_from: configFrom, | |||
| model: model.model, | |||
| model_type: model.model_type, | |||
| load_balancing: { | |||
| ...draftConfig, | |||
| configs: draftConfig!.configs.map(encodeConfigEntrySecretValues), | |||
| enabled: Boolean(draftConfig?.enabled), | |||
| }, | |||
| }, | |||
| ) | |||
| if (res.result === 'success') { | |||
| notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) | |||
| mutate() | |||
| onSave?.(provider.provider) | |||
| onClose?.() | |||
| } | |||
| @@ -110,7 +155,11 @@ const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSav | |||
| className='w-[640px] max-w-none px-8 pt-8' | |||
| title={ | |||
| <div className='pb-3 font-semibold'> | |||
| <div className='h-[30px]'>{t('common.modelProvider.configLoadBalancing')}</div> | |||
| <div className='h-[30px]'>{ | |||
| draftConfig?.enabled | |||
| ? t('common.modelProvider.auth.configLoadBalancing') | |||
| : t('common.modelProvider.auth.configModel') | |||
| }</div> | |||
| {Boolean(model) && ( | |||
| <div className='flex h-5 items-center'> | |||
| <ModelIcon | |||
| @@ -149,23 +198,51 @@ const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSav | |||
| )} | |||
| </div> | |||
| <div className='grow'> | |||
| <div className='text-sm text-text-secondary'>{t('common.modelProvider.providerManaged')}</div> | |||
| <div className='text-xs text-text-tertiary'>{t('common.modelProvider.providerManagedDescription')}</div> | |||
| <div className='text-sm text-text-secondary'>{ | |||
| providerFormSchemaPredefined | |||
| ? t('common.modelProvider.auth.providerManaged') | |||
| : t('common.modelProvider.auth.specifyModelCredential') | |||
| }</div> | |||
| <div className='text-xs text-text-tertiary'>{ | |||
| providerFormSchemaPredefined | |||
| ? t('common.modelProvider.auth.providerManagedTip') | |||
| : t('common.modelProvider.auth.specifyModelCredentialTip') | |||
| }</div> | |||
| </div> | |||
| { | |||
| !providerFormSchemaPredefined && ( | |||
| <SwitchCredentialInLoadBalancing | |||
| provider={provider} | |||
| customModelCredential={initialCustomModelCredential ?? customModelCredential} | |||
| setCustomModelCredential={setCustomModelCredential} | |||
| model={model} | |||
| credentials={available_credentials} | |||
| /> | |||
| ) | |||
| } | |||
| </div> | |||
| </div> | |||
| <ModelLoadBalancingConfigs {...{ | |||
| draftConfig, | |||
| setDraftConfig, | |||
| provider, | |||
| currentCustomConfigurationModelFixedFields: { | |||
| __model_name: model.model, | |||
| __model_type: model.model_type, | |||
| }, | |||
| configurationMethod: model.fetch_from, | |||
| className: 'mt-2', | |||
| }} /> | |||
| { | |||
| modelCredential && ( | |||
| <ModelLoadBalancingConfigs {...{ | |||
| draftConfig, | |||
| setDraftConfig, | |||
| provider, | |||
| currentCustomConfigurationModelFixedFields: { | |||
| __model_name: model.model, | |||
| __model_type: model.model_type, | |||
| }, | |||
| configurationMethod: model.fetch_from, | |||
| className: 'mt-2', | |||
| modelCredential, | |||
| onUpdate: refetch, | |||
| model: { | |||
| model: model.model, | |||
| model_type: model.model_type, | |||
| }, | |||
| }} /> | |||
| ) | |||
| } | |||
| </div> | |||
| <div className='mt-6 flex items-center justify-end gap-2'> | |||
| @@ -176,6 +253,7 @@ const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSav | |||
| disabled={ | |||
| loading | |||
| || (draftConfig?.enabled && (draftConfig?.configs.filter(config => config.enabled).length ?? 0) < 2) | |||
| || isLoading | |||
| } | |||
| >{t('common.operation.save')}</Button> | |||
| </div> | |||
| @@ -1,6 +1,5 @@ | |||
| import { ValidatedStatus } from '../key-validator/declarations' | |||
| import type { | |||
| CredentialFormSchemaRadio, | |||
| CredentialFormSchemaTextInput, | |||
| FormValue, | |||
| ModelLoadBalancingConfig, | |||
| @@ -82,12 +81,14 @@ export const saveCredentials = async (predefined: boolean, provider: string, v: | |||
| let body, url | |||
| if (predefined) { | |||
| const { __authorization_name__, ...rest } = v | |||
| body = { | |||
| config_from: ConfigurationMethodEnum.predefinedModel, | |||
| credentials: v, | |||
| credentials: rest, | |||
| load_balancing: loadBalancing, | |||
| name: __authorization_name__, | |||
| } | |||
| url = `/workspaces/current/model-providers/${provider}` | |||
| url = `/workspaces/current/model-providers/${provider}/credentials` | |||
| } | |||
| else { | |||
| const { __model_name, __model_type, ...credentials } = v | |||
| @@ -117,12 +118,17 @@ export const savePredefinedLoadBalancingConfig = async (provider: string, v: For | |||
| return setModelProvider({ url, body }) | |||
| } | |||
| export const removeCredentials = async (predefined: boolean, provider: string, v: FormValue) => { | |||
| export const removeCredentials = async (predefined: boolean, provider: string, v: FormValue, credentialId?: string) => { | |||
| let url = '' | |||
| let body | |||
| if (predefined) { | |||
| url = `/workspaces/current/model-providers/${provider}` | |||
| url = `/workspaces/current/model-providers/${provider}/credentials` | |||
| if (credentialId) { | |||
| body = { | |||
| credential_id: credentialId, | |||
| } | |||
| } | |||
| } | |||
| else { | |||
| if (v) { | |||
| @@ -174,7 +180,7 @@ export const genModelTypeFormSchema = (modelTypes: ModelTypeEnum[]) => { | |||
| show_on: [], | |||
| } | |||
| }), | |||
| } as CredentialFormSchemaRadio | |||
| } as any | |||
| } | |||
| export const genModelNameFormSchema = (model?: Pick<CredentialFormSchemaTextInput, 'label' | 'placeholder'>) => { | |||
| @@ -191,5 +197,5 @@ export const genModelNameFormSchema = (model?: Pick<CredentialFormSchemaTextInpu | |||
| zh_Hans: '请输入模型名称', | |||
| en_US: 'Please enter model name', | |||
| }, | |||
| } as CredentialFormSchemaTextInput | |||
| } as any | |||
| } | |||
| @@ -8,6 +8,8 @@ import type { AddOAuthButtonProps } from './add-oauth-button' | |||
| import AddApiKeyButton from './add-api-key-button' | |||
| import type { AddApiKeyButtonProps } from './add-api-key-button' | |||
| import type { PluginPayload } from '../types' | |||
| import cn from '@/utils/classnames' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| type AuthorizeProps = { | |||
| pluginPayload: PluginPayload | |||
| @@ -17,6 +19,7 @@ type AuthorizeProps = { | |||
| canApiKey?: boolean | |||
| disabled?: boolean | |||
| onUpdate?: () => void | |||
| notAllowCustomCredential?: boolean | |||
| } | |||
| const Authorize = ({ | |||
| pluginPayload, | |||
| @@ -26,6 +29,7 @@ const Authorize = ({ | |||
| canApiKey, | |||
| disabled, | |||
| onUpdate, | |||
| notAllowCustomCredential, | |||
| }: AuthorizeProps) => { | |||
| const { t } = useTranslation() | |||
| const oAuthButtonProps: AddOAuthButtonProps = useMemo(() => { | |||
| @@ -62,18 +66,54 @@ const Authorize = ({ | |||
| } | |||
| }, [canOAuth, theme, pluginPayload, t]) | |||
| const OAuthButton = useMemo(() => { | |||
| const Item = ( | |||
| <div className={cn('min-w-0 flex-[1]', notAllowCustomCredential && 'opacity-50')}> | |||
| <AddOAuthButton | |||
| {...oAuthButtonProps} | |||
| disabled={disabled || notAllowCustomCredential} | |||
| onUpdate={onUpdate} | |||
| /> | |||
| </div> | |||
| ) | |||
| if (notAllowCustomCredential) { | |||
| return ( | |||
| <Tooltip popupContent={t('plugin.auth.credentialUnavailable')}> | |||
| {Item} | |||
| </Tooltip> | |||
| ) | |||
| } | |||
| return Item | |||
| }, [notAllowCustomCredential, oAuthButtonProps, disabled, onUpdate, t]) | |||
| const ApiKeyButton = useMemo(() => { | |||
| const Item = ( | |||
| <div className={cn('min-w-0 flex-[1]', notAllowCustomCredential && 'opacity-50')}> | |||
| <AddApiKeyButton | |||
| {...apiKeyButtonProps} | |||
| disabled={disabled || notAllowCustomCredential} | |||
| onUpdate={onUpdate} | |||
| /> | |||
| </div> | |||
| ) | |||
| if (notAllowCustomCredential) { | |||
| return ( | |||
| <Tooltip popupContent={t('plugin.auth.credentialUnavailable')}> | |||
| {Item} | |||
| </Tooltip> | |||
| ) | |||
| } | |||
| return Item | |||
| }, [notAllowCustomCredential, apiKeyButtonProps, disabled, onUpdate, t]) | |||
| return ( | |||
| <> | |||
| <div className='flex items-center space-x-1.5'> | |||
| { | |||
| canOAuth && ( | |||
| <div className='min-w-0 flex-[1]'> | |||
| <AddOAuthButton | |||
| {...oAuthButtonProps} | |||
| disabled={disabled} | |||
| onUpdate={onUpdate} | |||
| /> | |||
| </div> | |||
| OAuthButton | |||
| ) | |||
| } | |||
| { | |||
| @@ -87,13 +127,7 @@ const Authorize = ({ | |||
| } | |||
| { | |||
| canApiKey && ( | |||
| <div className='min-w-0 flex-[1]'> | |||
| <AddApiKeyButton | |||
| {...apiKeyButtonProps} | |||
| disabled={disabled} | |||
| onUpdate={onUpdate} | |||
| /> | |||
| </div> | |||
| ApiKeyButton | |||
| ) | |||
| } | |||
| </div> | |||
| @@ -35,10 +35,13 @@ const AuthorizedInNode = ({ | |||
| credentials, | |||
| disabled, | |||
| invalidPluginCredentialInfo, | |||
| notAllowCustomCredential, | |||
| } = usePluginAuth(pluginPayload, isOpen || !!credentialId) | |||
| const renderTrigger = useCallback((open?: boolean) => { | |||
| let label = '' | |||
| let removed = false | |||
| let unavailable = false | |||
| let color = 'green' | |||
| if (!credentialId) { | |||
| label = t('plugin.auth.workspaceDefault') | |||
| } | |||
| @@ -46,6 +49,12 @@ const AuthorizedInNode = ({ | |||
| const credential = credentials.find(c => c.id === credentialId) | |||
| label = credential ? credential.name : t('plugin.auth.authRemoved') | |||
| removed = !credential | |||
| unavailable = !!credential?.not_allowed_to_use && !credential?.from_enterprise | |||
| if (removed) | |||
| color = 'red' | |||
| else if (unavailable) | |||
| color = 'gray' | |||
| } | |||
| return ( | |||
| <Button | |||
| @@ -57,9 +66,12 @@ const AuthorizedInNode = ({ | |||
| > | |||
| <Indicator | |||
| className='mr-1.5' | |||
| color={removed ? 'red' : 'green'} | |||
| color={color as any} | |||
| /> | |||
| {label} | |||
| { | |||
| unavailable && t('plugin.auth.unavailable') | |||
| } | |||
| <RiArrowDownSLine | |||
| className={cn( | |||
| 'h-3.5 w-3.5 text-components-button-ghost-text', | |||
| @@ -106,6 +118,7 @@ const AuthorizedInNode = ({ | |||
| showItemSelectedIcon | |||
| selectedCredentialId={credentialId || '__workspace_default__'} | |||
| onUpdate={invalidPluginCredentialInfo} | |||
| notAllowCustomCredential={notAllowCustomCredential} | |||
| /> | |||
| ) | |||
| } | |||
| @@ -52,6 +52,7 @@ type AuthorizedProps = { | |||
| showItemSelectedIcon?: boolean | |||
| selectedCredentialId?: string | |||
| onUpdate?: () => void | |||
| notAllowCustomCredential?: boolean | |||
| } | |||
| const Authorized = ({ | |||
| pluginPayload, | |||
| @@ -72,6 +73,7 @@ const Authorized = ({ | |||
| showItemSelectedIcon, | |||
| selectedCredentialId, | |||
| onUpdate, | |||
| notAllowCustomCredential, | |||
| }: AuthorizedProps) => { | |||
| const { t } = useTranslation() | |||
| const { notify } = useToastContext() | |||
| @@ -171,6 +173,7 @@ const Authorized = ({ | |||
| handleSetDoingAction(false) | |||
| } | |||
| }, [updatePluginCredential, notify, t, handleSetDoingAction, onUpdate]) | |||
| const unavailableCredentials = credentials.filter(credential => credential.not_allowed_to_use) | |||
| return ( | |||
| <> | |||
| @@ -201,6 +204,11 @@ const Authorized = ({ | |||
| ? t('plugin.auth.authorizations') | |||
| : t('plugin.auth.authorization') | |||
| } | |||
| { | |||
| !!unavailableCredentials.length && ( | |||
| ` (${unavailableCredentials.length} ${t('plugin.auth.unavailable')})` | |||
| ) | |||
| } | |||
| <RiArrowDownSLine className='ml-0.5 h-4 w-4' /> | |||
| </Button> | |||
| ) | |||
| @@ -294,18 +302,24 @@ const Authorized = ({ | |||
| ) | |||
| } | |||
| </div> | |||
| <div className='h-px bg-divider-subtle'></div> | |||
| <div className='p-2'> | |||
| <Authorize | |||
| pluginPayload={pluginPayload} | |||
| theme='secondary' | |||
| showDivider={false} | |||
| canOAuth={canOAuth} | |||
| canApiKey={canApiKey} | |||
| disabled={disabled} | |||
| onUpdate={onUpdate} | |||
| /> | |||
| </div> | |||
| { | |||
| !notAllowCustomCredential && ( | |||
| <> | |||
| <div className='h-[1px] bg-divider-subtle'></div> | |||
| <div className='p-2'> | |||
| <Authorize | |||
| pluginPayload={pluginPayload} | |||
| theme='secondary' | |||
| showDivider={false} | |||
| canOAuth={canOAuth} | |||
| canApiKey={canApiKey} | |||
| disabled={disabled} | |||
| onUpdate={onUpdate} | |||
| /> | |||
| </div> | |||
| </> | |||
| ) | |||
| } | |||
| </div> | |||
| </PortalToFollowElemContent> | |||
| </PortalToFollowElem> | |||
| @@ -61,14 +61,19 @@ const Item = ({ | |||
| return !(disableRename && disableEdit && disableDelete && disableSetDefault) | |||
| }, [disableRename, disableEdit, disableDelete, disableSetDefault]) | |||
| return ( | |||
| const CredentialItem = ( | |||
| <div | |||
| key={credential.id} | |||
| className={cn( | |||
| 'group flex h-8 items-center rounded-lg p-1 hover:bg-state-base-hover', | |||
| renaming && 'bg-state-base-hover', | |||
| (disabled || credential.not_allowed_to_use) && 'cursor-not-allowed opacity-50', | |||
| )} | |||
| onClick={() => onItemClick?.(credential.id === '__workspace_default__' ? '' : credential.id)} | |||
| onClick={() => { | |||
| if (credential.not_allowed_to_use || disabled) | |||
| return | |||
| onItemClick?.(credential.id === '__workspace_default__' ? '' : credential.id) | |||
| }} | |||
| > | |||
| { | |||
| renaming && ( | |||
| @@ -121,7 +126,10 @@ const Item = ({ | |||
| </div> | |||
| ) | |||
| } | |||
| <Indicator className='ml-2 mr-1.5 shrink-0' /> | |||
| <Indicator | |||
| className='ml-2 mr-1.5 shrink-0' | |||
| color={credential.not_allowed_to_use ? 'gray' : 'green'} | |||
| /> | |||
| <div | |||
| className='system-md-regular truncate text-text-secondary' | |||
| title={credential.name} | |||
| @@ -138,11 +146,18 @@ const Item = ({ | |||
| </div> | |||
| ) | |||
| } | |||
| { | |||
| credential.from_enterprise && ( | |||
| <Badge className='shrink-0'> | |||
| Enterprise | |||
| </Badge> | |||
| ) | |||
| } | |||
| { | |||
| showAction && !renaming && ( | |||
| <div className='ml-2 hidden shrink-0 items-center group-hover:flex'> | |||
| { | |||
| !credential.is_default && !disableSetDefault && ( | |||
| !credential.is_default && !disableSetDefault && !credential.not_allowed_to_use && ( | |||
| <Button | |||
| size='small' | |||
| disabled={disabled} | |||
| @@ -156,7 +171,7 @@ const Item = ({ | |||
| ) | |||
| } | |||
| { | |||
| !disableRename && ( | |||
| !disableRename && !credential.from_enterprise && !credential.not_allowed_to_use && ( | |||
| <Tooltip popupContent={t('common.operation.rename')}> | |||
| <ActionButton | |||
| disabled={disabled} | |||
| @@ -172,7 +187,7 @@ const Item = ({ | |||
| ) | |||
| } | |||
| { | |||
| !isOAuth && !disableEdit && ( | |||
| !isOAuth && !disableEdit && !credential.from_enterprise && !credential.not_allowed_to_use && ( | |||
| <Tooltip popupContent={t('common.operation.edit')}> | |||
| <ActionButton | |||
| disabled={disabled} | |||
| @@ -194,7 +209,7 @@ const Item = ({ | |||
| ) | |||
| } | |||
| { | |||
| !disableDelete && ( | |||
| !disableDelete && !credential.from_enterprise && ( | |||
| <Tooltip popupContent={t('common.operation.delete')}> | |||
| <ActionButton | |||
| className='hover:bg-transparent' | |||
| @@ -214,6 +229,18 @@ const Item = ({ | |||
| } | |||
| </div> | |||
| ) | |||
| if (credential.not_allowed_to_use) { | |||
| return ( | |||
| <Tooltip popupContent={t('plugin.auth.customCredentialUnavailable')}> | |||
| {CredentialItem} | |||
| </Tooltip> | |||
| ) | |||
| } | |||
| return ( | |||
| CredentialItem | |||
| ) | |||
| } | |||
| export default memo(Item) | |||
| @@ -0,0 +1,125 @@ | |||
| import { | |||
| useCallback, | |||
| useRef, | |||
| useState, | |||
| } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { useToastContext } from '@/app/components/base/toast' | |||
| import type { PluginPayload } from '@/app/components/plugins/plugin-auth/types' | |||
| import { | |||
| useDeletePluginCredentialHook, | |||
| useSetPluginDefaultCredentialHook, | |||
| useUpdatePluginCredentialHook, | |||
| } from '../hooks/use-credential' | |||
| export const usePluginAuthAction = ( | |||
| pluginPayload: PluginPayload, | |||
| onUpdate?: () => void, | |||
| ) => { | |||
| const { t } = useTranslation() | |||
| const { notify } = useToastContext() | |||
| const pendingOperationCredentialId = useRef<string | null>(null) | |||
| const [deleteCredentialId, setDeleteCredentialId] = useState<string | null>(null) | |||
| const { mutateAsync: deletePluginCredential } = useDeletePluginCredentialHook(pluginPayload) | |||
| const openConfirm = useCallback((credentialId?: string) => { | |||
| if (credentialId) | |||
| pendingOperationCredentialId.current = credentialId | |||
| setDeleteCredentialId(pendingOperationCredentialId.current) | |||
| }, []) | |||
| const closeConfirm = useCallback(() => { | |||
| setDeleteCredentialId(null) | |||
| pendingOperationCredentialId.current = null | |||
| }, []) | |||
| const [doingAction, setDoingAction] = useState(false) | |||
| const doingActionRef = useRef(doingAction) | |||
| const handleSetDoingAction = useCallback((doing: boolean) => { | |||
| doingActionRef.current = doing | |||
| setDoingAction(doing) | |||
| }, []) | |||
| const [editValues, setEditValues] = useState<Record<string, any> | null>(null) | |||
| const handleConfirm = useCallback(async () => { | |||
| if (doingActionRef.current) | |||
| return | |||
| if (!pendingOperationCredentialId.current) { | |||
| setDeleteCredentialId(null) | |||
| return | |||
| } | |||
| try { | |||
| handleSetDoingAction(true) | |||
| await deletePluginCredential({ credential_id: pendingOperationCredentialId.current }) | |||
| notify({ | |||
| type: 'success', | |||
| message: t('common.api.actionSuccess'), | |||
| }) | |||
| onUpdate?.() | |||
| setDeleteCredentialId(null) | |||
| pendingOperationCredentialId.current = null | |||
| setEditValues(null) | |||
| } | |||
| finally { | |||
| handleSetDoingAction(false) | |||
| } | |||
| }, [deletePluginCredential, onUpdate, notify, t, handleSetDoingAction]) | |||
| const handleEdit = useCallback((id: string, values: Record<string, any>) => { | |||
| pendingOperationCredentialId.current = id | |||
| setEditValues(values) | |||
| }, []) | |||
| const handleRemove = useCallback(() => { | |||
| setDeleteCredentialId(pendingOperationCredentialId.current) | |||
| }, []) | |||
| const { mutateAsync: setPluginDefaultCredential } = useSetPluginDefaultCredentialHook(pluginPayload) | |||
| const handleSetDefault = useCallback(async (id: string) => { | |||
| if (doingActionRef.current) | |||
| return | |||
| try { | |||
| handleSetDoingAction(true) | |||
| await setPluginDefaultCredential(id) | |||
| notify({ | |||
| type: 'success', | |||
| message: t('common.api.actionSuccess'), | |||
| }) | |||
| onUpdate?.() | |||
| } | |||
| finally { | |||
| handleSetDoingAction(false) | |||
| } | |||
| }, [setPluginDefaultCredential, onUpdate, notify, t, handleSetDoingAction]) | |||
| const { mutateAsync: updatePluginCredential } = useUpdatePluginCredentialHook(pluginPayload) | |||
| const handleRename = useCallback(async (payload: { | |||
| credential_id: string | |||
| name: string | |||
| }) => { | |||
| if (doingActionRef.current) | |||
| return | |||
| try { | |||
| handleSetDoingAction(true) | |||
| await updatePluginCredential(payload) | |||
| notify({ | |||
| type: 'success', | |||
| message: t('common.api.actionSuccess'), | |||
| }) | |||
| onUpdate?.() | |||
| } | |||
| finally { | |||
| handleSetDoingAction(false) | |||
| } | |||
| }, [updatePluginCredential, notify, t, handleSetDoingAction, onUpdate]) | |||
| return { | |||
| doingAction, | |||
| handleSetDoingAction, | |||
| openConfirm, | |||
| closeConfirm, | |||
| deleteCredentialId, | |||
| setDeleteCredentialId, | |||
| handleConfirm, | |||
| editValues, | |||
| setEditValues, | |||
| handleEdit, | |||
| handleRemove, | |||
| handleSetDefault, | |||
| handleRename, | |||
| pendingOperationCredentialId, | |||
| } | |||
| } | |||
| @@ -20,6 +20,7 @@ export const usePluginAuth = (pluginPayload: PluginPayload, enable?: boolean) => | |||
| canApiKey, | |||
| credentials: data?.credentials || [], | |||
| disabled: !isCurrentWorkspaceManager, | |||
| notAllowCustomCredential: data?.allow_custom_token === false, | |||
| invalidPluginCredentialInfo, | |||
| } | |||
| } | |||
| @@ -35,6 +35,7 @@ const PluginAuthInAgent = ({ | |||
| credentials, | |||
| disabled, | |||
| invalidPluginCredentialInfo, | |||
| notAllowCustomCredential, | |||
| } = usePluginAuth(pluginPayload, true) | |||
| const extraAuthorizationItems: Credential[] = [ | |||
| @@ -58,6 +59,8 @@ const PluginAuthInAgent = ({ | |||
| const renderTrigger = useCallback((isOpen?: boolean) => { | |||
| let label = '' | |||
| let removed = false | |||
| let unavailable = false | |||
| let color = 'green' | |||
| if (!credentialId) { | |||
| label = t('plugin.auth.workspaceDefault') | |||
| } | |||
| @@ -65,6 +68,11 @@ const PluginAuthInAgent = ({ | |||
| const credential = credentials.find(c => c.id === credentialId) | |||
| label = credential ? credential.name : t('plugin.auth.authRemoved') | |||
| removed = !credential | |||
| unavailable = !!credential?.not_allowed_to_use && !credential?.from_enterprise | |||
| if (removed) | |||
| color = 'red' | |||
| else if (unavailable) | |||
| color = 'gray' | |||
| } | |||
| return ( | |||
| <Button | |||
| @@ -75,9 +83,12 @@ const PluginAuthInAgent = ({ | |||
| )}> | |||
| <Indicator | |||
| className='mr-2' | |||
| color={removed ? 'red' : 'green'} | |||
| color={color as any} | |||
| /> | |||
| {label} | |||
| { | |||
| unavailable && t('plugin.auth.unavailable') | |||
| } | |||
| <RiArrowDownSLine className='ml-0.5 h-4 w-4' /> | |||
| </Button> | |||
| ) | |||
| @@ -93,6 +104,7 @@ const PluginAuthInAgent = ({ | |||
| canApiKey={canApiKey} | |||
| disabled={disabled} | |||
| onUpdate={invalidPluginCredentialInfo} | |||
| notAllowCustomCredential={notAllowCustomCredential} | |||
| /> | |||
| ) | |||
| } | |||
| @@ -113,6 +125,7 @@ const PluginAuthInAgent = ({ | |||
| onOpenChange={setIsOpen} | |||
| selectedCredentialId={credentialId || '__workspace_default__'} | |||
| onUpdate={invalidPluginCredentialInfo} | |||
| notAllowCustomCredential={notAllowCustomCredential} | |||
| /> | |||
| ) | |||
| } | |||
| @@ -22,6 +22,7 @@ const PluginAuth = ({ | |||
| credentials, | |||
| disabled, | |||
| invalidPluginCredentialInfo, | |||
| notAllowCustomCredential, | |||
| } = usePluginAuth(pluginPayload, !!pluginPayload.provider) | |||
| return ( | |||
| @@ -34,6 +35,7 @@ const PluginAuth = ({ | |||
| canApiKey={canApiKey} | |||
| disabled={disabled} | |||
| onUpdate={invalidPluginCredentialInfo} | |||
| notAllowCustomCredential={notAllowCustomCredential} | |||
| /> | |||
| ) | |||
| } | |||
| @@ -46,6 +48,7 @@ const PluginAuth = ({ | |||
| canApiKey={canApiKey} | |||
| disabled={disabled} | |||
| onUpdate={invalidPluginCredentialInfo} | |||
| notAllowCustomCredential={notAllowCustomCredential} | |||
| /> | |||
| ) | |||
| } | |||
| @@ -22,4 +22,6 @@ export type Credential = { | |||
| is_default: boolean | |||
| credentials?: Record<string, any> | |||
| isWorkspaceDefault?: boolean | |||
| from_enterprise?: boolean | |||
| not_allowed_to_use?: boolean | |||
| } | |||
| @@ -6,7 +6,9 @@ import { createContext, useContext, useContextSelector } from 'use-context-selec | |||
| import { useRouter, useSearchParams } from 'next/navigation' | |||
| import type { | |||
| ConfigurationMethodEnum, | |||
| Credential, | |||
| CustomConfigurationModelFixedFields, | |||
| CustomModel, | |||
| ModelLoadBalancingConfigEntry, | |||
| ModelProvider, | |||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| @@ -55,9 +57,6 @@ const ExternalAPIModal = dynamic(() => import('@/app/components/datasets/externa | |||
| const ModelLoadBalancingModal = dynamic(() => import('@/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal'), { | |||
| ssr: false, | |||
| }) | |||
| const ModelLoadBalancingEntryModal = dynamic(() => import('@/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal'), { | |||
| ssr: false, | |||
| }) | |||
| const OpeningSettingModal = dynamic(() => import('@/app/components/base/features/new-feature-panel/conversation-opener/modal'), { | |||
| ssr: false, | |||
| }) | |||
| @@ -84,6 +83,9 @@ export type ModelModalType = { | |||
| currentProvider: ModelProvider | |||
| currentConfigurationMethod: ConfigurationMethodEnum | |||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields | |||
| isModelCredential?: boolean | |||
| credential?: Credential | |||
| model?: CustomModel | |||
| } | |||
| export type LoadBalancingEntryModalType = ModelModalType & { | |||
| entry?: ModelLoadBalancingConfigEntry | |||
| @@ -100,7 +102,6 @@ export type ModalContextState = { | |||
| setShowModelModal: Dispatch<SetStateAction<ModalState<ModelModalType> | null>> | |||
| setShowExternalKnowledgeAPIModal: Dispatch<SetStateAction<ModalState<CreateExternalAPIReq> | null>> | |||
| setShowModelLoadBalancingModal: Dispatch<SetStateAction<ModelLoadBalancingModalProps | null>> | |||
| setShowModelLoadBalancingEntryModal: Dispatch<SetStateAction<ModalState<LoadBalancingEntryModalType> | null>> | |||
| setShowOpeningModal: Dispatch<SetStateAction<ModalState<OpeningStatement & { | |||
| promptVariables?: PromptVariable[] | |||
| workflowVariables?: InputVar[] | |||
| @@ -119,7 +120,6 @@ const ModalContext = createContext<ModalContextState>({ | |||
| setShowModelModal: noop, | |||
| setShowExternalKnowledgeAPIModal: noop, | |||
| setShowModelLoadBalancingModal: noop, | |||
| setShowModelLoadBalancingEntryModal: noop, | |||
| setShowOpeningModal: noop, | |||
| setShowUpdatePluginModal: noop, | |||
| setShowEducationExpireNoticeModal: noop, | |||
| @@ -145,7 +145,6 @@ export const ModalContextProvider = ({ | |||
| const [showModelModal, setShowModelModal] = useState<ModalState<ModelModalType> | null>(null) | |||
| const [showExternalKnowledgeAPIModal, setShowExternalKnowledgeAPIModal] = useState<ModalState<CreateExternalAPIReq> | null>(null) | |||
| const [showModelLoadBalancingModal, setShowModelLoadBalancingModal] = useState<ModelLoadBalancingModalProps | null>(null) | |||
| const [showModelLoadBalancingEntryModal, setShowModelLoadBalancingEntryModal] = useState<ModalState<LoadBalancingEntryModalType> | null>(null) | |||
| const [showOpeningModal, setShowOpeningModal] = useState<ModalState<OpeningStatement & { | |||
| promptVariables?: PromptVariable[] | |||
| workflowVariables?: InputVar[] | |||
| @@ -212,30 +211,12 @@ export const ModalContextProvider = ({ | |||
| setShowExternalKnowledgeAPIModal(null) | |||
| }, [showExternalKnowledgeAPIModal]) | |||
| const handleCancelModelLoadBalancingEntryModal = useCallback(() => { | |||
| showModelLoadBalancingEntryModal?.onCancelCallback?.() | |||
| setShowModelLoadBalancingEntryModal(null) | |||
| }, [showModelLoadBalancingEntryModal]) | |||
| const handleCancelOpeningModal = useCallback(() => { | |||
| setShowOpeningModal(null) | |||
| if (showOpeningModal?.onCancelCallback) | |||
| showOpeningModal.onCancelCallback() | |||
| }, [showOpeningModal]) | |||
| const handleSaveModelLoadBalancingEntryModal = useCallback((entry: ModelLoadBalancingConfigEntry) => { | |||
| showModelLoadBalancingEntryModal?.onSaveCallback?.({ | |||
| ...showModelLoadBalancingEntryModal.payload, | |||
| entry, | |||
| }) | |||
| setShowModelLoadBalancingEntryModal(null) | |||
| }, [showModelLoadBalancingEntryModal]) | |||
| const handleRemoveModelLoadBalancingEntry = useCallback(() => { | |||
| showModelLoadBalancingEntryModal?.onRemoveCallback?.(showModelLoadBalancingEntryModal.payload) | |||
| setShowModelLoadBalancingEntryModal(null) | |||
| }, [showModelLoadBalancingEntryModal]) | |||
| const handleSaveApiBasedExtension = (newApiBasedExtension: ApiBasedExtension) => { | |||
| if (showApiBasedExtensionModal?.onSaveCallback) | |||
| showApiBasedExtensionModal.onSaveCallback(newApiBasedExtension) | |||
| @@ -277,7 +258,6 @@ export const ModalContextProvider = ({ | |||
| setShowModelModal, | |||
| setShowExternalKnowledgeAPIModal, | |||
| setShowModelLoadBalancingModal, | |||
| setShowModelLoadBalancingEntryModal, | |||
| setShowOpeningModal, | |||
| setShowUpdatePluginModal, | |||
| setShowEducationExpireNoticeModal, | |||
| @@ -346,6 +326,9 @@ export const ModalContextProvider = ({ | |||
| provider={showModelModal.payload.currentProvider} | |||
| configurateMethod={showModelModal.payload.currentConfigurationMethod} | |||
| currentCustomConfigurationModelFixedFields={showModelModal.payload.currentCustomConfigurationModelFixedFields} | |||
| isModelCredential={showModelModal.payload.isModelCredential} | |||
| credential={showModelModal.payload.credential} | |||
| model={showModelModal.payload.model} | |||
| onCancel={handleCancelModelModal} | |||
| onSave={handleSaveModelModal} | |||
| /> | |||
| @@ -368,19 +351,6 @@ export const ModalContextProvider = ({ | |||
| <ModelLoadBalancingModal {...showModelLoadBalancingModal!} /> | |||
| ) | |||
| } | |||
| { | |||
| !!showModelLoadBalancingEntryModal && ( | |||
| <ModelLoadBalancingEntryModal | |||
| provider={showModelLoadBalancingEntryModal.payload.currentProvider} | |||
| configurationMethod={showModelLoadBalancingEntryModal.payload.currentConfigurationMethod} | |||
| currentCustomConfigurationModelFixedFields={showModelLoadBalancingEntryModal.payload.currentCustomConfigurationModelFixedFields} | |||
| entry={showModelLoadBalancingEntryModal.payload.entry} | |||
| onCancel={handleCancelModelLoadBalancingEntryModal} | |||
| onSave={handleSaveModelLoadBalancingEntryModal} | |||
| onRemove={handleRemoveModelLoadBalancingEntry} | |||
| /> | |||
| ) | |||
| } | |||
| {showOpeningModal && ( | |||
| <OpeningSettingModal | |||
| data={showOpeningModal.payload} | |||
| @@ -40,6 +40,7 @@ const translation = { | |||
| deleteApp: 'Delete App', | |||
| settings: 'Settings', | |||
| setup: 'Setup', | |||
| config: 'Config', | |||
| getForFree: 'Get for free', | |||
| reload: 'Reload', | |||
| ok: 'OK', | |||
| @@ -466,7 +467,7 @@ const translation = { | |||
| loadPresets: 'Load Presets', | |||
| parameters: 'PARAMETERS', | |||
| loadBalancing: 'Load balancing', | |||
| loadBalancingDescription: 'Reduce pressure with multiple sets of credentials.', | |||
| loadBalancingDescription: 'Configure multiple credentials for the model and invoke them automatically. ', | |||
| loadBalancingHeadline: 'Load Balancing', | |||
| configLoadBalancing: 'Config Load Balancing', | |||
| modelHasBeenDeprecated: 'This model has been deprecated', | |||
| @@ -486,6 +487,28 @@ const translation = { | |||
| discoverMore: 'Discover more in ', | |||
| emptyProviderTitle: 'Model provider not set up', | |||
| emptyProviderTip: 'Please install a model provider first.', | |||
| auth: { | |||
| unAuthorized: 'Unauthorized', | |||
| authRemoved: 'Auth removed', | |||
| apiKeys: 'API Keys', | |||
| addApiKey: 'Add API Key', | |||
| addNewModel: 'Add new model', | |||
| addCredential: 'Add credential', | |||
| addModelCredential: 'Add model credential', | |||
| modelCredentials: 'Model credentials', | |||
| configModel: 'Config model', | |||
| configLoadBalancing: 'Config Load Balancing', | |||
| authorizationError: 'Authorization error', | |||
| specifyModelCredential: 'Specify model credential', | |||
| specifyModelCredentialTip: 'Use a configured model credential.', | |||
| providerManaged: 'Provider managed', | |||
| providerManagedTip: 'The current configuration is hosted by the provider.', | |||
| apiKeyModal: { | |||
| title: 'API Key Authorization Configuration', | |||
| desc: 'After configuring credentials, all members within the workspace can use this model when orchestrating applications.', | |||
| addModel: 'Add model', | |||
| }, | |||
| }, | |||
| }, | |||
| dataSource: { | |||
| add: 'Add a data source', | |||
| @@ -297,6 +297,9 @@ const translation = { | |||
| authRemoved: 'Auth removed', | |||
| clientInfo: 'As no system client secrets found for this tool provider, setup it manually is required, for redirect_uri, please use', | |||
| oauthClient: 'OAuth Client', | |||
| credentialUnavailable: 'Credentials currently unavailable. Please contact admin.', | |||
| customCredentialUnavailable: 'Custom credentials currently unavailable', | |||
| unavailable: 'Unavailable', | |||
| }, | |||
| } | |||
| @@ -40,6 +40,7 @@ const translation = { | |||
| deleteApp: '删除应用', | |||
| settings: '设置', | |||
| setup: '设置', | |||
| config: '配置', | |||
| getForFree: '免费获取', | |||
| reload: '刷新', | |||
| ok: '好的', | |||
| @@ -465,7 +466,7 @@ const translation = { | |||
| loadPresets: '加载预设', | |||
| parameters: '参数', | |||
| loadBalancing: '负载均衡', | |||
| loadBalancingDescription: '为了减轻单组凭据的压力,您可以为模型调用配置多组凭据。', | |||
| loadBalancingDescription: '为模型配置多组凭据,并自动调用。', | |||
| loadBalancingHeadline: '负载均衡', | |||
| configLoadBalancing: '设置负载均衡', | |||
| modelHasBeenDeprecated: '该模型已废弃', | |||
| @@ -486,6 +487,28 @@ const translation = { | |||
| discoverMore: '发现更多就在', | |||
| emptyProviderTitle: '尚未安装模型供应商', | |||
| emptyProviderTip: '请安装模型供应商。', | |||
| auth: { | |||
| unAuthorized: '未授权', | |||
| authRemoved: '授权已移除', | |||
| apiKeys: 'API 密钥', | |||
| addApiKey: '添加 API 密钥', | |||
| addNewModel: '添加新模型', | |||
| addCredential: '添加凭据', | |||
| addModelCredential: '添加模型凭据', | |||
| modelCredentials: '模型凭据', | |||
| configModel: '配置模型', | |||
| configLoadBalancing: '配置负载均衡', | |||
| authorizationError: '授权错误', | |||
| specifyModelCredential: '指定模型凭据', | |||
| specifyModelCredentialTip: '使用已配置的模型凭据。', | |||
| providerManaged: '由模型供应商管理', | |||
| providerManagedTip: '使用模型供应商提供的单组凭据。', | |||
| apiKeyModal: { | |||
| title: 'API 密钥授权配置', | |||
| desc: '配置凭据后,工作空间中的所有成员都可以在编排应用时使用此模型。', | |||
| addModel: '添加模型', | |||
| }, | |||
| }, | |||
| }, | |||
| dataSource: { | |||
| add: '添加数据源', | |||
| @@ -297,6 +297,9 @@ const translation = { | |||
| authRemoved: '凭据已移除', | |||
| clientInfo: '由于未找到此工具提供者的系统客户端密钥,因此需要手动设置,对于 redirect_uri,请使用', | |||
| oauthClient: 'OAuth 客户端', | |||
| credentialUnavailable: '自定义凭据当前不可用,请联系管理员。', | |||
| customCredentialUnavailable: '自定义凭据当前不可用', | |||
| unavailable: '不可用', | |||
| }, | |||
| } | |||
| @@ -1,8 +1,18 @@ | |||
| import { get } from './base' | |||
| import { | |||
| del, | |||
| get, | |||
| post, | |||
| put, | |||
| } from './base' | |||
| import type { | |||
| ModelCredential, | |||
| ModelItem, | |||
| ModelLoadBalancingConfig, | |||
| ModelTypeEnum, | |||
| ProviderCredential, | |||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import { | |||
| useMutation, | |||
| useQuery, | |||
| // useQueryClient, | |||
| } from '@tanstack/react-query' | |||
| @@ -15,3 +25,131 @@ export const useModelProviderModelList = (provider: string) => { | |||
| queryFn: () => get<{ data: ModelItem[] }>(`/workspaces/current/model-providers/${provider}/models`), | |||
| }) | |||
| } | |||
| export const useGetProviderCredential = (enabled: boolean, provider: string, credentialId?: string) => { | |||
| return useQuery({ | |||
| enabled, | |||
| queryKey: [NAME_SPACE, 'model-list', provider, credentialId], | |||
| queryFn: () => get<ProviderCredential>(`/workspaces/current/model-providers/${provider}/credentials${credentialId ? `?credential_id=${credentialId}` : ''}`), | |||
| }) | |||
| } | |||
| export const useAddProviderCredential = (provider: string) => { | |||
| return useMutation({ | |||
| mutationFn: (data: ProviderCredential) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials`, { | |||
| body: data, | |||
| }), | |||
| }) | |||
| } | |||
| export const useEditProviderCredential = (provider: string) => { | |||
| return useMutation({ | |||
| mutationFn: (data: ProviderCredential) => put<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials`, { | |||
| body: data, | |||
| }), | |||
| }) | |||
| } | |||
| export const useDeleteProviderCredential = (provider: string) => { | |||
| return useMutation({ | |||
| mutationFn: (data: { | |||
| credential_id: string | |||
| }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials`, { | |||
| body: data, | |||
| }), | |||
| }) | |||
| } | |||
| export const useActiveProviderCredential = (provider: string) => { | |||
| return useMutation({ | |||
| mutationFn: (data: { | |||
| credential_id: string | |||
| model?: string | |||
| model_type?: ModelTypeEnum | |||
| }) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials/switch`, { | |||
| body: data, | |||
| }), | |||
| }) | |||
| } | |||
| export const useGetModelCredential = ( | |||
| enabled: boolean, | |||
| provider: string, | |||
| credentialId?: string, | |||
| model?: string, | |||
| modelType?: string, | |||
| configFrom?: string, | |||
| ) => { | |||
| return useQuery({ | |||
| enabled, | |||
| queryKey: [NAME_SPACE, 'model-list', provider, model, modelType, credentialId], | |||
| queryFn: () => get<ModelCredential>(`/workspaces/current/model-providers/${provider}/models/credentials?model=${model}&model_type=${modelType}&config_from=${configFrom}${credentialId ? `&credential_id=${credentialId}` : ''}`), | |||
| staleTime: 0, | |||
| gcTime: 0, | |||
| }) | |||
| } | |||
| export const useAddModelCredential = (provider: string) => { | |||
| return useMutation({ | |||
| mutationFn: (data: ModelCredential) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { | |||
| body: data, | |||
| }), | |||
| }) | |||
| } | |||
| export const useEditModelCredential = (provider: string) => { | |||
| return useMutation({ | |||
| mutationFn: (data: ModelCredential) => put<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { | |||
| body: data, | |||
| }), | |||
| }) | |||
| } | |||
| export const useDeleteModelCredential = (provider: string) => { | |||
| return useMutation({ | |||
| mutationFn: (data: { | |||
| credential_id: string | |||
| model?: string | |||
| model_type?: ModelTypeEnum | |||
| }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { | |||
| body: data, | |||
| }), | |||
| }) | |||
| } | |||
| export const useDeleteModel = (provider: string) => { | |||
| return useMutation({ | |||
| mutationFn: (data: { | |||
| model: string | |||
| model_type: ModelTypeEnum | |||
| }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { | |||
| body: data, | |||
| }), | |||
| }) | |||
| } | |||
| export const useActiveModelCredential = (provider: string) => { | |||
| return useMutation({ | |||
| mutationFn: (data: { | |||
| credential_id: string | |||
| model?: string | |||
| model_type?: ModelTypeEnum | |||
| }) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials/switch`, { | |||
| body: data, | |||
| }), | |||
| }) | |||
| } | |||
| export const useUpdateModelLoadBalancingConfig = (provider: string) => { | |||
| return useMutation({ | |||
| mutationFn: (data: { | |||
| config_from: string | |||
| model: string | |||
| model_type: ModelTypeEnum | |||
| load_balancing: ModelLoadBalancingConfig | |||
| credential_id?: string | |||
| }) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/models`, { | |||
| body: data, | |||
| }), | |||
| }) | |||
| } | |||
| @@ -19,6 +19,7 @@ export const useGetPluginCredentialInfo = ( | |||
| enabled: !!url, | |||
| queryKey: [NAME_SPACE, 'credential-info', url], | |||
| queryFn: () => get<{ | |||
| allow_custom_token?: boolean | |||
| supported_credential_types: string[] | |||
| credentials: Credential[] | |||
| is_oauth_custom_client_enabled: boolean | |||