| @@ -1,8 +1,7 @@ | |||
| from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||
| from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings | |||
| class XinferenceEmbedding(BaseEmbedding): | |||
| @@ -2,7 +2,6 @@ import json | |||
| from typing import Type | |||
| import requests | |||
| from langchain.embeddings import XinferenceEmbeddings | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding | |||
| @@ -11,6 +10,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings | |||
| from core.third_party.langchain.llms.xinference_llm import XinferenceLLM | |||
| from models.provider import ProviderType | |||
| @@ -1,21 +1,54 @@ | |||
| from typing import List | |||
| from typing import List, Optional, Any | |||
| import numpy as np | |||
| from langchain.embeddings import XinferenceEmbeddings | |||
| from langchain.embeddings.base import Embeddings | |||
| from xinference_client.client.restful.restful_client import Client | |||
| class XinferenceEmbedding(XinferenceEmbeddings): | |||
| class XinferenceEmbeddings(Embeddings): | |||
| client: Any | |||
| server_url: Optional[str] | |||
| """URL of the xinference server""" | |||
| model_uid: Optional[str] | |||
| """UID of the launched model""" | |||
| def __init__( | |||
| self, server_url: Optional[str] = None, model_uid: Optional[str] = None | |||
| ): | |||
| super().__init__() | |||
| if server_url is None: | |||
| raise ValueError("Please provide server URL") | |||
| if model_uid is None: | |||
| raise ValueError("Please provide the model UID") | |||
| self.server_url = server_url | |||
| self.model_uid = model_uid | |||
| self.client = Client(server_url) | |||
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |||
| vectors = super().embed_documents(texts) | |||
| model = self.client.get_model(self.model_uid) | |||
| embeddings = [ | |||
| model.create_embedding(text)["data"][0]["embedding"] for text in texts | |||
| ] | |||
| vectors = [list(map(float, e)) for e in embeddings] | |||
| normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors] | |||
| return normalized_vectors | |||
| def embed_query(self, text: str) -> List[float]: | |||
| vector = super().embed_query(text) | |||
| model = self.client.get_model(self.model_uid) | |||
| embedding_res = model.create_embedding(text) | |||
| embedding = embedding_res["data"][0]["embedding"] | |||
| vector = list(map(float, embedding)) | |||
| normalized_vector = (vector / np.linalg.norm(vector)).tolist() | |||
| return normalized_vector | |||
| @@ -1,16 +1,53 @@ | |||
| from typing import Optional, List, Any, Union, Generator | |||
| from typing import Optional, List, Any, Union, Generator, Mapping | |||
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |||
| from langchain.llms import Xinference | |||
| from langchain.llms.base import LLM | |||
| from langchain.llms.utils import enforce_stop_tokens | |||
| from xinference.client import ( | |||
| from xinference_client.client.restful.restful_client import ( | |||
| RESTfulChatglmCppChatModelHandle, | |||
| RESTfulChatModelHandle, | |||
| RESTfulGenerateModelHandle, | |||
| RESTfulGenerateModelHandle, Client, | |||
| ) | |||
| class XinferenceLLM(Xinference): | |||
| class XinferenceLLM(LLM): | |||
| client: Any | |||
| server_url: Optional[str] | |||
| """URL of the xinference server""" | |||
| model_uid: Optional[str] | |||
| """UID of the launched model""" | |||
| def __init__( | |||
| self, server_url: Optional[str] = None, model_uid: Optional[str] = None | |||
| ): | |||
| super().__init__( | |||
| **{ | |||
| "server_url": server_url, | |||
| "model_uid": model_uid, | |||
| } | |||
| ) | |||
| if self.server_url is None: | |||
| raise ValueError("Please provide server URL") | |||
| if self.model_uid is None: | |||
| raise ValueError("Please provide the model UID") | |||
| self.client = Client(server_url) | |||
| @property | |||
| def _llm_type(self) -> str: | |||
| """Return type of llm.""" | |||
| return "xinference" | |||
| @property | |||
| def _identifying_params(self) -> Mapping[str, Any]: | |||
| """Get the identifying parameters.""" | |||
| return { | |||
| **{"server_url": self.server_url}, | |||
| **{"model_uid": self.model_uid}, | |||
| } | |||
| def _call( | |||
| self, | |||
| prompt: str, | |||
| @@ -49,7 +49,7 @@ huggingface_hub~=0.16.4 | |||
| transformers~=4.31.0 | |||
| stripe~=5.5.0 | |||
| pandas==1.5.3 | |||
| xinference==0.5.2 | |||
| xinference-client~=0.1.2 | |||
| safetensors==0.3.2 | |||
| zhipuai==1.0.7 | |||
| werkzeug==2.3.7 | |||