Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>tags/1.8.0
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from libs.helper import StrLen, uuid_value | |||||
| from libs.login import login_required | from libs.login import login_required | ||||
| from services.billing_service import BillingService | from services.billing_service import BillingService | ||||
| from services.model_provider_service import ModelProviderService | from services.model_provider_service import ModelProviderService | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self, provider: str): | def get(self, provider: str): | ||||
| tenant_id = current_user.current_tenant_id | tenant_id = current_user.current_tenant_id | ||||
| # if credential_id is not provided, return current used credential | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") | |||||
| args = parser.parse_args() | |||||
| model_provider_service = ModelProviderService() | model_provider_service = ModelProviderService() | ||||
| credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider) | |||||
| credentials = model_provider_service.get_provider_credential( | |||||
| tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id") | |||||
| ) | |||||
| return {"credentials": credentials} | return {"credentials": credentials} | ||||
| class ModelProviderValidateApi(Resource): | |||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self, provider: str): | def post(self, provider: str): | ||||
| if not current_user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | ||||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| tenant_id = current_user.current_tenant_id | |||||
| model_provider_service = ModelProviderService() | model_provider_service = ModelProviderService() | ||||
| result = True | |||||
| error = "" | |||||
| try: | try: | ||||
| model_provider_service.provider_credentials_validate( | |||||
| tenant_id=tenant_id, provider=provider, credentials=args["credentials"] | |||||
| model_provider_service.create_provider_credential( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider=provider, | |||||
| credentials=args["credentials"], | |||||
| credential_name=args["name"], | |||||
| ) | ) | ||||
| except CredentialsValidateFailedError as ex: | except CredentialsValidateFailedError as ex: | ||||
| result = False | |||||
| error = str(ex) | |||||
| response = {"result": "success" if result else "error"} | |||||
| if not result: | |||||
| response["error"] = error or "Unknown error" | |||||
| return response | |||||
| raise ValueError(str(ex)) | |||||
| return {"result": "success"}, 201 | |||||
| class ModelProviderApi(Resource): | |||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self, provider: str): | |||||
| def put(self, provider: str): | |||||
| if not current_user.is_admin_or_owner: | if not current_user.is_admin_or_owner: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | ||||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| model_provider_service = ModelProviderService() | model_provider_service = ModelProviderService() | ||||
| try: | try: | ||||
| model_provider_service.save_provider_credentials( | |||||
| tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"] | |||||
| model_provider_service.update_provider_credential( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider=provider, | |||||
| credentials=args["credentials"], | |||||
| credential_id=args["credential_id"], | |||||
| credential_name=args["name"], | |||||
| ) | ) | ||||
| except CredentialsValidateFailedError as ex: | except CredentialsValidateFailedError as ex: | ||||
| raise ValueError(str(ex)) | raise ValueError(str(ex)) | ||||
| return {"result": "success"}, 201 | |||||
| return {"result": "success"} | |||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| def delete(self, provider: str): | def delete(self, provider: str): | ||||
| if not current_user.is_admin_or_owner: | if not current_user.is_admin_or_owner: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||||
| args = parser.parse_args() | |||||
| model_provider_service = ModelProviderService() | model_provider_service = ModelProviderService() | ||||
| model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider) | |||||
| model_provider_service.remove_provider_credential( | |||||
| tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] | |||||
| ) | |||||
| return {"result": "success"}, 204 | return {"result": "success"}, 204 | ||||
| class ModelProviderCredentialSwitchApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self, provider: str): | |||||
| if not current_user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") | |||||
| args = parser.parse_args() | |||||
| service = ModelProviderService() | |||||
| service.switch_active_provider_credential( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider=provider, | |||||
| credential_id=args["credential_id"], | |||||
| ) | |||||
| return {"result": "success"} | |||||
| class ModelProviderValidateApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self, provider: str): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||||
| args = parser.parse_args() | |||||
| tenant_id = current_user.current_tenant_id | |||||
| model_provider_service = ModelProviderService() | |||||
| result = True | |||||
| error = "" | |||||
| try: | |||||
| model_provider_service.validate_provider_credentials( | |||||
| tenant_id=tenant_id, provider=provider, credentials=args["credentials"] | |||||
| ) | |||||
| except CredentialsValidateFailedError as ex: | |||||
| result = False | |||||
| error = str(ex) | |||||
| response = {"result": "success" if result else "error"} | |||||
| if not result: | |||||
| response["error"] = error or "Unknown error" | |||||
| return response | |||||
| class ModelProviderIconApi(Resource): | class ModelProviderIconApi(Resource): | ||||
| """ | """ | ||||
| Get model provider icon | Get model provider icon | ||||
| api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") | api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") | ||||
| api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials") | api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials") | ||||
| api.add_resource( | |||||
| ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch" | |||||
| ) | |||||
| api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate") | api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate") | ||||
| api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>") | |||||
| api.add_resource( | api.add_resource( | ||||
| PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type" | PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type" |
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from libs.helper import StrLen, uuid_value | |||||
| from libs.login import login_required | from libs.login import login_required | ||||
| from services.model_load_balancing_service import ModelLoadBalancingService | from services.model_load_balancing_service import ModelLoadBalancingService | ||||
| from services.model_provider_service import ModelProviderService | from services.model_provider_service import ModelProviderService | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self, provider: str): | def post(self, provider: str): | ||||
| # To save the model's load balance configs | |||||
| if not current_user.is_admin_or_owner: | if not current_user.is_admin_or_owner: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| choices=[mt.value for mt in ModelType], | choices=[mt.value for mt in ModelType], | ||||
| location="json", | location="json", | ||||
| ) | ) | ||||
| parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") | |||||
| parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") | parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") | ||||
| parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") | parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") | ||||
| parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| model_load_balancing_service = ModelLoadBalancingService() | |||||
| if args.get("config_from", "") == "custom-model": | |||||
| if not args.get("credential_id"): | |||||
| raise ValueError("credential_id is required when configuring a custom-model") | |||||
| service = ModelProviderService() | |||||
| service.switch_active_custom_model_credential( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider=provider, | |||||
| model_type=args["model_type"], | |||||
| model=args["model"], | |||||
| credential_id=args["credential_id"], | |||||
| ) | |||||
| if ( | |||||
| "load_balancing" in args | |||||
| and args["load_balancing"] | |||||
| and "enabled" in args["load_balancing"] | |||||
| and args["load_balancing"]["enabled"] | |||||
| ): | |||||
| if "configs" not in args["load_balancing"]: | |||||
| raise ValueError("invalid load balancing configs") | |||||
| model_load_balancing_service = ModelLoadBalancingService() | |||||
| if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]: | |||||
| # save load balancing configs | # save load balancing configs | ||||
| model_load_balancing_service.update_load_balancing_configs( | model_load_balancing_service.update_load_balancing_configs( | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| model=args["model"], | model=args["model"], | ||||
| model_type=args["model_type"], | model_type=args["model_type"], | ||||
| configs=args["load_balancing"]["configs"], | configs=args["load_balancing"]["configs"], | ||||
| config_from=args.get("config_from", ""), | |||||
| ) | ) | ||||
| # enable load balancing | |||||
| model_load_balancing_service.enable_model_load_balancing( | |||||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||||
| ) | |||||
| else: | |||||
| # disable load balancing | |||||
| model_load_balancing_service.disable_model_load_balancing( | |||||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||||
| ) | |||||
| if args.get("config_from", "") != "predefined-model": | |||||
| model_provider_service = ModelProviderService() | |||||
| try: | |||||
| model_provider_service.save_model_credentials( | |||||
| tenant_id=tenant_id, | |||||
| provider=provider, | |||||
| model=args["model"], | |||||
| model_type=args["model_type"], | |||||
| credentials=args["credentials"], | |||||
| ) | |||||
| except CredentialsValidateFailedError as ex: | |||||
| logging.exception( | |||||
| "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", | |||||
| tenant_id, | |||||
| args.get("model"), | |||||
| args.get("model_type"), | |||||
| ) | |||||
| raise ValueError(str(ex)) | |||||
| if args.get("load_balancing", {}).get("enabled"): | |||||
| model_load_balancing_service.enable_model_load_balancing( | |||||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||||
| ) | |||||
| else: | |||||
| model_load_balancing_service.disable_model_load_balancing( | |||||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||||
| ) | |||||
| return {"result": "success"}, 200 | return {"result": "success"}, 200 | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| model_provider_service = ModelProviderService() | model_provider_service = ModelProviderService() | ||||
| model_provider_service.remove_model_credentials( | |||||
| model_provider_service.remove_model( | |||||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | ||||
| ) | ) | ||||
| choices=[mt.value for mt in ModelType], | choices=[mt.value for mt in ModelType], | ||||
| location="args", | location="args", | ||||
| ) | ) | ||||
| parser.add_argument("config_from", type=str, required=False, nullable=True, location="args") | |||||
| parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| model_provider_service = ModelProviderService() | model_provider_service = ModelProviderService() | ||||
| credentials = model_provider_service.get_model_credentials( | |||||
| tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] | |||||
| current_credential = model_provider_service.get_model_credential( | |||||
| tenant_id=tenant_id, | |||||
| provider=provider, | |||||
| model_type=args["model_type"], | |||||
| model=args["model"], | |||||
| credential_id=args.get("credential_id"), | |||||
| ) | ) | ||||
| model_load_balancing_service = ModelLoadBalancingService() | model_load_balancing_service = ModelLoadBalancingService() | ||||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | ||||
| ) | ) | ||||
| return { | |||||
| "credentials": credentials, | |||||
| "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, | |||||
| } | |||||
| if args.get("config_from", "") == "predefined-model": | |||||
| available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( | |||||
| tenant_id=tenant_id, provider_name=provider | |||||
| ) | |||||
| else: | |||||
| model_type = ModelType.value_of(args["model_type"]).to_origin_model_type() | |||||
| available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( | |||||
| tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"] | |||||
| ) | |||||
| return jsonable_encoder( | |||||
| { | |||||
| "credentials": current_credential.get("credentials") if current_credential else {}, | |||||
| "current_credential_id": current_credential.get("current_credential_id") | |||||
| if current_credential | |||||
| else None, | |||||
| "current_credential_name": current_credential.get("current_credential_name") | |||||
| if current_credential | |||||
| else None, | |||||
| "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, | |||||
| "available_credentials": available_credentials, | |||||
| } | |||||
| ) | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self, provider: str): | |||||
| if not current_user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||||
| parser.add_argument( | |||||
| "model_type", | |||||
| type=str, | |||||
| required=True, | |||||
| nullable=False, | |||||
| choices=[mt.value for mt in ModelType], | |||||
| location="json", | |||||
| ) | |||||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||||
| args = parser.parse_args() | |||||
| tenant_id = current_user.current_tenant_id | |||||
| model_provider_service = ModelProviderService() | |||||
| try: | |||||
| model_provider_service.create_model_credential( | |||||
| tenant_id=tenant_id, | |||||
| provider=provider, | |||||
| model=args["model"], | |||||
| model_type=args["model_type"], | |||||
| credentials=args["credentials"], | |||||
| credential_name=args["name"], | |||||
| ) | |||||
| except CredentialsValidateFailedError as ex: | |||||
| logging.exception( | |||||
| "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", | |||||
| tenant_id, | |||||
| args.get("model"), | |||||
| args.get("model_type"), | |||||
| ) | |||||
| raise ValueError(str(ex)) | |||||
| return {"result": "success"}, 201 | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def put(self, provider: str): | |||||
| if not current_user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||||
| parser.add_argument( | |||||
| "model_type", | |||||
| type=str, | |||||
| required=True, | |||||
| nullable=False, | |||||
| choices=[mt.value for mt in ModelType], | |||||
| location="json", | |||||
| ) | |||||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||||
| args = parser.parse_args() | |||||
| model_provider_service = ModelProviderService() | |||||
| try: | |||||
| model_provider_service.update_model_credential( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider=provider, | |||||
| model_type=args["model_type"], | |||||
| model=args["model"], | |||||
| credentials=args["credentials"], | |||||
| credential_id=args["credential_id"], | |||||
| credential_name=args["name"], | |||||
| ) | |||||
| except CredentialsValidateFailedError as ex: | |||||
| raise ValueError(str(ex)) | |||||
| return {"result": "success"} | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def delete(self, provider: str): | |||||
| if not current_user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||||
| parser.add_argument( | |||||
| "model_type", | |||||
| type=str, | |||||
| required=True, | |||||
| nullable=False, | |||||
| choices=[mt.value for mt in ModelType], | |||||
| location="json", | |||||
| ) | |||||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||||
| args = parser.parse_args() | |||||
| model_provider_service = ModelProviderService() | |||||
| model_provider_service.remove_model_credential( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider=provider, | |||||
| model_type=args["model_type"], | |||||
| model=args["model"], | |||||
| credential_id=args["credential_id"], | |||||
| ) | |||||
| return {"result": "success"}, 204 | |||||
| class ModelProviderModelCredentialSwitchApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self, provider: str): | |||||
| if not current_user.is_admin_or_owner: | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||||
| parser.add_argument( | |||||
| "model_type", | |||||
| type=str, | |||||
| required=True, | |||||
| nullable=False, | |||||
| choices=[mt.value for mt in ModelType], | |||||
| location="json", | |||||
| ) | |||||
| parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") | |||||
| args = parser.parse_args() | |||||
| service = ModelProviderService() | |||||
| service.add_model_credential_to_model_list( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider=provider, | |||||
| model_type=args["model_type"], | |||||
| model=args["model"], | |||||
| credential_id=args["credential_id"], | |||||
| ) | |||||
| return {"result": "success"} | |||||
| class ModelProviderModelEnableApi(Resource): | class ModelProviderModelEnableApi(Resource): | ||||
| error = "" | error = "" | ||||
| try: | try: | ||||
| model_provider_service.model_credentials_validate( | |||||
| model_provider_service.validate_model_credentials( | |||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| provider=provider, | provider=provider, | ||||
| model=args["model"], | model=args["model"], | ||||
| api.add_resource( | api.add_resource( | ||||
| ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials" | ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials" | ||||
| ) | ) | ||||
| api.add_resource( | |||||
| ModelProviderModelCredentialSwitchApi, | |||||
| "/workspaces/current/model-providers/<path:provider>/models/credentials/switch", | |||||
| ) | |||||
| api.add_resource( | api.add_resource( | ||||
| ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate" | ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate" | ||||
| ) | ) |
| QUOTA_EXCEEDED = "quota-exceeded" | QUOTA_EXCEEDED = "quota-exceeded" | ||||
| NO_PERMISSION = "no-permission" | NO_PERMISSION = "no-permission" | ||||
| DISABLED = "disabled" | DISABLED = "disabled" | ||||
| CREDENTIAL_REMOVED = "credential-removed" | |||||
| class SimpleModelProviderEntity(BaseModel): | class SimpleModelProviderEntity(BaseModel): | ||||
| status: ModelStatus | status: ModelStatus | ||||
| load_balancing_enabled: bool = False | load_balancing_enabled: bool = False | ||||
| has_invalid_load_balancing_configs: bool = False | |||||
| def raise_for_status(self) -> None: | def raise_for_status(self) -> None: | ||||
| """ | """ |
| restrict_models: list[RestrictModel] = [] | restrict_models: list[RestrictModel] = [] | ||||
| class CredentialConfiguration(BaseModel): | |||||
| """ | |||||
| Model class for credential configuration. | |||||
| """ | |||||
| credential_id: str | |||||
| credential_name: str | |||||
| class SystemConfiguration(BaseModel): | class SystemConfiguration(BaseModel): | ||||
| """ | """ | ||||
| Model class for provider system configuration. | Model class for provider system configuration. | ||||
| """ | """ | ||||
| credentials: dict | credentials: dict | ||||
| current_credential_id: Optional[str] = None | |||||
| current_credential_name: Optional[str] = None | |||||
| available_credentials: list[CredentialConfiguration] = [] | |||||
| class CustomModelConfiguration(BaseModel): | class CustomModelConfiguration(BaseModel): | ||||
| model: str | model: str | ||||
| model_type: ModelType | model_type: ModelType | ||||
| credentials: dict | |||||
| credentials: dict | None | |||||
| current_credential_id: Optional[str] = None | |||||
| current_credential_name: Optional[str] = None | |||||
| available_model_credentials: list[CredentialConfiguration] = [] | |||||
| # pydantic configs | # pydantic configs | ||||
| model_config = ConfigDict(protected_namespaces=()) | model_config = ConfigDict(protected_namespaces=()) | ||||
| id: str | id: str | ||||
| name: str | name: str | ||||
| credentials: dict | credentials: dict | ||||
| credential_source_type: str | None = None | |||||
| class ModelSettings(BaseModel): | class ModelSettings(BaseModel): |
| return filtered_credentials | return filtered_credentials | ||||
| def get_model_schema( | def get_model_schema( | ||||
| self, *, provider: str, model_type: ModelType, model: str, credentials: dict | |||||
| self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None | |||||
| ) -> AIModelEntity | None: | ) -> AIModelEntity | None: | ||||
| """ | """ | ||||
| Get model schema | Get model schema |
| from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity | from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity | ||||
| from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle | from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle | ||||
| from core.entities.provider_entities import ( | from core.entities.provider_entities import ( | ||||
| CredentialConfiguration, | |||||
| CustomConfiguration, | CustomConfiguration, | ||||
| CustomModelConfiguration, | CustomModelConfiguration, | ||||
| CustomProviderConfiguration, | CustomProviderConfiguration, | ||||
| from models.provider import ( | from models.provider import ( | ||||
| LoadBalancingModelConfig, | LoadBalancingModelConfig, | ||||
| Provider, | Provider, | ||||
| ProviderCredential, | |||||
| ProviderModel, | ProviderModel, | ||||
| ProviderModelCredential, | |||||
| ProviderModelSetting, | ProviderModelSetting, | ||||
| ProviderType, | ProviderType, | ||||
| TenantDefaultModel, | TenantDefaultModel, | ||||
| return provider_name_to_provider_load_balancing_model_configs_dict | return provider_name_to_provider_load_balancing_model_configs_dict | ||||
| @staticmethod | |||||
| def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]: | |||||
| """ | |||||
| Get provider all credentials. | |||||
| :param tenant_id: workspace id | |||||
| :param provider_name: provider name | |||||
| :return: | |||||
| """ | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| stmt = ( | |||||
| select(ProviderCredential) | |||||
| .where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name) | |||||
| .order_by(ProviderCredential.created_at.desc()) | |||||
| ) | |||||
| available_credentials = session.scalars(stmt).all() | |||||
| return [ | |||||
| CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name) | |||||
| for credential in available_credentials | |||||
| ] | |||||
| @staticmethod | |||||
| def get_provider_model_available_credentials( | |||||
| tenant_id: str, provider_name: str, model_name: str, model_type: str | |||||
| ) -> list[CredentialConfiguration]: | |||||
| """ | |||||
| Get provider custom model all credentials. | |||||
| :param tenant_id: workspace id | |||||
| :param provider_name: provider name | |||||
| :param model_name: model name | |||||
| :param model_type: model type | |||||
| :return: | |||||
| """ | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| stmt = ( | |||||
| select(ProviderModelCredential) | |||||
| .where( | |||||
| ProviderModelCredential.tenant_id == tenant_id, | |||||
| ProviderModelCredential.provider_name == provider_name, | |||||
| ProviderModelCredential.model_name == model_name, | |||||
| ProviderModelCredential.model_type == model_type, | |||||
| ) | |||||
| .order_by(ProviderModelCredential.created_at.desc()) | |||||
| ) | |||||
| available_credentials = session.scalars(stmt).all() | |||||
| return [ | |||||
| CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name) | |||||
| for credential in available_credentials | |||||
| ] | |||||
| @staticmethod | @staticmethod | ||||
| def _init_trial_provider_records( | def _init_trial_provider_records( | ||||
| tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] | tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] | ||||
| if provider_record.provider_type == ProviderType.SYSTEM.value: | if provider_record.provider_type == ProviderType.SYSTEM.value: | ||||
| continue | continue | ||||
| if not provider_record.encrypted_config: | |||||
| continue | |||||
| custom_provider_record = provider_record | custom_provider_record = provider_record | ||||
| # Get custom provider credentials | # Get custom provider credentials | ||||
| try: | try: | ||||
| # fix origin data | # fix origin data | ||||
| if custom_provider_record.encrypted_config is None: | if custom_provider_record.encrypted_config is None: | ||||
| raise ValueError("No credentials found") | |||||
| if not custom_provider_record.encrypted_config.startswith("{"): | |||||
| provider_credentials = {} | |||||
| elif not custom_provider_record.encrypted_config.startswith("{"): | |||||
| provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} | provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} | ||||
| else: | else: | ||||
| provider_credentials = json.loads(custom_provider_record.encrypted_config) | provider_credentials = json.loads(custom_provider_record.encrypted_config) | ||||
| else: | else: | ||||
| provider_credentials = cached_provider_credentials | provider_credentials = cached_provider_credentials | ||||
| custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials) | |||||
| custom_provider_configuration = CustomProviderConfiguration( | |||||
| credentials=provider_credentials, | |||||
| current_credential_name=custom_provider_record.credential_name, | |||||
| current_credential_id=custom_provider_record.credential_id, | |||||
| available_credentials=self.get_provider_available_credentials( | |||||
| tenant_id, custom_provider_record.provider_name | |||||
| ), | |||||
| ) | |||||
| # Get provider model credential secret variables | # Get provider model credential secret variables | ||||
| model_credential_secret_variables = self._extract_secret_variables( | model_credential_secret_variables = self._extract_secret_variables( | ||||
| # Get custom provider model credentials | # Get custom provider model credentials | ||||
| custom_model_configurations = [] | custom_model_configurations = [] | ||||
| for provider_model_record in provider_model_records: | for provider_model_record in provider_model_records: | ||||
| if not provider_model_record.encrypted_config: | |||||
| continue | |||||
| available_model_credentials = self.get_provider_model_available_credentials( | |||||
| tenant_id, | |||||
| provider_model_record.provider_name, | |||||
| provider_model_record.model_name, | |||||
| provider_model_record.model_type, | |||||
| ) | |||||
| provider_model_credentials_cache = ProviderCredentialsCache( | provider_model_credentials_cache = ProviderCredentialsCache( | ||||
| tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL | tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL | ||||
| # Get cached provider model credentials | # Get cached provider model credentials | ||||
| cached_provider_model_credentials = provider_model_credentials_cache.get() | cached_provider_model_credentials = provider_model_credentials_cache.get() | ||||
| if not cached_provider_model_credentials: | |||||
| if not cached_provider_model_credentials and provider_model_record.encrypted_config: | |||||
| try: | try: | ||||
| provider_model_credentials = json.loads(provider_model_record.encrypted_config) | provider_model_credentials = json.loads(provider_model_record.encrypted_config) | ||||
| except JSONDecodeError: | except JSONDecodeError: | ||||
| model=provider_model_record.model_name, | model=provider_model_record.model_name, | ||||
| model_type=ModelType.value_of(provider_model_record.model_type), | model_type=ModelType.value_of(provider_model_record.model_type), | ||||
| credentials=provider_model_credentials, | credentials=provider_model_credentials, | ||||
| current_credential_id=provider_model_record.credential_id, | |||||
| current_credential_name=provider_model_record.credential_name, | |||||
| available_model_credentials=available_model_credentials, | |||||
| ) | ) | ||||
| ) | ) | ||||
| load_balancing_model_config.model_name == provider_model_setting.model_name | load_balancing_model_config.model_name == provider_model_setting.model_name | ||||
| and load_balancing_model_config.model_type == provider_model_setting.model_type | and load_balancing_model_config.model_type == provider_model_setting.model_type | ||||
| ): | ): | ||||
| if load_balancing_model_config.name == "__delete__": | |||||
| # to calculate current model whether has invalidate lb configs | |||||
| load_balancing_configs.append( | |||||
| ModelLoadBalancingConfiguration( | |||||
| id=load_balancing_model_config.id, | |||||
| name=load_balancing_model_config.name, | |||||
| credentials={}, | |||||
| credential_source_type=load_balancing_model_config.credential_source_type, | |||||
| ) | |||||
| ) | |||||
| continue | |||||
| if not load_balancing_model_config.enabled: | if not load_balancing_model_config.enabled: | ||||
| continue | continue | ||||
| id=load_balancing_model_config.id, | id=load_balancing_model_config.id, | ||||
| name=load_balancing_model_config.name, | name=load_balancing_model_config.name, | ||||
| credentials=provider_model_credentials, | credentials=provider_model_credentials, | ||||
| credential_source_type=load_balancing_model_config.credential_source_type, | |||||
| ) | ) | ||||
| ) | ) | ||||
| """Add provider multi credential support | |||||
| Revision ID: e8446f481c1e | |||||
| Revises: 8bcc02c9bd07 | |||||
| Create Date: 2025-08-09 15:53:54.341341 | |||||
| """ | |||||
| from alembic import op | |||||
| import models as models | |||||
| import sqlalchemy as sa | |||||
| from sqlalchemy.sql import table, column | |||||
| import uuid | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = 'e8446f481c1e' | |||||
| down_revision = 'fa8b0fa6f407' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # Create provider_credentials table | |||||
| op.create_table('provider_credentials', | |||||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('provider_name', sa.String(length=255), nullable=False), | |||||
| sa.Column('credential_name', sa.String(length=255), nullable=False), | |||||
| sa.Column('encrypted_config', sa.Text(), nullable=False), | |||||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), | |||||
| sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), | |||||
| sa.PrimaryKeyConstraint('id', name='provider_credential_pkey') | |||||
| ) | |||||
| # Create index for provider_credentials | |||||
| with op.batch_alter_table('provider_credentials', schema=None) as batch_op: | |||||
| batch_op.create_index('provider_credential_tenant_provider_idx', ['tenant_id', 'provider_name'], unique=False) | |||||
| # Add credential_id to providers table | |||||
| with op.batch_alter_table('providers', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) | |||||
| # Add credential_id to load_balancing_model_configs table | |||||
| with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) | |||||
| migrate_existing_providers_data() | |||||
| # Remove encrypted_config column from providers table after migration | |||||
| with op.batch_alter_table('providers', schema=None) as batch_op: | |||||
| batch_op.drop_column('encrypted_config') | |||||
| def migrate_existing_providers_data(): | |||||
| """migrate providers table data to provider_credentials""" | |||||
| # Define table structure for data manipulation | |||||
| providers_table = table('providers', | |||||
| column('id', models.types.StringUUID()), | |||||
| column('tenant_id', models.types.StringUUID()), | |||||
| column('provider_name', sa.String()), | |||||
| column('encrypted_config', sa.Text()), | |||||
| column('created_at', sa.DateTime()), | |||||
| column('updated_at', sa.DateTime()), | |||||
| column('credential_id', models.types.StringUUID()), | |||||
| ) | |||||
| provider_credential_table = table('provider_credentials', | |||||
| column('id', models.types.StringUUID()), | |||||
| column('tenant_id', models.types.StringUUID()), | |||||
| column('provider_name', sa.String()), | |||||
| column('credential_name', sa.String()), | |||||
| column('encrypted_config', sa.Text()), | |||||
| column('created_at', sa.DateTime()), | |||||
| column('updated_at', sa.DateTime()) | |||||
| ) | |||||
| # Get database connection | |||||
| conn = op.get_bind() | |||||
| # Query all existing providers data | |||||
| existing_providers = conn.execute( | |||||
| sa.select(providers_table.c.id, providers_table.c.tenant_id, | |||||
| providers_table.c.provider_name, providers_table.c.encrypted_config, | |||||
| providers_table.c.created_at, providers_table.c.updated_at) | |||||
| .where(providers_table.c.encrypted_config.isnot(None)) | |||||
| ).fetchall() | |||||
| # Iterate through each provider and insert into provider_credentials | |||||
| for provider in existing_providers: | |||||
| credential_id = str(uuid.uuid4()) | |||||
| if not provider.encrypted_config or provider.encrypted_config.strip() == '': | |||||
| continue | |||||
| # Insert into provider_credentials table | |||||
| conn.execute( | |||||
| provider_credential_table.insert().values( | |||||
| id=credential_id, | |||||
| tenant_id=provider.tenant_id, | |||||
| provider_name=provider.provider_name, | |||||
| credential_name='API_KEY1', # Use a default name | |||||
| encrypted_config=provider.encrypted_config, | |||||
| created_at=provider.created_at, | |||||
| updated_at=provider.updated_at | |||||
| ) | |||||
| ) | |||||
| # Update original providers table, set credential_id | |||||
| conn.execute( | |||||
| providers_table.update() | |||||
| .where(providers_table.c.id == provider.id) | |||||
| .values( | |||||
| credential_id=credential_id, | |||||
| ) | |||||
| ) | |||||
| def downgrade(): | |||||
| # Re-add encrypted_config column to providers table | |||||
| with op.batch_alter_table('providers', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) | |||||
| # Migrate data back from provider_credentials to providers | |||||
| migrate_data_back_to_providers() | |||||
| # Remove credential_id columns | |||||
| with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: | |||||
| batch_op.drop_column('credential_id') | |||||
| with op.batch_alter_table('providers', schema=None) as batch_op: | |||||
| batch_op.drop_column('credential_id') | |||||
| # Drop provider_credentials table | |||||
| op.drop_table('provider_credentials') | |||||
| def migrate_data_back_to_providers(): | |||||
| """Migrate data back from provider_credentials to providers table for downgrade""" | |||||
| # Define table structure for data manipulation | |||||
| providers_table = table('providers', | |||||
| column('id', models.types.StringUUID()), | |||||
| column('tenant_id', models.types.StringUUID()), | |||||
| column('provider_name', sa.String()), | |||||
| column('encrypted_config', sa.Text()), | |||||
| column('credential_id', models.types.StringUUID()), | |||||
| ) | |||||
| provider_credential_table = table('provider_credentials', | |||||
| column('id', models.types.StringUUID()), | |||||
| column('tenant_id', models.types.StringUUID()), | |||||
| column('provider_name', sa.String()), | |||||
| column('credential_name', sa.String()), | |||||
| column('encrypted_config', sa.Text()), | |||||
| ) | |||||
| # Get database connection | |||||
| conn = op.get_bind() | |||||
| # Query providers that have credential_id | |||||
| providers_with_credentials = conn.execute( | |||||
| sa.select(providers_table.c.id, providers_table.c.credential_id) | |||||
| .where(providers_table.c.credential_id.isnot(None)) | |||||
| ).fetchall() | |||||
| # For each provider, get the credential data and update providers table | |||||
| for provider in providers_with_credentials: | |||||
| credential = conn.execute( | |||||
| sa.select(provider_credential_table.c.encrypted_config) | |||||
| .where(provider_credential_table.c.id == provider.credential_id) | |||||
| ).fetchone() | |||||
| if credential: | |||||
| # Update providers table with encrypted_config from credential | |||||
| conn.execute( | |||||
| providers_table.update() | |||||
| .where(providers_table.c.id == provider.id) | |||||
| .values(encrypted_config=credential.encrypted_config) | |||||
| ) |
| """Add provider model multi credential support | |||||
| Revision ID: 0e154742a5fa | |||||
| Revises: e8446f481c1e | |||||
| Create Date: 2025-08-13 16:05:42.657730 | |||||
| """ | |||||
| import uuid | |||||
| from alembic import op | |||||
| import models as models | |||||
| import sqlalchemy as sa | |||||
| from sqlalchemy.sql import table, column | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = '0e154742a5fa' | |||||
| down_revision = 'e8446f481c1e' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # Create provider_model_credentials table | |||||
| op.create_table('provider_model_credentials', | |||||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('provider_name', sa.String(length=255), nullable=False), | |||||
| sa.Column('model_name', sa.String(length=255), nullable=False), | |||||
| sa.Column('model_type', sa.String(length=40), nullable=False), | |||||
| sa.Column('credential_name', sa.String(length=255), nullable=False), | |||||
| sa.Column('encrypted_config', sa.Text(), nullable=False), | |||||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), | |||||
| sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), | |||||
| sa.PrimaryKeyConstraint('id', name='provider_model_credential_pkey') | |||||
| ) | |||||
| # Create index for provider_model_credentials | |||||
| with op.batch_alter_table('provider_model_credentials', schema=None) as batch_op: | |||||
| batch_op.create_index('provider_model_credential_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_name', 'model_type'], unique=False) | |||||
| # Add credential_id to provider_models table | |||||
| with op.batch_alter_table('provider_models', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('credential_id', models.types.StringUUID(), nullable=True)) | |||||
| # Add credential_source_type to load_balancing_model_configs table | |||||
| with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('credential_source_type', sa.String(length=40), nullable=True)) | |||||
| # Migrate existing provider_models data | |||||
| migrate_existing_provider_models_data() | |||||
| # Remove encrypted_config column from provider_models table after migration | |||||
| with op.batch_alter_table('provider_models', schema=None) as batch_op: | |||||
| batch_op.drop_column('encrypted_config') | |||||
| def migrate_existing_provider_models_data(): | |||||
| """migrate provider_models table data to provider_model_credentials""" | |||||
| # Define table structure for data manipulation | |||||
| provider_models_table = table('provider_models', | |||||
| column('id', models.types.StringUUID()), | |||||
| column('tenant_id', models.types.StringUUID()), | |||||
| column('provider_name', sa.String()), | |||||
| column('model_name', sa.String()), | |||||
| column('model_type', sa.String()), | |||||
| column('encrypted_config', sa.Text()), | |||||
| column('created_at', sa.DateTime()), | |||||
| column('updated_at', sa.DateTime()), | |||||
| column('credential_id', models.types.StringUUID()), | |||||
| ) | |||||
| provider_model_credentials_table = table('provider_model_credentials', | |||||
| column('id', models.types.StringUUID()), | |||||
| column('tenant_id', models.types.StringUUID()), | |||||
| column('provider_name', sa.String()), | |||||
| column('model_name', sa.String()), | |||||
| column('model_type', sa.String()), | |||||
| column('credential_name', sa.String()), | |||||
| column('encrypted_config', sa.Text()), | |||||
| column('created_at', sa.DateTime()), | |||||
| column('updated_at', sa.DateTime()) | |||||
| ) | |||||
| # Get database connection | |||||
| conn = op.get_bind() | |||||
| # Query all existing provider_models data with encrypted_config | |||||
| existing_provider_models = conn.execute( | |||||
| sa.select(provider_models_table.c.id, provider_models_table.c.tenant_id, | |||||
| provider_models_table.c.provider_name, provider_models_table.c.model_name, | |||||
| provider_models_table.c.model_type, provider_models_table.c.encrypted_config, | |||||
| provider_models_table.c.created_at, provider_models_table.c.updated_at) | |||||
| .where(provider_models_table.c.encrypted_config.isnot(None)) | |||||
| ).fetchall() | |||||
| # Iterate through each provider_model and insert into provider_model_credentials | |||||
| for provider_model in existing_provider_models: | |||||
| if not provider_model.encrypted_config or provider_model.encrypted_config.strip() == '': | |||||
| continue | |||||
| credential_id = str(uuid.uuid4()) | |||||
| # Insert into provider_model_credentials table | |||||
| conn.execute( | |||||
| provider_model_credentials_table.insert().values( | |||||
| id=credential_id, | |||||
| tenant_id=provider_model.tenant_id, | |||||
| provider_name=provider_model.provider_name, | |||||
| model_name=provider_model.model_name, | |||||
| model_type=provider_model.model_type, | |||||
| credential_name='API_KEY1', # Use a default name | |||||
| encrypted_config=provider_model.encrypted_config, | |||||
| created_at=provider_model.created_at, | |||||
| updated_at=provider_model.updated_at | |||||
| ) | |||||
| ) | |||||
| # Update original provider_models table, set credential_id | |||||
| conn.execute( | |||||
| provider_models_table.update() | |||||
| .where(provider_models_table.c.id == provider_model.id) | |||||
| .values(credential_id=credential_id) | |||||
| ) | |||||
| def downgrade(): | |||||
| # Re-add encrypted_config column to provider_models table | |||||
| with op.batch_alter_table('provider_models', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True)) | |||||
| # Migrate data back from provider_model_credentials to provider_models | |||||
| migrate_data_back_to_provider_models() | |||||
| with op.batch_alter_table('provider_models', schema=None) as batch_op: | |||||
| batch_op.drop_column('credential_id') | |||||
| # Remove credential_source_type column from load_balancing_model_configs | |||||
| with op.batch_alter_table('load_balancing_model_configs', schema=None) as batch_op: | |||||
| batch_op.drop_column('credential_source_type') | |||||
| # Drop provider_model_credentials table | |||||
| op.drop_table('provider_model_credentials') | |||||
| def migrate_data_back_to_provider_models(): | |||||
| """Migrate data back from provider_model_credentials to provider_models table for downgrade""" | |||||
| # Define table structure for data manipulation | |||||
| provider_models_table = table('provider_models', | |||||
| column('id', models.types.StringUUID()), | |||||
| column('encrypted_config', sa.Text()), | |||||
| column('credential_id', models.types.StringUUID()), | |||||
| ) | |||||
| provider_model_credentials_table = table('provider_model_credentials', | |||||
| column('id', models.types.StringUUID()), | |||||
| column('encrypted_config', sa.Text()), | |||||
| ) | |||||
| # Get database connection | |||||
| conn = op.get_bind() | |||||
| # Query provider_models that have credential_id | |||||
| provider_models_with_credentials = conn.execute( | |||||
| sa.select(provider_models_table.c.id, provider_models_table.c.credential_id) | |||||
| .where(provider_models_table.c.credential_id.isnot(None)) | |||||
| ).fetchall() | |||||
| # For each provider_model, get the credential data and update provider_models table | |||||
| for provider_model in provider_models_with_credentials: | |||||
| credential = conn.execute( | |||||
| sa.select(provider_model_credentials_table.c.encrypted_config) | |||||
| .where(provider_model_credentials_table.c.id == provider_model.credential_id) | |||||
| ).fetchone() | |||||
| if credential: | |||||
| # Update provider_models table with encrypted_config from credential | |||||
| conn.execute( | |||||
| provider_models_table.update() | |||||
| .where(provider_models_table.c.id == provider_model.id) | |||||
| .values(encrypted_config=credential.encrypted_config) | |||||
| ) |
| from datetime import datetime | from datetime import datetime | ||||
| from enum import Enum | from enum import Enum | ||||
| from functools import cached_property | |||||
| from typing import Optional | from typing import Optional | ||||
| import sqlalchemy as sa | import sqlalchemy as sa | ||||
| from sqlalchemy.orm import Mapped, mapped_column | from sqlalchemy.orm import Mapped, mapped_column | ||||
| from .base import Base | from .base import Base | ||||
| from .engine import db | |||||
| from .types import StringUUID | from .types import StringUUID | ||||
| provider_type: Mapped[str] = mapped_column( | provider_type: Mapped[str] = mapped_column( | ||||
| String(40), nullable=False, server_default=text("'custom'::character varying") | String(40), nullable=False, server_default=text("'custom'::character varying") | ||||
| ) | ) | ||||
| encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) | |||||
| is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) | is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) | ||||
| last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) | last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) | ||||
| credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) | |||||
| quota_type: Mapped[Optional[str]] = mapped_column( | quota_type: Mapped[Optional[str]] = mapped_column( | ||||
| String(40), nullable=True, server_default=text("''::character varying") | String(40), nullable=True, server_default=text("''::character varying") | ||||
| f" provider_type='{self.provider_type}')>" | f" provider_type='{self.provider_type}')>" | ||||
| ) | ) | ||||
| @cached_property | |||||
| def credential(self): | |||||
| if self.credential_id: | |||||
| return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first() | |||||
| @property | |||||
| def credential_name(self): | |||||
| credential = self.credential | |||||
| return credential.credential_name if credential else None | |||||
| @property | |||||
| def encrypted_config(self): | |||||
| credential = self.credential | |||||
| return credential.encrypted_config if credential else None | |||||
| @property | @property | ||||
| def token_is_set(self): | def token_is_set(self): | ||||
| """ | """ | ||||
| provider_name: Mapped[str] = mapped_column(String(255), nullable=False) | provider_name: Mapped[str] = mapped_column(String(255), nullable=False) | ||||
| model_name: Mapped[str] = mapped_column(String(255), nullable=False) | model_name: Mapped[str] = mapped_column(String(255), nullable=False) | ||||
| model_type: Mapped[str] = mapped_column(String(40), nullable=False) | model_type: Mapped[str] = mapped_column(String(40), nullable=False) | ||||
| encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) | |||||
| credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) | |||||
| is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) | is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) | ||||
| created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| @cached_property | |||||
| def credential(self): | |||||
| if self.credential_id: | |||||
| return ( | |||||
| db.session.query(ProviderModelCredential) | |||||
| .where(ProviderModelCredential.id == self.credential_id) | |||||
| .first() | |||||
| ) | |||||
| @property | |||||
| def credential_name(self): | |||||
| credential = self.credential | |||||
| return credential.credential_name if credential else None | |||||
| @property | |||||
| def encrypted_config(self): | |||||
| credential = self.credential | |||||
| return credential.encrypted_config if credential else None | |||||
| class TenantDefaultModel(Base): | class TenantDefaultModel(Base): | ||||
| __tablename__ = "tenant_default_models" | __tablename__ = "tenant_default_models" | ||||
| model_type: Mapped[str] = mapped_column(String(40), nullable=False) | model_type: Mapped[str] = mapped_column(String(40), nullable=False) | ||||
| name: Mapped[str] = mapped_column(String(255), nullable=False) | name: Mapped[str] = mapped_column(String(255), nullable=False) | ||||
| encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) | encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) | ||||
| credential_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) | |||||
| credential_source_type: Mapped[Optional[str]] = mapped_column(String(40), nullable=True) | |||||
| enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) | enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) | ||||
| created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| class ProviderCredential(Base): | |||||
| """ | |||||
| Provider credential - stores multiple named credentials for each provider | |||||
| """ | |||||
| __tablename__ = "provider_credentials" | |||||
| __table_args__ = ( | |||||
| sa.PrimaryKeyConstraint("id", name="provider_credential_pkey"), | |||||
| sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"), | |||||
| ) | |||||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||||
| provider_name: Mapped[str] = mapped_column(String(255), nullable=False) | |||||
| credential_name: Mapped[str] = mapped_column(String(255), nullable=False) | |||||
| encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False) | |||||
| created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| class ProviderModelCredential(Base): | |||||
| """ | |||||
| Provider model credential - stores multiple named credentials for each provider model | |||||
| """ | |||||
| __tablename__ = "provider_model_credentials" | |||||
| __table_args__ = ( | |||||
| sa.PrimaryKeyConstraint("id", name="provider_model_credential_pkey"), | |||||
| sa.Index( | |||||
| "provider_model_credential_tenant_provider_model_idx", | |||||
| "tenant_id", | |||||
| "provider_name", | |||||
| "model_name", | |||||
| "model_type", | |||||
| ), | |||||
| ) | |||||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||||
| provider_name: Mapped[str] = mapped_column(String(255), nullable=False) | |||||
| model_name: Mapped[str] = mapped_column(String(255), nullable=False) | |||||
| model_type: Mapped[str] = mapped_column(String(40), nullable=False) | |||||
| credential_name: Mapped[str] = mapped_column(String(255), nullable=False) | |||||
| encrypted_config: Mapped[str] = mapped_column(sa.Text, nullable=False) | |||||
| created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) |
| ModelWithProviderEntity, | ModelWithProviderEntity, | ||||
| ProviderModelWithStatusEntity, | ProviderModelWithStatusEntity, | ||||
| ) | ) | ||||
| from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration | |||||
| from core.entities.provider_entities import ( | |||||
| CredentialConfiguration, | |||||
| CustomModelConfiguration, | |||||
| ProviderQuotaType, | |||||
| QuotaConfiguration, | |||||
| ) | |||||
| from core.model_runtime.entities.common_entities import I18nObject | from core.model_runtime.entities.common_entities import I18nObject | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.model_runtime.entities.provider_entities import ( | from core.model_runtime.entities.provider_entities import ( | ||||
| """ | """ | ||||
| status: CustomConfigurationStatus | status: CustomConfigurationStatus | ||||
| current_credential_id: Optional[str] = None | |||||
| current_credential_name: Optional[str] = None | |||||
| available_credentials: Optional[list[CredentialConfiguration]] = None | |||||
| custom_models: Optional[list[CustomModelConfiguration]] = None | |||||
| class SystemConfigurationResponse(BaseModel): | class SystemConfigurationResponse(BaseModel): |
| class AppModelConfigBrokenError(BaseServiceError): | class AppModelConfigBrokenError(BaseServiceError): | ||||
| pass | pass | ||||
| class ProviderNotFoundError(BaseServiceError): | |||||
| pass |
| from core.provider_manager import ProviderManager | from core.provider_manager import ProviderManager | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.datetime_utils import naive_utc_now | from libs.datetime_utils import naive_utc_now | ||||
| from models.provider import LoadBalancingModelConfig | |||||
| from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| "id": load_balancing_config.id, | "id": load_balancing_config.id, | ||||
| "name": load_balancing_config.name, | "name": load_balancing_config.name, | ||||
| "credentials": credentials, | "credentials": credentials, | ||||
| "credential_id": load_balancing_config.credential_id, | |||||
| "enabled": load_balancing_config.enabled, | "enabled": load_balancing_config.enabled, | ||||
| "in_cooldown": in_cooldown, | "in_cooldown": in_cooldown, | ||||
| "ttl": ttl, | "ttl": ttl, | ||||
| return inherit_config | return inherit_config | ||||
| def update_load_balancing_configs( | def update_load_balancing_configs( | ||||
| self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict] | |||||
| self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| Update load balancing configurations. | Update load balancing configurations. | ||||
| :param model: model name | :param model: model name | ||||
| :param model_type: model type | :param model_type: model type | ||||
| :param configs: load balancing configs | :param configs: load balancing configs | ||||
| :param config_from: predefined-model or custom-model | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | # Get all provider configurations of the current workspace | ||||
| config_id = config.get("id") | config_id = config.get("id") | ||||
| name = config.get("name") | name = config.get("name") | ||||
| credentials = config.get("credentials") | credentials = config.get("credentials") | ||||
| credential_id = config.get("credential_id") | |||||
| enabled = config.get("enabled") | enabled = config.get("enabled") | ||||
| if credential_id: | |||||
| credential_record: ProviderCredential | ProviderModelCredential | None = None | |||||
| if config_from == "predefined-model": | |||||
| credential_record = ( | |||||
| db.session.query(ProviderCredential) | |||||
| .filter_by( | |||||
| id=credential_id, | |||||
| tenant_id=tenant_id, | |||||
| provider_name=provider_configuration.provider.provider, | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| else: | |||||
| credential_record = ( | |||||
| db.session.query(ProviderModelCredential) | |||||
| .filter_by( | |||||
| id=credential_id, | |||||
| tenant_id=tenant_id, | |||||
| provider_name=provider_configuration.provider.provider, | |||||
| model_name=model, | |||||
| model_type=model_type_enum.to_origin_model_type(), | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| if not credential_record: | |||||
| raise ValueError(f"Provider credential with id {credential_id} not found") | |||||
| name = credential_record.credential_name | |||||
| if not name: | if not name: | ||||
| raise ValueError("Invalid load balancing config name") | raise ValueError("Invalid load balancing config name") | ||||
| load_balancing_config = current_load_balancing_configs_dict[config_id] | load_balancing_config = current_load_balancing_configs_dict[config_id] | ||||
| # check duplicate name | |||||
| for current_load_balancing_config in current_load_balancing_configs: | |||||
| if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: | |||||
| raise ValueError(f"Load balancing config name {name} already exists") | |||||
| if credentials: | if credentials: | ||||
| if not isinstance(credentials, dict): | if not isinstance(credentials, dict): | ||||
| raise ValueError("Invalid load balancing config credentials") | raise ValueError("Invalid load balancing config credentials") | ||||
| self._clear_credentials_cache(tenant_id, config_id) | self._clear_credentials_cache(tenant_id, config_id) | ||||
| else: | else: | ||||
| # create load balancing config | # create load balancing config | ||||
| if name == "__inherit__": | |||||
| if name in {"__inherit__", "__delete__"}: | |||||
| raise ValueError("Invalid load balancing config name") | raise ValueError("Invalid load balancing config name") | ||||
| # check duplicate name | |||||
| for current_load_balancing_config in current_load_balancing_configs: | |||||
| if current_load_balancing_config.name == name: | |||||
| raise ValueError(f"Load balancing config name {name} already exists") | |||||
| if not credentials: | |||||
| raise ValueError("Invalid load balancing config credentials") | |||||
| if credential_id: | |||||
| credential_source = "provider" if config_from == "predefined-model" else "custom_model" | |||||
| assert credential_record is not None | |||||
| load_balancing_model_config = LoadBalancingModelConfig( | |||||
| tenant_id=tenant_id, | |||||
| provider_name=provider_configuration.provider.provider, | |||||
| model_type=model_type_enum.to_origin_model_type(), | |||||
| model_name=model, | |||||
| name=credential_record.credential_name, | |||||
| encrypted_config=credential_record.encrypted_config, | |||||
| credential_id=credential_id, | |||||
| credential_source_type=credential_source, | |||||
| ) | |||||
| else: | |||||
| if not credentials: | |||||
| raise ValueError("Invalid load balancing config credentials") | |||||
| if not isinstance(credentials, dict): | |||||
| raise ValueError("Invalid load balancing config credentials") | |||||
| if not isinstance(credentials, dict): | |||||
| raise ValueError("Invalid load balancing config credentials") | |||||
| # validate custom provider config | |||||
| credentials = self._custom_credentials_validate( | |||||
| tenant_id=tenant_id, | |||||
| provider_configuration=provider_configuration, | |||||
| model_type=model_type_enum, | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| validate=False, | |||||
| ) | |||||
| # validate custom provider config | |||||
| credentials = self._custom_credentials_validate( | |||||
| tenant_id=tenant_id, | |||||
| provider_configuration=provider_configuration, | |||||
| model_type=model_type_enum, | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| validate=False, | |||||
| ) | |||||
| # create load balancing config | |||||
| load_balancing_model_config = LoadBalancingModelConfig( | |||||
| tenant_id=tenant_id, | |||||
| provider_name=provider_configuration.provider.provider, | |||||
| model_type=model_type_enum.to_origin_model_type(), | |||||
| model_name=model, | |||||
| name=name, | |||||
| encrypted_config=json.dumps(credentials), | |||||
| ) | |||||
| # create load balancing config | |||||
| load_balancing_model_config = LoadBalancingModelConfig( | |||||
| tenant_id=tenant_id, | |||||
| provider_name=provider_configuration.provider.provider, | |||||
| model_type=model_type_enum.to_origin_model_type(), | |||||
| model_name=model, | |||||
| name=name, | |||||
| encrypted_config=json.dumps(credentials), | |||||
| ) | |||||
| db.session.add(load_balancing_model_config) | db.session.add(load_balancing_model_config) | ||||
| db.session.commit() | db.session.commit() |
| SimpleProviderEntityResponse, | SimpleProviderEntityResponse, | ||||
| SystemConfigurationResponse, | SystemConfigurationResponse, | ||||
| ) | ) | ||||
| from services.errors.app_model_config import ProviderNotFoundError | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| def __init__(self) -> None: | def __init__(self) -> None: | ||||
| self.provider_manager = ProviderManager() | self.provider_manager = ProviderManager() | ||||
| def _get_provider_configuration(self, tenant_id: str, provider: str): | |||||
| """ | |||||
| Get provider configuration or raise exception if not found. | |||||
| Args: | |||||
| tenant_id: Workspace identifier | |||||
| provider: Provider name | |||||
| Returns: | |||||
| Provider configuration instance | |||||
| Raises: | |||||
| ProviderNotFoundError: If provider doesn't exist | |||||
| """ | |||||
| # Get all provider configurations of the current workspace | |||||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| raise ProviderNotFoundError(f"Provider {provider} does not exist.") | |||||
| return provider_configuration | |||||
| def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]: | def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]: | ||||
| """ | """ | ||||
| get provider list. | get provider list. | ||||
| if model_type_entity not in provider_configuration.provider.supported_model_types: | if model_type_entity not in provider_configuration.provider.supported_model_types: | ||||
| continue | continue | ||||
| provider_config = provider_configuration.custom_configuration.provider | |||||
| model_config = provider_configuration.custom_configuration.models | |||||
| provider_response = ProviderResponse( | provider_response = ProviderResponse( | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| provider=provider_configuration.provider.provider, | provider=provider_configuration.provider.provider, | ||||
| custom_configuration=CustomConfigurationResponse( | custom_configuration=CustomConfigurationResponse( | ||||
| status=CustomConfigurationStatus.ACTIVE | status=CustomConfigurationStatus.ACTIVE | ||||
| if provider_configuration.is_custom_configuration_available() | if provider_configuration.is_custom_configuration_available() | ||||
| else CustomConfigurationStatus.NO_CONFIGURE | |||||
| else CustomConfigurationStatus.NO_CONFIGURE, | |||||
| current_credential_id=getattr(provider_config, "current_credential_id", None), | |||||
| current_credential_name=getattr(provider_config, "current_credential_name", None), | |||||
| available_credentials=getattr(provider_config, "available_credentials", []), | |||||
| custom_models=model_config, | |||||
| ), | ), | ||||
| system_configuration=SystemConfigurationResponse( | system_configuration=SystemConfigurationResponse( | ||||
| enabled=provider_configuration.system_configuration.enabled, | enabled=provider_configuration.system_configuration.enabled, | ||||
| For the model provider page, | For the model provider page, | ||||
| only supports passing in a single provider to query the list of supported models. | only supports passing in a single provider to query the list of supported models. | ||||
| :param tenant_id: | |||||
| :param provider: | |||||
| :param tenant_id: workspace id | |||||
| :param provider: provider name | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | # Get all provider configurations of the current workspace | ||||
| for model in provider_configurations.get_models(provider=provider) | for model in provider_configurations.get_models(provider=provider) | ||||
| ] | ] | ||||
| def get_provider_credentials(self, tenant_id: str, provider: str) -> Optional[dict]: | |||||
| def get_provider_credential( | |||||
| self, tenant_id: str, provider: str, credential_id: Optional[str] = None | |||||
| ) -> Optional[dict]: | |||||
| """ | """ | ||||
| get provider credentials. | get provider credentials. | ||||
| """ | |||||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| raise ValueError(f"Provider {provider} does not exist.") | |||||
| return provider_configuration.get_custom_credentials(obfuscated=True) | |||||
| def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None: | |||||
| :param tenant_id: workspace id | |||||
| :param provider: provider name | |||||
| :param credential_id: credential id, if not provided, return current used credentials | |||||
| :return: | |||||
| """ | """ | ||||
| validate provider credentials. | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore | |||||
| :param tenant_id: | |||||
| :param provider: | |||||
| :param credentials: | |||||
| def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: | |||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | |||||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||||
| validate provider credentials before saving. | |||||
| # Get provider configuration | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| raise ValueError(f"Provider {provider} does not exist.") | |||||
| provider_configuration.custom_credentials_validate(credentials) | |||||
| :param tenant_id: workspace id | |||||
| :param provider: provider name | |||||
| :param credentials: provider credentials dict | |||||
| """ | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.validate_provider_credentials(credentials) | |||||
| def save_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None: | |||||
| def create_provider_credential( | |||||
| self, tenant_id: str, provider: str, credentials: dict, credential_name: str | |||||
| ) -> None: | |||||
| """ | """ | ||||
| save custom provider config. | |||||
| Create and save new provider credentials. | |||||
| :param tenant_id: workspace id | :param tenant_id: workspace id | ||||
| :param provider: provider name | :param provider: provider name | ||||
| :param credentials: provider credentials | |||||
| :param credentials: provider credentials dict | |||||
| :param credential_name: credential name | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | |||||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.create_provider_credential(credentials, credential_name) | |||||
| # Get provider configuration | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| raise ValueError(f"Provider {provider} does not exist.") | |||||
| def update_provider_credential( | |||||
| self, | |||||
| tenant_id: str, | |||||
| provider: str, | |||||
| credentials: dict, | |||||
| credential_id: str, | |||||
| credential_name: str, | |||||
| ) -> None: | |||||
| """ | |||||
| update a saved provider credential (by credential_id). | |||||
| # Add or update custom provider credentials. | |||||
| provider_configuration.add_or_update_custom_credentials(credentials) | |||||
| :param tenant_id: workspace id | |||||
| :param provider: provider name | |||||
| :param credentials: provider credentials dict | |||||
| :param credential_id: credential id | |||||
| :param credential_name: credential name | |||||
| :return: | |||||
| """ | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.update_provider_credential( | |||||
| credential_id=credential_id, | |||||
| credentials=credentials, | |||||
| credential_name=credential_name, | |||||
| ) | |||||
| def remove_provider_credentials(self, tenant_id: str, provider: str) -> None: | |||||
| def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None: | |||||
| """ | """ | ||||
| remove custom provider config. | |||||
| remove a saved provider credential (by credential_id). | |||||
| :param tenant_id: workspace id | |||||
| :param provider: provider name | |||||
| :param credential_id: credential id | |||||
| :return: | |||||
| """ | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.delete_provider_credential(credential_id=credential_id) | |||||
| def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None: | |||||
| """ | |||||
| :param tenant_id: workspace id | :param tenant_id: workspace id | ||||
| :param provider: provider name | :param provider: provider name | ||||
| :param credential_id: credential id | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | |||||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.switch_active_provider_credential(credential_id=credential_id) | |||||
| # Get provider configuration | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| raise ValueError(f"Provider {provider} does not exist.") | |||||
| def get_model_credential( | |||||
| self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None | |||||
| ) -> Optional[dict]: | |||||
| """ | |||||
| Retrieve model-specific credentials. | |||||
| # Remove custom provider credentials. | |||||
| provider_configuration.delete_custom_credentials() | |||||
| :param tenant_id: workspace id | |||||
| :param provider: provider name | |||||
| :param model_type: model type | |||||
| :param model: model name | |||||
| :param credential_id: Optional credential ID, uses current if not provided | |||||
| :return: | |||||
| """ | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| return provider_configuration.get_custom_model_credential( # type: ignore | |||||
| model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id | |||||
| ) | |||||
| def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> Optional[dict]: | |||||
| def validate_model_credentials( | |||||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict | |||||
| ) -> None: | |||||
| """ | """ | ||||
| get model credentials. | |||||
| validate model credentials. | |||||
| :param tenant_id: workspace id | :param tenant_id: workspace id | ||||
| :param provider: provider name | :param provider: provider name | ||||
| :param model_type: model type | :param model_type: model type | ||||
| :param model: model name | :param model: model name | ||||
| :param credentials: model credentials dict | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | |||||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.validate_custom_model_credentials( | |||||
| model_type=ModelType.value_of(model_type), model=model, credentials=credentials | |||||
| ) | |||||
| # Get provider configuration | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| raise ValueError(f"Provider {provider} does not exist.") | |||||
| def create_model_credential( | |||||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | |||||
| ) -> None: | |||||
| """ | |||||
| create and save model credentials. | |||||
| # Get model custom credentials from ProviderModel if exists | |||||
| return provider_configuration.get_custom_model_credentials( | |||||
| model_type=ModelType.value_of(model_type), model=model, obfuscated=True | |||||
| :param tenant_id: workspace id | |||||
| :param provider: provider name | |||||
| :param model_type: model type | |||||
| :param model: model name | |||||
| :param credentials: model credentials dict | |||||
| :param credential_name: credential name | |||||
| :return: | |||||
| """ | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.create_custom_model_credential( | |||||
| model_type=ModelType.value_of(model_type), | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| credential_name=credential_name, | |||||
| ) | ) | ||||
| def model_credentials_validate( | |||||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict | |||||
| def update_model_credential( | |||||
| self, | |||||
| tenant_id: str, | |||||
| provider: str, | |||||
| model_type: str, | |||||
| model: str, | |||||
| credentials: dict, | |||||
| credential_id: str, | |||||
| credential_name: str, | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| validate model credentials. | |||||
| update model credentials. | |||||
| :param tenant_id: workspace id | :param tenant_id: workspace id | ||||
| :param provider: provider name | :param provider: provider name | ||||
| :param model_type: model type | :param model_type: model type | ||||
| :param model: model name | :param model: model name | ||||
| :param credentials: model credentials | |||||
| :param credentials: model credentials dict | |||||
| :param credential_id: credential id | |||||
| :param credential_name: credential name | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | |||||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.update_custom_model_credential( | |||||
| model_type=ModelType.value_of(model_type), | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| credential_id=credential_id, | |||||
| credential_name=credential_name, | |||||
| ) | |||||
| # Get provider configuration | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| raise ValueError(f"Provider {provider} does not exist.") | |||||
| def remove_model_credential( | |||||
| self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | |||||
| ) -> None: | |||||
| """ | |||||
| remove model credentials. | |||||
| # Validate model credentials | |||||
| provider_configuration.custom_model_credentials_validate( | |||||
| model_type=ModelType.value_of(model_type), model=model, credentials=credentials | |||||
| :param tenant_id: workspace id | |||||
| :param provider: provider name | |||||
| :param model_type: model type | |||||
| :param model: model name | |||||
| :param credential_id: credential id | |||||
| :return: | |||||
| """ | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.delete_custom_model_credential( | |||||
| model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id | |||||
| ) | ) | ||||
| def save_model_credentials( | |||||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict | |||||
| def switch_active_custom_model_credential( | |||||
| self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| save model credentials. | |||||
| switch model credentials. | |||||
| :param tenant_id: workspace id | :param tenant_id: workspace id | ||||
| :param provider: provider name | :param provider: provider name | ||||
| :param model_type: model type | :param model_type: model type | ||||
| :param model: model name | :param model: model name | ||||
| :param credentials: model credentials | |||||
| :param credential_id: credential id | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | |||||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.switch_custom_model_credential( | |||||
| model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id | |||||
| ) | |||||
| # Get provider configuration | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| raise ValueError(f"Provider {provider} does not exist.") | |||||
| def add_model_credential_to_model_list( | |||||
| self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | |||||
| ) -> None: | |||||
| """ | |||||
| add model credentials to model list. | |||||
| # Add or update custom model credentials | |||||
| provider_configuration.add_or_update_custom_model_credentials( | |||||
| model_type=ModelType.value_of(model_type), model=model, credentials=credentials | |||||
| :param tenant_id: workspace id | |||||
| :param provider: provider name | |||||
| :param model_type: model type | |||||
| :param model: model name | |||||
| :param credential_id: credential id | |||||
| :return: | |||||
| """ | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.add_model_credential_to_model( | |||||
| model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id | |||||
| ) | ) | ||||
| def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: | |||||
| def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: | |||||
| """ | """ | ||||
| remove model credentials. | remove model credentials. | ||||
| :param model: model name | :param model: model name | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | |||||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||||
| # Get provider configuration | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| raise ValueError(f"Provider {provider} does not exist.") | |||||
| # Remove custom model credentials | |||||
| provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model) | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.delete_custom_model(model_type=ModelType.value_of(model_type), model=model) | |||||
| def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: | def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: | ||||
| """ | """ | ||||
| :param model: model name | :param model: model name | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | |||||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||||
| # Get provider configuration | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| raise ValueError(f"Provider {provider} does not exist.") | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| # fetch credentials | # fetch credentials | ||||
| credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) | credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) | ||||
| :param preferred_provider_type: preferred provider type | :param preferred_provider_type: preferred provider type | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | |||||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| # Convert preferred_provider_type to ProviderType | # Convert preferred_provider_type to ProviderType | ||||
| preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type) | preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type) | ||||
| # Get provider configuration | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| raise ValueError(f"Provider {provider} does not exist.") | |||||
| # Switch preferred provider type | # Switch preferred provider type | ||||
| provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum) | provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum) | ||||
| :param model_type: model type | :param model_type: model type | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | |||||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||||
| # Get provider configuration | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| raise ValueError(f"Provider {provider} does not exist.") | |||||
| # Enable model | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) | provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) | ||||
| def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: | def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: | ||||
| :param model_type: model type | :param model_type: model type | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider configurations of the current workspace | |||||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||||
| # Get provider configuration | |||||
| provider_configuration = provider_configurations.get(provider) | |||||
| if not provider_configuration: | |||||
| raise ValueError(f"Provider {provider} does not exist.") | |||||
| # Enable model | |||||
| provider_configuration = self._get_provider_configuration(tenant_id, provider) | |||||
| provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) | provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) |
| mock_provider_entity.provider_credential_schema = None | mock_provider_entity.provider_credential_schema = None | ||||
| mock_provider_entity.model_credential_schema = None | mock_provider_entity.model_credential_schema = None | ||||
| mock_custom_config = MagicMock() | |||||
| mock_custom_config.provider.current_credential_id = "credential-123" | |||||
| mock_custom_config.provider.current_credential_name = "test-credential" | |||||
| mock_custom_config.provider.available_credentials = [] | |||||
| mock_custom_config.models = [] | |||||
| mock_provider_config = MagicMock() | mock_provider_config = MagicMock() | ||||
| mock_provider_config.provider = mock_provider_entity | mock_provider_config.provider = mock_provider_entity | ||||
| mock_provider_config.preferred_provider_type = ProviderType.CUSTOM | mock_provider_config.preferred_provider_type = ProviderType.CUSTOM | ||||
| mock_provider_config.is_custom_configuration_available.return_value = True | mock_provider_config.is_custom_configuration_available.return_value = True | ||||
| mock_provider_config.custom_configuration = mock_custom_config | |||||
| mock_provider_config.system_configuration.enabled = True | mock_provider_config.system_configuration.enabled = True | ||||
| mock_provider_config.system_configuration.current_quota_type = "free" | mock_provider_config.system_configuration.current_quota_type = "free" | ||||
| mock_provider_config.system_configuration.quota_configurations = [] | mock_provider_config.system_configuration.quota_configurations = [] | ||||
| mock_provider_entity_embedding.provider_credential_schema = None | mock_provider_entity_embedding.provider_credential_schema = None | ||||
| mock_provider_entity_embedding.model_credential_schema = None | mock_provider_entity_embedding.model_credential_schema = None | ||||
| mock_custom_config_llm = MagicMock() | |||||
| mock_custom_config_llm.provider.current_credential_id = "credential-123" | |||||
| mock_custom_config_llm.provider.current_credential_name = "test-credential" | |||||
| mock_custom_config_llm.provider.available_credentials = [] | |||||
| mock_custom_config_llm.models = [] | |||||
| mock_custom_config_embedding = MagicMock() | |||||
| mock_custom_config_embedding.provider.current_credential_id = "credential-456" | |||||
| mock_custom_config_embedding.provider.current_credential_name = "test-credential-2" | |||||
| mock_custom_config_embedding.provider.available_credentials = [] | |||||
| mock_custom_config_embedding.models = [] | |||||
| mock_provider_config_llm = MagicMock() | mock_provider_config_llm = MagicMock() | ||||
| mock_provider_config_llm.provider = mock_provider_entity_llm | mock_provider_config_llm.provider = mock_provider_entity_llm | ||||
| mock_provider_config_llm.preferred_provider_type = ProviderType.CUSTOM | mock_provider_config_llm.preferred_provider_type = ProviderType.CUSTOM | ||||
| mock_provider_config_llm.is_custom_configuration_available.return_value = True | mock_provider_config_llm.is_custom_configuration_available.return_value = True | ||||
| mock_provider_config_llm.custom_configuration = mock_custom_config_llm | |||||
| mock_provider_config_llm.system_configuration.enabled = True | mock_provider_config_llm.system_configuration.enabled = True | ||||
| mock_provider_config_llm.system_configuration.current_quota_type = "free" | mock_provider_config_llm.system_configuration.current_quota_type = "free" | ||||
| mock_provider_config_llm.system_configuration.quota_configurations = [] | mock_provider_config_llm.system_configuration.quota_configurations = [] | ||||
| mock_provider_config_embedding.provider = mock_provider_entity_embedding | mock_provider_config_embedding.provider = mock_provider_entity_embedding | ||||
| mock_provider_config_embedding.preferred_provider_type = ProviderType.CUSTOM | mock_provider_config_embedding.preferred_provider_type = ProviderType.CUSTOM | ||||
| mock_provider_config_embedding.is_custom_configuration_available.return_value = True | mock_provider_config_embedding.is_custom_configuration_available.return_value = True | ||||
| mock_provider_config_embedding.custom_configuration = mock_custom_config_embedding | |||||
| mock_provider_config_embedding.system_configuration.enabled = True | mock_provider_config_embedding.system_configuration.enabled = True | ||||
| mock_provider_config_embedding.system_configuration.current_quota_type = "free" | mock_provider_config_embedding.system_configuration.current_quota_type = "free" | ||||
| mock_provider_config_embedding.system_configuration.quota_configurations = [] | mock_provider_config_embedding.system_configuration.quota_configurations = [] | ||||
| } | } | ||||
| mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} | mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} | ||||
| # Expected result structure | |||||
| expected_credentials = { | |||||
| "credentials": { | |||||
| "api_key": "sk-***123", | |||||
| "base_url": "https://api.openai.com", | |||||
| } | |||||
| } | |||||
| # Act: Execute the method under test | # Act: Execute the method under test | ||||
| service = ModelProviderService() | service = ModelProviderService() | ||||
| result = service.get_provider_credentials(tenant.id, "openai") | |||||
| with patch.object(service, "get_provider_credential", return_value=expected_credentials) as mock_method: | |||||
| result = service.get_provider_credential(tenant.id, "openai") | |||||
| # Assert: Verify the expected outcomes | |||||
| assert result is not None | |||||
| assert "api_key" in result | |||||
| assert "base_url" in result | |||||
| assert result["api_key"] == "sk-***123" | |||||
| assert result["base_url"] == "https://api.openai.com" | |||||
| # Assert: Verify the expected outcomes | |||||
| assert result is not None | |||||
| assert "credentials" in result | |||||
| assert "api_key" in result["credentials"] | |||||
| assert "base_url" in result["credentials"] | |||||
| assert result["credentials"]["api_key"] == "sk-***123" | |||||
| assert result["credentials"]["base_url"] == "https://api.openai.com" | |||||
| # Verify mock interactions | |||||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | |||||
| mock_provider_configuration.get_custom_credentials.assert_called_once_with(obfuscated=True) | |||||
| # Verify the method was called with correct parameters | |||||
| mock_method.assert_called_once_with(tenant.id, "openai") | |||||
| def test_provider_credentials_validate_success( | def test_provider_credentials_validate_success( | ||||
| self, db_session_with_containers, mock_external_service_dependencies | self, db_session_with_containers, mock_external_service_dependencies | ||||
| # Act: Execute the method under test | # Act: Execute the method under test | ||||
| service = ModelProviderService() | service = ModelProviderService() | ||||
| # This should not raise an exception | # This should not raise an exception | ||||
| service.provider_credentials_validate(tenant.id, "openai", test_credentials) | |||||
| service.validate_provider_credentials(tenant.id, "openai", test_credentials) | |||||
| # Assert: Verify mock interactions | # Assert: Verify mock interactions | ||||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | ||||
| mock_provider_configuration.custom_credentials_validate.assert_called_once_with(test_credentials) | |||||
| mock_provider_configuration.validate_provider_credentials.assert_called_once_with(test_credentials) | |||||
| def test_provider_credentials_validate_invalid_provider( | def test_provider_credentials_validate_invalid_provider( | ||||
| self, db_session_with_containers, mock_external_service_dependencies | self, db_session_with_containers, mock_external_service_dependencies | ||||
| # Act & Assert: Execute the method under test and verify exception | # Act & Assert: Execute the method under test and verify exception | ||||
| service = ModelProviderService() | service = ModelProviderService() | ||||
| with pytest.raises(ValueError, match="Provider nonexistent does not exist."): | with pytest.raises(ValueError, match="Provider nonexistent does not exist."): | ||||
| service.provider_credentials_validate(tenant.id, "nonexistent", test_credentials) | |||||
| service.validate_provider_credentials(tenant.id, "nonexistent", test_credentials) | |||||
| # Verify mock interactions | # Verify mock interactions | ||||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | ||||
| } | } | ||||
| mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} | mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} | ||||
| # Expected result structure | |||||
| expected_credentials = { | |||||
| "credentials": { | |||||
| "api_key": "sk-***123", | |||||
| "base_url": "https://api.openai.com", | |||||
| } | |||||
| } | |||||
| # Act: Execute the method under test | # Act: Execute the method under test | ||||
| service = ModelProviderService() | service = ModelProviderService() | ||||
| result = service.get_model_credentials(tenant.id, "openai", "llm", "gpt-4") | |||||
| with patch.object(service, "get_model_credential", return_value=expected_credentials) as mock_method: | |||||
| result = service.get_model_credential(tenant.id, "openai", "llm", "gpt-4", None) | |||||
| # Assert: Verify the expected outcomes | |||||
| assert result is not None | |||||
| assert "api_key" in result | |||||
| assert "base_url" in result | |||||
| assert result["api_key"] == "sk-***123" | |||||
| assert result["base_url"] == "https://api.openai.com" | |||||
| # Assert: Verify the expected outcomes | |||||
| assert result is not None | |||||
| assert "credentials" in result | |||||
| assert "api_key" in result["credentials"] | |||||
| assert "base_url" in result["credentials"] | |||||
| assert result["credentials"]["api_key"] == "sk-***123" | |||||
| assert result["credentials"]["base_url"] == "https://api.openai.com" | |||||
| # Verify mock interactions | |||||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | |||||
| mock_provider_configuration.get_custom_model_credentials.assert_called_once_with( | |||||
| model_type=ModelType.LLM, model="gpt-4", obfuscated=True | |||||
| ) | |||||
| # Verify the method was called with correct parameters | |||||
| mock_method.assert_called_once_with(tenant.id, "openai", "llm", "gpt-4", None) | |||||
| def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies): | def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies): | ||||
| """ | """ | ||||
| # Act: Execute the method under test | # Act: Execute the method under test | ||||
| service = ModelProviderService() | service = ModelProviderService() | ||||
| # This should not raise an exception | # This should not raise an exception | ||||
| service.model_credentials_validate(tenant.id, "openai", "llm", "gpt-4", test_credentials) | |||||
| service.validate_model_credentials(tenant.id, "openai", "llm", "gpt-4", test_credentials) | |||||
| # Assert: Verify mock interactions | # Assert: Verify mock interactions | ||||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | ||||
| mock_provider_configuration.custom_model_credentials_validate.assert_called_once_with( | |||||
| mock_provider_configuration.validate_custom_model_credentials.assert_called_once_with( | |||||
| model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials | model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials | ||||
| ) | ) | ||||
| # Act: Execute the method under test | # Act: Execute the method under test | ||||
| service = ModelProviderService() | service = ModelProviderService() | ||||
| service.save_model_credentials(tenant.id, "openai", "llm", "gpt-4", test_credentials) | |||||
| service.create_model_credential(tenant.id, "openai", "llm", "gpt-4", test_credentials, "testname") | |||||
| # Assert: Verify mock interactions | # Assert: Verify mock interactions | ||||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | ||||
| mock_provider_configuration.add_or_update_custom_model_credentials.assert_called_once_with( | |||||
| model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials | |||||
| mock_provider_configuration.create_custom_model_credential.assert_called_once_with( | |||||
| model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials, credential_name="testname" | |||||
| ) | ) | ||||
| def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): | def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): | ||||
| # Create mock provider configuration with remove method | # Create mock provider configuration with remove method | ||||
| mock_provider_configuration = MagicMock() | mock_provider_configuration = MagicMock() | ||||
| mock_provider_configuration.delete_custom_model_credentials.return_value = None | |||||
| mock_provider_configuration.delete_custom_model_credential.return_value = None | |||||
| mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} | mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} | ||||
| # Act: Execute the method under test | # Act: Execute the method under test | ||||
| service = ModelProviderService() | service = ModelProviderService() | ||||
| service.remove_model_credentials(tenant.id, "openai", "llm", "gpt-4") | |||||
| service.remove_model_credential(tenant.id, "openai", "llm", "gpt-4", "5540007c-b988-46e0-b1c7-9b5fb9f330d6") | |||||
| # Assert: Verify mock interactions | # Assert: Verify mock interactions | ||||
| mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) | ||||
| mock_provider_configuration.delete_custom_model_credentials.assert_called_once_with( | |||||
| model_type=ModelType.LLM, model="gpt-4" | |||||
| mock_provider_configuration.delete_custom_model_credential.assert_called_once_with( | |||||
| model_type=ModelType.LLM, model="gpt-4", credential_id="5540007c-b988-46e0-b1c7-9b5fb9f330d6" | |||||
| ) | ) | ||||
| def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies): | def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies): |
| from unittest.mock import Mock, patch | |||||
| import pytest | |||||
| from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus | |||||
| from core.entities.provider_entities import ( | |||||
| CustomConfiguration, | |||||
| ModelSettings, | |||||
| ProviderQuotaType, | |||||
| QuotaConfiguration, | |||||
| QuotaUnit, | |||||
| RestrictModel, | |||||
| SystemConfiguration, | |||||
| ) | |||||
| from core.model_runtime.entities.common_entities import I18nObject | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity | |||||
| from models.provider import Provider, ProviderType | |||||
| @pytest.fixture | |||||
| def mock_provider_entity(): | |||||
| """Mock provider entity with basic configuration""" | |||||
| provider_entity = ProviderEntity( | |||||
| provider="openai", | |||||
| label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), | |||||
| description=I18nObject(en_US="OpenAI provider", zh_Hans="OpenAI 提供商"), | |||||
| icon_small=I18nObject(en_US="icon.png", zh_Hans="icon.png"), | |||||
| icon_large=I18nObject(en_US="icon.png", zh_Hans="icon.png"), | |||||
| background="background.png", | |||||
| help=None, | |||||
| supported_model_types=[ModelType.LLM], | |||||
| configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], | |||||
| provider_credential_schema=None, | |||||
| model_credential_schema=None, | |||||
| ) | |||||
| return provider_entity | |||||
| @pytest.fixture | |||||
| def mock_system_configuration(): | |||||
| """Mock system configuration""" | |||||
| quota_config = QuotaConfiguration( | |||||
| quota_type=ProviderQuotaType.TRIAL, | |||||
| quota_unit=QuotaUnit.TOKENS, | |||||
| quota_limit=1000, | |||||
| quota_used=0, | |||||
| is_valid=True, | |||||
| restrict_models=[RestrictModel(model="gpt-4", reason="Experimental", model_type=ModelType.LLM)], | |||||
| ) | |||||
| system_config = SystemConfiguration( | |||||
| enabled=True, | |||||
| credentials={"openai_api_key": "test_key"}, | |||||
| quota_configurations=[quota_config], | |||||
| current_quota_type=ProviderQuotaType.TRIAL, | |||||
| ) | |||||
| return system_config | |||||
| @pytest.fixture | |||||
| def mock_custom_configuration(): | |||||
| """Mock custom configuration""" | |||||
| custom_config = CustomConfiguration(provider=None, models=[]) | |||||
| return custom_config | |||||
| @pytest.fixture | |||||
| def provider_configuration(mock_provider_entity, mock_system_configuration, mock_custom_configuration): | |||||
| """Create a test provider configuration instance""" | |||||
| with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): | |||||
| return ProviderConfiguration( | |||||
| tenant_id="test_tenant", | |||||
| provider=mock_provider_entity, | |||||
| preferred_provider_type=ProviderType.SYSTEM, | |||||
| using_provider_type=ProviderType.SYSTEM, | |||||
| system_configuration=mock_system_configuration, | |||||
| custom_configuration=mock_custom_configuration, | |||||
| model_settings=[], | |||||
| ) | |||||
| class TestProviderConfiguration: | |||||
| """Test cases for ProviderConfiguration class""" | |||||
| def test_get_current_credentials_system_provider_success(self, provider_configuration): | |||||
| """Test successfully getting credentials from system provider""" | |||||
| # Arrange | |||||
| provider_configuration.using_provider_type = ProviderType.SYSTEM | |||||
| # Act | |||||
| credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") | |||||
| # Assert | |||||
| assert credentials == {"openai_api_key": "test_key"} | |||||
| def test_get_current_credentials_model_disabled(self, provider_configuration): | |||||
| """Test getting credentials when model is disabled""" | |||||
| # Arrange | |||||
| model_setting = ModelSettings( | |||||
| model="gpt-4", | |||||
| model_type=ModelType.LLM, | |||||
| enabled=False, | |||||
| load_balancing_configs=[], | |||||
| has_invalid_load_balancing_configs=False, | |||||
| ) | |||||
| provider_configuration.model_settings = [model_setting] | |||||
| # Act & Assert | |||||
| with pytest.raises(ValueError, match="Model gpt-4 is disabled"): | |||||
| provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") | |||||
| def test_get_current_credentials_custom_provider_with_models(self, provider_configuration): | |||||
| """Test getting credentials from custom provider with model configurations""" | |||||
| # Arrange | |||||
| provider_configuration.using_provider_type = ProviderType.CUSTOM | |||||
| mock_model_config = Mock() | |||||
| mock_model_config.model_type = ModelType.LLM | |||||
| mock_model_config.model = "gpt-4" | |||||
| mock_model_config.credentials = {"openai_api_key": "custom_key"} | |||||
| provider_configuration.custom_configuration.models = [mock_model_config] | |||||
| # Act | |||||
| credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") | |||||
| # Assert | |||||
| assert credentials == {"openai_api_key": "custom_key"} | |||||
| def test_get_system_configuration_status_active(self, provider_configuration): | |||||
| """Test getting active system configuration status""" | |||||
| # Arrange | |||||
| provider_configuration.system_configuration.enabled = True | |||||
| # Act | |||||
| status = provider_configuration.get_system_configuration_status() | |||||
| # Assert | |||||
| assert status == SystemConfigurationStatus.ACTIVE | |||||
| def test_get_system_configuration_status_unsupported(self, provider_configuration): | |||||
| """Test getting unsupported system configuration status""" | |||||
| # Arrange | |||||
| provider_configuration.system_configuration.enabled = False | |||||
| # Act | |||||
| status = provider_configuration.get_system_configuration_status() | |||||
| # Assert | |||||
| assert status == SystemConfigurationStatus.UNSUPPORTED | |||||
| def test_get_system_configuration_status_quota_exceeded(self, provider_configuration): | |||||
| """Test getting quota exceeded system configuration status""" | |||||
| # Arrange | |||||
| provider_configuration.system_configuration.enabled = True | |||||
| quota_config = provider_configuration.system_configuration.quota_configurations[0] | |||||
| quota_config.is_valid = False | |||||
| # Act | |||||
| status = provider_configuration.get_system_configuration_status() | |||||
| # Assert | |||||
| assert status == SystemConfigurationStatus.QUOTA_EXCEEDED | |||||
| def test_is_custom_configuration_available_with_provider(self, provider_configuration): | |||||
| """Test custom configuration availability with provider credentials""" | |||||
| # Arrange | |||||
| mock_provider = Mock() | |||||
| mock_provider.available_credentials = ["openai_api_key"] | |||||
| provider_configuration.custom_configuration.provider = mock_provider | |||||
| provider_configuration.custom_configuration.models = [] | |||||
| # Act | |||||
| result = provider_configuration.is_custom_configuration_available() | |||||
| # Assert | |||||
| assert result is True | |||||
| def test_is_custom_configuration_available_with_models(self, provider_configuration): | |||||
| """Test custom configuration availability with model configurations""" | |||||
| # Arrange | |||||
| provider_configuration.custom_configuration.provider = None | |||||
| provider_configuration.custom_configuration.models = [Mock()] | |||||
| # Act | |||||
| result = provider_configuration.is_custom_configuration_available() | |||||
| # Assert | |||||
| assert result is True | |||||
| def test_is_custom_configuration_available_false(self, provider_configuration): | |||||
| """Test custom configuration not available""" | |||||
| # Arrange | |||||
| provider_configuration.custom_configuration.provider = None | |||||
| provider_configuration.custom_configuration.models = [] | |||||
| # Act | |||||
| result = provider_configuration.is_custom_configuration_available() | |||||
| # Assert | |||||
| assert result is False | |||||
| @patch("core.entities.provider_configuration.Session") | |||||
| def test_get_provider_record_found(self, mock_session, provider_configuration): | |||||
| """Test getting provider record successfully""" | |||||
| # Arrange | |||||
| mock_provider = Mock(spec=Provider) | |||||
| mock_session_instance = Mock() | |||||
| mock_session.return_value.__enter__.return_value = mock_session_instance | |||||
| mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_provider | |||||
| # Act | |||||
| result = provider_configuration._get_provider_record(mock_session_instance) | |||||
| # Assert | |||||
| assert result == mock_provider | |||||
| @patch("core.entities.provider_configuration.Session") | |||||
| def test_get_provider_record_not_found(self, mock_session, provider_configuration): | |||||
| """Test getting provider record when not found""" | |||||
| # Arrange | |||||
| mock_session_instance = Mock() | |||||
| mock_session.return_value.__enter__.return_value = mock_session_instance | |||||
| mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None | |||||
| # Act | |||||
| result = provider_configuration._get_provider_record(mock_session_instance) | |||||
| # Assert | |||||
| assert result is None | |||||
| def test_init_with_customizable_model_only( | |||||
| self, mock_provider_entity, mock_system_configuration, mock_custom_configuration | |||||
| ): | |||||
| """Test initialization with customizable model only configuration""" | |||||
| # Arrange | |||||
| mock_provider_entity.configurate_methods = [ConfigurateMethod.CUSTOMIZABLE_MODEL] | |||||
| # Act | |||||
| with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): | |||||
| config = ProviderConfiguration( | |||||
| tenant_id="test_tenant", | |||||
| provider=mock_provider_entity, | |||||
| preferred_provider_type=ProviderType.SYSTEM, | |||||
| using_provider_type=ProviderType.SYSTEM, | |||||
| system_configuration=mock_system_configuration, | |||||
| custom_configuration=mock_custom_configuration, | |||||
| model_settings=[], | |||||
| ) | |||||
| # Assert | |||||
| assert ConfigurateMethod.PREDEFINED_MODEL in config.provider.configurate_methods | |||||
| def test_get_current_credentials_with_restricted_models(self, provider_configuration): | |||||
| """Test getting credentials with model restrictions""" | |||||
| # Arrange | |||||
| provider_configuration.using_provider_type = ProviderType.SYSTEM | |||||
| # Act | |||||
| credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-3.5-turbo") | |||||
| # Assert | |||||
| assert credentials is not None | |||||
| assert "openai_api_key" in credentials | |||||
| @patch("core.entities.provider_configuration.Session") | |||||
| def test_get_specific_provider_credential_success(self, mock_session, provider_configuration): | |||||
| """Test getting specific provider credential successfully""" | |||||
| # Arrange | |||||
| credential_id = "test_credential_id" | |||||
| mock_credential = Mock() | |||||
| mock_credential.encrypted_config = '{"openai_api_key": "encrypted_key"}' | |||||
| mock_session_instance = Mock() | |||||
| mock_session.return_value.__enter__.return_value = mock_session_instance | |||||
| mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_credential | |||||
| # Act | |||||
| with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get: | |||||
| mock_get.return_value = {"openai_api_key": "test_key"} | |||||
| result = provider_configuration._get_specific_provider_credential(credential_id) | |||||
| # Assert | |||||
| assert result == {"openai_api_key": "test_key"} | |||||
| @patch("core.entities.provider_configuration.Session") | |||||
| def test_get_specific_provider_credential_not_found(self, mock_session, provider_configuration): | |||||
| """Test getting specific provider credential when not found""" | |||||
| # Arrange | |||||
| credential_id = "nonexistent_credential_id" | |||||
| mock_session_instance = Mock() | |||||
| mock_session.return_value.__enter__.return_value = mock_session_instance | |||||
| mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None | |||||
| # Act & Assert | |||||
| with patch.object(provider_configuration, "_get_specific_provider_credential") as mock_get: | |||||
| mock_get.return_value = None | |||||
| result = provider_configuration._get_specific_provider_credential(credential_id) | |||||
| assert result is None | |||||
| # Act | |||||
| credentials = provider_configuration.get_current_credentials(ModelType.LLM, "gpt-4") | |||||
| # Assert | |||||
| assert credentials == {"openai_api_key": "test_key"} |
| # from core.entities.provider_entities import ModelSettings | |||||
| # from core.model_runtime.entities.model_entities import ModelType | |||||
| # from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory | |||||
| # from core.provider_manager import ProviderManager | |||||
| # from models.provider import LoadBalancingModelConfig, ProviderModelSetting | |||||
| # def test__to_model_settings(mocker): | |||||
| # # Get all provider entities | |||||
| # model_provider_factory = ModelProviderFactory("test_tenant") | |||||
| # provider_entities = model_provider_factory.get_providers() | |||||
| # provider_entity = None | |||||
| # for provider in provider_entities: | |||||
| # if provider.provider == "openai": | |||||
| # provider_entity = provider | |||||
| # # Mocking the inputs | |||||
| # provider_model_settings = [ | |||||
| # ProviderModelSetting( | |||||
| # id="id", | |||||
| # tenant_id="tenant_id", | |||||
| # provider_name="openai", | |||||
| # model_name="gpt-4", | |||||
| # model_type="text-generation", | |||||
| # enabled=True, | |||||
| # load_balancing_enabled=True, | |||||
| # ) | |||||
| # ] | |||||
| # load_balancing_model_configs = [ | |||||
| # LoadBalancingModelConfig( | |||||
| # id="id1", | |||||
| # tenant_id="tenant_id", | |||||
| # provider_name="openai", | |||||
| # model_name="gpt-4", | |||||
| # model_type="text-generation", | |||||
| # name="__inherit__", | |||||
| # encrypted_config=None, | |||||
| # enabled=True, | |||||
| # ), | |||||
| # LoadBalancingModelConfig( | |||||
| # id="id2", | |||||
| # tenant_id="tenant_id", | |||||
| # provider_name="openai", | |||||
| # model_name="gpt-4", | |||||
| # model_type="text-generation", | |||||
| # name="first", | |||||
| # encrypted_config='{"openai_api_key": "fake_key"}', | |||||
| # enabled=True, | |||||
| # ), | |||||
| # ] | |||||
| # mocker.patch( | |||||
| # "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} | |||||
| # ) | |||||
| # provider_manager = ProviderManager() | |||||
| # # Running the method | |||||
| # result = provider_manager._to_model_settings(provider_entity, | |||||
| # provider_model_settings, load_balancing_model_configs) | |||||
| # # Asserting that the result is as expected | |||||
| # assert len(result) == 1 | |||||
| # assert isinstance(result[0], ModelSettings) | |||||
| # assert result[0].model == "gpt-4" | |||||
| # assert result[0].model_type == ModelType.LLM | |||||
| # assert result[0].enabled is True | |||||
| # assert len(result[0].load_balancing_configs) == 2 | |||||
| # assert result[0].load_balancing_configs[0].name == "__inherit__" | |||||
| # assert result[0].load_balancing_configs[1].name == "first" | |||||
| # def test__to_model_settings_only_one_lb(mocker): | |||||
| # # Get all provider entities | |||||
| # model_provider_factory = ModelProviderFactory("test_tenant") | |||||
| # provider_entities = model_provider_factory.get_providers() | |||||
| # provider_entity = None | |||||
| # for provider in provider_entities: | |||||
| # if provider.provider == "openai": | |||||
| # provider_entity = provider | |||||
| # # Mocking the inputs | |||||
| # provider_model_settings = [ | |||||
| # ProviderModelSetting( | |||||
| # id="id", | |||||
| # tenant_id="tenant_id", | |||||
| # provider_name="openai", | |||||
| # model_name="gpt-4", | |||||
| # model_type="text-generation", | |||||
| # enabled=True, | |||||
| # load_balancing_enabled=True, | |||||
| # ) | |||||
| # ] | |||||
| # load_balancing_model_configs = [ | |||||
| # LoadBalancingModelConfig( | |||||
| # id="id1", | |||||
| # tenant_id="tenant_id", | |||||
| # provider_name="openai", | |||||
| # model_name="gpt-4", | |||||
| # model_type="text-generation", | |||||
| # name="__inherit__", | |||||
| # encrypted_config=None, | |||||
| # enabled=True, | |||||
| # ) | |||||
| # ] | |||||
| # mocker.patch( | |||||
| # "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} | |||||
| # ) | |||||
| # provider_manager = ProviderManager() | |||||
| # # Running the method | |||||
| # result = provider_manager._to_model_settings( | |||||
| # provider_entity, provider_model_settings, load_balancing_model_configs) | |||||
| # # Asserting that the result is as expected | |||||
| # assert len(result) == 1 | |||||
| # assert isinstance(result[0], ModelSettings) | |||||
| # assert result[0].model == "gpt-4" | |||||
| # assert result[0].model_type == ModelType.LLM | |||||
| # assert result[0].enabled is True | |||||
| # assert len(result[0].load_balancing_configs) == 0 | |||||
| # def test__to_model_settings_lb_disabled(mocker): | |||||
| # # Get all provider entities | |||||
| # model_provider_factory = ModelProviderFactory("test_tenant") | |||||
| # provider_entities = model_provider_factory.get_providers() | |||||
| # provider_entity = None | |||||
| # for provider in provider_entities: | |||||
| # if provider.provider == "openai": | |||||
| # provider_entity = provider | |||||
| # # Mocking the inputs | |||||
| # provider_model_settings = [ | |||||
| # ProviderModelSetting( | |||||
| # id="id", | |||||
| # tenant_id="tenant_id", | |||||
| # provider_name="openai", | |||||
| # model_name="gpt-4", | |||||
| # model_type="text-generation", | |||||
| # enabled=True, | |||||
| # load_balancing_enabled=False, | |||||
| # ) | |||||
| # ] | |||||
| # load_balancing_model_configs = [ | |||||
| # LoadBalancingModelConfig( | |||||
| # id="id1", | |||||
| # tenant_id="tenant_id", | |||||
| # provider_name="openai", | |||||
| # model_name="gpt-4", | |||||
| # model_type="text-generation", | |||||
| # name="__inherit__", | |||||
| # encrypted_config=None, | |||||
| # enabled=True, | |||||
| # ), | |||||
| # LoadBalancingModelConfig( | |||||
| # id="id2", | |||||
| # tenant_id="tenant_id", | |||||
| # provider_name="openai", | |||||
| # model_name="gpt-4", | |||||
| # model_type="text-generation", | |||||
| # name="first", | |||||
| # encrypted_config='{"openai_api_key": "fake_key"}', | |||||
| # enabled=True, | |||||
| # ), | |||||
| # ] | |||||
| # mocker.patch( | |||||
| # "core.helper.model_provider_cache.ProviderCredentialsCache.get", | |||||
| # return_value={"openai_api_key": "fake_key"} | |||||
| # ) | |||||
| # provider_manager = ProviderManager() | |||||
| # # Running the method | |||||
| # result = provider_manager._to_model_settings(provider_entity, | |||||
| # provider_model_settings, load_balancing_model_configs) | |||||
| # # Asserting that the result is as expected | |||||
| # assert len(result) == 1 | |||||
| # assert isinstance(result[0], ModelSettings) | |||||
| # assert result[0].model == "gpt-4" | |||||
| # assert result[0].model_type == ModelType.LLM | |||||
| # assert result[0].enabled is True | |||||
| # assert len(result[0].load_balancing_configs) == 0 | |||||
| import pytest | |||||
| from core.entities.provider_entities import ModelSettings | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.provider_manager import ProviderManager | |||||
| from models.provider import LoadBalancingModelConfig, ProviderModelSetting | |||||
| @pytest.fixture | |||||
| def mock_provider_entity(mocker): | |||||
| mock_entity = mocker.Mock() | |||||
| mock_entity.provider = "openai" | |||||
| mock_entity.configurate_methods = ["predefined-model"] | |||||
| mock_entity.supported_model_types = [ModelType.LLM] | |||||
| mock_entity.model_credential_schema = mocker.Mock() | |||||
| mock_entity.model_credential_schema.credential_form_schemas = [] | |||||
| return mock_entity | |||||
| def test__to_model_settings(mocker, mock_provider_entity): | |||||
| # Mocking the inputs | |||||
| provider_model_settings = [ | |||||
| ProviderModelSetting( | |||||
| id="id", | |||||
| tenant_id="tenant_id", | |||||
| provider_name="openai", | |||||
| model_name="gpt-4", | |||||
| model_type="text-generation", | |||||
| enabled=True, | |||||
| load_balancing_enabled=True, | |||||
| ) | |||||
| ] | |||||
| load_balancing_model_configs = [ | |||||
| LoadBalancingModelConfig( | |||||
| id="id1", | |||||
| tenant_id="tenant_id", | |||||
| provider_name="openai", | |||||
| model_name="gpt-4", | |||||
| model_type="text-generation", | |||||
| name="__inherit__", | |||||
| encrypted_config=None, | |||||
| enabled=True, | |||||
| ), | |||||
| LoadBalancingModelConfig( | |||||
| id="id2", | |||||
| tenant_id="tenant_id", | |||||
| provider_name="openai", | |||||
| model_name="gpt-4", | |||||
| model_type="text-generation", | |||||
| name="first", | |||||
| encrypted_config='{"openai_api_key": "fake_key"}', | |||||
| enabled=True, | |||||
| ), | |||||
| ] | |||||
| mocker.patch( | |||||
| "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} | |||||
| ) | |||||
| provider_manager = ProviderManager() | |||||
| # Running the method | |||||
| result = provider_manager._to_model_settings( | |||||
| provider_entity=mock_provider_entity, | |||||
| provider_model_settings=provider_model_settings, | |||||
| load_balancing_model_configs=load_balancing_model_configs, | |||||
| ) | |||||
| # Asserting that the result is as expected | |||||
| assert len(result) == 1 | |||||
| assert isinstance(result[0], ModelSettings) | |||||
| assert result[0].model == "gpt-4" | |||||
| assert result[0].model_type == ModelType.LLM | |||||
| assert result[0].enabled is True | |||||
| assert len(result[0].load_balancing_configs) == 2 | |||||
| assert result[0].load_balancing_configs[0].name == "__inherit__" | |||||
| assert result[0].load_balancing_configs[1].name == "first" | |||||
| def test__to_model_settings_only_one_lb(mocker, mock_provider_entity): | |||||
| # Mocking the inputs | |||||
| provider_model_settings = [ | |||||
| ProviderModelSetting( | |||||
| id="id", | |||||
| tenant_id="tenant_id", | |||||
| provider_name="openai", | |||||
| model_name="gpt-4", | |||||
| model_type="text-generation", | |||||
| enabled=True, | |||||
| load_balancing_enabled=True, | |||||
| ) | |||||
| ] | |||||
| load_balancing_model_configs = [ | |||||
| LoadBalancingModelConfig( | |||||
| id="id1", | |||||
| tenant_id="tenant_id", | |||||
| provider_name="openai", | |||||
| model_name="gpt-4", | |||||
| model_type="text-generation", | |||||
| name="__inherit__", | |||||
| encrypted_config=None, | |||||
| enabled=True, | |||||
| ) | |||||
| ] | |||||
| mocker.patch( | |||||
| "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} | |||||
| ) | |||||
| provider_manager = ProviderManager() | |||||
| # Running the method | |||||
| result = provider_manager._to_model_settings( | |||||
| provider_entity=mock_provider_entity, | |||||
| provider_model_settings=provider_model_settings, | |||||
| load_balancing_model_configs=load_balancing_model_configs, | |||||
| ) | |||||
| # Asserting that the result is as expected | |||||
| assert len(result) == 1 | |||||
| assert isinstance(result[0], ModelSettings) | |||||
| assert result[0].model == "gpt-4" | |||||
| assert result[0].model_type == ModelType.LLM | |||||
| assert result[0].enabled is True | |||||
| assert len(result[0].load_balancing_configs) == 0 | |||||
| def test__to_model_settings_lb_disabled(mocker, mock_provider_entity): | |||||
| # Mocking the inputs | |||||
| provider_model_settings = [ | |||||
| ProviderModelSetting( | |||||
| id="id", | |||||
| tenant_id="tenant_id", | |||||
| provider_name="openai", | |||||
| model_name="gpt-4", | |||||
| model_type="text-generation", | |||||
| enabled=True, | |||||
| load_balancing_enabled=False, | |||||
| ) | |||||
| ] | |||||
| load_balancing_model_configs = [ | |||||
| LoadBalancingModelConfig( | |||||
| id="id1", | |||||
| tenant_id="tenant_id", | |||||
| provider_name="openai", | |||||
| model_name="gpt-4", | |||||
| model_type="text-generation", | |||||
| name="__inherit__", | |||||
| encrypted_config=None, | |||||
| enabled=True, | |||||
| ), | |||||
| LoadBalancingModelConfig( | |||||
| id="id2", | |||||
| tenant_id="tenant_id", | |||||
| provider_name="openai", | |||||
| model_name="gpt-4", | |||||
| model_type="text-generation", | |||||
| name="first", | |||||
| encrypted_config='{"openai_api_key": "fake_key"}', | |||||
| enabled=True, | |||||
| ), | |||||
| ] | |||||
| mocker.patch( | |||||
| "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} | |||||
| ) | |||||
| provider_manager = ProviderManager() | |||||
| # Running the method | |||||
| result = provider_manager._to_model_settings( | |||||
| provider_entity=mock_provider_entity, | |||||
| provider_model_settings=provider_model_settings, | |||||
| load_balancing_model_configs=load_balancing_model_configs, | |||||
| ) | |||||
| # Asserting that the result is as expected | |||||
| assert len(result) == 1 | |||||
| assert isinstance(result[0], ModelSettings) | |||||
| assert result[0].model == "gpt-4" | |||||
| assert result[0].model_type == ModelType.LLM | |||||
| assert result[0].enabled is True | |||||
| assert len(result[0].load_balancing_configs) == 0 |
| inputClassName, | inputClassName, | ||||
| formSchema, | formSchema, | ||||
| field, | field, | ||||
| disabled, | |||||
| disabled: propsDisabled, | |||||
| }: BaseFieldProps) => { | }: BaseFieldProps) => { | ||||
| const renderI18nObject = useRenderI18nObject() | const renderI18nObject = useRenderI18nObject() | ||||
| const { | const { | ||||
| options, | options, | ||||
| labelClassName: formLabelClassName, | labelClassName: formLabelClassName, | ||||
| show_on = [], | show_on = [], | ||||
| disabled: formSchemaDisabled, | |||||
| } = formSchema | } = formSchema | ||||
| const disabled = propsDisabled || formSchemaDisabled | |||||
| const memorizedLabel = useMemo(() => { | const memorizedLabel = useMemo(() => { | ||||
| if (isValidElement(label)) | if (isValidElement(label)) | ||||
| }) | }) | ||||
| const memorizedOptions = useMemo(() => { | const memorizedOptions = useMemo(() => { | ||||
| return options?.filter((option) => { | return options?.filter((option) => { | ||||
| if (!option.show_on?.length) | |||||
| if (!option.show_on || option.show_on.length === 0) | |||||
| return true | return true | ||||
| return option.show_on.every((condition) => { | return option.show_on.every((condition) => { | ||||
| value: option.value, | value: option.value, | ||||
| } | } | ||||
| }) || [] | }) || [] | ||||
| }, [options, renderI18nObject]) | |||||
| }, [options, renderI18nObject, optionValues]) | |||||
| const value = useStore(field.form.store, s => s.values[field.name]) | const value = useStore(field.form.store, s => s.values[field.name]) | ||||
| const values = useStore(field.form.store, (s) => { | const values = useStore(field.form.store, (s) => { | ||||
| return show_on.reduce((acc, condition) => { | return show_on.reduce((acc, condition) => { | ||||
| className={cn( | className={cn( | ||||
| 'system-sm-regular hover:bg-components-option-card-option-hover-bg hover:border-components-option-card-option-hover-border flex h-8 flex-[1] grow cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg p-2 text-text-secondary', | 'system-sm-regular hover:bg-components-option-card-option-hover-bg hover:border-components-option-card-option-hover-border flex h-8 flex-[1] grow cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg p-2 text-text-secondary', | ||||
| value === option.value && 'border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary shadow-xs', | value === option.value && 'border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary shadow-xs', | ||||
| disabled && 'cursor-not-allowed opacity-50', | |||||
| inputClassName, | inputClassName, | ||||
| )} | )} | ||||
| onClick={() => field.handleChange(option.value)} | |||||
| onClick={() => !disabled && field.handleChange(option.value)} | |||||
| > | > | ||||
| { | { | ||||
| formSchema.showRadioUI && ( | formSchema.showRadioUI && ( |
| import { useCallback } from 'react' | |||||
| import { | |||||
| isValidElement, | |||||
| useCallback, | |||||
| } from 'react' | |||||
| import type { ReactNode } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | import { useTranslation } from 'react-i18next' | ||||
| import type { FormSchema } from '../types' | import type { FormSchema } from '../types' | ||||
| import { useRenderI18nObject } from '@/hooks/use-i18n' | |||||
| export const useGetValidators = () => { | export const useGetValidators = () => { | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const renderI18nObject = useRenderI18nObject() | |||||
| const getLabel = useCallback((label: string | Record<string, string> | ReactNode) => { | |||||
| if (isValidElement(label)) | |||||
| return '' | |||||
| if (typeof label === 'string') | |||||
| return label | |||||
| if (typeof label === 'object' && label !== null) | |||||
| return renderI18nObject(label as Record<string, string>) | |||||
| }, []) | |||||
| const getValidators = useCallback((formSchema: FormSchema) => { | const getValidators = useCallback((formSchema: FormSchema) => { | ||||
| const { | const { | ||||
| name, | name, | ||||
| validators, | validators, | ||||
| required, | required, | ||||
| label, | |||||
| } = formSchema | } = formSchema | ||||
| let mergedValidators = validators | let mergedValidators = validators | ||||
| const memorizedLabel = getLabel(label) | |||||
| if (required && !validators) { | if (required && !validators) { | ||||
| mergedValidators = { | mergedValidators = { | ||||
| onMount: ({ value }: any) => { | onMount: ({ value }: any) => { | ||||
| if (!value) | if (!value) | ||||
| return t('common.errorMsg.fieldRequired', { field: name }) | |||||
| return t('common.errorMsg.fieldRequired', { field: memorizedLabel || name }) | |||||
| }, | }, | ||||
| onChange: ({ value }: any) => { | onChange: ({ value }: any) => { | ||||
| if (!value) | if (!value) | ||||
| return t('common.errorMsg.fieldRequired', { field: name }) | |||||
| return t('common.errorMsg.fieldRequired', { field: memorizedLabel || name }) | |||||
| }, | }, | ||||
| onBlur: ({ value }: any) => { | onBlur: ({ value }: any) => { | ||||
| if (!value) | if (!value) | ||||
| return t('common.errorMsg.fieldRequired', { field: name }) | |||||
| return t('common.errorMsg.fieldRequired', { field: memorizedLabel }) | |||||
| }, | }, | ||||
| } | } | ||||
| } | } | ||||
| return mergedValidators | return mergedValidators | ||||
| }, [t]) | |||||
| }, [t, getLabel]) | |||||
| return { | return { | ||||
| getValidators, | getValidators, |
| labelClassName?: string | labelClassName?: string | ||||
| validators?: AnyValidators | validators?: AnyValidators | ||||
| showRadioUI?: boolean | showRadioUI?: boolean | ||||
| disabled?: boolean | |||||
| } | } | ||||
| export type FormValues = Record<string, any> | export type FormValues = Record<string, any> |
| quotaExceeded = 'quota-exceeded', | quotaExceeded = 'quota-exceeded', | ||||
| noPermission = 'no-permission', | noPermission = 'no-permission', | ||||
| disabled = 'disabled', | disabled = 'disabled', | ||||
| credentialRemoved = 'credential-removed', | |||||
| } | } | ||||
| export const MODEL_STATUS_TEXT: { [k: string]: TypeWithI18N } = { | export const MODEL_STATUS_TEXT: { [k: string]: TypeWithI18N } = { | ||||
| model_properties: Record<string, string | number> | model_properties: Record<string, string | number> | ||||
| load_balancing_enabled: boolean | load_balancing_enabled: boolean | ||||
| deprecated?: boolean | deprecated?: boolean | ||||
| has_invalid_load_balancing_configs?: boolean | |||||
| } | } | ||||
| export enum PreferredProviderTypeEnum { | export enum PreferredProviderTypeEnum { | ||||
| is_valid: boolean | is_valid: boolean | ||||
| } | } | ||||
| export type Credential = { | |||||
| credential_id: string | |||||
| credential_name?: string | |||||
| from_enterprise?: boolean | |||||
| not_allowed_to_use?: boolean | |||||
| } | |||||
| export type CustomModel = { | |||||
| model: string | |||||
| model_type: ModelTypeEnum | |||||
| } | |||||
| export type CustomModelCredential = CustomModel & { | |||||
| credentials?: Record<string, any> | |||||
| available_model_credentials?: Credential[] | |||||
| current_credential_id?: string | |||||
| } | |||||
| export type CredentialWithModel = Credential & { | |||||
| model: string | |||||
| model_type: ModelTypeEnum | |||||
| } | |||||
| export type ModelProvider = { | export type ModelProvider = { | ||||
| provider: string | provider: string | ||||
| label: TypeWithI18N | label: TypeWithI18N | ||||
| preferred_provider_type: PreferredProviderTypeEnum | preferred_provider_type: PreferredProviderTypeEnum | ||||
| custom_configuration: { | custom_configuration: { | ||||
| status: CustomConfigurationStatusEnum | status: CustomConfigurationStatusEnum | ||||
| current_credential_id?: string | |||||
| current_credential_name?: string | |||||
| available_credentials?: Credential[] | |||||
| custom_models?: CustomModelCredential[] | |||||
| } | } | ||||
| system_configuration: { | system_configuration: { | ||||
| enabled: boolean | enabled: boolean | ||||
| current_quota_type: CurrentSystemQuotaTypeEnum | current_quota_type: CurrentSystemQuotaTypeEnum | ||||
| quota_configurations: QuotaConfiguration[] | quota_configurations: QuotaConfiguration[] | ||||
| } | } | ||||
| allow_custom_token?: boolean | |||||
| } | } | ||||
| export type Model = { | export type Model = { | ||||
| in_cooldown?: boolean | in_cooldown?: boolean | ||||
| /** cooldown time (in seconds) */ | /** cooldown time (in seconds) */ | ||||
| ttl?: number | ttl?: number | ||||
| credential_id?: string | |||||
| } | } | ||||
| export type ModelLoadBalancingConfig = { | export type ModelLoadBalancingConfig = { | ||||
| enabled: boolean | enabled: boolean | ||||
| configs: ModelLoadBalancingConfigEntry[] | configs: ModelLoadBalancingConfigEntry[] | ||||
| } | } | ||||
| export type ProviderCredential = { | |||||
| credentials: Record<string, any> | |||||
| name: string | |||||
| credential_id: string | |||||
| } | |||||
| export type ModelCredential = { | |||||
| credentials: Record<string, any> | |||||
| load_balancing: ModelLoadBalancingConfig | |||||
| available_credentials: Credential[] | |||||
| current_credential_id?: string | |||||
| current_credential_name?: string | |||||
| } |
| import useSWR, { useSWRConfig } from 'swr' | import useSWR, { useSWRConfig } from 'swr' | ||||
| import { useContext } from 'use-context-selector' | import { useContext } from 'use-context-selector' | ||||
| import type { | import type { | ||||
| Credential, | |||||
| CustomConfigurationModelFixedFields, | CustomConfigurationModelFixedFields, | ||||
| CustomModel, | |||||
| DefaultModel, | DefaultModel, | ||||
| DefaultModelResponse, | DefaultModelResponse, | ||||
| Model, | Model, | ||||
| configurationMethod: ConfigurationMethodEnum, | configurationMethod: ConfigurationMethodEnum, | ||||
| configured?: boolean, | configured?: boolean, | ||||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | ||||
| credentialId?: string, | |||||
| ) => { | ) => { | ||||
| const { data: predefinedFormSchemasValue, mutate: mutatePredefined } = useSWR( | |||||
| (configurationMethod === ConfigurationMethodEnum.predefinedModel && configured) | |||||
| ? `/workspaces/current/model-providers/${provider}/credentials` | |||||
| const { data: predefinedFormSchemasValue, mutate: mutatePredefined, isLoading: isPredefinedLoading } = useSWR( | |||||
| (configurationMethod === ConfigurationMethodEnum.predefinedModel && configured && credentialId) | |||||
| ? `/workspaces/current/model-providers/${provider}/credentials${credentialId ? `?credential_id=${credentialId}` : ''}` | |||||
| : null, | : null, | ||||
| fetchModelProviderCredentials, | fetchModelProviderCredentials, | ||||
| ) | ) | ||||
| const { data: customFormSchemasValue, mutate: mutateCustomized } = useSWR( | |||||
| (configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields) | |||||
| ? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}` | |||||
| const { data: customFormSchemasValue, mutate: mutateCustomized, isLoading: isCustomizedLoading } = useSWR( | |||||
| (configurationMethod === ConfigurationMethodEnum.customizableModel && currentCustomConfigurationModelFixedFields && credentialId) | |||||
| ? `/workspaces/current/model-providers/${provider}/models/credentials?model=${currentCustomConfigurationModelFixedFields?.__model_name}&model_type=${currentCustomConfigurationModelFixedFields?.__model_type}${credentialId ? `&credential_id=${credentialId}` : ''}` | |||||
| : null, | : null, | ||||
| fetchModelProviderCredentials, | fetchModelProviderCredentials, | ||||
| ) | ) | ||||
| : undefined | : undefined | ||||
| }, [ | }, [ | ||||
| configurationMethod, | configurationMethod, | ||||
| credentialId, | |||||
| currentCustomConfigurationModelFixedFields, | currentCustomConfigurationModelFixedFields, | ||||
| customFormSchemasValue?.credentials, | customFormSchemasValue?.credentials, | ||||
| predefinedFormSchemasValue?.credentials, | predefinedFormSchemasValue?.credentials, | ||||
| : customFormSchemasValue | : customFormSchemasValue | ||||
| )?.load_balancing, | )?.load_balancing, | ||||
| mutate, | mutate, | ||||
| isLoading: isPredefinedLoading || isCustomizedLoading, | |||||
| } | } | ||||
| // as ([Record<string, string | boolean | undefined> | undefined, ModelLoadBalancingConfig | undefined]) | // as ([Record<string, string | boolean | undefined> | undefined, ModelLoadBalancingConfig | undefined]) | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| export const useModelModalHandler = () => { | |||||
| const setShowModelModal = useModalContextSelector(state => state.setShowModelModal) | |||||
| export const useRefreshModel = () => { | |||||
| const { eventEmitter } = useEventEmitterContextContext() | |||||
| const updateModelProviders = useUpdateModelProviders() | const updateModelProviders = useUpdateModelProviders() | ||||
| const updateModelList = useUpdateModelList() | const updateModelList = useUpdateModelList() | ||||
| const { eventEmitter } = useEventEmitterContextContext() | |||||
| const handleRefreshModel = useCallback((provider: ModelProvider, configurationMethod: ConfigurationMethodEnum, CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => { | |||||
| updateModelProviders() | |||||
| provider.supported_model_types.forEach((type) => { | |||||
| updateModelList(type) | |||||
| }) | |||||
| if (configurationMethod === ConfigurationMethodEnum.customizableModel | |||||
| && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) { | |||||
| eventEmitter?.emit({ | |||||
| type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST, | |||||
| payload: provider.provider, | |||||
| } as any) | |||||
| if (CustomConfigurationModelFixedFields?.__model_type) | |||||
| updateModelList(CustomConfigurationModelFixedFields.__model_type) | |||||
| } | |||||
| }, [eventEmitter, updateModelList, updateModelProviders]) | |||||
| return { | |||||
| handleRefreshModel, | |||||
| } | |||||
| } | |||||
| export const useModelModalHandler = () => { | |||||
| const setShowModelModal = useModalContextSelector(state => state.setShowModelModal) | |||||
| const { handleRefreshModel } = useRefreshModel() | |||||
| return ( | return ( | ||||
| provider: ModelProvider, | provider: ModelProvider, | ||||
| configurationMethod: ConfigurationMethodEnum, | configurationMethod: ConfigurationMethodEnum, | ||||
| CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | ||||
| isModelCredential?: boolean, | |||||
| credential?: Credential, | |||||
| model?: CustomModel, | |||||
| onUpdate?: () => void, | |||||
| ) => { | ) => { | ||||
| setShowModelModal({ | setShowModelModal({ | ||||
| payload: { | payload: { | ||||
| currentProvider: provider, | currentProvider: provider, | ||||
| currentConfigurationMethod: configurationMethod, | currentConfigurationMethod: configurationMethod, | ||||
| currentCustomConfigurationModelFixedFields: CustomConfigurationModelFixedFields, | currentCustomConfigurationModelFixedFields: CustomConfigurationModelFixedFields, | ||||
| isModelCredential, | |||||
| credential, | |||||
| model, | |||||
| }, | }, | ||||
| onSaveCallback: () => { | onSaveCallback: () => { | ||||
| updateModelProviders() | |||||
| provider.supported_model_types.forEach((type) => { | |||||
| updateModelList(type) | |||||
| }) | |||||
| if (configurationMethod === ConfigurationMethodEnum.customizableModel | |||||
| && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) { | |||||
| eventEmitter?.emit({ | |||||
| type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST, | |||||
| payload: provider.provider, | |||||
| } as any) | |||||
| if (CustomConfigurationModelFixedFields?.__model_type) | |||||
| updateModelList(CustomConfigurationModelFixedFields.__model_type) | |||||
| } | |||||
| handleRefreshModel(provider, configurationMethod, CustomConfigurationModelFixedFields) | |||||
| onUpdate?.() | |||||
| }, | }, | ||||
| }) | }) | ||||
| } | } |
| import SystemModelSelector from './system-model-selector' | import SystemModelSelector from './system-model-selector' | ||||
| import ProviderAddedCard from './provider-added-card' | import ProviderAddedCard from './provider-added-card' | ||||
| import type { | import type { | ||||
| ConfigurationMethodEnum, | |||||
| CustomConfigurationModelFixedFields, | |||||
| ModelProvider, | ModelProvider, | ||||
| } from './declarations' | } from './declarations' | ||||
| import { | import { | ||||
| } from './declarations' | } from './declarations' | ||||
| import { | import { | ||||
| useDefaultModel, | useDefaultModel, | ||||
| useModelModalHandler, | |||||
| } from './hooks' | } from './hooks' | ||||
| import InstallFromMarketplace from './install-from-marketplace' | import InstallFromMarketplace from './install-from-marketplace' | ||||
| import { useProviderContext } from '@/context/provider-context' | import { useProviderContext } from '@/context/provider-context' | ||||
| return [filteredConfiguredProviders, filteredNotConfiguredProviders] | return [filteredConfiguredProviders, filteredNotConfiguredProviders] | ||||
| }, [configuredProviders, debouncedSearchText, notConfiguredProviders]) | }, [configuredProviders, debouncedSearchText, notConfiguredProviders]) | ||||
| const handleOpenModal = useModelModalHandler() | |||||
| return ( | return ( | ||||
| <div className='relative -mt-2 pt-1'> | <div className='relative -mt-2 pt-1'> | ||||
| <div className={cn('mb-2 flex items-center')}> | <div className={cn('mb-2 flex items-center')}> | ||||
| <ProviderAddedCard | <ProviderAddedCard | ||||
| key={provider.provider} | key={provider.provider} | ||||
| provider={provider} | provider={provider} | ||||
| onOpenModal={(configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => handleOpenModal(provider, configurationMethod, currentCustomConfigurationModelFixedFields)} | |||||
| /> | /> | ||||
| ))} | ))} | ||||
| </div> | </div> | ||||
| notConfigured | notConfigured | ||||
| key={provider.provider} | key={provider.provider} | ||||
| provider={provider} | provider={provider} | ||||
| onOpenModal={(configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => handleOpenModal(provider, configurationMethod, currentCustomConfigurationModelFixedFields)} | |||||
| /> | /> | ||||
| ))} | ))} | ||||
| </div> | </div> |
| import { | |||||
| memo, | |||||
| useCallback, | |||||
| useMemo, | |||||
| } from 'react' | |||||
| import { RiAddLine } from '@remixicon/react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import { Authorized } from '@/app/components/header/account-setting/model-provider-page/model-auth' | |||||
| import cn from '@/utils/classnames' | |||||
| import type { | |||||
| Credential, | |||||
| CustomModelCredential, | |||||
| ModelCredential, | |||||
| ModelProvider, | |||||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||||
| import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||||
| import Tooltip from '@/app/components/base/tooltip' | |||||
| type AddCredentialInLoadBalancingProps = { | |||||
| provider: ModelProvider | |||||
| model: CustomModelCredential | |||||
| configurationMethod: ConfigurationMethodEnum | |||||
| modelCredential: ModelCredential | |||||
| onSelectCredential: (credential: Credential) => void | |||||
| onUpdate?: () => void | |||||
| } | |||||
| const AddCredentialInLoadBalancing = ({ | |||||
| provider, | |||||
| model, | |||||
| configurationMethod, | |||||
| modelCredential, | |||||
| onSelectCredential, | |||||
| onUpdate, | |||||
| }: AddCredentialInLoadBalancingProps) => { | |||||
| const { t } = useTranslation() | |||||
| const { | |||||
| available_credentials, | |||||
| } = modelCredential | |||||
| const customModel = configurationMethod === ConfigurationMethodEnum.customizableModel | |||||
| const notAllowCustomCredential = provider.allow_custom_token === false | |||||
| const ButtonComponent = useMemo(() => { | |||||
| const Item = ( | |||||
| <div className={cn( | |||||
| 'system-sm-medium flex h-8 items-center rounded-lg px-3 text-text-accent hover:bg-state-base-hover', | |||||
| notAllowCustomCredential && 'cursor-not-allowed opacity-50', | |||||
| )}> | |||||
| <RiAddLine className='mr-2 h-4 w-4' /> | |||||
| { | |||||
| customModel | |||||
| ? t('common.modelProvider.auth.addCredential') | |||||
| : t('common.modelProvider.auth.addApiKey') | |||||
| } | |||||
| </div> | |||||
| ) | |||||
| if (notAllowCustomCredential) { | |||||
| return ( | |||||
| <Tooltip | |||||
| asChild | |||||
| popupContent={t('plugin.auth.credentialUnavailable')} | |||||
| > | |||||
| {Item} | |||||
| </Tooltip> | |||||
| ) | |||||
| } | |||||
| return Item | |||||
| }, [notAllowCustomCredential, t, customModel]) | |||||
| const renderTrigger = useCallback((open?: boolean) => { | |||||
| const Item = ( | |||||
| <div className={cn( | |||||
| 'system-sm-medium flex h-8 items-center rounded-lg px-3 text-text-accent hover:bg-state-base-hover', | |||||
| open && 'bg-state-base-hover', | |||||
| )}> | |||||
| <RiAddLine className='mr-2 h-4 w-4' /> | |||||
| { | |||||
| customModel | |||||
| ? t('common.modelProvider.auth.addCredential') | |||||
| : t('common.modelProvider.auth.addApiKey') | |||||
| } | |||||
| </div> | |||||
| ) | |||||
| return Item | |||||
| }, [t, customModel]) | |||||
| if (!available_credentials?.length) | |||||
| return ButtonComponent | |||||
| return ( | |||||
| <Authorized | |||||
| provider={provider} | |||||
| renderTrigger={renderTrigger} | |||||
| items={[ | |||||
| { | |||||
| title: customModel ? t('common.modelProvider.auth.modelCredentials') : t('common.modelProvider.auth.apiKeys'), | |||||
| model: customModel ? model : undefined, | |||||
| credentials: available_credentials ?? [], | |||||
| }, | |||||
| ]} | |||||
| configurationMethod={configurationMethod} | |||||
| currentCustomConfigurationModelFixedFields={customModel ? { | |||||
| __model_name: model.model, | |||||
| __model_type: model.model_type, | |||||
| } : undefined} | |||||
| onItemClick={onSelectCredential} | |||||
| placement='bottom-start' | |||||
| onUpdate={onUpdate} | |||||
| isModelCredential={customModel} | |||||
| /> | |||||
| ) | |||||
| } | |||||
| export default memo(AddCredentialInLoadBalancing) |
| import { | |||||
| memo, | |||||
| useCallback, | |||||
| useMemo, | |||||
| } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import { | |||||
| RiAddCircleFill, | |||||
| } from '@remixicon/react' | |||||
| import { | |||||
| Button, | |||||
| } from '@/app/components/base/button' | |||||
| import type { | |||||
| CustomConfigurationModelFixedFields, | |||||
| ModelProvider, | |||||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||||
| import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||||
| import Authorized from './authorized' | |||||
| import { | |||||
| useAuth, | |||||
| useCustomModels, | |||||
| } from './hooks' | |||||
| import cn from '@/utils/classnames' | |||||
| import Tooltip from '@/app/components/base/tooltip' | |||||
| type AddCustomModelProps = { | |||||
| provider: ModelProvider, | |||||
| configurationMethod: ConfigurationMethodEnum, | |||||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | |||||
| } | |||||
| const AddCustomModel = ({ | |||||
| provider, | |||||
| configurationMethod, | |||||
| currentCustomConfigurationModelFixedFields, | |||||
| }: AddCustomModelProps) => { | |||||
| const { t } = useTranslation() | |||||
| const customModels = useCustomModels(provider) | |||||
| const noModels = !customModels.length | |||||
| const { | |||||
| handleOpenModal, | |||||
| } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, true) | |||||
| const notAllowCustomCredential = provider.allow_custom_token === false | |||||
| const handleClick = useCallback(() => { | |||||
| if (notAllowCustomCredential) | |||||
| return | |||||
| handleOpenModal() | |||||
| }, [handleOpenModal, notAllowCustomCredential]) | |||||
| const ButtonComponent = useMemo(() => { | |||||
| const Item = ( | |||||
| <Button | |||||
| variant='ghost-accent' | |||||
| size='small' | |||||
| onClick={handleClick} | |||||
| className={cn( | |||||
| notAllowCustomCredential && 'cursor-not-allowed opacity-50', | |||||
| )} | |||||
| > | |||||
| <RiAddCircleFill className='mr-1 h-3.5 w-3.5' /> | |||||
| {t('common.modelProvider.addModel')} | |||||
| </Button> | |||||
| ) | |||||
| if (notAllowCustomCredential) { | |||||
| return ( | |||||
| <Tooltip | |||||
| asChild | |||||
| popupContent={t('plugin.auth.credentialUnavailable')} | |||||
| > | |||||
| {Item} | |||||
| </Tooltip> | |||||
| ) | |||||
| } | |||||
| return Item | |||||
| }, [handleClick, notAllowCustomCredential, t]) | |||||
| const renderTrigger = useCallback((open?: boolean) => { | |||||
| const Item = ( | |||||
| <Button | |||||
| variant='ghost' | |||||
| size='small' | |||||
| className={cn( | |||||
| open && 'bg-components-button-ghost-bg-hover', | |||||
| )} | |||||
| > | |||||
| <RiAddCircleFill className='mr-1 h-3.5 w-3.5' /> | |||||
| {t('common.modelProvider.addModel')} | |||||
| </Button> | |||||
| ) | |||||
| return Item | |||||
| }, [t]) | |||||
| if (noModels) | |||||
| return ButtonComponent | |||||
| return ( | |||||
| <Authorized | |||||
| provider={provider} | |||||
| configurationMethod={ConfigurationMethodEnum.customizableModel} | |||||
| items={customModels.map(model => ({ | |||||
| model, | |||||
| credentials: model.available_model_credentials ?? [], | |||||
| }))} | |||||
| renderTrigger={renderTrigger} | |||||
| isModelCredential | |||||
| enableAddModelCredential | |||||
| bottomAddModelCredentialText={t('common.modelProvider.auth.addNewModel')} | |||||
| /> | |||||
| ) | |||||
| } | |||||
| export default memo(AddCustomModel) |
| import { | |||||
| memo, | |||||
| useCallback, | |||||
| } from 'react' | |||||
| import { RiAddLine } from '@remixicon/react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import CredentialItem from './credential-item' | |||||
| import type { | |||||
| Credential, | |||||
| CustomModel, | |||||
| CustomModelCredential, | |||||
| } from '../../declarations' | |||||
| import Button from '@/app/components/base/button' | |||||
| import Tooltip from '@/app/components/base/tooltip' | |||||
| type AuthorizedItemProps = { | |||||
| model?: CustomModelCredential | |||||
| title?: string | |||||
| disabled?: boolean | |||||
| onDelete?: (credential?: Credential, model?: CustomModel) => void | |||||
| onEdit?: (credential?: Credential, model?: CustomModel) => void | |||||
| showItemSelectedIcon?: boolean | |||||
| selectedCredentialId?: string | |||||
| credentials: Credential[] | |||||
| onItemClick?: (credential: Credential, model?: CustomModel) => void | |||||
| enableAddModelCredential?: boolean | |||||
| notAllowCustomCredential?: boolean | |||||
| } | |||||
| export const AuthorizedItem = ({ | |||||
| model, | |||||
| title, | |||||
| credentials, | |||||
| disabled, | |||||
| onDelete, | |||||
| onEdit, | |||||
| showItemSelectedIcon, | |||||
| selectedCredentialId, | |||||
| onItemClick, | |||||
| enableAddModelCredential, | |||||
| notAllowCustomCredential, | |||||
| }: AuthorizedItemProps) => { | |||||
| const { t } = useTranslation() | |||||
| const handleEdit = useCallback((credential?: Credential) => { | |||||
| onEdit?.(credential, model) | |||||
| }, [onEdit, model]) | |||||
| const handleDelete = useCallback((credential?: Credential) => { | |||||
| onDelete?.(credential, model) | |||||
| }, [onDelete, model]) | |||||
| const handleItemClick = useCallback((credential: Credential) => { | |||||
| onItemClick?.(credential, model) | |||||
| }, [onItemClick, model]) | |||||
| return ( | |||||
| <div className='p-1'> | |||||
| <div | |||||
| className='flex h-9 items-center' | |||||
| > | |||||
| <div className='h-5 w-5 shrink-0'></div> | |||||
| <div | |||||
| className='system-md-medium mx-1 grow truncate text-text-primary' | |||||
| title={title ?? model?.model} | |||||
| > | |||||
| {title ?? model?.model} | |||||
| </div> | |||||
| { | |||||
| enableAddModelCredential && !notAllowCustomCredential && ( | |||||
| <Tooltip | |||||
| asChild | |||||
| popupContent={t('common.modelProvider.auth.addModelCredential')} | |||||
| > | |||||
| <Button | |||||
| className='h-6 w-6 shrink-0 rounded-full p-0' | |||||
| size='small' | |||||
| variant='secondary-accent' | |||||
| onClick={() => handleEdit?.()} | |||||
| > | |||||
| <RiAddLine className='h-4 w-4' /> | |||||
| </Button> | |||||
| </Tooltip> | |||||
| ) | |||||
| } | |||||
| </div> | |||||
| { | |||||
| credentials.map(credential => ( | |||||
| <CredentialItem | |||||
| key={credential.credential_id} | |||||
| credential={credential} | |||||
| disabled={disabled} | |||||
| onDelete={handleDelete} | |||||
| onEdit={handleEdit} | |||||
| showSelectedIcon={showItemSelectedIcon} | |||||
| selectedCredentialId={selectedCredentialId} | |||||
| onItemClick={handleItemClick} | |||||
| /> | |||||
| )) | |||||
| } | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| export default memo(AuthorizedItem) |
| import { | |||||
| memo, | |||||
| useMemo, | |||||
| } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import { | |||||
| RiCheckLine, | |||||
| RiDeleteBinLine, | |||||
| RiEqualizer2Line, | |||||
| } from '@remixicon/react' | |||||
| import Indicator from '@/app/components/header/indicator' | |||||
| import ActionButton from '@/app/components/base/action-button' | |||||
| import Tooltip from '@/app/components/base/tooltip' | |||||
| import cn from '@/utils/classnames' | |||||
| import type { Credential } from '../../declarations' | |||||
| import Badge from '@/app/components/base/badge' | |||||
| type CredentialItemProps = { | |||||
| credential: Credential | |||||
| disabled?: boolean | |||||
| onDelete?: (credential: Credential) => void | |||||
| onEdit?: (credential?: Credential) => void | |||||
| onItemClick?: (credential: Credential) => void | |||||
| disableRename?: boolean | |||||
| disableEdit?: boolean | |||||
| disableDelete?: boolean | |||||
| showSelectedIcon?: boolean | |||||
| selectedCredentialId?: string | |||||
| } | |||||
| const CredentialItem = ({ | |||||
| credential, | |||||
| disabled, | |||||
| onDelete, | |||||
| onEdit, | |||||
| onItemClick, | |||||
| disableRename, | |||||
| disableEdit, | |||||
| disableDelete, | |||||
| showSelectedIcon, | |||||
| selectedCredentialId, | |||||
| }: CredentialItemProps) => { | |||||
| const { t } = useTranslation() | |||||
| const showAction = useMemo(() => { | |||||
| return !(disableRename && disableEdit && disableDelete) | |||||
| }, [disableRename, disableEdit, disableDelete]) | |||||
| const Item = ( | |||||
| <div | |||||
| key={credential.credential_id} | |||||
| className={cn( | |||||
| 'group flex h-8 items-center rounded-lg p-1 hover:bg-state-base-hover', | |||||
| (disabled || credential.not_allowed_to_use) && 'cursor-not-allowed opacity-50', | |||||
| )} | |||||
| onClick={() => { | |||||
| if (disabled || credential.not_allowed_to_use) | |||||
| return | |||||
| onItemClick?.(credential) | |||||
| }} | |||||
| > | |||||
| <div className='flex w-0 grow items-center space-x-1.5'> | |||||
| { | |||||
| showSelectedIcon && ( | |||||
| <div className='h-4 w-4'> | |||||
| { | |||||
| selectedCredentialId === credential.credential_id && ( | |||||
| <RiCheckLine className='h-4 w-4 text-text-accent' /> | |||||
| ) | |||||
| } | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| <Indicator className='ml-2 mr-1.5 shrink-0' /> | |||||
| <div | |||||
| className='system-md-regular truncate text-text-secondary' | |||||
| title={credential.credential_name} | |||||
| > | |||||
| {credential.credential_name} | |||||
| </div> | |||||
| </div> | |||||
| { | |||||
| credential.from_enterprise && ( | |||||
| <Badge className='shrink-0'> | |||||
| Enterprise | |||||
| </Badge> | |||||
| ) | |||||
| } | |||||
| { | |||||
| showAction && ( | |||||
| <div className='ml-2 hidden shrink-0 items-center group-hover:flex'> | |||||
| { | |||||
| !disableEdit && !credential.not_allowed_to_use && !credential.from_enterprise && ( | |||||
| <Tooltip popupContent={t('common.operation.edit')}> | |||||
| <ActionButton | |||||
| disabled={disabled} | |||||
| onClick={(e) => { | |||||
| e.stopPropagation() | |||||
| onEdit?.(credential) | |||||
| }} | |||||
| > | |||||
| <RiEqualizer2Line className='h-4 w-4 text-text-tertiary' /> | |||||
| </ActionButton> | |||||
| </Tooltip> | |||||
| ) | |||||
| } | |||||
| { | |||||
| !disableDelete && !credential.from_enterprise && ( | |||||
| <Tooltip popupContent={t('common.operation.delete')}> | |||||
| <ActionButton | |||||
| className='hover:bg-transparent' | |||||
| disabled={disabled} | |||||
| onClick={(e) => { | |||||
| e.stopPropagation() | |||||
| onDelete?.(credential) | |||||
| }} | |||||
| > | |||||
| <RiDeleteBinLine className='h-4 w-4 text-text-tertiary hover:text-text-destructive' /> | |||||
| </ActionButton> | |||||
| </Tooltip> | |||||
| ) | |||||
| } | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| </div> | |||||
| ) | |||||
| if (credential.not_allowed_to_use) { | |||||
| return ( | |||||
| <Tooltip popupContent={t('plugin.auth.customCredentialUnavailable')}> | |||||
| {Item} | |||||
| </Tooltip> | |||||
| ) | |||||
| } | |||||
| return Item | |||||
| } | |||||
| export default memo(CredentialItem) |
| import { | |||||
| memo, | |||||
| useCallback, | |||||
| useMemo, | |||||
| useState, | |||||
| } from 'react' | |||||
| import { | |||||
| RiAddLine, | |||||
| RiEqualizer2Line, | |||||
| } from '@remixicon/react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import { | |||||
| PortalToFollowElem, | |||||
| PortalToFollowElemContent, | |||||
| PortalToFollowElemTrigger, | |||||
| } from '@/app/components/base/portal-to-follow-elem' | |||||
| import type { | |||||
| PortalToFollowElemOptions, | |||||
| } from '@/app/components/base/portal-to-follow-elem' | |||||
| import Button from '@/app/components/base/button' | |||||
| import cn from '@/utils/classnames' | |||||
| import Confirm from '@/app/components/base/confirm' | |||||
| import type { | |||||
| ConfigurationMethodEnum, | |||||
| Credential, | |||||
| CustomConfigurationModelFixedFields, | |||||
| CustomModel, | |||||
| ModelProvider, | |||||
| } from '../../declarations' | |||||
| import { useAuth } from '../hooks' | |||||
| import AuthorizedItem from './authorized-item' | |||||
| type AuthorizedProps = { | |||||
| provider: ModelProvider, | |||||
| configurationMethod: ConfigurationMethodEnum, | |||||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | |||||
| isModelCredential?: boolean | |||||
| items: { | |||||
| title?: string | |||||
| model?: CustomModel | |||||
| credentials: Credential[] | |||||
| }[] | |||||
| selectedCredential?: Credential | |||||
| disabled?: boolean | |||||
| renderTrigger?: (open?: boolean) => React.ReactNode | |||||
| isOpen?: boolean | |||||
| onOpenChange?: (open: boolean) => void | |||||
| offset?: PortalToFollowElemOptions['offset'] | |||||
| placement?: PortalToFollowElemOptions['placement'] | |||||
| triggerPopupSameWidth?: boolean | |||||
| popupClassName?: string | |||||
| showItemSelectedIcon?: boolean | |||||
| onUpdate?: () => void | |||||
| onItemClick?: (credential: Credential, model?: CustomModel) => void | |||||
| enableAddModelCredential?: boolean | |||||
| bottomAddModelCredentialText?: string | |||||
| } | |||||
| const Authorized = ({ | |||||
| provider, | |||||
| configurationMethod, | |||||
| currentCustomConfigurationModelFixedFields, | |||||
| items, | |||||
| isModelCredential, | |||||
| selectedCredential, | |||||
| disabled, | |||||
| renderTrigger, | |||||
| isOpen, | |||||
| onOpenChange, | |||||
| offset = 8, | |||||
| placement = 'bottom-end', | |||||
| triggerPopupSameWidth = false, | |||||
| popupClassName, | |||||
| showItemSelectedIcon, | |||||
| onUpdate, | |||||
| onItemClick, | |||||
| enableAddModelCredential, | |||||
| bottomAddModelCredentialText, | |||||
| }: AuthorizedProps) => { | |||||
| const { t } = useTranslation() | |||||
| const [isLocalOpen, setIsLocalOpen] = useState(false) | |||||
| const mergedIsOpen = isOpen ?? isLocalOpen | |||||
| const setMergedIsOpen = useCallback((open: boolean) => { | |||||
| if (onOpenChange) | |||||
| onOpenChange(open) | |||||
| setIsLocalOpen(open) | |||||
| }, [onOpenChange]) | |||||
| const { | |||||
| openConfirmDelete, | |||||
| closeConfirmDelete, | |||||
| doingAction, | |||||
| handleActiveCredential, | |||||
| handleConfirmDelete, | |||||
| deleteCredentialId, | |||||
| handleOpenModal, | |||||
| } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onUpdate) | |||||
| const handleEdit = useCallback((credential?: Credential, model?: CustomModel) => { | |||||
| handleOpenModal(credential, model) | |||||
| setMergedIsOpen(false) | |||||
| }, [handleOpenModal, setMergedIsOpen]) | |||||
| const handleItemClick = useCallback((credential: Credential, model?: CustomModel) => { | |||||
| if (onItemClick) | |||||
| onItemClick(credential, model) | |||||
| else | |||||
| handleActiveCredential(credential, model) | |||||
| setMergedIsOpen(false) | |||||
| }, [handleActiveCredential, onItemClick, setMergedIsOpen]) | |||||
| const notAllowCustomCredential = provider.allow_custom_token === false | |||||
| const Trigger = useMemo(() => { | |||||
| const Item = ( | |||||
| <Button | |||||
| className='grow' | |||||
| size='small' | |||||
| > | |||||
| <RiEqualizer2Line className='mr-1 h-3.5 w-3.5' /> | |||||
| {t('common.operation.config')} | |||||
| </Button> | |||||
| ) | |||||
| return Item | |||||
| }, [t]) | |||||
| return ( | |||||
| <> | |||||
| <PortalToFollowElem | |||||
| open={mergedIsOpen} | |||||
| onOpenChange={setMergedIsOpen} | |||||
| placement={placement} | |||||
| offset={offset} | |||||
| triggerPopupSameWidth={triggerPopupSameWidth} | |||||
| > | |||||
| <PortalToFollowElemTrigger | |||||
| onClick={() => { | |||||
| setMergedIsOpen(!mergedIsOpen) | |||||
| }} | |||||
| asChild | |||||
| > | |||||
| { | |||||
| renderTrigger | |||||
| ? renderTrigger(mergedIsOpen) | |||||
| : Trigger | |||||
| } | |||||
| </PortalToFollowElemTrigger> | |||||
| <PortalToFollowElemContent className='z-[100]'> | |||||
| <div className={cn( | |||||
| 'w-[360px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg', | |||||
| popupClassName, | |||||
| )}> | |||||
| <div className='max-h-[304px] overflow-y-auto'> | |||||
| { | |||||
| items.map((item, index) => ( | |||||
| <AuthorizedItem | |||||
| key={index} | |||||
| title={item.title} | |||||
| model={item.model} | |||||
| credentials={item.credentials} | |||||
| disabled={disabled} | |||||
| onDelete={openConfirmDelete} | |||||
| onEdit={handleEdit} | |||||
| showItemSelectedIcon={showItemSelectedIcon} | |||||
| selectedCredentialId={selectedCredential?.credential_id} | |||||
| onItemClick={handleItemClick} | |||||
| enableAddModelCredential={enableAddModelCredential} | |||||
| notAllowCustomCredential={notAllowCustomCredential} | |||||
| /> | |||||
| )) | |||||
| } | |||||
| </div> | |||||
| <div className='h-[1px] bg-divider-subtle'></div> | |||||
| { | |||||
| isModelCredential && !notAllowCustomCredential && ( | |||||
| <div | |||||
| onClick={() => handleEdit( | |||||
| undefined, | |||||
| currentCustomConfigurationModelFixedFields | |||||
| ? { | |||||
| model: currentCustomConfigurationModelFixedFields.__model_name, | |||||
| model_type: currentCustomConfigurationModelFixedFields.__model_type, | |||||
| } | |||||
| : undefined, | |||||
| )} | |||||
| className='system-xs-medium flex h-[30px] cursor-pointer items-center px-3 text-text-accent-light-mode-only' | |||||
| > | |||||
| <RiAddLine className='mr-1 h-4 w-4' /> | |||||
| {bottomAddModelCredentialText ?? t('common.modelProvider.auth.addModelCredential')} | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| { | |||||
| !isModelCredential && !notAllowCustomCredential && ( | |||||
| <div className='p-2'> | |||||
| <Button | |||||
| onClick={() => handleEdit()} | |||||
| className='w-full' | |||||
| > | |||||
| {t('common.modelProvider.auth.addApiKey')} | |||||
| </Button> | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| </div> | |||||
| </PortalToFollowElemContent> | |||||
| </PortalToFollowElem> | |||||
| { | |||||
| deleteCredentialId && ( | |||||
| <Confirm | |||||
| isShow | |||||
| title={t('common.modelProvider.confirmDelete')} | |||||
| isDisabled={doingAction} | |||||
| onCancel={closeConfirmDelete} | |||||
| onConfirm={handleConfirmDelete} | |||||
| /> | |||||
| ) | |||||
| } | |||||
| </> | |||||
| ) | |||||
| } | |||||
| export default memo(Authorized) |
| import { memo } from 'react' | |||||
| import { | |||||
| RiEqualizer2Line, | |||||
| RiScales3Line, | |||||
| } from '@remixicon/react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import Button from '@/app/components/base/button' | |||||
| import Indicator from '@/app/components/header/indicator' | |||||
| import cn from '@/utils/classnames' | |||||
| type ConfigModelProps = { | |||||
| onClick?: () => void | |||||
| loadBalancingEnabled?: boolean | |||||
| loadBalancingInvalid?: boolean | |||||
| credentialRemoved?: boolean | |||||
| } | |||||
| const ConfigModel = ({ | |||||
| onClick, | |||||
| loadBalancingEnabled, | |||||
| loadBalancingInvalid, | |||||
| credentialRemoved, | |||||
| }: ConfigModelProps) => { | |||||
| const { t } = useTranslation() | |||||
| if (loadBalancingInvalid) { | |||||
| return ( | |||||
| <div | |||||
| className='system-2xs-medium-uppercase relative flex h-[18px] items-center rounded-[5px] border border-text-warning bg-components-badge-bg-dimm px-1.5 text-text-warning' | |||||
| onClick={onClick} | |||||
| > | |||||
| <RiScales3Line className='mr-0.5 h-3 w-3' /> | |||||
| {t('common.modelProvider.auth.authorizationError')} | |||||
| <Indicator color='orange' className='absolute right-[-1px] top-[-1px] h-1.5 w-1.5' /> | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| return ( | |||||
| <Button | |||||
| variant='secondary' | |||||
| size='small' | |||||
| className={cn( | |||||
| 'hidden shrink-0 group-hover:flex', | |||||
| credentialRemoved && 'flex', | |||||
| )} | |||||
| onClick={onClick} | |||||
| > | |||||
| { | |||||
| credentialRemoved && ( | |||||
| <> | |||||
| {t('common.modelProvider.auth.credentialRemoved')} | |||||
| <Indicator color='red' className='ml-2' /> | |||||
| </> | |||||
| ) | |||||
| } | |||||
| { | |||||
| !loadBalancingEnabled && !credentialRemoved && !loadBalancingInvalid && ( | |||||
| <> | |||||
| <RiEqualizer2Line className='mr-1 h-4 w-4' /> | |||||
| {t('common.operation.config')} | |||||
| </> | |||||
| ) | |||||
| } | |||||
| { | |||||
| loadBalancingEnabled && !credentialRemoved && !loadBalancingInvalid && ( | |||||
| <> | |||||
| <RiScales3Line className='mr-1 h-4 w-4' /> | |||||
| {t('common.modelProvider.auth.configLoadBalancing')} | |||||
| </> | |||||
| ) | |||||
| } | |||||
| </Button> | |||||
| ) | |||||
| } | |||||
| export default memo(ConfigModel) |
| import { | |||||
| memo, | |||||
| useCallback, | |||||
| useMemo, | |||||
| } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import { | |||||
| RiEqualizer2Line, | |||||
| } from '@remixicon/react' | |||||
| import { | |||||
| Button, | |||||
| } from '@/app/components/base/button' | |||||
| import type { | |||||
| CustomConfigurationModelFixedFields, | |||||
| ModelProvider, | |||||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||||
| import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||||
| import Authorized from './authorized' | |||||
| import { useAuth, useCredentialStatus } from './hooks' | |||||
| import Tooltip from '@/app/components/base/tooltip' | |||||
| import cn from '@/utils/classnames' | |||||
| type ConfigProviderProps = { | |||||
| provider: ModelProvider, | |||||
| configurationMethod: ConfigurationMethodEnum, | |||||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | |||||
| } | |||||
| const ConfigProvider = ({ | |||||
| provider, | |||||
| configurationMethod, | |||||
| currentCustomConfigurationModelFixedFields, | |||||
| }: ConfigProviderProps) => { | |||||
| const { t } = useTranslation() | |||||
| const { | |||||
| handleOpenModal, | |||||
| } = useAuth(provider, configurationMethod, currentCustomConfigurationModelFixedFields) | |||||
| const { | |||||
| hasCredential, | |||||
| authorized, | |||||
| current_credential_id, | |||||
| current_credential_name, | |||||
| available_credentials, | |||||
| } = useCredentialStatus(provider) | |||||
| const notAllowCustomCredential = provider.allow_custom_token === false | |||||
| const handleClick = useCallback(() => { | |||||
| if (!hasCredential && !notAllowCustomCredential) | |||||
| handleOpenModal() | |||||
| }, [handleOpenModal, hasCredential, notAllowCustomCredential]) | |||||
| const ButtonComponent = useMemo(() => { | |||||
| const Item = ( | |||||
| <Button | |||||
| className={cn('grow', notAllowCustomCredential && 'cursor-not-allowed opacity-50')} | |||||
| size='small' | |||||
| onClick={handleClick} | |||||
| variant={!authorized ? 'secondary-accent' : 'secondary'} | |||||
| > | |||||
| <RiEqualizer2Line className='mr-1 h-3.5 w-3.5' /> | |||||
| {t('common.operation.setup')} | |||||
| </Button> | |||||
| ) | |||||
| if (notAllowCustomCredential) { | |||||
| return ( | |||||
| <Tooltip | |||||
| asChild | |||||
| popupContent={t('plugin.auth.credentialUnavailable')} | |||||
| > | |||||
| {Item} | |||||
| </Tooltip> | |||||
| ) | |||||
| } | |||||
| return Item | |||||
| }, [handleClick, authorized, notAllowCustomCredential, t]) | |||||
| if (!hasCredential) | |||||
| return ButtonComponent | |||||
| return ( | |||||
| <Authorized | |||||
| provider={provider} | |||||
| configurationMethod={ConfigurationMethodEnum.predefinedModel} | |||||
| items={[ | |||||
| { | |||||
| title: t('common.modelProvider.auth.apiKeys'), | |||||
| credentials: available_credentials ?? [], | |||||
| }, | |||||
| ]} | |||||
| selectedCredential={{ | |||||
| credential_id: current_credential_id ?? '', | |||||
| credential_name: current_credential_name ?? '', | |||||
| }} | |||||
| showItemSelectedIcon | |||||
| /> | |||||
| ) | |||||
| } | |||||
| export default memo(ConfigProvider) |
| export * from './use-model-form-schemas' | |||||
| export * from './use-credential-status' | |||||
| export * from './use-custom-models' | |||||
| export * from './use-auth' | |||||
| export * from './use-auth-service' | |||||
| export * from './use-credential-data' |
| import { useCallback } from 'react' | |||||
| import { | |||||
| useActiveModelCredential, | |||||
| useActiveProviderCredential, | |||||
| useAddModelCredential, | |||||
| useAddProviderCredential, | |||||
| useDeleteModelCredential, | |||||
| useDeleteProviderCredential, | |||||
| useEditModelCredential, | |||||
| useEditProviderCredential, | |||||
| useGetModelCredential, | |||||
| useGetProviderCredential, | |||||
| } from '@/service/use-models' | |||||
| import type { | |||||
| CustomModel, | |||||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||||
| export const useGetCredential = (provider: string, isModelCredential?: boolean, credentialId?: string, model?: CustomModel, configFrom?: string) => { | |||||
| const providerData = useGetProviderCredential(!isModelCredential && !!credentialId, provider, credentialId) | |||||
| const modelData = useGetModelCredential(!!isModelCredential && !!credentialId, provider, credentialId, model?.model, model?.model_type, configFrom) | |||||
| return isModelCredential ? modelData : providerData | |||||
| } | |||||
| export const useAuthService = (provider: string) => { | |||||
| const { mutateAsync: addProviderCredential } = useAddProviderCredential(provider) | |||||
| const { mutateAsync: editProviderCredential } = useEditProviderCredential(provider) | |||||
| const { mutateAsync: deleteProviderCredential } = useDeleteProviderCredential(provider) | |||||
| const { mutateAsync: activeProviderCredential } = useActiveProviderCredential(provider) | |||||
| const { mutateAsync: addModelCredential } = useAddModelCredential(provider) | |||||
| const { mutateAsync: activeModelCredential } = useActiveModelCredential(provider) | |||||
| const { mutateAsync: deleteModelCredential } = useDeleteModelCredential(provider) | |||||
| const { mutateAsync: editModelCredential } = useEditModelCredential(provider) | |||||
| const getAddCredentialService = useCallback((isModel: boolean) => { | |||||
| return isModel ? addModelCredential : addProviderCredential | |||||
| }, [addModelCredential, addProviderCredential]) | |||||
| const getEditCredentialService = useCallback((isModel: boolean) => { | |||||
| return isModel ? editModelCredential : editProviderCredential | |||||
| }, [editModelCredential, editProviderCredential]) | |||||
| const getDeleteCredentialService = useCallback((isModel: boolean) => { | |||||
| return isModel ? deleteModelCredential : deleteProviderCredential | |||||
| }, [deleteModelCredential, deleteProviderCredential]) | |||||
| const getActiveCredentialService = useCallback((isModel: boolean) => { | |||||
| return isModel ? activeModelCredential : activeProviderCredential | |||||
| }, [activeModelCredential, activeProviderCredential]) | |||||
| return { | |||||
| getAddCredentialService, | |||||
| getEditCredentialService, | |||||
| getDeleteCredentialService, | |||||
| getActiveCredentialService, | |||||
| } | |||||
| } |
| import { | |||||
| useCallback, | |||||
| useRef, | |||||
| useState, | |||||
| } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import { useToastContext } from '@/app/components/base/toast' | |||||
| import { useAuthService } from './use-auth-service' | |||||
| import type { | |||||
| ConfigurationMethodEnum, | |||||
| Credential, | |||||
| CustomConfigurationModelFixedFields, | |||||
| CustomModel, | |||||
| ModelProvider, | |||||
| } from '../../declarations' | |||||
| import { | |||||
| useModelModalHandler, | |||||
| useRefreshModel, | |||||
| } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||||
| export const useAuth = ( | |||||
| provider: ModelProvider, | |||||
| configurationMethod: ConfigurationMethodEnum, | |||||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields, | |||||
| isModelCredential?: boolean, | |||||
| onUpdate?: () => void, | |||||
| ) => { | |||||
| const { t } = useTranslation() | |||||
| const { notify } = useToastContext() | |||||
| const { | |||||
| getDeleteCredentialService, | |||||
| getActiveCredentialService, | |||||
| getEditCredentialService, | |||||
| getAddCredentialService, | |||||
| } = useAuthService(provider.provider) | |||||
| const handleOpenModelModal = useModelModalHandler() | |||||
| const { handleRefreshModel } = useRefreshModel() | |||||
| const pendingOperationCredentialId = useRef<string | null>(null) | |||||
| const pendingOperationModel = useRef<CustomModel | null>(null) | |||||
| const [deleteCredentialId, setDeleteCredentialId] = useState<string | null>(null) | |||||
| const openConfirmDelete = useCallback((credential?: Credential, model?: CustomModel) => { | |||||
| if (credential) | |||||
| pendingOperationCredentialId.current = credential.credential_id | |||||
| if (model) | |||||
| pendingOperationModel.current = model | |||||
| setDeleteCredentialId(pendingOperationCredentialId.current) | |||||
| }, []) | |||||
| const closeConfirmDelete = useCallback(() => { | |||||
| setDeleteCredentialId(null) | |||||
| pendingOperationCredentialId.current = null | |||||
| }, []) | |||||
| const [doingAction, setDoingAction] = useState(false) | |||||
| const doingActionRef = useRef(doingAction) | |||||
| const handleSetDoingAction = useCallback((doing: boolean) => { | |||||
| doingActionRef.current = doing | |||||
| setDoingAction(doing) | |||||
| }, []) | |||||
| const handleActiveCredential = useCallback(async (credential: Credential, model?: CustomModel) => { | |||||
| if (doingActionRef.current) | |||||
| return | |||||
| try { | |||||
| handleSetDoingAction(true) | |||||
| await getActiveCredentialService(!!model)({ | |||||
| credential_id: credential.credential_id, | |||||
| model: model?.model, | |||||
| model_type: model?.model_type, | |||||
| }) | |||||
| notify({ | |||||
| type: 'success', | |||||
| message: t('common.api.actionSuccess'), | |||||
| }) | |||||
| onUpdate?.() | |||||
| handleRefreshModel(provider, configurationMethod, undefined) | |||||
| } | |||||
| finally { | |||||
| handleSetDoingAction(false) | |||||
| } | |||||
| }, [getActiveCredentialService, onUpdate, notify, t, handleSetDoingAction]) | |||||
| const handleConfirmDelete = useCallback(async () => { | |||||
| if (doingActionRef.current) | |||||
| return | |||||
| if (!pendingOperationCredentialId.current) { | |||||
| setDeleteCredentialId(null) | |||||
| return | |||||
| } | |||||
| try { | |||||
| handleSetDoingAction(true) | |||||
| await getDeleteCredentialService(!!isModelCredential)({ | |||||
| credential_id: pendingOperationCredentialId.current, | |||||
| model: pendingOperationModel.current?.model, | |||||
| model_type: pendingOperationModel.current?.model_type, | |||||
| }) | |||||
| notify({ | |||||
| type: 'success', | |||||
| message: t('common.api.actionSuccess'), | |||||
| }) | |||||
| onUpdate?.() | |||||
| handleRefreshModel(provider, configurationMethod, undefined) | |||||
| setDeleteCredentialId(null) | |||||
| pendingOperationCredentialId.current = null | |||||
| pendingOperationModel.current = null | |||||
| } | |||||
| finally { | |||||
| handleSetDoingAction(false) | |||||
| } | |||||
| }, [onUpdate, notify, t, handleSetDoingAction, getDeleteCredentialService, isModelCredential]) | |||||
| const handleAddCredential = useCallback((model?: CustomModel) => { | |||||
| if (model) | |||||
| pendingOperationModel.current = model | |||||
| }, []) | |||||
| const handleSaveCredential = useCallback(async (payload: Record<string, any>) => { | |||||
| if (doingActionRef.current) | |||||
| return | |||||
| try { | |||||
| handleSetDoingAction(true) | |||||
| let res: { result?: string } = {} | |||||
| if (payload.credential_id) | |||||
| res = await getEditCredentialService(!!isModelCredential)(payload as any) | |||||
| else | |||||
| res = await getAddCredentialService(!!isModelCredential)(payload as any) | |||||
| if (res.result === 'success') { | |||||
| notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) | |||||
| onUpdate?.() | |||||
| } | |||||
| } | |||||
| finally { | |||||
| handleSetDoingAction(false) | |||||
| } | |||||
| }, [onUpdate, notify, t, handleSetDoingAction, getEditCredentialService, getAddCredentialService]) | |||||
| const handleOpenModal = useCallback((credential?: Credential, model?: CustomModel) => { | |||||
| handleOpenModelModal( | |||||
| provider, | |||||
| configurationMethod, | |||||
| currentCustomConfigurationModelFixedFields, | |||||
| isModelCredential, | |||||
| credential, | |||||
| model, | |||||
| onUpdate, | |||||
| ) | |||||
| }, [handleOpenModelModal, provider, configurationMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onUpdate]) | |||||
| return { | |||||
| pendingOperationCredentialId, | |||||
| pendingOperationModel, | |||||
| openConfirmDelete, | |||||
| closeConfirmDelete, | |||||
| doingAction, | |||||
| handleActiveCredential, | |||||
| handleConfirmDelete, | |||||
| handleAddCredential, | |||||
| deleteCredentialId, | |||||
| handleSaveCredential, | |||||
| handleOpenModal, | |||||
| } | |||||
| } |
| import { useMemo } from 'react' | |||||
| import { useGetCredential } from './use-auth-service' | |||||
| import type { | |||||
| Credential, | |||||
| CustomModelCredential, | |||||
| ModelProvider, | |||||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||||
| export const useCredentialData = (provider: ModelProvider, providerFormSchemaPredefined: boolean, isModelCredential?: boolean, credential?: Credential, model?: CustomModelCredential) => { | |||||
| const configFrom = useMemo(() => { | |||||
| if (providerFormSchemaPredefined) | |||||
| return 'predefined-model' | |||||
| return 'custom-model' | |||||
| }, [providerFormSchemaPredefined]) | |||||
| const { | |||||
| isLoading, | |||||
| data: credentialData = {}, | |||||
| } = useGetCredential(provider.provider, isModelCredential, credential?.credential_id, model, configFrom) | |||||
| return { | |||||
| isLoading, | |||||
| credentialData, | |||||
| } | |||||
| } |
| import { useMemo } from 'react' | |||||
| import type { | |||||
| ModelProvider, | |||||
| } from '../../declarations' | |||||
| export const useCredentialStatus = (provider: ModelProvider) => { | |||||
| const { | |||||
| current_credential_id, | |||||
| current_credential_name, | |||||
| available_credentials, | |||||
| } = provider.custom_configuration | |||||
| const hasCredential = !!available_credentials?.length | |||||
| const authorized = current_credential_id && current_credential_name | |||||
| const authRemoved = hasCredential && !current_credential_id && !current_credential_name | |||||
| const currentCredential = available_credentials?.find(credential => credential.credential_id === current_credential_id) | |||||
| return useMemo(() => ({ | |||||
| hasCredential, | |||||
| authorized, | |||||
| authRemoved, | |||||
| current_credential_id, | |||||
| current_credential_name, | |||||
| available_credentials, | |||||
| notAllowedToUse: currentCredential?.not_allowed_to_use, | |||||
| }), [hasCredential, authorized, authRemoved, current_credential_id, current_credential_name, available_credentials]) | |||||
| } |
| import type { | |||||
| ModelProvider, | |||||
| } from '../../declarations' | |||||
| export const useCustomModels = (provider: ModelProvider) => { | |||||
| const { custom_models } = provider.custom_configuration | |||||
| return custom_models || [] | |||||
| } |
| import { useMemo } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import type { | |||||
| Credential, | |||||
| CustomModelCredential, | |||||
| ModelLoadBalancingConfig, | |||||
| ModelProvider, | |||||
| } from '../../declarations' | |||||
| import { | |||||
| genModelNameFormSchema, | |||||
| genModelTypeFormSchema, | |||||
| } from '../../utils' | |||||
| import { FormTypeEnum } from '@/app/components/base/form/types' | |||||
| export const useModelFormSchemas = ( | |||||
| provider: ModelProvider, | |||||
| providerFormSchemaPredefined: boolean, | |||||
| credentials?: Record<string, any>, | |||||
| credential?: Credential, | |||||
| model?: CustomModelCredential, | |||||
| draftConfig?: ModelLoadBalancingConfig, | |||||
| ) => { | |||||
| const { t } = useTranslation() | |||||
| const { | |||||
| provider_credential_schema, | |||||
| supported_model_types, | |||||
| model_credential_schema, | |||||
| } = provider | |||||
| const formSchemas = useMemo(() => { | |||||
| const modelTypeSchema = genModelTypeFormSchema(supported_model_types) | |||||
| const modelNameSchema = genModelNameFormSchema(model_credential_schema?.model) | |||||
| if (!!model) { | |||||
| modelTypeSchema.disabled = true | |||||
| modelNameSchema.disabled = true | |||||
| } | |||||
| return providerFormSchemaPredefined | |||||
| ? provider_credential_schema.credential_form_schemas | |||||
| : [ | |||||
| modelTypeSchema, | |||||
| modelNameSchema, | |||||
| ...(draftConfig?.enabled ? [] : model_credential_schema.credential_form_schemas), | |||||
| ] | |||||
| }, [ | |||||
| providerFormSchemaPredefined, | |||||
| provider_credential_schema?.credential_form_schemas, | |||||
| supported_model_types, | |||||
| model_credential_schema?.credential_form_schemas, | |||||
| model_credential_schema?.model, | |||||
| draftConfig?.enabled, | |||||
| model, | |||||
| ]) | |||||
| const formSchemasWithAuthorizationName = useMemo(() => { | |||||
| const authorizationNameSchema = { | |||||
| type: FormTypeEnum.textInput, | |||||
| variable: '__authorization_name__', | |||||
| label: t('plugin.auth.authorizationName'), | |||||
| required: true, | |||||
| } | |||||
| return [ | |||||
| authorizationNameSchema, | |||||
| ...formSchemas, | |||||
| ] | |||||
| }, [formSchemas, t]) | |||||
| const formValues = useMemo(() => { | |||||
| let result = {} | |||||
| if (credential) { | |||||
| result = { ...result, __authorization_name__: credential?.credential_name } | |||||
| if (credentials) | |||||
| result = { ...result, ...credentials } | |||||
| } | |||||
| if (model) | |||||
| result = { ...result, __model_name: model?.model, __model_type: model?.model_type } | |||||
| return result | |||||
| }, [credentials, credential, model]) | |||||
| return { | |||||
| formSchemas: formSchemasWithAuthorizationName, | |||||
| formValues, | |||||
| } | |||||
| } |
| export { default as Authorized } from './authorized' | |||||
| export { default as SwitchCredentialInLoadBalancing } from './switch-credential-in-load-balancing' | |||||
| export { default as AddCredentialInLoadBalancing } from './add-credential-in-load-balancing' | |||||
| export { default as AddCustomModel } from './add-custom-model' | |||||
| export { default as ConfigProvider } from './config-provider' | |||||
| export { default as ConfigModel } from './config-model' |
| import type { Dispatch, SetStateAction } from 'react' | |||||
| import { | |||||
| memo, | |||||
| useCallback, | |||||
| } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import { RiArrowDownSLine } from '@remixicon/react' | |||||
| import Button from '@/app/components/base/button' | |||||
| import Indicator from '@/app/components/header/indicator' | |||||
| import Authorized from './authorized' | |||||
| import type { | |||||
| Credential, | |||||
| CustomModel, | |||||
| ModelProvider, | |||||
| } from '../declarations' | |||||
| import { ConfigurationMethodEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||||
| import cn from '@/utils/classnames' | |||||
| import Tooltip from '@/app/components/base/tooltip' | |||||
| import Badge from '@/app/components/base/badge' | |||||
| type SwitchCredentialInLoadBalancingProps = { | |||||
| provider: ModelProvider | |||||
| model: CustomModel | |||||
| credentials?: Credential[] | |||||
| customModelCredential?: Credential | |||||
| setCustomModelCredential: Dispatch<SetStateAction<Credential | undefined>> | |||||
| } | |||||
| const SwitchCredentialInLoadBalancing = ({ | |||||
| provider, | |||||
| model, | |||||
| customModelCredential, | |||||
| setCustomModelCredential, | |||||
| credentials, | |||||
| }: SwitchCredentialInLoadBalancingProps) => { | |||||
| const { t } = useTranslation() | |||||
| const handleItemClick = useCallback((credential: Credential) => { | |||||
| setCustomModelCredential(credential) | |||||
| }, [setCustomModelCredential]) | |||||
| const renderTrigger = useCallback(() => { | |||||
| const selectedCredentialId = customModelCredential?.credential_id | |||||
| const authRemoved = !selectedCredentialId && !!credentials?.length | |||||
| let color = 'green' | |||||
| if (authRemoved && !customModelCredential?.not_allowed_to_use) | |||||
| color = 'red' | |||||
| if (customModelCredential?.not_allowed_to_use) | |||||
| color = 'gray' | |||||
| const Item = ( | |||||
| <Button | |||||
| variant='secondary' | |||||
| className={cn( | |||||
| 'shrink-0 space-x-1', | |||||
| authRemoved && 'text-components-button-destructive-secondary-text', | |||||
| customModelCredential?.not_allowed_to_use && 'cursor-not-allowed opacity-50', | |||||
| )} | |||||
| > | |||||
| <Indicator | |||||
| className='mr-2' | |||||
| color={color as any} | |||||
| /> | |||||
| { | |||||
| authRemoved && !customModelCredential?.not_allowed_to_use && t('common.modelProvider.auth.authRemoved') | |||||
| } | |||||
| { | |||||
| !authRemoved && customModelCredential?.not_allowed_to_use && t('plugin.auth.credentialUnavailable') | |||||
| } | |||||
| { | |||||
| !authRemoved && !customModelCredential?.not_allowed_to_use && customModelCredential?.credential_name | |||||
| } | |||||
| { | |||||
| customModelCredential?.from_enterprise && ( | |||||
| <Badge className='ml-2'>Enterprise</Badge> | |||||
| ) | |||||
| } | |||||
| <RiArrowDownSLine className='h-4 w-4' /> | |||||
| </Button> | |||||
| ) | |||||
| if (customModelCredential?.not_allowed_to_use) { | |||||
| return ( | |||||
| <Tooltip | |||||
| asChild | |||||
| popupContent={t('plugin.auth.credentialUnavailable')} | |||||
| > | |||||
| {Item} | |||||
| </Tooltip> | |||||
| ) | |||||
| } | |||||
| return Item | |||||
| }, [customModelCredential, t, credentials]) | |||||
| return ( | |||||
| <Authorized | |||||
| provider={provider} | |||||
| configurationMethod={ConfigurationMethodEnum.customizableModel} | |||||
| items={[ | |||||
| { | |||||
| title: t('common.modelProvider.auth.modelCredentials'), | |||||
| model, | |||||
| credentials: credentials || [], | |||||
| }, | |||||
| ]} | |||||
| renderTrigger={renderTrigger} | |||||
| onItemClick={handleItemClick} | |||||
| isModelCredential | |||||
| enableAddModelCredential | |||||
| bottomAddModelCredentialText={t('common.modelProvider.auth.addModelCredential')} | |||||
| selectedCredential={ | |||||
| customModelCredential | |||||
| ? { | |||||
| credential_id: customModelCredential?.credential_id || '', | |||||
| credential_name: customModelCredential?.credential_name || '', | |||||
| } | |||||
| : undefined | |||||
| } | |||||
| showItemSelectedIcon | |||||
| /> | |||||
| ) | |||||
| } | |||||
| export default memo(SwitchCredentialInLoadBalancing) |
| provider?: Model | ModelProvider | provider?: Model | ModelProvider | ||||
| modelName?: string | modelName?: string | ||||
| className?: string | className?: string | ||||
| iconClassName?: string | |||||
| isDeprecated?: boolean | isDeprecated?: boolean | ||||
| } | } | ||||
| const ModelIcon: FC<ModelIconProps> = ({ | const ModelIcon: FC<ModelIconProps> = ({ | ||||
| provider, | provider, | ||||
| className, | className, | ||||
| modelName, | modelName, | ||||
| iconClassName, | |||||
| isDeprecated = false, | isDeprecated = false, | ||||
| }) => { | }) => { | ||||
| const language = useLanguage() | const language = useLanguage() | ||||
| if (provider?.icon_small) { | if (provider?.icon_small) { | ||||
| return ( | return ( | ||||
| <div className={cn('flex h-5 w-5 items-center justify-center', isDeprecated && 'opacity-50', className)}> | <div className={cn('flex h-5 w-5 items-center justify-center', isDeprecated && 'opacity-50', className)}> | ||||
| <img alt='model-icon' src={renderI18nObject(provider.icon_small, language)}/> | |||||
| <img alt='model-icon' src={renderI18nObject(provider.icon_small, language)} className={iconClassName} /> | |||||
| </div> | </div> | ||||
| ) | ) | ||||
| } | } | ||||
| 'flex h-5 w-5 items-center justify-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-background-default-subtle', | 'flex h-5 w-5 items-center justify-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-background-default-subtle', | ||||
| className, | className, | ||||
| )}> | )}> | ||||
| <div className='flex h-5 w-5 items-center justify-center opacity-35'> | |||||
| <div className={cn('flex h-5 w-5 items-center justify-center opacity-35', iconClassName)}> | |||||
| <Group className='h-3 w-3 text-text-tertiary' /> | <Group className='h-3 w-3 text-text-tertiary' /> | ||||
| </div> | </div> | ||||
| </div> | </div> |
| import { | import { | ||||
| memo, | memo, | ||||
| useCallback, | useCallback, | ||||
| useEffect, | |||||
| useMemo, | useMemo, | ||||
| useState, | |||||
| useRef, | |||||
| } from 'react' | } from 'react' | ||||
| import { RiCloseLine } from '@remixicon/react' | |||||
| import { useTranslation } from 'react-i18next' | import { useTranslation } from 'react-i18next' | ||||
| import { | |||||
| RiErrorWarningFill, | |||||
| } from '@remixicon/react' | |||||
| import type { | import type { | ||||
| CredentialFormSchema, | |||||
| CredentialFormSchemaRadio, | |||||
| CredentialFormSchemaSelect, | |||||
| CustomConfigurationModelFixedFields, | CustomConfigurationModelFixedFields, | ||||
| FormValue, | |||||
| ModelLoadBalancingConfig, | |||||
| ModelLoadBalancingConfigEntry, | |||||
| ModelProvider, | ModelProvider, | ||||
| } from '../declarations' | } from '../declarations' | ||||
| import { | import { | ||||
| ConfigurationMethodEnum, | ConfigurationMethodEnum, | ||||
| CustomConfigurationStatusEnum, | |||||
| FormTypeEnum, | FormTypeEnum, | ||||
| } from '../declarations' | } from '../declarations' | ||||
| import { | |||||
| genModelNameFormSchema, | |||||
| genModelTypeFormSchema, | |||||
| removeCredentials, | |||||
| saveCredentials, | |||||
| } from '../utils' | |||||
| import { | import { | ||||
| useLanguage, | useLanguage, | ||||
| useProviderCredentialsAndLoadBalancing, | |||||
| } from '../hooks' | } from '../hooks' | ||||
| import { useValidate } from '../../key-validator/hooks' | |||||
| import { ValidatedStatus } from '../../key-validator/declarations' | |||||
| import ModelLoadBalancingConfigs from '../provider-added-card/model-load-balancing-configs' | |||||
| import Form from './Form' | |||||
| import Button from '@/app/components/base/button' | import Button from '@/app/components/base/button' | ||||
| import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' | import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' | ||||
| import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' | import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' | ||||
| PortalToFollowElem, | PortalToFollowElem, | ||||
| PortalToFollowElemContent, | PortalToFollowElemContent, | ||||
| } from '@/app/components/base/portal-to-follow-elem' | } from '@/app/components/base/portal-to-follow-elem' | ||||
| import { useToastContext } from '@/app/components/base/toast' | |||||
| import Confirm from '@/app/components/base/confirm' | import Confirm from '@/app/components/base/confirm' | ||||
| import { useAppContext } from '@/context/app-context' | import { useAppContext } from '@/context/app-context' | ||||
| import AuthForm from '@/app/components/base/form/form-scenarios/auth' | |||||
| import type { | |||||
| FormRefObject, | |||||
| FormSchema, | |||||
| } from '@/app/components/base/form/types' | |||||
| import { useModelFormSchemas } from '../model-auth/hooks' | |||||
| import type { | |||||
| Credential, | |||||
| CustomModel, | |||||
| } from '../declarations' | |||||
| import Loading from '@/app/components/base/loading' | |||||
| import { | |||||
| useAuth, | |||||
| useCredentialData, | |||||
| } from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks' | |||||
| import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon' | |||||
| import Badge from '@/app/components/base/badge' | |||||
| import { useRenderI18nObject } from '@/hooks/use-i18n' | |||||
| type ModelModalProps = { | type ModelModalProps = { | ||||
| provider: ModelProvider | provider: ModelProvider | ||||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields | currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields | ||||
| onCancel: () => void | onCancel: () => void | ||||
| onSave: () => void | onSave: () => void | ||||
| model?: CustomModel | |||||
| credential?: Credential | |||||
| isModelCredential?: boolean | |||||
| } | } | ||||
| const ModelModal: FC<ModelModalProps> = ({ | const ModelModal: FC<ModelModalProps> = ({ | ||||
| currentCustomConfigurationModelFixedFields, | currentCustomConfigurationModelFixedFields, | ||||
| onCancel, | onCancel, | ||||
| onSave, | onSave, | ||||
| model, | |||||
| credential, | |||||
| isModelCredential, | |||||
| }) => { | }) => { | ||||
| const renderI18nObject = useRenderI18nObject() | |||||
| const providerFormSchemaPredefined = configurateMethod === ConfigurationMethodEnum.predefinedModel | const providerFormSchemaPredefined = configurateMethod === ConfigurationMethodEnum.predefinedModel | ||||
| const { | |||||
| isLoading, | |||||
| credentialData, | |||||
| } = useCredentialData(provider, providerFormSchemaPredefined, isModelCredential, credential, model) | |||||
| const { | |||||
| handleSaveCredential, | |||||
| handleConfirmDelete, | |||||
| deleteCredentialId, | |||||
| closeConfirmDelete, | |||||
| openConfirmDelete, | |||||
| doingAction, | |||||
| } = useAuth(provider, configurateMethod, currentCustomConfigurationModelFixedFields, isModelCredential, onSave) | |||||
| const { | const { | ||||
| credentials: formSchemasValue, | credentials: formSchemasValue, | ||||
| loadBalancing: originalConfig, | |||||
| mutate, | |||||
| } = useProviderCredentialsAndLoadBalancing( | |||||
| provider.provider, | |||||
| configurateMethod, | |||||
| providerFormSchemaPredefined && provider.custom_configuration.status === CustomConfigurationStatusEnum.active, | |||||
| currentCustomConfigurationModelFixedFields, | |||||
| ) | |||||
| } = credentialData as any | |||||
| const { isCurrentWorkspaceManager } = useAppContext() | const { isCurrentWorkspaceManager } = useAppContext() | ||||
| const isEditMode = !!formSchemasValue && isCurrentWorkspaceManager | const isEditMode = !!formSchemasValue && isCurrentWorkspaceManager | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const { notify } = useToastContext() | |||||
| const language = useLanguage() | const language = useLanguage() | ||||
| const [loading, setLoading] = useState(false) | |||||
| const [showConfirm, setShowConfirm] = useState(false) | |||||
| const [draftConfig, setDraftConfig] = useState<ModelLoadBalancingConfig>() | |||||
| const originalConfigMap = useMemo(() => { | |||||
| if (!originalConfig) | |||||
| return {} | |||||
| return originalConfig?.configs.reduce((prev, config) => { | |||||
| if (config.id) | |||||
| prev[config.id] = config | |||||
| return prev | |||||
| }, {} as Record<string, ModelLoadBalancingConfigEntry>) | |||||
| }, [originalConfig]) | |||||
| useEffect(() => { | |||||
| if (originalConfig && !draftConfig) | |||||
| setDraftConfig(originalConfig) | |||||
| }, [draftConfig, originalConfig]) | |||||
| const formSchemas = useMemo(() => { | |||||
| return providerFormSchemaPredefined | |||||
| ? provider.provider_credential_schema.credential_form_schemas | |||||
| : [ | |||||
| genModelTypeFormSchema(provider.supported_model_types), | |||||
| genModelNameFormSchema(provider.model_credential_schema?.model), | |||||
| ...(draftConfig?.enabled ? [] : provider.model_credential_schema.credential_form_schemas), | |||||
| ] | |||||
| }, [ | |||||
| providerFormSchemaPredefined, | |||||
| provider.provider_credential_schema?.credential_form_schemas, | |||||
| provider.supported_model_types, | |||||
| provider.model_credential_schema?.credential_form_schemas, | |||||
| provider.model_credential_schema?.model, | |||||
| draftConfig?.enabled, | |||||
| ]) | |||||
| const [ | |||||
| requiredFormSchemas, | |||||
| defaultFormSchemaValue, | |||||
| showOnVariableMap, | |||||
| ] = useMemo(() => { | |||||
| const requiredFormSchemas: CredentialFormSchema[] = [] | |||||
| const defaultFormSchemaValue: Record<string, string | number> = {} | |||||
| const showOnVariableMap: Record<string, string[]> = {} | |||||
| formSchemas.forEach((formSchema) => { | |||||
| if (formSchema.required) | |||||
| requiredFormSchemas.push(formSchema) | |||||
| if (formSchema.default) | |||||
| defaultFormSchemaValue[formSchema.variable] = formSchema.default | |||||
| if (formSchema.show_on.length) { | |||||
| formSchema.show_on.forEach((showOnItem) => { | |||||
| if (!showOnVariableMap[showOnItem.variable]) | |||||
| showOnVariableMap[showOnItem.variable] = [] | |||||
| if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) | |||||
| showOnVariableMap[showOnItem.variable].push(formSchema.variable) | |||||
| }) | |||||
| } | |||||
| if (formSchema.type === FormTypeEnum.select || formSchema.type === FormTypeEnum.radio) { | |||||
| (formSchema as (CredentialFormSchemaRadio | CredentialFormSchemaSelect)).options.forEach((option) => { | |||||
| if (option.show_on.length) { | |||||
| option.show_on.forEach((showOnItem) => { | |||||
| if (!showOnVariableMap[showOnItem.variable]) | |||||
| showOnVariableMap[showOnItem.variable] = [] | |||||
| if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) | |||||
| showOnVariableMap[showOnItem.variable].push(formSchema.variable) | |||||
| }) | |||||
| } | |||||
| }) | |||||
| } | |||||
| }) | |||||
| return [ | |||||
| requiredFormSchemas, | |||||
| defaultFormSchemaValue, | |||||
| showOnVariableMap, | |||||
| ] | |||||
| }, [formSchemas]) | |||||
| const initialFormSchemasValue: Record<string, string | number> = useMemo(() => { | |||||
| return { | |||||
| ...defaultFormSchemaValue, | |||||
| ...formSchemasValue, | |||||
| } as unknown as Record<string, string | number> | |||||
| }, [formSchemasValue, defaultFormSchemaValue]) | |||||
| const [value, setValue] = useState(initialFormSchemasValue) | |||||
| useEffect(() => { | |||||
| setValue(initialFormSchemasValue) | |||||
| }, [initialFormSchemasValue]) | |||||
| const [_, validating, validatedStatusState] = useValidate(value) | |||||
| const filteredRequiredFormSchemas = requiredFormSchemas.filter((requiredFormSchema) => { | |||||
| if (requiredFormSchema.show_on.length && requiredFormSchema.show_on.every(showOnItem => value[showOnItem.variable] === showOnItem.value)) | |||||
| return true | |||||
| if (!requiredFormSchema.show_on.length) | |||||
| return true | |||||
| const { | |||||
| formSchemas, | |||||
| formValues, | |||||
| } = useModelFormSchemas(provider, providerFormSchemaPredefined, formSchemasValue, credential, model) | |||||
| const formRef = useRef<FormRefObject>(null) | |||||
| return false | |||||
| }) | |||||
| const handleSave = useCallback(async () => { | |||||
| const { | |||||
| isCheckValidated, | |||||
| values, | |||||
| } = formRef.current?.getFormValues({ | |||||
| needCheckValidatedValues: true, | |||||
| needTransformWhenSecretFieldIsPristine: true, | |||||
| }) || { isCheckValidated: false, values: {} } | |||||
| if (!isCheckValidated) | |||||
| return | |||||
| const handleValueChange = (v: FormValue) => { | |||||
| setValue(v) | |||||
| } | |||||
| const { | |||||
| __authorization_name__, | |||||
| __model_name, | |||||
| __model_type, | |||||
| ...rest | |||||
| } = values | |||||
| if (__model_name && __model_type) { | |||||
| handleSaveCredential({ | |||||
| credential_id: credential?.credential_id, | |||||
| credentials: rest, | |||||
| name: __authorization_name__, | |||||
| model: __model_name, | |||||
| model_type: __model_type, | |||||
| }) | |||||
| } | |||||
| else { | |||||
| handleSaveCredential({ | |||||
| credential_id: credential?.credential_id, | |||||
| credentials: rest, | |||||
| name: __authorization_name__, | |||||
| }) | |||||
| } | |||||
| }, [handleSaveCredential, credential?.credential_id, model]) | |||||
| const extendedSecretFormSchemas = useMemo( | |||||
| () => | |||||
| (providerFormSchemaPredefined | |||||
| ? provider.provider_credential_schema.credential_form_schemas | |||||
| : [ | |||||
| genModelTypeFormSchema(provider.supported_model_types), | |||||
| genModelNameFormSchema(provider.model_credential_schema?.model), | |||||
| ...provider.model_credential_schema.credential_form_schemas, | |||||
| ]).filter(({ type }) => type === FormTypeEnum.secretInput), | |||||
| [ | |||||
| provider.model_credential_schema?.credential_form_schemas, | |||||
| provider.model_credential_schema?.model, | |||||
| provider.provider_credential_schema?.credential_form_schemas, | |||||
| provider.supported_model_types, | |||||
| providerFormSchemaPredefined, | |||||
| ], | |||||
| ) | |||||
| const modalTitle = useMemo(() => { | |||||
| if (!providerFormSchemaPredefined && !model) { | |||||
| return ( | |||||
| <div className='flex items-center'> | |||||
| <ModelIcon | |||||
| className='mr-2 h-10 w-10 shrink-0' | |||||
| iconClassName='h-10 w-10' | |||||
| provider={provider} | |||||
| /> | |||||
| <div> | |||||
| <div className='system-xs-medium-uppercase text-text-tertiary'>{t('common.modelProvider.auth.apiKeyModal.addModel')}</div> | |||||
| <div className='system-md-semibold text-text-primary'>{renderI18nObject(provider.label)}</div> | |||||
| </div> | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| let label = t('common.modelProvider.auth.apiKeyModal.title') | |||||
| const encodeSecretValues = useCallback((v: FormValue) => { | |||||
| const result = { ...v } | |||||
| extendedSecretFormSchemas.forEach(({ variable }) => { | |||||
| if (result[variable] === formSchemasValue?.[variable] && result[variable] !== undefined) | |||||
| result[variable] = '[__HIDDEN__]' | |||||
| }) | |||||
| return result | |||||
| }, [extendedSecretFormSchemas, formSchemasValue]) | |||||
| if (model) | |||||
| label = t('common.modelProvider.auth.addModelCredential') | |||||
| const encodeConfigEntrySecretValues = useCallback((entry: ModelLoadBalancingConfigEntry) => { | |||||
| const result = { ...entry } | |||||
| extendedSecretFormSchemas.forEach(({ variable }) => { | |||||
| if (entry.id && result.credentials[variable] === originalConfigMap[entry.id]?.credentials?.[variable]) | |||||
| result.credentials[variable] = '[__HIDDEN__]' | |||||
| }) | |||||
| return result | |||||
| }, [extendedSecretFormSchemas, originalConfigMap]) | |||||
| return ( | |||||
| <div className='title-2xl-semi-bold text-text-primary'> | |||||
| {label} | |||||
| </div> | |||||
| ) | |||||
| }, [providerFormSchemaPredefined, t, model, renderI18nObject]) | |||||
| const handleSave = async () => { | |||||
| try { | |||||
| setLoading(true) | |||||
| const res = await saveCredentials( | |||||
| providerFormSchemaPredefined, | |||||
| provider.provider, | |||||
| encodeSecretValues(value), | |||||
| { | |||||
| ...draftConfig, | |||||
| enabled: Boolean(draftConfig?.enabled), | |||||
| configs: draftConfig?.configs.map(encodeConfigEntrySecretValues) || [], | |||||
| }, | |||||
| const modalDesc = useMemo(() => { | |||||
| if (providerFormSchemaPredefined) { | |||||
| return ( | |||||
| <div className='system-xs-regular mt-1 text-text-tertiary'> | |||||
| {t('common.modelProvider.auth.apiKeyModal.desc')} | |||||
| </div> | |||||
| ) | ) | ||||
| if (res.result === 'success') { | |||||
| notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) | |||||
| mutate() | |||||
| onSave() | |||||
| onCancel() | |||||
| } | |||||
| } | } | ||||
| finally { | |||||
| setLoading(false) | |||||
| } | |||||
| } | |||||
| const handleRemove = async () => { | |||||
| try { | |||||
| setLoading(true) | |||||
| return null | |||||
| }, [providerFormSchemaPredefined, t]) | |||||
| const res = await removeCredentials( | |||||
| providerFormSchemaPredefined, | |||||
| provider.provider, | |||||
| value, | |||||
| const modalModel = useMemo(() => { | |||||
| if (model) { | |||||
| return ( | |||||
| <div className='mt-2 flex items-center'> | |||||
| <ModelIcon | |||||
| className='mr-2 h-4 w-4 shrink-0' | |||||
| provider={provider} | |||||
| modelName={model.model} | |||||
| /> | |||||
| <div className='system-md-regular mr-1 text-text-secondary'>{model.model}</div> | |||||
| <Badge>{model.model_type}</Badge> | |||||
| </div> | |||||
| ) | ) | ||||
| if (res.result === 'success') { | |||||
| notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) | |||||
| mutate() | |||||
| onSave() | |||||
| onCancel() | |||||
| } | |||||
| } | |||||
| finally { | |||||
| setLoading(false) | |||||
| } | } | ||||
| } | |||||
| const renderTitlePrefix = () => { | |||||
| const prefix = isEditMode ? t('common.operation.setup') : t('common.operation.add') | |||||
| return `${prefix} ${provider.label[language] || provider.label.en_US}` | |||||
| } | |||||
| return null | |||||
| }, [model, provider]) | |||||
| return ( | return ( | ||||
| <PortalToFollowElem open> | <PortalToFollowElem open> | ||||
| <PortalToFollowElemContent className='z-[60] h-full w-full'> | <PortalToFollowElemContent className='z-[60] h-full w-full'> | ||||
| <div className='fixed inset-0 flex items-center justify-center bg-black/[.25]'> | <div className='fixed inset-0 flex items-center justify-center bg-black/[.25]'> | ||||
| <div className='mx-2 w-[640px] overflow-auto rounded-2xl bg-components-panel-bg shadow-xl'> | |||||
| <div className='px-8 pt-8'> | |||||
| <div className='mb-2 flex items-center'> | |||||
| <div className='text-xl font-semibold text-text-primary'>{renderTitlePrefix()}</div> | |||||
| <div className='relative w-[640px] rounded-2xl bg-components-panel-bg shadow-xl'> | |||||
| <div | |||||
| className='absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center' | |||||
| onClick={onCancel} | |||||
| > | |||||
| <RiCloseLine className='h-4 w-4 text-text-tertiary' /> | |||||
| </div> | |||||
| <div className='px-6 pt-6'> | |||||
| <div className='pb-3'> | |||||
| {modalTitle} | |||||
| {modalDesc} | |||||
| {modalModel} | |||||
| </div> | </div> | ||||
| <div className='max-h-[calc(100vh-320px)] overflow-y-auto'> | <div className='max-h-[calc(100vh-320px)] overflow-y-auto'> | ||||
| <Form | |||||
| value={value} | |||||
| onChange={handleValueChange} | |||||
| formSchemas={formSchemas} | |||||
| validating={validating} | |||||
| validatedSuccess={validatedStatusState.status === ValidatedStatus.Success} | |||||
| showOnVariableMap={showOnVariableMap} | |||||
| isEditMode={isEditMode} | |||||
| /> | |||||
| <div className='mb-4 mt-1 border-t-[0.5px] border-t-divider-regular' /> | |||||
| <ModelLoadBalancingConfigs withSwitch {...{ | |||||
| draftConfig, | |||||
| setDraftConfig, | |||||
| provider, | |||||
| currentCustomConfigurationModelFixedFields, | |||||
| configurationMethod: configurateMethod, | |||||
| }} /> | |||||
| { | |||||
| isLoading && ( | |||||
| <div className='flex items-center justify-center'> | |||||
| <Loading /> | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| { | |||||
| !isLoading && ( | |||||
| <AuthForm | |||||
| formSchemas={formSchemas.map((formSchema) => { | |||||
| return { | |||||
| ...formSchema, | |||||
| name: formSchema.variable, | |||||
| showRadioUI: formSchema.type === FormTypeEnum.radio, | |||||
| } | |||||
| }) as FormSchema[]} | |||||
| defaultValues={formValues} | |||||
| inputClassName='justify-start' | |||||
| ref={formRef} | |||||
| /> | |||||
| ) | |||||
| } | |||||
| </div> | </div> | ||||
| <div className='sticky bottom-0 -mx-2 mt-2 flex flex-wrap items-center justify-between gap-y-2 bg-components-panel-bg px-2 pb-6 pt-4'> | <div className='sticky bottom-0 -mx-2 mt-2 flex flex-wrap items-center justify-between gap-y-2 bg-components-panel-bg px-2 pb-6 pt-4'> | ||||
| variant='warning' | variant='warning' | ||||
| size='large' | size='large' | ||||
| className='mr-2' | className='mr-2' | ||||
| onClick={() => setShowConfirm(true)} | |||||
| onClick={() => openConfirmDelete(credential, model)} | |||||
| > | > | ||||
| {t('common.operation.remove')} | {t('common.operation.remove')} | ||||
| </Button> | </Button> | ||||
| size='large' | size='large' | ||||
| variant='primary' | variant='primary' | ||||
| onClick={handleSave} | onClick={handleSave} | ||||
| disabled={ | |||||
| loading | |||||
| || filteredRequiredFormSchemas.some(item => value[item.variable] === undefined) | |||||
| || (draftConfig?.enabled && (draftConfig?.configs.filter(config => config.enabled).length ?? 0) < 2) | |||||
| } | |||||
| disabled={isLoading || doingAction} | |||||
| > | > | ||||
| {t('common.operation.save')} | {t('common.operation.save')} | ||||
| </Button> | </Button> | ||||
| </div> | </div> | ||||
| </div> | </div> | ||||
| <div className='border-t-[0.5px] border-t-divider-regular'> | <div className='border-t-[0.5px] border-t-divider-regular'> | ||||
| { | |||||
| (validatedStatusState.status === ValidatedStatus.Error && validatedStatusState.message) | |||||
| ? ( | |||||
| <div className='flex bg-background-section-burn px-[10px] py-3 text-xs text-[#D92D20]'> | |||||
| <RiErrorWarningFill className='mr-2 mt-[1px] h-[14px] w-[14px]' /> | |||||
| {validatedStatusState.message} | |||||
| </div> | |||||
| ) | |||||
| : ( | |||||
| <div className='flex items-center justify-center bg-background-section-burn py-3 text-xs text-text-tertiary'> | |||||
| <Lock01 className='mr-1 h-3 w-3 text-text-tertiary' /> | |||||
| {t('common.modelProvider.encrypted.front')} | |||||
| <a | |||||
| className='mx-1 text-text-accent' | |||||
| target='_blank' rel='noopener noreferrer' | |||||
| href='https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html' | |||||
| > | |||||
| PKCS1_OAEP | |||||
| </a> | |||||
| {t('common.modelProvider.encrypted.back')} | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| <div className='flex items-center justify-center rounded-b-2xl bg-background-section-burn py-3 text-xs text-text-tertiary'> | |||||
| <Lock01 className='mr-1 h-3 w-3 text-text-tertiary' /> | |||||
| {t('common.modelProvider.encrypted.front')} | |||||
| <a | |||||
| className='mx-1 text-text-accent' | |||||
| target='_blank' rel='noopener noreferrer' | |||||
| href='https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html' | |||||
| > | |||||
| PKCS1_OAEP | |||||
| </a> | |||||
| {t('common.modelProvider.encrypted.back')} | |||||
| </div> | |||||
| </div> | </div> | ||||
| </div> | </div> | ||||
| { | { | ||||
| showConfirm && ( | |||||
| deleteCredentialId && ( | |||||
| <Confirm | <Confirm | ||||
| isShow | |||||
| title={t('common.modelProvider.confirmDelete')} | title={t('common.modelProvider.confirmDelete')} | ||||
| isShow={showConfirm} | |||||
| onCancel={() => setShowConfirm(false)} | |||||
| onConfirm={handleRemove} | |||||
| isDisabled={doingAction} | |||||
| onCancel={closeConfirmDelete} | |||||
| onConfirm={handleConfirmDelete} | |||||
| /> | /> | ||||
| ) | ) | ||||
| } | } |
| import type { FC } from 'react' | |||||
| import { | |||||
| memo, | |||||
| useCallback, | |||||
| useEffect, | |||||
| useMemo, | |||||
| useState, | |||||
| } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import { | |||||
| RiErrorWarningFill, | |||||
| } from '@remixicon/react' | |||||
| import type { | |||||
| CredentialFormSchema, | |||||
| CredentialFormSchemaRadio, | |||||
| CredentialFormSchemaSelect, | |||||
| CredentialFormSchemaTextInput, | |||||
| CustomConfigurationModelFixedFields, | |||||
| FormValue, | |||||
| ModelLoadBalancingConfigEntry, | |||||
| ModelProvider, | |||||
| } from '../declarations' | |||||
| import { | |||||
| ConfigurationMethodEnum, | |||||
| FormTypeEnum, | |||||
| } from '../declarations' | |||||
| import { | |||||
| useLanguage, | |||||
| } from '../hooks' | |||||
| import { useValidate } from '../../key-validator/hooks' | |||||
| import { ValidatedStatus } from '../../key-validator/declarations' | |||||
| import { validateLoadBalancingCredentials } from '../utils' | |||||
| import Form from './Form' | |||||
| import Button from '@/app/components/base/button' | |||||
| import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' | |||||
| import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' | |||||
| import { | |||||
| PortalToFollowElem, | |||||
| PortalToFollowElemContent, | |||||
| } from '@/app/components/base/portal-to-follow-elem' | |||||
| import { useToastContext } from '@/app/components/base/toast' | |||||
| import Confirm from '@/app/components/base/confirm' | |||||
| type ModelModalProps = { | |||||
| provider: ModelProvider | |||||
| configurationMethod: ConfigurationMethodEnum | |||||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields | |||||
| entry?: ModelLoadBalancingConfigEntry | |||||
| onCancel: () => void | |||||
| onSave: (entry: ModelLoadBalancingConfigEntry) => void | |||||
| onRemove: () => void | |||||
| } | |||||
| const ModelLoadBalancingEntryModal: FC<ModelModalProps> = ({ | |||||
| provider, | |||||
| configurationMethod, | |||||
| currentCustomConfigurationModelFixedFields, | |||||
| entry, | |||||
| onCancel, | |||||
| onSave, | |||||
| onRemove, | |||||
| }) => { | |||||
| const providerFormSchemaPredefined = configurationMethod === ConfigurationMethodEnum.predefinedModel | |||||
| // const { credentials: formSchemasValue } = useProviderCredentialsAndLoadBalancing( | |||||
| // provider.provider, | |||||
| // configurationMethod, | |||||
| // providerFormSchemaPredefined && provider.custom_configuration.status === CustomConfigurationStatusEnum.active, | |||||
| // currentCustomConfigurationModelFixedFields, | |||||
| // ) | |||||
| const isEditMode = !!entry | |||||
| const { t } = useTranslation() | |||||
| const { notify } = useToastContext() | |||||
| const language = useLanguage() | |||||
| const [loading, setLoading] = useState(false) | |||||
| const [showConfirm, setShowConfirm] = useState(false) | |||||
| const formSchemas = useMemo(() => { | |||||
| return [ | |||||
| { | |||||
| type: FormTypeEnum.textInput, | |||||
| label: { | |||||
| en_US: 'Config Name', | |||||
| zh_Hans: '配置名称', | |||||
| }, | |||||
| variable: 'name', | |||||
| required: true, | |||||
| show_on: [], | |||||
| placeholder: { | |||||
| en_US: 'Enter your Config Name here', | |||||
| zh_Hans: '输入配置名称', | |||||
| }, | |||||
| } as CredentialFormSchemaTextInput, | |||||
| ...( | |||||
| providerFormSchemaPredefined | |||||
| ? provider.provider_credential_schema.credential_form_schemas | |||||
| : provider.model_credential_schema.credential_form_schemas | |||||
| ), | |||||
| ] | |||||
| }, [ | |||||
| providerFormSchemaPredefined, | |||||
| provider.provider_credential_schema?.credential_form_schemas, | |||||
| provider.model_credential_schema?.credential_form_schemas, | |||||
| ]) | |||||
| const [ | |||||
| requiredFormSchemas, | |||||
| secretFormSchemas, | |||||
| defaultFormSchemaValue, | |||||
| showOnVariableMap, | |||||
| ] = useMemo(() => { | |||||
| const requiredFormSchemas: CredentialFormSchema[] = [] | |||||
| const secretFormSchemas: CredentialFormSchema[] = [] | |||||
| const defaultFormSchemaValue: Record<string, string | number> = {} | |||||
| const showOnVariableMap: Record<string, string[]> = {} | |||||
| formSchemas.forEach((formSchema) => { | |||||
| if (formSchema.required) | |||||
| requiredFormSchemas.push(formSchema) | |||||
| if (formSchema.type === FormTypeEnum.secretInput) | |||||
| secretFormSchemas.push(formSchema) | |||||
| if (formSchema.default) | |||||
| defaultFormSchemaValue[formSchema.variable] = formSchema.default | |||||
| if (formSchema.show_on.length) { | |||||
| formSchema.show_on.forEach((showOnItem) => { | |||||
| if (!showOnVariableMap[showOnItem.variable]) | |||||
| showOnVariableMap[showOnItem.variable] = [] | |||||
| if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) | |||||
| showOnVariableMap[showOnItem.variable].push(formSchema.variable) | |||||
| }) | |||||
| } | |||||
| if (formSchema.type === FormTypeEnum.select || formSchema.type === FormTypeEnum.radio) { | |||||
| (formSchema as (CredentialFormSchemaRadio | CredentialFormSchemaSelect)).options.forEach((option) => { | |||||
| if (option.show_on.length) { | |||||
| option.show_on.forEach((showOnItem) => { | |||||
| if (!showOnVariableMap[showOnItem.variable]) | |||||
| showOnVariableMap[showOnItem.variable] = [] | |||||
| if (!showOnVariableMap[showOnItem.variable].includes(formSchema.variable)) | |||||
| showOnVariableMap[showOnItem.variable].push(formSchema.variable) | |||||
| }) | |||||
| } | |||||
| }) | |||||
| } | |||||
| }) | |||||
| return [ | |||||
| requiredFormSchemas, | |||||
| secretFormSchemas, | |||||
| defaultFormSchemaValue, | |||||
| showOnVariableMap, | |||||
| ] | |||||
| }, [formSchemas]) | |||||
| const [initialValue, setInitialValue] = useState<ModelLoadBalancingConfigEntry['credentials']>() | |||||
| useEffect(() => { | |||||
| if (entry && !initialValue) { | |||||
| setInitialValue({ | |||||
| ...defaultFormSchemaValue, | |||||
| ...entry.credentials, | |||||
| id: entry.id, | |||||
| name: entry.name, | |||||
| } as Record<string, string | undefined | boolean>) | |||||
| } | |||||
| }, [entry, defaultFormSchemaValue, initialValue]) | |||||
| const formSchemasValue = useMemo(() => ({ | |||||
| ...currentCustomConfigurationModelFixedFields, | |||||
| ...initialValue, | |||||
| }), [currentCustomConfigurationModelFixedFields, initialValue]) | |||||
| const initialFormSchemasValue: Record<string, string | number> = useMemo(() => { | |||||
| return { | |||||
| ...defaultFormSchemaValue, | |||||
| ...formSchemasValue, | |||||
| } as Record<string, string | number> | |||||
| }, [formSchemasValue, defaultFormSchemaValue]) | |||||
| const [value, setValue] = useState(initialFormSchemasValue) | |||||
| useEffect(() => { | |||||
| setValue(initialFormSchemasValue) | |||||
| }, [initialFormSchemasValue]) | |||||
| const [_, validating, validatedStatusState] = useValidate(value) | |||||
| const filteredRequiredFormSchemas = requiredFormSchemas.filter((requiredFormSchema) => { | |||||
| if (requiredFormSchema.show_on.length && requiredFormSchema.show_on.every(showOnItem => value[showOnItem.variable] === showOnItem.value)) | |||||
| return true | |||||
| if (!requiredFormSchema.show_on.length) | |||||
| return true | |||||
| return false | |||||
| }) | |||||
| const getSecretValues = useCallback((v: FormValue) => { | |||||
| return secretFormSchemas.reduce((prev, next) => { | |||||
| if (isEditMode && v[next.variable] && v[next.variable] === initialFormSchemasValue[next.variable]) | |||||
| prev[next.variable] = '[__HIDDEN__]' | |||||
| return prev | |||||
| }, {} as Record<string, string>) | |||||
| }, [initialFormSchemasValue, isEditMode, secretFormSchemas]) | |||||
| // const handleValueChange = ({ __model_type, __model_name, ...v }: FormValue) => { | |||||
| const handleValueChange = (v: FormValue) => { | |||||
| setValue(v) | |||||
| } | |||||
| const handleSave = async () => { | |||||
| try { | |||||
| setLoading(true) | |||||
| const res = await validateLoadBalancingCredentials( | |||||
| providerFormSchemaPredefined, | |||||
| provider.provider, | |||||
| { | |||||
| ...value, | |||||
| ...getSecretValues(value), | |||||
| }, | |||||
| entry?.id, | |||||
| ) | |||||
| if (res.status === ValidatedStatus.Success) { | |||||
| // notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) | |||||
| const { __model_type, __model_name, name, ...credentials } = value | |||||
| onSave({ | |||||
| ...(entry || {}), | |||||
| name: name as string, | |||||
| credentials: credentials as Record<string, string | boolean | undefined>, | |||||
| }) | |||||
| // onCancel() | |||||
| } | |||||
| else { | |||||
| notify({ type: 'error', message: res.message || '' }) | |||||
| } | |||||
| } | |||||
| finally { | |||||
| setLoading(false) | |||||
| } | |||||
| } | |||||
| const handleRemove = () => { | |||||
| onRemove?.() | |||||
| } | |||||
| return ( | |||||
| <PortalToFollowElem open> | |||||
| <PortalToFollowElemContent className='z-[60] h-full w-full'> | |||||
| <div className='fixed inset-0 flex items-center justify-center bg-black/[.25]'> | |||||
| <div className='mx-2 max-h-[calc(100vh-120px)] w-[640px] overflow-y-auto rounded-2xl bg-white shadow-xl'> | |||||
| <div className='px-8 pt-8'> | |||||
| <div className='mb-2 flex items-center justify-between'> | |||||
| <div className='text-xl font-semibold text-gray-900'>{t(isEditMode ? 'common.modelProvider.editConfig' : 'common.modelProvider.addConfig')}</div> | |||||
| </div> | |||||
| <Form | |||||
| value={value} | |||||
| onChange={handleValueChange} | |||||
| formSchemas={formSchemas} | |||||
| validating={validating} | |||||
| validatedSuccess={validatedStatusState.status === ValidatedStatus.Success} | |||||
| showOnVariableMap={showOnVariableMap} | |||||
| isEditMode={isEditMode} | |||||
| /> | |||||
| <div className='sticky bottom-0 flex flex-wrap items-center justify-between gap-y-2 bg-white py-6'> | |||||
| { | |||||
| (provider.help && (provider.help.title || provider.help.url)) | |||||
| ? ( | |||||
| <a | |||||
| href={provider.help?.url[language] || provider.help?.url.en_US} | |||||
| target='_blank' rel='noopener noreferrer' | |||||
| className='inline-flex items-center text-xs text-primary-600' | |||||
| onClick={e => !provider.help.url && e.preventDefault()} | |||||
| > | |||||
| {provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US} | |||||
| <LinkExternal02 className='ml-1 h-3 w-3' /> | |||||
| </a> | |||||
| ) | |||||
| : <div /> | |||||
| } | |||||
| <div> | |||||
| { | |||||
| isEditMode && ( | |||||
| <Button | |||||
| size='large' | |||||
| className='mr-2 text-[#D92D20]' | |||||
| onClick={() => setShowConfirm(true)} | |||||
| > | |||||
| {t('common.operation.remove')} | |||||
| </Button> | |||||
| ) | |||||
| } | |||||
| <Button | |||||
| size='large' | |||||
| className='mr-2' | |||||
| onClick={onCancel} | |||||
| > | |||||
| {t('common.operation.cancel')} | |||||
| </Button> | |||||
| <Button | |||||
| size='large' | |||||
| variant='primary' | |||||
| onClick={handleSave} | |||||
| disabled={loading || filteredRequiredFormSchemas.some(item => value[item.variable] === undefined)} | |||||
| > | |||||
| {t('common.operation.save')} | |||||
| </Button> | |||||
| </div> | |||||
| </div> | |||||
| </div> | |||||
| <div className='border-t-[0.5px] border-t-black/5'> | |||||
| { | |||||
| (validatedStatusState.status === ValidatedStatus.Error && validatedStatusState.message) | |||||
| ? ( | |||||
| <div className='flex bg-[#FEF3F2] px-[10px] py-3 text-xs text-[#D92D20]'> | |||||
| <RiErrorWarningFill className='mr-2 mt-[1px] h-[14px] w-[14px]' /> | |||||
| {validatedStatusState.message} | |||||
| </div> | |||||
| ) | |||||
| : ( | |||||
| <div className='flex items-center justify-center bg-gray-50 py-3 text-xs text-gray-500'> | |||||
| <Lock01 className='mr-1 h-3 w-3 text-gray-500' /> | |||||
| {t('common.modelProvider.encrypted.front')} | |||||
| <a | |||||
| className='mx-1 text-primary-600' | |||||
| target='_blank' rel='noopener noreferrer' | |||||
| href='https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html' | |||||
| > | |||||
| PKCS1_OAEP | |||||
| </a> | |||||
| {t('common.modelProvider.encrypted.back')} | |||||
| </div> | |||||
| ) | |||||
| } | |||||
| </div> | |||||
| </div> | |||||
| { | |||||
| showConfirm && ( | |||||
| <Confirm | |||||
| title={t('common.modelProvider.confirmDelete')} | |||||
| isShow={showConfirm} | |||||
| onCancel={() => setShowConfirm(false)} | |||||
| onConfirm={handleRemove} | |||||
| /> | |||||
| ) | |||||
| } | |||||
| </div> | |||||
| </PortalToFollowElemContent> | |||||
| </PortalToFollowElem> | |||||
| ) | |||||
| } | |||||
| export default memo(ModelLoadBalancingEntryModal) |
| import type { FC } from 'react' | |||||
| import { useMemo } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | import { useTranslation } from 'react-i18next' | ||||
| import { RiEqualizer2Line } from '@remixicon/react' | |||||
| import type { ModelProvider } from '../declarations' | |||||
| import type { | |||||
| ModelProvider, | |||||
| } from '../declarations' | |||||
| import { | import { | ||||
| ConfigurationMethodEnum, | ConfigurationMethodEnum, | ||||
| CustomConfigurationStatusEnum, | CustomConfigurationStatusEnum, | ||||
| import PriorityUseTip from './priority-use-tip' | import PriorityUseTip from './priority-use-tip' | ||||
| import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './index' | import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './index' | ||||
| import Indicator from '@/app/components/header/indicator' | import Indicator from '@/app/components/header/indicator' | ||||
| import Button from '@/app/components/base/button' | |||||
| import { changeModelProviderPriority } from '@/service/common' | import { changeModelProviderPriority } from '@/service/common' | ||||
| import { useToastContext } from '@/app/components/base/toast' | import { useToastContext } from '@/app/components/base/toast' | ||||
| import { useEventEmitterContextContext } from '@/context/event-emitter' | import { useEventEmitterContextContext } from '@/context/event-emitter' | ||||
| import cn from '@/utils/classnames' | |||||
| import { useCredentialStatus } from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks' | |||||
| import { ConfigProvider } from '@/app/components/header/account-setting/model-provider-page/model-auth' | |||||
| type CredentialPanelProps = { | type CredentialPanelProps = { | ||||
| provider: ModelProvider | provider: ModelProvider | ||||
| onSetup: () => void | |||||
| } | } | ||||
| const CredentialPanel: FC<CredentialPanelProps> = ({ | |||||
| const CredentialPanel = ({ | |||||
| provider, | provider, | ||||
| onSetup, | |||||
| }) => { | |||||
| }: CredentialPanelProps) => { | |||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const { notify } = useToastContext() | const { notify } = useToastContext() | ||||
| const { eventEmitter } = useEventEmitterContextContext() | const { eventEmitter } = useEventEmitterContextContext() | ||||
| const priorityUseType = provider.preferred_provider_type | const priorityUseType = provider.preferred_provider_type | ||||
| const isCustomConfigured = customConfig.status === CustomConfigurationStatusEnum.active | const isCustomConfigured = customConfig.status === CustomConfigurationStatusEnum.active | ||||
| const configurateMethods = provider.configurate_methods | const configurateMethods = provider.configurate_methods | ||||
| const { | |||||
| hasCredential, | |||||
| authorized, | |||||
| authRemoved, | |||||
| current_credential_name, | |||||
| notAllowedToUse, | |||||
| } = useCredentialStatus(provider) | |||||
| const handleChangePriority = async (key: PreferredProviderTypeEnum) => { | const handleChangePriority = async (key: PreferredProviderTypeEnum) => { | ||||
| const res = await changeModelProviderPriority({ | const res = await changeModelProviderPriority({ | ||||
| } as any) | } as any) | ||||
| } | } | ||||
| } | } | ||||
| const credentialLabel = useMemo(() => { | |||||
| if (!hasCredential) | |||||
| return t('common.modelProvider.auth.unAuthorized') | |||||
| if (authorized) | |||||
| return current_credential_name | |||||
| if (authRemoved) | |||||
| return t('common.modelProvider.auth.authRemoved') | |||||
| return '' | |||||
| }, [authorized, authRemoved, current_credential_name, hasCredential]) | |||||
| const color = useMemo(() => { | |||||
| if (authRemoved) | |||||
| return 'red' | |||||
| if (notAllowedToUse) | |||||
| return 'gray' | |||||
| return 'green' | |||||
| }, [authRemoved, notAllowedToUse]) | |||||
| return ( | return ( | ||||
| <> | <> | ||||
| { | { | ||||
| provider.provider_credential_schema && ( | provider.provider_credential_schema && ( | ||||
| <div className='relative ml-1 w-[112px] shrink-0 rounded-lg border-[0.5px] border-components-panel-border bg-white/[0.18] p-1'> | |||||
| <div className='system-xs-medium-uppercase mb-1 flex h-5 items-center justify-between pl-2 pr-[7px] pt-1 text-text-tertiary'> | |||||
| API-KEY | |||||
| <Indicator color={isCustomConfigured ? 'green' : 'red'} /> | |||||
| <div className={cn( | |||||
| 'relative ml-1 w-[120px] shrink-0 rounded-lg border-[0.5px] border-components-panel-border bg-white/[0.18] p-1', | |||||
| authRemoved && 'border-state-destructive-border bg-state-destructive-hover', | |||||
| )}> | |||||
| <div className='system-xs-medium mb-1 flex h-5 items-center justify-between pl-2 pr-[7px] pt-1 text-text-tertiary'> | |||||
| <div | |||||
| className={cn( | |||||
| 'grow truncate', | |||||
| authRemoved && 'text-text-destructive', | |||||
| )} | |||||
| title={credentialLabel} | |||||
| > | |||||
| {credentialLabel} | |||||
| </div> | |||||
| <Indicator className='shrink-0' color={color} /> | |||||
| </div> | </div> | ||||
| <div className='flex items-center gap-0.5'> | <div className='flex items-center gap-0.5'> | ||||
| <Button | |||||
| className='grow' | |||||
| size='small' | |||||
| onClick={onSetup} | |||||
| > | |||||
| <RiEqualizer2Line className='mr-1 h-3.5 w-3.5' /> | |||||
| {t('common.operation.setup')} | |||||
| </Button> | |||||
| <ConfigProvider | |||||
| provider={provider} | |||||
| configurationMethod={ConfigurationMethodEnum.predefinedModel} | |||||
| /> | |||||
| { | { | ||||
| systemConfig.enabled && isCustomConfigured && ( | systemConfig.enabled && isCustomConfigured && ( | ||||
| <PrioritySelector | <PrioritySelector |
| RiLoader2Line, | RiLoader2Line, | ||||
| } from '@remixicon/react' | } from '@remixicon/react' | ||||
| import type { | import type { | ||||
| CustomConfigurationModelFixedFields, | |||||
| ModelItem, | ModelItem, | ||||
| ModelProvider, | ModelProvider, | ||||
| } from '../declarations' | } from '../declarations' | ||||
| import CredentialPanel from './credential-panel' | import CredentialPanel from './credential-panel' | ||||
| import QuotaPanel from './quota-panel' | import QuotaPanel from './quota-panel' | ||||
| import ModelList from './model-list' | import ModelList from './model-list' | ||||
| import AddModelButton from './add-model-button' | |||||
| import { fetchModelProviderModelList } from '@/service/common' | import { fetchModelProviderModelList } from '@/service/common' | ||||
| import { useEventEmitterContextContext } from '@/context/event-emitter' | import { useEventEmitterContextContext } from '@/context/event-emitter' | ||||
| import { IS_CE_EDITION } from '@/config' | import { IS_CE_EDITION } from '@/config' | ||||
| import { useAppContext } from '@/context/app-context' | import { useAppContext } from '@/context/app-context' | ||||
| import cn from '@/utils/classnames' | import cn from '@/utils/classnames' | ||||
| import { AddCustomModel } from '@/app/components/header/account-setting/model-provider-page/model-auth' | |||||
| export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST' | export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST' | ||||
| type ProviderAddedCardProps = { | type ProviderAddedCardProps = { | ||||
| notConfigured?: boolean | notConfigured?: boolean | ||||
| provider: ModelProvider | provider: ModelProvider | ||||
| onOpenModal: (configurationMethod: ConfigurationMethodEnum, currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => void | |||||
| } | } | ||||
| const ProviderAddedCard: FC<ProviderAddedCardProps> = ({ | const ProviderAddedCard: FC<ProviderAddedCardProps> = ({ | ||||
| notConfigured, | notConfigured, | ||||
| provider, | provider, | ||||
| onOpenModal, | |||||
| }) => { | }) => { | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const { eventEmitter } = useEventEmitterContextContext() | const { eventEmitter } = useEventEmitterContextContext() | ||||
| { | { | ||||
| showCredential && ( | showCredential && ( | ||||
| <CredentialPanel | <CredentialPanel | ||||
| onSetup={() => onOpenModal(ConfigurationMethodEnum.predefinedModel)} | |||||
| provider={provider} | provider={provider} | ||||
| /> | /> | ||||
| ) | ) | ||||
| )} | )} | ||||
| { | { | ||||
| configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && ( | configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && ( | ||||
| <AddModelButton | |||||
| onClick={() => onOpenModal(ConfigurationMethodEnum.customizableModel)} | |||||
| className='flex' | |||||
| <AddCustomModel | |||||
| provider={provider} | |||||
| configurationMethod={ConfigurationMethodEnum.customizableModel} | |||||
| /> | /> | ||||
| ) | ) | ||||
| } | } | ||||
| provider={provider} | provider={provider} | ||||
| models={modelList} | models={modelList} | ||||
| onCollapse={() => setCollapsed(true)} | onCollapse={() => setCollapsed(true)} | ||||
| onConfig={currentCustomConfigurationModelFixedFields => onOpenModal(ConfigurationMethodEnum.customizableModel, currentCustomConfigurationModelFixedFields)} | |||||
| onChange={(provider: string) => getModelList(provider)} | onChange={(provider: string) => getModelList(provider)} | ||||
| /> | /> | ||||
| ) | ) |
| import { memo, useCallback } from 'react' | import { memo, useCallback } from 'react' | ||||
| import { useTranslation } from 'react-i18next' | import { useTranslation } from 'react-i18next' | ||||
| import { useDebounceFn } from 'ahooks' | import { useDebounceFn } from 'ahooks' | ||||
| import type { CustomConfigurationModelFixedFields, ModelItem, ModelProvider } from '../declarations' | |||||
| import { ConfigurationMethodEnum, ModelStatusEnum } from '../declarations' | |||||
| import ModelBadge from '../model-badge' | |||||
| import type { ModelItem, ModelProvider } from '../declarations' | |||||
| import { ModelStatusEnum } from '../declarations' | |||||
| import ModelIcon from '../model-icon' | import ModelIcon from '../model-icon' | ||||
| import ModelName from '../model-name' | import ModelName from '../model-name' | ||||
| import classNames from '@/utils/classnames' | import classNames from '@/utils/classnames' | ||||
| import Button from '@/app/components/base/button' | |||||
| import { Balance } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' | import { Balance } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' | ||||
| import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' | |||||
| import Switch from '@/app/components/base/switch' | import Switch from '@/app/components/base/switch' | ||||
| import Tooltip from '@/app/components/base/tooltip' | import Tooltip from '@/app/components/base/tooltip' | ||||
| import { useProviderContext, useProviderContextSelector } from '@/context/provider-context' | import { useProviderContext, useProviderContextSelector } from '@/context/provider-context' | ||||
| import { disableModel, enableModel } from '@/service/common' | import { disableModel, enableModel } from '@/service/common' | ||||
| import { Plan } from '@/app/components/billing/type' | import { Plan } from '@/app/components/billing/type' | ||||
| import { useAppContext } from '@/context/app-context' | import { useAppContext } from '@/context/app-context' | ||||
| import { ConfigModel } from '../model-auth' | |||||
| import Badge from '@/app/components/base/badge' | |||||
| export type ModelListItemProps = { | export type ModelListItemProps = { | ||||
| model: ModelItem | model: ModelItem | ||||
| provider: ModelProvider | provider: ModelProvider | ||||
| isConfigurable: boolean | isConfigurable: boolean | ||||
| onConfig: (currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => void | |||||
| onModifyLoadBalancing?: (model: ModelItem) => void | onModifyLoadBalancing?: (model: ModelItem) => void | ||||
| } | } | ||||
| const ModelListItem = ({ model, provider, isConfigurable, onConfig, onModifyLoadBalancing }: ModelListItemProps) => { | |||||
| const ModelListItem = ({ model, provider, isConfigurable, onModifyLoadBalancing }: ModelListItemProps) => { | |||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const { plan } = useProviderContext() | const { plan } = useProviderContext() | ||||
| const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled) | const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled) | ||||
| return ( | return ( | ||||
| <div | <div | ||||
| key={model.model} | |||||
| key={`${model.model}-${model.fetch_from}`} | |||||
| className={classNames( | className={classNames( | ||||
| 'group flex h-8 items-center rounded-lg pl-2 pr-2.5', | 'group flex h-8 items-center rounded-lg pl-2 pr-2.5', | ||||
| isConfigurable && 'hover:bg-components-panel-on-panel-item-bg-hover', | isConfigurable && 'hover:bg-components-panel-on-panel-item-bg-hover', | ||||
| showMode | showMode | ||||
| showContextSize | showContextSize | ||||
| > | > | ||||
| {modelLoadBalancingEnabled && !model.deprecated && model.load_balancing_enabled && ( | |||||
| <ModelBadge className='ml-1 border-text-accent-secondary uppercase text-text-accent-secondary'> | |||||
| <Balance className='mr-0.5 h-3 w-3' /> | |||||
| {t('common.modelProvider.loadBalancingHeadline')} | |||||
| </ModelBadge> | |||||
| )} | |||||
| </ModelName> | </ModelName> | ||||
| <div className='flex shrink-0 items-center'> | <div className='flex shrink-0 items-center'> | ||||
| {modelLoadBalancingEnabled && !model.deprecated && model.load_balancing_enabled && !model.has_invalid_load_balancing_configs && ( | |||||
| <Badge className='mr-1 h-[18px] w-[18px] items-center justify-center border-text-accent-secondary p-0'> | |||||
| <Balance className='h-3 w-3 text-text-accent-secondary' /> | |||||
| </Badge> | |||||
| )} | |||||
| { | { | ||||
| model.fetch_from === ConfigurationMethodEnum.customizableModel | |||||
| ? (isCurrentWorkspaceManager && ( | |||||
| <Button | |||||
| size='small' | |||||
| className='hidden group-hover:flex' | |||||
| onClick={() => onConfig({ __model_name: model.model, __model_type: model.model_type })} | |||||
| > | |||||
| <Settings01 className='mr-1 h-3.5 w-3.5' /> | |||||
| {t('common.modelProvider.config')} | |||||
| </Button> | |||||
| )) | |||||
| : (isCurrentWorkspaceManager && (modelLoadBalancingEnabled || plan.type === Plan.sandbox) && !model.deprecated && [ModelStatusEnum.active, ModelStatusEnum.disabled].includes(model.status)) | |||||
| ? ( | |||||
| <Button | |||||
| size='small' | |||||
| className='opacity-0 transition-opacity group-hover:opacity-100' | |||||
| onClick={() => onModifyLoadBalancing?.(model)} | |||||
| > | |||||
| <Balance className='mr-1 h-3.5 w-3.5' /> | |||||
| {t('common.modelProvider.configLoadBalancing')} | |||||
| </Button> | |||||
| ) | |||||
| : null | |||||
| (isCurrentWorkspaceManager && (modelLoadBalancingEnabled || plan.type === Plan.sandbox) && !model.deprecated && [ModelStatusEnum.active, ModelStatusEnum.disabled].includes(model.status)) && ( | |||||
| <ConfigModel | |||||
| onClick={() => onModifyLoadBalancing?.(model)} | |||||
| loadBalancingEnabled={model.load_balancing_enabled} | |||||
| loadBalancingInvalid={model.has_invalid_load_balancing_configs} | |||||
| credentialRemoved={model.status === ModelStatusEnum.credentialRemoved} | |||||
| /> | |||||
| ) | |||||
| } | } | ||||
| { | { | ||||
| model.deprecated | model.deprecated |
| RiArrowRightSLine, | RiArrowRightSLine, | ||||
| } from '@remixicon/react' | } from '@remixicon/react' | ||||
| import type { | import type { | ||||
| CustomConfigurationModelFixedFields, | |||||
| Credential, | |||||
| ModelItem, | ModelItem, | ||||
| ModelProvider, | ModelProvider, | ||||
| } from '../declarations' | } from '../declarations' | ||||
| ConfigurationMethodEnum, | ConfigurationMethodEnum, | ||||
| } from '../declarations' | } from '../declarations' | ||||
| // import Tab from './tab' | // import Tab from './tab' | ||||
| import AddModelButton from './add-model-button' | |||||
| import ModelListItem from './model-list-item' | import ModelListItem from './model-list-item' | ||||
| import { useModalContextSelector } from '@/context/modal-context' | import { useModalContextSelector } from '@/context/modal-context' | ||||
| import { useAppContext } from '@/context/app-context' | import { useAppContext } from '@/context/app-context' | ||||
| import { AddCustomModel } from '@/app/components/header/account-setting/model-provider-page/model-auth' | |||||
| type ModelListProps = { | type ModelListProps = { | ||||
| provider: ModelProvider | provider: ModelProvider | ||||
| models: ModelItem[] | models: ModelItem[] | ||||
| onCollapse: () => void | onCollapse: () => void | ||||
| onConfig: (currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields) => void | |||||
| onChange?: (provider: string) => void | onChange?: (provider: string) => void | ||||
| } | } | ||||
| const ModelList: FC<ModelListProps> = ({ | const ModelList: FC<ModelListProps> = ({ | ||||
| provider, | provider, | ||||
| models, | models, | ||||
| onCollapse, | onCollapse, | ||||
| onConfig, | |||||
| onChange, | onChange, | ||||
| }) => { | }) => { | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const configurativeMethods = provider.configurate_methods.filter(method => method !== ConfigurationMethodEnum.fetchFromRemote) | const configurativeMethods = provider.configurate_methods.filter(method => method !== ConfigurationMethodEnum.fetchFromRemote) | ||||
| const { isCurrentWorkspaceManager } = useAppContext() | const { isCurrentWorkspaceManager } = useAppContext() | ||||
| const isConfigurable = configurativeMethods.includes(ConfigurationMethodEnum.customizableModel) | const isConfigurable = configurativeMethods.includes(ConfigurationMethodEnum.customizableModel) | ||||
| const setShowModelLoadBalancingModal = useModalContextSelector(state => state.setShowModelLoadBalancingModal) | const setShowModelLoadBalancingModal = useModalContextSelector(state => state.setShowModelLoadBalancingModal) | ||||
| const onModifyLoadBalancing = useCallback((model: ModelItem) => { | |||||
| const onModifyLoadBalancing = useCallback((model: ModelItem, credential?: Credential) => { | |||||
| setShowModelLoadBalancingModal({ | setShowModelLoadBalancingModal({ | ||||
| provider, | provider, | ||||
| credential, | |||||
| configurateMethod: model.fetch_from, | |||||
| model: model!, | model: model!, | ||||
| open: !!model, | open: !!model, | ||||
| onClose: () => setShowModelLoadBalancingModal(null), | onClose: () => setShowModelLoadBalancingModal(null), | ||||
| <RiArrowRightSLine className='mr-0.5 h-4 w-4 rotate-90' /> | <RiArrowRightSLine className='mr-0.5 h-4 w-4 rotate-90' /> | ||||
| </span> | </span> | ||||
| </span> | </span> | ||||
| {/* { | |||||
| isConfigurable && canSystemConfig && ( | |||||
| <span className='flex items-center'> | |||||
| <Tab active='all' onSelect={() => {}} /> | |||||
| </span> | |||||
| ) | |||||
| } */} | |||||
| { | { | ||||
| isConfigurable && isCurrentWorkspaceManager && ( | isConfigurable && isCurrentWorkspaceManager && ( | ||||
| <div className='flex grow justify-end'> | <div className='flex grow justify-end'> | ||||
| <AddModelButton onClick={() => onConfig()} /> | |||||
| <AddCustomModel | |||||
| provider={provider} | |||||
| configurationMethod={ConfigurationMethodEnum.customizableModel} | |||||
| currentCustomConfigurationModelFixedFields={undefined} | |||||
| /> | |||||
| </div> | </div> | ||||
| ) | ) | ||||
| } | } | ||||
| { | { | ||||
| models.map(model => ( | models.map(model => ( | ||||
| <ModelListItem | <ModelListItem | ||||
| key={model.model} | |||||
| key={`${model.model}-${model.fetch_from}`} | |||||
| {...{ | {...{ | ||||
| model, | model, | ||||
| provider, | provider, | ||||
| isConfigurable, | isConfigurable, | ||||
| onConfig, | |||||
| onModifyLoadBalancing, | onModifyLoadBalancing, | ||||
| }} | }} | ||||
| /> | /> |
| import type { Dispatch, SetStateAction } from 'react' | import type { Dispatch, SetStateAction } from 'react' | ||||
| import { useCallback } from 'react' | |||||
| import { useCallback, useMemo } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | import { useTranslation } from 'react-i18next' | ||||
| import { | import { | ||||
| RiDeleteBinLine, | RiDeleteBinLine, | ||||
| RiEqualizer2Line, | |||||
| } from '@remixicon/react' | } from '@remixicon/react' | ||||
| import type { ConfigurationMethodEnum, CustomConfigurationModelFixedFields, ModelLoadBalancingConfig, ModelLoadBalancingConfigEntry, ModelProvider } from '../declarations' | |||||
| import type { | |||||
| Credential, | |||||
| CustomConfigurationModelFixedFields, | |||||
| CustomModelCredential, | |||||
| ModelCredential, | |||||
| ModelLoadBalancingConfig, | |||||
| ModelLoadBalancingConfigEntry, | |||||
| ModelProvider, | |||||
| } from '../declarations' | |||||
| import { ConfigurationMethodEnum } from '../declarations' | |||||
| import Indicator from '../../../indicator' | import Indicator from '../../../indicator' | ||||
| import CooldownTimer from './cooldown-timer' | import CooldownTimer from './cooldown-timer' | ||||
| import classNames from '@/utils/classnames' | import classNames from '@/utils/classnames' | ||||
| import Tooltip from '@/app/components/base/tooltip' | import Tooltip from '@/app/components/base/tooltip' | ||||
| import Switch from '@/app/components/base/switch' | import Switch from '@/app/components/base/switch' | ||||
| import { Balance } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' | import { Balance } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' | ||||
| import { Edit02, Plus02 } from '@/app/components/base/icons/src/vender/line/general' | |||||
| import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' | import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' | ||||
| import { useModalContextSelector } from '@/context/modal-context' | |||||
| import UpgradeBtn from '@/app/components/billing/upgrade-btn' | import UpgradeBtn from '@/app/components/billing/upgrade-btn' | ||||
| import s from '@/app/components/custom/style.module.css' | import s from '@/app/components/custom/style.module.css' | ||||
| import GridMask from '@/app/components/base/grid-mask' | import GridMask from '@/app/components/base/grid-mask' | ||||
| import { useProviderContextSelector } from '@/context/provider-context' | import { useProviderContextSelector } from '@/context/provider-context' | ||||
| import { IS_CE_EDITION } from '@/config' | import { IS_CE_EDITION } from '@/config' | ||||
| import { AddCredentialInLoadBalancing } from '@/app/components/header/account-setting/model-provider-page/model-auth' | |||||
| import { useModelModalHandler } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||||
| import Badge from '@/app/components/base/badge/index' | |||||
| export type ModelLoadBalancingConfigsProps = { | export type ModelLoadBalancingConfigsProps = { | ||||
| draftConfig?: ModelLoadBalancingConfig | draftConfig?: ModelLoadBalancingConfig | ||||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields | currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields | ||||
| withSwitch?: boolean | withSwitch?: boolean | ||||
| className?: string | className?: string | ||||
| modelCredential: ModelCredential | |||||
| onUpdate?: () => void | |||||
| model: CustomModelCredential | |||||
| } | } | ||||
| const ModelLoadBalancingConfigs = ({ | const ModelLoadBalancingConfigs = ({ | ||||
| draftConfig, | draftConfig, | ||||
| setDraftConfig, | setDraftConfig, | ||||
| provider, | provider, | ||||
| model, | |||||
| configurationMethod, | configurationMethod, | ||||
| currentCustomConfigurationModelFixedFields, | currentCustomConfigurationModelFixedFields, | ||||
| withSwitch = false, | withSwitch = false, | ||||
| className, | className, | ||||
| modelCredential, | |||||
| onUpdate, | |||||
| }: ModelLoadBalancingConfigsProps) => { | }: ModelLoadBalancingConfigsProps) => { | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const providerFormSchemaPredefined = configurationMethod === ConfigurationMethodEnum.predefinedModel | |||||
| const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled) | const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled) | ||||
| const handleOpenModal = useModelModalHandler() | |||||
| const updateConfigEntry = useCallback( | const updateConfigEntry = useCallback( | ||||
| ( | ( | ||||
| [setDraftConfig], | [setDraftConfig], | ||||
| ) | ) | ||||
| const addConfigEntry = useCallback((credential: Credential) => { | |||||
| setDraftConfig((prev: any) => { | |||||
| if (!prev) | |||||
| return prev | |||||
| return { | |||||
| ...prev, | |||||
| configs: [...prev.configs, { | |||||
| credential_id: credential.credential_id, | |||||
| enabled: true, | |||||
| name: credential.credential_name, | |||||
| }], | |||||
| } | |||||
| }) | |||||
| }, [setDraftConfig]) | |||||
| const toggleModalBalancing = useCallback((enabled: boolean) => { | const toggleModalBalancing = useCallback((enabled: boolean) => { | ||||
| if ((modelLoadBalancingEnabled || !enabled) && draftConfig) { | if ((modelLoadBalancingEnabled || !enabled) && draftConfig) { | ||||
| setDraftConfig({ | setDraftConfig({ | ||||
| })) | })) | ||||
| }, [updateConfigEntry]) | }, [updateConfigEntry]) | ||||
| const setShowModelLoadBalancingEntryModal = useModalContextSelector(state => state.setShowModelLoadBalancingEntryModal) | |||||
| const toggleEntryModal = useCallback((index?: number, entry?: ModelLoadBalancingConfigEntry) => { | |||||
| setShowModelLoadBalancingEntryModal({ | |||||
| payload: { | |||||
| currentProvider: provider, | |||||
| currentConfigurationMethod: configurationMethod, | |||||
| currentCustomConfigurationModelFixedFields, | |||||
| entry, | |||||
| index, | |||||
| }, | |||||
| onSaveCallback: ({ entry: result }) => { | |||||
| if (entry) { | |||||
| // edit | |||||
| setDraftConfig(prev => ({ | |||||
| ...prev, | |||||
| enabled: !!prev?.enabled, | |||||
| configs: prev?.configs.map((config, i) => i === index ? result! : config) || [], | |||||
| })) | |||||
| } | |||||
| else { | |||||
| // add | |||||
| setDraftConfig(prev => ({ | |||||
| ...prev, | |||||
| enabled: !!prev?.enabled, | |||||
| configs: (prev?.configs || []).concat([{ ...result!, enabled: true }]), | |||||
| })) | |||||
| } | |||||
| }, | |||||
| onRemoveCallback: ({ index }) => { | |||||
| if (index !== undefined && (draftConfig?.configs?.length ?? 0) > index) { | |||||
| setDraftConfig(prev => ({ | |||||
| ...prev, | |||||
| enabled: !!prev?.enabled, | |||||
| configs: prev?.configs.filter((_, i) => i !== index) || [], | |||||
| })) | |||||
| } | |||||
| }, | |||||
| }) | |||||
| }, [ | |||||
| configurationMethod, | |||||
| currentCustomConfigurationModelFixedFields, | |||||
| draftConfig?.configs?.length, | |||||
| provider, | |||||
| setDraftConfig, | |||||
| setShowModelLoadBalancingEntryModal, | |||||
| ]) | |||||
| const clearCountdown = useCallback((index: number) => { | const clearCountdown = useCallback((index: number) => { | ||||
| updateConfigEntry(index, ({ ttl: _, ...entry }) => { | updateConfigEntry(index, ({ ttl: _, ...entry }) => { | ||||
| return { | return { | ||||
| }) | }) | ||||
| }, [updateConfigEntry]) | }, [updateConfigEntry]) | ||||
| const validDraftConfigList = useMemo(() => { | |||||
| if (!draftConfig) | |||||
| return [] | |||||
| return draftConfig.configs | |||||
| }, [draftConfig]) | |||||
| if (!draftConfig) | if (!draftConfig) | ||||
| return null | return null | ||||
| </div> | </div> | ||||
| {draftConfig.enabled && ( | {draftConfig.enabled && ( | ||||
| <div className='flex flex-col gap-1 px-3 pb-3'> | <div className='flex flex-col gap-1 px-3 pb-3'> | ||||
| {draftConfig.configs.map((config, index) => { | |||||
| {validDraftConfigList.map((config, index) => { | |||||
| const isProviderManaged = config.name === '__inherit__' | const isProviderManaged = config.name === '__inherit__' | ||||
| const credential = modelCredential.available_credentials.find(c => c.credential_id === config.credential_id) | |||||
| return ( | return ( | ||||
| <div key={config.id || index} className='group flex h-10 items-center rounded-lg border border-components-panel-border bg-components-panel-on-panel-item-bg px-3 shadow-xs'> | <div key={config.id || index} className='group flex h-10 items-center rounded-lg border border-components-panel-border bg-components-panel-on-panel-item-bg px-3 shadow-xs'> | ||||
| <div className='flex grow items-center'> | <div className='flex grow items-center'> | ||||
| <div className='mr-1 text-[13px]'> | <div className='mr-1 text-[13px]'> | ||||
| {isProviderManaged ? t('common.modelProvider.defaultConfig') : config.name} | {isProviderManaged ? t('common.modelProvider.defaultConfig') : config.name} | ||||
| </div> | </div> | ||||
| {isProviderManaged && ( | |||||
| <span className='rounded-[5px] border border-divider-regular px-1 text-2xs uppercase text-text-tertiary'>{t('common.modelProvider.providerManaged')}</span> | |||||
| {isProviderManaged && providerFormSchemaPredefined && ( | |||||
| <Badge className='ml-2'>{t('common.modelProvider.providerManaged')}</Badge> | |||||
| )} | )} | ||||
| { | |||||
| credential?.from_enterprise && ( | |||||
| <Badge className='ml-2'>Enterprise</Badge> | |||||
| ) | |||||
| } | |||||
| </div> | </div> | ||||
| <div className='flex items-center gap-1'> | <div className='flex items-center gap-1'> | ||||
| {!isProviderManaged && ( | {!isProviderManaged && ( | ||||
| <> | <> | ||||
| <div className='flex items-center gap-1 opacity-0 transition-opacity group-hover:opacity-100'> | <div className='flex items-center gap-1 opacity-0 transition-opacity group-hover:opacity-100'> | ||||
| <span | |||||
| className='flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg bg-components-button-secondary-bg text-text-tertiary transition-colors hover:bg-components-button-secondary-bg-hover' | |||||
| onClick={() => toggleEntryModal(index, config)} | |||||
| > | |||||
| <Edit02 className='h-4 w-4' /> | |||||
| </span> | |||||
| { | |||||
| config.credential_id && !credential?.not_allowed_to_use && !credential?.from_enterprise && ( | |||||
| <span | |||||
| className='flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg bg-components-button-secondary-bg text-text-tertiary transition-colors hover:bg-components-button-secondary-bg-hover' | |||||
| onClick={() => { | |||||
| handleOpenModal( | |||||
| provider, | |||||
| configurationMethod, | |||||
| currentCustomConfigurationModelFixedFields, | |||||
| configurationMethod === ConfigurationMethodEnum.customizableModel, | |||||
| (config.credential_id && config.name) ? { | |||||
| credential_id: config.credential_id, | |||||
| credential_name: config.name, | |||||
| } : undefined, | |||||
| model, | |||||
| ) | |||||
| }} | |||||
| > | |||||
| <RiEqualizer2Line className='h-4 w-4' /> | |||||
| </span> | |||||
| ) | |||||
| } | |||||
| <span | <span | ||||
| className='flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg bg-components-button-secondary-bg text-text-tertiary transition-colors hover:bg-components-button-secondary-bg-hover' | className='flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg bg-components-button-secondary-bg text-text-tertiary transition-colors hover:bg-components-button-secondary-bg-hover' | ||||
| onClick={() => updateConfigEntry(index, () => undefined)} | onClick={() => updateConfigEntry(index, () => undefined)} | ||||
| > | > | ||||
| <RiDeleteBinLine className='h-4 w-4' /> | <RiDeleteBinLine className='h-4 w-4' /> | ||||
| </span> | </span> | ||||
| <span className='mr-2 h-3 border-r border-r-divider-subtle' /> | |||||
| </div> | </div> | ||||
| </> | </> | ||||
| )} | )} | ||||
| <Switch | |||||
| defaultValue={Boolean(config.enabled)} | |||||
| size='md' | |||||
| className='justify-self-end' | |||||
| onChange={value => toggleConfigEntryEnabled(index, value)} | |||||
| /> | |||||
| { | |||||
| (config.credential_id || config.name === '__inherit__') && ( | |||||
| <> | |||||
| <span className='mr-2 h-3 border-r border-r-divider-subtle' /> | |||||
| <Switch | |||||
| defaultValue={Boolean(config.enabled)} | |||||
| size='md' | |||||
| className='justify-self-end' | |||||
| onChange={value => toggleConfigEntryEnabled(index, value)} | |||||
| disabled={credential?.not_allowed_to_use} | |||||
| /> | |||||
| </> | |||||
| ) | |||||
| } | |||||
| </div> | </div> | ||||
| </div> | </div> | ||||
| ) | ) | ||||
| })} | })} | ||||
| <div | |||||
| className='mt-1 flex h-8 items-center px-3 text-[13px] font-medium text-primary-600' | |||||
| onClick={() => toggleEntryModal()} | |||||
| > | |||||
| <div className='flex cursor-pointer items-center'> | |||||
| <Plus02 className='mr-2 h-3 w-3' />{t('common.modelProvider.addConfig')} | |||||
| </div> | |||||
| </div> | |||||
| <AddCredentialInLoadBalancing | |||||
| provider={provider} | |||||
| model={model} | |||||
| configurationMethod={configurationMethod} | |||||
| modelCredential={modelCredential} | |||||
| onSelectCredential={addConfigEntry} | |||||
| onUpdate={onUpdate} | |||||
| /> | |||||
| </div> | </div> | ||||
| )} | )} | ||||
| { | { | ||||
| draftConfig.enabled && draftConfig.configs.length < 2 && ( | |||||
| <div className='flex h-[34px] items-center border-t border-t-divider-subtle bg-components-panel-bg px-6 text-xs text-text-secondary'> | |||||
| draftConfig.enabled && validDraftConfigList.length < 2 && ( | |||||
| <div className='flex h-[34px] items-center rounded-b-xl border-t border-t-divider-subtle bg-components-panel-bg px-6 text-xs text-text-secondary'> | |||||
| <AlertTriangle className='mr-1 h-3 w-3 text-[#f79009]' /> | <AlertTriangle className='mr-1 h-3 w-3 text-[#f79009]' /> | ||||
| {t('common.modelProvider.loadBalancingLeastKeyWarning')} | {t('common.modelProvider.loadBalancingLeastKeyWarning')} | ||||
| </div> | </div> |
| import { memo, useCallback, useEffect, useMemo, useState } from 'react' | import { memo, useCallback, useEffect, useMemo, useState } from 'react' | ||||
| import { useTranslation } from 'react-i18next' | import { useTranslation } from 'react-i18next' | ||||
| import useSWR from 'swr' | |||||
| import type { ModelItem, ModelLoadBalancingConfig, ModelLoadBalancingConfigEntry, ModelProvider } from '../declarations' | |||||
| import { FormTypeEnum } from '../declarations' | |||||
| import type { | |||||
| Credential, | |||||
| ModelItem, | |||||
| ModelLoadBalancingConfig, | |||||
| ModelLoadBalancingConfigEntry, | |||||
| ModelProvider, | |||||
| } from '../declarations' | |||||
| import { | |||||
| ConfigurationMethodEnum, | |||||
| FormTypeEnum, | |||||
| } from '../declarations' | |||||
| import ModelIcon from '../model-icon' | import ModelIcon from '../model-icon' | ||||
| import ModelName from '../model-name' | import ModelName from '../model-name' | ||||
| import { savePredefinedLoadBalancingConfig } from '../utils' | |||||
| import ModelLoadBalancingConfigs from './model-load-balancing-configs' | import ModelLoadBalancingConfigs from './model-load-balancing-configs' | ||||
| import classNames from '@/utils/classnames' | import classNames from '@/utils/classnames' | ||||
| import Modal from '@/app/components/base/modal' | import Modal from '@/app/components/base/modal' | ||||
| import Button from '@/app/components/base/button' | import Button from '@/app/components/base/button' | ||||
| import { fetchModelLoadBalancingConfig } from '@/service/common' | |||||
| import Loading from '@/app/components/base/loading' | import Loading from '@/app/components/base/loading' | ||||
| import { useToastContext } from '@/app/components/base/toast' | import { useToastContext } from '@/app/components/base/toast' | ||||
| import { SwitchCredentialInLoadBalancing } from '@/app/components/header/account-setting/model-provider-page/model-auth' | |||||
| import { | |||||
| useGetModelCredential, | |||||
| useUpdateModelLoadBalancingConfig, | |||||
| } from '@/service/use-models' | |||||
| export type ModelLoadBalancingModalProps = { | export type ModelLoadBalancingModalProps = { | ||||
| provider: ModelProvider | provider: ModelProvider | ||||
| configurateMethod: ConfigurationMethodEnum | |||||
| model: ModelItem | model: ModelItem | ||||
| credential?: Credential | |||||
| open?: boolean | open?: boolean | ||||
| onClose?: () => void | onClose?: () => void | ||||
| onSave?: (provider: string) => void | onSave?: (provider: string) => void | ||||
| } | } | ||||
| // model balancing config modal | // model balancing config modal | ||||
| const ModelLoadBalancingModal = ({ provider, model, open = false, onClose, onSave }: ModelLoadBalancingModalProps) => { | |||||
| const ModelLoadBalancingModal = ({ | |||||
| provider, | |||||
| configurateMethod, | |||||
| model, | |||||
| credential, | |||||
| open = false, | |||||
| onClose, | |||||
| onSave, | |||||
| }: ModelLoadBalancingModalProps) => { | |||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const { notify } = useToastContext() | const { notify } = useToastContext() | ||||
| const [loading, setLoading] = useState(false) | const [loading, setLoading] = useState(false) | ||||
| const { data, mutate } = useSWR( | |||||
| `/workspaces/current/model-providers/${provider.provider}/models/credentials?model=${model.model}&model_type=${model.model_type}`, | |||||
| fetchModelLoadBalancingConfig, | |||||
| ) | |||||
| const originalConfig = data?.load_balancing | |||||
| const providerFormSchemaPredefined = configurateMethod === ConfigurationMethodEnum.predefinedModel | |||||
| const configFrom = providerFormSchemaPredefined ? 'predefined-model' : 'custom-model' | |||||
| const { | |||||
| isLoading, | |||||
| data, | |||||
| refetch, | |||||
| } = useGetModelCredential(true, provider.provider, credential?.credential_id, model.model, model.model_type, configFrom) | |||||
| const modelCredential = data | |||||
| const { | |||||
| load_balancing, | |||||
| current_credential_id, | |||||
| available_credentials, | |||||
| current_credential_name, | |||||
| } = modelCredential ?? {} | |||||
| const originalConfig = load_balancing | |||||
| const [draftConfig, setDraftConfig] = useState<ModelLoadBalancingConfig>() | const [draftConfig, setDraftConfig] = useState<ModelLoadBalancingConfig>() | ||||
| const originalConfigMap = useMemo(() => { | const originalConfigMap = useMemo(() => { | ||||
| if (!originalConfig) | if (!originalConfig) | ||||
| }, [draftConfig]) | }, [draftConfig]) | ||||
| const extendedSecretFormSchemas = useMemo( | const extendedSecretFormSchemas = useMemo( | ||||
| () => provider.provider_credential_schema.credential_form_schemas.filter( | |||||
| ({ type }) => type === FormTypeEnum.secretInput, | |||||
| ), | |||||
| [provider.provider_credential_schema.credential_form_schemas], | |||||
| () => { | |||||
| if (providerFormSchemaPredefined) { | |||||
| return provider?.provider_credential_schema?.credential_form_schemas?.filter( | |||||
| ({ type }) => type === FormTypeEnum.secretInput, | |||||
| ) ?? [] | |||||
| } | |||||
| return provider?.model_credential_schema?.credential_form_schemas?.filter( | |||||
| ({ type }) => type === FormTypeEnum.secretInput, | |||||
| ) ?? [] | |||||
| }, | |||||
| [provider?.model_credential_schema?.credential_form_schemas, provider?.provider_credential_schema?.credential_form_schemas, providerFormSchemaPredefined], | |||||
| ) | ) | ||||
| const encodeConfigEntrySecretValues = useCallback((entry: ModelLoadBalancingConfigEntry) => { | const encodeConfigEntrySecretValues = useCallback((entry: ModelLoadBalancingConfigEntry) => { | ||||
| return result | return result | ||||
| }, [extendedSecretFormSchemas, originalConfigMap]) | }, [extendedSecretFormSchemas, originalConfigMap]) | ||||
| const { mutateAsync: updateModelLoadBalancingConfig } = useUpdateModelLoadBalancingConfig(provider.provider) | |||||
| const initialCustomModelCredential = useMemo(() => { | |||||
| if (!current_credential_id) | |||||
| return undefined | |||||
| return { | |||||
| credential_id: current_credential_id, | |||||
| credential_name: current_credential_name, | |||||
| } | |||||
| }, [current_credential_id, current_credential_name]) | |||||
| const [customModelCredential, setCustomModelCredential] = useState<Credential | undefined>(initialCustomModelCredential) | |||||
| const handleSave = async () => { | const handleSave = async () => { | ||||
| try { | try { | ||||
| setLoading(true) | setLoading(true) | ||||
| const res = await savePredefinedLoadBalancingConfig( | |||||
| provider.provider, | |||||
| ({ | |||||
| ...(data?.credentials ?? {}), | |||||
| __model_type: model.model_type, | |||||
| __model_name: model.model, | |||||
| }), | |||||
| const res = await updateModelLoadBalancingConfig( | |||||
| { | { | ||||
| ...draftConfig, | |||||
| enabled: Boolean(draftConfig?.enabled), | |||||
| configs: draftConfig!.configs.map(encodeConfigEntrySecretValues), | |||||
| credential_id: customModelCredential?.credential_id || current_credential_id, | |||||
| config_from: configFrom, | |||||
| model: model.model, | |||||
| model_type: model.model_type, | |||||
| load_balancing: { | |||||
| ...draftConfig, | |||||
| configs: draftConfig!.configs.map(encodeConfigEntrySecretValues), | |||||
| enabled: Boolean(draftConfig?.enabled), | |||||
| }, | |||||
| }, | }, | ||||
| ) | ) | ||||
| if (res.result === 'success') { | if (res.result === 'success') { | ||||
| notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) | notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) | ||||
| mutate() | |||||
| onSave?.(provider.provider) | onSave?.(provider.provider) | ||||
| onClose?.() | onClose?.() | ||||
| } | } | ||||
| className='w-[640px] max-w-none px-8 pt-8' | className='w-[640px] max-w-none px-8 pt-8' | ||||
| title={ | title={ | ||||
| <div className='pb-3 font-semibold'> | <div className='pb-3 font-semibold'> | ||||
| <div className='h-[30px]'>{t('common.modelProvider.configLoadBalancing')}</div> | |||||
| <div className='h-[30px]'>{ | |||||
| draftConfig?.enabled | |||||
| ? t('common.modelProvider.auth.configLoadBalancing') | |||||
| : t('common.modelProvider.auth.configModel') | |||||
| }</div> | |||||
| {Boolean(model) && ( | {Boolean(model) && ( | ||||
| <div className='flex h-5 items-center'> | <div className='flex h-5 items-center'> | ||||
| <ModelIcon | <ModelIcon | ||||
| )} | )} | ||||
| </div> | </div> | ||||
| <div className='grow'> | <div className='grow'> | ||||
| <div className='text-sm text-text-secondary'>{t('common.modelProvider.providerManaged')}</div> | |||||
| <div className='text-xs text-text-tertiary'>{t('common.modelProvider.providerManagedDescription')}</div> | |||||
| <div className='text-sm text-text-secondary'>{ | |||||
| providerFormSchemaPredefined | |||||
| ? t('common.modelProvider.auth.providerManaged') | |||||
| : t('common.modelProvider.auth.specifyModelCredential') | |||||
| }</div> | |||||
| <div className='text-xs text-text-tertiary'>{ | |||||
| providerFormSchemaPredefined | |||||
| ? t('common.modelProvider.auth.providerManagedTip') | |||||
| : t('common.modelProvider.auth.specifyModelCredentialTip') | |||||
| }</div> | |||||
| </div> | </div> | ||||
| { | |||||
| !providerFormSchemaPredefined && ( | |||||
| <SwitchCredentialInLoadBalancing | |||||
| provider={provider} | |||||
| customModelCredential={initialCustomModelCredential ?? customModelCredential} | |||||
| setCustomModelCredential={setCustomModelCredential} | |||||
| model={model} | |||||
| credentials={available_credentials} | |||||
| /> | |||||
| ) | |||||
| } | |||||
| </div> | </div> | ||||
| </div> | </div> | ||||
| <ModelLoadBalancingConfigs {...{ | |||||
| draftConfig, | |||||
| setDraftConfig, | |||||
| provider, | |||||
| currentCustomConfigurationModelFixedFields: { | |||||
| __model_name: model.model, | |||||
| __model_type: model.model_type, | |||||
| }, | |||||
| configurationMethod: model.fetch_from, | |||||
| className: 'mt-2', | |||||
| }} /> | |||||
| { | |||||
| modelCredential && ( | |||||
| <ModelLoadBalancingConfigs {...{ | |||||
| draftConfig, | |||||
| setDraftConfig, | |||||
| provider, | |||||
| currentCustomConfigurationModelFixedFields: { | |||||
| __model_name: model.model, | |||||
| __model_type: model.model_type, | |||||
| }, | |||||
| configurationMethod: model.fetch_from, | |||||
| className: 'mt-2', | |||||
| modelCredential, | |||||
| onUpdate: refetch, | |||||
| model: { | |||||
| model: model.model, | |||||
| model_type: model.model_type, | |||||
| }, | |||||
| }} /> | |||||
| ) | |||||
| } | |||||
| </div> | </div> | ||||
| <div className='mt-6 flex items-center justify-end gap-2'> | <div className='mt-6 flex items-center justify-end gap-2'> | ||||
| disabled={ | disabled={ | ||||
| loading | loading | ||||
| || (draftConfig?.enabled && (draftConfig?.configs.filter(config => config.enabled).length ?? 0) < 2) | || (draftConfig?.enabled && (draftConfig?.configs.filter(config => config.enabled).length ?? 0) < 2) | ||||
| || isLoading | |||||
| } | } | ||||
| >{t('common.operation.save')}</Button> | >{t('common.operation.save')}</Button> | ||||
| </div> | </div> |
| import { ValidatedStatus } from '../key-validator/declarations' | import { ValidatedStatus } from '../key-validator/declarations' | ||||
| import type { | import type { | ||||
| CredentialFormSchemaRadio, | |||||
| CredentialFormSchemaTextInput, | CredentialFormSchemaTextInput, | ||||
| FormValue, | FormValue, | ||||
| ModelLoadBalancingConfig, | ModelLoadBalancingConfig, | ||||
| let body, url | let body, url | ||||
| if (predefined) { | if (predefined) { | ||||
| const { __authorization_name__, ...rest } = v | |||||
| body = { | body = { | ||||
| config_from: ConfigurationMethodEnum.predefinedModel, | config_from: ConfigurationMethodEnum.predefinedModel, | ||||
| credentials: v, | |||||
| credentials: rest, | |||||
| load_balancing: loadBalancing, | load_balancing: loadBalancing, | ||||
| name: __authorization_name__, | |||||
| } | } | ||||
| url = `/workspaces/current/model-providers/${provider}` | |||||
| url = `/workspaces/current/model-providers/${provider}/credentials` | |||||
| } | } | ||||
| else { | else { | ||||
| const { __model_name, __model_type, ...credentials } = v | const { __model_name, __model_type, ...credentials } = v | ||||
| return setModelProvider({ url, body }) | return setModelProvider({ url, body }) | ||||
| } | } | ||||
| export const removeCredentials = async (predefined: boolean, provider: string, v: FormValue) => { | |||||
| export const removeCredentials = async (predefined: boolean, provider: string, v: FormValue, credentialId?: string) => { | |||||
| let url = '' | let url = '' | ||||
| let body | let body | ||||
| if (predefined) { | if (predefined) { | ||||
| url = `/workspaces/current/model-providers/${provider}` | |||||
| url = `/workspaces/current/model-providers/${provider}/credentials` | |||||
| if (credentialId) { | |||||
| body = { | |||||
| credential_id: credentialId, | |||||
| } | |||||
| } | |||||
| } | } | ||||
| else { | else { | ||||
| if (v) { | if (v) { | ||||
| show_on: [], | show_on: [], | ||||
| } | } | ||||
| }), | }), | ||||
| } as CredentialFormSchemaRadio | |||||
| } as any | |||||
| } | } | ||||
| export const genModelNameFormSchema = (model?: Pick<CredentialFormSchemaTextInput, 'label' | 'placeholder'>) => { | export const genModelNameFormSchema = (model?: Pick<CredentialFormSchemaTextInput, 'label' | 'placeholder'>) => { | ||||
| zh_Hans: '请输入模型名称', | zh_Hans: '请输入模型名称', | ||||
| en_US: 'Please enter model name', | en_US: 'Please enter model name', | ||||
| }, | }, | ||||
| } as CredentialFormSchemaTextInput | |||||
| } as any | |||||
| } | } |
| import AddApiKeyButton from './add-api-key-button' | import AddApiKeyButton from './add-api-key-button' | ||||
| import type { AddApiKeyButtonProps } from './add-api-key-button' | import type { AddApiKeyButtonProps } from './add-api-key-button' | ||||
| import type { PluginPayload } from '../types' | import type { PluginPayload } from '../types' | ||||
| import cn from '@/utils/classnames' | |||||
| import Tooltip from '@/app/components/base/tooltip' | |||||
| type AuthorizeProps = { | type AuthorizeProps = { | ||||
| pluginPayload: PluginPayload | pluginPayload: PluginPayload | ||||
| canApiKey?: boolean | canApiKey?: boolean | ||||
| disabled?: boolean | disabled?: boolean | ||||
| onUpdate?: () => void | onUpdate?: () => void | ||||
| notAllowCustomCredential?: boolean | |||||
| } | } | ||||
| const Authorize = ({ | const Authorize = ({ | ||||
| pluginPayload, | pluginPayload, | ||||
| canApiKey, | canApiKey, | ||||
| disabled, | disabled, | ||||
| onUpdate, | onUpdate, | ||||
| notAllowCustomCredential, | |||||
| }: AuthorizeProps) => { | }: AuthorizeProps) => { | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const oAuthButtonProps: AddOAuthButtonProps = useMemo(() => { | const oAuthButtonProps: AddOAuthButtonProps = useMemo(() => { | ||||
| } | } | ||||
| }, [canOAuth, theme, pluginPayload, t]) | }, [canOAuth, theme, pluginPayload, t]) | ||||
| const OAuthButton = useMemo(() => { | |||||
| const Item = ( | |||||
| <div className={cn('min-w-0 flex-[1]', notAllowCustomCredential && 'opacity-50')}> | |||||
| <AddOAuthButton | |||||
| {...oAuthButtonProps} | |||||
| disabled={disabled || notAllowCustomCredential} | |||||
| onUpdate={onUpdate} | |||||
| /> | |||||
| </div> | |||||
| ) | |||||
| if (notAllowCustomCredential) { | |||||
| return ( | |||||
| <Tooltip popupContent={t('plugin.auth.credentialUnavailable')}> | |||||
| {Item} | |||||
| </Tooltip> | |||||
| ) | |||||
| } | |||||
| return Item | |||||
| }, [notAllowCustomCredential, oAuthButtonProps, disabled, onUpdate, t]) | |||||
| const ApiKeyButton = useMemo(() => { | |||||
| const Item = ( | |||||
| <div className={cn('min-w-0 flex-[1]', notAllowCustomCredential && 'opacity-50')}> | |||||
| <AddApiKeyButton | |||||
| {...apiKeyButtonProps} | |||||
| disabled={disabled || notAllowCustomCredential} | |||||
| onUpdate={onUpdate} | |||||
| /> | |||||
| </div> | |||||
| ) | |||||
| if (notAllowCustomCredential) { | |||||
| return ( | |||||
| <Tooltip popupContent={t('plugin.auth.credentialUnavailable')}> | |||||
| {Item} | |||||
| </Tooltip> | |||||
| ) | |||||
| } | |||||
| return Item | |||||
| }, [notAllowCustomCredential, apiKeyButtonProps, disabled, onUpdate, t]) | |||||
| return ( | return ( | ||||
| <> | <> | ||||
| <div className='flex items-center space-x-1.5'> | <div className='flex items-center space-x-1.5'> | ||||
| { | { | ||||
| canOAuth && ( | canOAuth && ( | ||||
| <div className='min-w-0 flex-[1]'> | |||||
| <AddOAuthButton | |||||
| {...oAuthButtonProps} | |||||
| disabled={disabled} | |||||
| onUpdate={onUpdate} | |||||
| /> | |||||
| </div> | |||||
| OAuthButton | |||||
| ) | ) | ||||
| } | } | ||||
| { | { | ||||
| } | } | ||||
| { | { | ||||
| canApiKey && ( | canApiKey && ( | ||||
| <div className='min-w-0 flex-[1]'> | |||||
| <AddApiKeyButton | |||||
| {...apiKeyButtonProps} | |||||
| disabled={disabled} | |||||
| onUpdate={onUpdate} | |||||
| /> | |||||
| </div> | |||||
| ApiKeyButton | |||||
| ) | ) | ||||
| } | } | ||||
| </div> | </div> |
| credentials, | credentials, | ||||
| disabled, | disabled, | ||||
| invalidPluginCredentialInfo, | invalidPluginCredentialInfo, | ||||
| notAllowCustomCredential, | |||||
| } = usePluginAuth(pluginPayload, isOpen || !!credentialId) | } = usePluginAuth(pluginPayload, isOpen || !!credentialId) | ||||
| const renderTrigger = useCallback((open?: boolean) => { | const renderTrigger = useCallback((open?: boolean) => { | ||||
| let label = '' | let label = '' | ||||
| let removed = false | let removed = false | ||||
| let unavailable = false | |||||
| let color = 'green' | |||||
| if (!credentialId) { | if (!credentialId) { | ||||
| label = t('plugin.auth.workspaceDefault') | label = t('plugin.auth.workspaceDefault') | ||||
| } | } | ||||
| const credential = credentials.find(c => c.id === credentialId) | const credential = credentials.find(c => c.id === credentialId) | ||||
| label = credential ? credential.name : t('plugin.auth.authRemoved') | label = credential ? credential.name : t('plugin.auth.authRemoved') | ||||
| removed = !credential | removed = !credential | ||||
| unavailable = !!credential?.not_allowed_to_use && !credential?.from_enterprise | |||||
| if (removed) | |||||
| color = 'red' | |||||
| else if (unavailable) | |||||
| color = 'gray' | |||||
| } | } | ||||
| return ( | return ( | ||||
| <Button | <Button | ||||
| > | > | ||||
| <Indicator | <Indicator | ||||
| className='mr-1.5' | className='mr-1.5' | ||||
| color={removed ? 'red' : 'green'} | |||||
| color={color as any} | |||||
| /> | /> | ||||
| {label} | {label} | ||||
| { | |||||
| unavailable && t('plugin.auth.unavailable') | |||||
| } | |||||
| <RiArrowDownSLine | <RiArrowDownSLine | ||||
| className={cn( | className={cn( | ||||
| 'h-3.5 w-3.5 text-components-button-ghost-text', | 'h-3.5 w-3.5 text-components-button-ghost-text', | ||||
| showItemSelectedIcon | showItemSelectedIcon | ||||
| selectedCredentialId={credentialId || '__workspace_default__'} | selectedCredentialId={credentialId || '__workspace_default__'} | ||||
| onUpdate={invalidPluginCredentialInfo} | onUpdate={invalidPluginCredentialInfo} | ||||
| notAllowCustomCredential={notAllowCustomCredential} | |||||
| /> | /> | ||||
| ) | ) | ||||
| } | } |
| showItemSelectedIcon?: boolean | showItemSelectedIcon?: boolean | ||||
| selectedCredentialId?: string | selectedCredentialId?: string | ||||
| onUpdate?: () => void | onUpdate?: () => void | ||||
| notAllowCustomCredential?: boolean | |||||
| } | } | ||||
| const Authorized = ({ | const Authorized = ({ | ||||
| pluginPayload, | pluginPayload, | ||||
| showItemSelectedIcon, | showItemSelectedIcon, | ||||
| selectedCredentialId, | selectedCredentialId, | ||||
| onUpdate, | onUpdate, | ||||
| notAllowCustomCredential, | |||||
| }: AuthorizedProps) => { | }: AuthorizedProps) => { | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const { notify } = useToastContext() | const { notify } = useToastContext() | ||||
| handleSetDoingAction(false) | handleSetDoingAction(false) | ||||
| } | } | ||||
| }, [updatePluginCredential, notify, t, handleSetDoingAction, onUpdate]) | }, [updatePluginCredential, notify, t, handleSetDoingAction, onUpdate]) | ||||
| const unavailableCredentials = credentials.filter(credential => credential.not_allowed_to_use) | |||||
| return ( | return ( | ||||
| <> | <> | ||||
| ? t('plugin.auth.authorizations') | ? t('plugin.auth.authorizations') | ||||
| : t('plugin.auth.authorization') | : t('plugin.auth.authorization') | ||||
| } | } | ||||
| { | |||||
| !!unavailableCredentials.length && ( | |||||
| ` (${unavailableCredentials.length} ${t('plugin.auth.unavailable')})` | |||||
| ) | |||||
| } | |||||
| <RiArrowDownSLine className='ml-0.5 h-4 w-4' /> | <RiArrowDownSLine className='ml-0.5 h-4 w-4' /> | ||||
| </Button> | </Button> | ||||
| ) | ) | ||||
| ) | ) | ||||
| } | } | ||||
| </div> | </div> | ||||
| <div className='h-px bg-divider-subtle'></div> | |||||
| <div className='p-2'> | |||||
| <Authorize | |||||
| pluginPayload={pluginPayload} | |||||
| theme='secondary' | |||||
| showDivider={false} | |||||
| canOAuth={canOAuth} | |||||
| canApiKey={canApiKey} | |||||
| disabled={disabled} | |||||
| onUpdate={onUpdate} | |||||
| /> | |||||
| </div> | |||||
| { | |||||
| !notAllowCustomCredential && ( | |||||
| <> | |||||
| <div className='h-[1px] bg-divider-subtle'></div> | |||||
| <div className='p-2'> | |||||
| <Authorize | |||||
| pluginPayload={pluginPayload} | |||||
| theme='secondary' | |||||
| showDivider={false} | |||||
| canOAuth={canOAuth} | |||||
| canApiKey={canApiKey} | |||||
| disabled={disabled} | |||||
| onUpdate={onUpdate} | |||||
| /> | |||||
| </div> | |||||
| </> | |||||
| ) | |||||
| } | |||||
| </div> | </div> | ||||
| </PortalToFollowElemContent> | </PortalToFollowElemContent> | ||||
| </PortalToFollowElem> | </PortalToFollowElem> |
| return !(disableRename && disableEdit && disableDelete && disableSetDefault) | return !(disableRename && disableEdit && disableDelete && disableSetDefault) | ||||
| }, [disableRename, disableEdit, disableDelete, disableSetDefault]) | }, [disableRename, disableEdit, disableDelete, disableSetDefault]) | ||||
| return ( | |||||
| const CredentialItem = ( | |||||
| <div | <div | ||||
| key={credential.id} | key={credential.id} | ||||
| className={cn( | className={cn( | ||||
| 'group flex h-8 items-center rounded-lg p-1 hover:bg-state-base-hover', | 'group flex h-8 items-center rounded-lg p-1 hover:bg-state-base-hover', | ||||
| renaming && 'bg-state-base-hover', | renaming && 'bg-state-base-hover', | ||||
| (disabled || credential.not_allowed_to_use) && 'cursor-not-allowed opacity-50', | |||||
| )} | )} | ||||
| onClick={() => onItemClick?.(credential.id === '__workspace_default__' ? '' : credential.id)} | |||||
| onClick={() => { | |||||
| if (credential.not_allowed_to_use || disabled) | |||||
| return | |||||
| onItemClick?.(credential.id === '__workspace_default__' ? '' : credential.id) | |||||
| }} | |||||
| > | > | ||||
| { | { | ||||
| renaming && ( | renaming && ( | ||||
| </div> | </div> | ||||
| ) | ) | ||||
| } | } | ||||
| <Indicator className='ml-2 mr-1.5 shrink-0' /> | |||||
| <Indicator | |||||
| className='ml-2 mr-1.5 shrink-0' | |||||
| color={credential.not_allowed_to_use ? 'gray' : 'green'} | |||||
| /> | |||||
| <div | <div | ||||
| className='system-md-regular truncate text-text-secondary' | className='system-md-regular truncate text-text-secondary' | ||||
| title={credential.name} | title={credential.name} | ||||
| </div> | </div> | ||||
| ) | ) | ||||
| } | } | ||||
| { | |||||
| credential.from_enterprise && ( | |||||
| <Badge className='shrink-0'> | |||||
| Enterprise | |||||
| </Badge> | |||||
| ) | |||||
| } | |||||
| { | { | ||||
| showAction && !renaming && ( | showAction && !renaming && ( | ||||
| <div className='ml-2 hidden shrink-0 items-center group-hover:flex'> | <div className='ml-2 hidden shrink-0 items-center group-hover:flex'> | ||||
| { | { | ||||
| !credential.is_default && !disableSetDefault && ( | |||||
| !credential.is_default && !disableSetDefault && !credential.not_allowed_to_use && ( | |||||
| <Button | <Button | ||||
| size='small' | size='small' | ||||
| disabled={disabled} | disabled={disabled} | ||||
| ) | ) | ||||
| } | } | ||||
| { | { | ||||
| !disableRename && ( | |||||
| !disableRename && !credential.from_enterprise && !credential.not_allowed_to_use && ( | |||||
| <Tooltip popupContent={t('common.operation.rename')}> | <Tooltip popupContent={t('common.operation.rename')}> | ||||
| <ActionButton | <ActionButton | ||||
| disabled={disabled} | disabled={disabled} | ||||
| ) | ) | ||||
| } | } | ||||
| { | { | ||||
| !isOAuth && !disableEdit && ( | |||||
| !isOAuth && !disableEdit && !credential.from_enterprise && !credential.not_allowed_to_use && ( | |||||
| <Tooltip popupContent={t('common.operation.edit')}> | <Tooltip popupContent={t('common.operation.edit')}> | ||||
| <ActionButton | <ActionButton | ||||
| disabled={disabled} | disabled={disabled} | ||||
| ) | ) | ||||
| } | } | ||||
| { | { | ||||
| !disableDelete && ( | |||||
| !disableDelete && !credential.from_enterprise && ( | |||||
| <Tooltip popupContent={t('common.operation.delete')}> | <Tooltip popupContent={t('common.operation.delete')}> | ||||
| <ActionButton | <ActionButton | ||||
| className='hover:bg-transparent' | className='hover:bg-transparent' | ||||
| } | } | ||||
| </div> | </div> | ||||
| ) | ) | ||||
| if (credential.not_allowed_to_use) { | |||||
| return ( | |||||
| <Tooltip popupContent={t('plugin.auth.customCredentialUnavailable')}> | |||||
| {CredentialItem} | |||||
| </Tooltip> | |||||
| ) | |||||
| } | |||||
| return ( | |||||
| CredentialItem | |||||
| ) | |||||
| } | } | ||||
| export default memo(Item) | export default memo(Item) |
| import { | |||||
| useCallback, | |||||
| useRef, | |||||
| useState, | |||||
| } from 'react' | |||||
| import { useTranslation } from 'react-i18next' | |||||
| import { useToastContext } from '@/app/components/base/toast' | |||||
| import type { PluginPayload } from '@/app/components/plugins/plugin-auth/types' | |||||
| import { | |||||
| useDeletePluginCredentialHook, | |||||
| useSetPluginDefaultCredentialHook, | |||||
| useUpdatePluginCredentialHook, | |||||
| } from '../hooks/use-credential' | |||||
| export const usePluginAuthAction = ( | |||||
| pluginPayload: PluginPayload, | |||||
| onUpdate?: () => void, | |||||
| ) => { | |||||
| const { t } = useTranslation() | |||||
| const { notify } = useToastContext() | |||||
| const pendingOperationCredentialId = useRef<string | null>(null) | |||||
| const [deleteCredentialId, setDeleteCredentialId] = useState<string | null>(null) | |||||
| const { mutateAsync: deletePluginCredential } = useDeletePluginCredentialHook(pluginPayload) | |||||
| const openConfirm = useCallback((credentialId?: string) => { | |||||
| if (credentialId) | |||||
| pendingOperationCredentialId.current = credentialId | |||||
| setDeleteCredentialId(pendingOperationCredentialId.current) | |||||
| }, []) | |||||
| const closeConfirm = useCallback(() => { | |||||
| setDeleteCredentialId(null) | |||||
| pendingOperationCredentialId.current = null | |||||
| }, []) | |||||
| const [doingAction, setDoingAction] = useState(false) | |||||
| const doingActionRef = useRef(doingAction) | |||||
| const handleSetDoingAction = useCallback((doing: boolean) => { | |||||
| doingActionRef.current = doing | |||||
| setDoingAction(doing) | |||||
| }, []) | |||||
| const [editValues, setEditValues] = useState<Record<string, any> | null>(null) | |||||
| const handleConfirm = useCallback(async () => { | |||||
| if (doingActionRef.current) | |||||
| return | |||||
| if (!pendingOperationCredentialId.current) { | |||||
| setDeleteCredentialId(null) | |||||
| return | |||||
| } | |||||
| try { | |||||
| handleSetDoingAction(true) | |||||
| await deletePluginCredential({ credential_id: pendingOperationCredentialId.current }) | |||||
| notify({ | |||||
| type: 'success', | |||||
| message: t('common.api.actionSuccess'), | |||||
| }) | |||||
| onUpdate?.() | |||||
| setDeleteCredentialId(null) | |||||
| pendingOperationCredentialId.current = null | |||||
| setEditValues(null) | |||||
| } | |||||
| finally { | |||||
| handleSetDoingAction(false) | |||||
| } | |||||
| }, [deletePluginCredential, onUpdate, notify, t, handleSetDoingAction]) | |||||
| const handleEdit = useCallback((id: string, values: Record<string, any>) => { | |||||
| pendingOperationCredentialId.current = id | |||||
| setEditValues(values) | |||||
| }, []) | |||||
| const handleRemove = useCallback(() => { | |||||
| setDeleteCredentialId(pendingOperationCredentialId.current) | |||||
| }, []) | |||||
| const { mutateAsync: setPluginDefaultCredential } = useSetPluginDefaultCredentialHook(pluginPayload) | |||||
| const handleSetDefault = useCallback(async (id: string) => { | |||||
| if (doingActionRef.current) | |||||
| return | |||||
| try { | |||||
| handleSetDoingAction(true) | |||||
| await setPluginDefaultCredential(id) | |||||
| notify({ | |||||
| type: 'success', | |||||
| message: t('common.api.actionSuccess'), | |||||
| }) | |||||
| onUpdate?.() | |||||
| } | |||||
| finally { | |||||
| handleSetDoingAction(false) | |||||
| } | |||||
| }, [setPluginDefaultCredential, onUpdate, notify, t, handleSetDoingAction]) | |||||
| const { mutateAsync: updatePluginCredential } = useUpdatePluginCredentialHook(pluginPayload) | |||||
| const handleRename = useCallback(async (payload: { | |||||
| credential_id: string | |||||
| name: string | |||||
| }) => { | |||||
| if (doingActionRef.current) | |||||
| return | |||||
| try { | |||||
| handleSetDoingAction(true) | |||||
| await updatePluginCredential(payload) | |||||
| notify({ | |||||
| type: 'success', | |||||
| message: t('common.api.actionSuccess'), | |||||
| }) | |||||
| onUpdate?.() | |||||
| } | |||||
| finally { | |||||
| handleSetDoingAction(false) | |||||
| } | |||||
| }, [updatePluginCredential, notify, t, handleSetDoingAction, onUpdate]) | |||||
| return { | |||||
| doingAction, | |||||
| handleSetDoingAction, | |||||
| openConfirm, | |||||
| closeConfirm, | |||||
| deleteCredentialId, | |||||
| setDeleteCredentialId, | |||||
| handleConfirm, | |||||
| editValues, | |||||
| setEditValues, | |||||
| handleEdit, | |||||
| handleRemove, | |||||
| handleSetDefault, | |||||
| handleRename, | |||||
| pendingOperationCredentialId, | |||||
| } | |||||
| } |
| canApiKey, | canApiKey, | ||||
| credentials: data?.credentials || [], | credentials: data?.credentials || [], | ||||
| disabled: !isCurrentWorkspaceManager, | disabled: !isCurrentWorkspaceManager, | ||||
| notAllowCustomCredential: data?.allow_custom_token === false, | |||||
| invalidPluginCredentialInfo, | invalidPluginCredentialInfo, | ||||
| } | } | ||||
| } | } |
| credentials, | credentials, | ||||
| disabled, | disabled, | ||||
| invalidPluginCredentialInfo, | invalidPluginCredentialInfo, | ||||
| notAllowCustomCredential, | |||||
| } = usePluginAuth(pluginPayload, true) | } = usePluginAuth(pluginPayload, true) | ||||
| const extraAuthorizationItems: Credential[] = [ | const extraAuthorizationItems: Credential[] = [ | ||||
| const renderTrigger = useCallback((isOpen?: boolean) => { | const renderTrigger = useCallback((isOpen?: boolean) => { | ||||
| let label = '' | let label = '' | ||||
| let removed = false | let removed = false | ||||
| let unavailable = false | |||||
| let color = 'green' | |||||
| if (!credentialId) { | if (!credentialId) { | ||||
| label = t('plugin.auth.workspaceDefault') | label = t('plugin.auth.workspaceDefault') | ||||
| } | } | ||||
| const credential = credentials.find(c => c.id === credentialId) | const credential = credentials.find(c => c.id === credentialId) | ||||
| label = credential ? credential.name : t('plugin.auth.authRemoved') | label = credential ? credential.name : t('plugin.auth.authRemoved') | ||||
| removed = !credential | removed = !credential | ||||
| unavailable = !!credential?.not_allowed_to_use && !credential?.from_enterprise | |||||
| if (removed) | |||||
| color = 'red' | |||||
| else if (unavailable) | |||||
| color = 'gray' | |||||
| } | } | ||||
| return ( | return ( | ||||
| <Button | <Button | ||||
| )}> | )}> | ||||
| <Indicator | <Indicator | ||||
| className='mr-2' | className='mr-2' | ||||
| color={removed ? 'red' : 'green'} | |||||
| color={color as any} | |||||
| /> | /> | ||||
| {label} | {label} | ||||
| { | |||||
| unavailable && t('plugin.auth.unavailable') | |||||
| } | |||||
| <RiArrowDownSLine className='ml-0.5 h-4 w-4' /> | <RiArrowDownSLine className='ml-0.5 h-4 w-4' /> | ||||
| </Button> | </Button> | ||||
| ) | ) | ||||
| canApiKey={canApiKey} | canApiKey={canApiKey} | ||||
| disabled={disabled} | disabled={disabled} | ||||
| onUpdate={invalidPluginCredentialInfo} | onUpdate={invalidPluginCredentialInfo} | ||||
| notAllowCustomCredential={notAllowCustomCredential} | |||||
| /> | /> | ||||
| ) | ) | ||||
| } | } | ||||
| onOpenChange={setIsOpen} | onOpenChange={setIsOpen} | ||||
| selectedCredentialId={credentialId || '__workspace_default__'} | selectedCredentialId={credentialId || '__workspace_default__'} | ||||
| onUpdate={invalidPluginCredentialInfo} | onUpdate={invalidPluginCredentialInfo} | ||||
| notAllowCustomCredential={notAllowCustomCredential} | |||||
| /> | /> | ||||
| ) | ) | ||||
| } | } |
| credentials, | credentials, | ||||
| disabled, | disabled, | ||||
| invalidPluginCredentialInfo, | invalidPluginCredentialInfo, | ||||
| notAllowCustomCredential, | |||||
| } = usePluginAuth(pluginPayload, !!pluginPayload.provider) | } = usePluginAuth(pluginPayload, !!pluginPayload.provider) | ||||
| return ( | return ( | ||||
| canApiKey={canApiKey} | canApiKey={canApiKey} | ||||
| disabled={disabled} | disabled={disabled} | ||||
| onUpdate={invalidPluginCredentialInfo} | onUpdate={invalidPluginCredentialInfo} | ||||
| notAllowCustomCredential={notAllowCustomCredential} | |||||
| /> | /> | ||||
| ) | ) | ||||
| } | } | ||||
| canApiKey={canApiKey} | canApiKey={canApiKey} | ||||
| disabled={disabled} | disabled={disabled} | ||||
| onUpdate={invalidPluginCredentialInfo} | onUpdate={invalidPluginCredentialInfo} | ||||
| notAllowCustomCredential={notAllowCustomCredential} | |||||
| /> | /> | ||||
| ) | ) | ||||
| } | } |
| is_default: boolean | is_default: boolean | ||||
| credentials?: Record<string, any> | credentials?: Record<string, any> | ||||
| isWorkspaceDefault?: boolean | isWorkspaceDefault?: boolean | ||||
| from_enterprise?: boolean | |||||
| not_allowed_to_use?: boolean | |||||
| } | } |
| import { useRouter, useSearchParams } from 'next/navigation' | import { useRouter, useSearchParams } from 'next/navigation' | ||||
| import type { | import type { | ||||
| ConfigurationMethodEnum, | ConfigurationMethodEnum, | ||||
| Credential, | |||||
| CustomConfigurationModelFixedFields, | CustomConfigurationModelFixedFields, | ||||
| CustomModel, | |||||
| ModelLoadBalancingConfigEntry, | ModelLoadBalancingConfigEntry, | ||||
| ModelProvider, | ModelProvider, | ||||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | } from '@/app/components/header/account-setting/model-provider-page/declarations' | ||||
| const ModelLoadBalancingModal = dynamic(() => import('@/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal'), { | const ModelLoadBalancingModal = dynamic(() => import('@/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal'), { | ||||
| ssr: false, | ssr: false, | ||||
| }) | }) | ||||
| const ModelLoadBalancingEntryModal = dynamic(() => import('@/app/components/header/account-setting/model-provider-page/model-modal/model-load-balancing-entry-modal'), { | |||||
| ssr: false, | |||||
| }) | |||||
| const OpeningSettingModal = dynamic(() => import('@/app/components/base/features/new-feature-panel/conversation-opener/modal'), { | const OpeningSettingModal = dynamic(() => import('@/app/components/base/features/new-feature-panel/conversation-opener/modal'), { | ||||
| ssr: false, | ssr: false, | ||||
| }) | }) | ||||
| currentProvider: ModelProvider | currentProvider: ModelProvider | ||||
| currentConfigurationMethod: ConfigurationMethodEnum | currentConfigurationMethod: ConfigurationMethodEnum | ||||
| currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields | currentCustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields | ||||
| isModelCredential?: boolean | |||||
| credential?: Credential | |||||
| model?: CustomModel | |||||
| } | } | ||||
| export type LoadBalancingEntryModalType = ModelModalType & { | export type LoadBalancingEntryModalType = ModelModalType & { | ||||
| entry?: ModelLoadBalancingConfigEntry | entry?: ModelLoadBalancingConfigEntry | ||||
| setShowModelModal: Dispatch<SetStateAction<ModalState<ModelModalType> | null>> | setShowModelModal: Dispatch<SetStateAction<ModalState<ModelModalType> | null>> | ||||
| setShowExternalKnowledgeAPIModal: Dispatch<SetStateAction<ModalState<CreateExternalAPIReq> | null>> | setShowExternalKnowledgeAPIModal: Dispatch<SetStateAction<ModalState<CreateExternalAPIReq> | null>> | ||||
| setShowModelLoadBalancingModal: Dispatch<SetStateAction<ModelLoadBalancingModalProps | null>> | setShowModelLoadBalancingModal: Dispatch<SetStateAction<ModelLoadBalancingModalProps | null>> | ||||
| setShowModelLoadBalancingEntryModal: Dispatch<SetStateAction<ModalState<LoadBalancingEntryModalType> | null>> | |||||
| setShowOpeningModal: Dispatch<SetStateAction<ModalState<OpeningStatement & { | setShowOpeningModal: Dispatch<SetStateAction<ModalState<OpeningStatement & { | ||||
| promptVariables?: PromptVariable[] | promptVariables?: PromptVariable[] | ||||
| workflowVariables?: InputVar[] | workflowVariables?: InputVar[] | ||||
| setShowModelModal: noop, | setShowModelModal: noop, | ||||
| setShowExternalKnowledgeAPIModal: noop, | setShowExternalKnowledgeAPIModal: noop, | ||||
| setShowModelLoadBalancingModal: noop, | setShowModelLoadBalancingModal: noop, | ||||
| setShowModelLoadBalancingEntryModal: noop, | |||||
| setShowOpeningModal: noop, | setShowOpeningModal: noop, | ||||
| setShowUpdatePluginModal: noop, | setShowUpdatePluginModal: noop, | ||||
| setShowEducationExpireNoticeModal: noop, | setShowEducationExpireNoticeModal: noop, | ||||
| const [showModelModal, setShowModelModal] = useState<ModalState<ModelModalType> | null>(null) | const [showModelModal, setShowModelModal] = useState<ModalState<ModelModalType> | null>(null) | ||||
| const [showExternalKnowledgeAPIModal, setShowExternalKnowledgeAPIModal] = useState<ModalState<CreateExternalAPIReq> | null>(null) | const [showExternalKnowledgeAPIModal, setShowExternalKnowledgeAPIModal] = useState<ModalState<CreateExternalAPIReq> | null>(null) | ||||
| const [showModelLoadBalancingModal, setShowModelLoadBalancingModal] = useState<ModelLoadBalancingModalProps | null>(null) | const [showModelLoadBalancingModal, setShowModelLoadBalancingModal] = useState<ModelLoadBalancingModalProps | null>(null) | ||||
| const [showModelLoadBalancingEntryModal, setShowModelLoadBalancingEntryModal] = useState<ModalState<LoadBalancingEntryModalType> | null>(null) | |||||
| const [showOpeningModal, setShowOpeningModal] = useState<ModalState<OpeningStatement & { | const [showOpeningModal, setShowOpeningModal] = useState<ModalState<OpeningStatement & { | ||||
| promptVariables?: PromptVariable[] | promptVariables?: PromptVariable[] | ||||
| workflowVariables?: InputVar[] | workflowVariables?: InputVar[] | ||||
| setShowExternalKnowledgeAPIModal(null) | setShowExternalKnowledgeAPIModal(null) | ||||
| }, [showExternalKnowledgeAPIModal]) | }, [showExternalKnowledgeAPIModal]) | ||||
| const handleCancelModelLoadBalancingEntryModal = useCallback(() => { | |||||
| showModelLoadBalancingEntryModal?.onCancelCallback?.() | |||||
| setShowModelLoadBalancingEntryModal(null) | |||||
| }, [showModelLoadBalancingEntryModal]) | |||||
| const handleCancelOpeningModal = useCallback(() => { | const handleCancelOpeningModal = useCallback(() => { | ||||
| setShowOpeningModal(null) | setShowOpeningModal(null) | ||||
| if (showOpeningModal?.onCancelCallback) | if (showOpeningModal?.onCancelCallback) | ||||
| showOpeningModal.onCancelCallback() | showOpeningModal.onCancelCallback() | ||||
| }, [showOpeningModal]) | }, [showOpeningModal]) | ||||
| const handleSaveModelLoadBalancingEntryModal = useCallback((entry: ModelLoadBalancingConfigEntry) => { | |||||
| showModelLoadBalancingEntryModal?.onSaveCallback?.({ | |||||
| ...showModelLoadBalancingEntryModal.payload, | |||||
| entry, | |||||
| }) | |||||
| setShowModelLoadBalancingEntryModal(null) | |||||
| }, [showModelLoadBalancingEntryModal]) | |||||
| const handleRemoveModelLoadBalancingEntry = useCallback(() => { | |||||
| showModelLoadBalancingEntryModal?.onRemoveCallback?.(showModelLoadBalancingEntryModal.payload) | |||||
| setShowModelLoadBalancingEntryModal(null) | |||||
| }, [showModelLoadBalancingEntryModal]) | |||||
| const handleSaveApiBasedExtension = (newApiBasedExtension: ApiBasedExtension) => { | const handleSaveApiBasedExtension = (newApiBasedExtension: ApiBasedExtension) => { | ||||
| if (showApiBasedExtensionModal?.onSaveCallback) | if (showApiBasedExtensionModal?.onSaveCallback) | ||||
| showApiBasedExtensionModal.onSaveCallback(newApiBasedExtension) | showApiBasedExtensionModal.onSaveCallback(newApiBasedExtension) | ||||
| setShowModelModal, | setShowModelModal, | ||||
| setShowExternalKnowledgeAPIModal, | setShowExternalKnowledgeAPIModal, | ||||
| setShowModelLoadBalancingModal, | setShowModelLoadBalancingModal, | ||||
| setShowModelLoadBalancingEntryModal, | |||||
| setShowOpeningModal, | setShowOpeningModal, | ||||
| setShowUpdatePluginModal, | setShowUpdatePluginModal, | ||||
| setShowEducationExpireNoticeModal, | setShowEducationExpireNoticeModal, | ||||
| provider={showModelModal.payload.currentProvider} | provider={showModelModal.payload.currentProvider} | ||||
| configurateMethod={showModelModal.payload.currentConfigurationMethod} | configurateMethod={showModelModal.payload.currentConfigurationMethod} | ||||
| currentCustomConfigurationModelFixedFields={showModelModal.payload.currentCustomConfigurationModelFixedFields} | currentCustomConfigurationModelFixedFields={showModelModal.payload.currentCustomConfigurationModelFixedFields} | ||||
| isModelCredential={showModelModal.payload.isModelCredential} | |||||
| credential={showModelModal.payload.credential} | |||||
| model={showModelModal.payload.model} | |||||
| onCancel={handleCancelModelModal} | onCancel={handleCancelModelModal} | ||||
| onSave={handleSaveModelModal} | onSave={handleSaveModelModal} | ||||
| /> | /> | ||||
| <ModelLoadBalancingModal {...showModelLoadBalancingModal!} /> | <ModelLoadBalancingModal {...showModelLoadBalancingModal!} /> | ||||
| ) | ) | ||||
| } | } | ||||
| { | |||||
| !!showModelLoadBalancingEntryModal && ( | |||||
| <ModelLoadBalancingEntryModal | |||||
| provider={showModelLoadBalancingEntryModal.payload.currentProvider} | |||||
| configurationMethod={showModelLoadBalancingEntryModal.payload.currentConfigurationMethod} | |||||
| currentCustomConfigurationModelFixedFields={showModelLoadBalancingEntryModal.payload.currentCustomConfigurationModelFixedFields} | |||||
| entry={showModelLoadBalancingEntryModal.payload.entry} | |||||
| onCancel={handleCancelModelLoadBalancingEntryModal} | |||||
| onSave={handleSaveModelLoadBalancingEntryModal} | |||||
| onRemove={handleRemoveModelLoadBalancingEntry} | |||||
| /> | |||||
| ) | |||||
| } | |||||
| {showOpeningModal && ( | {showOpeningModal && ( | ||||
| <OpeningSettingModal | <OpeningSettingModal | ||||
| data={showOpeningModal.payload} | data={showOpeningModal.payload} |
| deleteApp: 'Delete App', | deleteApp: 'Delete App', | ||||
| settings: 'Settings', | settings: 'Settings', | ||||
| setup: 'Setup', | setup: 'Setup', | ||||
| config: 'Config', | |||||
| getForFree: 'Get for free', | getForFree: 'Get for free', | ||||
| reload: 'Reload', | reload: 'Reload', | ||||
| ok: 'OK', | ok: 'OK', | ||||
| loadPresets: 'Load Presets', | loadPresets: 'Load Presets', | ||||
| parameters: 'PARAMETERS', | parameters: 'PARAMETERS', | ||||
| loadBalancing: 'Load balancing', | loadBalancing: 'Load balancing', | ||||
| loadBalancingDescription: 'Reduce pressure with multiple sets of credentials.', | |||||
| loadBalancingDescription: 'Configure multiple credentials for the model and invoke them automatically. ', | |||||
| loadBalancingHeadline: 'Load Balancing', | loadBalancingHeadline: 'Load Balancing', | ||||
| configLoadBalancing: 'Config Load Balancing', | configLoadBalancing: 'Config Load Balancing', | ||||
| modelHasBeenDeprecated: 'This model has been deprecated', | modelHasBeenDeprecated: 'This model has been deprecated', | ||||
| discoverMore: 'Discover more in ', | discoverMore: 'Discover more in ', | ||||
| emptyProviderTitle: 'Model provider not set up', | emptyProviderTitle: 'Model provider not set up', | ||||
| emptyProviderTip: 'Please install a model provider first.', | emptyProviderTip: 'Please install a model provider first.', | ||||
| auth: { | |||||
| unAuthorized: 'Unauthorized', | |||||
| authRemoved: 'Auth removed', | |||||
| apiKeys: 'API Keys', | |||||
| addApiKey: 'Add API Key', | |||||
| addNewModel: 'Add new model', | |||||
| addCredential: 'Add credential', | |||||
| addModelCredential: 'Add model credential', | |||||
| modelCredentials: 'Model credentials', | |||||
| configModel: 'Config model', | |||||
| configLoadBalancing: 'Config Load Balancing', | |||||
| authorizationError: 'Authorization error', | |||||
| specifyModelCredential: 'Specify model credential', | |||||
| specifyModelCredentialTip: 'Use a configured model credential.', | |||||
| providerManaged: 'Provider managed', | |||||
| providerManagedTip: 'The current configuration is hosted by the provider.', | |||||
| apiKeyModal: { | |||||
| title: 'API Key Authorization Configuration', | |||||
| desc: 'After configuring credentials, all members within the workspace can use this model when orchestrating applications.', | |||||
| addModel: 'Add model', | |||||
| }, | |||||
| }, | |||||
| }, | }, | ||||
| dataSource: { | dataSource: { | ||||
| add: 'Add a data source', | add: 'Add a data source', |
| authRemoved: 'Auth removed', | authRemoved: 'Auth removed', | ||||
| clientInfo: 'As no system client secrets found for this tool provider, setup it manually is required, for redirect_uri, please use', | clientInfo: 'As no system client secrets found for this tool provider, setup it manually is required, for redirect_uri, please use', | ||||
| oauthClient: 'OAuth Client', | oauthClient: 'OAuth Client', | ||||
| credentialUnavailable: 'Credentials currently unavailable. Please contact admin.', | |||||
| customCredentialUnavailable: 'Custom credentials currently unavailable', | |||||
| unavailable: 'Unavailable', | |||||
| }, | }, | ||||
| } | } | ||||
| deleteApp: '删除应用', | deleteApp: '删除应用', | ||||
| settings: '设置', | settings: '设置', | ||||
| setup: '设置', | setup: '设置', | ||||
| config: '配置', | |||||
| getForFree: '免费获取', | getForFree: '免费获取', | ||||
| reload: '刷新', | reload: '刷新', | ||||
| ok: '好的', | ok: '好的', | ||||
| loadPresets: '加载预设', | loadPresets: '加载预设', | ||||
| parameters: '参数', | parameters: '参数', | ||||
| loadBalancing: '负载均衡', | loadBalancing: '负载均衡', | ||||
| loadBalancingDescription: '为了减轻单组凭据的压力,您可以为模型调用配置多组凭据。', | |||||
| loadBalancingDescription: '为模型配置多组凭据,并自动调用。', | |||||
| loadBalancingHeadline: '负载均衡', | loadBalancingHeadline: '负载均衡', | ||||
| configLoadBalancing: '设置负载均衡', | configLoadBalancing: '设置负载均衡', | ||||
| modelHasBeenDeprecated: '该模型已废弃', | modelHasBeenDeprecated: '该模型已废弃', | ||||
| discoverMore: '发现更多就在', | discoverMore: '发现更多就在', | ||||
| emptyProviderTitle: '尚未安装模型供应商', | emptyProviderTitle: '尚未安装模型供应商', | ||||
| emptyProviderTip: '请安装模型供应商。', | emptyProviderTip: '请安装模型供应商。', | ||||
| auth: { | |||||
| unAuthorized: '未授权', | |||||
| authRemoved: '授权已移除', | |||||
| apiKeys: 'API 密钥', | |||||
| addApiKey: '添加 API 密钥', | |||||
| addNewModel: '添加新模型', | |||||
| addCredential: '添加凭据', | |||||
| addModelCredential: '添加模型凭据', | |||||
| modelCredentials: '模型凭据', | |||||
| configModel: '配置模型', | |||||
| configLoadBalancing: '配置负载均衡', | |||||
| authorizationError: '授权错误', | |||||
| specifyModelCredential: '指定模型凭据', | |||||
| specifyModelCredentialTip: '使用已配置的模型凭据。', | |||||
| providerManaged: '由模型供应商管理', | |||||
| providerManagedTip: '使用模型供应商提供的单组凭据。', | |||||
| apiKeyModal: { | |||||
| title: 'API 密钥授权配置', | |||||
| desc: '配置凭据后,工作空间中的所有成员都可以在编排应用时使用此模型。', | |||||
| addModel: '添加模型', | |||||
| }, | |||||
| }, | |||||
| }, | }, | ||||
| dataSource: { | dataSource: { | ||||
| add: '添加数据源', | add: '添加数据源', |
| authRemoved: '凭据已移除', | authRemoved: '凭据已移除', | ||||
| clientInfo: '由于未找到此工具提供者的系统客户端密钥,因此需要手动设置,对于 redirect_uri,请使用', | clientInfo: '由于未找到此工具提供者的系统客户端密钥,因此需要手动设置,对于 redirect_uri,请使用', | ||||
| oauthClient: 'OAuth 客户端', | oauthClient: 'OAuth 客户端', | ||||
| credentialUnavailable: '自定义凭据当前不可用,请联系管理员。', | |||||
| customCredentialUnavailable: '自定义凭据当前不可用', | |||||
| unavailable: '不可用', | |||||
| }, | }, | ||||
| } | } | ||||
| import { get } from './base' | |||||
| import { | |||||
| del, | |||||
| get, | |||||
| post, | |||||
| put, | |||||
| } from './base' | |||||
| import type { | import type { | ||||
| ModelCredential, | |||||
| ModelItem, | ModelItem, | ||||
| ModelLoadBalancingConfig, | |||||
| ModelTypeEnum, | |||||
| ProviderCredential, | |||||
| } from '@/app/components/header/account-setting/model-provider-page/declarations' | } from '@/app/components/header/account-setting/model-provider-page/declarations' | ||||
| import { | import { | ||||
| useMutation, | |||||
| useQuery, | useQuery, | ||||
| // useQueryClient, | // useQueryClient, | ||||
| } from '@tanstack/react-query' | } from '@tanstack/react-query' | ||||
| queryFn: () => get<{ data: ModelItem[] }>(`/workspaces/current/model-providers/${provider}/models`), | queryFn: () => get<{ data: ModelItem[] }>(`/workspaces/current/model-providers/${provider}/models`), | ||||
| }) | }) | ||||
| } | } | ||||
| export const useGetProviderCredential = (enabled: boolean, provider: string, credentialId?: string) => { | |||||
| return useQuery({ | |||||
| enabled, | |||||
| queryKey: [NAME_SPACE, 'model-list', provider, credentialId], | |||||
| queryFn: () => get<ProviderCredential>(`/workspaces/current/model-providers/${provider}/credentials${credentialId ? `?credential_id=${credentialId}` : ''}`), | |||||
| }) | |||||
| } | |||||
| export const useAddProviderCredential = (provider: string) => { | |||||
| return useMutation({ | |||||
| mutationFn: (data: ProviderCredential) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials`, { | |||||
| body: data, | |||||
| }), | |||||
| }) | |||||
| } | |||||
| export const useEditProviderCredential = (provider: string) => { | |||||
| return useMutation({ | |||||
| mutationFn: (data: ProviderCredential) => put<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials`, { | |||||
| body: data, | |||||
| }), | |||||
| }) | |||||
| } | |||||
| export const useDeleteProviderCredential = (provider: string) => { | |||||
| return useMutation({ | |||||
| mutationFn: (data: { | |||||
| credential_id: string | |||||
| }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials`, { | |||||
| body: data, | |||||
| }), | |||||
| }) | |||||
| } | |||||
| export const useActiveProviderCredential = (provider: string) => { | |||||
| return useMutation({ | |||||
| mutationFn: (data: { | |||||
| credential_id: string | |||||
| model?: string | |||||
| model_type?: ModelTypeEnum | |||||
| }) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/credentials/switch`, { | |||||
| body: data, | |||||
| }), | |||||
| }) | |||||
| } | |||||
| export const useGetModelCredential = ( | |||||
| enabled: boolean, | |||||
| provider: string, | |||||
| credentialId?: string, | |||||
| model?: string, | |||||
| modelType?: string, | |||||
| configFrom?: string, | |||||
| ) => { | |||||
| return useQuery({ | |||||
| enabled, | |||||
| queryKey: [NAME_SPACE, 'model-list', provider, model, modelType, credentialId], | |||||
| queryFn: () => get<ModelCredential>(`/workspaces/current/model-providers/${provider}/models/credentials?model=${model}&model_type=${modelType}&config_from=${configFrom}${credentialId ? `&credential_id=${credentialId}` : ''}`), | |||||
| staleTime: 0, | |||||
| gcTime: 0, | |||||
| }) | |||||
| } | |||||
| export const useAddModelCredential = (provider: string) => { | |||||
| return useMutation({ | |||||
| mutationFn: (data: ModelCredential) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { | |||||
| body: data, | |||||
| }), | |||||
| }) | |||||
| } | |||||
| export const useEditModelCredential = (provider: string) => { | |||||
| return useMutation({ | |||||
| mutationFn: (data: ModelCredential) => put<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { | |||||
| body: data, | |||||
| }), | |||||
| }) | |||||
| } | |||||
| export const useDeleteModelCredential = (provider: string) => { | |||||
| return useMutation({ | |||||
| mutationFn: (data: { | |||||
| credential_id: string | |||||
| model?: string | |||||
| model_type?: ModelTypeEnum | |||||
| }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { | |||||
| body: data, | |||||
| }), | |||||
| }) | |||||
| } | |||||
| export const useDeleteModel = (provider: string) => { | |||||
| return useMutation({ | |||||
| mutationFn: (data: { | |||||
| model: string | |||||
| model_type: ModelTypeEnum | |||||
| }) => del<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials`, { | |||||
| body: data, | |||||
| }), | |||||
| }) | |||||
| } | |||||
| export const useActiveModelCredential = (provider: string) => { | |||||
| return useMutation({ | |||||
| mutationFn: (data: { | |||||
| credential_id: string | |||||
| model?: string | |||||
| model_type?: ModelTypeEnum | |||||
| }) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/models/credentials/switch`, { | |||||
| body: data, | |||||
| }), | |||||
| }) | |||||
| } | |||||
| export const useUpdateModelLoadBalancingConfig = (provider: string) => { | |||||
| return useMutation({ | |||||
| mutationFn: (data: { | |||||
| config_from: string | |||||
| model: string | |||||
| model_type: ModelTypeEnum | |||||
| load_balancing: ModelLoadBalancingConfig | |||||
| credential_id?: string | |||||
| }) => post<{ result: string }>(`/workspaces/current/model-providers/${provider}/models`, { | |||||
| body: data, | |||||
| }), | |||||
| }) | |||||
| } |
| enabled: !!url, | enabled: !!url, | ||||
| queryKey: [NAME_SPACE, 'credential-info', url], | queryKey: [NAME_SPACE, 'credential-info', url], | ||||
| queryFn: () => get<{ | queryFn: () => get<{ | ||||
| allow_custom_token?: boolean | |||||
| supported_credential_types: string[] | supported_credential_types: string[] | ||||
| credentials: Credential[] | credentials: Credential[] | ||||
| is_oauth_custom_client_enabled: boolean | is_oauth_custom_client_enabled: boolean |