| @@ -2,20 +2,15 @@ from typing import Optional | |||
| import httpx | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType | |||
| from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | |||
| from core.model_runtime.errors.invoke import ( | |||
| InvokeAuthorizationError, | |||
| InvokeBadRequestError, | |||
| InvokeConnectionError, | |||
| InvokeError, | |||
| InvokeRateLimitError, | |||
| InvokeServerUnavailableError, | |||
| ) | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.rerank_model import RerankModel | |||
| from core.model_runtime.model_providers.wenxin._common import _CommonWenxin | |||
| from core.model_runtime.model_providers.wenxin.wenxin_errors import ( | |||
| InternalServerError, | |||
| invoke_error_mapping, | |||
| ) | |||
| class WenxinRerank(_CommonWenxin): | |||
| @@ -32,7 +27,7 @@ class WenxinRerank(_CommonWenxin): | |||
| response.raise_for_status() | |||
| return response.json() | |||
| except httpx.HTTPStatusError as e: | |||
| raise InvokeServerUnavailableError(str(e)) | |||
| raise InternalServerError(str(e)) | |||
| class WenxinRerankModel(RerankModel): | |||
| @@ -93,7 +88,7 @@ class WenxinRerankModel(RerankModel): | |||
| return RerankResult(model=model, docs=rerank_documents) | |||
| except httpx.HTTPStatusError as e: | |||
| raise InvokeServerUnavailableError(str(e)) | |||
| raise InternalServerError(str(e)) | |||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||
| """ | |||
| @@ -124,24 +119,4 @@ class WenxinRerankModel(RerankModel): | |||
| """ | |||
| Map model invoke error to unified error | |||
| """ | |||
| return { | |||
| InvokeConnectionError: [httpx.ConnectError], | |||
| InvokeServerUnavailableError: [httpx.RemoteProtocolError], | |||
| InvokeRateLimitError: [], | |||
| InvokeAuthorizationError: [httpx.HTTPStatusError], | |||
| InvokeBadRequestError: [httpx.RequestError], | |||
| } | |||
| def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: | |||
| """ | |||
| generate custom model entities from credentials | |||
| """ | |||
| entity = AIModelEntity( | |||
| model=model, | |||
| label=I18nObject(en_US=model), | |||
| model_type=ModelType.RERANK, | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, | |||
| ) | |||
| return entity | |||
| return invoke_error_mapping() | |||