| import logging | import logging | ||||
| from typing import Optional, Union | from typing import Optional, Union | ||||
| import openai | |||||
| import requests | import requests | ||||
| from core.llm.provider.base import BaseProvider | from core.llm.provider.base import BaseProvider | ||||
| class AzureProvider(BaseProvider): | class AzureProvider(BaseProvider): | ||||
| def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]: | def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]: | ||||
| credentials = self.get_credentials(model_id) if not credentials else credentials | |||||
| url = "{}/openai/deployments?api-version={}".format( | |||||
| str(credentials.get('openai_api_base')), | |||||
| str(credentials.get('openai_api_version')) | |||||
| ) | |||||
| headers = { | |||||
| "api-key": str(credentials.get('openai_api_key')), | |||||
| "content-type": "application/json; charset=utf-8" | |||||
| } | |||||
| response = requests.get(url, headers=headers) | |||||
| if response.status_code == 200: | |||||
| result = response.json() | |||||
| return [{ | |||||
| 'id': deployment['id'], | |||||
| 'name': '{} ({})'.format(deployment['id'], deployment['model']) | |||||
| } for deployment in result['data'] if deployment['status'] == 'succeeded'] | |||||
| else: | |||||
| if response.status_code == 401: | |||||
| raise AzureAuthenticationError() | |||||
| return [] | |||||
| def check_embedding_model(self, credentials: Optional[dict] = None): | |||||
| credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials | |||||
| try: | |||||
| result = openai.Embedding.create(input=['test'], | |||||
| engine='text-embedding-ada-0021', | |||||
| timeout=60, | |||||
| api_key=str(credentials.get('openai_api_key')), | |||||
| api_base=str(credentials.get('openai_api_base')), | |||||
| api_type='azure', | |||||
| api_version=str(credentials.get('openai_api_version')))["data"][0][ | |||||
| "embedding"] | |||||
| except openai.error.AuthenticationError as e: | |||||
| raise AzureAuthenticationError(str(e)) | |||||
| except openai.error.APIConnectionError as e: | |||||
| raise AzureRequestFailedError( | |||||
| 'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`') | |||||
| except openai.error.InvalidRequestError as e: | |||||
| if e.http_status == 404: | |||||
| raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' " | |||||
| "deployment name is exists in Azure AI") | |||||
| else: | else: | ||||
| raise AzureRequestFailedError('Failed to request Azure OpenAI. Status code: {}'.format(response.status_code)) | |||||
| raise AzureRequestFailedError( | |||||
| 'Failed to request Azure OpenAI. cause: {}'.format(str(e))) | |||||
| except openai.error.OpenAIError as e: | |||||
| raise AzureRequestFailedError( | |||||
| 'Failed to request Azure OpenAI. cause: {}'.format(str(e))) | |||||
| if not isinstance(result, list): | |||||
| raise AzureRequestFailedError('Failed to request Azure OpenAI.') | |||||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | def get_credentials(self, model_id: Optional[str] = None) -> dict: | ||||
| """ | """ | ||||
| if 'openai_api_version' not in config: | if 'openai_api_version' not in config: | ||||
| config['openai_api_version'] = '2023-03-15-preview' | config['openai_api_version'] = '2023-03-15-preview' | ||||
| models = self.get_models(credentials=config) | |||||
| if not models: | |||||
| raise ValidateFailedError("Please add deployments for " | |||||
| "'gpt-3.5-turbo', 'text-embedding-ada-002' (required) " | |||||
| "and 'gpt-4', 'gpt-35-turbo-16k', 'text-davinci-003' (optional).") | |||||
| fixed_model_ids = [ | |||||
| 'gpt-35-turbo', | |||||
| 'text-embedding-ada-002' | |||||
| ] | |||||
| current_model_ids = [model['id'] for model in models] | |||||
| missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if | |||||
| fixed_model_id not in current_model_ids] | |||||
| if missing_model_ids: | |||||
| raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids))) | |||||
| self.check_embedding_model(credentials=config) | |||||
| except ValidateFailedError as e: | except ValidateFailedError as e: | ||||
| raise e | raise e | ||||
| except AzureAuthenticationError: | except AzureAuthenticationError: | ||||
| raise ValidateFailedError('Validation failed, please check your API Key.') | raise ValidateFailedError('Validation failed, please check your API Key.') | ||||
| except (requests.ConnectionError, requests.RequestException): | |||||
| raise ValidateFailedError('Validation failed, please check your API Base Endpoint.') | |||||
| except AzureRequestFailedError as ex: | except AzureRequestFailedError as ex: | ||||
| raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex))) | raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex))) | ||||
| except Exception as ex: | except Exception as ex: |