Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>tags/0.4.5
| @@ -1,7 +1,7 @@ | |||
| import datetime | |||
| import json | |||
| import logging | |||
| import time | |||
| from json import JSONDecodeError | |||
| from typing import Optional, List, Dict, Tuple, Iterator | |||
| @@ -11,8 +11,9 @@ from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, S | |||
| from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus | |||
| from core.helper import encrypter | |||
| from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType | |||
| from core.model_runtime.entities.model_entities import ModelType, FetchFrom | |||
| from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \ | |||
| ConfigurateMethod | |||
| from core.model_runtime.model_providers import model_provider_factory | |||
| from core.model_runtime.model_providers.__base.ai_model import AIModel | |||
| from core.model_runtime.model_providers.__base.model_provider import ModelProvider | |||
| @@ -22,6 +23,8 @@ from models.provider import ProviderType, Provider, ProviderModel, TenantPreferr | |||
| logger = logging.getLogger(__name__) | |||
| original_provider_configurate_methods = {} | |||
| class ProviderConfiguration(BaseModel): | |||
| """ | |||
| @@ -34,6 +37,20 @@ class ProviderConfiguration(BaseModel): | |||
| system_configuration: SystemConfiguration | |||
| custom_configuration: CustomConfiguration | |||
| def __init__(self, **data): | |||
| super().__init__(**data) | |||
| if self.provider.provider not in original_provider_configurate_methods: | |||
| original_provider_configurate_methods[self.provider.provider] = [] | |||
| for configurate_method in self.provider.configurate_methods: | |||
| original_provider_configurate_methods[self.provider.provider].append(configurate_method) | |||
| if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: | |||
| if (any([len(quota_configuration.restrict_models) > 0 | |||
| for quota_configuration in self.system_configuration.quota_configurations]) | |||
| and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): | |||
| self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) | |||
| def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: | |||
| """ | |||
| Get current credentials. | |||
| @@ -123,7 +140,8 @@ class ProviderConfiguration(BaseModel): | |||
| if provider_record: | |||
| try: | |||
| original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {} | |||
| original_credentials = json.loads( | |||
| provider_record.encrypted_config) if provider_record.encrypted_config else {} | |||
| except JSONDecodeError: | |||
| original_credentials = {} | |||
| @@ -265,7 +283,8 @@ class ProviderConfiguration(BaseModel): | |||
| if provider_model_record: | |||
| try: | |||
| original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} | |||
| original_credentials = json.loads( | |||
| provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} | |||
| except JSONDecodeError: | |||
| original_credentials = {} | |||
| @@ -534,21 +553,70 @@ class ProviderConfiguration(BaseModel): | |||
| ] | |||
| ) | |||
| if self.provider.provider not in original_provider_configurate_methods: | |||
| original_provider_configurate_methods[self.provider.provider] = [] | |||
| for configurate_method in provider_instance.get_provider_schema().configurate_methods: | |||
| original_provider_configurate_methods[self.provider.provider].append(configurate_method) | |||
| should_use_custom_model = False | |||
| if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: | |||
| should_use_custom_model = True | |||
| for quota_configuration in self.system_configuration.quota_configurations: | |||
| if self.system_configuration.current_quota_type != quota_configuration.quota_type: | |||
| continue | |||
| restrict_llms = quota_configuration.restrict_llms | |||
| if not restrict_llms: | |||
| restrict_models = quota_configuration.restrict_models | |||
| if len(restrict_models) == 0: | |||
| break | |||
| if should_use_custom_model: | |||
| if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: | |||
| # only customizable model | |||
| for restrict_model in restrict_models: | |||
| copy_credentials = self.system_configuration.credentials.copy() | |||
| if restrict_model.base_model_name: | |||
| copy_credentials['base_model_name'] = restrict_model.base_model_name | |||
| try: | |||
| custom_model_schema = ( | |||
| provider_instance.get_model_instance(restrict_model.model_type) | |||
| .get_customizable_model_schema_from_credentials( | |||
| restrict_model.model, | |||
| copy_credentials | |||
| ) | |||
| ) | |||
| except Exception as ex: | |||
| logger.warning(f'get custom model schema failed, {ex}') | |||
| continue | |||
| if not custom_model_schema: | |||
| continue | |||
| if custom_model_schema.model_type not in model_types: | |||
| continue | |||
| provider_models.append( | |||
| ModelWithProviderEntity( | |||
| model=custom_model_schema.model, | |||
| label=custom_model_schema.label, | |||
| model_type=custom_model_schema.model_type, | |||
| features=custom_model_schema.features, | |||
| fetch_from=FetchFrom.PREDEFINED_MODEL, | |||
| model_properties=custom_model_schema.model_properties, | |||
| deprecated=custom_model_schema.deprecated, | |||
| provider=SimpleModelProviderEntity(self.provider), | |||
| status=ModelStatus.ACTIVE | |||
| ) | |||
| ) | |||
| # if llm name not in restricted llm list, remove it | |||
| restrict_model_names = [rm.model for rm in restrict_models] | |||
| for m in provider_models: | |||
| if m.model_type == ModelType.LLM and m.model not in restrict_llms: | |||
| if m.model_type == ModelType.LLM and m.model not in restrict_model_names: | |||
| m.status = ModelStatus.NO_PERMISSION | |||
| elif not quota_configuration.is_valid: | |||
| m.status = ModelStatus.QUOTA_EXCEEDED | |||
| return provider_models | |||
| def _get_custom_provider_models(self, | |||
| @@ -21,6 +21,12 @@ class SystemConfigurationStatus(Enum): | |||
| UNSUPPORTED = 'unsupported' | |||
| class RestrictModel(BaseModel): | |||
| model: str | |||
| base_model_name: Optional[str] = None | |||
| model_type: ModelType | |||
| class QuotaConfiguration(BaseModel): | |||
| """ | |||
| Model class for provider quota configuration. | |||
| @@ -30,7 +36,7 @@ class QuotaConfiguration(BaseModel): | |||
| quota_limit: int | |||
| quota_used: int | |||
| is_valid: bool | |||
| restrict_llms: list[str] = [] | |||
| restrict_models: list[RestrictModel] = [] | |||
| class SystemConfiguration(BaseModel): | |||
| @@ -4,13 +4,14 @@ from typing import Optional | |||
| from flask import Flask | |||
| from pydantic import BaseModel | |||
| from core.entities.provider_entities import QuotaUnit | |||
| from core.entities.provider_entities import QuotaUnit, RestrictModel | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from models.provider import ProviderQuotaType | |||
| class HostingQuota(BaseModel): | |||
| quota_type: ProviderQuotaType | |||
| restrict_llms: list[str] = [] | |||
| restrict_models: list[RestrictModel] = [] | |||
| class TrialHostingQuota(HostingQuota): | |||
| @@ -47,10 +48,9 @@ class HostingConfiguration: | |||
| provider_map: dict[str, HostingProvider] = {} | |||
| moderation_config: HostedModerationConfig = None | |||
| def init_app(self, app: Flask): | |||
| if app.config.get('EDITION') != 'CLOUD': | |||
| return | |||
| def init_app(self, app: Flask) -> None: | |||
| self.provider_map["azure_openai"] = self.init_azure_openai() | |||
| self.provider_map["openai"] = self.init_openai() | |||
| self.provider_map["anthropic"] = self.init_anthropic() | |||
| self.provider_map["minimax"] = self.init_minimax() | |||
| @@ -59,6 +59,47 @@ class HostingConfiguration: | |||
| self.moderation_config = self.init_moderation_config() | |||
| def init_azure_openai(self) -> HostingProvider: | |||
| quota_unit = QuotaUnit.TIMES | |||
| if os.environ.get("HOSTED_AZURE_OPENAI_ENABLED") and os.environ.get("HOSTED_AZURE_OPENAI_ENABLED").lower() == 'true': | |||
| credentials = { | |||
| "openai_api_key": os.environ.get("HOSTED_AZURE_OPENAI_API_KEY"), | |||
| "openai_api_base": os.environ.get("HOSTED_AZURE_OPENAI_API_BASE"), | |||
| "base_model_name": "gpt-35-turbo" | |||
| } | |||
| quotas = [] | |||
| hosted_quota_limit = int(os.environ.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000")) | |||
| if hosted_quota_limit != -1 or hosted_quota_limit > 0: | |||
| trial_quota = TrialHostingQuota( | |||
| quota_limit=hosted_quota_limit, | |||
| restrict_models=[ | |||
| RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM), | |||
| RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM), | |||
| RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM), | |||
| RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM), | |||
| RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM), | |||
| RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM), | |||
| RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM), | |||
| RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM), | |||
| RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM), | |||
| RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING), | |||
| ] | |||
| ) | |||
| quotas.append(trial_quota) | |||
| return HostingProvider( | |||
| enabled=True, | |||
| credentials=credentials, | |||
| quota_unit=quota_unit, | |||
| quotas=quotas | |||
| ) | |||
| return HostingProvider( | |||
| enabled=False, | |||
| quota_unit=quota_unit, | |||
| ) | |||
| def init_openai(self) -> HostingProvider: | |||
| quota_unit = QuotaUnit.TIMES | |||
| if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true': | |||
| @@ -77,12 +118,12 @@ class HostingConfiguration: | |||
| if hosted_quota_limit != -1 or hosted_quota_limit > 0: | |||
| trial_quota = TrialHostingQuota( | |||
| quota_limit=hosted_quota_limit, | |||
| restrict_llms=[ | |||
| "gpt-3.5-turbo", | |||
| "gpt-3.5-turbo-1106", | |||
| "gpt-3.5-turbo-instruct", | |||
| "gpt-3.5-turbo-16k", | |||
| "text-davinci-003" | |||
| restrict_models=[ | |||
| RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM), | |||
| RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM), | |||
| RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM), | |||
| RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM), | |||
| RestrictModel(model="text-davinci-003", model_type=ModelType.LLM), | |||
| ] | |||
| ) | |||
| quotas.append(trial_quota) | |||
| @@ -144,7 +144,7 @@ class ModelInstance: | |||
| user=user | |||
| ) | |||
| def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ | |||
| def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None, **params) \ | |||
| -> str: | |||
| """ | |||
| Invoke large language model | |||
| @@ -161,7 +161,8 @@ class ModelInstance: | |||
| model=self.model, | |||
| credentials=self.credentials, | |||
| file=file, | |||
| user=user | |||
| user=user, | |||
| **params | |||
| ) | |||
| @@ -32,7 +32,7 @@ class ModelType(Enum): | |||
| return cls.TEXT_EMBEDDING | |||
| elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value: | |||
| return cls.RERANK | |||
| elif origin_model_type == cls.SPEECH2TEXT.value: | |||
| elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value: | |||
| return cls.SPEECH2TEXT | |||
| elif origin_model_type == cls.MODERATION.value: | |||
| return cls.MODERATION | |||
| @@ -2,7 +2,7 @@ from pydantic import BaseModel | |||
| from core.model_runtime.entities.llm_entities import LLMMode | |||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \ | |||
| DefaultParameterName, PriceConfig | |||
| DefaultParameterName, PriceConfig, ModelPropertyKey | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject | |||
| from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE | |||
| @@ -502,8 +502,8 @@ EMBEDDING_BASE_MODELS = [ | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model_properties={ | |||
| 'context_size': 8097, | |||
| 'max_chunks': 32, | |||
| ModelPropertyKey.CONTEXT_SIZE: 8097, | |||
| ModelPropertyKey.MAX_CHUNKS: 32, | |||
| }, | |||
| pricing=PriceConfig( | |||
| input=0.0001, | |||
| @@ -30,7 +30,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): | |||
| stream: bool = True, user: Optional[str] = None) \ | |||
| -> Union[LLMResult, Generator]: | |||
| ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) | |||
| ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) | |||
| if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: | |||
| # chat model | |||
| @@ -59,7 +59,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): | |||
| def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], | |||
| tools: Optional[list[PromptMessageTool]] = None) -> int: | |||
| model_mode = self._get_ai_model_entity(credentials['base_model_name'], model).entity.model_properties.get( | |||
| model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get( | |||
| ModelPropertyKey.MODE) | |||
| if model_mode == LLMMode.CHAT.value: | |||
| @@ -79,7 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): | |||
| if 'base_model_name' not in credentials: | |||
| raise CredentialsValidateFailedError('Base Model Name is required') | |||
| ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) | |||
| ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) | |||
| if not ai_model_entity: | |||
| raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') | |||
| @@ -109,8 +109,8 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): | |||
| raise CredentialsValidateFailedError(str(ex)) | |||
| def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | |||
| ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) | |||
| return ai_model_entity.entity | |||
| ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model) | |||
| return ai_model_entity.entity if ai_model_entity else None | |||
| def _generate(self, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, | |||
| @@ -12,7 +12,8 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC | |||
| from core.helper import encrypter | |||
| from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType | |||
| from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \ | |||
| ConfigurateMethod | |||
| from core.model_runtime.model_providers import model_provider_factory | |||
| from extensions import ext_hosting_provider | |||
| from extensions.ext_database import db | |||
| @@ -607,7 +608,7 @@ class ProviderManager: | |||
| quota_used=provider_record.quota_used, | |||
| quota_limit=provider_record.quota_limit, | |||
| is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1, | |||
| restrict_llms=provider_quota.restrict_llms | |||
| restrict_models=provider_quota.restrict_models | |||
| ) | |||
| quota_configurations.append(quota_configuration) | |||