| @@ -1,5 +1,6 @@ | |||
| import logging | |||
| import os | |||
| import re | |||
| import time | |||
| from abc import abstractmethod | |||
| from typing import Generator, List, Optional, Union | |||
| @@ -212,6 +213,10 @@ class LargeLanguageModel(AIModel): | |||
| """ | |||
| raise NotImplementedError | |||
| def enforce_stop_tokens(self, text: str, stop: List[str]) -> str: | |||
| """Cut off the text as soon as any stop words occur.""" | |||
| return re.split("|".join(stop), text, maxsplit=1)[0] | |||
| def _llm_result_to_stream(self, result: LLMResult) -> Generator: | |||
| """ | |||
| Transform llm result to stream | |||
| @@ -14,9 +14,12 @@ help: | |||
| url: | |||
| en_US: https://dashboard.cohere.com/api-keys | |||
| supported_model_types: | |||
| - llm | |||
| - text-embedding | |||
| - rerank | |||
| configurate_methods: | |||
| - predefined-model | |||
| - customizable-model | |||
| provider_credential_schema: | |||
| credential_form_schemas: | |||
| - variable: api_key | |||
| @@ -26,6 +29,44 @@ provider_credential_schema: | |||
| type: secret-input | |||
| required: true | |||
| placeholder: | |||
| zh_Hans: 请填写 API Key | |||
| en_US: Please fill in API Key | |||
| zh_Hans: 在此输入您的 API Key | |||
| en_US: Enter your API Key | |||
| show_on: [ ] | |||
| model_credential_schema: | |||
| model: | |||
| label: | |||
| en_US: Model Name | |||
| zh_Hans: 模型名称 | |||
| placeholder: | |||
| en_US: Enter your model name | |||
| zh_Hans: 输入模型名称 | |||
| credential_form_schemas: | |||
| - variable: mode | |||
| show_on: | |||
| - variable: __model_type | |||
| value: llm | |||
| label: | |||
| en_US: Completion mode | |||
| type: select | |||
| required: false | |||
| default: chat | |||
| placeholder: | |||
| zh_Hans: 选择对话类型 | |||
| en_US: Select completion mode | |||
| options: | |||
| - value: completion | |||
| label: | |||
| en_US: Completion | |||
| zh_Hans: 补全 | |||
| - value: chat | |||
| label: | |||
| en_US: Chat | |||
| zh_Hans: 对话 | |||
| - variable: api_key | |||
| label: | |||
| en_US: API Key | |||
| type: secret-input | |||
| required: true | |||
| placeholder: | |||
| zh_Hans: 在此输入您的 API Key | |||
| en_US: Enter your API Key | |||
| @@ -0,0 +1,8 @@ | |||
| - command-chat | |||
| - command-light-chat | |||
| - command-nightly-chat | |||
| - command-light-nightly-chat | |||
| - command | |||
| - command-light | |||
| - command-nightly | |||
| - command-light-nightly | |||
| @@ -0,0 +1,62 @@ | |||
| model: command-chat | |||
| label: | |||
| zh_Hans: command-chat | |||
| en_US: command-chat | |||
| model_type: llm | |||
| features: | |||
| - agent-thought | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 4096 | |||
| parameter_rules: | |||
| - name: temperature | |||
| use_template: temperature | |||
| max: 5.0 | |||
| - name: p | |||
| use_template: top_p | |||
| default: 0.75 | |||
| min: 0.01 | |||
| max: 0.99 | |||
| - name: k | |||
| label: | |||
| zh_Hans: 取样数量 | |||
| en_US: Top k | |||
| type: int | |||
| help: | |||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||
| en_US: Only sample from the top K options for each subsequent token. | |||
| required: false | |||
| default: 0 | |||
| min: 0 | |||
| max: 500 | |||
| - name: max_tokens | |||
| use_template: max_tokens | |||
| default: 256 | |||
| max: 4096 | |||
| - name: preamble_override | |||
| label: | |||
| zh_Hans: 前导文本 | |||
| en_US: Preamble | |||
| type: string | |||
| help: | |||
| zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。 | |||
| en_US: When specified, the default Cohere preamble will be replaced with the provided one. | |||
| required: false | |||
| - name: prompt_truncation | |||
| label: | |||
| zh_Hans: 提示截断 | |||
| en_US: Prompt Truncation | |||
| type: string | |||
| help: | |||
| zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。 | |||
| en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit. | |||
| required: true | |||
| default: 'AUTO' | |||
| options: | |||
| - 'AUTO' | |||
| - 'OFF' | |||
| pricing: | |||
| input: '1.0' | |||
| output: '2.0' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,62 @@ | |||
| model: command-light-chat | |||
| label: | |||
| zh_Hans: command-light-chat | |||
| en_US: command-light-chat | |||
| model_type: llm | |||
| features: | |||
| - agent-thought | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 4096 | |||
| parameter_rules: | |||
| - name: temperature | |||
| use_template: temperature | |||
| max: 5.0 | |||
| - name: p | |||
| use_template: top_p | |||
| default: 0.75 | |||
| min: 0.01 | |||
| max: 0.99 | |||
| - name: k | |||
| label: | |||
| zh_Hans: 取样数量 | |||
| en_US: Top k | |||
| type: int | |||
| help: | |||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||
| en_US: Only sample from the top K options for each subsequent token. | |||
| required: false | |||
| default: 0 | |||
| min: 0 | |||
| max: 500 | |||
| - name: max_tokens | |||
| use_template: max_tokens | |||
| default: 256 | |||
| max: 4096 | |||
| - name: preamble_override | |||
| label: | |||
| zh_Hans: 前导文本 | |||
| en_US: Preamble | |||
| type: string | |||
| help: | |||
| zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。 | |||
| en_US: When specified, the default Cohere preamble will be replaced with the provided one. | |||
| required: false | |||
| - name: prompt_truncation | |||
| label: | |||
| zh_Hans: 提示截断 | |||
| en_US: Prompt Truncation | |||
| type: string | |||
| help: | |||
| zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。 | |||
| en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit. | |||
| required: true | |||
| default: 'AUTO' | |||
| options: | |||
| - 'AUTO' | |||
| - 'OFF' | |||
| pricing: | |||
| input: '0.3' | |||
| output: '0.6' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,62 @@ | |||
| model: command-light-nightly-chat | |||
| label: | |||
| zh_Hans: command-light-nightly-chat | |||
| en_US: command-light-nightly-chat | |||
| model_type: llm | |||
| features: | |||
| - agent-thought | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 4096 | |||
| parameter_rules: | |||
| - name: temperature | |||
| use_template: temperature | |||
| max: 5.0 | |||
| - name: p | |||
| use_template: top_p | |||
| default: 0.75 | |||
| min: 0.01 | |||
| max: 0.99 | |||
| - name: k | |||
| label: | |||
| zh_Hans: 取样数量 | |||
| en_US: Top k | |||
| type: int | |||
| help: | |||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||
| en_US: Only sample from the top K options for each subsequent token. | |||
| required: false | |||
| default: 0 | |||
| min: 0 | |||
| max: 500 | |||
| - name: max_tokens | |||
| use_template: max_tokens | |||
| default: 256 | |||
| max: 4096 | |||
| - name: preamble_override | |||
| label: | |||
| zh_Hans: 前导文本 | |||
| en_US: Preamble | |||
| type: string | |||
| help: | |||
| zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。 | |||
| en_US: When specified, the default Cohere preamble will be replaced with the provided one. | |||
| required: false | |||
| - name: prompt_truncation | |||
| label: | |||
| zh_Hans: 提示截断 | |||
| en_US: Prompt Truncation | |||
| type: string | |||
| help: | |||
| zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。 | |||
| en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit. | |||
| required: true | |||
| default: 'AUTO' | |||
| options: | |||
| - 'AUTO' | |||
| - 'OFF' | |||
| pricing: | |||
| input: '0.3' | |||
| output: '0.6' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,44 @@ | |||
| model: command-light-nightly | |||
| label: | |||
| zh_Hans: command-light-nightly | |||
| en_US: command-light-nightly | |||
| model_type: llm | |||
| features: | |||
| - agent-thought | |||
| model_properties: | |||
| mode: completion | |||
| context_size: 4096 | |||
| parameter_rules: | |||
| - name: temperature | |||
| use_template: temperature | |||
| max: 5.0 | |||
| - name: p | |||
| use_template: top_p | |||
| default: 0.75 | |||
| min: 0.01 | |||
| max: 0.99 | |||
| - name: k | |||
| label: | |||
| zh_Hans: 取样数量 | |||
| en_US: Top k | |||
| type: int | |||
| help: | |||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||
| en_US: Only sample from the top K options for each subsequent token. | |||
| required: false | |||
| default: 0 | |||
| min: 0 | |||
| max: 500 | |||
| - name: presence_penalty | |||
| use_template: presence_penalty | |||
| - name: frequency_penalty | |||
| use_template: frequency_penalty | |||
| - name: max_tokens | |||
| use_template: max_tokens | |||
| default: 256 | |||
| max: 4096 | |||
| pricing: | |||
| input: '0.3' | |||
| output: '0.6' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,44 @@ | |||
| model: command-light | |||
| label: | |||
| zh_Hans: command-light | |||
| en_US: command-light | |||
| model_type: llm | |||
| features: | |||
| - agent-thought | |||
| model_properties: | |||
| mode: completion | |||
| context_size: 4096 | |||
| parameter_rules: | |||
| - name: temperature | |||
| use_template: temperature | |||
| max: 5.0 | |||
| - name: p | |||
| use_template: top_p | |||
| default: 0.75 | |||
| min: 0.01 | |||
| max: 0.99 | |||
| - name: k | |||
| label: | |||
| zh_Hans: 取样数量 | |||
| en_US: Top k | |||
| type: int | |||
| help: | |||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||
| en_US: Only sample from the top K options for each subsequent token. | |||
| required: false | |||
| default: 0 | |||
| min: 0 | |||
| max: 500 | |||
| - name: presence_penalty | |||
| use_template: presence_penalty | |||
| - name: frequency_penalty | |||
| use_template: frequency_penalty | |||
| - name: max_tokens | |||
| use_template: max_tokens | |||
| default: 256 | |||
| max: 4096 | |||
| pricing: | |||
| input: '0.3' | |||
| output: '0.6' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,62 @@ | |||
| model: command-nightly-chat | |||
| label: | |||
| zh_Hans: command-nightly-chat | |||
| en_US: command-nightly-chat | |||
| model_type: llm | |||
| features: | |||
| - agent-thought | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 4096 | |||
| parameter_rules: | |||
| - name: temperature | |||
| use_template: temperature | |||
| max: 5.0 | |||
| - name: p | |||
| use_template: top_p | |||
| default: 0.75 | |||
| min: 0.01 | |||
| max: 0.99 | |||
| - name: k | |||
| label: | |||
| zh_Hans: 取样数量 | |||
| en_US: Top k | |||
| type: int | |||
| help: | |||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||
| en_US: Only sample from the top K options for each subsequent token. | |||
| required: false | |||
| default: 0 | |||
| min: 0 | |||
| max: 500 | |||
| - name: max_tokens | |||
| use_template: max_tokens | |||
| default: 256 | |||
| max: 4096 | |||
| - name: preamble_override | |||
| label: | |||
| zh_Hans: 前导文本 | |||
| en_US: Preamble | |||
| type: string | |||
| help: | |||
| zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。 | |||
| en_US: When specified, the default Cohere preamble will be replaced with the provided one. | |||
| required: false | |||
| - name: prompt_truncation | |||
| label: | |||
| zh_Hans: 提示截断 | |||
| en_US: Prompt Truncation | |||
| type: string | |||
| help: | |||
| zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。 | |||
| en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit. | |||
| required: true | |||
| default: 'AUTO' | |||
| options: | |||
| - 'AUTO' | |||
| - 'OFF' | |||
| pricing: | |||
| input: '1.0' | |||
| output: '2.0' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,44 @@ | |||
| model: command-nightly | |||
| label: | |||
| zh_Hans: command-nightly | |||
| en_US: command-nightly | |||
| model_type: llm | |||
| features: | |||
| - agent-thought | |||
| model_properties: | |||
| mode: completion | |||
| context_size: 4096 | |||
| parameter_rules: | |||
| - name: temperature | |||
| use_template: temperature | |||
| max: 5.0 | |||
| - name: p | |||
| use_template: top_p | |||
| default: 0.75 | |||
| min: 0.01 | |||
| max: 0.99 | |||
| - name: k | |||
| label: | |||
| zh_Hans: 取样数量 | |||
| en_US: Top k | |||
| type: int | |||
| help: | |||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||
| en_US: Only sample from the top K options for each subsequent token. | |||
| required: false | |||
| default: 0 | |||
| min: 0 | |||
| max: 500 | |||
| - name: presence_penalty | |||
| use_template: presence_penalty | |||
| - name: frequency_penalty | |||
| use_template: frequency_penalty | |||
| - name: max_tokens | |||
| use_template: max_tokens | |||
| default: 256 | |||
| max: 4096 | |||
| pricing: | |||
| input: '1.0' | |||
| output: '2.0' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,44 @@ | |||
| model: command | |||
| label: | |||
| zh_Hans: command | |||
| en_US: command | |||
| model_type: llm | |||
| features: | |||
| - agent-thought | |||
| model_properties: | |||
| mode: completion | |||
| context_size: 4096 | |||
| parameter_rules: | |||
| - name: temperature | |||
| use_template: temperature | |||
| max: 5.0 | |||
| - name: p | |||
| use_template: top_p | |||
| default: 0.75 | |||
| min: 0.01 | |||
| max: 0.99 | |||
| - name: k | |||
| label: | |||
| zh_Hans: 取样数量 | |||
| en_US: Top k | |||
| type: int | |||
| help: | |||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||
| en_US: Only sample from the top K options for each subsequent token. | |||
| required: false | |||
| default: 0 | |||
| min: 0 | |||
| max: 500 | |||
| - name: presence_penalty | |||
| use_template: presence_penalty | |||
| - name: frequency_penalty | |||
| use_template: frequency_penalty | |||
| - name: max_tokens | |||
| use_template: max_tokens | |||
| default: 256 | |||
| max: 4096 | |||
| pricing: | |||
| input: '1.0' | |||
| output: '2.0' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,565 @@ | |||
| import logging | |||
| from typing import Generator, List, Optional, Union, cast, Tuple | |||
| import cohere | |||
| from cohere.responses import Chat, Generations | |||
| from cohere.responses.chat import StreamingChat, StreamTextGeneration, StreamEnd | |||
| from cohere.responses.generation import StreamingText, StreamingGenerations | |||
| from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, | |||
| PromptMessageContentType, SystemPromptMessage, | |||
| TextPromptMessageContent, UserPromptMessage, | |||
| PromptMessageTool) | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType | |||
| from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeError, \ | |||
| InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| logger = logging.getLogger(__name__) | |||
| class CohereLargeLanguageModel(LargeLanguageModel): | |||
| """ | |||
| Model class for Cohere large language model. | |||
| """ | |||
| def _invoke(self, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) \ | |||
| -> Union[LLMResult, Generator]: | |||
| """ | |||
| Invoke large language model | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param prompt_messages: prompt messages | |||
| :param model_parameters: model parameters | |||
| :param tools: tools for tool calling | |||
| :param stop: stop words | |||
| :param stream: is stream response | |||
| :param user: unique user id | |||
| :return: full response or stream response chunk generator result | |||
| """ | |||
| # get model mode | |||
| model_mode = self.get_model_mode(model, credentials) | |||
| if model_mode == LLMMode.CHAT: | |||
| return self._chat_generate( | |||
| model=model, | |||
| credentials=credentials, | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=model_parameters, | |||
| stop=stop, | |||
| stream=stream, | |||
| user=user | |||
| ) | |||
| else: | |||
| return self._generate( | |||
| model=model, | |||
| credentials=credentials, | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=model_parameters, | |||
| stop=stop, | |||
| stream=stream, | |||
| user=user | |||
| ) | |||
| def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], | |||
| tools: Optional[list[PromptMessageTool]] = None) -> int: | |||
| """ | |||
| Get number of tokens for given prompt messages | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param prompt_messages: prompt messages | |||
| :param tools: tools for tool calling | |||
| :return: | |||
| """ | |||
| # get model mode | |||
| model_mode = self.get_model_mode(model) | |||
| try: | |||
| if model_mode == LLMMode.CHAT: | |||
| return self._num_tokens_from_messages(model, credentials, prompt_messages) | |||
| else: | |||
| return self._num_tokens_from_string(model, credentials, prompt_messages[0].content) | |||
| except Exception as e: | |||
| raise self._transform_invoke_error(e) | |||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||
| """ | |||
| Validate model credentials | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :return: | |||
| """ | |||
| try: | |||
| # get model mode | |||
| model_mode = self.get_model_mode(model) | |||
| if model_mode == LLMMode.CHAT: | |||
| self._chat_generate( | |||
| model=model, | |||
| credentials=credentials, | |||
| prompt_messages=[UserPromptMessage(content='ping')], | |||
| model_parameters={ | |||
| 'max_tokens': 20, | |||
| 'temperature': 0, | |||
| }, | |||
| stream=False | |||
| ) | |||
| else: | |||
| self._generate( | |||
| model=model, | |||
| credentials=credentials, | |||
| prompt_messages=[UserPromptMessage(content='ping')], | |||
| model_parameters={ | |||
| 'max_tokens': 20, | |||
| 'temperature': 0, | |||
| }, | |||
| stream=False | |||
| ) | |||
| except Exception as ex: | |||
| raise CredentialsValidateFailedError(str(ex)) | |||
| def _generate(self, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: | |||
| """ | |||
| Invoke llm model | |||
| :param model: model name | |||
| :param credentials: credentials | |||
| :param prompt_messages: prompt messages | |||
| :param model_parameters: model parameters | |||
| :param stop: stop words | |||
| :param stream: is stream response | |||
| :param user: unique user id | |||
| :return: full response or stream response chunk generator result | |||
| """ | |||
| # initialize client | |||
| client = cohere.Client(credentials.get('api_key')) | |||
| if stop: | |||
| model_parameters['end_sequences'] = stop | |||
| response = client.generate( | |||
| prompt=prompt_messages[0].content, | |||
| model=model, | |||
| stream=stream, | |||
| **model_parameters, | |||
| ) | |||
| if stream: | |||
| return self._handle_generate_stream_response(model, credentials, response, prompt_messages) | |||
| return self._handle_generate_response(model, credentials, response, prompt_messages) | |||
| def _handle_generate_response(self, model: str, credentials: dict, response: Generations, | |||
| prompt_messages: list[PromptMessage]) \ | |||
| -> LLMResult: | |||
| """ | |||
| Handle llm response | |||
| :param model: model name | |||
| :param credentials: credentials | |||
| :param response: response | |||
| :param prompt_messages: prompt messages | |||
| :return: llm response | |||
| """ | |||
| assistant_text = response.generations[0].text | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=assistant_text | |||
| ) | |||
| # calculate num tokens | |||
| prompt_tokens = response.meta['billed_units']['input_tokens'] | |||
| completion_tokens = response.meta['billed_units']['output_tokens'] | |||
| # transform usage | |||
| usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) | |||
| # transform response | |||
| response = LLMResult( | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| message=assistant_prompt_message, | |||
| usage=usage | |||
| ) | |||
| return response | |||
| def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations, | |||
| prompt_messages: list[PromptMessage]) -> Generator: | |||
| """ | |||
| Handle llm stream response | |||
| :param model: model name | |||
| :param response: response | |||
| :param prompt_messages: prompt messages | |||
| :return: llm response chunk generator | |||
| """ | |||
| index = 1 | |||
| full_assistant_content = '' | |||
| for chunk in response: | |||
| if isinstance(chunk, StreamingText): | |||
| chunk = cast(StreamingText, chunk) | |||
| text = chunk.text | |||
| if text is None: | |||
| continue | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=text | |||
| ) | |||
| full_assistant_content += text | |||
| yield LLMResultChunk( | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| delta=LLMResultChunkDelta( | |||
| index=index, | |||
| message=assistant_prompt_message, | |||
| ) | |||
| ) | |||
| index += 1 | |||
| elif chunk is None: | |||
| # calculate num tokens | |||
| prompt_tokens = response.meta['billed_units']['input_tokens'] | |||
| completion_tokens = response.meta['billed_units']['output_tokens'] | |||
| # transform usage | |||
| usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) | |||
| yield LLMResultChunk( | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| delta=LLMResultChunkDelta( | |||
| index=index, | |||
| message=AssistantPromptMessage(content=''), | |||
| finish_reason=response.finish_reason, | |||
| usage=usage | |||
| ) | |||
| ) | |||
| break | |||
| def _chat_generate(self, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: | |||
| """ | |||
| Invoke llm chat model | |||
| :param model: model name | |||
| :param credentials: credentials | |||
| :param prompt_messages: prompt messages | |||
| :param model_parameters: model parameters | |||
| :param stop: stop words | |||
| :param stream: is stream response | |||
| :param user: unique user id | |||
| :return: full response or stream response chunk generator result | |||
| """ | |||
| # initialize client | |||
| client = cohere.Client(credentials.get('api_key')) | |||
| if user: | |||
| model_parameters['user_name'] = user | |||
| message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) | |||
| # chat model | |||
| real_model = model | |||
| if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: | |||
| real_model = model.removesuffix('-chat') | |||
| response = client.chat( | |||
| message=message, | |||
| chat_history=chat_histories, | |||
| model=real_model, | |||
| stream=stream, | |||
| return_preamble=True, | |||
| **model_parameters, | |||
| ) | |||
| if stream: | |||
| return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop) | |||
| return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop) | |||
| def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat, | |||
| prompt_messages: list[PromptMessage], stop: Optional[List[str]] = None) \ | |||
| -> LLMResult: | |||
| """ | |||
| Handle llm chat response | |||
| :param model: model name | |||
| :param credentials: credentials | |||
| :param response: response | |||
| :param prompt_messages: prompt messages | |||
| :param stop: stop words | |||
| :return: llm response | |||
| """ | |||
| assistant_text = response.text | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=assistant_text | |||
| ) | |||
| # calculate num tokens | |||
| prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) | |||
| completion_tokens = self._num_tokens_from_messages(model, credentials, [assistant_prompt_message]) | |||
| # transform usage | |||
| usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) | |||
| if stop: | |||
| # enforce stop tokens | |||
| assistant_text = self.enforce_stop_tokens(assistant_text, stop) | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=assistant_text | |||
| ) | |||
| # transform response | |||
| response = LLMResult( | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| message=assistant_prompt_message, | |||
| usage=usage, | |||
| system_fingerprint=response.preamble | |||
| ) | |||
| return response | |||
| def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat, | |||
| prompt_messages: list[PromptMessage], | |||
| stop: Optional[List[str]] = None) -> Generator: | |||
| """ | |||
| Handle llm chat stream response | |||
| :param model: model name | |||
| :param response: response | |||
| :param prompt_messages: prompt messages | |||
| :param stop: stop words | |||
| :return: llm response chunk generator | |||
| """ | |||
| def final_response(full_text: str, index: int, finish_reason: Optional[str] = None, | |||
| preamble: Optional[str] = None) -> LLMResultChunk: | |||
| # calculate num tokens | |||
| prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) | |||
| full_assistant_prompt_message = AssistantPromptMessage( | |||
| content=full_text | |||
| ) | |||
| completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message]) | |||
| # transform usage | |||
| usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) | |||
| return LLMResultChunk( | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| system_fingerprint=preamble, | |||
| delta=LLMResultChunkDelta( | |||
| index=index, | |||
| message=AssistantPromptMessage(content=''), | |||
| finish_reason=finish_reason, | |||
| usage=usage | |||
| ) | |||
| ) | |||
| index = 1 | |||
| full_assistant_content = '' | |||
| for chunk in response: | |||
| if isinstance(chunk, StreamTextGeneration): | |||
| chunk = cast(StreamTextGeneration, chunk) | |||
| text = chunk.text | |||
| if text is None: | |||
| continue | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=text | |||
| ) | |||
| # stop | |||
| # notice: This logic can only cover few stop scenarios | |||
| if stop and text in stop: | |||
| yield final_response(full_assistant_content, index, 'stop') | |||
| break | |||
| full_assistant_content += text | |||
| yield LLMResultChunk( | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| delta=LLMResultChunkDelta( | |||
| index=index, | |||
| message=assistant_prompt_message, | |||
| ) | |||
| ) | |||
| index += 1 | |||
| elif isinstance(chunk, StreamEnd): | |||
| chunk = cast(StreamEnd, chunk) | |||
| yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble) | |||
| index += 1 | |||
| def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ | |||
| -> Tuple[str, list[dict]]: | |||
| """ | |||
| Convert prompt messages to message and chat histories | |||
| :param prompt_messages: prompt messages | |||
| :return: | |||
| """ | |||
| chat_histories = [] | |||
| for prompt_message in prompt_messages: | |||
| chat_histories.append(self._convert_prompt_message_to_dict(prompt_message)) | |||
| # get latest message from chat histories and pop it | |||
| if len(chat_histories) > 0: | |||
| latest_message = chat_histories.pop() | |||
| message = latest_message['message'] | |||
| else: | |||
| raise ValueError('Prompt messages is empty') | |||
| return message, chat_histories | |||
| def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: | |||
| """ | |||
| Convert PromptMessage to dict for Cohere model | |||
| """ | |||
| if isinstance(message, UserPromptMessage): | |||
| message = cast(UserPromptMessage, message) | |||
| if isinstance(message.content, str): | |||
| message_dict = {"role": "USER", "message": message.content} | |||
| else: | |||
| sub_message_text = '' | |||
| for message_content in message.content: | |||
| if message_content.type == PromptMessageContentType.TEXT: | |||
| message_content = cast(TextPromptMessageContent, message_content) | |||
| sub_message_text += message_content.data | |||
| message_dict = {"role": "USER", "message": sub_message_text} | |||
| elif isinstance(message, AssistantPromptMessage): | |||
| message = cast(AssistantPromptMessage, message) | |||
| message_dict = {"role": "CHATBOT", "message": message.content} | |||
| elif isinstance(message, SystemPromptMessage): | |||
| message = cast(SystemPromptMessage, message) | |||
| message_dict = {"role": "USER", "message": message.content} | |||
| else: | |||
| raise ValueError(f"Got unknown type {message}") | |||
| if message.name is not None: | |||
| message_dict["user_name"] = message.name | |||
| return message_dict | |||
| def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int: | |||
| """ | |||
| Calculate num tokens for text completion model. | |||
| :param model: model name | |||
| :param credentials: credentials | |||
| :param text: prompt text | |||
| :return: number of tokens | |||
| """ | |||
| # initialize client | |||
| client = cohere.Client(credentials.get('api_key')) | |||
| response = client.tokenize( | |||
| text=text, | |||
| model=model | |||
| ) | |||
| return response.length | |||
| def _num_tokens_from_messages(self, model: str, credentials: dict, messages: List[PromptMessage]) -> int: | |||
| """Calculate num tokens Cohere model.""" | |||
| messages = [self._convert_prompt_message_to_dict(m) for m in messages] | |||
| message_strs = [f"{message['role']}: {message['message']}" for message in messages] | |||
| message_str = "\n".join(message_strs) | |||
| real_model = model | |||
| if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: | |||
| real_model = model.removesuffix('-chat') | |||
| return self._num_tokens_from_string(real_model, credentials, message_str) | |||
| def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: | |||
| """ | |||
| Cohere supports fine-tuning of their models. This method returns the schema of the base model | |||
| but renamed to the fine-tuned model name. | |||
| :param model: model name | |||
| :param credentials: credentials | |||
| :return: model schema | |||
| """ | |||
| # get model schema | |||
| models = self.predefined_models() | |||
| model_map = {model.model: model for model in models} | |||
| mode = credentials.get('mode') | |||
| if mode == 'chat': | |||
| base_model_schema = model_map['command-light-chat'] | |||
| else: | |||
| base_model_schema = model_map['command-light'] | |||
| base_model_schema = cast(AIModelEntity, base_model_schema) | |||
| base_model_schema_features = base_model_schema.features or [] | |||
| base_model_schema_model_properties = base_model_schema.model_properties or {} | |||
| base_model_schema_parameters_rules = base_model_schema.parameter_rules or [] | |||
| entity = AIModelEntity( | |||
| model=model, | |||
| label=I18nObject( | |||
| zh_Hans=model, | |||
| en_US=model | |||
| ), | |||
| model_type=ModelType.LLM, | |||
| features=[feature for feature in base_model_schema_features], | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_properties={ | |||
| key: property for key, property in base_model_schema_model_properties.items() | |||
| }, | |||
| parameter_rules=[rule for rule in base_model_schema_parameters_rules], | |||
| pricing=base_model_schema.pricing | |||
| ) | |||
| return entity | |||
| @property | |||
| def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | |||
| """ | |||
| Map model invoke error to unified error | |||
| The key is the error type thrown to the caller | |||
| The value is the error type thrown by the model, | |||
| which needs to be converted into a unified error type for the caller. | |||
| :return: Invoke error mapping | |||
| """ | |||
| return { | |||
| InvokeConnectionError: [ | |||
| cohere.CohereConnectionError | |||
| ], | |||
| InvokeServerUnavailableError: [], | |||
| InvokeRateLimitError: [], | |||
| InvokeAuthorizationError: [], | |||
| InvokeBadRequestError: [ | |||
| cohere.CohereAPIError, | |||
| cohere.CohereError, | |||
| ] | |||
| } | |||
| @@ -0,0 +1,7 @@ | |||
| - embed-multilingual-v3.0 | |||
| - embed-multilingual-light-v3.0 | |||
| - embed-english-v3.0 | |||
| - embed-english-light-v3.0 | |||
| - embed-multilingual-v2.0 | |||
| - embed-english-v2.0 | |||
| - embed-english-light-v2.0 | |||
| @@ -0,0 +1,9 @@ | |||
| model: embed-english-light-v2.0 | |||
| model_type: text-embedding | |||
| model_properties: | |||
| context_size: 1024 | |||
| max_chunks: 48 | |||
| pricing: | |||
| input: '0.1' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,9 @@ | |||
| model: embed-english-light-v3.0 | |||
| model_type: text-embedding | |||
| model_properties: | |||
| context_size: 384 | |||
| max_chunks: 48 | |||
| pricing: | |||
| input: '0.1' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,9 @@ | |||
| model: embed-english-v2.0 | |||
| model_type: text-embedding | |||
| model_properties: | |||
| context_size: 4096 | |||
| max_chunks: 48 | |||
| pricing: | |||
| input: '0.1' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,9 @@ | |||
| model: embed-english-v3.0 | |||
| model_type: text-embedding | |||
| model_properties: | |||
| context_size: 1024 | |||
| max_chunks: 48 | |||
| pricing: | |||
| input: '0.1' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,9 @@ | |||
| model: embed-multilingual-light-v3.0 | |||
| model_type: text-embedding | |||
| model_properties: | |||
| context_size: 384 | |||
| max_chunks: 48 | |||
| pricing: | |||
| input: '0.1' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,9 @@ | |||
| model: embed-multilingual-v2.0 | |||
| model_type: text-embedding | |||
| model_properties: | |||
| context_size: 768 | |||
| max_chunks: 48 | |||
| pricing: | |||
| input: '0.1' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,9 @@ | |||
| model: embed-multilingual-v3.0 | |||
| model_type: text-embedding | |||
| model_properties: | |||
| context_size: 1024 | |||
| max_chunks: 48 | |||
| pricing: | |||
| input: '0.1' | |||
| unit: '0.000001' | |||
| currency: USD | |||
| @@ -0,0 +1,234 @@ | |||
| import time | |||
| from typing import Optional, Tuple | |||
| import cohere | |||
| import numpy as np | |||
| from cohere.responses import Tokens | |||
| from core.model_runtime.entities.model_entities import PriceType | |||
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | |||
| from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ | |||
| InvokeAuthorizationError, InvokeBadRequestError, InvokeError | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| class CohereTextEmbeddingModel(TextEmbeddingModel): | |||
| """ | |||
| Model class for Cohere text embedding model. | |||
| """ | |||
| def _invoke(self, model: str, credentials: dict, | |||
| texts: list[str], user: Optional[str] = None) \ | |||
| -> TextEmbeddingResult: | |||
| """ | |||
| Invoke text embedding model | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param texts: texts to embed | |||
| :param user: unique user id | |||
| :return: embeddings result | |||
| """ | |||
| # get model properties | |||
| context_size = self._get_context_size(model, credentials) | |||
| max_chunks = self._get_max_chunks(model, credentials) | |||
| embeddings: list[list[float]] = [[] for _ in range(len(texts))] | |||
| tokens = [] | |||
| indices = [] | |||
| used_tokens = 0 | |||
| for i, text in enumerate(texts): | |||
| tokenize_response = self._tokenize( | |||
| model=model, | |||
| credentials=credentials, | |||
| text=text | |||
| ) | |||
| for j in range(0, tokenize_response.length, context_size): | |||
| tokens += [tokenize_response.token_strings[j: j + context_size]] | |||
| indices += [i] | |||
| batched_embeddings = [] | |||
| _iter = range(0, len(tokens), max_chunks) | |||
| for i in _iter: | |||
| # call embedding model | |||
| embeddings_batch, embedding_used_tokens = self._embedding_invoke( | |||
| model=model, | |||
| credentials=credentials, | |||
| texts=["".join(token) for token in tokens[i: i + max_chunks]] | |||
| ) | |||
| used_tokens += embedding_used_tokens | |||
| batched_embeddings += embeddings_batch | |||
| results: list[list[list[float]]] = [[] for _ in range(len(texts))] | |||
| num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))] | |||
| for i in range(len(indices)): | |||
| results[indices[i]].append(batched_embeddings[i]) | |||
| num_tokens_in_batch[indices[i]].append(len(tokens[i])) | |||
| for i in range(len(texts)): | |||
| _result = results[i] | |||
| if len(_result) == 0: | |||
| embeddings_batch, embedding_used_tokens = self._embedding_invoke( | |||
| model=model, | |||
| credentials=credentials, | |||
| texts=[""] | |||
| ) | |||
| used_tokens += embedding_used_tokens | |||
| average = embeddings_batch[0] | |||
| else: | |||
| average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) | |||
| embeddings[i] = (average / np.linalg.norm(average)).tolist() | |||
| # calc usage | |||
| usage = self._calc_response_usage( | |||
| model=model, | |||
| credentials=credentials, | |||
| tokens=used_tokens | |||
| ) | |||
| return TextEmbeddingResult( | |||
| embeddings=embeddings, | |||
| usage=usage, | |||
| model=model | |||
| ) | |||
| def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: | |||
| """ | |||
| Get number of tokens for given prompt messages | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param texts: texts to embed | |||
| :return: | |||
| """ | |||
| if len(texts) == 0: | |||
| return 0 | |||
| full_text = ' '.join(texts) | |||
| try: | |||
| response = self._tokenize( | |||
| model=model, | |||
| credentials=credentials, | |||
| text=full_text | |||
| ) | |||
| except Exception as e: | |||
| raise self._transform_invoke_error(e) | |||
| return response.length | |||
| def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens: | |||
| """ | |||
| Tokenize text | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param text: text to tokenize | |||
| :return: | |||
| """ | |||
| # initialize client | |||
| client = cohere.Client(credentials.get('api_key')) | |||
| response = client.tokenize( | |||
| text=text, | |||
| model=model | |||
| ) | |||
| return response | |||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||
| """ | |||
| Validate model credentials | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :return: | |||
| """ | |||
| try: | |||
| # call embedding model | |||
| self._embedding_invoke( | |||
| model=model, | |||
| credentials=credentials, | |||
| texts=['ping'] | |||
| ) | |||
| except Exception as ex: | |||
| raise CredentialsValidateFailedError(str(ex)) | |||
| def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> Tuple[list[list[float]], int]: | |||
| """ | |||
| Invoke embedding model | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param texts: texts to embed | |||
| :return: embeddings and used tokens | |||
| """ | |||
| # initialize client | |||
| client = cohere.Client(credentials.get('api_key')) | |||
| # call embedding model | |||
| response = client.embed( | |||
| texts=texts, | |||
| model=model, | |||
| input_type='search_document' if len(texts) > 1 else 'search_query' | |||
| ) | |||
| return response.embeddings, response.meta['billed_units']['input_tokens'] | |||
| def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: | |||
| """ | |||
| Calculate response usage | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param tokens: input tokens | |||
| :return: usage | |||
| """ | |||
| # get input price info | |||
| input_price_info = self.get_price( | |||
| model=model, | |||
| credentials=credentials, | |||
| price_type=PriceType.INPUT, | |||
| tokens=tokens | |||
| ) | |||
| # transform usage | |||
| usage = EmbeddingUsage( | |||
| tokens=tokens, | |||
| total_tokens=tokens, | |||
| unit_price=input_price_info.unit_price, | |||
| price_unit=input_price_info.unit, | |||
| total_price=input_price_info.total_amount, | |||
| currency=input_price_info.currency, | |||
| latency=time.perf_counter() - self.started_at | |||
| ) | |||
| return usage | |||
| @property | |||
| def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | |||
| """ | |||
| Map model invoke error to unified error | |||
| The key is the error type thrown to the caller | |||
| The value is the error type thrown by the model, | |||
| which needs to be converted into a unified error type for the caller. | |||
| :return: Invoke error mapping | |||
| """ | |||
| return { | |||
| InvokeConnectionError: [ | |||
| cohere.CohereConnectionError | |||
| ], | |||
| InvokeServerUnavailableError: [], | |||
| InvokeRateLimitError: [], | |||
| InvokeAuthorizationError: [], | |||
| InvokeBadRequestError: [ | |||
| cohere.CohereAPIError, | |||
| cohere.CohereError, | |||
| ] | |||
| } | |||
| @@ -24,6 +24,9 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): | |||
| **kwargs: Any, | |||
| ): | |||
| def _token_encoder(text: str) -> int: | |||
| if not text: | |||
| return 0 | |||
| if embedding_model_instance: | |||
| embedding_model_type_instance = embedding_model_instance.model_type_instance | |||
| embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) | |||
| @@ -54,7 +54,7 @@ zhipuai==1.0.7 | |||
| werkzeug==2.3.8 | |||
| pymilvus==2.3.0 | |||
| qdrant-client==1.6.4 | |||
| cohere~=4.32 | |||
| cohere~=4.44 | |||
| pyyaml~=6.0.1 | |||
| numpy~=1.25.2 | |||
| unstructured[docx,pptx,msg,md,ppt]~=0.10.27 | |||
| @@ -0,0 +1,272 @@ | |||
| import os | |||
| from typing import Generator | |||
| import pytest | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage, | |||
| UserPromptMessage) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel | |||
| def test_validate_credentials_for_chat_model(): | |||
| model = CohereLargeLanguageModel() | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| model.validate_credentials( | |||
| model='command-light-chat', | |||
| credentials={ | |||
| 'api_key': 'invalid_key' | |||
| } | |||
| ) | |||
| model.validate_credentials( | |||
| model='command-light-chat', | |||
| credentials={ | |||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||
| } | |||
| ) | |||
| def test_validate_credentials_for_completion_model(): | |||
| model = CohereLargeLanguageModel() | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| model.validate_credentials( | |||
| model='command-light', | |||
| credentials={ | |||
| 'api_key': 'invalid_key' | |||
| } | |||
| ) | |||
| model.validate_credentials( | |||
| model='command-light', | |||
| credentials={ | |||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||
| } | |||
| ) | |||
| def test_invoke_completion_model(): | |||
| model = CohereLargeLanguageModel() | |||
| credentials = { | |||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||
| } | |||
| result = model.invoke( | |||
| model='command-light', | |||
| credentials=credentials, | |||
| prompt_messages=[ | |||
| UserPromptMessage( | |||
| content='Hello World!' | |||
| ) | |||
| ], | |||
| model_parameters={ | |||
| 'temperature': 0.0, | |||
| 'max_tokens': 1 | |||
| }, | |||
| stream=False, | |||
| user="abc-123" | |||
| ) | |||
| assert isinstance(result, LLMResult) | |||
| assert len(result.message.content) > 0 | |||
| assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1 | |||
| def test_invoke_stream_completion_model(): | |||
| model = CohereLargeLanguageModel() | |||
| result = model.invoke( | |||
| model='command-light', | |||
| credentials={ | |||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||
| }, | |||
| prompt_messages=[ | |||
| UserPromptMessage( | |||
| content='Hello World!' | |||
| ) | |||
| ], | |||
| model_parameters={ | |||
| 'temperature': 0.0, | |||
| 'max_tokens': 100 | |||
| }, | |||
| stream=True, | |||
| user="abc-123" | |||
| ) | |||
| assert isinstance(result, Generator) | |||
| for chunk in result: | |||
| assert isinstance(chunk, LLMResultChunk) | |||
| assert isinstance(chunk.delta, LLMResultChunkDelta) | |||
| assert isinstance(chunk.delta.message, AssistantPromptMessage) | |||
| assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True | |||
| def test_invoke_chat_model(): | |||
| model = CohereLargeLanguageModel() | |||
| result = model.invoke( | |||
| model='command-light-chat', | |||
| credentials={ | |||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||
| }, | |||
| prompt_messages=[ | |||
| SystemPromptMessage( | |||
| content='You are a helpful AI assistant.', | |||
| ), | |||
| UserPromptMessage( | |||
| content='Hello World!' | |||
| ) | |||
| ], | |||
| model_parameters={ | |||
| 'temperature': 0.0, | |||
| 'p': 0.99, | |||
| 'presence_penalty': 0.0, | |||
| 'frequency_penalty': 0.0, | |||
| 'max_tokens': 10 | |||
| }, | |||
| stop=['How'], | |||
| stream=False, | |||
| user="abc-123" | |||
| ) | |||
| assert isinstance(result, LLMResult) | |||
| assert len(result.message.content) > 0 | |||
| for chunk in model._llm_result_to_stream(result): | |||
| assert isinstance(chunk, LLMResultChunk) | |||
| assert isinstance(chunk.delta, LLMResultChunkDelta) | |||
| assert isinstance(chunk.delta.message, AssistantPromptMessage) | |||
| assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True | |||
| def test_invoke_stream_chat_model(): | |||
| model = CohereLargeLanguageModel() | |||
| result = model.invoke( | |||
| model='command-light-chat', | |||
| credentials={ | |||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||
| }, | |||
| prompt_messages=[ | |||
| SystemPromptMessage( | |||
| content='You are a helpful AI assistant.', | |||
| ), | |||
| UserPromptMessage( | |||
| content='Hello World!' | |||
| ) | |||
| ], | |||
| model_parameters={ | |||
| 'temperature': 0.0, | |||
| 'max_tokens': 100 | |||
| }, | |||
| stream=True, | |||
| user="abc-123" | |||
| ) | |||
| assert isinstance(result, Generator) | |||
| for chunk in result: | |||
| assert isinstance(chunk, LLMResultChunk) | |||
| assert isinstance(chunk.delta, LLMResultChunkDelta) | |||
| assert isinstance(chunk.delta.message, AssistantPromptMessage) | |||
| assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True | |||
| if chunk.delta.finish_reason is not None: | |||
| assert chunk.delta.usage is not None | |||
| assert chunk.delta.usage.completion_tokens > 0 | |||
| def test_get_num_tokens(): | |||
| model = CohereLargeLanguageModel() | |||
| num_tokens = model.get_num_tokens( | |||
| model='command-light', | |||
| credentials={ | |||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||
| }, | |||
| prompt_messages=[ | |||
| UserPromptMessage( | |||
| content='Hello World!' | |||
| ) | |||
| ] | |||
| ) | |||
| assert num_tokens == 3 | |||
| num_tokens = model.get_num_tokens( | |||
| model='command-light-chat', | |||
| credentials={ | |||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||
| }, | |||
| prompt_messages=[ | |||
| SystemPromptMessage( | |||
| content='You are a helpful AI assistant.', | |||
| ), | |||
| UserPromptMessage( | |||
| content='Hello World!' | |||
| ) | |||
| ] | |||
| ) | |||
| assert num_tokens == 15 | |||
| def test_fine_tuned_model(): | |||
| model = CohereLargeLanguageModel() | |||
| # test invoke | |||
| result = model.invoke( | |||
| model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft', | |||
| credentials={ | |||
| 'api_key': os.environ.get('COHERE_API_KEY'), | |||
| 'mode': 'completion' | |||
| }, | |||
| prompt_messages=[ | |||
| SystemPromptMessage( | |||
| content='You are a helpful AI assistant.', | |||
| ), | |||
| UserPromptMessage( | |||
| content='Hello World!' | |||
| ) | |||
| ], | |||
| model_parameters={ | |||
| 'temperature': 0.0, | |||
| 'max_tokens': 100 | |||
| }, | |||
| stream=False, | |||
| user="abc-123" | |||
| ) | |||
| assert isinstance(result, LLMResult) | |||
| def test_fine_tuned_chat_model(): | |||
| model = CohereLargeLanguageModel() | |||
| # test invoke | |||
| result = model.invoke( | |||
| model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft', | |||
| credentials={ | |||
| 'api_key': os.environ.get('COHERE_API_KEY'), | |||
| 'mode': 'chat' | |||
| }, | |||
| prompt_messages=[ | |||
| SystemPromptMessage( | |||
| content='You are a helpful AI assistant.', | |||
| ), | |||
| UserPromptMessage( | |||
| content='Hello World!' | |||
| ) | |||
| ], | |||
| model_parameters={ | |||
| 'temperature': 0.0, | |||
| 'max_tokens': 100 | |||
| }, | |||
| stream=False, | |||
| user="abc-123" | |||
| ) | |||
| assert isinstance(result, LLMResult) | |||
| @@ -0,0 +1,64 @@ | |||
| import os | |||
| import pytest | |||
| from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.cohere.text_embedding.text_embedding import CohereTextEmbeddingModel | |||
| def test_validate_credentials(): | |||
| model = CohereTextEmbeddingModel() | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| model.validate_credentials( | |||
| model='embed-multilingual-v3.0', | |||
| credentials={ | |||
| 'api_key': 'invalid_key' | |||
| } | |||
| ) | |||
| model.validate_credentials( | |||
| model='embed-multilingual-v3.0', | |||
| credentials={ | |||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||
| } | |||
| ) | |||
| def test_invoke_model(): | |||
| model = CohereTextEmbeddingModel() | |||
| result = model.invoke( | |||
| model='embed-multilingual-v3.0', | |||
| credentials={ | |||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||
| }, | |||
| texts=[ | |||
| "hello", | |||
| "world", | |||
| " ".join(["long_text"] * 100), | |||
| " ".join(["another_long_text"] * 100) | |||
| ], | |||
| user="abc-123" | |||
| ) | |||
| assert isinstance(result, TextEmbeddingResult) | |||
| assert len(result.embeddings) == 4 | |||
| assert result.usage.total_tokens == 811 | |||
| def test_get_num_tokens(): | |||
| model = CohereTextEmbeddingModel() | |||
| num_tokens = model.get_num_tokens( | |||
| model='embed-multilingual-v3.0', | |||
| credentials={ | |||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||
| }, | |||
| texts=[ | |||
| "hello", | |||
| "world" | |||
| ] | |||
| ) | |||
| assert num_tokens == 3 | |||