| 
                        123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596 | 
                        - import datetime
 - import json
 - import logging
 - import os
 - from collections import defaultdict
 - from typing import Optional
 - 
 - import requests
 - 
 - from core.model_providers.model_factory import ModelFactory
 - from extensions.ext_database import db
 - from core.model_providers.model_provider_factory import ModelProviderFactory
 - from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
 - from models.provider import Provider, ProviderModel, TenantPreferredModelProvider, ProviderType, ProviderQuotaType, \
 -     TenantDefaultModel
 - 
 - 
 - class ProviderService:
 - 
 -     def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list:
 -         """
 -         get provider list of tenant.
 - 
 -         :param tenant_id: workspace id
 -         :param model_type: filter by model type
 -         :return:
 -         """
 -         # get rules for all providers
 -         model_provider_rules = ModelProviderFactory.get_provider_rules()
 -         model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()]
 - 
 -         for model_provider_name, model_provider_rule in model_provider_rules.items():
 -             if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types'] \
 -                     and 'system_config' in model_provider_rule and model_provider_rule['system_config'] \
 -                     and 'supported_quota_types' in model_provider_rule['system_config'] \
 -                     and 'trial' in model_provider_rule['system_config']['supported_quota_types']:
 -                 ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
 - 
 -         configurable_model_provider_names = [
 -             model_provider_name
 -             for model_provider_name, model_provider_rules in model_provider_rules.items()
 -             if 'custom' in model_provider_rules['support_provider_types']
 -                and model_provider_rules['model_flexibility'] == 'configurable'
 -         ]
 - 
 -         # get all providers for the tenant
 -         providers = db.session.query(Provider) \
 -             .filter(
 -             Provider.tenant_id == tenant_id,
 -             Provider.provider_name.in_(model_provider_names),
 -             Provider.is_valid == True
 -         ).order_by(Provider.created_at.desc()).all()
 - 
 -         provider_name_to_provider_dict = defaultdict(list)
 -         for provider in providers:
 -             provider_name_to_provider_dict[provider.provider_name].append(provider)
 - 
 -         # get all configurable provider models for the tenant
 -         provider_models = db.session.query(ProviderModel) \
 -             .filter(
 -             ProviderModel.tenant_id == tenant_id,
 -             ProviderModel.provider_name.in_(configurable_model_provider_names),
 -             ProviderModel.is_valid == True
 -         ).order_by(ProviderModel.created_at.desc()).all()
 - 
 -         provider_name_to_provider_model_dict = defaultdict(list)
 -         for provider_model in provider_models:
 -             provider_name_to_provider_model_dict[provider_model.provider_name].append(provider_model)
 - 
 -         # get all preferred provider type for the tenant
 -         preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
 -             .filter(
 -             TenantPreferredModelProvider.tenant_id == tenant_id,
 -             TenantPreferredModelProvider.provider_name.in_(model_provider_names)
 -         ).all()
 - 
 -         provider_name_to_preferred_provider_type_dict = {preferred_provider_type.provider_name: preferred_provider_type
 -                                                          for preferred_provider_type in preferred_provider_types}
 - 
 -         providers_list = {}
 - 
 -         for model_provider_name, model_provider_rule in model_provider_rules.items():
 -             if model_type and model_type not in model_provider_rule.get('supported_model_types', []):
 -                 continue
 - 
 -             # get preferred provider type
 -             preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name)
 -             preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider(
 -                 tenant_id,
 -                 model_provider_name,
 -                 preferred_model_provider
 -             )
 - 
 -             provider_config_dict = {
 -                 "preferred_provider_type": preferred_provider_type,
 -                 "model_flexibility": model_provider_rule['model_flexibility'],
 -                 "supported_model_types": model_provider_rule.get("supported_model_types", []),
 -             }
 - 
 -             provider_parameter_dict = {}
 -             if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types']:
 -                 for quota_type_enum in ProviderQuotaType:
 -                     quota_type = quota_type_enum.value
 -                     if quota_type in model_provider_rule['system_config']['supported_quota_types']:
 -                         key = ProviderType.SYSTEM.value + ':' + quota_type
 -                         provider_parameter_dict[key] = {
 -                             "provider_name": model_provider_name,
 -                             "provider_type": ProviderType.SYSTEM.value,
 -                             "config": None,
 -                             "is_valid": False,  # need update
 -                             "quota_type": quota_type,
 -                             "quota_unit": model_provider_rule['system_config']['quota_unit'],  # need update
 -                             "quota_limit": 0 if quota_type != ProviderQuotaType.TRIAL.value else
 -                             model_provider_rule['system_config']['quota_limit'],  # need update
 -                             "quota_used": 0,  # need update
 -                             "last_used": None  # need update
 -                         }
 - 
 -             if ProviderType.CUSTOM.value in model_provider_rule['support_provider_types']:
 -                 provider_parameter_dict[ProviderType.CUSTOM.value] = {
 -                     "provider_name": model_provider_name,
 -                     "provider_type": ProviderType.CUSTOM.value,
 -                     "config": None,  # need update
 -                     "models": [],  # need update
 -                     "is_valid": False,
 -                     "last_used": None  # need update
 -                 }
 - 
 -             model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
 - 
 -             current_providers = provider_name_to_provider_dict[model_provider_name]
 -             for provider in current_providers:
 -                 if provider.provider_type == ProviderType.SYSTEM.value:
 -                     quota_type = provider.quota_type
 -                     key = f'{ProviderType.SYSTEM.value}:{quota_type}'
 - 
 -                     if key in provider_parameter_dict:
 -                         provider_parameter_dict[key]['is_valid'] = provider.is_valid
 -                         provider_parameter_dict[key]['quota_used'] = provider.quota_used
 -                         provider_parameter_dict[key]['quota_limit'] = provider.quota_limit
 -                         provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
 -                             if provider.last_used else None
 -                 elif provider.provider_type == ProviderType.CUSTOM.value \
 -                         and ProviderType.CUSTOM.value in provider_parameter_dict:
 -                     # if custom
 -                     key = ProviderType.CUSTOM.value
 -                     provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
 -                             if provider.last_used else None
 -                     provider_parameter_dict[key]['is_valid'] = provider.is_valid
 - 
 -                     if model_provider_rule['model_flexibility'] == 'fixed':
 -                         provider_parameter_dict[key]['config'] = model_provider_class(provider=provider) \
 -                             .get_provider_credentials(obfuscated=True)
 -                     else:
 -                         models = []
 -                         provider_models = provider_name_to_provider_model_dict[model_provider_name]
 -                         for provider_model in provider_models:
 -                             models.append({
 -                                 "model_name": provider_model.model_name,
 -                                 "model_type": provider_model.model_type,
 -                                 "config": model_provider_class(provider=provider) \
 -                                     .get_model_credentials(provider_model.model_name,
 -                                                            ModelType.value_of(provider_model.model_type),
 -                                                            obfuscated=True),
 -                                 "is_valid": provider_model.is_valid
 -                             })
 -                         provider_parameter_dict[key]['models'] = models
 - 
 -             provider_config_dict['providers'] = list(provider_parameter_dict.values())
 -             providers_list[model_provider_name] = provider_config_dict
 - 
 -         return providers_list
 - 
 -     def custom_provider_config_validate(self, provider_name: str, config: dict) -> None:
 -         """
 -         validate custom provider config.
 - 
 -         :param provider_name:
 -         :param config:
 -         :return:
 -         :raises CredentialsValidateFailedError: When the config credential verification fails.
 -         """
 -         # get model provider rules
 -         model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
 - 
 -         if model_provider_rules['model_flexibility'] != 'fixed':
 -             raise ValueError('Only support fixed model provider')
 - 
 -         # only support provider type CUSTOM
 -         if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
 -             raise ValueError('Only support provider type CUSTOM')
 - 
 -         # validate provider config
 -         model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
 -         model_provider_class.is_provider_credentials_valid_or_raise(config)
 - 
 -     def save_custom_provider_config(self, tenant_id: str, provider_name: str, config: dict) -> None:
 -         """
 -         save custom provider config.
 - 
 -         :param tenant_id:
 -         :param provider_name:
 -         :param config:
 -         :return:
 -         """
 -         # validate custom provider config
 -         self.custom_provider_config_validate(provider_name, config)
 - 
 -         # get provider
 -         provider = db.session.query(Provider) \
 -             .filter(
 -             Provider.tenant_id == tenant_id,
 -             Provider.provider_name == provider_name,
 -             Provider.provider_type == ProviderType.CUSTOM.value
 -         ).first()
 - 
 -         model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
 -         encrypted_config = model_provider_class.encrypt_provider_credentials(tenant_id, config)
 - 
 -         # save provider
 -         if provider:
 -             provider.encrypted_config = json.dumps(encrypted_config)
 -             provider.is_valid = True
 -             provider.updated_at = datetime.datetime.utcnow()
 -             db.session.commit()
 -         else:
 -             provider = Provider(
 -                 tenant_id=tenant_id,
 -                 provider_name=provider_name,
 -                 provider_type=ProviderType.CUSTOM.value,
 -                 encrypted_config=json.dumps(encrypted_config),
 -                 is_valid=True
 -             )
 -             db.session.add(provider)
 -             db.session.commit()
 - 
 -     def delete_custom_provider(self, tenant_id: str, provider_name: str) -> None:
 -         """
 -         delete custom provider.
 - 
 -         :param tenant_id:
 -         :param provider_name:
 -         :return:
 -         """
 -         # get provider
 -         provider = db.session.query(Provider) \
 -             .filter(
 -             Provider.tenant_id == tenant_id,
 -             Provider.provider_name == provider_name,
 -             Provider.provider_type == ProviderType.CUSTOM.value
 -         ).first()
 - 
 -         if provider:
 -             try:
 -                 self.switch_preferred_provider(tenant_id, provider_name, ProviderType.SYSTEM.value)
 -             except ValueError:
 -                 pass
 - 
 -             db.session.delete(provider)
 -             db.session.commit()
 - 
 -     def custom_provider_model_config_validate(self,
 -                                               provider_name: str,
 -                                               model_name: str,
 -                                               model_type: str,
 -                                               config: dict) -> None:
 -         """
 -         validate custom provider model config.
 - 
 -         :param provider_name:
 -         :param model_name:
 -         :param model_type:
 -         :param config:
 -         :return:
 -         :raises CredentialsValidateFailedError: When the config credential verification fails.
 -         """
 -         # get model provider rules
 -         model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
 - 
 -         if model_provider_rules['model_flexibility'] != 'configurable':
 -             raise ValueError('Only support configurable model provider')
 - 
 -         # only support provider type CUSTOM
 -         if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
 -             raise ValueError('Only support provider type CUSTOM')
 - 
 -         # validate provider model config
 -         model_type = ModelType.value_of(model_type)
 -         model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
 -         model_provider_class.is_model_credentials_valid_or_raise(model_name, model_type, config)
 - 
 -     def add_or_save_custom_provider_model_config(self,
 -                                                  tenant_id: str,
 -                                                  provider_name: str,
 -                                                  model_name: str,
 -                                                  model_type: str,
 -                                                  config: dict) -> None:
 -         """
 -         Add or save custom provider model config.
 - 
 -         :param tenant_id:
 -         :param provider_name:
 -         :param model_name:
 -         :param model_type:
 -         :param config:
 -         :return:
 -         """
 -         # validate custom provider model config
 -         self.custom_provider_model_config_validate(provider_name, model_name, model_type, config)
 - 
 -         # get provider
 -         provider = db.session.query(Provider) \
 -             .filter(
 -             Provider.tenant_id == tenant_id,
 -             Provider.provider_name == provider_name,
 -             Provider.provider_type == ProviderType.CUSTOM.value
 -         ).first()
 - 
 -         if not provider:
 -             provider = Provider(
 -                 tenant_id=tenant_id,
 -                 provider_name=provider_name,
 -                 provider_type=ProviderType.CUSTOM.value,
 -                 is_valid=True
 -             )
 -             db.session.add(provider)
 -             db.session.commit()
 -         elif not provider.is_valid:
 -             provider.is_valid = True
 -             provider.encrypted_config = None
 -             db.session.commit()
 - 
 -         model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
 -         encrypted_config = model_provider_class.encrypt_model_credentials(
 -             tenant_id,
 -             model_name,
 -             ModelType.value_of(model_type),
 -             config
 -         )
 - 
 -         # get provider model
 -         provider_model = db.session.query(ProviderModel) \
 -             .filter(
 -             ProviderModel.tenant_id == tenant_id,
 -             ProviderModel.provider_name == provider_name,
 -             ProviderModel.model_name == model_name,
 -             ProviderModel.model_type == model_type
 -         ).first()
 - 
 -         if provider_model:
 -             provider_model.encrypted_config = json.dumps(encrypted_config)
 -             provider_model.is_valid = True
 -             db.session.commit()
 -         else:
 -             provider_model = ProviderModel(
 -                 tenant_id=tenant_id,
 -                 provider_name=provider_name,
 -                 model_name=model_name,
 -                 model_type=model_type,
 -                 encrypted_config=json.dumps(encrypted_config),
 -                 is_valid=True
 -             )
 -             db.session.add(provider_model)
 -             db.session.commit()
 - 
 -     def delete_custom_provider_model(self,
 -                                      tenant_id: str,
 -                                      provider_name: str,
 -                                      model_name: str,
 -                                      model_type: str) -> None:
 -         """
 -         delete custom provider model.
 - 
 -         :param tenant_id:
 -         :param provider_name:
 -         :param model_name:
 -         :param model_type:
 -         :return:
 -         """
 -         # get provider model
 -         provider_model = db.session.query(ProviderModel) \
 -             .filter(
 -             ProviderModel.tenant_id == tenant_id,
 -             ProviderModel.provider_name == provider_name,
 -             ProviderModel.model_name == model_name,
 -             ProviderModel.model_type == model_type
 -         ).first()
 - 
 -         if provider_model:
 -             db.session.delete(provider_model)
 -             db.session.commit()
 - 
 -     def switch_preferred_provider(self, tenant_id: str, provider_name: str, preferred_provider_type: str) -> None:
 -         """
 -         switch preferred provider.
 - 
 -         :param tenant_id:
 -         :param provider_name:
 -         :param preferred_provider_type:
 -         :return:
 -         """
 -         provider_type = ProviderType.value_of(preferred_provider_type)
 -         if not provider_type:
 -             raise ValueError(f'Invalid preferred provider type: {preferred_provider_type}')
 - 
 -         model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
 -         if preferred_provider_type not in model_provider_rules['support_provider_types']:
 -             raise ValueError(f'Not support provider type: {preferred_provider_type}')
 - 
 -         model_provider = ModelProviderFactory.get_model_provider_class(provider_name)
 -         if not model_provider.is_provider_type_system_supported():
 -             return
 - 
 -         # get preferred provider
 -         preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
 -             .filter(
 -             TenantPreferredModelProvider.tenant_id == tenant_id,
 -             TenantPreferredModelProvider.provider_name == provider_name
 -         ).first()
 - 
 -         if preferred_model_provider:
 -             preferred_model_provider.preferred_provider_type = preferred_provider_type
 -         else:
 -             preferred_model_provider = TenantPreferredModelProvider(
 -                 tenant_id=tenant_id,
 -                 provider_name=provider_name,
 -                 preferred_provider_type=preferred_provider_type
 -             )
 -             db.session.add(preferred_model_provider)
 - 
 -         db.session.commit()
 - 
 -     def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[TenantDefaultModel]:
 -         """
 -         get default model of model type.
 - 
 -         :param tenant_id:
 -         :param model_type:
 -         :return:
 -         """
 -         return ModelFactory.get_default_model(tenant_id, ModelType.value_of(model_type))
 - 
 -     def update_default_model_of_model_type(self,
 -                                            tenant_id: str,
 -                                            model_type: str,
 -                                            provider_name: str,
 -                                            model_name: str) -> TenantDefaultModel:
 -         """
 -         update default model of model type.
 - 
 -         :param tenant_id:
 -         :param model_type:
 -         :param provider_name:
 -         :param model_name:
 -         :return:
 -         """
 -         return ModelFactory.update_default_model(tenant_id, ModelType.value_of(model_type), provider_name, model_name)
 - 
 -     def get_valid_model_list(self, tenant_id: str, model_type: str) -> list:
 -         """
 -         get valid model list.
 - 
 -         :param tenant_id:
 -         :param model_type:
 -         :return:
 -         """
 -         valid_model_list = []
 - 
 -         # get model provider rules
 -         model_provider_rules = ModelProviderFactory.get_provider_rules()
 -         for model_provider_name, model_provider_rule in model_provider_rules.items():
 -             model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
 -             if not model_provider:
 -                 continue
 - 
 -             model_list = model_provider.get_supported_model_list(ModelType.value_of(model_type))
 -             provider = model_provider.provider
 -             for model in model_list:
 -                 valid_model_dict = {
 -                     "model_name": model['id'],
 -                     "model_display_name": model['name'],
 -                     "model_type": model_type,
 -                     "model_provider": {
 -                         "provider_name": provider.provider_name,
 -                         "provider_type": provider.provider_type
 -                     },
 -                     'features': []
 -                 }
 - 
 -                 if 'mode' in model:
 -                     valid_model_dict['model_mode'] = model['mode']
 - 
 -                 if 'features' in model:
 -                     valid_model_dict['features'] = model['features']
 - 
 -                 if provider.provider_type == ProviderType.SYSTEM.value:
 -                     valid_model_dict['model_provider']['quota_type'] = provider.quota_type
 -                     valid_model_dict['model_provider']['quota_unit'] = model_provider_rule['system_config']['quota_unit']
 -                     valid_model_dict['model_provider']['quota_limit'] = provider.quota_limit
 -                     valid_model_dict['model_provider']['quota_used'] = provider.quota_used
 - 
 -                 valid_model_list.append(valid_model_dict)
 - 
 -         return valid_model_list
 - 
 -     def get_model_parameter_rules(self, tenant_id: str, model_provider_name: str, model_name: str, model_type: str) \
 -             -> ModelKwargsRules:
 -         """
 -         get model parameter rules.
 -         It depends on preferred provider in use.
 - 
 -         :param tenant_id:
 -         :param model_provider_name:
 -         :param model_name:
 -         :param model_type:
 -         :return:
 -         """
 -         # get model provider
 -         model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
 -         if not model_provider:
 -             # get empty model provider
 -             return ModelKwargsRules()
 - 
 -         # get model parameter rules
 -         return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type))
 - 
 -     def free_quota_submit(self, tenant_id: str, provider_name: str):
 -         api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
 -         api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
 -         api_url = api_base_url + '/api/v1/providers/apply'
 - 
 -         headers = {
 -             'Content-Type': 'application/json',
 -             'Authorization': f"Bearer {api_key}"
 -         }
 -         response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider_name})
 -         if not response.ok:
 -             logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
 -             raise ValueError(f"Error: {response.status_code} ")
 - 
 -         if response.json()["code"] != 'success':
 -             raise ValueError(
 -                 f"error: {response.json()['message']}"
 -             )
 - 
 -         rst = response.json()
 - 
 -         if rst['type'] == 'redirect':
 -             return {
 -                 'type': rst['type'],
 -                 'redirect_url': rst['redirect_url']
 -             }
 -         else:
 -             return {
 -                 'type': rst['type'],
 -                 'result': 'success'
 -             }
 - 
 -     def free_quota_qualification_verify(self, tenant_id: str, provider_name: str, token: Optional[str]):
 -         api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
 -         api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
 -         api_url = api_base_url + '/api/v1/providers/qualification-verify'
 - 
 -         headers = {
 -             'Content-Type': 'application/json',
 -             'Authorization': f"Bearer {api_key}"
 -         }
 -         json_data = {'workspace_id': tenant_id, 'provider_name': provider_name}
 -         if token:
 -             json_data['token'] = token
 -         response = requests.post(api_url, headers=headers,
 -                                  json=json_data)
 -         if not response.ok:
 -             logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
 -             raise ValueError(f"Error: {response.status_code} ")
 - 
 -         rst = response.json()
 -         if rst["code"] != 'success':
 -             raise ValueError(
 -                 f"error: {rst['message']}"
 -             )
 - 
 -         data = rst['data']
 -         if data['qualified'] is True:
 -             return {
 -                 'result': 'success',
 -                 'provider_name': provider_name,
 -                 'flag': True
 -             }
 -         else:
 -             return {
 -                 'result': 'success',
 -                 'provider_name': provider_name,
 -                 'flag': False,
 -                 'reason': data['reason']
 -             }
 
 
  |