|
|
|
@@ -2,20 +2,15 @@ from typing import Optional |
|
|
|
|
|
|
|
import httpx |
|
|
|
|
|
|
|
from core.model_runtime.entities.common_entities import I18nObject |
|
|
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType |
|
|
|
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult |
|
|
|
from core.model_runtime.errors.invoke import ( |
|
|
|
InvokeAuthorizationError, |
|
|
|
InvokeBadRequestError, |
|
|
|
InvokeConnectionError, |
|
|
|
InvokeError, |
|
|
|
InvokeRateLimitError, |
|
|
|
InvokeServerUnavailableError, |
|
|
|
) |
|
|
|
from core.model_runtime.errors.invoke import InvokeError |
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError |
|
|
|
from core.model_runtime.model_providers.__base.rerank_model import RerankModel |
|
|
|
from core.model_runtime.model_providers.wenxin._common import _CommonWenxin |
|
|
|
from core.model_runtime.model_providers.wenxin.wenxin_errors import ( |
|
|
|
InternalServerError, |
|
|
|
invoke_error_mapping, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class WenxinRerank(_CommonWenxin): |
|
|
|
@@ -32,7 +27,7 @@ class WenxinRerank(_CommonWenxin): |
|
|
|
response.raise_for_status() |
|
|
|
return response.json() |
|
|
|
except httpx.HTTPStatusError as e: |
|
|
|
raise InvokeServerUnavailableError(str(e)) |
|
|
|
raise InternalServerError(str(e)) |
|
|
|
|
|
|
|
|
|
|
|
class WenxinRerankModel(RerankModel): |
|
|
|
@@ -93,7 +88,7 @@ class WenxinRerankModel(RerankModel): |
|
|
|
|
|
|
|
return RerankResult(model=model, docs=rerank_documents) |
|
|
|
except httpx.HTTPStatusError as e: |
|
|
|
raise InvokeServerUnavailableError(str(e)) |
|
|
|
raise InternalServerError(str(e)) |
|
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None: |
|
|
|
""" |
|
|
|
@@ -124,24 +119,4 @@ class WenxinRerankModel(RerankModel): |
|
|
|
""" |
|
|
|
Map model invoke error to unified error |
|
|
|
""" |
|
|
|
return { |
|
|
|
InvokeConnectionError: [httpx.ConnectError], |
|
|
|
InvokeServerUnavailableError: [httpx.RemoteProtocolError], |
|
|
|
InvokeRateLimitError: [], |
|
|
|
InvokeAuthorizationError: [httpx.HTTPStatusError], |
|
|
|
InvokeBadRequestError: [httpx.RequestError], |
|
|
|
} |
|
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: |
|
|
|
""" |
|
|
|
generate custom model entities from credentials |
|
|
|
""" |
|
|
|
entity = AIModelEntity( |
|
|
|
model=model, |
|
|
|
label=I18nObject(en_US=model), |
|
|
|
model_type=ModelType.RERANK, |
|
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, |
|
|
|
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, |
|
|
|
) |
|
|
|
|
|
|
|
return entity |
|
|
|
return invoke_error_mapping() |