|
|
|
@@ -10,6 +10,7 @@ from core.entities.provider_configuration import ProviderConfigurations, Provide |
|
|
|
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, CustomModelConfiguration, \ |
|
|
|
SystemConfiguration, QuotaConfiguration |
|
|
|
from core.helper import encrypter |
|
|
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType |
|
|
|
from core.model_runtime.entities.model_entities import ModelType |
|
|
|
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType |
|
|
|
from core.model_runtime.model_providers import model_provider_factory |
|
|
|
@@ -79,9 +80,6 @@ class ProviderManager: |
|
|
|
# Get All preferred provider types of the workspace |
|
|
|
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) |
|
|
|
|
|
|
|
# Get decoding rsa key and cipher for decrypting credentials |
|
|
|
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) |
|
|
|
|
|
|
|
provider_configurations = ProviderConfigurations( |
|
|
|
tenant_id=tenant_id |
|
|
|
) |
|
|
|
@@ -100,19 +98,17 @@ class ProviderManager: |
|
|
|
|
|
|
|
# Convert to custom configuration |
|
|
|
custom_configuration = self._to_custom_configuration( |
|
|
|
tenant_id, |
|
|
|
provider_entity, |
|
|
|
provider_records, |
|
|
|
provider_model_records, |
|
|
|
decoding_rsa_key, |
|
|
|
decoding_cipher_rsa |
|
|
|
provider_model_records |
|
|
|
) |
|
|
|
|
|
|
|
# Convert to system configuration |
|
|
|
system_configuration = self._to_system_configuration( |
|
|
|
tenant_id, |
|
|
|
provider_entity, |
|
|
|
provider_records, |
|
|
|
decoding_rsa_key, |
|
|
|
decoding_cipher_rsa |
|
|
|
provider_records |
|
|
|
) |
|
|
|
|
|
|
|
# Get preferred provider type |
|
|
|
@@ -413,19 +409,17 @@ class ProviderManager: |
|
|
|
return provider_name_to_provider_records_dict |
|
|
|
|
|
|
|
def _to_custom_configuration(self, |
|
|
|
tenant_id: str, |
|
|
|
provider_entity: ProviderEntity, |
|
|
|
provider_records: list[Provider], |
|
|
|
provider_model_records: list[ProviderModel], |
|
|
|
decoding_rsa_key, |
|
|
|
decoding_cipher_rsa) -> CustomConfiguration: |
|
|
|
provider_model_records: list[ProviderModel]) -> CustomConfiguration: |
|
|
|
""" |
|
|
|
Convert to custom configuration. |
|
|
|
|
|
|
|
:param tenant_id: workspace id |
|
|
|
:param provider_entity: provider entity |
|
|
|
:param provider_records: provider records |
|
|
|
:param provider_model_records: provider model records |
|
|
|
:param decoding_rsa_key: decoding rsa key |
|
|
|
:param decoding_cipher_rsa: decoding cipher rsa |
|
|
|
:return: |
|
|
|
""" |
|
|
|
# Get provider credential secret variables |
|
|
|
@@ -448,28 +442,48 @@ class ProviderManager: |
|
|
|
# Get custom provider credentials |
|
|
|
custom_provider_configuration = None |
|
|
|
if custom_provider_record: |
|
|
|
try: |
|
|
|
# fix origin data |
|
|
|
if (custom_provider_record.encrypted_config |
|
|
|
and not custom_provider_record.encrypted_config.startswith("{")): |
|
|
|
provider_credentials = { |
|
|
|
"openai_api_key": custom_provider_record.encrypted_config |
|
|
|
} |
|
|
|
else: |
|
|
|
provider_credentials = json.loads(custom_provider_record.encrypted_config) |
|
|
|
except JSONDecodeError: |
|
|
|
provider_credentials = {} |
|
|
|
provider_credentials_cache = ProviderCredentialsCache( |
|
|
|
tenant_id=tenant_id, |
|
|
|
identity_id=custom_provider_record.id, |
|
|
|
cache_type=ProviderCredentialsCacheType.PROVIDER |
|
|
|
) |
|
|
|
|
|
|
|
for variable in provider_credential_secret_variables: |
|
|
|
if variable in provider_credentials: |
|
|
|
try: |
|
|
|
provider_credentials[variable] = encrypter.decrypt_token_with_decoding( |
|
|
|
provider_credentials.get(variable), |
|
|
|
decoding_rsa_key, |
|
|
|
decoding_cipher_rsa |
|
|
|
) |
|
|
|
except ValueError: |
|
|
|
pass |
|
|
|
# Get cached provider credentials |
|
|
|
cached_provider_credentials = provider_credentials_cache.get() |
|
|
|
|
|
|
|
if not cached_provider_credentials: |
|
|
|
try: |
|
|
|
# fix origin data |
|
|
|
if (custom_provider_record.encrypted_config |
|
|
|
and not custom_provider_record.encrypted_config.startswith("{")): |
|
|
|
provider_credentials = { |
|
|
|
"openai_api_key": custom_provider_record.encrypted_config |
|
|
|
} |
|
|
|
else: |
|
|
|
provider_credentials = json.loads(custom_provider_record.encrypted_config) |
|
|
|
except JSONDecodeError: |
|
|
|
provider_credentials = {} |
|
|
|
|
|
|
|
# Get decoding rsa key and cipher for decrypting credentials |
|
|
|
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) |
|
|
|
|
|
|
|
for variable in provider_credential_secret_variables: |
|
|
|
if variable in provider_credentials: |
|
|
|
try: |
|
|
|
provider_credentials[variable] = encrypter.decrypt_token_with_decoding( |
|
|
|
provider_credentials.get(variable), |
|
|
|
decoding_rsa_key, |
|
|
|
decoding_cipher_rsa |
|
|
|
) |
|
|
|
except ValueError: |
|
|
|
pass |
|
|
|
|
|
|
|
# cache provider credentials |
|
|
|
provider_credentials_cache.set( |
|
|
|
credentials=provider_credentials |
|
|
|
) |
|
|
|
else: |
|
|
|
provider_credentials = cached_provider_credentials |
|
|
|
|
|
|
|
custom_provider_configuration = CustomProviderConfiguration( |
|
|
|
credentials=provider_credentials |
|
|
|
@@ -487,21 +501,41 @@ class ProviderManager: |
|
|
|
if not provider_model_record.encrypted_config: |
|
|
|
continue |
|
|
|
|
|
|
|
try: |
|
|
|
provider_model_credentials = json.loads(provider_model_record.encrypted_config) |
|
|
|
except JSONDecodeError: |
|
|
|
continue |
|
|
|
provider_model_credentials_cache = ProviderCredentialsCache( |
|
|
|
tenant_id=tenant_id, |
|
|
|
identity_id=provider_model_record.id, |
|
|
|
cache_type=ProviderCredentialsCacheType.MODEL |
|
|
|
) |
|
|
|
|
|
|
|
for variable in model_credential_secret_variables: |
|
|
|
if variable in provider_model_credentials: |
|
|
|
try: |
|
|
|
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( |
|
|
|
provider_model_credentials.get(variable), |
|
|
|
decoding_rsa_key, |
|
|
|
decoding_cipher_rsa |
|
|
|
) |
|
|
|
except ValueError: |
|
|
|
pass |
|
|
|
# Get cached provider model credentials |
|
|
|
cached_provider_model_credentials = provider_model_credentials_cache.get() |
|
|
|
|
|
|
|
if not cached_provider_model_credentials: |
|
|
|
try: |
|
|
|
provider_model_credentials = json.loads(provider_model_record.encrypted_config) |
|
|
|
except JSONDecodeError: |
|
|
|
continue |
|
|
|
|
|
|
|
# Get decoding rsa key and cipher for decrypting credentials |
|
|
|
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) |
|
|
|
|
|
|
|
for variable in model_credential_secret_variables: |
|
|
|
if variable in provider_model_credentials: |
|
|
|
try: |
|
|
|
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( |
|
|
|
provider_model_credentials.get(variable), |
|
|
|
decoding_rsa_key, |
|
|
|
decoding_cipher_rsa |
|
|
|
) |
|
|
|
except ValueError: |
|
|
|
pass |
|
|
|
|
|
|
|
# cache provider model credentials |
|
|
|
provider_model_credentials_cache.set( |
|
|
|
credentials=provider_model_credentials |
|
|
|
) |
|
|
|
else: |
|
|
|
provider_model_credentials = cached_provider_model_credentials |
|
|
|
|
|
|
|
custom_model_configurations.append( |
|
|
|
CustomModelConfiguration( |
|
|
|
@@ -517,17 +551,15 @@ class ProviderManager: |
|
|
|
) |
|
|
|
|
|
|
|
def _to_system_configuration(self, |
|
|
|
tenant_id: str, |
|
|
|
provider_entity: ProviderEntity, |
|
|
|
provider_records: list[Provider], |
|
|
|
decoding_rsa_key, |
|
|
|
decoding_cipher_rsa) -> SystemConfiguration: |
|
|
|
provider_records: list[Provider]) -> SystemConfiguration: |
|
|
|
""" |
|
|
|
Convert to system configuration. |
|
|
|
|
|
|
|
:param tenant_id: workspace id |
|
|
|
:param provider_entity: provider entity |
|
|
|
:param provider_records: provider records |
|
|
|
:param decoding_rsa_key: decoding rsa key |
|
|
|
:param decoding_cipher_rsa: decoding cipher rsa |
|
|
|
:return: |
|
|
|
""" |
|
|
|
# Get hosting configuration |
|
|
|
@@ -580,29 +612,49 @@ class ProviderManager: |
|
|
|
provider_record = quota_type_to_provider_records_dict.get(current_quota_type) |
|
|
|
|
|
|
|
if provider_record: |
|
|
|
try: |
|
|
|
provider_credentials = json.loads(provider_record.encrypted_config) |
|
|
|
except JSONDecodeError: |
|
|
|
provider_credentials = {} |
|
|
|
|
|
|
|
# Get provider credential secret variables |
|
|
|
provider_credential_secret_variables = self._extract_secret_variables( |
|
|
|
provider_entity.provider_credential_schema.credential_form_schemas |
|
|
|
if provider_entity.provider_credential_schema else [] |
|
|
|
provider_credentials_cache = ProviderCredentialsCache( |
|
|
|
tenant_id=tenant_id, |
|
|
|
identity_id=provider_record.id, |
|
|
|
cache_type=ProviderCredentialsCacheType.PROVIDER |
|
|
|
) |
|
|
|
|
|
|
|
for variable in provider_credential_secret_variables: |
|
|
|
if variable in provider_credentials: |
|
|
|
try: |
|
|
|
provider_credentials[variable] = encrypter.decrypt_token_with_decoding( |
|
|
|
provider_credentials.get(variable), |
|
|
|
decoding_rsa_key, |
|
|
|
decoding_cipher_rsa |
|
|
|
) |
|
|
|
except ValueError: |
|
|
|
pass |
|
|
|
# Get cached provider credentials |
|
|
|
cached_provider_credentials = provider_credentials_cache.get() |
|
|
|
|
|
|
|
current_using_credentials = provider_credentials |
|
|
|
if not cached_provider_credentials: |
|
|
|
try: |
|
|
|
provider_credentials = json.loads(provider_record.encrypted_config) |
|
|
|
except JSONDecodeError: |
|
|
|
provider_credentials = {} |
|
|
|
|
|
|
|
# Get provider credential secret variables |
|
|
|
provider_credential_secret_variables = self._extract_secret_variables( |
|
|
|
provider_entity.provider_credential_schema.credential_form_schemas |
|
|
|
if provider_entity.provider_credential_schema else [] |
|
|
|
) |
|
|
|
|
|
|
|
# Get decoding rsa key and cipher for decrypting credentials |
|
|
|
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) |
|
|
|
|
|
|
|
for variable in provider_credential_secret_variables: |
|
|
|
if variable in provider_credentials: |
|
|
|
try: |
|
|
|
provider_credentials[variable] = encrypter.decrypt_token_with_decoding( |
|
|
|
provider_credentials.get(variable), |
|
|
|
decoding_rsa_key, |
|
|
|
decoding_cipher_rsa |
|
|
|
) |
|
|
|
except ValueError: |
|
|
|
pass |
|
|
|
|
|
|
|
current_using_credentials = provider_credentials |
|
|
|
|
|
|
|
# cache provider credentials |
|
|
|
provider_credentials_cache.set( |
|
|
|
credentials=current_using_credentials |
|
|
|
) |
|
|
|
else: |
|
|
|
current_using_credentials = cached_provider_credentials |
|
|
|
else: |
|
|
|
current_using_credentials = {} |
|
|
|
|