Browse Source

Azure openai init (#1929)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
tags/0.4.5
Charlie.Wei 1 year ago
parent
commit
5b24d7129e
No account linked to committer's email address

+ 77
- 9
api/core/entities/provider_configuration.py View File

import datetime import datetime
import json import json
import logging import logging
import time
from json import JSONDecodeError from json import JSONDecodeError
from typing import Optional, List, Dict, Tuple, Iterator from typing import Optional, List, Dict, Tuple, Iterator


from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
from core.helper import encrypter from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
from core.model_runtime.entities.model_entities import ModelType, FetchFrom
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
ConfigurateMethod
from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.model_providers.__base.model_provider import ModelProvider


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)


original_provider_configurate_methods = {}



class ProviderConfiguration(BaseModel): class ProviderConfiguration(BaseModel):
""" """
system_configuration: SystemConfiguration system_configuration: SystemConfiguration
custom_configuration: CustomConfiguration custom_configuration: CustomConfiguration


def __init__(self, **data):
super().__init__(**data)

if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in self.provider.configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)

if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if (any([len(quota_configuration.restrict_models) > 0
for quota_configuration in self.system_configuration.quota_configurations])
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)

def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
""" """
Get current credentials. Get current credentials.


if provider_record: if provider_record:
try: try:
original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {}
original_credentials = json.loads(
provider_record.encrypted_config) if provider_record.encrypted_config else {}
except JSONDecodeError: except JSONDecodeError:
original_credentials = {} original_credentials = {}




if provider_model_record: if provider_model_record:
try: try:
original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
original_credentials = json.loads(
provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
except JSONDecodeError: except JSONDecodeError:
original_credentials = {} original_credentials = {}


] ]
) )


if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
original_provider_configurate_methods[self.provider.provider].append(configurate_method)

should_use_custom_model = False
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
should_use_custom_model = True

for quota_configuration in self.system_configuration.quota_configurations: for quota_configuration in self.system_configuration.quota_configurations:
if self.system_configuration.current_quota_type != quota_configuration.quota_type: if self.system_configuration.current_quota_type != quota_configuration.quota_type:
continue continue


restrict_llms = quota_configuration.restrict_llms
if not restrict_llms:
restrict_models = quota_configuration.restrict_models
if len(restrict_models) == 0:
break break


if should_use_custom_model:
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy()
if restrict_model.base_model_name:
copy_credentials['base_model_name'] = restrict_model.base_model_name

try:
custom_model_schema = (
provider_instance.get_model_instance(restrict_model.model_type)
.get_customizable_model_schema_from_credentials(
restrict_model.model,
copy_credentials
)
)
except Exception as ex:
logger.warning(f'get custom model schema failed, {ex}')
continue

if not custom_model_schema:
continue

if custom_model_schema.model_type not in model_types:
continue

provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
label=custom_model_schema.label,
model_type=custom_model_schema.model_type,
features=custom_model_schema.features,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
)
)

# if llm name not in restricted llm list, remove it # if llm name not in restricted llm list, remove it
restrict_model_names = [rm.model for rm in restrict_models]
for m in provider_models: for m in provider_models:
if m.model_type == ModelType.LLM and m.model not in restrict_llms:
if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
m.status = ModelStatus.NO_PERMISSION m.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid: elif not quota_configuration.is_valid:
m.status = ModelStatus.QUOTA_EXCEEDED m.status = ModelStatus.QUOTA_EXCEEDED

return provider_models return provider_models


def _get_custom_provider_models(self, def _get_custom_provider_models(self,

+ 7
- 1
api/core/entities/provider_entities.py View File

UNSUPPORTED = 'unsupported' UNSUPPORTED = 'unsupported'




class RestrictModel(BaseModel):
model: str
base_model_name: Optional[str] = None
model_type: ModelType


class QuotaConfiguration(BaseModel): class QuotaConfiguration(BaseModel):
""" """
Model class for provider quota configuration. Model class for provider quota configuration.
quota_limit: int quota_limit: int
quota_used: int quota_used: int
is_valid: bool is_valid: bool
restrict_llms: list[str] = []
restrict_models: list[RestrictModel] = []




class SystemConfiguration(BaseModel): class SystemConfiguration(BaseModel):

+ 52
- 11
api/core/hosting_configuration.py View File

from flask import Flask from flask import Flask
from pydantic import BaseModel from pydantic import BaseModel


from core.entities.provider_entities import QuotaUnit
from core.entities.provider_entities import QuotaUnit, RestrictModel
from core.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderQuotaType from models.provider import ProviderQuotaType




class HostingQuota(BaseModel): class HostingQuota(BaseModel):
quota_type: ProviderQuotaType quota_type: ProviderQuotaType
restrict_llms: list[str] = []
restrict_models: list[RestrictModel] = []




class TrialHostingQuota(HostingQuota): class TrialHostingQuota(HostingQuota):
provider_map: dict[str, HostingProvider] = {} provider_map: dict[str, HostingProvider] = {}
moderation_config: HostedModerationConfig = None moderation_config: HostedModerationConfig = None


def init_app(self, app: Flask):
if app.config.get('EDITION') != 'CLOUD':
return
def init_app(self, app: Flask) -> None:


self.provider_map["azure_openai"] = self.init_azure_openai()
self.provider_map["openai"] = self.init_openai() self.provider_map["openai"] = self.init_openai()
self.provider_map["anthropic"] = self.init_anthropic() self.provider_map["anthropic"] = self.init_anthropic()
self.provider_map["minimax"] = self.init_minimax() self.provider_map["minimax"] = self.init_minimax()


self.moderation_config = self.init_moderation_config() self.moderation_config = self.init_moderation_config()


def init_azure_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.TIMES
if os.environ.get("HOSTED_AZURE_OPENAI_ENABLED") and os.environ.get("HOSTED_AZURE_OPENAI_ENABLED").lower() == 'true':
credentials = {
"openai_api_key": os.environ.get("HOSTED_AZURE_OPENAI_API_KEY"),
"openai_api_base": os.environ.get("HOSTED_AZURE_OPENAI_API_BASE"),
"base_model_name": "gpt-35-turbo"
}

quotas = []
hosted_quota_limit = int(os.environ.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000"))
if hosted_quota_limit != -1 or hosted_quota_limit > 0:
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_models=[
RestrictModel(model="gpt-4", base_model_name="gpt-4", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
]
)
quotas.append(trial_quota)

return HostingProvider(
enabled=True,
credentials=credentials,
quota_unit=quota_unit,
quotas=quotas
)

return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)

def init_openai(self) -> HostingProvider: def init_openai(self) -> HostingProvider:
quota_unit = QuotaUnit.TIMES quota_unit = QuotaUnit.TIMES
if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true': if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true':
if hosted_quota_limit != -1 or hosted_quota_limit > 0: if hosted_quota_limit != -1 or hosted_quota_limit > 0:
trial_quota = TrialHostingQuota( trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit, quota_limit=hosted_quota_limit,
restrict_llms=[
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-instruct",
"gpt-3.5-turbo-16k",
"text-davinci-003"
restrict_models=[
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
] ]
) )
quotas.append(trial_quota) quotas.append(trial_quota)

+ 3
- 2
api/core/model_manager.py View File

user=user user=user
) )


def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None, **params) \
-> str: -> str:
""" """
Invoke large language model Invoke large language model
model=self.model, model=self.model,
credentials=self.credentials, credentials=self.credentials,
file=file, file=file,
user=user
user=user,
**params
) )





+ 1
- 1
api/core/model_runtime/entities/model_entities.py View File

return cls.TEXT_EMBEDDING return cls.TEXT_EMBEDDING
elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value: elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
return cls.RERANK return cls.RERANK
elif origin_model_type == cls.SPEECH2TEXT.value:
elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
return cls.SPEECH2TEXT return cls.SPEECH2TEXT
elif origin_model_type == cls.MODERATION.value: elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION return cls.MODERATION

+ 3
- 3
api/core/model_runtime/model_providers/azure_openai/_constant.py View File



from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \ from core.model_runtime.entities.model_entities import ModelFeature, ModelType, FetchFrom, ParameterRule, \
DefaultParameterName, PriceConfig
DefaultParameterName, PriceConfig, ModelPropertyKey
from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE


fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model_properties={ model_properties={
'context_size': 8097,
'max_chunks': 32,
ModelPropertyKey.CONTEXT_SIZE: 8097,
ModelPropertyKey.MAX_CHUNKS: 32,
}, },
pricing=PriceConfig( pricing=PriceConfig(
input=0.0001, input=0.0001,

+ 5
- 5
api/core/model_runtime/model_providers/azure_openai/llm/llm.py View File

stream: bool = True, user: Optional[str] = None) \ stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]: -> Union[LLMResult, Generator]:


ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)


if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model # chat model
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int: tools: Optional[list[PromptMessageTool]] = None) -> int:


model_mode = self._get_ai_model_entity(credentials['base_model_name'], model).entity.model_properties.get(
model_mode = self._get_ai_model_entity(credentials.get('base_model_name'), model).entity.model_properties.get(
ModelPropertyKey.MODE) ModelPropertyKey.MODE)


if model_mode == LLMMode.CHAT.value: if model_mode == LLMMode.CHAT.value:
if 'base_model_name' not in credentials: if 'base_model_name' not in credentials:
raise CredentialsValidateFailedError('Base Model Name is required') raise CredentialsValidateFailedError('Base Model Name is required')


ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)


if not ai_model_entity: if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))


def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
return ai_model_entity.entity
ai_model_entity = self._get_ai_model_entity(credentials.get('base_model_name'), model)
return ai_model_entity.entity if ai_model_entity else None


def _generate(self, model: str, credentials: dict, def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,

+ 3
- 2
api/core/provider_manager.py View File

from core.helper import encrypter from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \
ConfigurateMethod
from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider from extensions import ext_hosting_provider
from extensions.ext_database import db from extensions.ext_database import db
quota_used=provider_record.quota_used, quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit, quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1, is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1,
restrict_llms=provider_quota.restrict_llms
restrict_models=provider_quota.restrict_models
) )


quota_configurations.append(quota_configuration) quota_configurations.append(quota_configuration)

Loading…
Cancel
Save