| @@ -85,7 +85,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| tools=tools, stop=stop, stream=stream, user=user, | |||
| extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( | |||
| server_url=credentials['server_url'], | |||
| model_uid=credentials['model_uid'] | |||
| model_uid=credentials['model_uid'], | |||
| api_key=credentials.get('api_key'), | |||
| ) | |||
| ) | |||
| @@ -106,7 +107,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| extra_param = XinferenceHelper.get_xinference_extra_parameter( | |||
| server_url=credentials['server_url'], | |||
| model_uid=credentials['model_uid'] | |||
| model_uid=credentials['model_uid'], | |||
| api_key=credentials.get('api_key') | |||
| ) | |||
| if 'completion_type' not in credentials: | |||
| if 'chat' in extra_param.model_ability: | |||
| @@ -396,7 +398,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| else: | |||
| extra_args = XinferenceHelper.get_xinference_extra_parameter( | |||
| server_url=credentials['server_url'], | |||
| model_uid=credentials['model_uid'] | |||
| model_uid=credentials['model_uid'], | |||
| api_key=credentials.get('api_key') | |||
| ) | |||
| if 'chat' in extra_args.model_ability: | |||
| @@ -464,6 +467,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| xinference_client = Client( | |||
| base_url=credentials['server_url'], | |||
| api_key=credentials.get('api_key'), | |||
| ) | |||
| xinference_model = xinference_client.get_model(credentials['model_uid']) | |||
| @@ -108,7 +108,8 @@ class XinferenceRerankModel(RerankModel): | |||
| # initialize client | |||
| client = Client( | |||
| base_url=credentials['server_url'] | |||
| base_url=credentials['server_url'], | |||
| api_key=credentials.get('api_key'), | |||
| ) | |||
| xinference_client = client.get_model(model_uid=credentials['model_uid']) | |||
| @@ -52,7 +52,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel): | |||
| # initialize client | |||
| client = Client( | |||
| base_url=credentials['server_url'] | |||
| base_url=credentials['server_url'], | |||
| api_key=credentials.get('api_key'), | |||
| ) | |||
| xinference_client = client.get_model(model_uid=credentials['model_uid']) | |||
| @@ -110,14 +110,22 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): | |||
| server_url = credentials['server_url'] | |||
| model_uid = credentials['model_uid'] | |||
| extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) | |||
| api_key = credentials.get('api_key') | |||
| extra_args = XinferenceHelper.get_xinference_extra_parameter( | |||
| server_url=server_url, | |||
| model_uid=model_uid, | |||
| api_key=api_key, | |||
| ) | |||
| if extra_args.max_tokens: | |||
| credentials['max_tokens'] = extra_args.max_tokens | |||
| if server_url.endswith('/'): | |||
| server_url = server_url[:-1] | |||
| client = Client(base_url=server_url) | |||
| client = Client( | |||
| base_url=server_url, | |||
| api_key=api_key, | |||
| ) | |||
| try: | |||
| handle = client.get_model(model_uid=model_uid) | |||
| @@ -81,7 +81,8 @@ class XinferenceText2SpeechModel(TTSModel): | |||
| extra_param = XinferenceHelper.get_xinference_extra_parameter( | |||
| server_url=credentials['server_url'], | |||
| model_uid=credentials['model_uid'] | |||
| model_uid=credentials['model_uid'], | |||
| api_key=credentials.get('api_key'), | |||
| ) | |||
| if 'text-to-audio' not in extra_param.model_ability: | |||
| @@ -203,7 +204,11 @@ class XinferenceText2SpeechModel(TTSModel): | |||
| credentials['server_url'] = credentials['server_url'][:-1] | |||
| try: | |||
| handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={}) | |||
| api_key = credentials.get('api_key') | |||
| auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} | |||
| handle = RESTfulAudioModelHandle( | |||
| credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers | |||
| ) | |||
| model_support_voice = [x.get("value") for x in | |||
| self.get_tts_model_voices(model=model, credentials=credentials)] | |||
| @@ -35,13 +35,13 @@ cache_lock = Lock() | |||
| class XinferenceHelper: | |||
| @staticmethod | |||
| def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: | |||
| def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: | |||
| XinferenceHelper._clean_cache() | |||
| with cache_lock: | |||
| if model_uid not in cache: | |||
| cache[model_uid] = { | |||
| 'expires': time() + 300, | |||
| 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid) | |||
| 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key) | |||
| } | |||
| return cache[model_uid]['value'] | |||
| @@ -56,7 +56,7 @@ class XinferenceHelper: | |||
| pass | |||
| @staticmethod | |||
| def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: | |||
| def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: | |||
| """ | |||
| get xinference model extra parameter like model_format and model_handle_type | |||
| """ | |||
| @@ -70,9 +70,10 @@ class XinferenceHelper: | |||
| session = Session() | |||
| session.mount('http://', HTTPAdapter(max_retries=3)) | |||
| session.mount('https://', HTTPAdapter(max_retries=3)) | |||
| headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} | |||
| try: | |||
| response = session.get(url, timeout=10) | |||
| response = session.get(url, headers=headers, timeout=10) | |||
| except (MissingSchema, ConnectionError, Timeout) as e: | |||
| raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') | |||
| if response.status_code != 200: | |||