| import httpx | import httpx | ||||
| from core.model_runtime.entities.common_entities import I18nObject | |||||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType | |||||
| from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | ||||
| from core.model_runtime.errors.invoke import ( | |||||
| InvokeAuthorizationError, | |||||
| InvokeBadRequestError, | |||||
| InvokeConnectionError, | |||||
| InvokeError, | |||||
| InvokeRateLimitError, | |||||
| InvokeServerUnavailableError, | |||||
| ) | |||||
| from core.model_runtime.errors.invoke import InvokeError | |||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| from core.model_runtime.model_providers.__base.rerank_model import RerankModel | from core.model_runtime.model_providers.__base.rerank_model import RerankModel | ||||
| from core.model_runtime.model_providers.wenxin._common import _CommonWenxin | from core.model_runtime.model_providers.wenxin._common import _CommonWenxin | ||||
| from core.model_runtime.model_providers.wenxin.wenxin_errors import ( | |||||
| InternalServerError, | |||||
| invoke_error_mapping, | |||||
| ) | |||||
| class WenxinRerank(_CommonWenxin): | class WenxinRerank(_CommonWenxin): | ||||
| response.raise_for_status() | response.raise_for_status() | ||||
| return response.json() | return response.json() | ||||
| except httpx.HTTPStatusError as e: | except httpx.HTTPStatusError as e: | ||||
| raise InvokeServerUnavailableError(str(e)) | |||||
| raise InternalServerError(str(e)) | |||||
| class WenxinRerankModel(RerankModel): | class WenxinRerankModel(RerankModel): | ||||
| return RerankResult(model=model, docs=rerank_documents) | return RerankResult(model=model, docs=rerank_documents) | ||||
| except httpx.HTTPStatusError as e: | except httpx.HTTPStatusError as e: | ||||
| raise InvokeServerUnavailableError(str(e)) | |||||
| raise InternalServerError(str(e)) | |||||
| def validate_credentials(self, model: str, credentials: dict) -> None: | def validate_credentials(self, model: str, credentials: dict) -> None: | ||||
| """ | """ | ||||
| """ | """ | ||||
| Map model invoke error to unified error | Map model invoke error to unified error | ||||
| """ | """ | ||||
| return { | |||||
| InvokeConnectionError: [httpx.ConnectError], | |||||
| InvokeServerUnavailableError: [httpx.RemoteProtocolError], | |||||
| InvokeRateLimitError: [], | |||||
| InvokeAuthorizationError: [httpx.HTTPStatusError], | |||||
| InvokeBadRequestError: [httpx.RequestError], | |||||
| } | |||||
| def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: | |||||
| """ | |||||
| generate custom model entities from credentials | |||||
| """ | |||||
| entity = AIModelEntity( | |||||
| model=model, | |||||
| label=I18nObject(en_US=model), | |||||
| model_type=ModelType.RERANK, | |||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||||
| model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, | |||||
| ) | |||||
| return entity | |||||
| return invoke_error_mapping() |