| @@ -1,4 +1,4 @@ | |||
| from langchain.embeddings import XinferenceEmbeddings | |||
| from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings | |||
| from replicate.exceptions import ModelError, ReplicateError | |||
| from core.model_providers.error import LLMBadRequestError | |||
| @@ -14,7 +14,8 @@ class XinferenceEmbedding(BaseEmbedding): | |||
| ) | |||
| client = XinferenceEmbeddings( | |||
| **credentials, | |||
| server_url=credentials['server_url'], | |||
| model_uid=credentials['model_uid'], | |||
| ) | |||
| super().__init__(model_provider, client, name) | |||
| @@ -0,0 +1,21 @@ | |||
| from typing import List | |||
| import numpy as np | |||
| from langchain.embeddings import XinferenceEmbeddings | |||
| class XinferenceEmbedding(XinferenceEmbeddings): | |||
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |||
| vectors = super().embed_documents(texts) | |||
| normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors] | |||
| return normalized_vectors | |||
| def embed_query(self, text: str) -> List[float]: | |||
| vector = super().embed_query(text) | |||
| normalized_vector = (vector / np.linalg.norm(vector)).tolist() | |||
| return normalized_vector | |||