|
|
|
@@ -4,7 +4,8 @@ from typing import Optional |
|
|
|
|
|
|
|
from requests import post |
|
|
|
|
|
|
|
from core.model_runtime.entities.model_entities import PriceType |
|
|
|
from core.model_runtime.entities.common_entities import I18nObject |
|
|
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType |
|
|
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult |
|
|
|
from core.model_runtime.errors.invoke import ( |
|
|
|
InvokeAuthorizationError, |
|
|
|
@@ -23,8 +24,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): |
|
|
|
""" |
|
|
|
Model class for Jina text embedding model. |
|
|
|
""" |
|
|
|
api_base: str = 'https://api.jina.ai/v1/embeddings' |
|
|
|
models: list[str] = ['jina-embeddings-v2-base-en', 'jina-embeddings-v2-small-en', 'jina-embeddings-v2-base-zh', 'jina-embeddings-v2-base-de'] |
|
|
|
api_base: str = 'https://api.jina.ai/v1' |
|
|
|
|
|
|
|
def _invoke(self, model: str, credentials: dict, |
|
|
|
texts: list[str], user: Optional[str] = None) \ |
|
|
|
@@ -39,11 +39,14 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): |
|
|
|
:return: embeddings result |
|
|
|
""" |
|
|
|
api_key = credentials['api_key'] |
|
|
|
if model not in self.models: |
|
|
|
raise InvokeBadRequestError('Invalid model name') |
|
|
|
if not api_key: |
|
|
|
raise CredentialsValidateFailedError('api_key is required') |
|
|
|
url = self.api_base |
|
|
|
|
|
|
|
base_url = credentials.get('base_url', self.api_base) |
|
|
|
if base_url.endswith('/'): |
|
|
|
base_url = base_url[:-1] |
|
|
|
|
|
|
|
url = base_url + '/embeddings' |
|
|
|
headers = { |
|
|
|
'Authorization': 'Bearer ' + api_key, |
|
|
|
'Content-Type': 'application/json' |
|
|
|
@@ -70,7 +73,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): |
|
|
|
elif response.status_code == 500: |
|
|
|
raise InvokeServerUnavailableError(msg) |
|
|
|
else: |
|
|
|
raise InvokeError(msg) |
|
|
|
raise InvokeBadRequestError(msg) |
|
|
|
except JSONDecodeError as e: |
|
|
|
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") |
|
|
|
|
|
|
|
@@ -118,8 +121,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): |
|
|
|
""" |
|
|
|
try: |
|
|
|
self._invoke(model=model, credentials=credentials, texts=['ping']) |
|
|
|
except InvokeAuthorizationError: |
|
|
|
raise CredentialsValidateFailedError('Invalid api key') |
|
|
|
except Exception as e: |
|
|
|
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') |
|
|
|
|
|
|
|
@property |
|
|
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: |
|
|
|
@@ -137,7 +140,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): |
|
|
|
InvokeAuthorizationError |
|
|
|
], |
|
|
|
InvokeBadRequestError: [ |
|
|
|
KeyError |
|
|
|
KeyError, |
|
|
|
InvokeBadRequestError |
|
|
|
] |
|
|
|
} |
|
|
|
|
|
|
|
@@ -170,3 +174,19 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): |
|
|
|
) |
|
|
|
|
|
|
|
return usage |
|
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: |
|
|
|
""" |
|
|
|
generate custom model entities from credentials |
|
|
|
""" |
|
|
|
entity = AIModelEntity( |
|
|
|
model=model, |
|
|
|
label=I18nObject(en_US=model), |
|
|
|
model_type=ModelType.TEXT_EMBEDDING, |
|
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, |
|
|
|
model_properties={ |
|
|
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')) |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
return entity |