Co-authored-by: Claude <noreply@anthropic.com>tags/1.8.1
| @@ -67,7 +67,7 @@ class ModelProviderCredentialApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| @@ -94,7 +94,7 @@ class ModelProviderCredentialApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| @@ -219,7 +219,11 @@ class ModelProviderModelCredentialApi(Resource): | |||
| model_load_balancing_service = ModelLoadBalancingService() | |||
| is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args["model"], | |||
| model_type=args["model_type"], | |||
| config_from=args.get("config_from", ""), | |||
| ) | |||
| if args.get("config_from", "") == "predefined-model": | |||
| @@ -263,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource): | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| @@ -309,7 +313,7 @@ class ModelProviderModelCredentialApi(Resource): | |||
| ) | |||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| @@ -1,5 +1,6 @@ | |||
| import json | |||
| import logging | |||
| import re | |||
| from collections import defaultdict | |||
| from collections.abc import Iterator, Sequence | |||
| from json import JSONDecodeError | |||
| @@ -343,7 +344,65 @@ class ProviderConfiguration(BaseModel): | |||
| with Session(db.engine) as new_session: | |||
| return _validate(new_session) | |||
| def create_provider_credential(self, credentials: dict, credential_name: str) -> None: | |||
| def _generate_provider_credential_name(self, session) -> str: | |||
| """ | |||
| Generate a unique credential name for provider. | |||
| :return: credential name | |||
| """ | |||
| return self._generate_next_api_key_name( | |||
| session=session, | |||
| query_factory=lambda: select(ProviderCredential).where( | |||
| ProviderCredential.tenant_id == self.tenant_id, | |||
| ProviderCredential.provider_name == self.provider.provider, | |||
| ), | |||
| ) | |||
| def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str: | |||
| """ | |||
| Generate a unique credential name for custom model. | |||
| :return: credential name | |||
| """ | |||
| return self._generate_next_api_key_name( | |||
| session=session, | |||
| query_factory=lambda: select(ProviderModelCredential).where( | |||
| ProviderModelCredential.tenant_id == self.tenant_id, | |||
| ProviderModelCredential.provider_name == self.provider.provider, | |||
| ProviderModelCredential.model_name == model, | |||
| ProviderModelCredential.model_type == model_type.to_origin_model_type(), | |||
| ), | |||
| ) | |||
| def _generate_next_api_key_name(self, session, query_factory) -> str: | |||
| """ | |||
| Generate next available API KEY name by finding the highest numbered suffix. | |||
| :param session: database session | |||
| :param query_factory: function that returns the SQLAlchemy query | |||
| :return: next available API KEY name | |||
| """ | |||
| try: | |||
| stmt = query_factory() | |||
| credential_records = session.execute(stmt).scalars().all() | |||
| if not credential_records: | |||
| return "API KEY 1" | |||
| # Extract numbers from API KEY pattern using list comprehension | |||
| pattern = re.compile(r"^API KEY\s+(\d+)$") | |||
| numbers = [ | |||
| int(match.group(1)) | |||
| for cr in credential_records | |||
| if cr.credential_name and (match := pattern.match(cr.credential_name.strip())) | |||
| ] | |||
| # Return next sequential number | |||
| next_number = max(numbers, default=0) + 1 | |||
| return f"API KEY {next_number}" | |||
| except Exception as e: | |||
| logger.warning("Error generating next credential name: %s", str(e)) | |||
| return "API KEY 1" | |||
| def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None: | |||
| """ | |||
| Add custom provider credentials. | |||
| :param credentials: provider credentials | |||
| @@ -351,8 +410,12 @@ class ProviderConfiguration(BaseModel): | |||
| :return: | |||
| """ | |||
| with Session(db.engine) as session: | |||
| if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): | |||
| if credential_name and self._check_provider_credential_name_exists( | |||
| credential_name=credential_name, session=session | |||
| ): | |||
| raise ValueError(f"Credential with name '{credential_name}' already exists.") | |||
| else: | |||
| credential_name = self._generate_provider_credential_name(session) | |||
| credentials = self.validate_provider_credentials(credentials=credentials, session=session) | |||
| provider_record = self._get_provider_record(session) | |||
| @@ -395,7 +458,7 @@ class ProviderConfiguration(BaseModel): | |||
| self, | |||
| credentials: dict, | |||
| credential_id: str, | |||
| credential_name: str, | |||
| credential_name: str | None, | |||
| ) -> None: | |||
| """ | |||
| update a saved provider credential (by credential_id). | |||
| @@ -406,7 +469,7 @@ class ProviderConfiguration(BaseModel): | |||
| :return: | |||
| """ | |||
| with Session(db.engine) as session: | |||
| if self._check_provider_credential_name_exists( | |||
| if credential_name and self._check_provider_credential_name_exists( | |||
| credential_name=credential_name, session=session, exclude_id=credential_id | |||
| ): | |||
| raise ValueError(f"Credential with name '{credential_name}' already exists.") | |||
| @@ -428,9 +491,9 @@ class ProviderConfiguration(BaseModel): | |||
| try: | |||
| # Update credential | |||
| credential_record.encrypted_config = json.dumps(credentials) | |||
| credential_record.credential_name = credential_name | |||
| credential_record.updated_at = naive_utc_now() | |||
| if credential_name: | |||
| credential_record.credential_name = credential_name | |||
| session.commit() | |||
| if provider_record and provider_record.credential_id == credential_id: | |||
| @@ -532,13 +595,7 @@ class ProviderConfiguration(BaseModel): | |||
| cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, | |||
| ) | |||
| lb_credentials_cache.delete() | |||
| lb_config.credential_id = None | |||
| lb_config.encrypted_config = None | |||
| lb_config.enabled = False | |||
| lb_config.name = "__delete__" | |||
| lb_config.updated_at = naive_utc_now() | |||
| session.add(lb_config) | |||
| session.delete(lb_config) | |||
| # Check if this is the currently active credential | |||
| provider_record = self._get_provider_record(session) | |||
| @@ -822,7 +879,7 @@ class ProviderConfiguration(BaseModel): | |||
| return _validate(new_session) | |||
| def create_custom_model_credential( | |||
| self, model_type: ModelType, model: str, credentials: dict, credential_name: str | |||
| self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None | |||
| ) -> None: | |||
| """ | |||
| Create a custom model credential. | |||
| @@ -833,10 +890,14 @@ class ProviderConfiguration(BaseModel): | |||
| :return: | |||
| """ | |||
| with Session(db.engine) as session: | |||
| if self._check_custom_model_credential_name_exists( | |||
| if credential_name and self._check_custom_model_credential_name_exists( | |||
| model=model, model_type=model_type, credential_name=credential_name, session=session | |||
| ): | |||
| raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") | |||
| else: | |||
| credential_name = self._generate_custom_model_credential_name( | |||
| model=model, model_type=model_type, session=session | |||
| ) | |||
| # validate custom model config | |||
| credentials = self.validate_custom_model_credentials( | |||
| model_type=model_type, model=model, credentials=credentials, session=session | |||
| @@ -880,7 +941,7 @@ class ProviderConfiguration(BaseModel): | |||
| raise | |||
| def update_custom_model_credential( | |||
| self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str | |||
| self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str | |||
| ) -> None: | |||
| """ | |||
| Update a custom model credential. | |||
| @@ -893,7 +954,7 @@ class ProviderConfiguration(BaseModel): | |||
| :return: | |||
| """ | |||
| with Session(db.engine) as session: | |||
| if self._check_custom_model_credential_name_exists( | |||
| if credential_name and self._check_custom_model_credential_name_exists( | |||
| model=model, | |||
| model_type=model_type, | |||
| credential_name=credential_name, | |||
| @@ -925,8 +986,9 @@ class ProviderConfiguration(BaseModel): | |||
| try: | |||
| # Update credential | |||
| credential_record.encrypted_config = json.dumps(credentials) | |||
| credential_record.credential_name = credential_name | |||
| credential_record.updated_at = naive_utc_now() | |||
| if credential_name: | |||
| credential_record.credential_name = credential_name | |||
| session.commit() | |||
| if provider_model_record and provider_model_record.credential_id == credential_id: | |||
| @@ -982,12 +1044,7 @@ class ProviderConfiguration(BaseModel): | |||
| cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, | |||
| ) | |||
| lb_credentials_cache.delete() | |||
| lb_config.credential_id = None | |||
| lb_config.encrypted_config = None | |||
| lb_config.enabled = False | |||
| lb_config.name = "__delete__" | |||
| lb_config.updated_at = naive_utc_now() | |||
| session.add(lb_config) | |||
| session.delete(lb_config) | |||
| # Check if this is the currently active credential | |||
| provider_model_record = self._get_custom_model_record(model_type, model, session=session) | |||
| @@ -1054,6 +1111,7 @@ class ProviderConfiguration(BaseModel): | |||
| provider_name=self.provider.provider, | |||
| model_name=model, | |||
| model_type=model_type.to_origin_model_type(), | |||
| is_valid=True, | |||
| credential_id=credential_id, | |||
| ) | |||
| else: | |||
| @@ -1605,11 +1663,9 @@ class ProviderConfiguration(BaseModel): | |||
| if config.credential_source_type != "custom_model" | |||
| ] | |||
| if len(provider_model_lb_configs) > 1: | |||
| load_balancing_enabled = True | |||
| if any(config.name == "__delete__" for config in provider_model_lb_configs): | |||
| has_invalid_load_balancing_configs = True | |||
| load_balancing_enabled = model_setting.load_balancing_enabled | |||
| # when the user enable load_balancing but available configs are less than 2 display warning | |||
| has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2 | |||
| provider_models.append( | |||
| ModelWithProviderEntity( | |||
| @@ -1631,6 +1687,8 @@ class ProviderConfiguration(BaseModel): | |||
| for model_configuration in self.custom_configuration.models: | |||
| if model_configuration.model_type not in model_types: | |||
| continue | |||
| if model_configuration.unadded_to_model_list: | |||
| continue | |||
| if model and model != model_configuration.model: | |||
| continue | |||
| try: | |||
| @@ -1663,11 +1721,9 @@ class ProviderConfiguration(BaseModel): | |||
| if config.credential_source_type != "provider" | |||
| ] | |||
| if len(custom_model_lb_configs) > 1: | |||
| load_balancing_enabled = True | |||
| if any(config.name == "__delete__" for config in custom_model_lb_configs): | |||
| has_invalid_load_balancing_configs = True | |||
| load_balancing_enabled = model_setting.load_balancing_enabled | |||
| # when the user enable load_balancing but available configs are less than 2 display warning | |||
| has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2 | |||
| if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials: | |||
| status = ModelStatus.CREDENTIAL_REMOVED | |||
| @@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel): | |||
| current_credential_id: Optional[str] = None | |||
| current_credential_name: Optional[str] = None | |||
| available_model_credentials: list[CredentialConfiguration] = [] | |||
| unadded_to_model_list: Optional[bool] = False | |||
| # pydantic configs | |||
| model_config = ConfigDict(protected_namespaces=()) | |||
| class UnaddedModelConfiguration(BaseModel): | |||
| """ | |||
| Model class for provider unadded model configuration. | |||
| """ | |||
| model: str | |||
| model_type: ModelType | |||
| class CustomConfiguration(BaseModel): | |||
| """ | |||
| Model class for provider custom configuration. | |||
| @@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel): | |||
| provider: Optional[CustomProviderConfiguration] = None | |||
| models: list[CustomModelConfiguration] = [] | |||
| can_added_models: list[UnaddedModelConfiguration] = [] | |||
| class ModelLoadBalancingConfiguration(BaseModel): | |||
| @@ -144,6 +155,7 @@ class ModelSettings(BaseModel): | |||
| model: str | |||
| model_type: ModelType | |||
| enabled: bool = True | |||
| load_balancing_enabled: bool = False | |||
| load_balancing_configs: list[ModelLoadBalancingConfiguration] = [] | |||
| # pydantic configs | |||
| @@ -1,8 +1,9 @@ | |||
| import contextlib | |||
| import json | |||
| from collections import defaultdict | |||
| from collections.abc import Sequence | |||
| from json import JSONDecodeError | |||
| from typing import Any, Optional | |||
| from typing import Any, Optional, cast | |||
| from sqlalchemy import select | |||
| from sqlalchemy.exc import IntegrityError | |||
| @@ -22,6 +23,7 @@ from core.entities.provider_entities import ( | |||
| QuotaConfiguration, | |||
| QuotaUnit, | |||
| SystemConfiguration, | |||
| UnaddedModelConfiguration, | |||
| ) | |||
| from core.helper import encrypter | |||
| from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | |||
| @@ -537,6 +539,23 @@ class ProviderManager: | |||
| for credential in available_credentials | |||
| ] | |||
| @staticmethod | |||
| def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]: | |||
| """ | |||
| Get all the credentials records from ProviderModelCredential by provider_name | |||
| :param tenant_id: workspace id | |||
| :param provider_name: provider name | |||
| """ | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(ProviderModelCredential).where( | |||
| ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name | |||
| ) | |||
| all_credentials = session.scalars(stmt).all() | |||
| return all_credentials | |||
| @staticmethod | |||
| def _init_trial_provider_records( | |||
| tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] | |||
| @@ -623,6 +642,44 @@ class ProviderManager: | |||
| :param provider_model_records: provider model records | |||
| :return: | |||
| """ | |||
| # Get custom provider configuration | |||
| custom_provider_configuration = self._get_custom_provider_configuration( | |||
| tenant_id, provider_entity, provider_records | |||
| ) | |||
| # Get all model credentials once | |||
| all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider) | |||
| # Get custom models which have not been added to the model list yet | |||
| unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials) | |||
| # Get custom model configurations | |||
| custom_model_configurations = self._get_custom_model_configurations( | |||
| tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials | |||
| ) | |||
| can_added_models = [ | |||
| UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models | |||
| ] | |||
| return CustomConfiguration( | |||
| provider=custom_provider_configuration, | |||
| models=custom_model_configurations, | |||
| can_added_models=can_added_models, | |||
| ) | |||
| def _get_custom_provider_configuration( | |||
| self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] | |||
| ) -> CustomProviderConfiguration | None: | |||
| """Get custom provider configuration.""" | |||
| # Find custom provider record (non-system) | |||
| custom_provider_record = next( | |||
| (record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None | |||
| ) | |||
| if not custom_provider_record: | |||
| return None | |||
| # Get provider credential secret variables | |||
| provider_credential_secret_variables = self._extract_secret_variables( | |||
| provider_entity.provider_credential_schema.credential_form_schemas | |||
| @@ -630,113 +687,98 @@ class ProviderManager: | |||
| else [] | |||
| ) | |||
| # Get custom provider record | |||
| custom_provider_record = None | |||
| for provider_record in provider_records: | |||
| if provider_record.provider_type == ProviderType.SYSTEM.value: | |||
| continue | |||
| # Get and decrypt provider credentials | |||
| provider_credentials = self._get_and_decrypt_credentials( | |||
| tenant_id=tenant_id, | |||
| record_id=custom_provider_record.id, | |||
| encrypted_config=custom_provider_record.encrypted_config, | |||
| secret_variables=provider_credential_secret_variables, | |||
| cache_type=ProviderCredentialsCacheType.PROVIDER, | |||
| is_provider=True, | |||
| ) | |||
| custom_provider_record = provider_record | |||
| return CustomProviderConfiguration( | |||
| credentials=provider_credentials, | |||
| current_credential_name=custom_provider_record.credential_name, | |||
| current_credential_id=custom_provider_record.credential_id, | |||
| available_credentials=self.get_provider_available_credentials( | |||
| tenant_id, custom_provider_record.provider_name | |||
| ), | |||
| ) | |||
| # Get custom provider credentials | |||
| custom_provider_configuration = None | |||
| if custom_provider_record: | |||
| provider_credentials_cache = ProviderCredentialsCache( | |||
| tenant_id=tenant_id, | |||
| identity_id=custom_provider_record.id, | |||
| cache_type=ProviderCredentialsCacheType.PROVIDER, | |||
| ) | |||
| def _get_can_added_models( | |||
| self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential] | |||
| ) -> list[dict]: | |||
| """Get the custom models and credentials from enterprise version which haven't add to the model list""" | |||
| existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records} | |||
| # Get not added custom models credentials | |||
| not_added_custom_models_credentials = [ | |||
| credential | |||
| for credential in all_model_credentials | |||
| if (credential.model_name, credential.model_type) not in existing_model_set | |||
| ] | |||
| # Get cached provider credentials | |||
| cached_provider_credentials = provider_credentials_cache.get() | |||
| if not cached_provider_credentials: | |||
| try: | |||
| # fix origin data | |||
| if custom_provider_record.encrypted_config is None: | |||
| provider_credentials = {} | |||
| elif not custom_provider_record.encrypted_config.startswith("{"): | |||
| provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} | |||
| else: | |||
| provider_credentials = json.loads(custom_provider_record.encrypted_config) | |||
| except JSONDecodeError: | |||
| provider_credentials = {} | |||
| # Get decoding rsa key and cipher for decrypting credentials | |||
| if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: | |||
| self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) | |||
| for variable in provider_credential_secret_variables: | |||
| if variable in provider_credentials: | |||
| with contextlib.suppress(ValueError): | |||
| provider_credentials[variable] = encrypter.decrypt_token_with_decoding( | |||
| provider_credentials.get(variable) or "", # type: ignore | |||
| self.decoding_rsa_key, | |||
| self.decoding_cipher_rsa, | |||
| ) | |||
| # Group credentials by model | |||
| model_to_credentials = defaultdict(list) | |||
| for credential in not_added_custom_models_credentials: | |||
| model_to_credentials[(credential.model_name, credential.model_type)].append(credential) | |||
| # cache provider credentials | |||
| provider_credentials_cache.set(credentials=provider_credentials) | |||
| else: | |||
| provider_credentials = cached_provider_credentials | |||
| custom_provider_configuration = CustomProviderConfiguration( | |||
| credentials=provider_credentials, | |||
| current_credential_name=custom_provider_record.credential_name, | |||
| current_credential_id=custom_provider_record.credential_id, | |||
| available_credentials=self.get_provider_available_credentials( | |||
| tenant_id, custom_provider_record.provider_name | |||
| ), | |||
| ) | |||
| return [ | |||
| { | |||
| "model": model_key[0], | |||
| "model_type": ModelType.value_of(model_key[1]), | |||
| "available_model_credentials": [ | |||
| CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name) | |||
| for cred in creds | |||
| ], | |||
| } | |||
| for model_key, creds in model_to_credentials.items() | |||
| ] | |||
| # Get provider model credential secret variables | |||
| def _get_custom_model_configurations( | |||
| self, | |||
| tenant_id: str, | |||
| provider_entity: ProviderEntity, | |||
| provider_model_records: list[ProviderModel], | |||
| can_added_models: list[dict], | |||
| all_model_credentials: Sequence[ProviderModelCredential], | |||
| ) -> list[CustomModelConfiguration]: | |||
| """Get custom model configurations.""" | |||
| # Get model credential secret variables | |||
| model_credential_secret_variables = self._extract_secret_variables( | |||
| provider_entity.model_credential_schema.credential_form_schemas | |||
| if provider_entity.model_credential_schema | |||
| else [] | |||
| ) | |||
| # Get custom provider model credentials | |||
| # Create credentials lookup for efficient access | |||
| credentials_map = defaultdict(list) | |||
| for credential in all_model_credentials: | |||
| credentials_map[(credential.model_name, credential.model_type)].append(credential) | |||
| custom_model_configurations = [] | |||
| # Process existing model records | |||
| for provider_model_record in provider_model_records: | |||
| available_model_credentials = self.get_provider_model_available_credentials( | |||
| tenant_id, | |||
| provider_model_record.provider_name, | |||
| provider_model_record.model_name, | |||
| provider_model_record.model_type, | |||
| ) | |||
| # Use pre-fetched credentials instead of individual database calls | |||
| available_model_credentials = [ | |||
| CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name) | |||
| for cred in credentials_map.get( | |||
| (provider_model_record.model_name, provider_model_record.model_type), [] | |||
| ) | |||
| ] | |||
| provider_model_credentials_cache = ProviderCredentialsCache( | |||
| tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL | |||
| # Get and decrypt model credentials | |||
| provider_model_credentials = self._get_and_decrypt_credentials( | |||
| tenant_id=tenant_id, | |||
| record_id=provider_model_record.id, | |||
| encrypted_config=provider_model_record.encrypted_config, | |||
| secret_variables=model_credential_secret_variables, | |||
| cache_type=ProviderCredentialsCacheType.MODEL, | |||
| is_provider=False, | |||
| ) | |||
| # Get cached provider model credentials | |||
| cached_provider_model_credentials = provider_model_credentials_cache.get() | |||
| if not cached_provider_model_credentials and provider_model_record.encrypted_config: | |||
| try: | |||
| provider_model_credentials = json.loads(provider_model_record.encrypted_config) | |||
| except JSONDecodeError: | |||
| continue | |||
| # Get decoding rsa key and cipher for decrypting credentials | |||
| if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: | |||
| self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) | |||
| for variable in model_credential_secret_variables: | |||
| if variable in provider_model_credentials: | |||
| with contextlib.suppress(ValueError): | |||
| provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( | |||
| provider_model_credentials.get(variable), | |||
| self.decoding_rsa_key, | |||
| self.decoding_cipher_rsa, | |||
| ) | |||
| # cache provider model credentials | |||
| provider_model_credentials_cache.set(credentials=provider_model_credentials) | |||
| else: | |||
| provider_model_credentials = cached_provider_model_credentials | |||
| custom_model_configurations.append( | |||
| CustomModelConfiguration( | |||
| model=provider_model_record.model_name, | |||
| @@ -748,7 +790,71 @@ class ProviderManager: | |||
| ) | |||
| ) | |||
| return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations) | |||
| # Add models that can be added | |||
| for model in can_added_models: | |||
| custom_model_configurations.append( | |||
| CustomModelConfiguration( | |||
| model=model["model"], | |||
| model_type=model["model_type"], | |||
| credentials=None, | |||
| current_credential_id=None, | |||
| current_credential_name=None, | |||
| available_model_credentials=model["available_model_credentials"], | |||
| unadded_to_model_list=True, | |||
| ) | |||
| ) | |||
| return custom_model_configurations | |||
| def _get_and_decrypt_credentials( | |||
| self, | |||
| tenant_id: str, | |||
| record_id: str, | |||
| encrypted_config: str | None, | |||
| secret_variables: list[str], | |||
| cache_type: ProviderCredentialsCacheType, | |||
| is_provider: bool = False, | |||
| ) -> dict: | |||
| """Get and decrypt credentials with caching.""" | |||
| credentials_cache = ProviderCredentialsCache( | |||
| tenant_id=tenant_id, | |||
| identity_id=record_id, | |||
| cache_type=cache_type, | |||
| ) | |||
| # Try to get from cache first | |||
| cached_credentials = credentials_cache.get() | |||
| if cached_credentials: | |||
| return cached_credentials | |||
| # Parse encrypted config | |||
| if not encrypted_config: | |||
| return {} | |||
| if is_provider and not encrypted_config.startswith("{"): | |||
| return {"openai_api_key": encrypted_config} | |||
| try: | |||
| credentials = cast(dict, json.loads(encrypted_config)) | |||
| except JSONDecodeError: | |||
| return {} | |||
| # Decrypt secret variables | |||
| if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: | |||
| self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) | |||
| for variable in secret_variables: | |||
| if variable in credentials: | |||
| with contextlib.suppress(ValueError): | |||
| credentials[variable] = encrypter.decrypt_token_with_decoding( | |||
| credentials.get(variable) or "", | |||
| self.decoding_rsa_key, | |||
| self.decoding_cipher_rsa, | |||
| ) | |||
| # Cache the decrypted credentials | |||
| credentials_cache.set(credentials=credentials) | |||
| return credentials | |||
| def _to_system_configuration( | |||
| self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] | |||
| @@ -956,18 +1062,6 @@ class ProviderManager: | |||
| load_balancing_model_config.model_name == provider_model_setting.model_name | |||
| and load_balancing_model_config.model_type == provider_model_setting.model_type | |||
| ): | |||
| if load_balancing_model_config.name == "__delete__": | |||
| # to calculate current model whether has invalidate lb configs | |||
| load_balancing_configs.append( | |||
| ModelLoadBalancingConfiguration( | |||
| id=load_balancing_model_config.id, | |||
| name=load_balancing_model_config.name, | |||
| credentials={}, | |||
| credential_source_type=load_balancing_model_config.credential_source_type, | |||
| ) | |||
| ) | |||
| continue | |||
| if not load_balancing_model_config.enabled: | |||
| continue | |||
| @@ -1033,6 +1127,7 @@ class ProviderManager: | |||
| model=provider_model_setting.model_name, | |||
| model_type=ModelType.value_of(provider_model_setting.model_type), | |||
| enabled=provider_model_setting.enabled, | |||
| load_balancing_enabled=provider_model_setting.load_balancing_enabled, | |||
| load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], | |||
| ) | |||
| ) | |||
| @@ -13,6 +13,7 @@ from core.entities.provider_entities import ( | |||
| CustomModelConfiguration, | |||
| ProviderQuotaType, | |||
| QuotaConfiguration, | |||
| UnaddedModelConfiguration, | |||
| ) | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| @@ -45,6 +46,7 @@ class CustomConfigurationResponse(BaseModel): | |||
| current_credential_name: Optional[str] = None | |||
| available_credentials: Optional[list[CredentialConfiguration]] = None | |||
| custom_models: Optional[list[CustomModelConfiguration]] = None | |||
| can_added_models: Optional[list[UnaddedModelConfiguration]] = None | |||
| class SystemConfigurationResponse(BaseModel): | |||
| @@ -3,6 +3,8 @@ import logging | |||
| from json import JSONDecodeError | |||
| from typing import Optional, Union | |||
| from sqlalchemy import or_ | |||
| from constants import HIDDEN_VALUE | |||
| from core.entities.provider_configuration import ProviderConfiguration | |||
| from core.helper import encrypter | |||
| @@ -69,7 +71,7 @@ class ModelLoadBalancingService: | |||
| provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) | |||
| def get_load_balancing_configs( | |||
| self, tenant_id: str, provider: str, model: str, model_type: str | |||
| self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = "" | |||
| ) -> tuple[bool, list[dict]]: | |||
| """ | |||
| Get load balancing configurations. | |||
| @@ -100,6 +102,11 @@ class ModelLoadBalancingService: | |||
| if provider_model_setting and provider_model_setting.load_balancing_enabled: | |||
| is_load_balancing_enabled = True | |||
| if config_from == "predefined-model": | |||
| credential_source_type = "provider" | |||
| else: | |||
| credential_source_type = "custom_model" | |||
| # Get load balancing configurations | |||
| load_balancing_configs = ( | |||
| db.session.query(LoadBalancingModelConfig) | |||
| @@ -108,6 +115,10 @@ class ModelLoadBalancingService: | |||
| LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | |||
| LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), | |||
| LoadBalancingModelConfig.model_name == model, | |||
| or_( | |||
| LoadBalancingModelConfig.credential_source_type == credential_source_type, | |||
| LoadBalancingModelConfig.credential_source_type.is_(None), | |||
| ), | |||
| ) | |||
| .order_by(LoadBalancingModelConfig.created_at) | |||
| .all() | |||
| @@ -405,7 +416,7 @@ class ModelLoadBalancingService: | |||
| self._clear_credentials_cache(tenant_id, config_id) | |||
| else: | |||
| # create load balancing config | |||
| if name in {"__inherit__", "__delete__"}: | |||
| if name == "__inherit__": | |||
| raise ValueError("Invalid load balancing config name") | |||
| if credential_id: | |||
| @@ -72,6 +72,7 @@ class ModelProviderService: | |||
| provider_config = provider_configuration.custom_configuration.provider | |||
| model_config = provider_configuration.custom_configuration.models | |||
| can_added_models = provider_configuration.custom_configuration.can_added_models | |||
| provider_response = ProviderResponse( | |||
| tenant_id=tenant_id, | |||
| @@ -95,6 +96,7 @@ class ModelProviderService: | |||
| current_credential_name=getattr(provider_config, "current_credential_name", None), | |||
| available_credentials=getattr(provider_config, "available_credentials", []), | |||
| custom_models=model_config, | |||
| can_added_models=can_added_models, | |||
| ), | |||
| system_configuration=SystemConfigurationResponse( | |||
| enabled=provider_configuration.system_configuration.enabled, | |||
| @@ -152,7 +154,7 @@ class ModelProviderService: | |||
| provider_configuration.validate_provider_credentials(credentials) | |||
| def create_provider_credential( | |||
| self, tenant_id: str, provider: str, credentials: dict, credential_name: str | |||
| self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None | |||
| ) -> None: | |||
| """ | |||
| Create and save new provider credentials. | |||
| @@ -172,7 +174,7 @@ class ModelProviderService: | |||
| provider: str, | |||
| credentials: dict, | |||
| credential_id: str, | |||
| credential_name: str, | |||
| credential_name: str | None, | |||
| ) -> None: | |||
| """ | |||
| update a saved provider credential (by credential_id). | |||
| @@ -249,7 +251,7 @@ class ModelProviderService: | |||
| ) | |||
| def create_model_credential( | |||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | |||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None | |||
| ) -> None: | |||
| """ | |||
| create and save model credentials. | |||
| @@ -278,7 +280,7 @@ class ModelProviderService: | |||
| model: str, | |||
| credentials: dict, | |||
| credential_id: str, | |||
| credential_name: str, | |||
| credential_name: str | None, | |||
| ) -> None: | |||
| """ | |||
| update model credentials. | |||