Преглед на файлове

feat(api): support wenxin text embedding (#7377)

tags/0.7.1
Chengyu Yan преди 1 година
родител
ревизия
bfd905602f
No account linked to committer's email address

+ 195
- 0
api/core/model_runtime/model_providers/wenxin/_common.py Целия файл

@@ -0,0 +1,195 @@
from datetime import datetime, timedelta
from threading import Lock

from requests import post

from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)

baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
baidu_access_tokens_lock = Lock()


class BaiduAccessToken:
api_key: str
access_token: str
expires: datetime

def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.access_token = ''
self.expires = datetime.now() + timedelta(days=3)

@staticmethod
def _get_access_token(api_key: str, secret_key: str) -> str:
"""
request access token from Baidu
"""
try:
response = post(
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
headers={
'Content-Type': 'application/json',
'Accept': 'application/json'
},
)
except Exception as e:
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')

resp = response.json()
if 'error' in resp:
if resp['error'] == 'invalid_client':
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
elif resp['error'] == 'unknown_error':
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
elif resp['error'] == 'invalid_request':
raise BadRequestError(f'Bad request: {resp["error_description"]}')
elif resp['error'] == 'rate_limit_exceeded':
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
else:
raise Exception(f'Unknown error: {resp["error_description"]}')

return resp['access_token']

@staticmethod
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
"""
LLM from Baidu requires access token to invoke the API.
however, we have api_key and secret_key, and access token is valid for 30 days.
so we can cache the access token for 3 days. (avoid memory leak)

it may be more efficient to use a ticker to refresh access token, but it will cause
more complexity, so we just refresh access tokens when get_access_token is called.
"""

# loop up cache, remove expired access token
baidu_access_tokens_lock.acquire()
now = datetime.now()
for key in list(baidu_access_tokens.keys()):
token = baidu_access_tokens[key]
if token.expires < now:
baidu_access_tokens.pop(key)

if api_key not in baidu_access_tokens:
# if access token not in cache, request it
token = BaiduAccessToken(api_key)
baidu_access_tokens[api_key] = token
# release it to enhance performance
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
baidu_access_tokens_lock.release()
# try to get access token
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
token.access_token = token_str
token.expires = now + timedelta(days=3)
return token
else:
# if access token in cache, return it
token = baidu_access_tokens[api_key]
baidu_access_tokens_lock.release()
return token


class _CommonWenxin:
api_bases = {
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1',
}

function_calling_supports = [
'ernie-bot',
'ernie-bot-8k',
'ernie-3.5-8k',
'ernie-3.5-8k-0205',
'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205',
'ernie-3.5-128k',
'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview',
'yi_34b_chat'
]

api_key: str = ''
secret_key: str = ''

def __init__(self, api_key: str, secret_key: str):
self.api_key = api_key
self.secret_key = secret_key

@staticmethod
def _to_credential_kwargs(credentials: dict) -> dict:
credentials_kwargs = {
"api_key": credentials['api_key'],
"secret_key": credentials['secret_key']
}
return credentials_kwargs

def _handle_error(self, code: int, msg: str):
error_map = {
1: InternalServerError,
2: InternalServerError,
3: BadRequestError,
4: RateLimitReachedError,
6: InvalidAuthenticationError,
13: InvalidAPIKeyError,
14: InvalidAPIKeyError,
15: InvalidAPIKeyError,
17: RateLimitReachedError,
18: RateLimitReachedError,
19: RateLimitReachedError,
100: InvalidAPIKeyError,
111: InvalidAPIKeyError,
200: InternalServerError,
336000: InternalServerError,
336001: BadRequestError,
336002: BadRequestError,
336003: BadRequestError,
336004: InvalidAuthenticationError,
336005: InvalidAPIKeyError,
336006: BadRequestError,
336007: BadRequestError,
336008: BadRequestError,
336100: InternalServerError,
336101: BadRequestError,
336102: BadRequestError,
336103: BadRequestError,
336104: BadRequestError,
336105: BadRequestError,
336200: InternalServerError,
336303: BadRequestError,
337006: BadRequestError
}

if code in error_map:
raise error_map[code](msg)
else:
raise InternalServerError(f'Unknown error: {msg}')

def _get_access_token(self) -> str:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token

+ 3
- 177
api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py Целия файл

@@ -1,102 +1,17 @@
from collections.abc import Generator
from datetime import datetime, timedelta
from enum import Enum
from json import dumps, loads
from threading import Lock
from typing import Any, Union

from requests import Response, post

from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)

# map api_key to access_token
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
baidu_access_tokens_lock = Lock()

class BaiduAccessToken:
api_key: str
access_token: str
expires: datetime

def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.access_token = ''
self.expires = datetime.now() + timedelta(days=3)

def _get_access_token(api_key: str, secret_key: str) -> str:
"""
request access token from Baidu
"""
try:
response = post(
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
headers={
'Content-Type': 'application/json',
'Accept': 'application/json'
},
)
except Exception as e:
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')

resp = response.json()
if 'error' in resp:
if resp['error'] == 'invalid_client':
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
elif resp['error'] == 'unknown_error':
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
elif resp['error'] == 'invalid_request':
raise BadRequestError(f'Bad request: {resp["error_description"]}')
elif resp['error'] == 'rate_limit_exceeded':
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
else:
raise Exception(f'Unknown error: {resp["error_description"]}')

return resp['access_token']

@staticmethod
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
"""
LLM from Baidu requires access token to invoke the API.
however, we have api_key and secret_key, and access token is valid for 30 days.
so we can cache the access token for 3 days. (avoid memory leak)

it may be more efficient to use a ticker to refresh access token, but it will cause
more complexity, so we just refresh access tokens when get_access_token is called.
"""

# loop up cache, remove expired access token
baidu_access_tokens_lock.acquire()
now = datetime.now()
for key in list(baidu_access_tokens.keys()):
token = baidu_access_tokens[key]
if token.expires < now:
baidu_access_tokens.pop(key)

if api_key not in baidu_access_tokens:
# if access token not in cache, request it
token = BaiduAccessToken(api_key)
baidu_access_tokens[api_key] = token
# release it to enhance performance
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
baidu_access_tokens_lock.release()
# try to get access token
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
token.access_token = token_str
token.expires = now + timedelta(days=3)
return token
else:
# if access token in cache, return it
token = baidu_access_tokens[api_key]
baidu_access_tokens_lock.release()
return token


class ErnieMessage:
class Role(Enum):
@@ -120,51 +35,7 @@ class ErnieMessage:
self.content = content
self.role = role

class ErnieBotModel:
api_bases = {
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
}

function_calling_supports = [
'ernie-bot',
'ernie-bot-8k',
'ernie-3.5-8k',
'ernie-3.5-8k-0205',
'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205',
'ernie-3.5-128k',
'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview',
'yi_34b_chat'
]

api_key: str = ''
secret_key: str = ''

def __init__(self, api_key: str, secret_key: str):
self.api_key = api_key
self.secret_key = secret_key
class ErnieBotModel(_CommonWenxin):

def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
@@ -199,51 +70,6 @@ class ErnieBotModel:
return self._handle_chat_stream_generate_response(resp)
return self._handle_chat_generate_response(resp)

def _handle_error(self, code: int, msg: str):
error_map = {
1: InternalServerError,
2: InternalServerError,
3: BadRequestError,
4: RateLimitReachedError,
6: InvalidAuthenticationError,
13: InvalidAPIKeyError,
14: InvalidAPIKeyError,
15: InvalidAPIKeyError,
17: RateLimitReachedError,
18: RateLimitReachedError,
19: RateLimitReachedError,
100: InvalidAPIKeyError,
111: InvalidAPIKeyError,
200: InternalServerError,
336000: InternalServerError,
336001: BadRequestError,
336002: BadRequestError,
336003: BadRequestError,
336004: InvalidAuthenticationError,
336005: InvalidAPIKeyError,
336006: BadRequestError,
336007: BadRequestError,
336008: BadRequestError,
336100: InternalServerError,
336101: BadRequestError,
336102: BadRequestError,
336103: BadRequestError,
336104: BadRequestError,
336105: BadRequestError,
336200: InternalServerError,
336303: BadRequestError,
337006: BadRequestError
}

if code in error_map:
raise error_map[code](msg)
else:
raise InternalServerError(f'Unknown error: {msg}')

def _get_access_token(self) -> str:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token

def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
return [ErnieMessage(message.content, message.role) for message in messages]


+ 0
- 17
api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py Целия файл

@@ -1,17 +0,0 @@
class InvalidAuthenticationError(Exception):
pass

class InvalidAPIKeyError(Exception):
pass

class RateLimitReachedError(Exception):
pass

class InsufficientAccountBalance(Exception):
pass

class InternalServerError(Exception):
pass

class BadRequestError(Exception):
pass

+ 5
- 34
api/core/model_runtime/model_providers/wenxin/llm/llm.py Целия файл

@@ -11,24 +11,13 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
BadRequestError,
InsufficientAccountBalance,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage
from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping

ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
@@ -140,7 +129,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
api_key = credentials['api_key']
secret_key = credentials['secret_key']
try:
BaiduAccessToken._get_access_token(api_key, secret_key)
BaiduAccessToken.get_access_token(api_key, secret_key)
except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')

@@ -254,22 +243,4 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):

:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
}
return invoke_error_mapping()

+ 0
- 0
api/core/model_runtime/model_providers/wenxin/text_embedding/__init__.py Целия файл


+ 9
- 0
api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml Целия файл

@@ -0,0 +1,9 @@
model: embedding-v1
model_type: text-embedding
model_properties:
context_size: 384
max_chunks: 16
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

+ 184
- 0
api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py Целия файл

@@ -0,0 +1,184 @@
import time
from abc import abstractmethod
from collections.abc import Mapping
from json import dumps
from typing import Any, Optional

import numpy as np
from requests import Response, post

from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
invoke_error_mapping,
)


class TextEmbedding:
@abstractmethod
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
raise NotImplementedError


class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
access_token = self._get_access_token()
url = f'{self.api_bases[model]}?access_token={access_token}'
body = self._build_embed_request_body(model, texts, user)
headers = {
'Content-Type': 'application/json',
}

resp = post(url, data=dumps(body), headers=headers)
if resp.status_code != 200:
raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}')
return self._handle_embed_response(model, resp)

def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]:
if len(texts) == 0:
raise BadRequestError('The number of texts should not be zero.')
body = {
'input': texts,
'user_id': user,
}
return body

def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
data = response.json()
if 'error_code' in data:
code = data['error_code']
msg = data['error_msg']
# raise error
self._handle_error(code, msg)

embeddings = [v['embedding'] for v in data['data']]
_usage = data['usage']
tokens = _usage['prompt_tokens']
total_tokens = _usage['total_tokens']

return embeddings, tokens, total_tokens


class WenxinTextEmbeddingModel(TextEmbeddingModel):
def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding:
return WenxinTextEmbedding(api_key, secret_key)

def _invoke(self, model: str, credentials: dict, texts: list[str],
user: Optional[str] = None) -> TextEmbeddingResult:
"""
Invoke text embedding model

:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""

api_key = credentials['api_key']
secret_key = credentials['secret_key']
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
user = user if user else 'ErnieBotDefault'

context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
inputs = []
indices = []
used_tokens = 0
used_total_tokens = 0

for i, text in enumerate(texts):

# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)

if num_tokens >= context_size:
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0:cutoff])
else:
inputs.append(text)
indices += [i]

batched_embeddings = []
_iter = range(0, len(inputs), max_chunks)
for i in _iter:
embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents(
model,
inputs[i: i + max_chunks],
user)
used_tokens += _used_tokens
used_total_tokens += _total_used_tokens
batched_embeddings += embeddings_batch

usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens)
return TextEmbeddingResult(
model=model,
embeddings=batched_embeddings,
usage=usage,
)

def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages

:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
if len(texts) == 0:
return 0
total_num_tokens = 0
for text in texts:
total_num_tokens += self._get_num_tokens_by_gpt2(text)

return total_num_tokens

def validate_credentials(self, model: str, credentials: Mapping) -> None:
api_key = credentials['api_key']
secret_key = credentials['secret_key']
try:
BaiduAccessToken.get_access_token(api_key, secret_key)
except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return invoke_error_mapping()

def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
"""
Calculate response usage

:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)

# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=total_tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)

return usage

+ 1
- 0
api/core/model_runtime/model_providers/wenxin/wenxin.yaml Целия файл

@@ -17,6 +17,7 @@ help:
en_US: https://cloud.baidu.com/wenxin.html
supported_model_types:
- llm
- text-embedding
configurate_methods:
- predefined-model
provider_credential_schema:

+ 57
- 0
api/core/model_runtime/model_providers/wenxin/wenxin_errors.py Целия файл

@@ -0,0 +1,57 @@
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)


def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.

:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
}


class InvalidAuthenticationError(Exception):
pass

class InvalidAPIKeyError(Exception):
pass

class RateLimitReachedError(Exception):
pass

class InsufficientAccountBalance(Exception):
pass

class InternalServerError(Exception):
pass

class BadRequestError(Exception):
pass

+ 24
- 0
api/tests/integration_tests/model_runtime/wenxin/test_embedding.py Целия файл

@@ -0,0 +1,24 @@
import os
from time import sleep

from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel


def test_invoke_embedding_model():
sleep(3)
model = WenxinTextEmbeddingModel()

response = model.invoke(
model='embedding-v1',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
texts=['hello', '你好', 'xxxxx'],
user="abc-123"
)

assert isinstance(response, TextEmbeddingResult)
assert len(response.embeddings) == 3
assert isinstance(response.embeddings[0], list)

+ 0
- 0
api/tests/unit_tests/core/model_runtime/__init__.py Целия файл


+ 0
- 0
api/tests/unit_tests/core/model_runtime/model_providers/__init__.py Целия файл


+ 0
- 0
api/tests/unit_tests/core/model_runtime/model_providers/wenxin/__init__.py Целия файл


+ 75
- 0
api/tests/unit_tests/core/model_runtime/model_providers/wenxin/test_text_embedding.py Целия файл

@@ -0,0 +1,75 @@
import numpy as np

from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import (
TextEmbedding,
WenxinTextEmbeddingModel,
)


def test_max_chunks():
class _MockTextEmbedding(TextEmbedding):
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))]
tokens = 0
for text in texts:
tokens += len(text)

return embeddings, tokens, tokens

def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
return _MockTextEmbedding()

model = 'embedding-v1'
credentials = {
'api_key': 'xxxx',
'secret_key': 'yyyy',
}
embedding_model = WenxinTextEmbeddingModel()
context_size = embedding_model._get_context_size(model, credentials)
max_chunks = embedding_model._get_max_chunks(model, credentials)
embedding_model._create_text_embedding = _create_text_embedding

texts = ['0123456789' for i in range(0, max_chunks * 2)]
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
assert len(result.embeddings) == max_chunks * 2


def test_context_size():
def get_num_tokens_by_gpt2(text: str) -> int:
return GPT2Tokenizer.get_num_tokens(text)

def mock_text(token_size: int) -> str:
_text = "".join(['0' for i in range(token_size)])
num_tokens = get_num_tokens_by_gpt2(_text)
ratio = int(np.floor(len(_text) / num_tokens))
m_text = "".join([_text for i in range(ratio)])
return m_text

model = 'embedding-v1'
credentials = {
'api_key': 'xxxx',
'secret_key': 'yyyy',
}
embedding_model = WenxinTextEmbeddingModel()
context_size = embedding_model._get_context_size(model, credentials)

class _MockTextEmbedding(TextEmbedding):
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))]
tokens = 0
for text in texts:
tokens += get_num_tokens_by_gpt2(text)
return embeddings, tokens, tokens

def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
return _MockTextEmbedding()

embedding_model._create_text_embedding = _create_text_embedding
text = mock_text(context_size * 2)
assert get_num_tokens_by_gpt2(text) == context_size * 2

texts = [text]
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
assert result.usage.tokens == context_size

Loading…
Отказ
Запис