| @@ -47,6 +47,20 @@ class XinferenceSpeech2TextModel(Speech2TextModel): | |||
| if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: | |||
| raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") | |||
| if credentials['server_url'].endswith('/'): | |||
| credentials['server_url'] = credentials['server_url'][:-1] | |||
| # initialize client | |||
| client = Client( | |||
| base_url=credentials['server_url'] | |||
| ) | |||
| xinference_client = client.get_model(model_uid=credentials['model_uid']) | |||
| if not isinstance(xinference_client, RESTfulAudioModelHandle): | |||
| raise InvokeBadRequestError( | |||
| 'please check model type, the model you want to invoke is not a audio model') | |||
| audio_file_path = self._get_demo_file_path() | |||
| with open(audio_file_path, 'rb') as audio_file: | |||
| @@ -110,17 +124,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel): | |||
| if credentials['server_url'].endswith('/'): | |||
| credentials['server_url'] = credentials['server_url'][:-1] | |||
| # initialize client | |||
| client = Client( | |||
| base_url=credentials['server_url'] | |||
| ) | |||
| xinference_client = client.get_model(model_uid=credentials['model_uid']) | |||
| if not isinstance(xinference_client, RESTfulAudioModelHandle): | |||
| raise InvokeBadRequestError('please check model type, the model you want to invoke is not a audio model') | |||
| response = xinference_client.transcriptions( | |||
| handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={}) | |||
| response = handle.transcriptions( | |||
| audio=file, | |||
| language = language, | |||
| prompt = prompt, | |||