Co-authored-by: Claude <noreply@anthropic.com>tags/1.8.1
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | ||||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||||
| parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| model_provider_service = ModelProviderService() | model_provider_service = ModelProviderService() | ||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | ||||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | ||||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||||
| parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| model_provider_service = ModelProviderService() | model_provider_service = ModelProviderService() |
| model_load_balancing_service = ModelLoadBalancingService() | model_load_balancing_service = ModelLoadBalancingService() | ||||
| is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( | is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( | ||||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||||
| tenant_id=tenant_id, | |||||
| provider=provider, | |||||
| model=args["model"], | |||||
| model_type=args["model_type"], | |||||
| config_from=args.get("config_from", ""), | |||||
| ) | ) | ||||
| if args.get("config_from", "") == "predefined-model": | if args.get("config_from", "") == "predefined-model": | ||||
| choices=[mt.value for mt in ModelType], | choices=[mt.value for mt in ModelType], | ||||
| location="json", | location="json", | ||||
| ) | ) | ||||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||||
| parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") | |||||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| ) | ) | ||||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | ||||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | ||||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||||
| parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| model_provider_service = ModelProviderService() | model_provider_service = ModelProviderService() |
| import json | import json | ||||
| import logging | import logging | ||||
| import re | |||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from collections.abc import Iterator, Sequence | from collections.abc import Iterator, Sequence | ||||
| from json import JSONDecodeError | from json import JSONDecodeError | ||||
| with Session(db.engine) as new_session: | with Session(db.engine) as new_session: | ||||
| return _validate(new_session) | return _validate(new_session) | ||||
| def create_provider_credential(self, credentials: dict, credential_name: str) -> None: | |||||
| def _generate_provider_credential_name(self, session) -> str: | |||||
| """ | |||||
| Generate a unique credential name for provider. | |||||
| :return: credential name | |||||
| """ | |||||
| return self._generate_next_api_key_name( | |||||
| session=session, | |||||
| query_factory=lambda: select(ProviderCredential).where( | |||||
| ProviderCredential.tenant_id == self.tenant_id, | |||||
| ProviderCredential.provider_name == self.provider.provider, | |||||
| ), | |||||
| ) | |||||
| def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str: | |||||
| """ | |||||
| Generate a unique credential name for custom model. | |||||
| :return: credential name | |||||
| """ | |||||
| return self._generate_next_api_key_name( | |||||
| session=session, | |||||
| query_factory=lambda: select(ProviderModelCredential).where( | |||||
| ProviderModelCredential.tenant_id == self.tenant_id, | |||||
| ProviderModelCredential.provider_name == self.provider.provider, | |||||
| ProviderModelCredential.model_name == model, | |||||
| ProviderModelCredential.model_type == model_type.to_origin_model_type(), | |||||
| ), | |||||
| ) | |||||
| def _generate_next_api_key_name(self, session, query_factory) -> str: | |||||
| """ | |||||
| Generate next available API KEY name by finding the highest numbered suffix. | |||||
| :param session: database session | |||||
| :param query_factory: function that returns the SQLAlchemy query | |||||
| :return: next available API KEY name | |||||
| """ | |||||
| try: | |||||
| stmt = query_factory() | |||||
| credential_records = session.execute(stmt).scalars().all() | |||||
| if not credential_records: | |||||
| return "API KEY 1" | |||||
| # Extract numbers from API KEY pattern using list comprehension | |||||
| pattern = re.compile(r"^API KEY\s+(\d+)$") | |||||
| numbers = [ | |||||
| int(match.group(1)) | |||||
| for cr in credential_records | |||||
| if cr.credential_name and (match := pattern.match(cr.credential_name.strip())) | |||||
| ] | |||||
| # Return next sequential number | |||||
| next_number = max(numbers, default=0) + 1 | |||||
| return f"API KEY {next_number}" | |||||
| except Exception as e: | |||||
| logger.warning("Error generating next credential name: %s", str(e)) | |||||
| return "API KEY 1" | |||||
| def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None: | |||||
| """ | """ | ||||
| Add custom provider credentials. | Add custom provider credentials. | ||||
| :param credentials: provider credentials | :param credentials: provider credentials | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| with Session(db.engine) as session: | with Session(db.engine) as session: | ||||
| if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): | |||||
| if credential_name and self._check_provider_credential_name_exists( | |||||
| credential_name=credential_name, session=session | |||||
| ): | |||||
| raise ValueError(f"Credential with name '{credential_name}' already exists.") | raise ValueError(f"Credential with name '{credential_name}' already exists.") | ||||
| else: | |||||
| credential_name = self._generate_provider_credential_name(session) | |||||
| credentials = self.validate_provider_credentials(credentials=credentials, session=session) | credentials = self.validate_provider_credentials(credentials=credentials, session=session) | ||||
| provider_record = self._get_provider_record(session) | provider_record = self._get_provider_record(session) | ||||
| self, | self, | ||||
| credentials: dict, | credentials: dict, | ||||
| credential_id: str, | credential_id: str, | ||||
| credential_name: str, | |||||
| credential_name: str | None, | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| update a saved provider credential (by credential_id). | update a saved provider credential (by credential_id). | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| with Session(db.engine) as session: | with Session(db.engine) as session: | ||||
| if self._check_provider_credential_name_exists( | |||||
| if credential_name and self._check_provider_credential_name_exists( | |||||
| credential_name=credential_name, session=session, exclude_id=credential_id | credential_name=credential_name, session=session, exclude_id=credential_id | ||||
| ): | ): | ||||
| raise ValueError(f"Credential with name '{credential_name}' already exists.") | raise ValueError(f"Credential with name '{credential_name}' already exists.") | ||||
| try: | try: | ||||
| # Update credential | # Update credential | ||||
| credential_record.encrypted_config = json.dumps(credentials) | credential_record.encrypted_config = json.dumps(credentials) | ||||
| credential_record.credential_name = credential_name | |||||
| credential_record.updated_at = naive_utc_now() | credential_record.updated_at = naive_utc_now() | ||||
| if credential_name: | |||||
| credential_record.credential_name = credential_name | |||||
| session.commit() | session.commit() | ||||
| if provider_record and provider_record.credential_id == credential_id: | if provider_record and provider_record.credential_id == credential_id: | ||||
| cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, | cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, | ||||
| ) | ) | ||||
| lb_credentials_cache.delete() | lb_credentials_cache.delete() | ||||
| lb_config.credential_id = None | |||||
| lb_config.encrypted_config = None | |||||
| lb_config.enabled = False | |||||
| lb_config.name = "__delete__" | |||||
| lb_config.updated_at = naive_utc_now() | |||||
| session.add(lb_config) | |||||
| session.delete(lb_config) | |||||
| # Check if this is the currently active credential | # Check if this is the currently active credential | ||||
| provider_record = self._get_provider_record(session) | provider_record = self._get_provider_record(session) | ||||
| return _validate(new_session) | return _validate(new_session) | ||||
| def create_custom_model_credential( | def create_custom_model_credential( | ||||
| self, model_type: ModelType, model: str, credentials: dict, credential_name: str | |||||
| self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| Create a custom model credential. | Create a custom model credential. | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| with Session(db.engine) as session: | with Session(db.engine) as session: | ||||
| if self._check_custom_model_credential_name_exists( | |||||
| if credential_name and self._check_custom_model_credential_name_exists( | |||||
| model=model, model_type=model_type, credential_name=credential_name, session=session | model=model, model_type=model_type, credential_name=credential_name, session=session | ||||
| ): | ): | ||||
| raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") | raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") | ||||
| else: | |||||
| credential_name = self._generate_custom_model_credential_name( | |||||
| model=model, model_type=model_type, session=session | |||||
| ) | |||||
| # validate custom model config | # validate custom model config | ||||
| credentials = self.validate_custom_model_credentials( | credentials = self.validate_custom_model_credentials( | ||||
| model_type=model_type, model=model, credentials=credentials, session=session | model_type=model_type, model=model, credentials=credentials, session=session | ||||
| raise | raise | ||||
| def update_custom_model_credential( | def update_custom_model_credential( | ||||
| self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str | |||||
| self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| Update a custom model credential. | Update a custom model credential. | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| with Session(db.engine) as session: | with Session(db.engine) as session: | ||||
| if self._check_custom_model_credential_name_exists( | |||||
| if credential_name and self._check_custom_model_credential_name_exists( | |||||
| model=model, | model=model, | ||||
| model_type=model_type, | model_type=model_type, | ||||
| credential_name=credential_name, | credential_name=credential_name, | ||||
| try: | try: | ||||
| # Update credential | # Update credential | ||||
| credential_record.encrypted_config = json.dumps(credentials) | credential_record.encrypted_config = json.dumps(credentials) | ||||
| credential_record.credential_name = credential_name | |||||
| credential_record.updated_at = naive_utc_now() | credential_record.updated_at = naive_utc_now() | ||||
| if credential_name: | |||||
| credential_record.credential_name = credential_name | |||||
| session.commit() | session.commit() | ||||
| if provider_model_record and provider_model_record.credential_id == credential_id: | if provider_model_record and provider_model_record.credential_id == credential_id: | ||||
| cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, | cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, | ||||
| ) | ) | ||||
| lb_credentials_cache.delete() | lb_credentials_cache.delete() | ||||
| lb_config.credential_id = None | |||||
| lb_config.encrypted_config = None | |||||
| lb_config.enabled = False | |||||
| lb_config.name = "__delete__" | |||||
| lb_config.updated_at = naive_utc_now() | |||||
| session.add(lb_config) | |||||
| session.delete(lb_config) | |||||
| # Check if this is the currently active credential | # Check if this is the currently active credential | ||||
| provider_model_record = self._get_custom_model_record(model_type, model, session=session) | provider_model_record = self._get_custom_model_record(model_type, model, session=session) | ||||
| provider_name=self.provider.provider, | provider_name=self.provider.provider, | ||||
| model_name=model, | model_name=model, | ||||
| model_type=model_type.to_origin_model_type(), | model_type=model_type.to_origin_model_type(), | ||||
| is_valid=True, | |||||
| credential_id=credential_id, | credential_id=credential_id, | ||||
| ) | ) | ||||
| else: | else: | ||||
| if config.credential_source_type != "custom_model" | if config.credential_source_type != "custom_model" | ||||
| ] | ] | ||||
| if len(provider_model_lb_configs) > 1: | |||||
| load_balancing_enabled = True | |||||
| if any(config.name == "__delete__" for config in provider_model_lb_configs): | |||||
| has_invalid_load_balancing_configs = True | |||||
| load_balancing_enabled = model_setting.load_balancing_enabled | |||||
| # when the user enable load_balancing but available configs are less than 2 display warning | |||||
| has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2 | |||||
| provider_models.append( | provider_models.append( | ||||
| ModelWithProviderEntity( | ModelWithProviderEntity( | ||||
| for model_configuration in self.custom_configuration.models: | for model_configuration in self.custom_configuration.models: | ||||
| if model_configuration.model_type not in model_types: | if model_configuration.model_type not in model_types: | ||||
| continue | continue | ||||
| if model_configuration.unadded_to_model_list: | |||||
| continue | |||||
| if model and model != model_configuration.model: | if model and model != model_configuration.model: | ||||
| continue | continue | ||||
| try: | try: | ||||
| if config.credential_source_type != "provider" | if config.credential_source_type != "provider" | ||||
| ] | ] | ||||
| if len(custom_model_lb_configs) > 1: | |||||
| load_balancing_enabled = True | |||||
| if any(config.name == "__delete__" for config in custom_model_lb_configs): | |||||
| has_invalid_load_balancing_configs = True | |||||
| load_balancing_enabled = model_setting.load_balancing_enabled | |||||
| # when the user enable load_balancing but available configs are less than 2 display warning | |||||
| has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2 | |||||
| if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials: | if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials: | ||||
| status = ModelStatus.CREDENTIAL_REMOVED | status = ModelStatus.CREDENTIAL_REMOVED |
| current_credential_id: Optional[str] = None | current_credential_id: Optional[str] = None | ||||
| current_credential_name: Optional[str] = None | current_credential_name: Optional[str] = None | ||||
| available_model_credentials: list[CredentialConfiguration] = [] | available_model_credentials: list[CredentialConfiguration] = [] | ||||
| unadded_to_model_list: Optional[bool] = False | |||||
| # pydantic configs | # pydantic configs | ||||
| model_config = ConfigDict(protected_namespaces=()) | model_config = ConfigDict(protected_namespaces=()) | ||||
| class UnaddedModelConfiguration(BaseModel): | |||||
| """ | |||||
| Model class for provider unadded model configuration. | |||||
| """ | |||||
| model: str | |||||
| model_type: ModelType | |||||
| class CustomConfiguration(BaseModel): | class CustomConfiguration(BaseModel): | ||||
| """ | """ | ||||
| Model class for provider custom configuration. | Model class for provider custom configuration. | ||||
| provider: Optional[CustomProviderConfiguration] = None | provider: Optional[CustomProviderConfiguration] = None | ||||
| models: list[CustomModelConfiguration] = [] | models: list[CustomModelConfiguration] = [] | ||||
| can_added_models: list[UnaddedModelConfiguration] = [] | |||||
| class ModelLoadBalancingConfiguration(BaseModel): | class ModelLoadBalancingConfiguration(BaseModel): | ||||
| model: str | model: str | ||||
| model_type: ModelType | model_type: ModelType | ||||
| enabled: bool = True | enabled: bool = True | ||||
| load_balancing_enabled: bool = False | |||||
| load_balancing_configs: list[ModelLoadBalancingConfiguration] = [] | load_balancing_configs: list[ModelLoadBalancingConfiguration] = [] | ||||
| # pydantic configs | # pydantic configs |
| import contextlib | import contextlib | ||||
| import json | import json | ||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from collections.abc import Sequence | |||||
| from json import JSONDecodeError | from json import JSONDecodeError | ||||
| from typing import Any, Optional | |||||
| from typing import Any, Optional, cast | |||||
| from sqlalchemy import select | from sqlalchemy import select | ||||
| from sqlalchemy.exc import IntegrityError | from sqlalchemy.exc import IntegrityError | ||||
| QuotaConfiguration, | QuotaConfiguration, | ||||
| QuotaUnit, | QuotaUnit, | ||||
| SystemConfiguration, | SystemConfiguration, | ||||
| UnaddedModelConfiguration, | |||||
| ) | ) | ||||
| from core.helper import encrypter | from core.helper import encrypter | ||||
| from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | ||||
| for credential in available_credentials | for credential in available_credentials | ||||
| ] | ] | ||||
| @staticmethod | |||||
| def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]: | |||||
| """ | |||||
| Get all the credentials records from ProviderModelCredential by provider_name | |||||
| :param tenant_id: workspace id | |||||
| :param provider_name: provider name | |||||
| """ | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| stmt = select(ProviderModelCredential).where( | |||||
| ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name | |||||
| ) | |||||
| all_credentials = session.scalars(stmt).all() | |||||
| return all_credentials | |||||
| @staticmethod | @staticmethod | ||||
| def _init_trial_provider_records( | def _init_trial_provider_records( | ||||
| tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] | tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] | ||||
| :param provider_model_records: provider model records | :param provider_model_records: provider model records | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get custom provider configuration | |||||
| custom_provider_configuration = self._get_custom_provider_configuration( | |||||
| tenant_id, provider_entity, provider_records | |||||
| ) | |||||
| # Get all model credentials once | |||||
| all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider) | |||||
| # Get custom models which have not been added to the model list yet | |||||
| unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials) | |||||
| # Get custom model configurations | |||||
| custom_model_configurations = self._get_custom_model_configurations( | |||||
| tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials | |||||
| ) | |||||
| can_added_models = [ | |||||
| UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models | |||||
| ] | |||||
| return CustomConfiguration( | |||||
| provider=custom_provider_configuration, | |||||
| models=custom_model_configurations, | |||||
| can_added_models=can_added_models, | |||||
| ) | |||||
| def _get_custom_provider_configuration( | |||||
| self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] | |||||
| ) -> CustomProviderConfiguration | None: | |||||
| """Get custom provider configuration.""" | |||||
| # Find custom provider record (non-system) | |||||
| custom_provider_record = next( | |||||
| (record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None | |||||
| ) | |||||
| if not custom_provider_record: | |||||
| return None | |||||
| # Get provider credential secret variables | # Get provider credential secret variables | ||||
| provider_credential_secret_variables = self._extract_secret_variables( | provider_credential_secret_variables = self._extract_secret_variables( | ||||
| provider_entity.provider_credential_schema.credential_form_schemas | provider_entity.provider_credential_schema.credential_form_schemas | ||||
| else [] | else [] | ||||
| ) | ) | ||||
| # Get custom provider record | |||||
| custom_provider_record = None | |||||
| for provider_record in provider_records: | |||||
| if provider_record.provider_type == ProviderType.SYSTEM.value: | |||||
| continue | |||||
| # Get and decrypt provider credentials | |||||
| provider_credentials = self._get_and_decrypt_credentials( | |||||
| tenant_id=tenant_id, | |||||
| record_id=custom_provider_record.id, | |||||
| encrypted_config=custom_provider_record.encrypted_config, | |||||
| secret_variables=provider_credential_secret_variables, | |||||
| cache_type=ProviderCredentialsCacheType.PROVIDER, | |||||
| is_provider=True, | |||||
| ) | |||||
| custom_provider_record = provider_record | |||||
| return CustomProviderConfiguration( | |||||
| credentials=provider_credentials, | |||||
| current_credential_name=custom_provider_record.credential_name, | |||||
| current_credential_id=custom_provider_record.credential_id, | |||||
| available_credentials=self.get_provider_available_credentials( | |||||
| tenant_id, custom_provider_record.provider_name | |||||
| ), | |||||
| ) | |||||
| # Get custom provider credentials | |||||
| custom_provider_configuration = None | |||||
| if custom_provider_record: | |||||
| provider_credentials_cache = ProviderCredentialsCache( | |||||
| tenant_id=tenant_id, | |||||
| identity_id=custom_provider_record.id, | |||||
| cache_type=ProviderCredentialsCacheType.PROVIDER, | |||||
| ) | |||||
| def _get_can_added_models( | |||||
| self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential] | |||||
| ) -> list[dict]: | |||||
| """Get the custom models and credentials from enterprise version which haven't add to the model list""" | |||||
| existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records} | |||||
| # Get not added custom models credentials | |||||
| not_added_custom_models_credentials = [ | |||||
| credential | |||||
| for credential in all_model_credentials | |||||
| if (credential.model_name, credential.model_type) not in existing_model_set | |||||
| ] | |||||
| # Get cached provider credentials | |||||
| cached_provider_credentials = provider_credentials_cache.get() | |||||
| if not cached_provider_credentials: | |||||
| try: | |||||
| # fix origin data | |||||
| if custom_provider_record.encrypted_config is None: | |||||
| provider_credentials = {} | |||||
| elif not custom_provider_record.encrypted_config.startswith("{"): | |||||
| provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} | |||||
| else: | |||||
| provider_credentials = json.loads(custom_provider_record.encrypted_config) | |||||
| except JSONDecodeError: | |||||
| provider_credentials = {} | |||||
| # Get decoding rsa key and cipher for decrypting credentials | |||||
| if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: | |||||
| self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) | |||||
| for variable in provider_credential_secret_variables: | |||||
| if variable in provider_credentials: | |||||
| with contextlib.suppress(ValueError): | |||||
| provider_credentials[variable] = encrypter.decrypt_token_with_decoding( | |||||
| provider_credentials.get(variable) or "", # type: ignore | |||||
| self.decoding_rsa_key, | |||||
| self.decoding_cipher_rsa, | |||||
| ) | |||||
| # Group credentials by model | |||||
| model_to_credentials = defaultdict(list) | |||||
| for credential in not_added_custom_models_credentials: | |||||
| model_to_credentials[(credential.model_name, credential.model_type)].append(credential) | |||||
| # cache provider credentials | |||||
| provider_credentials_cache.set(credentials=provider_credentials) | |||||
| else: | |||||
| provider_credentials = cached_provider_credentials | |||||
| custom_provider_configuration = CustomProviderConfiguration( | |||||
| credentials=provider_credentials, | |||||
| current_credential_name=custom_provider_record.credential_name, | |||||
| current_credential_id=custom_provider_record.credential_id, | |||||
| available_credentials=self.get_provider_available_credentials( | |||||
| tenant_id, custom_provider_record.provider_name | |||||
| ), | |||||
| ) | |||||
| return [ | |||||
| { | |||||
| "model": model_key[0], | |||||
| "model_type": ModelType.value_of(model_key[1]), | |||||
| "available_model_credentials": [ | |||||
| CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name) | |||||
| for cred in creds | |||||
| ], | |||||
| } | |||||
| for model_key, creds in model_to_credentials.items() | |||||
| ] | |||||
| # Get provider model credential secret variables | |||||
| def _get_custom_model_configurations( | |||||
| self, | |||||
| tenant_id: str, | |||||
| provider_entity: ProviderEntity, | |||||
| provider_model_records: list[ProviderModel], | |||||
| can_added_models: list[dict], | |||||
| all_model_credentials: Sequence[ProviderModelCredential], | |||||
| ) -> list[CustomModelConfiguration]: | |||||
| """Get custom model configurations.""" | |||||
| # Get model credential secret variables | |||||
| model_credential_secret_variables = self._extract_secret_variables( | model_credential_secret_variables = self._extract_secret_variables( | ||||
| provider_entity.model_credential_schema.credential_form_schemas | provider_entity.model_credential_schema.credential_form_schemas | ||||
| if provider_entity.model_credential_schema | if provider_entity.model_credential_schema | ||||
| else [] | else [] | ||||
| ) | ) | ||||
| # Get custom provider model credentials | |||||
| # Create credentials lookup for efficient access | |||||
| credentials_map = defaultdict(list) | |||||
| for credential in all_model_credentials: | |||||
| credentials_map[(credential.model_name, credential.model_type)].append(credential) | |||||
| custom_model_configurations = [] | custom_model_configurations = [] | ||||
| # Process existing model records | |||||
| for provider_model_record in provider_model_records: | for provider_model_record in provider_model_records: | ||||
| available_model_credentials = self.get_provider_model_available_credentials( | |||||
| tenant_id, | |||||
| provider_model_record.provider_name, | |||||
| provider_model_record.model_name, | |||||
| provider_model_record.model_type, | |||||
| ) | |||||
| # Use pre-fetched credentials instead of individual database calls | |||||
| available_model_credentials = [ | |||||
| CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name) | |||||
| for cred in credentials_map.get( | |||||
| (provider_model_record.model_name, provider_model_record.model_type), [] | |||||
| ) | |||||
| ] | |||||
| provider_model_credentials_cache = ProviderCredentialsCache( | |||||
| tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL | |||||
| # Get and decrypt model credentials | |||||
| provider_model_credentials = self._get_and_decrypt_credentials( | |||||
| tenant_id=tenant_id, | |||||
| record_id=provider_model_record.id, | |||||
| encrypted_config=provider_model_record.encrypted_config, | |||||
| secret_variables=model_credential_secret_variables, | |||||
| cache_type=ProviderCredentialsCacheType.MODEL, | |||||
| is_provider=False, | |||||
| ) | ) | ||||
| # Get cached provider model credentials | |||||
| cached_provider_model_credentials = provider_model_credentials_cache.get() | |||||
| if not cached_provider_model_credentials and provider_model_record.encrypted_config: | |||||
| try: | |||||
| provider_model_credentials = json.loads(provider_model_record.encrypted_config) | |||||
| except JSONDecodeError: | |||||
| continue | |||||
| # Get decoding rsa key and cipher for decrypting credentials | |||||
| if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: | |||||
| self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) | |||||
| for variable in model_credential_secret_variables: | |||||
| if variable in provider_model_credentials: | |||||
| with contextlib.suppress(ValueError): | |||||
| provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( | |||||
| provider_model_credentials.get(variable), | |||||
| self.decoding_rsa_key, | |||||
| self.decoding_cipher_rsa, | |||||
| ) | |||||
| # cache provider model credentials | |||||
| provider_model_credentials_cache.set(credentials=provider_model_credentials) | |||||
| else: | |||||
| provider_model_credentials = cached_provider_model_credentials | |||||
| custom_model_configurations.append( | custom_model_configurations.append( | ||||
| CustomModelConfiguration( | CustomModelConfiguration( | ||||
| model=provider_model_record.model_name, | model=provider_model_record.model_name, | ||||
| ) | ) | ||||
| ) | ) | ||||
| return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations) | |||||
| # Add models that can be added | |||||
| for model in can_added_models: | |||||
| custom_model_configurations.append( | |||||
| CustomModelConfiguration( | |||||
| model=model["model"], | |||||
| model_type=model["model_type"], | |||||
| credentials=None, | |||||
| current_credential_id=None, | |||||
| current_credential_name=None, | |||||
| available_model_credentials=model["available_model_credentials"], | |||||
| unadded_to_model_list=True, | |||||
| ) | |||||
| ) | |||||
| return custom_model_configurations | |||||
| def _get_and_decrypt_credentials( | |||||
| self, | |||||
| tenant_id: str, | |||||
| record_id: str, | |||||
| encrypted_config: str | None, | |||||
| secret_variables: list[str], | |||||
| cache_type: ProviderCredentialsCacheType, | |||||
| is_provider: bool = False, | |||||
| ) -> dict: | |||||
| """Get and decrypt credentials with caching.""" | |||||
| credentials_cache = ProviderCredentialsCache( | |||||
| tenant_id=tenant_id, | |||||
| identity_id=record_id, | |||||
| cache_type=cache_type, | |||||
| ) | |||||
| # Try to get from cache first | |||||
| cached_credentials = credentials_cache.get() | |||||
| if cached_credentials: | |||||
| return cached_credentials | |||||
| # Parse encrypted config | |||||
| if not encrypted_config: | |||||
| return {} | |||||
| if is_provider and not encrypted_config.startswith("{"): | |||||
| return {"openai_api_key": encrypted_config} | |||||
| try: | |||||
| credentials = cast(dict, json.loads(encrypted_config)) | |||||
| except JSONDecodeError: | |||||
| return {} | |||||
| # Decrypt secret variables | |||||
| if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: | |||||
| self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) | |||||
| for variable in secret_variables: | |||||
| if variable in credentials: | |||||
| with contextlib.suppress(ValueError): | |||||
| credentials[variable] = encrypter.decrypt_token_with_decoding( | |||||
| credentials.get(variable) or "", | |||||
| self.decoding_rsa_key, | |||||
| self.decoding_cipher_rsa, | |||||
| ) | |||||
| # Cache the decrypted credentials | |||||
| credentials_cache.set(credentials=credentials) | |||||
| return credentials | |||||
| def _to_system_configuration( | def _to_system_configuration( | ||||
| self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] | self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] | ||||
| load_balancing_model_config.model_name == provider_model_setting.model_name | load_balancing_model_config.model_name == provider_model_setting.model_name | ||||
| and load_balancing_model_config.model_type == provider_model_setting.model_type | and load_balancing_model_config.model_type == provider_model_setting.model_type | ||||
| ): | ): | ||||
| if load_balancing_model_config.name == "__delete__": | |||||
| # to calculate current model whether has invalidate lb configs | |||||
| load_balancing_configs.append( | |||||
| ModelLoadBalancingConfiguration( | |||||
| id=load_balancing_model_config.id, | |||||
| name=load_balancing_model_config.name, | |||||
| credentials={}, | |||||
| credential_source_type=load_balancing_model_config.credential_source_type, | |||||
| ) | |||||
| ) | |||||
| continue | |||||
| if not load_balancing_model_config.enabled: | if not load_balancing_model_config.enabled: | ||||
| continue | continue | ||||
| model=provider_model_setting.model_name, | model=provider_model_setting.model_name, | ||||
| model_type=ModelType.value_of(provider_model_setting.model_type), | model_type=ModelType.value_of(provider_model_setting.model_type), | ||||
| enabled=provider_model_setting.enabled, | enabled=provider_model_setting.enabled, | ||||
| load_balancing_enabled=provider_model_setting.load_balancing_enabled, | |||||
| load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], | load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], | ||||
| ) | ) | ||||
| ) | ) |
| CustomModelConfiguration, | CustomModelConfiguration, | ||||
| ProviderQuotaType, | ProviderQuotaType, | ||||
| QuotaConfiguration, | QuotaConfiguration, | ||||
| UnaddedModelConfiguration, | |||||
| ) | ) | ||||
| from core.model_runtime.entities.common_entities import I18nObject | from core.model_runtime.entities.common_entities import I18nObject | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| current_credential_name: Optional[str] = None | current_credential_name: Optional[str] = None | ||||
| available_credentials: Optional[list[CredentialConfiguration]] = None | available_credentials: Optional[list[CredentialConfiguration]] = None | ||||
| custom_models: Optional[list[CustomModelConfiguration]] = None | custom_models: Optional[list[CustomModelConfiguration]] = None | ||||
| can_added_models: Optional[list[UnaddedModelConfiguration]] = None | |||||
| class SystemConfigurationResponse(BaseModel): | class SystemConfigurationResponse(BaseModel): |
| from json import JSONDecodeError | from json import JSONDecodeError | ||||
| from typing import Optional, Union | from typing import Optional, Union | ||||
| from sqlalchemy import or_ | |||||
| from constants import HIDDEN_VALUE | from constants import HIDDEN_VALUE | ||||
| from core.entities.provider_configuration import ProviderConfiguration | from core.entities.provider_configuration import ProviderConfiguration | ||||
| from core.helper import encrypter | from core.helper import encrypter | ||||
| provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) | provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) | ||||
| def get_load_balancing_configs( | def get_load_balancing_configs( | ||||
| self, tenant_id: str, provider: str, model: str, model_type: str | |||||
| self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = "" | |||||
| ) -> tuple[bool, list[dict]]: | ) -> tuple[bool, list[dict]]: | ||||
| """ | """ | ||||
| Get load balancing configurations. | Get load balancing configurations. | ||||
| if provider_model_setting and provider_model_setting.load_balancing_enabled: | if provider_model_setting and provider_model_setting.load_balancing_enabled: | ||||
| is_load_balancing_enabled = True | is_load_balancing_enabled = True | ||||
| if config_from == "predefined-model": | |||||
| credential_source_type = "provider" | |||||
| else: | |||||
| credential_source_type = "custom_model" | |||||
| # Get load balancing configurations | # Get load balancing configurations | ||||
| load_balancing_configs = ( | load_balancing_configs = ( | ||||
| db.session.query(LoadBalancingModelConfig) | db.session.query(LoadBalancingModelConfig) | ||||
| LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | ||||
| LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), | LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), | ||||
| LoadBalancingModelConfig.model_name == model, | LoadBalancingModelConfig.model_name == model, | ||||
| or_( | |||||
| LoadBalancingModelConfig.credential_source_type == credential_source_type, | |||||
| LoadBalancingModelConfig.credential_source_type.is_(None), | |||||
| ), | |||||
| ) | ) | ||||
| .order_by(LoadBalancingModelConfig.created_at) | .order_by(LoadBalancingModelConfig.created_at) | ||||
| .all() | .all() | ||||
| self._clear_credentials_cache(tenant_id, config_id) | self._clear_credentials_cache(tenant_id, config_id) | ||||
| else: | else: | ||||
| # create load balancing config | # create load balancing config | ||||
| if name in {"__inherit__", "__delete__"}: | |||||
| if name == "__inherit__": | |||||
| raise ValueError("Invalid load balancing config name") | raise ValueError("Invalid load balancing config name") | ||||
| if credential_id: | if credential_id: |
| provider_config = provider_configuration.custom_configuration.provider | provider_config = provider_configuration.custom_configuration.provider | ||||
| model_config = provider_configuration.custom_configuration.models | model_config = provider_configuration.custom_configuration.models | ||||
| can_added_models = provider_configuration.custom_configuration.can_added_models | |||||
| provider_response = ProviderResponse( | provider_response = ProviderResponse( | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| current_credential_name=getattr(provider_config, "current_credential_name", None), | current_credential_name=getattr(provider_config, "current_credential_name", None), | ||||
| available_credentials=getattr(provider_config, "available_credentials", []), | available_credentials=getattr(provider_config, "available_credentials", []), | ||||
| custom_models=model_config, | custom_models=model_config, | ||||
| can_added_models=can_added_models, | |||||
| ), | ), | ||||
| system_configuration=SystemConfigurationResponse( | system_configuration=SystemConfigurationResponse( | ||||
| enabled=provider_configuration.system_configuration.enabled, | enabled=provider_configuration.system_configuration.enabled, | ||||
| provider_configuration.validate_provider_credentials(credentials) | provider_configuration.validate_provider_credentials(credentials) | ||||
| def create_provider_credential( | def create_provider_credential( | ||||
| self, tenant_id: str, provider: str, credentials: dict, credential_name: str | |||||
| self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| Create and save new provider credentials. | Create and save new provider credentials. | ||||
| provider: str, | provider: str, | ||||
| credentials: dict, | credentials: dict, | ||||
| credential_id: str, | credential_id: str, | ||||
| credential_name: str, | |||||
| credential_name: str | None, | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| update a saved provider credential (by credential_id). | update a saved provider credential (by credential_id). | ||||
| ) | ) | ||||
| def create_model_credential( | def create_model_credential( | ||||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | |||||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| create and save model credentials. | create and save model credentials. | ||||
| model: str, | model: str, | ||||
| credentials: dict, | credentials: dict, | ||||
| credential_id: str, | credential_id: str, | ||||
| credential_name: str, | |||||
| credential_name: str | None, | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| update model credentials. | update model credentials. |