Parcourir la source

feat: support weixin ernie-bot-4 and chat mode (#1375)

tags/0.3.29
takatost il y a 2 ans
Parent
révision
7c9b585a47
Aucun compte lié à l'adresse e-mail de l'auteur

+ 14
- 5
api/core/model_providers/models/llm/wenxin_model.py Voir le fichier

@@ -6,17 +6,16 @@ from langchain.schema import LLMResult

from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.wenxin import Wenxin


class WenxinModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
model_mode: ModelMode = ModelMode.CHAT

def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin(
model=self.name,
streaming=self.streaming,
@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)

generate_kwargs = {'stop': stop, 'callbacks': callbacks, 'messages': [prompts]}

if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']

return self._client.generate(**generate_kwargs)

def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
return max(self._client.get_num_tokens_from_messages(prompts), 0)

def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):

def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}")

@property
def support_streaming(self):
return True

+ 13
- 5
api/core/model_providers/providers/wenxin_provider.py Voir le fichier

@@ -2,6 +2,8 @@ import json
from json import JSONDecodeError
from typing import Type

from langchain.schema import HumanMessage

from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
@@ -23,20 +25,25 @@ class WenxinProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'ernie-bot-4',
'name': 'ERNIE-Bot-4',
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot',
'name': 'ERNIE-Bot',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
},
{
'id': 'bloomz-7b',
'name': 'BLOOMZ-7B',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
}
]
else:
@@ -68,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
:return:
"""
model_max_tokens = {
'ernie-bot-4': 4800,
'ernie-bot': 4800,
'ernie-bot-turbo': 11200,
}

if model_name in ['ernie-bot', 'ernie-bot-turbo']:
if model_name in ['ernie-bot-4', 'ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
@@ -111,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
**credential_kwargs
)

llm("ping")
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))


+ 6
- 0
api/core/model_providers/rules/wenxin.json Voir le fichier

@@ -5,6 +5,12 @@
"system_config": null,
"model_flexibility": "fixed",
"price_config": {
"ernie-bot-4": {
"prompt": "0",
"completion": "0",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot": {
"prompt": "0.012",
"completion": "0.012",

+ 135
- 63
api/core/third_party/langchain/llms/wenxin.py Voir le fichier

@@ -8,12 +8,15 @@ from typing import (
Any,
Dict,
List,
Optional, Iterator,
Optional, Iterator, Tuple,
)

import requests
from langchain.chat_models.base import BaseChatModel
from langchain.llms.utils import enforce_stop_tokens
from langchain.schema.output import GenerationChunk
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema.messages import AIMessageChunk
from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator

from langchain.callbacks.manager import (
@@ -61,6 +64,7 @@ class _WenxinEndpointClient(BaseModel):
raise ValueError(f"Wenxin Model name is required")

model_url_map = {
'ernie-bot-4': 'completions_pro',
'ernie-bot': 'completions',
'ernie-bot-turbo': 'eb-instant',
'bloomz-7b': 'bloomz_7b1',
@@ -70,6 +74,7 @@ class _WenxinEndpointClient(BaseModel):

access_token = self.get_access_token()
api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
del request['model']

headers = {"Content-Type": "application/json"}
response = requests.post(api_url,
@@ -86,22 +91,21 @@ class _WenxinEndpointClient(BaseModel):
f"Wenxin API {json_response['error_code']}"
f" error: {json_response['error_msg']}"
)
return json_response["result"]
return json_response
else:
return response


class Wenxin(LLM):
"""Wrapper around Wenxin large language models.
To use, you should have the environment variable
``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
or pass them as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms.wenxin import Wenxin
wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
secret_key="my-group-id")
"""
class Wenxin(BaseChatModel):
"""Wrapper around Wenxin large language models."""

@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}

@property
def lc_serializable(self) -> bool:
return True

_client: _WenxinEndpointClient = PrivateAttr()
model: str = "ernie-bot"
@@ -161,64 +165,89 @@ class Wenxin(LLM):
secret_key=self.secret_key,
)

def _call(
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict

def _create_message_dicts(
self, messages: List[BaseMessage]
) -> Tuple[List[Dict[str, Any]], str]:
dict_messages = []
system = None
for m in messages:
message = self._convert_message_to_dict(m)
if message['role'] == 'system':
if not system:
system = message['content']
else:
system += f"\n{message['content']}"
continue

if dict_messages:
previous_message = dict_messages[-1]
if previous_message['role'] == message['role']:
dict_messages[-1]['content'] += f"\n{message['content']}"
else:
dict_messages.append(message)
else:
dict_messages.append(message)

return dict_messages, system

def _generate(
self,
prompt: str,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
r"""Call out to Wenxin's completion endpoint to chat
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = wenxin("Tell me a joke.")
"""
) -> ChatResult:
if self.streaming:
completion = ""
generation: Optional[ChatGenerationChunk] = None
llm_output: Optional[Dict] = None
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
if chunk.generation_info is not None \
and 'token_usage' in chunk.generation_info:
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}

if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation], llm_output=llm_output)
else:
message_dicts, system = self._create_message_dicts(messages)
request = self._default_params
request["messages"] = [{"role": "user", "content": prompt}]
request["messages"] = message_dicts
if system:
request["system"] = system
request.update(kwargs)
completion = self._client.post(request)

if stop is not None:
completion = enforce_stop_tokens(completion, stop)

return completion
response = self._client.post(request)
return self._create_chat_result(response)

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
r"""Call wenxin completion_stream and return the resulting generator.

Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from Wenxin.
Example:
.. code-block:: python

prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = wenxin.stream(prompt)
for token in generator:
yield token
"""
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, system = self._create_message_dicts(messages)
request = self._default_params
request["messages"] = [{"role": "user", "content": prompt}]
request["messages"] = message_dicts
if system:
request["system"] = system
request.update(kwargs)

for token in self._client.post(request).iter_lines():
@@ -228,12 +257,18 @@ class Wenxin(LLM):
if token.startswith('data:'):
completion = json.loads(token[5:])

yield GenerationChunk(text=completion['result'])
if run_manager:
run_manager.on_llm_new_token(completion['result'])
chunk_dict = {
'message': AIMessageChunk(content=completion['result']),
}

if completion['is_end']:
break
token_usage = completion['usage']
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
chunk_dict['generation_info'] = dict({'token_usage': token_usage})

yield ChatGenerationChunk(**chunk_dict)
if run_manager:
run_manager.on_llm_new_token(completion['result'])
else:
try:
json_response = json.loads(token)
@@ -245,3 +280,40 @@ class Wenxin(LLM):
f" error: {json_response['error_msg']}, "
f"please confirm if the model you have chosen is already paid for."
)

def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
generations = [ChatGeneration(
message=AIMessage(content=response['result']),
)]
token_usage = response.get("usage")
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']

llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.

Useful for checking if an input will fit in a model's context window.

Args:
messages: The message inputs to tokenize.

Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(m.content) for m in messages])

def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model}

+ 2
- 3
api/tests/integration_tests/models/llm/test_wenxin_model.py Voir le fichier

@@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)

model = get_mock_model('ernie-bot')
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
messages = [PromptMessage(type=MessageType.USER, content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
rst = model.run(
messages,
stop=['\nHuman:'],
messages
)
assert len(rst.content) > 0

+ 4
- 1
api/tests/unit_tests/model_providers/test_wenxin_provider.py Voir le fichier

@@ -2,6 +2,8 @@ import pytest
from unittest.mock import patch
import json

from langchain.schema import AIMessage, ChatGeneration, ChatResult

from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.wenxin_provider import WenxinProvider
from models.provider import ProviderType, Provider
@@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):


def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._call', return_value="abc")
mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._generate',
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))

MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)


Chargement…
Annuler
Enregistrer