Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>tags/0.4.5
| import datetime | import datetime | ||||
| import json | import json | ||||
| import logging | import logging | ||||
| import time | |||||
| from json import JSONDecodeError | from json import JSONDecodeError | ||||
| from typing import Optional, List, Dict, Tuple, Iterator | from typing import Optional, List, Dict, Tuple, Iterator | ||||
| from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus | from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus | ||||
| from core.helper import encrypter | from core.helper import encrypter | ||||
| from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | ||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType | |||||
| from core.model_runtime.entities.model_entities import ModelType, FetchFrom | |||||
| from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \ | |||||
| ConfigurateMethod | |||||
| from core.model_runtime.model_providers import model_provider_factory | from core.model_runtime.model_providers import model_provider_factory | ||||
| from core.model_runtime.model_providers.__base.ai_model import AIModel | from core.model_runtime.model_providers.__base.ai_model import AIModel | ||||
| from core.model_runtime.model_providers.__base.model_provider import ModelProvider | from core.model_runtime.model_providers.__base.model_provider import ModelProvider | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| original_provider_configurate_methods = {} | |||||
| class ProviderConfiguration(BaseModel): | class ProviderConfiguration(BaseModel): | ||||
| """ | """ | ||||
| system_configuration: SystemConfiguration | system_configuration: SystemConfiguration | ||||
| custom_configuration: CustomConfiguration | custom_configuration: CustomConfiguration | ||||
| def __init__(self, **data): | |||||
| super().__init__(**data) | |||||
| if self.provider.provider not in original_provider_configurate_methods: | |||||
| original_provider_configurate_methods[self.provider.provider] = [] | |||||
| for configurate_method in self.provider.configurate_methods: | |||||
| original_provider_configurate_methods[self.provider.provider].append(configurate_method) | |||||
| if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: | |||||
| if (any([len(quota_configuration.restrict_models) > 0 | |||||
| for quota_configuration in self.system_configuration.quota_configurations]) | |||||
| and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): | |||||
| self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) | |||||
| def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: | def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: | ||||
| """ | """ | ||||
| Get current credentials. | Get current credentials. | ||||
| if provider_record: | if provider_record: | ||||
| try: | try: | ||||
| original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {} | |||||
| original_credentials = json.loads( | |||||
| provider_record.encrypted_config) if provider_record.encrypted_config else {} | |||||
| except JSONDecodeError: | except JSONDecodeError: | ||||
| original_credentials = {} | original_credentials = {} | ||||
| if provider_model_record: | if provider_model_record: | ||||
| try: | try: | ||||
| original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} | |||||
| original_credentials = json.loads( | |||||
| provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} | |||||
| except JSONDecodeError: | except JSONDecodeError: | ||||
| original_credentials = {} | original_credentials = {} | ||||
| ] | ] | ||||
| ) | ) | ||||
| if self.provider.provider not in original_provider_configurate_methods: | |||||
| original_provider_configurate_methods[self.provider.provider] = [] | |||||
| for configurate_method in provider_instance.get_provider_schema().configurate_methods: | |||||
| original_provider_configurate_methods[self.provider.provider].append(configurate_method) | |||||
| should_use_custom_model = False | |||||
| if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: | |||||
| should_use_custom_model = True | |||||
| for quota_configuration in self.system_configuration.quota_configurations: | for quota_configuration in self.system_configuration.quota_configurations: | ||||
| if self.system_configuration.current_quota_type != quota_configuration.quota_type: | if self.system_configuration.current_quota_type != quota_configuration.quota_type: | ||||
| continue | continue | ||||
| restrict_llms = quota_configuration.restrict_llms | |||||
| if not restrict_llms: | |||||
| restrict_models = quota_configuration.restrict_models | |||||
| if len(restrict_models) == 0: | |||||
| break | break | ||||
| if should_use_custom_model: | |||||
| if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: | |||||
| # only customizable model | |||||
| for restrict_model in restrict_models: | |||||
| copy_credentials = self.system_configuration.credentials.copy() | |||||
| if restrict_model.base_model_name: | |||||
| copy_credentials['base_model_name'] = restrict_model.base_model_name | |||||
| try: | |||||
| custom_model_schema = ( | |||||
| provider_instance.get_model_instance(restrict_model.model_type) | |||||
| .get_customizable_model_schema_from_credentials( | |||||
| restrict_model.model, | |||||
| copy_credentials | |||||
| ) | |||||
| ) | |||||
| except Exception as ex: | |||||
| logger.warning(f'get custom model schema failed, {ex}') | |||||
| continue | |||||
| if not custom_model_schema: | |||||
| continue | |||||
| if custom_model_schema.model_type not in model_types: | |||||
| continue | |||||
| provider_models.append( | |||||
| ModelWithProviderEntity( | |||||
| model=custom_model_schema.model, | |||||
| label=custom_model_schema.label, | |||||
| model_type=custom_model_schema.model_type, | |||||
| features=custom_model_schema.features, | |||||
| fetch_from=FetchFrom.PREDEFINED_MODEL, | |||||
| model_properties=custom_model_schema.model_properties, | |||||
| deprecated=custom_model_schema.deprecated, | |||||
| provider=SimpleModelProviderEntity(self.provider), | |||||
| status=ModelStatus.ACTIVE | |||||
| ) | |||||
| ) | |||||
| # if llm name not in restricted llm list, remove it | # if llm name not in restricted llm list, remove it | ||||
| restrict_model_names = [rm.model for rm in restrict_models] | |||||
| for m in provider_models: | for m in provider_models: | ||||
| if m.model_type == ModelType.LLM and m.model not in restrict_llms: | |||||
| if m.model_type == ModelType.LLM and m.model not in restrict_model_names: | |||||
| m.status = ModelStatus.NO_PERMISSION | m.status = ModelStatus.NO_PERMISSION | ||||
| elif not quota_configuration.is_valid: | elif not quota_configuration.is_valid: | ||||
| m.status = ModelStatus.QUOTA_EXCEEDED | m.status = ModelStatus.QUOTA_EXCEEDED | ||||
| return provider_models | return provider_models | ||||
| def _get_custom_provider_models(self, | def _get_custom_provider_models(self, |
| UNSUPPORTED = 'unsupported' | UNSUPPORTED = 'unsupported' | ||||
| class RestrictModel(BaseModel): | |||||
| model: str | |||||
| base_model_name: Optional[str] = None | |||||
| model_type: ModelType | |||||
| class QuotaConfiguration(BaseModel): | class QuotaConfiguration(BaseModel): | ||||
| """ | """ | ||||
| Model class for provider quota configuration. | Model class for provider quota configuration. | ||||
| quota_limit: int | quota_limit: int | ||||
| quota_used: int | quota_used: int | ||||
| is_valid: bool | is_valid: bool | ||||
| restrict_llms: list[str] = [] | |||||
| restrict_models: list[RestrictModel] = [] | |||||
| class SystemConfiguration(BaseModel): | class SystemConfiguration(BaseModel): |
| from flask import Flask | from flask import Flask | ||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from core.entities.provider_entities import QuotaUnit | |||||
| from core.entities.provider_entities import QuotaUnit, RestrictModel | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from models.provider import ProviderQuotaType | from models.provider import ProviderQuotaType | ||||
| class HostingQuota(BaseModel): | class HostingQuota(BaseModel): | ||||
| quota_type: ProviderQuotaType | quota_type: ProviderQuotaType | ||||
| restrict_llms: list[str] = [] | |||||
| restrict_models: list[RestrictModel] = [] | |||||
| class TrialHostingQuota(HostingQuota): | class TrialHostingQuota(HostingQuota): | ||||
| provider_map: dict[str, HostingProvider] = {} | provider_map: dict[str, HostingProvider] = {} | ||||
| moderation_config: HostedModerationConfig = None | moderation_config: HostedModerationConfig = None | ||||
| def init_app(self, app: Flask): | |||||
| if app.config.get('EDITION') != 'CLOUD': | |||||
| return | |||||
| def init_app(self, app: Flask) -> None: | |||||
| self.provider_map["azure_openai"] = self.init_azure_openai() | |||||
| self.provider_map["openai"] = self.init_openai() | self.provider_map["openai"] = self.init_openai() | ||||
| self.provider_map["anthropic"] = self.init_anthropic() | self.provider_map["anthropic"] = self.init_anthropic() | ||||
| self.provider_map["minimax"] = self.init_minimax() | self.provider_map["minimax"] = self.init_minimax() | ||||
| self.moderation_config = self.init_moderation_config() | self.moderation_config = self.init_moderation_config() | ||||
| def init_azure_openai(self) -> HostingProvider: | |||||
| quota_unit = QuotaUnit.TIMES | |||||
| if os.environ.get("HOSTED_AZURE_OPENAI_ENABLED") and os.environ.get("HOSTED_AZURE_OPENAI_ENABLED").lower() == 'true': | |||||
| credentials = { | |||||
| "openai_api_key": os.environ.get("HOSTED_AZURE_OPENAI_API_KEY"), | |||||
| "openai_api_base": os.environ.get("HOSTED_AZURE_OPENAI_API_BASE"), | |||||
| "base_model_name": "gpt-35-turbo" | |||||
| } | |||||
| quotas = [] | |||||
| hosted_quota_limit = int(os.environ.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000")) | |||||
| if hosted_quota_limit != -1 or hosted_quota_limit > 0: | |||||
| trial_quota = TrialHostingQuota( | |||||
| quota_limit=hosted_quota_limit, | |||||
| restrict_models=[ | |||||
| RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM), | |||||
| RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM), | |||||
| RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM), | |||||
| RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM), | |||||
| RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM), | |||||
| RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM), | |||||
| RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM), | |||||
| RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM), | |||||
| RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM), | |||||
| RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING), | |||||
| ] | |||||
| ) | |||||
| quotas.append(trial_quota) | |||||
| return HostingProvider( | |||||
| enabled=True, | |||||
| credentials=credentials, | |||||
| quota_unit=quota_unit, | |||||
| quotas=quotas | |||||
| ) | |||||
| return HostingProvider( | |||||
| enabled=False, | |||||
| quota_unit=quota_unit, | |||||
| ) | |||||
| def init_openai(self) -> HostingProvider: | def init_openai(self) -> HostingProvider: | ||||
| quota_unit = QuotaUnit.TIMES | quota_unit = QuotaUnit.TIMES | ||||
| if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true': | if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true': | ||||
| if hosted_quota_limit != -1 or hosted_quota_limit > 0: | if hosted_quota_limit != -1 or hosted_quota_limit > 0: | ||||
| trial_quota = TrialHostingQuota( | trial_quota = TrialHostingQuota( | ||||
| quota_limit=hosted_quota_limit, | quota_limit=hosted_quota_limit, | ||||
| restrict_llms=[ | |||||
| "gpt-3.5-turbo", | |||||
| "gpt-3.5-turbo-1106", | |||||
| "gpt-3.5-turbo-instruct", | |||||
| "gpt-3.5-turbo-16k", | |||||
| "text-davinci-003" | |||||
| restrict_models=[ | |||||
| RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM), | |||||
| RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM), | |||||
| RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM), | |||||
| RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM), | |||||
| RestrictModel(model="text-davinci-003", model_type=ModelType.LLM), | |||||
| ] | ] | ||||
| ) | ) | ||||
| quotas.append(trial_quota) | quotas.append(trial_quota) |
| user=user | user=user | ||||
| ) | ) | ||||
| def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ | |||||
| def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None, **params) \ | |||||
| -> str: | -> str: | ||||
| """ | """ | ||||
| Invoke large language model | Invoke large language model | ||||
| model=self.model, | model=self.model, | ||||
| credentials=self.credentials, | credentials=self.credentials, | ||||
| file=file, | file=file, | ||||
| user=user | |||||
| user=user, | |||||
| **params | |||||
| ) | ) | ||||
| return cls.TEXT_EMBEDDING | return cls.TEXT_EMBEDDING | ||||
| elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value: | elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value: | ||||
| return cls.RERANK | return cls.RERANK | ||||
| elif origin_model_type == cls.SPEECH2TEXT.value: | |||||
| elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value: | |||||
| return cls.SPEECH2TEXT | return cls.SPEECH2TEXT | ||||
| elif origin_model_type == cls.MODERATION.value: | elif origin_model_type == cls.MODERATION.value: | ||||
| return cls.MODERATION | return cls.MODERATION |
| from core.model_runtime.entities.llm_entities import LLMMode | from core.model_runtime.entities.llm_entities import LLMMode | ||||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \ | from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \ | ||||
| DefaultParameterName, PriceConfig | |||||
| DefaultParameterName, PriceConfig, ModelPropertyKey | |||||
| from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject | from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject | ||||
| from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE | from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE | ||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||||
| model_type=ModelType.TEXT_EMBEDDING, | model_type=ModelType.TEXT_EMBEDDING, | ||||
| model_properties={ | model_properties={ | ||||
| 'context_size': 8097, | |||||
| 'max_chunks': 32, | |||||
| ModelPropertyKey.CONTEXT_SIZE: 8097, | |||||
| ModelPropertyKey.MAX_CHUNKS: 32, | |||||
| }, | }, | ||||
| pricing=PriceConfig( | pricing=PriceConfig( | ||||
| input=0.0001, | input=0.0001, |
| stream: bool = True, user: Optional[str] = None) \ | stream: bool = True, user: Optional[str] = None) \ | ||||
| -> Union[LLMResult, Generator]: | -> Union[LLMResult, Generator]: | ||||
| ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) | |||||
| ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) | |||||
| if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: | if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: | ||||
| # chat model | # chat model | ||||
| def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], | def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], | ||||
| tools: Optional[list[PromptMessageTool]] = None) -> int: | tools: Optional[list[PromptMessageTool]] = None) -> int: | ||||
| model_mode = self._get_ai_model_entity(credentials['base_model_name'], model).entity.model_properties.get( | |||||
| model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get( | |||||
| ModelPropertyKey.MODE) | ModelPropertyKey.MODE) | ||||
| if model_mode == LLMMode.CHAT.value: | if model_mode == LLMMode.CHAT.value: | ||||
| if 'base_model_name' not in credentials: | if 'base_model_name' not in credentials: | ||||
| raise CredentialsValidateFailedError('Base Model Name is required') | raise CredentialsValidateFailedError('Base Model Name is required') | ||||
| ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) | |||||
| ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) | |||||
| if not ai_model_entity: | if not ai_model_entity: | ||||
| raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') | raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') | ||||
| raise CredentialsValidateFailedError(str(ex)) | raise CredentialsValidateFailedError(str(ex)) | ||||
| def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | ||||
| ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) | |||||
| return ai_model_entity.entity | |||||
| ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) | |||||
| return ai_model_entity.entity if ai_model_entity else None | |||||
| def _generate(self, model: str, credentials: dict, | def _generate(self, model: str, credentials: dict, | ||||
| prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, | prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, |
| from core.helper import encrypter | from core.helper import encrypter | ||||
| from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType | |||||
| from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \ | |||||
| ConfigurateMethod | |||||
| from core.model_runtime.model_providers import model_provider_factory | from core.model_runtime.model_providers import model_provider_factory | ||||
| from extensions import ext_hosting_provider | from extensions import ext_hosting_provider | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| quota_used=provider_record.quota_used, | quota_used=provider_record.quota_used, | ||||
| quota_limit=provider_record.quota_limit, | quota_limit=provider_record.quota_limit, | ||||
| is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1, | is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1, | ||||
| restrict_llms=provider_quota.restrict_llms | |||||
| restrict_models=provider_quota.restrict_models | |||||
| ) | ) | ||||
| quota_configurations.append(quota_configuration) | quota_configurations.append(quota_configuration) |