Sfoglia il codice sorgente

chore: Extract common functions of the base model in Azure OpenAI Provider (#9907)

tags/0.10.2
ice yao 1 anno fa
parent
commit
22776f24ab
Nessun account collegato all'indirizzo email del committer

+ 3
- 0
api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml Vedi File

type: select type: select
required: true required: true
options: options:
- label:
en_US: 2024-10-01-preview
value: 2024-10-01-preview
- label: - label:
en_US: 2024-09-01-preview en_US: 2024-09-01-preview
value: 2024-09-01-preview value: 2024-09-01-preview

+ 10
- 17
api/core/model_runtime/model_providers/azure_openai/llm/llm.py Vedi File

stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)


if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
) -> int: ) -> int:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
base_model_name = self._get_base_model_name(credentials)
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not model_entity: if not model_entity:
raise ValueError(f"Base Model Name {base_model_name} is invalid") raise ValueError(f"Base Model Name {base_model_name} is invalid")
if "base_model_name" not in credentials: if "base_model_name" not in credentials:
raise CredentialsValidateFailedError("Base Model Name is required") raise CredentialsValidateFailedError("Base Model Name is required")


base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise CredentialsValidateFailedError("Base Model Name is required")
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)


if not ai_model_entity: if not ai_model_entity:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))


def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
return ai_model_entity.entity if ai_model_entity else None return ai_model_entity.entity if ai_model_entity else None




if tools: if tools:
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
# extra_model_kwargs['functions'] = [{
# "name": tool.name,
# "description": tool.description,
# "parameters": tool.parameters
# } for tool in tools]


if stop: if stop:
extra_model_kwargs["stop"] = stop extra_model_kwargs["stop"] = stop
ai_model_entity_copy.entity.label.en_US = model ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy return ai_model_entity_copy

def _get_base_model_name(self, credentials: dict) -> str:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
return base_model_name

Loading…
Annulla
Salva