| 
                        123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 | 
                        - from abc import ABC, abstractmethod
 - from datetime import datetime
 - from typing import Type, Optional
 - 
 - from flask import current_app
 - from pydantic import BaseModel
 - 
 - from core.model_providers.error import QuotaExceededError, LLMBadRequestError
 - from extensions.ext_database import db
 - from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
 - from core.model_providers.models.entity.provider import ProviderQuotaUnit
 - from core.model_providers.rules import provider_rules
 - from models.provider import Provider, ProviderType, ProviderModel
 - 
 - 
 - class BaseModelProvider(BaseModel, ABC):
 - 
 -     provider: Provider
 - 
 -     class Config:
 -         """Configuration for this pydantic object."""
 - 
 -         arbitrary_types_allowed = True
 - 
 -     @property
 -     @abstractmethod
 -     def provider_name(self):
 -         """
 -         Returns the name of a provider.
 -         """
 -         raise NotImplementedError
 - 
 -     def get_rules(self):
 -         """
 -         Returns the rules of a provider.
 -         """
 -         return provider_rules[self.provider_name]
 - 
 -     def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
 -         """
 -         get supported model object list for use.
 - 
 -         :param model_type:
 -         :return:
 -         """
 -         rules = self.get_rules()
 -         if 'custom' not in rules['support_provider_types']:
 -             return self._get_fixed_model_list(model_type)
 - 
 -         if 'model_flexibility' not in rules:
 -             return self._get_fixed_model_list(model_type)
 - 
 -         if rules['model_flexibility'] == 'fixed':
 -             return self._get_fixed_model_list(model_type)
 - 
 -         # get configurable provider models
 -         provider_models = db.session.query(ProviderModel).filter(
 -             ProviderModel.tenant_id == self.provider.tenant_id,
 -             ProviderModel.provider_name == self.provider.provider_name,
 -             ProviderModel.model_type == model_type.value,
 -             ProviderModel.is_valid == True
 -         ).order_by(ProviderModel.created_at.asc()).all()
 - 
 -         provider_model_list = []
 -         for provider_model in provider_models:
 -             provider_model_dict = {
 -                 'id': provider_model.model_name,
 -                 'name': provider_model.model_name
 -             }
 - 
 -             if model_type == ModelType.TEXT_GENERATION:
 -                 provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
 - 
 -             provider_model_list.append(provider_model_dict)
 - 
 -         return provider_model_list
 - 
 -     @abstractmethod
 -     def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
 -         """
 -         get supported model object list for use.
 - 
 -         :param model_type:
 -         :return:
 -         """
 -         raise NotImplementedError
 - 
 -     @abstractmethod
 -     def _get_text_generation_model_mode(self, model_name) -> str:
 -         """
 -         get text generation model mode.
 - 
 -         :param model_name:
 -         :return:
 -         """
 -         raise NotImplementedError
 - 
 -     @abstractmethod
 -     def get_model_class(self, model_type: ModelType) -> Type:
 -         """
 -         get specific model class.
 - 
 -         :param model_type:
 -         :return:
 -         """
 -         raise NotImplementedError
 - 
 -     @classmethod
 -     @abstractmethod
 -     def is_provider_credentials_valid_or_raise(cls, credentials: dict):
 -         """
 -         check provider credentials valid.
 - 
 -         :param credentials:
 -         """
 -         raise NotImplementedError
 - 
 -     @classmethod
 -     @abstractmethod
 -     def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
 -         """
 -         encrypt provider credentials for save.
 - 
 -         :param tenant_id:
 -         :param credentials:
 -         :return:
 -         """
 -         raise NotImplementedError
 - 
 -     @abstractmethod
 -     def get_provider_credentials(self, obfuscated: bool = False) -> dict:
 -         """
 -         get credentials for llm use.
 - 
 -         :param obfuscated:
 -         :return:
 -         """
 -         raise NotImplementedError
 - 
 -     @classmethod
 -     @abstractmethod
 -     def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
 -         """
 -         check model credentials valid.
 - 
 -         :param model_name:
 -         :param model_type:
 -         :param credentials:
 -         """
 -         raise NotImplementedError
 - 
 -     @classmethod
 -     @abstractmethod
 -     def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
 -                                   credentials: dict) -> dict:
 -         """
 -         encrypt model credentials for save.
 - 
 -         :param tenant_id:
 -         :param model_name:
 -         :param model_type:
 -         :param credentials:
 -         :return:
 -         """
 -         raise NotImplementedError
 - 
 -     @abstractmethod
 -     def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
 -         """
 -         get model parameter rules.
 - 
 -         :param model_name:
 -         :param model_type:
 -         :return:
 -         """
 -         raise NotImplementedError
 - 
 -     @abstractmethod
 -     def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
 -         """
 -         get credentials for llm use.
 - 
 -         :param model_name:
 -         :param model_type:
 -         :param obfuscated:
 -         :return:
 -         """
 -         raise NotImplementedError
 - 
 -     @classmethod
 -     def is_provider_type_system_supported(cls) -> bool:
 -         return current_app.config['EDITION'] == 'CLOUD'
 - 
 -     def check_quota_over_limit(self):
 -         """
 -         check provider quota over limit.
 - 
 -         :return:
 -         """
 -         if self.provider.provider_type != ProviderType.SYSTEM.value:
 -             return
 - 
 -         rules = self.get_rules()
 -         if 'system' not in rules['support_provider_types']:
 -             return
 - 
 -         provider = db.session.query(Provider).filter(
 -             db.and_(
 -                 Provider.id == self.provider.id,
 -                 Provider.is_valid == True,
 -                 Provider.quota_limit > Provider.quota_used
 -             )
 -         ).first()
 - 
 -         if not provider:
 -             raise QuotaExceededError()
 - 
 -     def deduct_quota(self, used_tokens: int = 0) -> None:
 -         """
 -         deduct available quota when provider type is system or paid.
 - 
 -         :return:
 -         """
 -         if self.provider.provider_type != ProviderType.SYSTEM.value:
 -             return
 - 
 -         rules = self.get_rules()
 -         if 'system' not in rules['support_provider_types']:
 -             return
 - 
 -         if not self.should_deduct_quota():
 -             return
 - 
 -         if 'system_config' not in rules:
 -             quota_unit = ProviderQuotaUnit.TIMES.value
 -         elif 'quota_unit' not in rules['system_config']:
 -             quota_unit = ProviderQuotaUnit.TIMES.value
 -         else:
 -             quota_unit = rules['system_config']['quota_unit']
 - 
 -         if quota_unit == ProviderQuotaUnit.TOKENS.value:
 -             used_quota = used_tokens
 -         else:
 -             used_quota = 1
 - 
 -         db.session.query(Provider).filter(
 -             Provider.tenant_id == self.provider.tenant_id,
 -             Provider.provider_name == self.provider.provider_name,
 -             Provider.provider_type == self.provider.provider_type,
 -             Provider.quota_type == self.provider.quota_type,
 -             Provider.quota_limit > Provider.quota_used
 -         ).update({'quota_used': Provider.quota_used + used_quota})
 -         db.session.commit()
 - 
 -     def should_deduct_quota(self):
 -         return False
 - 
 -     def update_last_used(self) -> None:
 -         """
 -         update last used time.
 - 
 -         :return:
 -         """
 -         db.session.query(Provider).filter(
 -             Provider.tenant_id == self.provider.tenant_id,
 -             Provider.provider_name == self.provider.provider_name
 -         ).update({'last_used': datetime.utcnow()})
 -         db.session.commit()
 - 
 -     def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
 -         """
 -         get provider model.
 - 
 -         :param model_name:
 -         :param model_type:
 -         :return:
 -         """
 -         provider_model = db.session.query(ProviderModel).filter(
 -             ProviderModel.tenant_id == self.provider.tenant_id,
 -             ProviderModel.provider_name == self.provider.provider_name,
 -             ProviderModel.model_name == model_name,
 -             ProviderModel.model_type == model_type.value,
 -             ProviderModel.is_valid == True
 -         ).first()
 - 
 -         if not provider_model:
 -             raise LLMBadRequestError(f"The model {model_name} does not exist. "
 -                                      f"Please check the configuration.")
 - 
 -         return provider_model
 - 
 - 
 - class CredentialsValidateFailedError(Exception):
 -     pass
 
 
  |