|
|
|
@@ -45,9 +45,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
stream: bool = True, |
|
|
|
user: Optional[str] = None, |
|
|
|
) -> Union[LLMResult, Generator]: |
|
|
|
base_model_name = credentials.get("base_model_name") |
|
|
|
if not base_model_name: |
|
|
|
raise ValueError("Base Model Name is required") |
|
|
|
base_model_name = self._get_base_model_name(credentials) |
|
|
|
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) |
|
|
|
|
|
|
|
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: |
|
|
|
@@ -81,9 +79,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
prompt_messages: list[PromptMessage], |
|
|
|
tools: Optional[list[PromptMessageTool]] = None, |
|
|
|
) -> int: |
|
|
|
base_model_name = credentials.get("base_model_name") |
|
|
|
if not base_model_name: |
|
|
|
raise ValueError("Base Model Name is required") |
|
|
|
base_model_name = self._get_base_model_name(credentials) |
|
|
|
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) |
|
|
|
if not model_entity: |
|
|
|
raise ValueError(f"Base Model Name {base_model_name} is invalid") |
|
|
|
@@ -108,9 +104,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
if "base_model_name" not in credentials: |
|
|
|
raise CredentialsValidateFailedError("Base Model Name is required") |
|
|
|
|
|
|
|
base_model_name = credentials.get("base_model_name") |
|
|
|
if not base_model_name: |
|
|
|
raise CredentialsValidateFailedError("Base Model Name is required") |
|
|
|
base_model_name = self._get_base_model_name(credentials) |
|
|
|
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) |
|
|
|
|
|
|
|
if not ai_model_entity: |
|
|
|
@@ -149,9 +143,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
raise CredentialsValidateFailedError(str(ex)) |
|
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: |
|
|
|
base_model_name = credentials.get("base_model_name") |
|
|
|
if not base_model_name: |
|
|
|
raise ValueError("Base Model Name is required") |
|
|
|
base_model_name = self._get_base_model_name(credentials) |
|
|
|
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) |
|
|
|
return ai_model_entity.entity if ai_model_entity else None |
|
|
|
|
|
|
|
@@ -308,11 +300,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
|
|
|
|
if tools: |
|
|
|
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] |
|
|
|
# extra_model_kwargs['functions'] = [{ |
|
|
|
# "name": tool.name, |
|
|
|
# "description": tool.description, |
|
|
|
# "parameters": tool.parameters |
|
|
|
# } for tool in tools] |
|
|
|
|
|
|
|
if stop: |
|
|
|
extra_model_kwargs["stop"] = stop |
|
|
|
@@ -769,3 +756,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
|
|
ai_model_entity_copy.entity.label.en_US = model |
|
|
|
ai_model_entity_copy.entity.label.zh_Hans = model |
|
|
|
return ai_model_entity_copy |
|
|
|
|
|
|
|
def _get_base_model_name(self, credentials: dict) -> str: |
|
|
|
base_model_name = credentials.get("base_model_name") |
|
|
|
if not base_model_name: |
|
|
|
raise ValueError("Base Model Name is required") |
|
|
|
return base_model_name |