| from collections.abc import Generator | from collections.abc import Generator | ||||
| from typing import Optional, Union | from typing import Optional, Union | ||||
| from urllib.parse import urlparse | |||||
| import tiktoken | |||||
| from core.model_runtime.entities.llm_entities import LLMResult | from core.model_runtime.entities.llm_entities import LLMResult | ||||
| from core.model_runtime.entities.message_entities import ( | from core.model_runtime.entities.message_entities import ( | ||||
| PromptMessage, | PromptMessage, | ||||
| PromptMessageTool, | PromptMessageTool, | ||||
| SystemPromptMessage, | |||||
| ) | ) | ||||
| from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel | |||||
| from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel | |||||
| class YiLargeLanguageModel(OAIAPICompatLargeLanguageModel): | |||||
| class YiLargeLanguageModel(OpenAILargeLanguageModel): | |||||
| def _invoke(self, model: str, credentials: dict, | def _invoke(self, model: str, credentials: dict, | ||||
| prompt_messages: list[PromptMessage], model_parameters: dict, | prompt_messages: list[PromptMessage], model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | ||||
| stream: bool = True, user: Optional[str] = None) \ | stream: bool = True, user: Optional[str] = None) \ | ||||
| -> Union[LLMResult, Generator]: | -> Union[LLMResult, Generator]: | ||||
| self._add_custom_parameters(credentials) | self._add_custom_parameters(credentials) | ||||
| # yi-vl-plus not support system prompt yet. | |||||
| if model == "yi-vl-plus": | |||||
| prompt_message_except_system: list[PromptMessage] = [] | |||||
| for message in prompt_messages: | |||||
| if not isinstance(message, SystemPromptMessage): | |||||
| prompt_message_except_system.append(message) | |||||
| return super()._invoke(model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream) | |||||
| return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) | return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) | ||||
| def validate_credentials(self, model: str, credentials: dict) -> None: | def validate_credentials(self, model: str, credentials: dict) -> None: | ||||
| self._add_custom_parameters(credentials) | self._add_custom_parameters(credentials) | ||||
| super().validate_credentials(model, credentials) | super().validate_credentials(model, credentials) | ||||
| # refactored from openai model runtime, use cl100k_base for calculate token number | |||||
| def _num_tokens_from_string(self, model: str, text: str, | |||||
| tools: Optional[list[PromptMessageTool]] = None) -> int: | |||||
| """ | |||||
| Calculate num tokens for text completion model with tiktoken package. | |||||
| :param model: model name | |||||
| :param text: prompt text | |||||
| :param tools: tools for tool calling | |||||
| :return: number of tokens | |||||
| """ | |||||
| encoding = tiktoken.get_encoding("cl100k_base") | |||||
| num_tokens = len(encoding.encode(text)) | |||||
| if tools: | |||||
| num_tokens += self._num_tokens_for_tools(encoding, tools) | |||||
| return num_tokens | |||||
| # refactored from openai model runtime, use cl100k_base for calculate token number | |||||
| def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], | |||||
| tools: Optional[list[PromptMessageTool]] = None) -> int: | |||||
| """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. | |||||
| Official documentation: https://github.com/openai/openai-cookbook/blob/ | |||||
| main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" | |||||
| encoding = tiktoken.get_encoding("cl100k_base") | |||||
| tokens_per_message = 3 | |||||
| tokens_per_name = 1 | |||||
| num_tokens = 0 | |||||
| messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] | |||||
| for message in messages_dict: | |||||
| num_tokens += tokens_per_message | |||||
| for key, value in message.items(): | |||||
| # Cast str(value) in case the message value is not a string | |||||
| # This occurs with function messages | |||||
| # TODO: The current token calculation method for the image type is not implemented, | |||||
| # which need to download the image and then get the resolution for calculation, | |||||
| # and will increase the request delay | |||||
| if isinstance(value, list): | |||||
| text = '' | |||||
| for item in value: | |||||
| if isinstance(item, dict) and item['type'] == 'text': | |||||
| text += item['text'] | |||||
| value = text | |||||
| if key == "tool_calls": | |||||
| for tool_call in value: | |||||
| for t_key, t_value in tool_call.items(): | |||||
| num_tokens += len(encoding.encode(t_key)) | |||||
| if t_key == "function": | |||||
| for f_key, f_value in t_value.items(): | |||||
| num_tokens += len(encoding.encode(f_key)) | |||||
| num_tokens += len(encoding.encode(f_value)) | |||||
| else: | |||||
| num_tokens += len(encoding.encode(t_key)) | |||||
| num_tokens += len(encoding.encode(t_value)) | |||||
| else: | |||||
| num_tokens += len(encoding.encode(str(value))) | |||||
| if key == "name": | |||||
| num_tokens += tokens_per_name | |||||
| # every reply is primed with <im_start>assistant | |||||
| num_tokens += 3 | |||||
| if tools: | |||||
| num_tokens += self._num_tokens_for_tools(encoding, tools) | |||||
| return num_tokens | |||||
| @staticmethod | @staticmethod | ||||
| def _add_custom_parameters(credentials: dict) -> None: | def _add_custom_parameters(credentials: dict) -> None: | ||||
| credentials['mode'] = 'chat' | credentials['mode'] = 'chat' | ||||
| credentials['openai_api_key']=credentials['api_key'] | |||||
| if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": | if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": | ||||
| credentials['endpoint_url'] = 'https://api.lingyiwanwu.com/v1' | |||||
| credentials['openai_api_base']='https://api.lingyiwanwu.com' | |||||
| else: | |||||
| parsed_url = urlparse(credentials['endpoint_url']) | |||||
| credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" |