Преглед на файлове

feat: use xinference client instead of xinference (#1339)

tags/0.3.28
takatost преди 2 години
родител
ревизия
3efaa713da
No account linked to committer's email address

+ 1
- 2
api/core/model_providers/models/embedding/xinference_embedding.py Целия файл

@@ -1,8 +1,7 @@
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings

from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings


class XinferenceEmbedding(BaseEmbedding):

+ 1
- 1
api/core/model_providers/providers/xinference_provider.py Целия файл

@@ -2,7 +2,6 @@ import json
from typing import Type

import requests
from langchain.embeddings import XinferenceEmbeddings

from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@@ -11,6 +10,7 @@ from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError

from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
from models.provider import ProviderType


+ 38
- 5
api/core/third_party/langchain/embeddings/xinference_embedding.py Целия файл

@@ -1,21 +1,54 @@
from typing import List
from typing import List, Optional, Any

import numpy as np
from langchain.embeddings import XinferenceEmbeddings
from langchain.embeddings.base import Embeddings
from xinference_client.client.restful.restful_client import Client


class XinferenceEmbedding(XinferenceEmbeddings):
class XinferenceEmbeddings(Embeddings):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""

def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
):

super().__init__()

if server_url is None:
raise ValueError("Please provide server URL")

if model_uid is None:
raise ValueError("Please provide the model UID")

self.server_url = server_url

self.model_uid = model_uid

self.client = Client(server_url)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
vectors = super().embed_documents(texts)
model = self.client.get_model(self.model_uid)

embeddings = [
model.create_embedding(text)["data"][0]["embedding"] for text in texts
]
vectors = [list(map(float, e)) for e in embeddings]
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]

return normalized_vectors

def embed_query(self, text: str) -> List[float]:
vector = super().embed_query(text)
model = self.client.get_model(self.model_uid)

embedding_res = model.create_embedding(text)

embedding = embedding_res["data"][0]["embedding"]

vector = list(map(float, embedding))
normalized_vector = (vector / np.linalg.norm(vector)).tolist()

return normalized_vector

+ 42
- 5
api/core/third_party/langchain/llms/xinference_llm.py Целия файл

@@ -1,16 +1,53 @@
from typing import Optional, List, Any, Union, Generator
from typing import Optional, List, Any, Union, Generator, Mapping

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Xinference
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from xinference.client import (
from xinference_client.client.restful.restful_client import (
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,
RESTfulGenerateModelHandle, Client,
)


class XinferenceLLM(Xinference):
class XinferenceLLM(LLM):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""

def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
):
super().__init__(
**{
"server_url": server_url,
"model_uid": model_uid,
}
)

if self.server_url is None:
raise ValueError("Please provide server URL")

if self.model_uid is None:
raise ValueError("Please provide the model UID")

self.client = Client(server_url)

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "xinference"

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"server_url": self.server_url},
**{"model_uid": self.model_uid},
}

def _call(
self,
prompt: str,

+ 1
- 1
api/requirements.txt Целия файл

@@ -49,7 +49,7 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.5.2
xinference-client~=0.1.2
safetensors==0.3.2
zhipuai==1.0.7
werkzeug==2.3.7

Loading…
Отказ
Запис