Co-authored-by: zxhlyh <jasonapring2015@outlook.com>tags/0.3.25
| elif provider_name == 'chatglm': | elif provider_name == 'chatglm': | ||||
| from core.model_providers.providers.chatglm_provider import ChatGLMProvider | from core.model_providers.providers.chatglm_provider import ChatGLMProvider | ||||
| return ChatGLMProvider | return ChatGLMProvider | ||||
| elif provider_name == 'baichuan': | |||||
| from core.model_providers.providers.baichuan_provider import BaichuanProvider | |||||
| return BaichuanProvider | |||||
| elif provider_name == 'azure_openai': | elif provider_name == 'azure_openai': | ||||
| from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider | from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider | ||||
| return AzureOpenAIProvider | return AzureOpenAIProvider |
| from typing import List, Optional, Any | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.schema import LLMResult | |||||
| from core.model_providers.error import LLMBadRequestError | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.model_providers.models.entity.message import PromptMessage | |||||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||||
| from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM | |||||
| class BaichuanModel(BaseLLM): | |||||
| model_mode: ModelMode = ModelMode.CHAT | |||||
| def _init_client(self) -> Any: | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||||
| return BaichuanChatLLM( | |||||
| streaming=self.streaming, | |||||
| callbacks=self.callbacks, | |||||
| **self.credentials, | |||||
| **provider_model_kwargs | |||||
| ) | |||||
| def _run(self, messages: List[PromptMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs) -> LLMResult: | |||||
| """ | |||||
| run predict by prompt messages and stop words. | |||||
| :param messages: | |||||
| :param stop: | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return self._client.generate([prompts], stop, callbacks) | |||||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||||
| """ | |||||
| get num tokens of prompt messages. | |||||
| :param messages: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return max(self._client.get_num_tokens_from_messages(prompts), 0) | |||||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||||
| for k, v in provider_model_kwargs.items(): | |||||
| if hasattr(self.client, k): | |||||
| setattr(self.client, k, v) | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| return LLMBadRequestError(f"Baichuan: {str(ex)}") | |||||
| @property | |||||
| def support_streaming(self): | |||||
| return True |
| import json | |||||
| from json import JSONDecodeError | |||||
| from typing import Type | |||||
| from langchain.schema import HumanMessage | |||||
| from core.helper import encrypter | |||||
| from core.model_providers.models.base import BaseProviderModel | |||||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType | |||||
| from core.model_providers.models.llm.baichuan_model import BaichuanModel | |||||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||||
| from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM | |||||
| from models.provider import ProviderType | |||||
| class BaichuanProvider(BaseModelProvider): | |||||
| @property | |||||
| def provider_name(self): | |||||
| """ | |||||
| Returns the name of a provider. | |||||
| """ | |||||
| return 'baichuan' | |||||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||||
| if model_type == ModelType.TEXT_GENERATION: | |||||
| return [ | |||||
| { | |||||
| 'id': 'baichuan2-53b', | |||||
| 'name': 'Baichuan2-53B', | |||||
| } | |||||
| ] | |||||
| else: | |||||
| return [] | |||||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||||
| """ | |||||
| Returns the model class. | |||||
| :param model_type: | |||||
| :return: | |||||
| """ | |||||
| if model_type == ModelType.TEXT_GENERATION: | |||||
| model_class = BaichuanModel | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return model_class | |||||
| def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: | |||||
| """ | |||||
| get model parameter rules. | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :return: | |||||
| """ | |||||
| return ModelKwargsRules( | |||||
| temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2), | |||||
| top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2), | |||||
| presence_penalty=KwargRule[float](enabled=False), | |||||
| frequency_penalty=KwargRule[float](enabled=False), | |||||
| max_tokens=KwargRule[int](enabled=False), | |||||
| ) | |||||
| @classmethod | |||||
| def is_provider_credentials_valid_or_raise(cls, credentials: dict): | |||||
| """ | |||||
| Validates the given credentials. | |||||
| """ | |||||
| if 'api_key' not in credentials: | |||||
| raise CredentialsValidateFailedError('Baichuan api_key must be provided.') | |||||
| if 'secret_key' not in credentials: | |||||
| raise CredentialsValidateFailedError('Baichuan secret_key must be provided.') | |||||
| try: | |||||
| credential_kwargs = { | |||||
| 'api_key': credentials['api_key'], | |||||
| 'secret_key': credentials['secret_key'], | |||||
| } | |||||
| llm = BaichuanChatLLM( | |||||
| temperature=0, | |||||
| **credential_kwargs | |||||
| ) | |||||
| llm([HumanMessage(content='ping')]) | |||||
| except Exception as ex: | |||||
| raise CredentialsValidateFailedError(str(ex)) | |||||
| @classmethod | |||||
| def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: | |||||
| credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key']) | |||||
| credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key']) | |||||
| return credentials | |||||
| def get_provider_credentials(self, obfuscated: bool = False) -> dict: | |||||
| if self.provider.provider_type == ProviderType.CUSTOM.value: | |||||
| try: | |||||
| credentials = json.loads(self.provider.encrypted_config) | |||||
| except JSONDecodeError: | |||||
| credentials = { | |||||
| 'api_key': None, | |||||
| 'secret_key': None, | |||||
| } | |||||
| if credentials['api_key']: | |||||
| credentials['api_key'] = encrypter.decrypt_token( | |||||
| self.provider.tenant_id, | |||||
| credentials['api_key'] | |||||
| ) | |||||
| if obfuscated: | |||||
| credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key']) | |||||
| if credentials['secret_key']: | |||||
| credentials['secret_key'] = encrypter.decrypt_token( | |||||
| self.provider.tenant_id, | |||||
| credentials['secret_key'] | |||||
| ) | |||||
| if obfuscated: | |||||
| credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key']) | |||||
| return credentials | |||||
| else: | |||||
| return {} | |||||
| def should_deduct_quota(self): | |||||
| return True | |||||
| @classmethod | |||||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | |||||
| """ | |||||
| check model credentials valid. | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :param credentials: | |||||
| """ | |||||
| return | |||||
| @classmethod | |||||
| def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, | |||||
| credentials: dict) -> dict: | |||||
| """ | |||||
| encrypt model credentials for save. | |||||
| :param tenant_id: | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :param credentials: | |||||
| :return: | |||||
| """ | |||||
| return {} | |||||
| def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: | |||||
| """ | |||||
| get credentials for llm use. | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :param obfuscated: | |||||
| :return: | |||||
| """ | |||||
| return self.get_provider_credentials(obfuscated) |
| "spark", | "spark", | ||||
| "wenxin", | "wenxin", | ||||
| "zhipuai", | "zhipuai", | ||||
| "baichuan", | |||||
| "chatglm", | "chatglm", | ||||
| "replicate", | "replicate", | ||||
| "huggingface_hub", | "huggingface_hub", | ||||
| "xinference", | "xinference", | ||||
| "openllm", | "openllm", | ||||
| "localai" | "localai" | ||||
| ] | |||||
| ] |
| { | |||||
| "support_provider_types": [ | |||||
| "custom" | |||||
| ], | |||||
| "system_config": null, | |||||
| "model_flexibility": "fixed", | |||||
| "price_config": { | |||||
| "baichuan2-53b": { | |||||
| "prompt": "0.01", | |||||
| "completion": "0.01", | |||||
| "unit": "0.001", | |||||
| "currency": "RMB" | |||||
| } | |||||
| } | |||||
| } |
| """Wrapper around Baichuan APIs.""" | |||||
| from __future__ import annotations | |||||
| import hashlib | |||||
| import json | |||||
| import logging | |||||
| import time | |||||
| from typing import ( | |||||
| Any, | |||||
| Dict, | |||||
| List, | |||||
| Optional, Iterator, | |||||
| ) | |||||
| import requests | |||||
| from langchain.chat_models.base import BaseChatModel | |||||
| from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage | |||||
| from langchain.schema.messages import AIMessageChunk | |||||
| from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration | |||||
| from pydantic import Extra, root_validator, BaseModel | |||||
| from langchain.callbacks.manager import ( | |||||
| CallbackManagerForLLMRun, | |||||
| ) | |||||
| from langchain.utils import get_from_dict_or_env | |||||
| logger = logging.getLogger(__name__) | |||||
| class BaichuanModelAPI(BaseModel): | |||||
| api_key: str | |||||
| secret_key: str | |||||
| base_url: str = "https://api.baichuan-ai.com/v1" | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| extra = Extra.forbid | |||||
| def do_request(self, model: str, messages: list[dict], parameters: dict, **kwargs: Any): | |||||
| stream = 'stream' in kwargs and kwargs['stream'] | |||||
| url = self.base_url + ("/stream/chat" if stream else "/chat") | |||||
| data = { | |||||
| "model": model, | |||||
| "messages": messages, | |||||
| "parameters": parameters | |||||
| } | |||||
| json_data = json.dumps(data) | |||||
| time_stamp = int(time.time()) | |||||
| signature = self._calculate_md5(self.secret_key + json_data + str(time_stamp)) | |||||
| headers = { | |||||
| "Content-Type": "application/json", | |||||
| "Authorization": "Bearer " + self.api_key, | |||||
| "X-BC-Request-Id": "your requestId", | |||||
| "X-BC-Timestamp": str(time_stamp), | |||||
| "X-BC-Signature": signature, | |||||
| "X-BC-Sign-Algo": "MD5", | |||||
| } | |||||
| response = requests.post(url, data=json_data, headers=headers, stream=stream, timeout=(5, 60)) | |||||
| if not response.ok: | |||||
| raise ValueError(f"HTTP {response.status_code} error: {response.text}") | |||||
| if not stream: | |||||
| json_response = response.json() | |||||
| if json_response['code'] != 0: | |||||
| raise ValueError( | |||||
| f"API {json_response['code']}" | |||||
| f" error: {json_response['msg']}" | |||||
| ) | |||||
| return json_response | |||||
| else: | |||||
| return response | |||||
| def _calculate_md5(self, input_string): | |||||
| md5 = hashlib.md5() | |||||
| md5.update(input_string.encode('utf-8')) | |||||
| encrypted = md5.hexdigest() | |||||
| return encrypted | |||||
| class BaichuanChatLLM(BaseChatModel): | |||||
| """Wrapper around Baichuan large language models. | |||||
| To use, you should pass the api_key as a named parameter to the constructor. | |||||
| Example: | |||||
| .. code-block:: python | |||||
| from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM | |||||
| model = BaichuanChatLLM(model="<model_name>", api_key="my-api-key", secret_key="my-secret-key") | |||||
| """ | |||||
| @property | |||||
| def lc_secrets(self) -> Dict[str, str]: | |||||
| return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"} | |||||
| @property | |||||
| def lc_serializable(self) -> bool: | |||||
| return True | |||||
| client: Any = None #: :meta private: | |||||
| model: str = "Baichuan2-53B" | |||||
| """Model name to use.""" | |||||
| temperature: float = 0.3 | |||||
| """A non-negative float that tunes the degree of randomness in generation.""" | |||||
| top_p: float = 0.85 | |||||
| """Total probability mass of tokens to consider at each step.""" | |||||
| streaming: bool = False | |||||
| """Whether to stream the response or return it all at once.""" | |||||
| api_key: Optional[str] = None | |||||
| secret_key: Optional[str] = None | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| extra = Extra.forbid | |||||
| @root_validator() | |||||
| def validate_environment(cls, values: Dict) -> Dict: | |||||
| """Validate that api key and python package exists in environment.""" | |||||
| values["api_key"] = get_from_dict_or_env( | |||||
| values, "api_key", "BAICHUAN_API_KEY" | |||||
| ) | |||||
| values["secret_key"] = get_from_dict_or_env( | |||||
| values, "secret_key", "BAICHUAN_SECRET_KEY" | |||||
| ) | |||||
| values['client'] = BaichuanModelAPI( | |||||
| api_key=values['api_key'], | |||||
| secret_key=values['secret_key'] | |||||
| ) | |||||
| return values | |||||
| @property | |||||
| def _default_params(self) -> Dict[str, Any]: | |||||
| """Get the default parameters for calling OpenAI API.""" | |||||
| return { | |||||
| "model": self.model, | |||||
| "parameters": { | |||||
| "temperature": self.temperature, | |||||
| "top_p": self.top_p | |||||
| } | |||||
| } | |||||
| @property | |||||
| def _identifying_params(self) -> Dict[str, Any]: | |||||
| """Get the identifying parameters.""" | |||||
| return self._default_params | |||||
| @property | |||||
| def _llm_type(self) -> str: | |||||
| """Return type of llm.""" | |||||
| return "baichuan" | |||||
| def _convert_message_to_dict(self, message: BaseMessage) -> dict: | |||||
| if isinstance(message, ChatMessage): | |||||
| message_dict = {"role": message.role, "content": message.content} | |||||
| elif isinstance(message, HumanMessage): | |||||
| message_dict = {"role": "user", "content": message.content} | |||||
| elif isinstance(message, AIMessage): | |||||
| message_dict = {"role": "assistant", "content": message.content} | |||||
| elif isinstance(message, SystemMessage): | |||||
| message_dict = {"role": "user", "content": message.content} | |||||
| else: | |||||
| raise ValueError(f"Got unknown type {message}") | |||||
| return message_dict | |||||
| def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage: | |||||
| role = _dict["role"] | |||||
| if role == "user": | |||||
| return HumanMessage(content=_dict["content"]) | |||||
| elif role == "assistant": | |||||
| return AIMessage(content=_dict["content"]) | |||||
| elif role == "system": | |||||
| return SystemMessage(content=_dict["content"]) | |||||
| else: | |||||
| return ChatMessage(content=_dict["content"], role=role) | |||||
| def _create_message_dicts( | |||||
| self, messages: List[BaseMessage] | |||||
| ) -> List[Dict[str, Any]]: | |||||
| dict_messages = [] | |||||
| for m in messages: | |||||
| message = self._convert_message_to_dict(m) | |||||
| if dict_messages: | |||||
| previous_message = dict_messages[-1] | |||||
| if previous_message['role'] == message['role']: | |||||
| dict_messages[-1]['content'] += f"\n{message['content']}" | |||||
| else: | |||||
| dict_messages.append(message) | |||||
| else: | |||||
| dict_messages.append(message) | |||||
| return dict_messages | |||||
| def _generate( | |||||
| self, | |||||
| messages: List[BaseMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||||
| **kwargs: Any, | |||||
| ) -> ChatResult: | |||||
| if self.streaming: | |||||
| generation: Optional[ChatGenerationChunk] = None | |||||
| llm_output: Optional[Dict] = None | |||||
| for chunk in self._stream( | |||||
| messages=messages, stop=stop, run_manager=run_manager, **kwargs | |||||
| ): | |||||
| if generation is None: | |||||
| generation = chunk | |||||
| else: | |||||
| generation += chunk | |||||
| if chunk.generation_info is not None \ | |||||
| and 'token_usage' in chunk.generation_info: | |||||
| llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model} | |||||
| assert generation is not None | |||||
| return ChatResult(generations=[generation], llm_output=llm_output) | |||||
| else: | |||||
| message_dicts = self._create_message_dicts(messages) | |||||
| params = self._default_params | |||||
| params["messages"] = message_dicts | |||||
| params.update(kwargs) | |||||
| response = self.client.do_request(**params) | |||||
| return self._create_chat_result(response) | |||||
| def _stream( | |||||
| self, | |||||
| messages: List[BaseMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||||
| **kwargs: Any, | |||||
| ) -> Iterator[ChatGenerationChunk]: | |||||
| message_dicts = self._create_message_dicts(messages) | |||||
| params = self._default_params | |||||
| params["messages"] = message_dicts | |||||
| params.update(kwargs) | |||||
| for event in self.client.do_request(stream=True, **params).iter_lines(): | |||||
| if event: | |||||
| event = event.decode("utf-8") | |||||
| meta = json.loads(event) | |||||
| if meta['code'] != 0: | |||||
| raise ValueError( | |||||
| f"API {meta['code']}" | |||||
| f" error: {meta['msg']}" | |||||
| ) | |||||
| content = meta['data']['messages'][0]['content'] | |||||
| chunk_kwargs = { | |||||
| 'message': AIMessageChunk(content=content), | |||||
| } | |||||
| if 'usage' in meta: | |||||
| token_usage = meta['usage'] | |||||
| overall_token_usage = { | |||||
| 'prompt_tokens': token_usage.get('prompt_tokens', 0), | |||||
| 'completion_tokens': token_usage.get('answer_tokens', 0), | |||||
| 'total_tokens': token_usage.get('total_tokens', 0) | |||||
| } | |||||
| chunk_kwargs['generation_info'] = {'token_usage': overall_token_usage} | |||||
| yield ChatGenerationChunk(**chunk_kwargs) | |||||
| if run_manager: | |||||
| run_manager.on_llm_new_token(content) | |||||
| def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult: | |||||
| data = response["data"] | |||||
| generations = [] | |||||
| for res in data["messages"]: | |||||
| message = self._convert_dict_to_message(res) | |||||
| gen = ChatGeneration( | |||||
| message=message | |||||
| ) | |||||
| generations.append(gen) | |||||
| usage = response.get("usage") | |||||
| token_usage = { | |||||
| 'prompt_tokens': usage.get('prompt_tokens', 0), | |||||
| 'completion_tokens': usage.get('answer_tokens', 0), | |||||
| 'total_tokens': usage.get('total_tokens', 0) | |||||
| } | |||||
| llm_output = {"token_usage": token_usage, "model_name": self.model} | |||||
| return ChatResult(generations=generations, llm_output=llm_output) | |||||
| def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: | |||||
| """Get the number of tokens in the messages. | |||||
| Useful for checking if an input will fit in a model's context window. | |||||
| Args: | |||||
| messages: The message inputs to tokenize. | |||||
| Returns: | |||||
| The sum of the number of tokens across the messages. | |||||
| """ | |||||
| return sum([self.get_num_tokens(m.content) for m in messages]) | |||||
| def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: | |||||
| token_usage: dict = {} | |||||
| for output in llm_outputs: | |||||
| if output is None: | |||||
| # Happens in streaming | |||||
| continue | |||||
| token_usage = output["token_usage"] | |||||
| return {"token_usage": token_usage, "model_name": self.model} |
| # ZhipuAI Credentials | # ZhipuAI Credentials | ||||
| ZHIPUAI_API_KEY= | ZHIPUAI_API_KEY= | ||||
| # Baichuan Credentials | |||||
| BAICHUAN_API_KEY= | |||||
| BAICHUAN_SECRET_KEY= | |||||
| # ChatGLM Credentials | # ChatGLM Credentials | ||||
| CHATGLM_API_BASE= | CHATGLM_API_BASE= | ||||
| import json | |||||
| import os | |||||
| from unittest.mock import patch | |||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||||
| from core.model_providers.models.entity.model_params import ModelKwargs | |||||
| from core.model_providers.models.llm.baichuan_model import BaichuanModel | |||||
| from core.model_providers.providers.baichuan_provider import BaichuanProvider | |||||
| from models.provider import Provider, ProviderType | |||||
| def get_mock_provider(valid_api_key, valid_secret_key): | |||||
| return Provider( | |||||
| id='provider_id', | |||||
| tenant_id='tenant_id', | |||||
| provider_name='baichuan', | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=json.dumps({ | |||||
| 'api_key': valid_api_key, | |||||
| 'secret_key': valid_secret_key, | |||||
| }), | |||||
| is_valid=True, | |||||
| ) | |||||
| def get_mock_model(model_name: str, streaming: bool = False): | |||||
| model_kwargs = ModelKwargs( | |||||
| temperature=0.01, | |||||
| ) | |||||
| valid_api_key = os.environ['BAICHUAN_API_KEY'] | |||||
| valid_secret_key = os.environ['BAICHUAN_SECRET_KEY'] | |||||
| model_provider = BaichuanProvider(provider=get_mock_provider(valid_api_key, valid_secret_key)) | |||||
| return BaichuanModel( | |||||
| model_provider=model_provider, | |||||
| name=model_name, | |||||
| model_kwargs=model_kwargs, | |||||
| streaming=streaming | |||||
| ) | |||||
| def decrypt_side_effect(tenant_id, encrypted_api_key): | |||||
| return encrypted_api_key | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_chat_get_num_tokens(mock_decrypt): | |||||
| model = get_mock_model('baichuan2-53b') | |||||
| rst = model.get_num_tokens([ | |||||
| PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), | |||||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||||
| ]) | |||||
| assert rst > 0 | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_chat_run(mock_decrypt, mocker): | |||||
| mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) | |||||
| model = get_mock_model('baichuan2-53b') | |||||
| messages = [ | |||||
| PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') | |||||
| ] | |||||
| rst = model.run( | |||||
| messages, | |||||
| ) | |||||
| assert len(rst.content) > 0 | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_chat_stream_run(mock_decrypt, mocker): | |||||
| mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) | |||||
| model = get_mock_model('baichuan2-53b', streaming=True) | |||||
| messages = [ | |||||
| PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') | |||||
| ] | |||||
| rst = model.run( | |||||
| messages | |||||
| ) | |||||
| assert len(rst.content) > 0 |
| import pytest | |||||
| from unittest.mock import patch | |||||
| import json | |||||
| from langchain.schema import ChatResult, ChatGeneration, AIMessage | |||||
| from core.model_providers.providers.baichuan_provider import BaichuanProvider | |||||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||||
| from models.provider import ProviderType, Provider | |||||
| PROVIDER_NAME = 'baichuan' | |||||
| MODEL_PROVIDER_CLASS = BaichuanProvider | |||||
| VALIDATE_CREDENTIAL = { | |||||
| 'api_key': 'valid_key', | |||||
| 'secret_key': 'valid_key', | |||||
| } | |||||
| def encrypt_side_effect(tenant_id, encrypt_key): | |||||
| return f'encrypted_{encrypt_key}' | |||||
| def decrypt_side_effect(tenant_id, encrypted_key): | |||||
| return encrypted_key.replace('encrypted_', '') | |||||
| def test_is_provider_credentials_valid_or_raise_valid(mocker): | |||||
| mocker.patch('core.third_party.langchain.llms.baichuan_llm.BaichuanChatLLM._generate', | |||||
| return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))])) | |||||
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) | |||||
| def test_is_provider_credentials_valid_or_raise_invalid(): | |||||
| # raise CredentialsValidateFailedError if api_key is not in credentials | |||||
| with pytest.raises(CredentialsValidateFailedError): | |||||
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) | |||||
| credential = VALIDATE_CREDENTIAL.copy() | |||||
| credential['api_key'] = 'invalid_key' | |||||
| credential['secret_key'] = 'invalid_key' | |||||
| # raise CredentialsValidateFailedError if api_key is invalid | |||||
| with pytest.raises(CredentialsValidateFailedError): | |||||
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential) | |||||
| @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) | |||||
| def test_encrypt_credentials(mock_encrypt): | |||||
| result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy()) | |||||
| assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}' | |||||
| assert result['secret_key'] == f'encrypted_{VALIDATE_CREDENTIAL["secret_key"]}' | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_get_credentials_custom(mock_decrypt): | |||||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||||
| encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] | |||||
| encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key'] | |||||
| provider = Provider( | |||||
| id='provider_id', | |||||
| tenant_id='tenant_id', | |||||
| provider_name=PROVIDER_NAME, | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=json.dumps(encrypted_credential), | |||||
| is_valid=True, | |||||
| ) | |||||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||||
| result = model_provider.get_provider_credentials() | |||||
| assert result['api_key'] == 'valid_key' | |||||
| assert result['secret_key'] == 'valid_key' | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_get_credentials_obfuscated(mock_decrypt): | |||||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||||
| encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] | |||||
| encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key'] | |||||
| provider = Provider( | |||||
| id='provider_id', | |||||
| tenant_id='tenant_id', | |||||
| provider_name=PROVIDER_NAME, | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=json.dumps(encrypted_credential), | |||||
| is_valid=True, | |||||
| ) | |||||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||||
| result = model_provider.get_provider_credentials(obfuscated=True) | |||||
| middle_token = result['api_key'][6:-2] | |||||
| secret_key_middle_token = result['secret_key'][6:-2] | |||||
| assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0) | |||||
| assert len(secret_key_middle_token) == max(len(VALIDATE_CREDENTIAL['secret_key']) - 8, 0) | |||||
| assert all(char == '*' for char in middle_token) | |||||
| assert all(char == '*' for char in secret_key_middle_token) |