| @@ -1,30 +1,119 @@ | |||
| from collections.abc import Generator | |||
| from typing import Optional, Union | |||
| from urllib.parse import urlparse | |||
| import tiktoken | |||
| from core.model_runtime.entities.llm_entities import LLMResult | |||
| from core.model_runtime.entities.message_entities import ( | |||
| PromptMessage, | |||
| PromptMessageTool, | |||
| SystemPromptMessage, | |||
| ) | |||
| from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel | |||
| from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel | |||
| class YiLargeLanguageModel(OAIAPICompatLargeLanguageModel): | |||
| class YiLargeLanguageModel(OpenAILargeLanguageModel): | |||
| def _invoke(self, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) \ | |||
| -> Union[LLMResult, Generator]: | |||
| self._add_custom_parameters(credentials) | |||
| # yi-vl-plus not support system prompt yet. | |||
| if model == "yi-vl-plus": | |||
| prompt_message_except_system: list[PromptMessage] = [] | |||
| for message in prompt_messages: | |||
| if not isinstance(message, SystemPromptMessage): | |||
| prompt_message_except_system.append(message) | |||
| return super()._invoke(model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream) | |||
| return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) | |||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||
| self._add_custom_parameters(credentials) | |||
| super().validate_credentials(model, credentials) | |||
| # refactored from openai model runtime, use cl100k_base for calculate token number | |||
| def _num_tokens_from_string(self, model: str, text: str, | |||
| tools: Optional[list[PromptMessageTool]] = None) -> int: | |||
| """ | |||
| Calculate num tokens for text completion model with tiktoken package. | |||
| :param model: model name | |||
| :param text: prompt text | |||
| :param tools: tools for tool calling | |||
| :return: number of tokens | |||
| """ | |||
| encoding = tiktoken.get_encoding("cl100k_base") | |||
| num_tokens = len(encoding.encode(text)) | |||
| if tools: | |||
| num_tokens += self._num_tokens_for_tools(encoding, tools) | |||
| return num_tokens | |||
| # refactored from openai model runtime, use cl100k_base for calculate token number | |||
| def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], | |||
| tools: Optional[list[PromptMessageTool]] = None) -> int: | |||
| """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. | |||
| Official documentation: https://github.com/openai/openai-cookbook/blob/ | |||
| main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" | |||
| encoding = tiktoken.get_encoding("cl100k_base") | |||
| tokens_per_message = 3 | |||
| tokens_per_name = 1 | |||
| num_tokens = 0 | |||
| messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] | |||
| for message in messages_dict: | |||
| num_tokens += tokens_per_message | |||
| for key, value in message.items(): | |||
| # Cast str(value) in case the message value is not a string | |||
| # This occurs with function messages | |||
| # TODO: The current token calculation method for the image type is not implemented, | |||
| # which need to download the image and then get the resolution for calculation, | |||
| # and will increase the request delay | |||
| if isinstance(value, list): | |||
| text = '' | |||
| for item in value: | |||
| if isinstance(item, dict) and item['type'] == 'text': | |||
| text += item['text'] | |||
| value = text | |||
| if key == "tool_calls": | |||
| for tool_call in value: | |||
| for t_key, t_value in tool_call.items(): | |||
| num_tokens += len(encoding.encode(t_key)) | |||
| if t_key == "function": | |||
| for f_key, f_value in t_value.items(): | |||
| num_tokens += len(encoding.encode(f_key)) | |||
| num_tokens += len(encoding.encode(f_value)) | |||
| else: | |||
| num_tokens += len(encoding.encode(t_key)) | |||
| num_tokens += len(encoding.encode(t_value)) | |||
| else: | |||
| num_tokens += len(encoding.encode(str(value))) | |||
| if key == "name": | |||
| num_tokens += tokens_per_name | |||
| # every reply is primed with <im_start>assistant | |||
| num_tokens += 3 | |||
| if tools: | |||
| num_tokens += self._num_tokens_for_tools(encoding, tools) | |||
| return num_tokens | |||
| @staticmethod | |||
| def _add_custom_parameters(credentials: dict) -> None: | |||
| credentials['mode'] = 'chat' | |||
| credentials['openai_api_key']=credentials['api_key'] | |||
| if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": | |||
| credentials['endpoint_url'] = 'https://api.lingyiwanwu.com/v1' | |||
| credentials['openai_api_base']='https://api.lingyiwanwu.com' | |||
| else: | |||
| parsed_url = urlparse(credentials['endpoint_url']) | |||
| credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" | |||