|
|
|
@@ -2,6 +2,7 @@ import json |
|
|
|
from typing import Type |
|
|
|
|
|
|
|
import requests |
|
|
|
from langchain.embeddings import XinferenceEmbeddings |
|
|
|
|
|
|
|
from core.helper import encrypter |
|
|
|
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding |
|
|
|
@@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider): |
|
|
|
'model_uid': credentials['model_uid'], |
|
|
|
} |
|
|
|
|
|
|
|
llm = XinferenceLLM( |
|
|
|
**credential_kwargs |
|
|
|
) |
|
|
|
if model_type == ModelType.TEXT_GENERATION: |
|
|
|
llm = XinferenceLLM( |
|
|
|
**credential_kwargs |
|
|
|
) |
|
|
|
|
|
|
|
llm("ping") |
|
|
|
elif model_type == ModelType.EMBEDDINGS: |
|
|
|
embedding = XinferenceEmbeddings( |
|
|
|
**credential_kwargs |
|
|
|
) |
|
|
|
|
|
|
|
llm("ping") |
|
|
|
embedding.embed_query("ping") |
|
|
|
except Exception as ex: |
|
|
|
raise CredentialsValidateFailedError(str(ex)) |
|
|
|
|
|
|
|
@@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider): |
|
|
|
:param credentials: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
extra_credentials = cls._get_extra_credentials(credentials) |
|
|
|
credentials.update(extra_credentials) |
|
|
|
if model_type == ModelType.TEXT_GENERATION: |
|
|
|
extra_credentials = cls._get_extra_credentials(credentials) |
|
|
|
credentials.update(extra_credentials) |
|
|
|
|
|
|
|
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url']) |
|
|
|
|