| @@ -0,0 +1,195 @@ | |||
| from datetime import datetime, timedelta | |||
| from threading import Lock | |||
| from requests import post | |||
| from core.model_runtime.model_providers.wenxin.wenxin_errors import ( | |||
| BadRequestError, | |||
| InternalServerError, | |||
| InvalidAPIKeyError, | |||
| InvalidAuthenticationError, | |||
| RateLimitReachedError, | |||
| ) | |||
| baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} | |||
| baidu_access_tokens_lock = Lock() | |||
| class BaiduAccessToken: | |||
| api_key: str | |||
| access_token: str | |||
| expires: datetime | |||
| def __init__(self, api_key: str) -> None: | |||
| self.api_key = api_key | |||
| self.access_token = '' | |||
| self.expires = datetime.now() + timedelta(days=3) | |||
| @staticmethod | |||
| def _get_access_token(api_key: str, secret_key: str) -> str: | |||
| """ | |||
| request access token from Baidu | |||
| """ | |||
| try: | |||
| response = post( | |||
| url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', | |||
| headers={ | |||
| 'Content-Type': 'application/json', | |||
| 'Accept': 'application/json' | |||
| }, | |||
| ) | |||
| except Exception as e: | |||
| raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') | |||
| resp = response.json() | |||
| if 'error' in resp: | |||
| if resp['error'] == 'invalid_client': | |||
| raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') | |||
| elif resp['error'] == 'unknown_error': | |||
| raise InternalServerError(f'Internal server error: {resp["error_description"]}') | |||
| elif resp['error'] == 'invalid_request': | |||
| raise BadRequestError(f'Bad request: {resp["error_description"]}') | |||
| elif resp['error'] == 'rate_limit_exceeded': | |||
| raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') | |||
| else: | |||
| raise Exception(f'Unknown error: {resp["error_description"]}') | |||
| return resp['access_token'] | |||
| @staticmethod | |||
| def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': | |||
| """ | |||
| LLM from Baidu requires access token to invoke the API. | |||
| however, we have api_key and secret_key, and access token is valid for 30 days. | |||
| so we can cache the access token for 3 days. (avoid memory leak) | |||
| it may be more efficient to use a ticker to refresh access token, but it will cause | |||
| more complexity, so we just refresh access tokens when get_access_token is called. | |||
| """ | |||
| # loop up cache, remove expired access token | |||
| baidu_access_tokens_lock.acquire() | |||
| now = datetime.now() | |||
| for key in list(baidu_access_tokens.keys()): | |||
| token = baidu_access_tokens[key] | |||
| if token.expires < now: | |||
| baidu_access_tokens.pop(key) | |||
| if api_key not in baidu_access_tokens: | |||
| # if access token not in cache, request it | |||
| token = BaiduAccessToken(api_key) | |||
| baidu_access_tokens[api_key] = token | |||
| # release it to enhance performance | |||
| # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock | |||
| baidu_access_tokens_lock.release() | |||
| # try to get access token | |||
| token_str = BaiduAccessToken._get_access_token(api_key, secret_key) | |||
| token.access_token = token_str | |||
| token.expires = now + timedelta(days=3) | |||
| return token | |||
| else: | |||
| # if access token in cache, return it | |||
| token = baidu_access_tokens[api_key] | |||
| baidu_access_tokens_lock.release() | |||
| return token | |||
| class _CommonWenxin: | |||
| api_bases = { | |||
| 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', | |||
| 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', | |||
| 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', | |||
| 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', | |||
| 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', | |||
| 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', | |||
| 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', | |||
| 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', | |||
| 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', | |||
| 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', | |||
| 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', | |||
| 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', | |||
| 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', | |||
| 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', | |||
| 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', | |||
| 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', | |||
| 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', | |||
| 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', | |||
| 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', | |||
| 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', | |||
| 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', | |||
| 'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1', | |||
| } | |||
| function_calling_supports = [ | |||
| 'ernie-bot', | |||
| 'ernie-bot-8k', | |||
| 'ernie-3.5-8k', | |||
| 'ernie-3.5-8k-0205', | |||
| 'ernie-3.5-8k-1222', | |||
| 'ernie-3.5-4k-0205', | |||
| 'ernie-3.5-128k', | |||
| 'ernie-4.0-8k', | |||
| 'ernie-4.0-turbo-8k', | |||
| 'ernie-4.0-turbo-8k-preview', | |||
| 'yi_34b_chat' | |||
| ] | |||
| api_key: str = '' | |||
| secret_key: str = '' | |||
| def __init__(self, api_key: str, secret_key: str): | |||
| self.api_key = api_key | |||
| self.secret_key = secret_key | |||
| @staticmethod | |||
| def _to_credential_kwargs(credentials: dict) -> dict: | |||
| credentials_kwargs = { | |||
| "api_key": credentials['api_key'], | |||
| "secret_key": credentials['secret_key'] | |||
| } | |||
| return credentials_kwargs | |||
| def _handle_error(self, code: int, msg: str): | |||
| error_map = { | |||
| 1: InternalServerError, | |||
| 2: InternalServerError, | |||
| 3: BadRequestError, | |||
| 4: RateLimitReachedError, | |||
| 6: InvalidAuthenticationError, | |||
| 13: InvalidAPIKeyError, | |||
| 14: InvalidAPIKeyError, | |||
| 15: InvalidAPIKeyError, | |||
| 17: RateLimitReachedError, | |||
| 18: RateLimitReachedError, | |||
| 19: RateLimitReachedError, | |||
| 100: InvalidAPIKeyError, | |||
| 111: InvalidAPIKeyError, | |||
| 200: InternalServerError, | |||
| 336000: InternalServerError, | |||
| 336001: BadRequestError, | |||
| 336002: BadRequestError, | |||
| 336003: BadRequestError, | |||
| 336004: InvalidAuthenticationError, | |||
| 336005: InvalidAPIKeyError, | |||
| 336006: BadRequestError, | |||
| 336007: BadRequestError, | |||
| 336008: BadRequestError, | |||
| 336100: InternalServerError, | |||
| 336101: BadRequestError, | |||
| 336102: BadRequestError, | |||
| 336103: BadRequestError, | |||
| 336104: BadRequestError, | |||
| 336105: BadRequestError, | |||
| 336200: InternalServerError, | |||
| 336303: BadRequestError, | |||
| 337006: BadRequestError | |||
| } | |||
| if code in error_map: | |||
| raise error_map[code](msg) | |||
| else: | |||
| raise InternalServerError(f'Unknown error: {msg}') | |||
| def _get_access_token(self) -> str: | |||
| token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) | |||
| return token.access_token | |||
| @@ -1,102 +1,17 @@ | |||
| from collections.abc import Generator | |||
| from datetime import datetime, timedelta | |||
| from enum import Enum | |||
| from json import dumps, loads | |||
| from threading import Lock | |||
| from typing import Any, Union | |||
| from requests import Response, post | |||
| from core.model_runtime.entities.message_entities import PromptMessageTool | |||
| from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( | |||
| from core.model_runtime.model_providers.wenxin._common import _CommonWenxin | |||
| from core.model_runtime.model_providers.wenxin.wenxin_errors import ( | |||
| BadRequestError, | |||
| InternalServerError, | |||
| InvalidAPIKeyError, | |||
| InvalidAuthenticationError, | |||
| RateLimitReachedError, | |||
| ) | |||
| # map api_key to access_token | |||
| baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} | |||
| baidu_access_tokens_lock = Lock() | |||
| class BaiduAccessToken: | |||
| api_key: str | |||
| access_token: str | |||
| expires: datetime | |||
| def __init__(self, api_key: str) -> None: | |||
| self.api_key = api_key | |||
| self.access_token = '' | |||
| self.expires = datetime.now() + timedelta(days=3) | |||
| def _get_access_token(api_key: str, secret_key: str) -> str: | |||
| """ | |||
| request access token from Baidu | |||
| """ | |||
| try: | |||
| response = post( | |||
| url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', | |||
| headers={ | |||
| 'Content-Type': 'application/json', | |||
| 'Accept': 'application/json' | |||
| }, | |||
| ) | |||
| except Exception as e: | |||
| raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') | |||
| resp = response.json() | |||
| if 'error' in resp: | |||
| if resp['error'] == 'invalid_client': | |||
| raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') | |||
| elif resp['error'] == 'unknown_error': | |||
| raise InternalServerError(f'Internal server error: {resp["error_description"]}') | |||
| elif resp['error'] == 'invalid_request': | |||
| raise BadRequestError(f'Bad request: {resp["error_description"]}') | |||
| elif resp['error'] == 'rate_limit_exceeded': | |||
| raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') | |||
| else: | |||
| raise Exception(f'Unknown error: {resp["error_description"]}') | |||
| return resp['access_token'] | |||
| @staticmethod | |||
| def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': | |||
| """ | |||
| LLM from Baidu requires access token to invoke the API. | |||
| however, we have api_key and secret_key, and access token is valid for 30 days. | |||
| so we can cache the access token for 3 days. (avoid memory leak) | |||
| it may be more efficient to use a ticker to refresh access token, but it will cause | |||
| more complexity, so we just refresh access tokens when get_access_token is called. | |||
| """ | |||
| # loop up cache, remove expired access token | |||
| baidu_access_tokens_lock.acquire() | |||
| now = datetime.now() | |||
| for key in list(baidu_access_tokens.keys()): | |||
| token = baidu_access_tokens[key] | |||
| if token.expires < now: | |||
| baidu_access_tokens.pop(key) | |||
| if api_key not in baidu_access_tokens: | |||
| # if access token not in cache, request it | |||
| token = BaiduAccessToken(api_key) | |||
| baidu_access_tokens[api_key] = token | |||
| # release it to enhance performance | |||
| # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock | |||
| baidu_access_tokens_lock.release() | |||
| # try to get access token | |||
| token_str = BaiduAccessToken._get_access_token(api_key, secret_key) | |||
| token.access_token = token_str | |||
| token.expires = now + timedelta(days=3) | |||
| return token | |||
| else: | |||
| # if access token in cache, return it | |||
| token = baidu_access_tokens[api_key] | |||
| baidu_access_tokens_lock.release() | |||
| return token | |||
| class ErnieMessage: | |||
| class Role(Enum): | |||
| @@ -120,51 +35,7 @@ class ErnieMessage: | |||
| self.content = content | |||
| self.role = role | |||
| class ErnieBotModel: | |||
| api_bases = { | |||
| 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', | |||
| 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', | |||
| 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', | |||
| 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', | |||
| 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', | |||
| 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', | |||
| 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', | |||
| 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', | |||
| 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', | |||
| 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', | |||
| 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', | |||
| 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', | |||
| 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', | |||
| 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', | |||
| 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', | |||
| 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', | |||
| 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', | |||
| 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', | |||
| 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', | |||
| 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', | |||
| 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', | |||
| } | |||
| function_calling_supports = [ | |||
| 'ernie-bot', | |||
| 'ernie-bot-8k', | |||
| 'ernie-3.5-8k', | |||
| 'ernie-3.5-8k-0205', | |||
| 'ernie-3.5-8k-1222', | |||
| 'ernie-3.5-4k-0205', | |||
| 'ernie-3.5-128k', | |||
| 'ernie-4.0-8k', | |||
| 'ernie-4.0-turbo-8k', | |||
| 'ernie-4.0-turbo-8k-preview', | |||
| 'yi_34b_chat' | |||
| ] | |||
| api_key: str = '' | |||
| secret_key: str = '' | |||
| def __init__(self, api_key: str, secret_key: str): | |||
| self.api_key = api_key | |||
| self.secret_key = secret_key | |||
| class ErnieBotModel(_CommonWenxin): | |||
| def generate(self, model: str, stream: bool, messages: list[ErnieMessage], | |||
| parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ | |||
| @@ -199,51 +70,6 @@ class ErnieBotModel: | |||
| return self._handle_chat_stream_generate_response(resp) | |||
| return self._handle_chat_generate_response(resp) | |||
| def _handle_error(self, code: int, msg: str): | |||
| error_map = { | |||
| 1: InternalServerError, | |||
| 2: InternalServerError, | |||
| 3: BadRequestError, | |||
| 4: RateLimitReachedError, | |||
| 6: InvalidAuthenticationError, | |||
| 13: InvalidAPIKeyError, | |||
| 14: InvalidAPIKeyError, | |||
| 15: InvalidAPIKeyError, | |||
| 17: RateLimitReachedError, | |||
| 18: RateLimitReachedError, | |||
| 19: RateLimitReachedError, | |||
| 100: InvalidAPIKeyError, | |||
| 111: InvalidAPIKeyError, | |||
| 200: InternalServerError, | |||
| 336000: InternalServerError, | |||
| 336001: BadRequestError, | |||
| 336002: BadRequestError, | |||
| 336003: BadRequestError, | |||
| 336004: InvalidAuthenticationError, | |||
| 336005: InvalidAPIKeyError, | |||
| 336006: BadRequestError, | |||
| 336007: BadRequestError, | |||
| 336008: BadRequestError, | |||
| 336100: InternalServerError, | |||
| 336101: BadRequestError, | |||
| 336102: BadRequestError, | |||
| 336103: BadRequestError, | |||
| 336104: BadRequestError, | |||
| 336105: BadRequestError, | |||
| 336200: InternalServerError, | |||
| 336303: BadRequestError, | |||
| 337006: BadRequestError | |||
| } | |||
| if code in error_map: | |||
| raise error_map[code](msg) | |||
| else: | |||
| raise InternalServerError(f'Unknown error: {msg}') | |||
| def _get_access_token(self) -> str: | |||
| token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) | |||
| return token.access_token | |||
| def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: | |||
| return [ErnieMessage(message.content, message.role) for message in messages] | |||
| @@ -1,17 +0,0 @@ | |||
| class InvalidAuthenticationError(Exception): | |||
| pass | |||
| class InvalidAPIKeyError(Exception): | |||
| pass | |||
| class RateLimitReachedError(Exception): | |||
| pass | |||
| class InsufficientAccountBalance(Exception): | |||
| pass | |||
| class InternalServerError(Exception): | |||
| pass | |||
| class BadRequestError(Exception): | |||
| pass | |||
| @@ -11,24 +11,13 @@ from core.model_runtime.entities.message_entities import ( | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.errors.invoke import ( | |||
| InvokeAuthorizationError, | |||
| InvokeBadRequestError, | |||
| InvokeConnectionError, | |||
| InvokeError, | |||
| InvokeRateLimitError, | |||
| InvokeServerUnavailableError, | |||
| ) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage | |||
| from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( | |||
| BadRequestError, | |||
| InsufficientAccountBalance, | |||
| InternalServerError, | |||
| InvalidAPIKeyError, | |||
| InvalidAuthenticationError, | |||
| RateLimitReachedError, | |||
| ) | |||
| from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken | |||
| from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage | |||
| from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping | |||
| ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. | |||
| The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure | |||
| @@ -140,7 +129,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): | |||
| api_key = credentials['api_key'] | |||
| secret_key = credentials['secret_key'] | |||
| try: | |||
| BaiduAccessToken._get_access_token(api_key, secret_key) | |||
| BaiduAccessToken.get_access_token(api_key, secret_key) | |||
| except Exception as e: | |||
| raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') | |||
| @@ -254,22 +243,4 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): | |||
| :return: Invoke error mapping | |||
| """ | |||
| return { | |||
| InvokeConnectionError: [ | |||
| ], | |||
| InvokeServerUnavailableError: [ | |||
| InternalServerError | |||
| ], | |||
| InvokeRateLimitError: [ | |||
| RateLimitReachedError | |||
| ], | |||
| InvokeAuthorizationError: [ | |||
| InvalidAuthenticationError, | |||
| InsufficientAccountBalance, | |||
| InvalidAPIKeyError, | |||
| ], | |||
| InvokeBadRequestError: [ | |||
| BadRequestError, | |||
| KeyError | |||
| ] | |||
| } | |||
| return invoke_error_mapping() | |||
| @@ -0,0 +1,9 @@ | |||
| model: embedding-v1 | |||
| model_type: text-embedding | |||
| model_properties: | |||
| context_size: 384 | |||
| max_chunks: 16 | |||
| pricing: | |||
| input: '0.0005' | |||
| unit: '0.001' | |||
| currency: RMB | |||
| @@ -0,0 +1,184 @@ | |||
| import time | |||
| from abc import abstractmethod | |||
| from collections.abc import Mapping | |||
| from json import dumps | |||
| from typing import Any, Optional | |||
| import numpy as np | |||
| from requests import Response, post | |||
| from core.model_runtime.entities.model_entities import PriceType | |||
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin | |||
| from core.model_runtime.model_providers.wenxin.wenxin_errors import ( | |||
| BadRequestError, | |||
| InternalServerError, | |||
| invoke_error_mapping, | |||
| ) | |||
| class TextEmbedding: | |||
| @abstractmethod | |||
| def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): | |||
| raise NotImplementedError | |||
| class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): | |||
| def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): | |||
| access_token = self._get_access_token() | |||
| url = f'{self.api_bases[model]}?access_token={access_token}' | |||
| body = self._build_embed_request_body(model, texts, user) | |||
| headers = { | |||
| 'Content-Type': 'application/json', | |||
| } | |||
| resp = post(url, data=dumps(body), headers=headers) | |||
| if resp.status_code != 200: | |||
| raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') | |||
| return self._handle_embed_response(model, resp) | |||
| def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]: | |||
| if len(texts) == 0: | |||
| raise BadRequestError('The number of texts should not be zero.') | |||
| body = { | |||
| 'input': texts, | |||
| 'user_id': user, | |||
| } | |||
| return body | |||
| def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): | |||
| data = response.json() | |||
| if 'error_code' in data: | |||
| code = data['error_code'] | |||
| msg = data['error_msg'] | |||
| # raise error | |||
| self._handle_error(code, msg) | |||
| embeddings = [v['embedding'] for v in data['data']] | |||
| _usage = data['usage'] | |||
| tokens = _usage['prompt_tokens'] | |||
| total_tokens = _usage['total_tokens'] | |||
| return embeddings, tokens, total_tokens | |||
| class WenxinTextEmbeddingModel(TextEmbeddingModel): | |||
| def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding: | |||
| return WenxinTextEmbedding(api_key, secret_key) | |||
| def _invoke(self, model: str, credentials: dict, texts: list[str], | |||
| user: Optional[str] = None) -> TextEmbeddingResult: | |||
| """ | |||
| Invoke text embedding model | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param texts: texts to embed | |||
| :param user: unique user id | |||
| :return: embeddings result | |||
| """ | |||
| api_key = credentials['api_key'] | |||
| secret_key = credentials['secret_key'] | |||
| embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key) | |||
| user = user if user else 'ErnieBotDefault' | |||
| context_size = self._get_context_size(model, credentials) | |||
| max_chunks = self._get_max_chunks(model, credentials) | |||
| inputs = [] | |||
| indices = [] | |||
| used_tokens = 0 | |||
| used_total_tokens = 0 | |||
| for i, text in enumerate(texts): | |||
| # Here token count is only an approximation based on the GPT2 tokenizer | |||
| num_tokens = self._get_num_tokens_by_gpt2(text) | |||
| if num_tokens >= context_size: | |||
| cutoff = int(np.floor(len(text) * (context_size / num_tokens))) | |||
| # if num tokens is larger than context length, only use the start | |||
| inputs.append(text[0:cutoff]) | |||
| else: | |||
| inputs.append(text) | |||
| indices += [i] | |||
| batched_embeddings = [] | |||
| _iter = range(0, len(inputs), max_chunks) | |||
| for i in _iter: | |||
| embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents( | |||
| model, | |||
| inputs[i: i + max_chunks], | |||
| user) | |||
| used_tokens += _used_tokens | |||
| used_total_tokens += _total_used_tokens | |||
| batched_embeddings += embeddings_batch | |||
| usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens) | |||
| return TextEmbeddingResult( | |||
| model=model, | |||
| embeddings=batched_embeddings, | |||
| usage=usage, | |||
| ) | |||
| def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: | |||
| """ | |||
| Get number of tokens for given prompt messages | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param texts: texts to embed | |||
| :return: | |||
| """ | |||
| if len(texts) == 0: | |||
| return 0 | |||
| total_num_tokens = 0 | |||
| for text in texts: | |||
| total_num_tokens += self._get_num_tokens_by_gpt2(text) | |||
| return total_num_tokens | |||
| def validate_credentials(self, model: str, credentials: Mapping) -> None: | |||
| api_key = credentials['api_key'] | |||
| secret_key = credentials['secret_key'] | |||
| try: | |||
| BaiduAccessToken.get_access_token(api_key, secret_key) | |||
| except Exception as e: | |||
| raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') | |||
| @property | |||
| def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | |||
| return invoke_error_mapping() | |||
| def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage: | |||
| """ | |||
| Calculate response usage | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param tokens: input tokens | |||
| :return: usage | |||
| """ | |||
| # get input price info | |||
| input_price_info = self.get_price( | |||
| model=model, | |||
| credentials=credentials, | |||
| price_type=PriceType.INPUT, | |||
| tokens=tokens | |||
| ) | |||
| # transform usage | |||
| usage = EmbeddingUsage( | |||
| tokens=tokens, | |||
| total_tokens=total_tokens, | |||
| unit_price=input_price_info.unit_price, | |||
| price_unit=input_price_info.unit, | |||
| total_price=input_price_info.total_amount, | |||
| currency=input_price_info.currency, | |||
| latency=time.perf_counter() - self.started_at | |||
| ) | |||
| return usage | |||
| @@ -17,6 +17,7 @@ help: | |||
| en_US: https://cloud.baidu.com/wenxin.html | |||
| supported_model_types: | |||
| - llm | |||
| - text-embedding | |||
| configurate_methods: | |||
| - predefined-model | |||
| provider_credential_schema: | |||
| @@ -0,0 +1,57 @@ | |||
| from core.model_runtime.errors.invoke import ( | |||
| InvokeAuthorizationError, | |||
| InvokeBadRequestError, | |||
| InvokeConnectionError, | |||
| InvokeError, | |||
| InvokeRateLimitError, | |||
| InvokeServerUnavailableError, | |||
| ) | |||
| def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]: | |||
| """ | |||
| Map model invoke error to unified error | |||
| The key is the error type thrown to the caller | |||
| The value is the error type thrown by the model, | |||
| which needs to be converted into a unified error type for the caller. | |||
| :return: Invoke error mapping | |||
| """ | |||
| return { | |||
| InvokeConnectionError: [ | |||
| ], | |||
| InvokeServerUnavailableError: [ | |||
| InternalServerError | |||
| ], | |||
| InvokeRateLimitError: [ | |||
| RateLimitReachedError | |||
| ], | |||
| InvokeAuthorizationError: [ | |||
| InvalidAuthenticationError, | |||
| InsufficientAccountBalance, | |||
| InvalidAPIKeyError, | |||
| ], | |||
| InvokeBadRequestError: [ | |||
| BadRequestError, | |||
| KeyError | |||
| ] | |||
| } | |||
| class InvalidAuthenticationError(Exception): | |||
| pass | |||
| class InvalidAPIKeyError(Exception): | |||
| pass | |||
| class RateLimitReachedError(Exception): | |||
| pass | |||
| class InsufficientAccountBalance(Exception): | |||
| pass | |||
| class InternalServerError(Exception): | |||
| pass | |||
| class BadRequestError(Exception): | |||
| pass | |||
| @@ -0,0 +1,24 @@ | |||
| import os | |||
| from time import sleep | |||
| from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult | |||
| from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel | |||
| def test_invoke_embedding_model(): | |||
| sleep(3) | |||
| model = WenxinTextEmbeddingModel() | |||
| response = model.invoke( | |||
| model='embedding-v1', | |||
| credentials={ | |||
| 'api_key': os.environ.get('WENXIN_API_KEY'), | |||
| 'secret_key': os.environ.get('WENXIN_SECRET_KEY') | |||
| }, | |||
| texts=['hello', '你好', 'xxxxx'], | |||
| user="abc-123" | |||
| ) | |||
| assert isinstance(response, TextEmbeddingResult) | |||
| assert len(response.embeddings) == 3 | |||
| assert isinstance(response.embeddings[0], list) | |||
| @@ -0,0 +1,75 @@ | |||
| import numpy as np | |||
| from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult | |||
| from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer | |||
| from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import ( | |||
| TextEmbedding, | |||
| WenxinTextEmbeddingModel, | |||
| ) | |||
| def test_max_chunks(): | |||
| class _MockTextEmbedding(TextEmbedding): | |||
| def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): | |||
| embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))] | |||
| tokens = 0 | |||
| for text in texts: | |||
| tokens += len(text) | |||
| return embeddings, tokens, tokens | |||
| def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: | |||
| return _MockTextEmbedding() | |||
| model = 'embedding-v1' | |||
| credentials = { | |||
| 'api_key': 'xxxx', | |||
| 'secret_key': 'yyyy', | |||
| } | |||
| embedding_model = WenxinTextEmbeddingModel() | |||
| context_size = embedding_model._get_context_size(model, credentials) | |||
| max_chunks = embedding_model._get_max_chunks(model, credentials) | |||
| embedding_model._create_text_embedding = _create_text_embedding | |||
| texts = ['0123456789' for i in range(0, max_chunks * 2)] | |||
| result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') | |||
| assert len(result.embeddings) == max_chunks * 2 | |||
| def test_context_size(): | |||
| def get_num_tokens_by_gpt2(text: str) -> int: | |||
| return GPT2Tokenizer.get_num_tokens(text) | |||
| def mock_text(token_size: int) -> str: | |||
| _text = "".join(['0' for i in range(token_size)]) | |||
| num_tokens = get_num_tokens_by_gpt2(_text) | |||
| ratio = int(np.floor(len(_text) / num_tokens)) | |||
| m_text = "".join([_text for i in range(ratio)]) | |||
| return m_text | |||
| model = 'embedding-v1' | |||
| credentials = { | |||
| 'api_key': 'xxxx', | |||
| 'secret_key': 'yyyy', | |||
| } | |||
| embedding_model = WenxinTextEmbeddingModel() | |||
| context_size = embedding_model._get_context_size(model, credentials) | |||
| class _MockTextEmbedding(TextEmbedding): | |||
| def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): | |||
| embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))] | |||
| tokens = 0 | |||
| for text in texts: | |||
| tokens += get_num_tokens_by_gpt2(text) | |||
| return embeddings, tokens, tokens | |||
| def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding: | |||
| return _MockTextEmbedding() | |||
| embedding_model._create_text_embedding = _create_text_embedding | |||
| text = mock_text(context_size * 2) | |||
| assert get_num_tokens_by_gpt2(text) == context_size * 2 | |||
| texts = [text] | |||
| result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test') | |||
| assert result.usage.tokens == context_size | |||