|
|
|
@@ -8,12 +8,15 @@ from typing import ( |
|
|
|
Any, |
|
|
|
Dict, |
|
|
|
List, |
|
|
|
Optional, Iterator, |
|
|
|
Optional, Iterator, Tuple, |
|
|
|
) |
|
|
|
|
|
|
|
import requests |
|
|
|
from langchain.chat_models.base import BaseChatModel |
|
|
|
from langchain.llms.utils import enforce_stop_tokens |
|
|
|
from langchain.schema.output import GenerationChunk |
|
|
|
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage |
|
|
|
from langchain.schema.messages import AIMessageChunk |
|
|
|
from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration |
|
|
|
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator |
|
|
|
|
|
|
|
from langchain.callbacks.manager import ( |
|
|
|
@@ -61,6 +64,7 @@ class _WenxinEndpointClient(BaseModel): |
|
|
|
raise ValueError(f"Wenxin Model name is required") |
|
|
|
|
|
|
|
model_url_map = { |
|
|
|
'ernie-bot-4': 'completions_pro', |
|
|
|
'ernie-bot': 'completions', |
|
|
|
'ernie-bot-turbo': 'eb-instant', |
|
|
|
'bloomz-7b': 'bloomz_7b1', |
|
|
|
@@ -70,6 +74,7 @@ class _WenxinEndpointClient(BaseModel): |
|
|
|
|
|
|
|
access_token = self.get_access_token() |
|
|
|
api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}" |
|
|
|
del request['model'] |
|
|
|
|
|
|
|
headers = {"Content-Type": "application/json"} |
|
|
|
response = requests.post(api_url, |
|
|
|
@@ -86,22 +91,21 @@ class _WenxinEndpointClient(BaseModel): |
|
|
|
f"Wenxin API {json_response['error_code']}" |
|
|
|
f" error: {json_response['error_msg']}" |
|
|
|
) |
|
|
|
return json_response["result"] |
|
|
|
return json_response |
|
|
|
else: |
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
class Wenxin(LLM): |
|
|
|
"""Wrapper around Wenxin large language models. |
|
|
|
To use, you should have the environment variable |
|
|
|
``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key, |
|
|
|
or pass them as a named parameter to the constructor. |
|
|
|
Example: |
|
|
|
.. code-block:: python |
|
|
|
from langchain.llms.wenxin import Wenxin |
|
|
|
wenxin = Wenxin(model="<model_name>", api_key="my-api-key", |
|
|
|
secret_key="my-group-id") |
|
|
|
""" |
|
|
|
class Wenxin(BaseChatModel): |
|
|
|
"""Wrapper around Wenxin large language models.""" |
|
|
|
|
|
|
|
@property |
|
|
|
def lc_secrets(self) -> Dict[str, str]: |
|
|
|
return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"} |
|
|
|
|
|
|
|
@property |
|
|
|
def lc_serializable(self) -> bool: |
|
|
|
return True |
|
|
|
|
|
|
|
_client: _WenxinEndpointClient = PrivateAttr() |
|
|
|
model: str = "ernie-bot" |
|
|
|
@@ -161,64 +165,89 @@ class Wenxin(LLM): |
|
|
|
secret_key=self.secret_key, |
|
|
|
) |
|
|
|
|
|
|
|
def _call( |
|
|
|
def _convert_message_to_dict(self, message: BaseMessage) -> dict: |
|
|
|
if isinstance(message, ChatMessage): |
|
|
|
message_dict = {"role": message.role, "content": message.content} |
|
|
|
elif isinstance(message, HumanMessage): |
|
|
|
message_dict = {"role": "user", "content": message.content} |
|
|
|
elif isinstance(message, AIMessage): |
|
|
|
message_dict = {"role": "assistant", "content": message.content} |
|
|
|
elif isinstance(message, SystemMessage): |
|
|
|
message_dict = {"role": "system", "content": message.content} |
|
|
|
else: |
|
|
|
raise ValueError(f"Got unknown type {message}") |
|
|
|
return message_dict |
|
|
|
|
|
|
|
def _create_message_dicts( |
|
|
|
self, messages: List[BaseMessage] |
|
|
|
) -> Tuple[List[Dict[str, Any]], str]: |
|
|
|
dict_messages = [] |
|
|
|
system = None |
|
|
|
for m in messages: |
|
|
|
message = self._convert_message_to_dict(m) |
|
|
|
if message['role'] == 'system': |
|
|
|
if not system: |
|
|
|
system = message['content'] |
|
|
|
else: |
|
|
|
system += f"\n{message['content']}" |
|
|
|
continue |
|
|
|
|
|
|
|
if dict_messages: |
|
|
|
previous_message = dict_messages[-1] |
|
|
|
if previous_message['role'] == message['role']: |
|
|
|
dict_messages[-1]['content'] += f"\n{message['content']}" |
|
|
|
else: |
|
|
|
dict_messages.append(message) |
|
|
|
else: |
|
|
|
dict_messages.append(message) |
|
|
|
|
|
|
|
return dict_messages, system |
|
|
|
|
|
|
|
def _generate( |
|
|
|
self, |
|
|
|
prompt: str, |
|
|
|
messages: List[BaseMessage], |
|
|
|
stop: Optional[List[str]] = None, |
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
|
|
**kwargs: Any, |
|
|
|
) -> str: |
|
|
|
r"""Call out to Wenxin's completion endpoint to chat |
|
|
|
Args: |
|
|
|
prompt: The prompt to pass into the model. |
|
|
|
Returns: |
|
|
|
The string generated by the model. |
|
|
|
Example: |
|
|
|
.. code-block:: python |
|
|
|
response = wenxin("Tell me a joke.") |
|
|
|
""" |
|
|
|
) -> ChatResult: |
|
|
|
if self.streaming: |
|
|
|
completion = "" |
|
|
|
generation: Optional[ChatGenerationChunk] = None |
|
|
|
llm_output: Optional[Dict] = None |
|
|
|
for chunk in self._stream( |
|
|
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs |
|
|
|
messages=messages, stop=stop, run_manager=run_manager, **kwargs |
|
|
|
): |
|
|
|
completion += chunk.text |
|
|
|
if chunk.generation_info is not None \ |
|
|
|
and 'token_usage' in chunk.generation_info: |
|
|
|
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model} |
|
|
|
|
|
|
|
if generation is None: |
|
|
|
generation = chunk |
|
|
|
else: |
|
|
|
generation += chunk |
|
|
|
assert generation is not None |
|
|
|
return ChatResult(generations=[generation], llm_output=llm_output) |
|
|
|
else: |
|
|
|
message_dicts, system = self._create_message_dicts(messages) |
|
|
|
request = self._default_params |
|
|
|
request["messages"] = [{"role": "user", "content": prompt}] |
|
|
|
request["messages"] = message_dicts |
|
|
|
if system: |
|
|
|
request["system"] = system |
|
|
|
request.update(kwargs) |
|
|
|
completion = self._client.post(request) |
|
|
|
|
|
|
|
if stop is not None: |
|
|
|
completion = enforce_stop_tokens(completion, stop) |
|
|
|
|
|
|
|
return completion |
|
|
|
response = self._client.post(request) |
|
|
|
return self._create_chat_result(response) |
|
|
|
|
|
|
|
def _stream( |
|
|
|
self, |
|
|
|
prompt: str, |
|
|
|
stop: Optional[List[str]] = None, |
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
|
|
**kwargs: Any, |
|
|
|
) -> Iterator[GenerationChunk]: |
|
|
|
r"""Call wenxin completion_stream and return the resulting generator. |
|
|
|
|
|
|
|
Args: |
|
|
|
prompt: The prompt to pass into the model. |
|
|
|
stop: Optional list of stop words to use when generating. |
|
|
|
Returns: |
|
|
|
A generator representing the stream of tokens from Wenxin. |
|
|
|
Example: |
|
|
|
.. code-block:: python |
|
|
|
|
|
|
|
prompt = "Write a poem about a stream." |
|
|
|
prompt = f"\n\nHuman: {prompt}\n\nAssistant:" |
|
|
|
generator = wenxin.stream(prompt) |
|
|
|
for token in generator: |
|
|
|
yield token |
|
|
|
""" |
|
|
|
self, |
|
|
|
messages: List[BaseMessage], |
|
|
|
stop: Optional[List[str]] = None, |
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
|
|
**kwargs: Any, |
|
|
|
) -> Iterator[ChatGenerationChunk]: |
|
|
|
message_dicts, system = self._create_message_dicts(messages) |
|
|
|
request = self._default_params |
|
|
|
request["messages"] = [{"role": "user", "content": prompt}] |
|
|
|
request["messages"] = message_dicts |
|
|
|
if system: |
|
|
|
request["system"] = system |
|
|
|
request.update(kwargs) |
|
|
|
|
|
|
|
for token in self._client.post(request).iter_lines(): |
|
|
|
@@ -228,12 +257,18 @@ class Wenxin(LLM): |
|
|
|
if token.startswith('data:'): |
|
|
|
completion = json.loads(token[5:]) |
|
|
|
|
|
|
|
yield GenerationChunk(text=completion['result']) |
|
|
|
if run_manager: |
|
|
|
run_manager.on_llm_new_token(completion['result']) |
|
|
|
chunk_dict = { |
|
|
|
'message': AIMessageChunk(content=completion['result']), |
|
|
|
} |
|
|
|
|
|
|
|
if completion['is_end']: |
|
|
|
break |
|
|
|
token_usage = completion['usage'] |
|
|
|
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens'] |
|
|
|
chunk_dict['generation_info'] = dict({'token_usage': token_usage}) |
|
|
|
|
|
|
|
yield ChatGenerationChunk(**chunk_dict) |
|
|
|
if run_manager: |
|
|
|
run_manager.on_llm_new_token(completion['result']) |
|
|
|
else: |
|
|
|
try: |
|
|
|
json_response = json.loads(token) |
|
|
|
@@ -245,3 +280,40 @@ class Wenxin(LLM): |
|
|
|
f" error: {json_response['error_msg']}, " |
|
|
|
f"please confirm if the model you have chosen is already paid for." |
|
|
|
) |
|
|
|
|
|
|
|
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult: |
|
|
|
generations = [ChatGeneration( |
|
|
|
message=AIMessage(content=response['result']), |
|
|
|
)] |
|
|
|
token_usage = response.get("usage") |
|
|
|
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens'] |
|
|
|
|
|
|
|
llm_output = {"token_usage": token_usage, "model_name": self.model} |
|
|
|
return ChatResult(generations=generations, llm_output=llm_output) |
|
|
|
|
|
|
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: |
|
|
|
"""Get the number of tokens in the messages. |
|
|
|
|
|
|
|
Useful for checking if an input will fit in a model's context window. |
|
|
|
|
|
|
|
Args: |
|
|
|
messages: The message inputs to tokenize. |
|
|
|
|
|
|
|
Returns: |
|
|
|
The sum of the number of tokens across the messages. |
|
|
|
""" |
|
|
|
return sum([self.get_num_tokens(m.content) for m in messages]) |
|
|
|
|
|
|
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: |
|
|
|
overall_token_usage: dict = {} |
|
|
|
for output in llm_outputs: |
|
|
|
if output is None: |
|
|
|
# Happens in streaming |
|
|
|
continue |
|
|
|
token_usage = output["token_usage"] |
|
|
|
for k, v in token_usage.items(): |
|
|
|
if k in overall_token_usage: |
|
|
|
overall_token_usage[k] += v |
|
|
|
else: |
|
|
|
overall_token_usage[k] = v |
|
|
|
return {"token_usage": overall_token_usage, "model_name": self.model} |