| @@ -2,6 +2,7 @@ import json | |||
| import logging | |||
| from typing import Optional, Union | |||
| import openai | |||
| import requests | |||
| from core.llm.provider.base import BaseProvider | |||
| @@ -11,30 +12,37 @@ from models.provider import ProviderName | |||
| class AzureProvider(BaseProvider): | |||
| def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]: | |||
| credentials = self.get_credentials(model_id) if not credentials else credentials | |||
| url = "{}/openai/deployments?api-version={}".format( | |||
| str(credentials.get('openai_api_base')), | |||
| str(credentials.get('openai_api_version')) | |||
| ) | |||
| headers = { | |||
| "api-key": str(credentials.get('openai_api_key')), | |||
| "content-type": "application/json; charset=utf-8" | |||
| } | |||
| response = requests.get(url, headers=headers) | |||
| if response.status_code == 200: | |||
| result = response.json() | |||
| return [{ | |||
| 'id': deployment['id'], | |||
| 'name': '{} ({})'.format(deployment['id'], deployment['model']) | |||
| } for deployment in result['data'] if deployment['status'] == 'succeeded'] | |||
| else: | |||
| if response.status_code == 401: | |||
| raise AzureAuthenticationError() | |||
| return [] | |||
| def check_embedding_model(self, credentials: Optional[dict] = None): | |||
| credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials | |||
| try: | |||
| result = openai.Embedding.create(input=['test'], | |||
| engine='text-embedding-ada-0021', | |||
| timeout=60, | |||
| api_key=str(credentials.get('openai_api_key')), | |||
| api_base=str(credentials.get('openai_api_base')), | |||
| api_type='azure', | |||
| api_version=str(credentials.get('openai_api_version')))["data"][0][ | |||
| "embedding"] | |||
| except openai.error.AuthenticationError as e: | |||
| raise AzureAuthenticationError(str(e)) | |||
| except openai.error.APIConnectionError as e: | |||
| raise AzureRequestFailedError( | |||
| 'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`') | |||
| except openai.error.InvalidRequestError as e: | |||
| if e.http_status == 404: | |||
| raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' " | |||
| "deployment name is exists in Azure AI") | |||
| else: | |||
| raise AzureRequestFailedError('Failed to request Azure OpenAI. Status code: {}'.format(response.status_code)) | |||
| raise AzureRequestFailedError( | |||
| 'Failed to request Azure OpenAI. cause: {}'.format(str(e))) | |||
| except openai.error.OpenAIError as e: | |||
| raise AzureRequestFailedError( | |||
| 'Failed to request Azure OpenAI. cause: {}'.format(str(e))) | |||
| if not isinstance(result, list): | |||
| raise AzureRequestFailedError('Failed to request Azure OpenAI.') | |||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||
| """ | |||
| @@ -94,31 +102,11 @@ class AzureProvider(BaseProvider): | |||
| if 'openai_api_version' not in config: | |||
| config['openai_api_version'] = '2023-03-15-preview' | |||
| models = self.get_models(credentials=config) | |||
| if not models: | |||
| raise ValidateFailedError("Please add deployments for " | |||
| "'gpt-3.5-turbo', 'text-embedding-ada-002' (required) " | |||
| "and 'gpt-4', 'gpt-35-turbo-16k', 'text-davinci-003' (optional).") | |||
| fixed_model_ids = [ | |||
| 'gpt-35-turbo', | |||
| 'text-embedding-ada-002' | |||
| ] | |||
| current_model_ids = [model['id'] for model in models] | |||
| missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if | |||
| fixed_model_id not in current_model_ids] | |||
| if missing_model_ids: | |||
| raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids))) | |||
| self.check_embedding_model(credentials=config) | |||
| except ValidateFailedError as e: | |||
| raise e | |||
| except AzureAuthenticationError: | |||
| raise ValidateFailedError('Validation failed, please check your API Key.') | |||
| except (requests.ConnectionError, requests.RequestException): | |||
| raise ValidateFailedError('Validation failed, please check your API Base Endpoint.') | |||
| except AzureRequestFailedError as ex: | |||
| raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex))) | |||
| except Exception as ex: | |||