Bladeren bron

fix: inference embedding validate (#1187)

tags/0.3.23
takatost 2 jaren geleden
bovenliggende
commit
c8bd76cd66
No account linked to committer's email address
2 gewijzigde bestanden met toevoegingen van 17 en 8 verwijderingen
  1. 15
    6
      api/core/model_providers/providers/xinference_provider.py
  2. 2
    2
      api/requirements.txt

+ 15
- 6
api/core/model_providers/providers/xinference_provider.py Bestand weergeven

@@ -2,6 +2,7 @@ import json
from typing import Type

import requests
from langchain.embeddings import XinferenceEmbeddings

from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider):
'model_uid': credentials['model_uid'],
}

llm = XinferenceLLM(
**credential_kwargs
)
if model_type == ModelType.TEXT_GENERATION:
llm = XinferenceLLM(
**credential_kwargs
)

llm("ping")
elif model_type == ModelType.EMBEDDINGS:
embedding = XinferenceEmbeddings(
**credential_kwargs
)

llm("ping")
embedding.embed_query("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

@@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider):
:param credentials:
:return:
"""
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
if model_type == ModelType.TEXT_GENERATION:
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)

credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])


+ 2
- 2
api/requirements.txt Bestand weergeven

@@ -19,7 +19,7 @@ pytest~=7.3.1
pytest-mock~=3.11.1
tiktoken==0.3.3
Authlib==1.2.0
boto3~=1.26.123
boto3==1.28.17
tenacity==8.2.2
cachetools~=5.3.0
weaviate-client~=3.21.0
@@ -49,5 +49,5 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.2.1
xinference==0.4.2
safetensors==0.3.2

Laden…
Annuleren
Opslaan