|
|
|
@@ -0,0 +1,81 @@ |
|
|
|
from typing import IO, Optional |
|
|
|
|
|
|
|
from openai import AzureOpenAI |
|
|
|
|
|
|
|
from core.model_runtime.entities.model_entities import AIModelEntity |
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError |
|
|
|
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel |
|
|
|
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI |
|
|
|
from core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel |
|
|
|
|
|
|
|
|
|
|
|
class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): |
|
|
|
""" |
|
|
|
Model class for OpenAI Speech to text model. |
|
|
|
""" |
|
|
|
|
|
|
|
def _invoke(self, model: str, credentials: dict, |
|
|
|
file: IO[bytes], user: Optional[str] = None) \ |
|
|
|
-> str: |
|
|
|
""" |
|
|
|
Invoke speech2text model |
|
|
|
|
|
|
|
:param model: model name |
|
|
|
:param credentials: model credentials |
|
|
|
:param file: audio file |
|
|
|
:param user: unique user id |
|
|
|
:return: text for given audio file |
|
|
|
""" |
|
|
|
return self._speech2text_invoke(model, credentials, file) |
|
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None: |
|
|
|
""" |
|
|
|
Validate model credentials |
|
|
|
|
|
|
|
:param model: model name |
|
|
|
:param credentials: model credentials |
|
|
|
:return: |
|
|
|
""" |
|
|
|
try: |
|
|
|
audio_file_path = self._get_demo_file_path() |
|
|
|
|
|
|
|
with open(audio_file_path, 'rb') as audio_file: |
|
|
|
self._speech2text_invoke(model, credentials, audio_file) |
|
|
|
except Exception as ex: |
|
|
|
raise CredentialsValidateFailedError(str(ex)) |
|
|
|
|
|
|
|
def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str: |
|
|
|
""" |
|
|
|
Invoke speech2text model |
|
|
|
|
|
|
|
:param model: model name |
|
|
|
:param credentials: model credentials |
|
|
|
:param file: audio file |
|
|
|
:return: text for given audio file |
|
|
|
""" |
|
|
|
# transform credentials to kwargs for model instance |
|
|
|
credentials_kwargs = self._to_credential_kwargs(credentials) |
|
|
|
|
|
|
|
# init model client |
|
|
|
client = AzureOpenAI(**credentials_kwargs) |
|
|
|
|
|
|
|
response = client.audio.transcriptions.create(model=model, file=file) |
|
|
|
|
|
|
|
return response.text |
|
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: |
|
|
|
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) |
|
|
|
return ai_model_entity.entity |
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: |
|
|
|
for ai_model_entity in SPEECH2TEXT_BASE_MODELS: |
|
|
|
if ai_model_entity.base_model_name == base_model_name: |
|
|
|
ai_model_entity_copy = copy.deepcopy(ai_model_entity) |
|
|
|
ai_model_entity_copy.entity.model = model |
|
|
|
ai_model_entity_copy.entity.label.en_US = model |
|
|
|
ai_model_entity_copy.entity.label.zh_Hans = model |
|
|
|
return ai_model_entity_copy |
|
|
|
|
|
|
|
return None |