Co-authored-by: zxhlyh <jasonapring2015@outlook.com>tags/0.3.25
| @@ -51,6 +51,9 @@ class ModelProviderFactory: | |||
| elif provider_name == 'chatglm': | |||
| from core.model_providers.providers.chatglm_provider import ChatGLMProvider | |||
| return ChatGLMProvider | |||
| elif provider_name == 'baichuan': | |||
| from core.model_providers.providers.baichuan_provider import BaichuanProvider | |||
| return BaichuanProvider | |||
| elif provider_name == 'azure_openai': | |||
| from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider | |||
| return AzureOpenAIProvider | |||
| @@ -0,0 +1,61 @@ | |||
| from typing import List, Optional, Any | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.schema import LLMResult | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM | |||
| class BaichuanModel(BaseLLM): | |||
| model_mode: ModelMode = ModelMode.CHAT | |||
| def _init_client(self) -> Any: | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| return BaichuanChatLLM( | |||
| streaming=self.streaming, | |||
| callbacks=self.callbacks, | |||
| **self.credentials, | |||
| **provider_model_kwargs | |||
| ) | |||
| def _run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs) -> LLMResult: | |||
| """ | |||
| run predict by prompt messages and stop words. | |||
| :param messages: | |||
| :param stop: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.generate([prompts], stop, callbacks) | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| get num tokens of prompt messages. | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return max(self._client.get_num_tokens_from_messages(prompts), 0) | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| for k, v in provider_model_kwargs.items(): | |||
| if hasattr(self.client, k): | |||
| setattr(self.client, k, v) | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| return LLMBadRequestError(f"Baichuan: {str(ex)}") | |||
| @property | |||
| def support_streaming(self): | |||
| return True | |||
| @@ -0,0 +1,167 @@ | |||
| import json | |||
| from json import JSONDecodeError | |||
| from typing import Type | |||
| from langchain.schema import HumanMessage | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType | |||
| from core.model_providers.models.llm.baichuan_model import BaichuanModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM | |||
| from models.provider import ProviderType | |||
| class BaichuanProvider(BaseModelProvider): | |||
| @property | |||
| def provider_name(self): | |||
| """ | |||
| Returns the name of a provider. | |||
| """ | |||
| return 'baichuan' | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| if model_type == ModelType.TEXT_GENERATION: | |||
| return [ | |||
| { | |||
| 'id': 'baichuan2-53b', | |||
| 'name': 'Baichuan2-53B', | |||
| } | |||
| ] | |||
| else: | |||
| return [] | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| :param model_type: | |||
| :return: | |||
| """ | |||
| if model_type == ModelType.TEXT_GENERATION: | |||
| model_class = BaichuanModel | |||
| else: | |||
| raise NotImplementedError | |||
| return model_class | |||
| def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: | |||
| """ | |||
| get model parameter rules. | |||
| :param model_name: | |||
| :param model_type: | |||
| :return: | |||
| """ | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2), | |||
| top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](enabled=False), | |||
| ) | |||
| @classmethod | |||
| def is_provider_credentials_valid_or_raise(cls, credentials: dict): | |||
| """ | |||
| Validates the given credentials. | |||
| """ | |||
| if 'api_key' not in credentials: | |||
| raise CredentialsValidateFailedError('Baichuan api_key must be provided.') | |||
| if 'secret_key' not in credentials: | |||
| raise CredentialsValidateFailedError('Baichuan secret_key must be provided.') | |||
| try: | |||
| credential_kwargs = { | |||
| 'api_key': credentials['api_key'], | |||
| 'secret_key': credentials['secret_key'], | |||
| } | |||
| llm = BaichuanChatLLM( | |||
| temperature=0, | |||
| **credential_kwargs | |||
| ) | |||
| llm([HumanMessage(content='ping')]) | |||
| except Exception as ex: | |||
| raise CredentialsValidateFailedError(str(ex)) | |||
| @classmethod | |||
| def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: | |||
| credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key']) | |||
| credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key']) | |||
| return credentials | |||
| def get_provider_credentials(self, obfuscated: bool = False) -> dict: | |||
| if self.provider.provider_type == ProviderType.CUSTOM.value: | |||
| try: | |||
| credentials = json.loads(self.provider.encrypted_config) | |||
| except JSONDecodeError: | |||
| credentials = { | |||
| 'api_key': None, | |||
| 'secret_key': None, | |||
| } | |||
| if credentials['api_key']: | |||
| credentials['api_key'] = encrypter.decrypt_token( | |||
| self.provider.tenant_id, | |||
| credentials['api_key'] | |||
| ) | |||
| if obfuscated: | |||
| credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key']) | |||
| if credentials['secret_key']: | |||
| credentials['secret_key'] = encrypter.decrypt_token( | |||
| self.provider.tenant_id, | |||
| credentials['secret_key'] | |||
| ) | |||
| if obfuscated: | |||
| credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key']) | |||
| return credentials | |||
| else: | |||
| return {} | |||
| def should_deduct_quota(self): | |||
| return True | |||
| @classmethod | |||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | |||
| """ | |||
| check model credentials valid. | |||
| :param model_name: | |||
| :param model_type: | |||
| :param credentials: | |||
| """ | |||
| return | |||
| @classmethod | |||
| def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, | |||
| credentials: dict) -> dict: | |||
| """ | |||
| encrypt model credentials for save. | |||
| :param tenant_id: | |||
| :param model_name: | |||
| :param model_type: | |||
| :param credentials: | |||
| :return: | |||
| """ | |||
| return {} | |||
| def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: | |||
| """ | |||
| get credentials for llm use. | |||
| :param model_name: | |||
| :param model_type: | |||
| :param obfuscated: | |||
| :return: | |||
| """ | |||
| return self.get_provider_credentials(obfuscated) | |||
| @@ -7,10 +7,11 @@ | |||
| "spark", | |||
| "wenxin", | |||
| "zhipuai", | |||
| "baichuan", | |||
| "chatglm", | |||
| "replicate", | |||
| "huggingface_hub", | |||
| "xinference", | |||
| "openllm", | |||
| "localai" | |||
| ] | |||
| ] | |||
| @@ -0,0 +1,15 @@ | |||
| { | |||
| "support_provider_types": [ | |||
| "custom" | |||
| ], | |||
| "system_config": null, | |||
| "model_flexibility": "fixed", | |||
| "price_config": { | |||
| "baichuan2-53b": { | |||
| "prompt": "0.01", | |||
| "completion": "0.01", | |||
| "unit": "0.001", | |||
| "currency": "RMB" | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,315 @@ | |||
| """Wrapper around Baichuan APIs.""" | |||
| from __future__ import annotations | |||
| import hashlib | |||
| import json | |||
| import logging | |||
| import time | |||
| from typing import ( | |||
| Any, | |||
| Dict, | |||
| List, | |||
| Optional, Iterator, | |||
| ) | |||
| import requests | |||
| from langchain.chat_models.base import BaseChatModel | |||
| from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage | |||
| from langchain.schema.messages import AIMessageChunk | |||
| from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration | |||
| from pydantic import Extra, root_validator, BaseModel | |||
| from langchain.callbacks.manager import ( | |||
| CallbackManagerForLLMRun, | |||
| ) | |||
| from langchain.utils import get_from_dict_or_env | |||
| logger = logging.getLogger(__name__) | |||
| class BaichuanModelAPI(BaseModel): | |||
| api_key: str | |||
| secret_key: str | |||
| base_url: str = "https://api.baichuan-ai.com/v1" | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| extra = Extra.forbid | |||
| def do_request(self, model: str, messages: list[dict], parameters: dict, **kwargs: Any): | |||
| stream = 'stream' in kwargs and kwargs['stream'] | |||
| url = self.base_url + ("/stream/chat" if stream else "/chat") | |||
| data = { | |||
| "model": model, | |||
| "messages": messages, | |||
| "parameters": parameters | |||
| } | |||
| json_data = json.dumps(data) | |||
| time_stamp = int(time.time()) | |||
| signature = self._calculate_md5(self.secret_key + json_data + str(time_stamp)) | |||
| headers = { | |||
| "Content-Type": "application/json", | |||
| "Authorization": "Bearer " + self.api_key, | |||
| "X-BC-Request-Id": "your requestId", | |||
| "X-BC-Timestamp": str(time_stamp), | |||
| "X-BC-Signature": signature, | |||
| "X-BC-Sign-Algo": "MD5", | |||
| } | |||
| response = requests.post(url, data=json_data, headers=headers, stream=stream, timeout=(5, 60)) | |||
| if not response.ok: | |||
| raise ValueError(f"HTTP {response.status_code} error: {response.text}") | |||
| if not stream: | |||
| json_response = response.json() | |||
| if json_response['code'] != 0: | |||
| raise ValueError( | |||
| f"API {json_response['code']}" | |||
| f" error: {json_response['msg']}" | |||
| ) | |||
| return json_response | |||
| else: | |||
| return response | |||
| def _calculate_md5(self, input_string): | |||
| md5 = hashlib.md5() | |||
| md5.update(input_string.encode('utf-8')) | |||
| encrypted = md5.hexdigest() | |||
| return encrypted | |||
| class BaichuanChatLLM(BaseChatModel): | |||
| """Wrapper around Baichuan large language models. | |||
| To use, you should pass the api_key as a named parameter to the constructor. | |||
| Example: | |||
| .. code-block:: python | |||
| from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM | |||
| model = BaichuanChatLLM(model="<model_name>", api_key="my-api-key", secret_key="my-secret-key") | |||
| """ | |||
| @property | |||
| def lc_secrets(self) -> Dict[str, str]: | |||
| return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"} | |||
| @property | |||
| def lc_serializable(self) -> bool: | |||
| return True | |||
| client: Any = None #: :meta private: | |||
| model: str = "Baichuan2-53B" | |||
| """Model name to use.""" | |||
| temperature: float = 0.3 | |||
| """A non-negative float that tunes the degree of randomness in generation.""" | |||
| top_p: float = 0.85 | |||
| """Total probability mass of tokens to consider at each step.""" | |||
| streaming: bool = False | |||
| """Whether to stream the response or return it all at once.""" | |||
| api_key: Optional[str] = None | |||
| secret_key: Optional[str] = None | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| extra = Extra.forbid | |||
| @root_validator() | |||
| def validate_environment(cls, values: Dict) -> Dict: | |||
| """Validate that api key and python package exists in environment.""" | |||
| values["api_key"] = get_from_dict_or_env( | |||
| values, "api_key", "BAICHUAN_API_KEY" | |||
| ) | |||
| values["secret_key"] = get_from_dict_or_env( | |||
| values, "secret_key", "BAICHUAN_SECRET_KEY" | |||
| ) | |||
| values['client'] = BaichuanModelAPI( | |||
| api_key=values['api_key'], | |||
| secret_key=values['secret_key'] | |||
| ) | |||
| return values | |||
| @property | |||
| def _default_params(self) -> Dict[str, Any]: | |||
| """Get the default parameters for calling OpenAI API.""" | |||
| return { | |||
| "model": self.model, | |||
| "parameters": { | |||
| "temperature": self.temperature, | |||
| "top_p": self.top_p | |||
| } | |||
| } | |||
| @property | |||
| def _identifying_params(self) -> Dict[str, Any]: | |||
| """Get the identifying parameters.""" | |||
| return self._default_params | |||
| @property | |||
| def _llm_type(self) -> str: | |||
| """Return type of llm.""" | |||
| return "baichuan" | |||
| def _convert_message_to_dict(self, message: BaseMessage) -> dict: | |||
| if isinstance(message, ChatMessage): | |||
| message_dict = {"role": message.role, "content": message.content} | |||
| elif isinstance(message, HumanMessage): | |||
| message_dict = {"role": "user", "content": message.content} | |||
| elif isinstance(message, AIMessage): | |||
| message_dict = {"role": "assistant", "content": message.content} | |||
| elif isinstance(message, SystemMessage): | |||
| message_dict = {"role": "user", "content": message.content} | |||
| else: | |||
| raise ValueError(f"Got unknown type {message}") | |||
| return message_dict | |||
| def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage: | |||
| role = _dict["role"] | |||
| if role == "user": | |||
| return HumanMessage(content=_dict["content"]) | |||
| elif role == "assistant": | |||
| return AIMessage(content=_dict["content"]) | |||
| elif role == "system": | |||
| return SystemMessage(content=_dict["content"]) | |||
| else: | |||
| return ChatMessage(content=_dict["content"], role=role) | |||
| def _create_message_dicts( | |||
| self, messages: List[BaseMessage] | |||
| ) -> List[Dict[str, Any]]: | |||
| dict_messages = [] | |||
| for m in messages: | |||
| message = self._convert_message_to_dict(m) | |||
| if dict_messages: | |||
| previous_message = dict_messages[-1] | |||
| if previous_message['role'] == message['role']: | |||
| dict_messages[-1]['content'] += f"\n{message['content']}" | |||
| else: | |||
| dict_messages.append(message) | |||
| else: | |||
| dict_messages.append(message) | |||
| return dict_messages | |||
| def _generate( | |||
| self, | |||
| messages: List[BaseMessage], | |||
| stop: Optional[List[str]] = None, | |||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||
| **kwargs: Any, | |||
| ) -> ChatResult: | |||
| if self.streaming: | |||
| generation: Optional[ChatGenerationChunk] = None | |||
| llm_output: Optional[Dict] = None | |||
| for chunk in self._stream( | |||
| messages=messages, stop=stop, run_manager=run_manager, **kwargs | |||
| ): | |||
| if generation is None: | |||
| generation = chunk | |||
| else: | |||
| generation += chunk | |||
| if chunk.generation_info is not None \ | |||
| and 'token_usage' in chunk.generation_info: | |||
| llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model} | |||
| assert generation is not None | |||
| return ChatResult(generations=[generation], llm_output=llm_output) | |||
| else: | |||
| message_dicts = self._create_message_dicts(messages) | |||
| params = self._default_params | |||
| params["messages"] = message_dicts | |||
| params.update(kwargs) | |||
| response = self.client.do_request(**params) | |||
| return self._create_chat_result(response) | |||
| def _stream( | |||
| self, | |||
| messages: List[BaseMessage], | |||
| stop: Optional[List[str]] = None, | |||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||
| **kwargs: Any, | |||
| ) -> Iterator[ChatGenerationChunk]: | |||
| message_dicts = self._create_message_dicts(messages) | |||
| params = self._default_params | |||
| params["messages"] = message_dicts | |||
| params.update(kwargs) | |||
| for event in self.client.do_request(stream=True, **params).iter_lines(): | |||
| if event: | |||
| event = event.decode("utf-8") | |||
| meta = json.loads(event) | |||
| if meta['code'] != 0: | |||
| raise ValueError( | |||
| f"API {meta['code']}" | |||
| f" error: {meta['msg']}" | |||
| ) | |||
| content = meta['data']['messages'][0]['content'] | |||
| chunk_kwargs = { | |||
| 'message': AIMessageChunk(content=content), | |||
| } | |||
| if 'usage' in meta: | |||
| token_usage = meta['usage'] | |||
| overall_token_usage = { | |||
| 'prompt_tokens': token_usage.get('prompt_tokens', 0), | |||
| 'completion_tokens': token_usage.get('answer_tokens', 0), | |||
| 'total_tokens': token_usage.get('total_tokens', 0) | |||
| } | |||
| chunk_kwargs['generation_info'] = {'token_usage': overall_token_usage} | |||
| yield ChatGenerationChunk(**chunk_kwargs) | |||
| if run_manager: | |||
| run_manager.on_llm_new_token(content) | |||
| def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult: | |||
| data = response["data"] | |||
| generations = [] | |||
| for res in data["messages"]: | |||
| message = self._convert_dict_to_message(res) | |||
| gen = ChatGeneration( | |||
| message=message | |||
| ) | |||
| generations.append(gen) | |||
| usage = response.get("usage") | |||
| token_usage = { | |||
| 'prompt_tokens': usage.get('prompt_tokens', 0), | |||
| 'completion_tokens': usage.get('answer_tokens', 0), | |||
| 'total_tokens': usage.get('total_tokens', 0) | |||
| } | |||
| llm_output = {"token_usage": token_usage, "model_name": self.model} | |||
| return ChatResult(generations=generations, llm_output=llm_output) | |||
| def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: | |||
| """Get the number of tokens in the messages. | |||
| Useful for checking if an input will fit in a model's context window. | |||
| Args: | |||
| messages: The message inputs to tokenize. | |||
| Returns: | |||
| The sum of the number of tokens across the messages. | |||
| """ | |||
| return sum([self.get_num_tokens(m.content) for m in messages]) | |||
| def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: | |||
| token_usage: dict = {} | |||
| for output in llm_outputs: | |||
| if output is None: | |||
| # Happens in streaming | |||
| continue | |||
| token_usage = output["token_usage"] | |||
| return {"token_usage": token_usage, "model_name": self.model} | |||
| @@ -35,6 +35,10 @@ WENXIN_SECRET_KEY= | |||
| # ZhipuAI Credentials | |||
| ZHIPUAI_API_KEY= | |||
| # Baichuan Credentials | |||
| BAICHUAN_API_KEY= | |||
| BAICHUAN_SECRET_KEY= | |||
| # ChatGLM Credentials | |||
| CHATGLM_API_BASE= | |||
| @@ -0,0 +1,81 @@ | |||
| import json | |||
| import os | |||
| from unittest.mock import patch | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||
| from core.model_providers.models.entity.model_params import ModelKwargs | |||
| from core.model_providers.models.llm.baichuan_model import BaichuanModel | |||
| from core.model_providers.providers.baichuan_provider import BaichuanProvider | |||
| from models.provider import Provider, ProviderType | |||
| def get_mock_provider(valid_api_key, valid_secret_key): | |||
| return Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name='baichuan', | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps({ | |||
| 'api_key': valid_api_key, | |||
| 'secret_key': valid_secret_key, | |||
| }), | |||
| is_valid=True, | |||
| ) | |||
| def get_mock_model(model_name: str, streaming: bool = False): | |||
| model_kwargs = ModelKwargs( | |||
| temperature=0.01, | |||
| ) | |||
| valid_api_key = os.environ['BAICHUAN_API_KEY'] | |||
| valid_secret_key = os.environ['BAICHUAN_SECRET_KEY'] | |||
| model_provider = BaichuanProvider(provider=get_mock_provider(valid_api_key, valid_secret_key)) | |||
| return BaichuanModel( | |||
| model_provider=model_provider, | |||
| name=model_name, | |||
| model_kwargs=model_kwargs, | |||
| streaming=streaming | |||
| ) | |||
| def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| return encrypted_api_key | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_chat_get_num_tokens(mock_decrypt): | |||
| model = get_mock_model('baichuan2-53b') | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst > 0 | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_chat_run(mock_decrypt, mocker): | |||
| mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) | |||
| model = get_mock_model('baichuan2-53b') | |||
| messages = [ | |||
| PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') | |||
| ] | |||
| rst = model.run( | |||
| messages, | |||
| ) | |||
| assert len(rst.content) > 0 | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_chat_stream_run(mock_decrypt, mocker): | |||
| mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) | |||
| model = get_mock_model('baichuan2-53b', streaming=True) | |||
| messages = [ | |||
| PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') | |||
| ] | |||
| rst = model.run( | |||
| messages | |||
| ) | |||
| assert len(rst.content) > 0 | |||
| @@ -0,0 +1,97 @@ | |||
| import pytest | |||
| from unittest.mock import patch | |||
| import json | |||
| from langchain.schema import ChatResult, ChatGeneration, AIMessage | |||
| from core.model_providers.providers.baichuan_provider import BaichuanProvider | |||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||
| from models.provider import ProviderType, Provider | |||
| PROVIDER_NAME = 'baichuan' | |||
| MODEL_PROVIDER_CLASS = BaichuanProvider | |||
| VALIDATE_CREDENTIAL = { | |||
| 'api_key': 'valid_key', | |||
| 'secret_key': 'valid_key', | |||
| } | |||
| def encrypt_side_effect(tenant_id, encrypt_key): | |||
| return f'encrypted_{encrypt_key}' | |||
| def decrypt_side_effect(tenant_id, encrypted_key): | |||
| return encrypted_key.replace('encrypted_', '') | |||
| def test_is_provider_credentials_valid_or_raise_valid(mocker): | |||
| mocker.patch('core.third_party.langchain.llms.baichuan_llm.BaichuanChatLLM._generate', | |||
| return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))])) | |||
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) | |||
| def test_is_provider_credentials_valid_or_raise_invalid(): | |||
| # raise CredentialsValidateFailedError if api_key is not in credentials | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) | |||
| credential = VALIDATE_CREDENTIAL.copy() | |||
| credential['api_key'] = 'invalid_key' | |||
| credential['secret_key'] = 'invalid_key' | |||
| # raise CredentialsValidateFailedError if api_key is invalid | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential) | |||
| @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) | |||
| def test_encrypt_credentials(mock_encrypt): | |||
| result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy()) | |||
| assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}' | |||
| assert result['secret_key'] == f'encrypted_{VALIDATE_CREDENTIAL["secret_key"]}' | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_get_credentials_custom(mock_decrypt): | |||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||
| encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] | |||
| encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key'] | |||
| provider = Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name=PROVIDER_NAME, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps(encrypted_credential), | |||
| is_valid=True, | |||
| ) | |||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||
| result = model_provider.get_provider_credentials() | |||
| assert result['api_key'] == 'valid_key' | |||
| assert result['secret_key'] == 'valid_key' | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_get_credentials_obfuscated(mock_decrypt): | |||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||
| encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] | |||
| encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key'] | |||
| provider = Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name=PROVIDER_NAME, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps(encrypted_credential), | |||
| is_valid=True, | |||
| ) | |||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||
| result = model_provider.get_provider_credentials(obfuscated=True) | |||
| middle_token = result['api_key'][6:-2] | |||
| secret_key_middle_token = result['secret_key'][6:-2] | |||
| assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0) | |||
| assert len(secret_key_middle_token) == max(len(VALIDATE_CREDENTIAL['secret_key']) - 8, 0) | |||
| assert all(char == '*' for char in middle_token) | |||
| assert all(char == '*' for char in secret_key_middle_token) | |||