| import logging | import logging | ||||
| import os | import os | ||||
| import re | |||||
| import time | import time | ||||
| from abc import abstractmethod | from abc import abstractmethod | ||||
| from typing import Generator, List, Optional, Union | from typing import Generator, List, Optional, Union | ||||
| """ | """ | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def enforce_stop_tokens(self, text: str, stop: List[str]) -> str: | |||||
| """Cut off the text as soon as any stop words occur.""" | |||||
| return re.split("|".join(stop), text, maxsplit=1)[0] | |||||
| def _llm_result_to_stream(self, result: LLMResult) -> Generator: | def _llm_result_to_stream(self, result: LLMResult) -> Generator: | ||||
| """ | """ | ||||
| Transform llm result to stream | Transform llm result to stream |
| url: | url: | ||||
| en_US: https://dashboard.cohere.com/api-keys | en_US: https://dashboard.cohere.com/api-keys | ||||
| supported_model_types: | supported_model_types: | ||||
| - llm | |||||
| - text-embedding | |||||
| - rerank | - rerank | ||||
| configurate_methods: | configurate_methods: | ||||
| - predefined-model | - predefined-model | ||||
| - customizable-model | |||||
| provider_credential_schema: | provider_credential_schema: | ||||
| credential_form_schemas: | credential_form_schemas: | ||||
| - variable: api_key | - variable: api_key | ||||
| type: secret-input | type: secret-input | ||||
| required: true | required: true | ||||
| placeholder: | placeholder: | ||||
| zh_Hans: 请填写 API Key | |||||
| en_US: Please fill in API Key | |||||
| zh_Hans: 在此输入您的 API Key | |||||
| en_US: Enter your API Key | |||||
| show_on: [ ] | show_on: [ ] | ||||
| model_credential_schema: | |||||
| model: | |||||
| label: | |||||
| en_US: Model Name | |||||
| zh_Hans: 模型名称 | |||||
| placeholder: | |||||
| en_US: Enter your model name | |||||
| zh_Hans: 输入模型名称 | |||||
| credential_form_schemas: | |||||
| - variable: mode | |||||
| show_on: | |||||
| - variable: __model_type | |||||
| value: llm | |||||
| label: | |||||
| en_US: Completion mode | |||||
| type: select | |||||
| required: false | |||||
| default: chat | |||||
| placeholder: | |||||
| zh_Hans: 选择对话类型 | |||||
| en_US: Select completion mode | |||||
| options: | |||||
| - value: completion | |||||
| label: | |||||
| en_US: Completion | |||||
| zh_Hans: 补全 | |||||
| - value: chat | |||||
| label: | |||||
| en_US: Chat | |||||
| zh_Hans: 对话 | |||||
| - variable: api_key | |||||
| label: | |||||
| en_US: API Key | |||||
| type: secret-input | |||||
| required: true | |||||
| placeholder: | |||||
| zh_Hans: 在此输入您的 API Key | |||||
| en_US: Enter your API Key |
| - command-chat | |||||
| - command-light-chat | |||||
| - command-nightly-chat | |||||
| - command-light-nightly-chat | |||||
| - command | |||||
| - command-light | |||||
| - command-nightly | |||||
| - command-light-nightly |
| model: command-chat | |||||
| label: | |||||
| zh_Hans: command-chat | |||||
| en_US: command-chat | |||||
| model_type: llm | |||||
| features: | |||||
| - agent-thought | |||||
| model_properties: | |||||
| mode: chat | |||||
| context_size: 4096 | |||||
| parameter_rules: | |||||
| - name: temperature | |||||
| use_template: temperature | |||||
| max: 5.0 | |||||
| - name: p | |||||
| use_template: top_p | |||||
| default: 0.75 | |||||
| min: 0.01 | |||||
| max: 0.99 | |||||
| - name: k | |||||
| label: | |||||
| zh_Hans: 取样数量 | |||||
| en_US: Top k | |||||
| type: int | |||||
| help: | |||||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||||
| en_US: Only sample from the top K options for each subsequent token. | |||||
| required: false | |||||
| default: 0 | |||||
| min: 0 | |||||
| max: 500 | |||||
| - name: max_tokens | |||||
| use_template: max_tokens | |||||
| default: 256 | |||||
| max: 4096 | |||||
| - name: preamble_override | |||||
| label: | |||||
| zh_Hans: 前导文本 | |||||
| en_US: Preamble | |||||
| type: string | |||||
| help: | |||||
| zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。 | |||||
| en_US: When specified, the default Cohere preamble will be replaced with the provided one. | |||||
| required: false | |||||
| - name: prompt_truncation | |||||
| label: | |||||
| zh_Hans: 提示截断 | |||||
| en_US: Prompt Truncation | |||||
| type: string | |||||
| help: | |||||
| zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。 | |||||
| en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit. | |||||
| required: true | |||||
| default: 'AUTO' | |||||
| options: | |||||
| - 'AUTO' | |||||
| - 'OFF' | |||||
| pricing: | |||||
| input: '1.0' | |||||
| output: '2.0' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| model: command-light-chat | |||||
| label: | |||||
| zh_Hans: command-light-chat | |||||
| en_US: command-light-chat | |||||
| model_type: llm | |||||
| features: | |||||
| - agent-thought | |||||
| model_properties: | |||||
| mode: chat | |||||
| context_size: 4096 | |||||
| parameter_rules: | |||||
| - name: temperature | |||||
| use_template: temperature | |||||
| max: 5.0 | |||||
| - name: p | |||||
| use_template: top_p | |||||
| default: 0.75 | |||||
| min: 0.01 | |||||
| max: 0.99 | |||||
| - name: k | |||||
| label: | |||||
| zh_Hans: 取样数量 | |||||
| en_US: Top k | |||||
| type: int | |||||
| help: | |||||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||||
| en_US: Only sample from the top K options for each subsequent token. | |||||
| required: false | |||||
| default: 0 | |||||
| min: 0 | |||||
| max: 500 | |||||
| - name: max_tokens | |||||
| use_template: max_tokens | |||||
| default: 256 | |||||
| max: 4096 | |||||
| - name: preamble_override | |||||
| label: | |||||
| zh_Hans: 前导文本 | |||||
| en_US: Preamble | |||||
| type: string | |||||
| help: | |||||
| zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。 | |||||
| en_US: When specified, the default Cohere preamble will be replaced with the provided one. | |||||
| required: false | |||||
| - name: prompt_truncation | |||||
| label: | |||||
| zh_Hans: 提示截断 | |||||
| en_US: Prompt Truncation | |||||
| type: string | |||||
| help: | |||||
| zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。 | |||||
| en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit. | |||||
| required: true | |||||
| default: 'AUTO' | |||||
| options: | |||||
| - 'AUTO' | |||||
| - 'OFF' | |||||
| pricing: | |||||
| input: '0.3' | |||||
| output: '0.6' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| model: command-light-nightly-chat | |||||
| label: | |||||
| zh_Hans: command-light-nightly-chat | |||||
| en_US: command-light-nightly-chat | |||||
| model_type: llm | |||||
| features: | |||||
| - agent-thought | |||||
| model_properties: | |||||
| mode: chat | |||||
| context_size: 4096 | |||||
| parameter_rules: | |||||
| - name: temperature | |||||
| use_template: temperature | |||||
| max: 5.0 | |||||
| - name: p | |||||
| use_template: top_p | |||||
| default: 0.75 | |||||
| min: 0.01 | |||||
| max: 0.99 | |||||
| - name: k | |||||
| label: | |||||
| zh_Hans: 取样数量 | |||||
| en_US: Top k | |||||
| type: int | |||||
| help: | |||||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||||
| en_US: Only sample from the top K options for each subsequent token. | |||||
| required: false | |||||
| default: 0 | |||||
| min: 0 | |||||
| max: 500 | |||||
| - name: max_tokens | |||||
| use_template: max_tokens | |||||
| default: 256 | |||||
| max: 4096 | |||||
| - name: preamble_override | |||||
| label: | |||||
| zh_Hans: 前导文本 | |||||
| en_US: Preamble | |||||
| type: string | |||||
| help: | |||||
| zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。 | |||||
| en_US: When specified, the default Cohere preamble will be replaced with the provided one. | |||||
| required: false | |||||
| - name: prompt_truncation | |||||
| label: | |||||
| zh_Hans: 提示截断 | |||||
| en_US: Prompt Truncation | |||||
| type: string | |||||
| help: | |||||
| zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。 | |||||
| en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit. | |||||
| required: true | |||||
| default: 'AUTO' | |||||
| options: | |||||
| - 'AUTO' | |||||
| - 'OFF' | |||||
| pricing: | |||||
| input: '0.3' | |||||
| output: '0.6' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| model: command-light-nightly | |||||
| label: | |||||
| zh_Hans: command-light-nightly | |||||
| en_US: command-light-nightly | |||||
| model_type: llm | |||||
| features: | |||||
| - agent-thought | |||||
| model_properties: | |||||
| mode: completion | |||||
| context_size: 4096 | |||||
| parameter_rules: | |||||
| - name: temperature | |||||
| use_template: temperature | |||||
| max: 5.0 | |||||
| - name: p | |||||
| use_template: top_p | |||||
| default: 0.75 | |||||
| min: 0.01 | |||||
| max: 0.99 | |||||
| - name: k | |||||
| label: | |||||
| zh_Hans: 取样数量 | |||||
| en_US: Top k | |||||
| type: int | |||||
| help: | |||||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||||
| en_US: Only sample from the top K options for each subsequent token. | |||||
| required: false | |||||
| default: 0 | |||||
| min: 0 | |||||
| max: 500 | |||||
| - name: presence_penalty | |||||
| use_template: presence_penalty | |||||
| - name: frequency_penalty | |||||
| use_template: frequency_penalty | |||||
| - name: max_tokens | |||||
| use_template: max_tokens | |||||
| default: 256 | |||||
| max: 4096 | |||||
| pricing: | |||||
| input: '0.3' | |||||
| output: '0.6' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| model: command-light | |||||
| label: | |||||
| zh_Hans: command-light | |||||
| en_US: command-light | |||||
| model_type: llm | |||||
| features: | |||||
| - agent-thought | |||||
| model_properties: | |||||
| mode: completion | |||||
| context_size: 4096 | |||||
| parameter_rules: | |||||
| - name: temperature | |||||
| use_template: temperature | |||||
| max: 5.0 | |||||
| - name: p | |||||
| use_template: top_p | |||||
| default: 0.75 | |||||
| min: 0.01 | |||||
| max: 0.99 | |||||
| - name: k | |||||
| label: | |||||
| zh_Hans: 取样数量 | |||||
| en_US: Top k | |||||
| type: int | |||||
| help: | |||||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||||
| en_US: Only sample from the top K options for each subsequent token. | |||||
| required: false | |||||
| default: 0 | |||||
| min: 0 | |||||
| max: 500 | |||||
| - name: presence_penalty | |||||
| use_template: presence_penalty | |||||
| - name: frequency_penalty | |||||
| use_template: frequency_penalty | |||||
| - name: max_tokens | |||||
| use_template: max_tokens | |||||
| default: 256 | |||||
| max: 4096 | |||||
| pricing: | |||||
| input: '0.3' | |||||
| output: '0.6' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| model: command-nightly-chat | |||||
| label: | |||||
| zh_Hans: command-nightly-chat | |||||
| en_US: command-nightly-chat | |||||
| model_type: llm | |||||
| features: | |||||
| - agent-thought | |||||
| model_properties: | |||||
| mode: chat | |||||
| context_size: 4096 | |||||
| parameter_rules: | |||||
| - name: temperature | |||||
| use_template: temperature | |||||
| max: 5.0 | |||||
| - name: p | |||||
| use_template: top_p | |||||
| default: 0.75 | |||||
| min: 0.01 | |||||
| max: 0.99 | |||||
| - name: k | |||||
| label: | |||||
| zh_Hans: 取样数量 | |||||
| en_US: Top k | |||||
| type: int | |||||
| help: | |||||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||||
| en_US: Only sample from the top K options for each subsequent token. | |||||
| required: false | |||||
| default: 0 | |||||
| min: 0 | |||||
| max: 500 | |||||
| - name: max_tokens | |||||
| use_template: max_tokens | |||||
| default: 256 | |||||
| max: 4096 | |||||
| - name: preamble_override | |||||
| label: | |||||
| zh_Hans: 前导文本 | |||||
| en_US: Preamble | |||||
| type: string | |||||
| help: | |||||
| zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。 | |||||
| en_US: When specified, the default Cohere preamble will be replaced with the provided one. | |||||
| required: false | |||||
| - name: prompt_truncation | |||||
| label: | |||||
| zh_Hans: 提示截断 | |||||
| en_US: Prompt Truncation | |||||
| type: string | |||||
| help: | |||||
| zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。 | |||||
| en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit. | |||||
| required: true | |||||
| default: 'AUTO' | |||||
| options: | |||||
| - 'AUTO' | |||||
| - 'OFF' | |||||
| pricing: | |||||
| input: '1.0' | |||||
| output: '2.0' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| model: command-nightly | |||||
| label: | |||||
| zh_Hans: command-nightly | |||||
| en_US: command-nightly | |||||
| model_type: llm | |||||
| features: | |||||
| - agent-thought | |||||
| model_properties: | |||||
| mode: completion | |||||
| context_size: 4096 | |||||
| parameter_rules: | |||||
| - name: temperature | |||||
| use_template: temperature | |||||
| max: 5.0 | |||||
| - name: p | |||||
| use_template: top_p | |||||
| default: 0.75 | |||||
| min: 0.01 | |||||
| max: 0.99 | |||||
| - name: k | |||||
| label: | |||||
| zh_Hans: 取样数量 | |||||
| en_US: Top k | |||||
| type: int | |||||
| help: | |||||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||||
| en_US: Only sample from the top K options for each subsequent token. | |||||
| required: false | |||||
| default: 0 | |||||
| min: 0 | |||||
| max: 500 | |||||
| - name: presence_penalty | |||||
| use_template: presence_penalty | |||||
| - name: frequency_penalty | |||||
| use_template: frequency_penalty | |||||
| - name: max_tokens | |||||
| use_template: max_tokens | |||||
| default: 256 | |||||
| max: 4096 | |||||
| pricing: | |||||
| input: '1.0' | |||||
| output: '2.0' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| model: command | |||||
| label: | |||||
| zh_Hans: command | |||||
| en_US: command | |||||
| model_type: llm | |||||
| features: | |||||
| - agent-thought | |||||
| model_properties: | |||||
| mode: completion | |||||
| context_size: 4096 | |||||
| parameter_rules: | |||||
| - name: temperature | |||||
| use_template: temperature | |||||
| max: 5.0 | |||||
| - name: p | |||||
| use_template: top_p | |||||
| default: 0.75 | |||||
| min: 0.01 | |||||
| max: 0.99 | |||||
| - name: k | |||||
| label: | |||||
| zh_Hans: 取样数量 | |||||
| en_US: Top k | |||||
| type: int | |||||
| help: | |||||
| zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 | |||||
| en_US: Only sample from the top K options for each subsequent token. | |||||
| required: false | |||||
| default: 0 | |||||
| min: 0 | |||||
| max: 500 | |||||
| - name: presence_penalty | |||||
| use_template: presence_penalty | |||||
| - name: frequency_penalty | |||||
| use_template: frequency_penalty | |||||
| - name: max_tokens | |||||
| use_template: max_tokens | |||||
| default: 256 | |||||
| max: 4096 | |||||
| pricing: | |||||
| input: '1.0' | |||||
| output: '2.0' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| import logging | |||||
| from typing import Generator, List, Optional, Union, cast, Tuple | |||||
| import cohere | |||||
| from cohere.responses import Chat, Generations | |||||
| from cohere.responses.chat import StreamingChat, StreamTextGeneration, StreamEnd | |||||
| from cohere.responses.generation import StreamingText, StreamingGenerations | |||||
| from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta | |||||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, | |||||
| PromptMessageContentType, SystemPromptMessage, | |||||
| TextPromptMessageContent, UserPromptMessage, | |||||
| PromptMessageTool) | |||||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType | |||||
| from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeError, \ | |||||
| InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError | |||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||||
| logger = logging.getLogger(__name__) | |||||
| class CohereLargeLanguageModel(LargeLanguageModel): | |||||
| """ | |||||
| Model class for Cohere large language model. | |||||
| """ | |||||
| def _invoke(self, model: str, credentials: dict, | |||||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||||
| stream: bool = True, user: Optional[str] = None) \ | |||||
| -> Union[LLMResult, Generator]: | |||||
| """ | |||||
| Invoke large language model | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param prompt_messages: prompt messages | |||||
| :param model_parameters: model parameters | |||||
| :param tools: tools for tool calling | |||||
| :param stop: stop words | |||||
| :param stream: is stream response | |||||
| :param user: unique user id | |||||
| :return: full response or stream response chunk generator result | |||||
| """ | |||||
| # get model mode | |||||
| model_mode = self.get_model_mode(model, credentials) | |||||
| if model_mode == LLMMode.CHAT: | |||||
| return self._chat_generate( | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| prompt_messages=prompt_messages, | |||||
| model_parameters=model_parameters, | |||||
| stop=stop, | |||||
| stream=stream, | |||||
| user=user | |||||
| ) | |||||
| else: | |||||
| return self._generate( | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| prompt_messages=prompt_messages, | |||||
| model_parameters=model_parameters, | |||||
| stop=stop, | |||||
| stream=stream, | |||||
| user=user | |||||
| ) | |||||
| def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], | |||||
| tools: Optional[list[PromptMessageTool]] = None) -> int: | |||||
| """ | |||||
| Get number of tokens for given prompt messages | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param prompt_messages: prompt messages | |||||
| :param tools: tools for tool calling | |||||
| :return: | |||||
| """ | |||||
| # get model mode | |||||
| model_mode = self.get_model_mode(model) | |||||
| try: | |||||
| if model_mode == LLMMode.CHAT: | |||||
| return self._num_tokens_from_messages(model, credentials, prompt_messages) | |||||
| else: | |||||
| return self._num_tokens_from_string(model, credentials, prompt_messages[0].content) | |||||
| except Exception as e: | |||||
| raise self._transform_invoke_error(e) | |||||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||||
| """ | |||||
| Validate model credentials | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :return: | |||||
| """ | |||||
| try: | |||||
| # get model mode | |||||
| model_mode = self.get_model_mode(model) | |||||
| if model_mode == LLMMode.CHAT: | |||||
| self._chat_generate( | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| prompt_messages=[UserPromptMessage(content='ping')], | |||||
| model_parameters={ | |||||
| 'max_tokens': 20, | |||||
| 'temperature': 0, | |||||
| }, | |||||
| stream=False | |||||
| ) | |||||
| else: | |||||
| self._generate( | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| prompt_messages=[UserPromptMessage(content='ping')], | |||||
| model_parameters={ | |||||
| 'max_tokens': 20, | |||||
| 'temperature': 0, | |||||
| }, | |||||
| stream=False | |||||
| ) | |||||
| except Exception as ex: | |||||
| raise CredentialsValidateFailedError(str(ex)) | |||||
| def _generate(self, model: str, credentials: dict, | |||||
| prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, | |||||
| stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: | |||||
| """ | |||||
| Invoke llm model | |||||
| :param model: model name | |||||
| :param credentials: credentials | |||||
| :param prompt_messages: prompt messages | |||||
| :param model_parameters: model parameters | |||||
| :param stop: stop words | |||||
| :param stream: is stream response | |||||
| :param user: unique user id | |||||
| :return: full response or stream response chunk generator result | |||||
| """ | |||||
| # initialize client | |||||
| client = cohere.Client(credentials.get('api_key')) | |||||
| if stop: | |||||
| model_parameters['end_sequences'] = stop | |||||
| response = client.generate( | |||||
| prompt=prompt_messages[0].content, | |||||
| model=model, | |||||
| stream=stream, | |||||
| **model_parameters, | |||||
| ) | |||||
| if stream: | |||||
| return self._handle_generate_stream_response(model, credentials, response, prompt_messages) | |||||
| return self._handle_generate_response(model, credentials, response, prompt_messages) | |||||
| def _handle_generate_response(self, model: str, credentials: dict, response: Generations, | |||||
| prompt_messages: list[PromptMessage]) \ | |||||
| -> LLMResult: | |||||
| """ | |||||
| Handle llm response | |||||
| :param model: model name | |||||
| :param credentials: credentials | |||||
| :param response: response | |||||
| :param prompt_messages: prompt messages | |||||
| :return: llm response | |||||
| """ | |||||
| assistant_text = response.generations[0].text | |||||
| # transform assistant message to prompt message | |||||
| assistant_prompt_message = AssistantPromptMessage( | |||||
| content=assistant_text | |||||
| ) | |||||
| # calculate num tokens | |||||
| prompt_tokens = response.meta['billed_units']['input_tokens'] | |||||
| completion_tokens = response.meta['billed_units']['output_tokens'] | |||||
| # transform usage | |||||
| usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) | |||||
| # transform response | |||||
| response = LLMResult( | |||||
| model=model, | |||||
| prompt_messages=prompt_messages, | |||||
| message=assistant_prompt_message, | |||||
| usage=usage | |||||
| ) | |||||
| return response | |||||
| def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations, | |||||
| prompt_messages: list[PromptMessage]) -> Generator: | |||||
| """ | |||||
| Handle llm stream response | |||||
| :param model: model name | |||||
| :param response: response | |||||
| :param prompt_messages: prompt messages | |||||
| :return: llm response chunk generator | |||||
| """ | |||||
| index = 1 | |||||
| full_assistant_content = '' | |||||
| for chunk in response: | |||||
| if isinstance(chunk, StreamingText): | |||||
| chunk = cast(StreamingText, chunk) | |||||
| text = chunk.text | |||||
| if text is None: | |||||
| continue | |||||
| # transform assistant message to prompt message | |||||
| assistant_prompt_message = AssistantPromptMessage( | |||||
| content=text | |||||
| ) | |||||
| full_assistant_content += text | |||||
| yield LLMResultChunk( | |||||
| model=model, | |||||
| prompt_messages=prompt_messages, | |||||
| delta=LLMResultChunkDelta( | |||||
| index=index, | |||||
| message=assistant_prompt_message, | |||||
| ) | |||||
| ) | |||||
| index += 1 | |||||
| elif chunk is None: | |||||
| # calculate num tokens | |||||
| prompt_tokens = response.meta['billed_units']['input_tokens'] | |||||
| completion_tokens = response.meta['billed_units']['output_tokens'] | |||||
| # transform usage | |||||
| usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) | |||||
| yield LLMResultChunk( | |||||
| model=model, | |||||
| prompt_messages=prompt_messages, | |||||
| delta=LLMResultChunkDelta( | |||||
| index=index, | |||||
| message=AssistantPromptMessage(content=''), | |||||
| finish_reason=response.finish_reason, | |||||
| usage=usage | |||||
| ) | |||||
| ) | |||||
| break | |||||
| def _chat_generate(self, model: str, credentials: dict, | |||||
| prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, | |||||
| stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: | |||||
| """ | |||||
| Invoke llm chat model | |||||
| :param model: model name | |||||
| :param credentials: credentials | |||||
| :param prompt_messages: prompt messages | |||||
| :param model_parameters: model parameters | |||||
| :param stop: stop words | |||||
| :param stream: is stream response | |||||
| :param user: unique user id | |||||
| :return: full response or stream response chunk generator result | |||||
| """ | |||||
| # initialize client | |||||
| client = cohere.Client(credentials.get('api_key')) | |||||
| if user: | |||||
| model_parameters['user_name'] = user | |||||
| message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) | |||||
| # chat model | |||||
| real_model = model | |||||
| if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: | |||||
| real_model = model.removesuffix('-chat') | |||||
| response = client.chat( | |||||
| message=message, | |||||
| chat_history=chat_histories, | |||||
| model=real_model, | |||||
| stream=stream, | |||||
| return_preamble=True, | |||||
| **model_parameters, | |||||
| ) | |||||
| if stream: | |||||
| return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop) | |||||
| return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop) | |||||
| def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat, | |||||
| prompt_messages: list[PromptMessage], stop: Optional[List[str]] = None) \ | |||||
| -> LLMResult: | |||||
| """ | |||||
| Handle llm chat response | |||||
| :param model: model name | |||||
| :param credentials: credentials | |||||
| :param response: response | |||||
| :param prompt_messages: prompt messages | |||||
| :param stop: stop words | |||||
| :return: llm response | |||||
| """ | |||||
| assistant_text = response.text | |||||
| # transform assistant message to prompt message | |||||
| assistant_prompt_message = AssistantPromptMessage( | |||||
| content=assistant_text | |||||
| ) | |||||
| # calculate num tokens | |||||
| prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) | |||||
| completion_tokens = self._num_tokens_from_messages(model, credentials, [assistant_prompt_message]) | |||||
| # transform usage | |||||
| usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) | |||||
| if stop: | |||||
| # enforce stop tokens | |||||
| assistant_text = self.enforce_stop_tokens(assistant_text, stop) | |||||
| assistant_prompt_message = AssistantPromptMessage( | |||||
| content=assistant_text | |||||
| ) | |||||
| # transform response | |||||
| response = LLMResult( | |||||
| model=model, | |||||
| prompt_messages=prompt_messages, | |||||
| message=assistant_prompt_message, | |||||
| usage=usage, | |||||
| system_fingerprint=response.preamble | |||||
| ) | |||||
| return response | |||||
| def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat, | |||||
| prompt_messages: list[PromptMessage], | |||||
| stop: Optional[List[str]] = None) -> Generator: | |||||
| """ | |||||
| Handle llm chat stream response | |||||
| :param model: model name | |||||
| :param response: response | |||||
| :param prompt_messages: prompt messages | |||||
| :param stop: stop words | |||||
| :return: llm response chunk generator | |||||
| """ | |||||
| def final_response(full_text: str, index: int, finish_reason: Optional[str] = None, | |||||
| preamble: Optional[str] = None) -> LLMResultChunk: | |||||
| # calculate num tokens | |||||
| prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) | |||||
| full_assistant_prompt_message = AssistantPromptMessage( | |||||
| content=full_text | |||||
| ) | |||||
| completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message]) | |||||
| # transform usage | |||||
| usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) | |||||
| return LLMResultChunk( | |||||
| model=model, | |||||
| prompt_messages=prompt_messages, | |||||
| system_fingerprint=preamble, | |||||
| delta=LLMResultChunkDelta( | |||||
| index=index, | |||||
| message=AssistantPromptMessage(content=''), | |||||
| finish_reason=finish_reason, | |||||
| usage=usage | |||||
| ) | |||||
| ) | |||||
| index = 1 | |||||
| full_assistant_content = '' | |||||
| for chunk in response: | |||||
| if isinstance(chunk, StreamTextGeneration): | |||||
| chunk = cast(StreamTextGeneration, chunk) | |||||
| text = chunk.text | |||||
| if text is None: | |||||
| continue | |||||
| # transform assistant message to prompt message | |||||
| assistant_prompt_message = AssistantPromptMessage( | |||||
| content=text | |||||
| ) | |||||
| # stop | |||||
| # notice: This logic can only cover few stop scenarios | |||||
| if stop and text in stop: | |||||
| yield final_response(full_assistant_content, index, 'stop') | |||||
| break | |||||
| full_assistant_content += text | |||||
| yield LLMResultChunk( | |||||
| model=model, | |||||
| prompt_messages=prompt_messages, | |||||
| delta=LLMResultChunkDelta( | |||||
| index=index, | |||||
| message=assistant_prompt_message, | |||||
| ) | |||||
| ) | |||||
| index += 1 | |||||
| elif isinstance(chunk, StreamEnd): | |||||
| chunk = cast(StreamEnd, chunk) | |||||
| yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble) | |||||
| index += 1 | |||||
| def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ | |||||
| -> Tuple[str, list[dict]]: | |||||
| """ | |||||
| Convert prompt messages to message and chat histories | |||||
| :param prompt_messages: prompt messages | |||||
| :return: | |||||
| """ | |||||
| chat_histories = [] | |||||
| for prompt_message in prompt_messages: | |||||
| chat_histories.append(self._convert_prompt_message_to_dict(prompt_message)) | |||||
| # get latest message from chat histories and pop it | |||||
| if len(chat_histories) > 0: | |||||
| latest_message = chat_histories.pop() | |||||
| message = latest_message['message'] | |||||
| else: | |||||
| raise ValueError('Prompt messages is empty') | |||||
| return message, chat_histories | |||||
| def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: | |||||
| """ | |||||
| Convert PromptMessage to dict for Cohere model | |||||
| """ | |||||
| if isinstance(message, UserPromptMessage): | |||||
| message = cast(UserPromptMessage, message) | |||||
| if isinstance(message.content, str): | |||||
| message_dict = {"role": "USER", "message": message.content} | |||||
| else: | |||||
| sub_message_text = '' | |||||
| for message_content in message.content: | |||||
| if message_content.type == PromptMessageContentType.TEXT: | |||||
| message_content = cast(TextPromptMessageContent, message_content) | |||||
| sub_message_text += message_content.data | |||||
| message_dict = {"role": "USER", "message": sub_message_text} | |||||
| elif isinstance(message, AssistantPromptMessage): | |||||
| message = cast(AssistantPromptMessage, message) | |||||
| message_dict = {"role": "CHATBOT", "message": message.content} | |||||
| elif isinstance(message, SystemPromptMessage): | |||||
| message = cast(SystemPromptMessage, message) | |||||
| message_dict = {"role": "USER", "message": message.content} | |||||
| else: | |||||
| raise ValueError(f"Got unknown type {message}") | |||||
| if message.name is not None: | |||||
| message_dict["user_name"] = message.name | |||||
| return message_dict | |||||
| def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int: | |||||
| """ | |||||
| Calculate num tokens for text completion model. | |||||
| :param model: model name | |||||
| :param credentials: credentials | |||||
| :param text: prompt text | |||||
| :return: number of tokens | |||||
| """ | |||||
| # initialize client | |||||
| client = cohere.Client(credentials.get('api_key')) | |||||
| response = client.tokenize( | |||||
| text=text, | |||||
| model=model | |||||
| ) | |||||
| return response.length | |||||
| def _num_tokens_from_messages(self, model: str, credentials: dict, messages: List[PromptMessage]) -> int: | |||||
| """Calculate num tokens Cohere model.""" | |||||
| messages = [self._convert_prompt_message_to_dict(m) for m in messages] | |||||
| message_strs = [f"{message['role']}: {message['message']}" for message in messages] | |||||
| message_str = "\n".join(message_strs) | |||||
| real_model = model | |||||
| if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: | |||||
| real_model = model.removesuffix('-chat') | |||||
| return self._num_tokens_from_string(real_model, credentials, message_str) | |||||
| def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: | |||||
| """ | |||||
| Cohere supports fine-tuning of their models. This method returns the schema of the base model | |||||
| but renamed to the fine-tuned model name. | |||||
| :param model: model name | |||||
| :param credentials: credentials | |||||
| :return: model schema | |||||
| """ | |||||
| # get model schema | |||||
| models = self.predefined_models() | |||||
| model_map = {model.model: model for model in models} | |||||
| mode = credentials.get('mode') | |||||
| if mode == 'chat': | |||||
| base_model_schema = model_map['command-light-chat'] | |||||
| else: | |||||
| base_model_schema = model_map['command-light'] | |||||
| base_model_schema = cast(AIModelEntity, base_model_schema) | |||||
| base_model_schema_features = base_model_schema.features or [] | |||||
| base_model_schema_model_properties = base_model_schema.model_properties or {} | |||||
| base_model_schema_parameters_rules = base_model_schema.parameter_rules or [] | |||||
| entity = AIModelEntity( | |||||
| model=model, | |||||
| label=I18nObject( | |||||
| zh_Hans=model, | |||||
| en_US=model | |||||
| ), | |||||
| model_type=ModelType.LLM, | |||||
| features=[feature for feature in base_model_schema_features], | |||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||||
| model_properties={ | |||||
| key: property for key, property in base_model_schema_model_properties.items() | |||||
| }, | |||||
| parameter_rules=[rule for rule in base_model_schema_parameters_rules], | |||||
| pricing=base_model_schema.pricing | |||||
| ) | |||||
| return entity | |||||
| @property | |||||
| def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | |||||
| """ | |||||
| Map model invoke error to unified error | |||||
| The key is the error type thrown to the caller | |||||
| The value is the error type thrown by the model, | |||||
| which needs to be converted into a unified error type for the caller. | |||||
| :return: Invoke error mapping | |||||
| """ | |||||
| return { | |||||
| InvokeConnectionError: [ | |||||
| cohere.CohereConnectionError | |||||
| ], | |||||
| InvokeServerUnavailableError: [], | |||||
| InvokeRateLimitError: [], | |||||
| InvokeAuthorizationError: [], | |||||
| InvokeBadRequestError: [ | |||||
| cohere.CohereAPIError, | |||||
| cohere.CohereError, | |||||
| ] | |||||
| } |
| - embed-multilingual-v3.0 | |||||
| - embed-multilingual-light-v3.0 | |||||
| - embed-english-v3.0 | |||||
| - embed-english-light-v3.0 | |||||
| - embed-multilingual-v2.0 | |||||
| - embed-english-v2.0 | |||||
| - embed-english-light-v2.0 |
| model: embed-english-light-v2.0 | |||||
| model_type: text-embedding | |||||
| model_properties: | |||||
| context_size: 1024 | |||||
| max_chunks: 48 | |||||
| pricing: | |||||
| input: '0.1' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| model: embed-english-light-v3.0 | |||||
| model_type: text-embedding | |||||
| model_properties: | |||||
| context_size: 384 | |||||
| max_chunks: 48 | |||||
| pricing: | |||||
| input: '0.1' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| model: embed-english-v2.0 | |||||
| model_type: text-embedding | |||||
| model_properties: | |||||
| context_size: 4096 | |||||
| max_chunks: 48 | |||||
| pricing: | |||||
| input: '0.1' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| model: embed-english-v3.0 | |||||
| model_type: text-embedding | |||||
| model_properties: | |||||
| context_size: 1024 | |||||
| max_chunks: 48 | |||||
| pricing: | |||||
| input: '0.1' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| model: embed-multilingual-light-v3.0 | |||||
| model_type: text-embedding | |||||
| model_properties: | |||||
| context_size: 384 | |||||
| max_chunks: 48 | |||||
| pricing: | |||||
| input: '0.1' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| model: embed-multilingual-v2.0 | |||||
| model_type: text-embedding | |||||
| model_properties: | |||||
| context_size: 768 | |||||
| max_chunks: 48 | |||||
| pricing: | |||||
| input: '0.1' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| model: embed-multilingual-v3.0 | |||||
| model_type: text-embedding | |||||
| model_properties: | |||||
| context_size: 1024 | |||||
| max_chunks: 48 | |||||
| pricing: | |||||
| input: '0.1' | |||||
| unit: '0.000001' | |||||
| currency: USD |
| import time | |||||
| from typing import Optional, Tuple | |||||
| import cohere | |||||
| import numpy as np | |||||
| from cohere.responses import Tokens | |||||
| from core.model_runtime.entities.model_entities import PriceType | |||||
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | |||||
| from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \ | |||||
| InvokeAuthorizationError, InvokeBadRequestError, InvokeError | |||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||||
| class CohereTextEmbeddingModel(TextEmbeddingModel): | |||||
| """ | |||||
| Model class for Cohere text embedding model. | |||||
| """ | |||||
| def _invoke(self, model: str, credentials: dict, | |||||
| texts: list[str], user: Optional[str] = None) \ | |||||
| -> TextEmbeddingResult: | |||||
| """ | |||||
| Invoke text embedding model | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param texts: texts to embed | |||||
| :param user: unique user id | |||||
| :return: embeddings result | |||||
| """ | |||||
| # get model properties | |||||
| context_size = self._get_context_size(model, credentials) | |||||
| max_chunks = self._get_max_chunks(model, credentials) | |||||
| embeddings: list[list[float]] = [[] for _ in range(len(texts))] | |||||
| tokens = [] | |||||
| indices = [] | |||||
| used_tokens = 0 | |||||
| for i, text in enumerate(texts): | |||||
| tokenize_response = self._tokenize( | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| text=text | |||||
| ) | |||||
| for j in range(0, tokenize_response.length, context_size): | |||||
| tokens += [tokenize_response.token_strings[j: j + context_size]] | |||||
| indices += [i] | |||||
| batched_embeddings = [] | |||||
| _iter = range(0, len(tokens), max_chunks) | |||||
| for i in _iter: | |||||
| # call embedding model | |||||
| embeddings_batch, embedding_used_tokens = self._embedding_invoke( | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| texts=["".join(token) for token in tokens[i: i + max_chunks]] | |||||
| ) | |||||
| used_tokens += embedding_used_tokens | |||||
| batched_embeddings += embeddings_batch | |||||
| results: list[list[list[float]]] = [[] for _ in range(len(texts))] | |||||
| num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))] | |||||
| for i in range(len(indices)): | |||||
| results[indices[i]].append(batched_embeddings[i]) | |||||
| num_tokens_in_batch[indices[i]].append(len(tokens[i])) | |||||
| for i in range(len(texts)): | |||||
| _result = results[i] | |||||
| if len(_result) == 0: | |||||
| embeddings_batch, embedding_used_tokens = self._embedding_invoke( | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| texts=[""] | |||||
| ) | |||||
| used_tokens += embedding_used_tokens | |||||
| average = embeddings_batch[0] | |||||
| else: | |||||
| average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) | |||||
| embeddings[i] = (average / np.linalg.norm(average)).tolist() | |||||
| # calc usage | |||||
| usage = self._calc_response_usage( | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| tokens=used_tokens | |||||
| ) | |||||
| return TextEmbeddingResult( | |||||
| embeddings=embeddings, | |||||
| usage=usage, | |||||
| model=model | |||||
| ) | |||||
| def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: | |||||
| """ | |||||
| Get number of tokens for given prompt messages | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param texts: texts to embed | |||||
| :return: | |||||
| """ | |||||
| if len(texts) == 0: | |||||
| return 0 | |||||
| full_text = ' '.join(texts) | |||||
| try: | |||||
| response = self._tokenize( | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| text=full_text | |||||
| ) | |||||
| except Exception as e: | |||||
| raise self._transform_invoke_error(e) | |||||
| return response.length | |||||
| def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens: | |||||
| """ | |||||
| Tokenize text | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param text: text to tokenize | |||||
| :return: | |||||
| """ | |||||
| # initialize client | |||||
| client = cohere.Client(credentials.get('api_key')) | |||||
| response = client.tokenize( | |||||
| text=text, | |||||
| model=model | |||||
| ) | |||||
| return response | |||||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||||
| """ | |||||
| Validate model credentials | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :return: | |||||
| """ | |||||
| try: | |||||
| # call embedding model | |||||
| self._embedding_invoke( | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| texts=['ping'] | |||||
| ) | |||||
| except Exception as ex: | |||||
| raise CredentialsValidateFailedError(str(ex)) | |||||
| def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> Tuple[list[list[float]], int]: | |||||
| """ | |||||
| Invoke embedding model | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param texts: texts to embed | |||||
| :return: embeddings and used tokens | |||||
| """ | |||||
| # initialize client | |||||
| client = cohere.Client(credentials.get('api_key')) | |||||
| # call embedding model | |||||
| response = client.embed( | |||||
| texts=texts, | |||||
| model=model, | |||||
| input_type='search_document' if len(texts) > 1 else 'search_query' | |||||
| ) | |||||
| return response.embeddings, response.meta['billed_units']['input_tokens'] | |||||
| def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: | |||||
| """ | |||||
| Calculate response usage | |||||
| :param model: model name | |||||
| :param credentials: model credentials | |||||
| :param tokens: input tokens | |||||
| :return: usage | |||||
| """ | |||||
| # get input price info | |||||
| input_price_info = self.get_price( | |||||
| model=model, | |||||
| credentials=credentials, | |||||
| price_type=PriceType.INPUT, | |||||
| tokens=tokens | |||||
| ) | |||||
| # transform usage | |||||
| usage = EmbeddingUsage( | |||||
| tokens=tokens, | |||||
| total_tokens=tokens, | |||||
| unit_price=input_price_info.unit_price, | |||||
| price_unit=input_price_info.unit, | |||||
| total_price=input_price_info.total_amount, | |||||
| currency=input_price_info.currency, | |||||
| latency=time.perf_counter() - self.started_at | |||||
| ) | |||||
| return usage | |||||
| @property | |||||
| def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | |||||
| """ | |||||
| Map model invoke error to unified error | |||||
| The key is the error type thrown to the caller | |||||
| The value is the error type thrown by the model, | |||||
| which needs to be converted into a unified error type for the caller. | |||||
| :return: Invoke error mapping | |||||
| """ | |||||
| return { | |||||
| InvokeConnectionError: [ | |||||
| cohere.CohereConnectionError | |||||
| ], | |||||
| InvokeServerUnavailableError: [], | |||||
| InvokeRateLimitError: [], | |||||
| InvokeAuthorizationError: [], | |||||
| InvokeBadRequestError: [ | |||||
| cohere.CohereAPIError, | |||||
| cohere.CohereError, | |||||
| ] | |||||
| } |
| **kwargs: Any, | **kwargs: Any, | ||||
| ): | ): | ||||
| def _token_encoder(text: str) -> int: | def _token_encoder(text: str) -> int: | ||||
| if not text: | |||||
| return 0 | |||||
| if embedding_model_instance: | if embedding_model_instance: | ||||
| embedding_model_type_instance = embedding_model_instance.model_type_instance | embedding_model_type_instance = embedding_model_instance.model_type_instance | ||||
| embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) | embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) |
| werkzeug==2.3.8 | werkzeug==2.3.8 | ||||
| pymilvus==2.3.0 | pymilvus==2.3.0 | ||||
| qdrant-client==1.6.4 | qdrant-client==1.6.4 | ||||
| cohere~=4.32 | |||||
| cohere~=4.44 | |||||
| pyyaml~=6.0.1 | pyyaml~=6.0.1 | ||||
| numpy~=1.25.2 | numpy~=1.25.2 | ||||
| unstructured[docx,pptx,msg,md,ppt]~=0.10.27 | unstructured[docx,pptx,msg,md,ppt]~=0.10.27 |
| import os | |||||
| from typing import Generator | |||||
| import pytest | |||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage, | |||||
| UserPromptMessage) | |||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||||
| from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel | |||||
| def test_validate_credentials_for_chat_model(): | |||||
| model = CohereLargeLanguageModel() | |||||
| with pytest.raises(CredentialsValidateFailedError): | |||||
| model.validate_credentials( | |||||
| model='command-light-chat', | |||||
| credentials={ | |||||
| 'api_key': 'invalid_key' | |||||
| } | |||||
| ) | |||||
| model.validate_credentials( | |||||
| model='command-light-chat', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||||
| } | |||||
| ) | |||||
| def test_validate_credentials_for_completion_model(): | |||||
| model = CohereLargeLanguageModel() | |||||
| with pytest.raises(CredentialsValidateFailedError): | |||||
| model.validate_credentials( | |||||
| model='command-light', | |||||
| credentials={ | |||||
| 'api_key': 'invalid_key' | |||||
| } | |||||
| ) | |||||
| model.validate_credentials( | |||||
| model='command-light', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||||
| } | |||||
| ) | |||||
| def test_invoke_completion_model(): | |||||
| model = CohereLargeLanguageModel() | |||||
| credentials = { | |||||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||||
| } | |||||
| result = model.invoke( | |||||
| model='command-light', | |||||
| credentials=credentials, | |||||
| prompt_messages=[ | |||||
| UserPromptMessage( | |||||
| content='Hello World!' | |||||
| ) | |||||
| ], | |||||
| model_parameters={ | |||||
| 'temperature': 0.0, | |||||
| 'max_tokens': 1 | |||||
| }, | |||||
| stream=False, | |||||
| user="abc-123" | |||||
| ) | |||||
| assert isinstance(result, LLMResult) | |||||
| assert len(result.message.content) > 0 | |||||
| assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1 | |||||
| def test_invoke_stream_completion_model(): | |||||
| model = CohereLargeLanguageModel() | |||||
| result = model.invoke( | |||||
| model='command-light', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||||
| }, | |||||
| prompt_messages=[ | |||||
| UserPromptMessage( | |||||
| content='Hello World!' | |||||
| ) | |||||
| ], | |||||
| model_parameters={ | |||||
| 'temperature': 0.0, | |||||
| 'max_tokens': 100 | |||||
| }, | |||||
| stream=True, | |||||
| user="abc-123" | |||||
| ) | |||||
| assert isinstance(result, Generator) | |||||
| for chunk in result: | |||||
| assert isinstance(chunk, LLMResultChunk) | |||||
| assert isinstance(chunk.delta, LLMResultChunkDelta) | |||||
| assert isinstance(chunk.delta.message, AssistantPromptMessage) | |||||
| assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True | |||||
| def test_invoke_chat_model(): | |||||
| model = CohereLargeLanguageModel() | |||||
| result = model.invoke( | |||||
| model='command-light-chat', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||||
| }, | |||||
| prompt_messages=[ | |||||
| SystemPromptMessage( | |||||
| content='You are a helpful AI assistant.', | |||||
| ), | |||||
| UserPromptMessage( | |||||
| content='Hello World!' | |||||
| ) | |||||
| ], | |||||
| model_parameters={ | |||||
| 'temperature': 0.0, | |||||
| 'p': 0.99, | |||||
| 'presence_penalty': 0.0, | |||||
| 'frequency_penalty': 0.0, | |||||
| 'max_tokens': 10 | |||||
| }, | |||||
| stop=['How'], | |||||
| stream=False, | |||||
| user="abc-123" | |||||
| ) | |||||
| assert isinstance(result, LLMResult) | |||||
| assert len(result.message.content) > 0 | |||||
| for chunk in model._llm_result_to_stream(result): | |||||
| assert isinstance(chunk, LLMResultChunk) | |||||
| assert isinstance(chunk.delta, LLMResultChunkDelta) | |||||
| assert isinstance(chunk.delta.message, AssistantPromptMessage) | |||||
| assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True | |||||
| def test_invoke_stream_chat_model(): | |||||
| model = CohereLargeLanguageModel() | |||||
| result = model.invoke( | |||||
| model='command-light-chat', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||||
| }, | |||||
| prompt_messages=[ | |||||
| SystemPromptMessage( | |||||
| content='You are a helpful AI assistant.', | |||||
| ), | |||||
| UserPromptMessage( | |||||
| content='Hello World!' | |||||
| ) | |||||
| ], | |||||
| model_parameters={ | |||||
| 'temperature': 0.0, | |||||
| 'max_tokens': 100 | |||||
| }, | |||||
| stream=True, | |||||
| user="abc-123" | |||||
| ) | |||||
| assert isinstance(result, Generator) | |||||
| for chunk in result: | |||||
| assert isinstance(chunk, LLMResultChunk) | |||||
| assert isinstance(chunk.delta, LLMResultChunkDelta) | |||||
| assert isinstance(chunk.delta.message, AssistantPromptMessage) | |||||
| assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True | |||||
| if chunk.delta.finish_reason is not None: | |||||
| assert chunk.delta.usage is not None | |||||
| assert chunk.delta.usage.completion_tokens > 0 | |||||
| def test_get_num_tokens(): | |||||
| model = CohereLargeLanguageModel() | |||||
| num_tokens = model.get_num_tokens( | |||||
| model='command-light', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||||
| }, | |||||
| prompt_messages=[ | |||||
| UserPromptMessage( | |||||
| content='Hello World!' | |||||
| ) | |||||
| ] | |||||
| ) | |||||
| assert num_tokens == 3 | |||||
| num_tokens = model.get_num_tokens( | |||||
| model='command-light-chat', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||||
| }, | |||||
| prompt_messages=[ | |||||
| SystemPromptMessage( | |||||
| content='You are a helpful AI assistant.', | |||||
| ), | |||||
| UserPromptMessage( | |||||
| content='Hello World!' | |||||
| ) | |||||
| ] | |||||
| ) | |||||
| assert num_tokens == 15 | |||||
| def test_fine_tuned_model(): | |||||
| model = CohereLargeLanguageModel() | |||||
| # test invoke | |||||
| result = model.invoke( | |||||
| model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('COHERE_API_KEY'), | |||||
| 'mode': 'completion' | |||||
| }, | |||||
| prompt_messages=[ | |||||
| SystemPromptMessage( | |||||
| content='You are a helpful AI assistant.', | |||||
| ), | |||||
| UserPromptMessage( | |||||
| content='Hello World!' | |||||
| ) | |||||
| ], | |||||
| model_parameters={ | |||||
| 'temperature': 0.0, | |||||
| 'max_tokens': 100 | |||||
| }, | |||||
| stream=False, | |||||
| user="abc-123" | |||||
| ) | |||||
| assert isinstance(result, LLMResult) | |||||
| def test_fine_tuned_chat_model(): | |||||
| model = CohereLargeLanguageModel() | |||||
| # test invoke | |||||
| result = model.invoke( | |||||
| model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('COHERE_API_KEY'), | |||||
| 'mode': 'chat' | |||||
| }, | |||||
| prompt_messages=[ | |||||
| SystemPromptMessage( | |||||
| content='You are a helpful AI assistant.', | |||||
| ), | |||||
| UserPromptMessage( | |||||
| content='Hello World!' | |||||
| ) | |||||
| ], | |||||
| model_parameters={ | |||||
| 'temperature': 0.0, | |||||
| 'max_tokens': 100 | |||||
| }, | |||||
| stream=False, | |||||
| user="abc-123" | |||||
| ) | |||||
| assert isinstance(result, LLMResult) |
| import os | |||||
| import pytest | |||||
| from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult | |||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||||
| from core.model_runtime.model_providers.cohere.text_embedding.text_embedding import CohereTextEmbeddingModel | |||||
| def test_validate_credentials(): | |||||
| model = CohereTextEmbeddingModel() | |||||
| with pytest.raises(CredentialsValidateFailedError): | |||||
| model.validate_credentials( | |||||
| model='embed-multilingual-v3.0', | |||||
| credentials={ | |||||
| 'api_key': 'invalid_key' | |||||
| } | |||||
| ) | |||||
| model.validate_credentials( | |||||
| model='embed-multilingual-v3.0', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||||
| } | |||||
| ) | |||||
| def test_invoke_model(): | |||||
| model = CohereTextEmbeddingModel() | |||||
| result = model.invoke( | |||||
| model='embed-multilingual-v3.0', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||||
| }, | |||||
| texts=[ | |||||
| "hello", | |||||
| "world", | |||||
| " ".join(["long_text"] * 100), | |||||
| " ".join(["another_long_text"] * 100) | |||||
| ], | |||||
| user="abc-123" | |||||
| ) | |||||
| assert isinstance(result, TextEmbeddingResult) | |||||
| assert len(result.embeddings) == 4 | |||||
| assert result.usage.total_tokens == 811 | |||||
| def test_get_num_tokens(): | |||||
| model = CohereTextEmbeddingModel() | |||||
| num_tokens = model.get_num_tokens( | |||||
| model='embed-multilingual-v3.0', | |||||
| credentials={ | |||||
| 'api_key': os.environ.get('COHERE_API_KEY') | |||||
| }, | |||||
| texts=[ | |||||
| "hello", | |||||
| "world" | |||||
| ] | |||||
| ) | |||||
| assert num_tokens == 3 |