| from datetime import datetime, timedelta | |||||
| from threading import Lock | |||||
| from requests import post | |||||
| from core.model_runtime.model_providers.wenxin.wenxin_errors import ( | |||||
| BadRequestError, | |||||
| InternalServerError, | |||||
| InvalidAPIKeyError, | |||||
| InvalidAuthenticationError, | |||||
| RateLimitReachedError, | |||||
| ) | |||||
| baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} | |||||
| baidu_access_tokens_lock = Lock() | |||||
| class BaiduAccessToken: | |||||
| api_key: str | |||||
| access_token: str | |||||
| expires: datetime | |||||
| def __init__(self, api_key: str) -> None: | |||||
| self.api_key = api_key | |||||
| self.access_token = '' | |||||
| self.expires = datetime.now() + timedelta(days=3) | |||||
| @staticmethod | |||||
| def _get_access_token(api_key: str, secret_key: str) -> str: | |||||
| """ | |||||
| request access token from Baidu | |||||
| """ | |||||
| try: | |||||
| response = post( | |||||
| url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', | |||||
| headers={ | |||||
| 'Content-Type': 'application/json', | |||||
| 'Accept': 'application/json' | |||||
| }, | |||||
| ) | |||||
| except Exception as e: | |||||
| raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') | |||||
| resp = response.json() | |||||
| if 'error' in resp: | |||||
| if resp['error'] == 'invalid_client': | |||||
| raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') | |||||
| elif resp['error'] == 'unknown_error': | |||||
| raise InternalServerError(f'Internal server error: {resp["error_description"]}') | |||||
| elif resp['error'] == 'invalid_request': | |||||
| raise BadRequestError(f'Bad request: {resp["error_description"]}') | |||||
| elif resp['error'] == 'rate_limit_exceeded': | |||||
| raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') | |||||
| else: | |||||
| raise Exception(f'Unknown error: {resp["error_description"]}') | |||||
| return resp['access_token'] | |||||
| @staticmethod | |||||
| def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': | |||||
| """ | |||||
| LLM from Baidu requires access token to invoke the API. | |||||
| however, we have api_key and secret_key, and access token is valid for 30 days. | |||||
| so we can cache the access token for 3 days. (avoid memory leak) | |||||
| it may be more efficient to use a ticker to refresh access token, but it will cause | |||||
| more complexity, so we just refresh access tokens when get_access_token is called. | |||||
| """ | |||||
| # loop up cache, remove expired access token | |||||
| baidu_access_tokens_lock.acquire() | |||||
| now = datetime.now() | |||||
| for key in list(baidu_access_tokens.keys()): | |||||
| token = baidu_access_tokens[key] | |||||
| if token.expires < now: | |||||
| baidu_access_tokens.pop(key) | |||||
| if api_key not in baidu_access_tokens: | |||||
| # if access token not in cache, request it | |||||
| token = BaiduAccessToken(api_key) | |||||
| baidu_access_tokens[api_key] = token | |||||
| # release it to enhance performance | |||||
| # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock | |||||
| baidu_access_tokens_lock.release() | |||||
| # try to get access token | |||||
| token_str = BaiduAccessToken._get_access_token(api_key, secret_key) | |||||
| token.access_token = token_str | |||||
| token.expires = now + timedelta(days=3) | |||||
| return token | |||||
| else: | |||||
| # if access token in cache, return it | |||||
| token = baidu_access_tokens[api_key] | |||||
| baidu_access_tokens_lock.release() | |||||
| return token | |||||
| class _CommonWenxin: | |||||
| api_bases = { | |||||
| 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', | |||||
| 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', | |||||
| 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', | |||||
| 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', | |||||
| 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', | |||||
| 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', | |||||
| 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', | |||||
| 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', | |||||
| 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', | |||||
| 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', | |||||
| 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', | |||||
| 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', | |||||
| 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', | |||||
| 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', | |||||
| 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', | |||||
| 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', | |||||
| 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', | |||||
| 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', | |||||
| 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', | |||||
| 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', | |||||
| 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', | |||||
| 'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1', | |||||
| } | |||||
| function_calling_supports = [ | |||||
| 'ernie-bot', | |||||
| 'ernie-bot-8k', | |||||
| 'ernie-3.5-8k', | |||||
| 'ernie-3.5-8k-0205', | |||||
| 'ernie-3.5-8k-1222', | |||||
| 'ernie-3.5-4k-0205', | |||||
| 'ernie-3.5-128k', | |||||
| 'ernie-4.0-8k', | |||||
| 'ernie-4.0-turbo-8k', | |||||
| 'ernie-4.0-turbo-8k-preview', | |||||
| 'yi_34b_chat' | |||||
| ] | |||||
| api_key: str = '' | |||||
| secret_key: str = '' | |||||
| def __init__(self, api_key: str, secret_key: str): | |||||
| self.api_key = api_key | |||||
| self.secret_key = secret_key | |||||
| @staticmethod | |||||
| def _to_credential_kwargs(credentials: dict) -> dict: | |||||
| credentials_kwargs = { | |||||
| "api_key": credentials['api_key'], | |||||
| "secret_key": credentials['secret_key'] | |||||
| } | |||||
| return credentials_kwargs | |||||
| def _handle_error(self, code: int, msg: str): | |||||
| error_map = { | |||||
| 1: InternalServerError, | |||||
| 2: InternalServerError, | |||||
| 3: BadRequestError, | |||||
| 4: RateLimitReachedError, | |||||
| 6: InvalidAuthenticationError, | |||||
| 13: InvalidAPIKeyError, | |||||
| 14: InvalidAPIKeyError, | |||||
| 15: InvalidAPIKeyError, | |||||
| 17: RateLimitReachedError, | |||||
| 18: RateLimitReachedError, | |||||
| 19: RateLimitReachedError, | |||||
| 100: InvalidAPIKeyError, | |||||
| 111: InvalidAPIKeyError, | |||||
| 200: InternalServerError, | |||||
| 336000: InternalServerError, | |||||
| 336001: BadRequestError, | |||||
| 336002: BadRequestError, | |||||
| 336003: BadRequestError, | |||||
| 336004: InvalidAuthenticationError, | |||||
| 336005: InvalidAPIKeyError, | |||||
| 336006: BadRequestError, | |||||
| 336007: BadRequestError, | |||||
| 336008: BadRequestError, | |||||
| 336100: InternalServerError, | |||||
| 336101: BadRequestError, | |||||
| 336102: BadRequestError, | |||||
| 336103: BadRequestError, | |||||
| 336104: BadRequestError, | |||||
| 336105: BadRequestError, | |||||
| 336200: InternalServerError, | |||||
| 336303: BadRequestError, | |||||
| 337006: BadRequestError | |||||
| } | |||||
| if code in error_map: | |||||
| raise error_map[code](msg) | |||||
| else: | |||||
| raise InternalServerError(f'Unknown error: {msg}') | |||||
| def _get_access_token(self) -> str: | |||||
| token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) | |||||
| return token.access_token | 
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from datetime import datetime, timedelta | |||||
| from enum import Enum | from enum import Enum | ||||
| from json import dumps, loads | from json import dumps, loads | ||||
| from threading import Lock | |||||
| from typing import Any, Union | from typing import Any, Union | ||||
| from requests import Response, post | from requests import Response, post | ||||
| from core.model_runtime.entities.message_entities import PromptMessageTool | from core.model_runtime.entities.message_entities import PromptMessageTool | ||||
| from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( | |||||
| from core.model_runtime.model_providers.wenxin._common import _CommonWenxin | |||||
| from core.model_runtime.model_providers.wenxin.wenxin_errors import ( | |||||
| BadRequestError, | BadRequestError, | ||||
| InternalServerError, | InternalServerError, | ||||
| InvalidAPIKeyError, | |||||
| InvalidAuthenticationError, | |||||
| RateLimitReachedError, | |||||
| ) | ) | ||||
| # map api_key to access_token | |||||
| baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} | |||||
| baidu_access_tokens_lock = Lock() | |||||
| class BaiduAccessToken: | |||||
| api_key: str | |||||
| access_token: str | |||||
| expires: datetime | |||||
| def __init__(self, api_key: str) -> None: | |||||
| self.api_key = api_key | |||||
| self.access_token = '' | |||||
| self.expires = datetime.now() + timedelta(days=3) | |||||
| def _get_access_token(api_key: str, secret_key: str) -> str: | |||||
| """ | |||||
| request access token from Baidu | |||||
| """ | |||||
| try: | |||||
| response = post( | |||||
| url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', | |||||
| headers={ | |||||
| 'Content-Type': 'application/json', | |||||
| 'Accept': 'application/json' | |||||
| }, | |||||
| ) | |||||
| except Exception as e: | |||||
| raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') | |||||
| resp = response.json() | |||||
| if 'error' in resp: | |||||
| if resp['error'] == 'invalid_client': | |||||
| raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') | |||||
| elif resp['error'] == 'unknown_error': | |||||
| raise InternalServerError(f'Internal server error: {resp["error_description"]}') | |||||
| elif resp['error'] == 'invalid_request': | |||||
| raise BadRequestError(f'Bad request: {resp["error_description"]}') | |||||
| elif resp['error'] == 'rate_limit_exceeded': | |||||
| raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') | |||||
| else: | |||||
| raise Exception(f'Unknown error: {resp["error_description"]}') | |||||
| return resp['access_token'] | |||||
| @staticmethod | |||||
| def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': | |||||
| """ | |||||
| LLM from Baidu requires access token to invoke the API. | |||||
| however, we have api_key and secret_key, and access token is valid for 30 days. | |||||
| so we can cache the access token for 3 days. (avoid memory leak) | |||||
| it may be more efficient to use a ticker to refresh access token, but it will cause | |||||
| more complexity, so we just refresh access tokens when get_access_token is called. | |||||
| """ | |||||
| # loop up cache, remove expired access token | |||||
| baidu_access_tokens_lock.acquire() | |||||
| now = datetime.now() | |||||
| for key in list(baidu_access_tokens.keys()): | |||||
| token = baidu_access_tokens[key] | |||||
| if token.expires < now: | |||||
| baidu_access_tokens.pop(key) | |||||
| if api_key not in baidu_access_tokens: | |||||
| # if access token not in cache, request it | |||||
| token = BaiduAccessToken(api_key) | |||||
| baidu_access_tokens[api_key] = token | |||||
| # release it to enhance performance | |||||
| # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock | |||||
| baidu_access_tokens_lock.release() | |||||
| # try to get access token | |||||
| token_str = BaiduAccessToken._get_access_token(api_key, secret_key) | |||||
| token.access_token = token_str | |||||
| token.expires = now + timedelta(days=3) | |||||
| return token | |||||
| else: | |||||
| # if access token in cache, return it | |||||
| token = baidu_access_tokens[api_key] | |||||
| baidu_access_tokens_lock.release() | |||||
| return token | |||||
| class ErnieMessage: | class ErnieMessage: | ||||
| class Role(Enum): | class Role(Enum): | ||||
| self.content = content | self.content = content | ||||
| self.role = role | self.role = role | ||||
| class ErnieBotModel: | |||||
| api_bases = { | |||||
| 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', | |||||
| 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', | |||||
| 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', | |||||
| 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', | |||||
| 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', | |||||
| 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', | |||||
| 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', | |||||
| 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', | |||||
| 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', | |||||
| 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', | |||||
| 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', | |||||
| 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', | |||||
| 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', | |||||
| 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', | |||||
| 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', | |||||
| 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', | |||||
| 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', | |||||
| 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', | |||||
| 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', | |||||
| 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', | |||||
| 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', | |||||
| } | |||||
| function_calling_supports = [ | |||||
| 'ernie-bot', | |||||
| 'ernie-bot-8k', | |||||
| 'ernie-3.5-8k', | |||||
| 'ernie-3.5-8k-0205', | |||||
| 'ernie-3.5-8k-1222', | |||||
| 'ernie-3.5-4k-0205', | |||||
| 'ernie-3.5-128k', | |||||
| 'ernie-4.0-8k', | |||||
| 'ernie-4.0-turbo-8k', | |||||
| 'ernie-4.0-turbo-8k-preview', | |||||
| 'yi_34b_chat' | |||||
| ] | |||||
| api_key: str = '' | |||||
| secret_key: str = '' | |||||
| def __init__(self, api_key: str, secret_key: str): | |||||
| self.api_key = api_key | |||||
| self.secret_key = secret_key | |||||
| class ErnieBotModel(_CommonWenxin): | |||||
| def generate(self, model: str, stream: bool, messages: list[ErnieMessage], | def generate(self, model: str, stream: bool, messages: list[ErnieMessage], | ||||
| parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ | parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ | ||||
| return self._handle_chat_stream_generate_response(resp) | return self._handle_chat_stream_generate_response(resp) | ||||
| return self._handle_chat_generate_response(resp) | return self._handle_chat_generate_response(resp) | ||||
| def _handle_error(self, code: int, msg: str): | |||||
| error_map = { | |||||
| 1: InternalServerError, | |||||
| 2: InternalServerError, | |||||
| 3: BadRequestError, | |||||
| 4: RateLimitReachedError, | |||||
| 6: InvalidAuthenticationError, | |||||
| 13: InvalidAPIKeyError, | |||||
| 14: InvalidAPIKeyError, | |||||
| 15: InvalidAPIKeyError, | |||||
| 17: RateLimitReachedError, | |||||
| 18: RateLimitReachedError, | |||||
| 19: RateLimitReachedError, | |||||
| 100: InvalidAPIKeyError, | |||||
| 111: InvalidAPIKeyError, | |||||
| 200: InternalServerError, | |||||
| 336000: InternalServerError, | |||||
| 336001: BadRequestError, | |||||
| 336002: BadRequestError, | |||||
| 336003: BadRequestError, | |||||
| 336004: InvalidAuthenticationError, | |||||
| 336005: InvalidAPIKeyError, | |||||
| 336006: BadRequestError, | |||||
| 336007: BadRequestError, | |||||
| 336008: BadRequestError, | |||||
| 336100: InternalServerError, | |||||
| 336101: BadRequestError, | |||||
| 336102: BadRequestError, | |||||
| 336103: BadRequestError, | |||||
| 336104: BadRequestError, | |||||
| 336105: BadRequestError, | |||||
| 336200: InternalServerError, | |||||
| 336303: BadRequestError, | |||||
| 337006: BadRequestError | |||||
| } | |||||
| if code in error_map: | |||||
| raise error_map[code](msg) | |||||
| else: | |||||
| raise InternalServerError(f'Unknown error: {msg}') | |||||
| def _get_access_token(self) -> str: | |||||
| token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) | |||||
| return token.access_token | |||||
| def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: | def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: | ||||
| return [ErnieMessage(message.content, message.role) for message in messages] | return [ErnieMessage(message.content, message.role) for message in messages] | ||||
| class InvalidAuthenticationError(Exception): | |||||
| pass | |||||
| class InvalidAPIKeyError(Exception): | |||||
| pass | |||||
| class RateLimitReachedError(Exception): | |||||
| pass | |||||
| class InsufficientAccountBalance(Exception): | |||||
| pass | |||||
| class InternalServerError(Exception): | |||||
| pass | |||||
| class BadRequestError(Exception): | |||||
| pass | 
| UserPromptMessage, | UserPromptMessage, | ||||
| ) | ) | ||||
| from core.model_runtime.errors.invoke import ( | from core.model_runtime.errors.invoke import ( | ||||
| InvokeAuthorizationError, | |||||
| InvokeBadRequestError, | |||||
| InvokeConnectionError, | |||||
| InvokeError, | InvokeError, | ||||
| InvokeRateLimitError, | |||||
| InvokeServerUnavailableError, | |||||
| ) | ) | ||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage | |||||
| from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( | |||||
| BadRequestError, | |||||
| InsufficientAccountBalance, | |||||
| InternalServerError, | |||||
| InvalidAPIKeyError, | |||||
| InvalidAuthenticationError, | |||||
| RateLimitReachedError, | |||||
| ) | |||||
| from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken | |||||
| from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage | |||||
| from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping | |||||
| ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. | ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. | ||||
| The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure | The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure | ||||
| api_key = credentials['api_key'] | api_key = credentials['api_key'] | ||||
| secret_key = credentials['secret_key'] | secret_key = credentials['secret_key'] | ||||
| try: | try: | ||||
| BaiduAccessToken._get_access_token(api_key, secret_key) | |||||
| BaiduAccessToken.get_access_token(api_key, secret_key) | |||||
| except Exception as e: | except Exception as e: | ||||
| raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') | raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') | ||||
| :return: Invoke error mapping | :return: Invoke error mapping | ||||
| """ | """ | ||||
| return { | |||||
| InvokeConnectionError: [ | |||||
| ], | |||||
| InvokeServerUnavailableError: [ | |||||
| InternalServerError | |||||
| ], | |||||
| InvokeRateLimitError: [ | |||||
| RateLimitReachedError | |||||
| ], | |||||
| InvokeAuthorizationError: [ | |||||
| InvalidAuthenticationError, | |||||
| InsufficientAccountBalance, | |||||
| InvalidAPIKeyError, | |||||
| ], | |||||
| InvokeBadRequestError: [ | |||||
| BadRequestError, | |||||
| KeyError | |||||
| ] | |||||
| } | |||||
| return invoke_error_mapping() | 
| model: embedding-v1 | |||||
| model_type: text-embedding | |||||
| model_properties: | |||||
| context_size: 384 | |||||
| max_chunks: 16 | |||||
| pricing: | |||||
| input: '0.0005' | |||||
| unit: '0.001' | |||||
| currency: RMB | 
| import time | |||||
| from abc import abstractmethod | |||||
| from collections.abc import Mapping | |||||
| from json import dumps | |||||
| from typing import Any, Optional | |||||
| import numpy as np | |||||
| from requests import Response, post | |||||
| from core.model_runtime.entities.model_entities import PriceType | |||||
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | |||||
| from core.model_runtime.errors.invoke import InvokeError | |||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||||
| from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin | |||||
| from core.model_runtime.model_providers.wenxin.wenxin_errors import ( | |||||
| BadRequestError, | |||||
| InternalServerError, | |||||
| invoke_error_mapping, | |||||
| ) | |||||
| class TextEmbedding: | |||||
| @abstractmethod | |||||
| def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): | |||||
| raise NotImplementedError | |||||
| class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): | |||||
| def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): | |||||
| access_token = self._get_access_token() | |||||
| url = f'{self.api_bases[model]}?access_token={access_token}' | |||||
| body = self._build_embed_request_body(model, texts, user) | |||||
| headers = { | |||||
| 'Content-Type': 'application/json', | |||||
| } | |||||
| resp = post(url, data=dumps(body), headers=headers) | |||||
| if resp.status_code != 200: | |||||
| raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') | |||||
| return self._handle_embed_response(model, resp) | |||||
| def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]: | |||||
| if len(texts) == 0: | |||||
| raise BadRequestError('The number of texts should not be zero.') | |||||
| body = { | |||||
| 'input': texts, | |||||
| 'user_id': user, | |||||
| } | |||||
| return body | |||||
| def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): | |||||
| data = response.json() | |||||
| if 'error_code' in data: | |||||
| code = data['error_code'] | |||||
| msg = data['error_msg'] | |||||
| # raise error | |||||
| self._handle_error(code, msg) | |||||
| embeddings = [v['embedding'] for v in data['data']] | |||||
| _usage = data['usage'] | |||||
| tokens = _usage['prompt_tokens'] | |||||
| total_tokens = _usage['total_tokens'] | |||||
| return embeddings, tokens, total_tokens | |||||
| class WenxinTextEmbeddingModel(TextEmbeddingModel): | |||||
| def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding: | |||||
| return WenxinTextEmbedding(api_key, secret_key) | |||||
| def _invoke(self, model: str, credentials: dict, texts: list[str], | |||||
| user: Optional[str] = None) -> TextEmbeddingResult: | |||||
| """ | |||||
| Invoke text embedding model | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param texts: texts to embed | |||||
| :param user: unique user id | |||||
| :return: embeddings result | |||||
| """ | |||||
| api_key = credentials['api_key'] | |||||
| secret_key = credentials['secret_key'] | |||||
| embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key) | |||||
| user = user if user else 'ErnieBotDefault' | |||||
| context_size = self._get_context_size(model, credentials) | |||||
| max_chunks = self._get_max_chunks(model, credentials) | |||||
| inputs = [] | |||||
| indices = [] | |||||
| used_tokens = 0 | |||||
| used_total_tokens = 0 | |||||
| for i, text in enumerate(texts): | |||||
| # Here token count is only an approximation based on the GPT2 tokenizer | |||||
| num_tokens = self._get_num_tokens_by_gpt2(text) | |||||
| if num_tokens >= context_size: | |||||
| cutoff = int(np.floor(len(text) * (context_size / num_tokens))) | |||||
| # if num tokens is larger than context length, only use the start | |||||
| inputs.append(text[0:cutoff]) | |||||
| else: | |||||
| inputs.append(text) | |||||
| indices += [i] | |||||
| batched_embeddings = [] | |||||
| _iter = range(0, len(inputs), max_chunks) | |||||
| for i in _iter: | |||||
| embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents( | |||||
| model, | |||||
| inputs[i: i + max_chunks], | |||||
| user) | |||||
| used_tokens += _used_tokens | |||||
| used_total_tokens += _total_used_tokens | |||||
| batched_embeddings += embeddings_batch | |||||
| usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens) | |||||
| return TextEmbeddingResult( | |||||
| model=model, | |||||
| embeddings=batched_embeddings, | |||||
| usage=usage, | |||||
| ) | |||||
| def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: | |||||
| """ | |||||
| Get number of tokens for given prompt messages | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param texts: texts to embed | |||||
| :return: | |||||
| """ | |||||
| if len(texts) == 0: | |||||
| return 0 | |||||
| total_num_tokens = 0 | |||||
| for text in texts: | |||||
| total_num_tokens += self._get_num_tokens_by_gpt2(text) | |||||
| return total_num_tokens | |||||
| def validate_credentials(self, model: str, credentials: Mapping) -> None: | |||||
| api_key = credentials['api_key'] | |||||
| secret_key = credentials['secret_key'] | |||||
| try: | |||||
| BaiduAccessToken.get_access_token(api_key, secret_key) | |||||
| except Exception as e: | |||||
| raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') | |||||
| @property | |||||
| def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | |||||
| return invoke_error_mapping() | |||||
| def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage: | |||||
| """ | |||||
| Calculate response usage | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param tokens: input tokens | |||||
| :return: usage | |||||
| """ | |||||
| # get input price info | |||||
| input_price_info = self.get_price( | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| price_type=PriceType.INPUT, | |||||
| tokens=tokens | |||||
| ) | |||||
| # transform usage | |||||
| usage = EmbeddingUsage( | |||||
| tokens=tokens, | |||||
| total_tokens=total_tokens, | |||||
| unit_price=input_price_info.unit_price, | |||||
| price_unit=input_price_info.unit, | |||||
| total_price=input_price_info.total_amount, | |||||
| currency=input_price_info.currency, | |||||
| latency=time.perf_counter() - self.started_at | |||||
| ) | |||||
| return usage | 
| en_US: https://cloud.baidu.com/wenxin.html | en_US: https://cloud.baidu.com/wenxin.html | ||||
| supported_model_types: | supported_model_types: | ||||
| - llm | - llm | ||||
| - text-embedding | |||||
| configurate_methods: | configurate_methods: | ||||
| - predefined-model | - predefined-model | ||||
| provider_credential_schema: | provider_credential_schema: | 
| from core.model_runtime.errors.invoke import ( | |||||
| InvokeAuthorizationError, | |||||
| InvokeBadRequestError, | |||||
| InvokeConnectionError, | |||||
| InvokeError, | |||||
| InvokeRateLimitError, | |||||
| InvokeServerUnavailableError, | |||||
| ) | |||||
| def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]: | |||||
| """ | |||||
| Map model invoke error to unified error | |||||
| The key is the error type thrown to the caller | |||||
| The value is the error type thrown by the model, | |||||
| which needs to be converted into a unified error type for the caller. | |||||
| :return: Invoke error mapping | |||||
| """ | |||||
| return { | |||||
| InvokeConnectionError: [ | |||||
| ], | |||||
| InvokeServerUnavailableError: [ | |||||
| InternalServerError | |||||
| ], | |||||
| InvokeRateLimitError: [ | |||||
| RateLimitReachedError | |||||
| ], | |||||
| InvokeAuthorizationError: [ | |||||
| InvalidAuthenticationError, | |||||
| InsufficientAccountBalance, | |||||
| InvalidAPIKeyError, | |||||
| ], | |||||
| InvokeBadRequestError: [ | |||||
| BadRequestError, | |||||
| KeyError | |||||
| ] | |||||
| } | |||||
| class InvalidAuthenticationError(Exception): | |||||
| pass | |||||
| class InvalidAPIKeyError(Exception): | |||||
| pass | |||||
| class RateLimitReachedError(Exception): | |||||
| pass | |||||
| class InsufficientAccountBalance(Exception): | |||||
| pass | |||||
| class InternalServerError(Exception): | |||||
| pass | |||||
| class BadRequestError(Exception): | |||||
| pass | 
| import os | |||||
| from time import sleep | |||||
| from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult | |||||
| from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel | |||||
| def test_invoke_embedding_model(): | |||||
| sleep(3) | |||||
| model = WenxinTextEmbeddingModel() | |||||
| response = model.invoke( | |||||
| model='embedding-v1', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('WENXIN_API_KEY'), | |||||
| 'secret_key': os.environ.get('WENXIN_SECRET_KEY') | |||||
| }, | |||||
| texts=['hello', '你好', 'xxxxx'], | |||||
| user="abc-123" | |||||
| ) | |||||
| assert isinstance(response, TextEmbeddingResult) | |||||
| assert len(response.embeddings) == 3 | |||||
| assert isinstance(response.embeddings[0], list) | 
| import numpy as np | |||||
| from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult | |||||
| from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer | |||||
| from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import ( | |||||
| TextEmbedding, | |||||
| WenxinTextEmbeddingModel, | |||||
| ) | |||||
| def test_max_chunks(): | |||||
| class _MockTextEmbedding(TextEmbedding): | |||||
| def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): | |||||
| embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))] | |||||
| tokens = 0 | |||||
| for text in texts: | |||||
| tokens += len(text) | |||||
| return embeddings, tokens, tokens | |||||
| def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: | |||||
| return _MockTextEmbedding() | |||||
| model = 'embedding-v1' | |||||
| credentials = { | |||||
| 'api_key': 'xxxx', | |||||
| 'secret_key': 'yyyy', | |||||
| } | |||||
| embedding_model = WenxinTextEmbeddingModel() | |||||
| context_size = embedding_model._get_context_size(model, credentials) | |||||
| max_chunks = embedding_model._get_max_chunks(model, credentials) | |||||
| embedding_model._create_text_embedding = _create_text_embedding | |||||
| texts = ['0123456789' for i in range(0, max_chunks * 2)] | |||||
| result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') | |||||
| assert len(result.embeddings) == max_chunks * 2 | |||||
| def test_context_size(): | |||||
| def get_num_tokens_by_gpt2(text: str) -> int: | |||||
| return GPT2Tokenizer.get_num_tokens(text) | |||||
| def mock_text(token_size: int) -> str: | |||||
| _text = "".join(['0' for i in range(token_size)]) | |||||
| num_tokens = get_num_tokens_by_gpt2(_text) | |||||
| ratio = int(np.floor(len(_text) / num_tokens)) | |||||
| m_text = "".join([_text for i in range(ratio)]) | |||||
| return m_text | |||||
| model = 'embedding-v1' | |||||
| credentials = { | |||||
| 'api_key': 'xxxx', | |||||
| 'secret_key': 'yyyy', | |||||
| } | |||||
| embedding_model = WenxinTextEmbeddingModel() | |||||
| context_size = embedding_model._get_context_size(model, credentials) | |||||
| class _MockTextEmbedding(TextEmbedding): | |||||
| def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): | |||||
| embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))] | |||||
| tokens = 0 | |||||
| for text in texts: | |||||
| tokens += get_num_tokens_by_gpt2(text) | |||||
| return embeddings, tokens, tokens | |||||
| def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: | |||||
| return _MockTextEmbedding() | |||||
| embedding_model._create_text_embedding = _create_text_embedding | |||||
| text = mock_text(context_size * 2) | |||||
| assert get_num_tokens_by_gpt2(text) == context_size * 2 | |||||
| texts = [text] | |||||
| result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') | |||||
| assert result.usage.tokens == context_size |