| 'enabled': v.enabled, | 'enabled': v.enabled, | ||||
| 'min': v.min, | 'min': v.min, | ||||
| 'max': v.max, | 'max': v.max, | ||||
| 'default': v.default | |||||
| 'default': v.default, | |||||
| 'precision': v.precision | |||||
| } | } | ||||
| for k, v in vars(parameter_rules).items() | for k, v in vars(parameter_rules).items() | ||||
| } | } | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self, provider_name: str): | def get(self, provider_name: str): | ||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('token', type=str, required=False, nullable=True, location='args') | |||||
| args = parser.parse_args() | |||||
| provider_service = ProviderService() | provider_service = ProviderService() | ||||
| result = provider_service.free_quota_qualification_verify( | result = provider_service.free_quota_qualification_verify( | ||||
| tenant_id=current_user.current_tenant_id, | tenant_id=current_user.current_tenant_id, | ||||
| provider_name=provider_name | |||||
| provider_name=provider_name, | |||||
| token=args['token'] | |||||
| ) | ) | ||||
| return result | return result |
| self.conversation_message_task.append_message_text(response.generations[0][0].text) | self.conversation_message_task.append_message_text(response.generations[0][0].text) | ||||
| self.llm_message.completion = response.generations[0][0].text | self.llm_message.completion = response.generations[0][0].text | ||||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)]) | |||||
| if response.llm_output and 'token_usage' in response.llm_output: | |||||
| if 'prompt_tokens' in response.llm_output['token_usage']: | |||||
| self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] | |||||
| if 'completion_tokens' in response.llm_output['token_usage']: | |||||
| self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens'] | |||||
| else: | |||||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens( | |||||
| [PromptMessage(content=self.llm_message.completion)]) | |||||
| else: | |||||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens( | |||||
| [PromptMessage(content=self.llm_message.completion)]) | |||||
| self.conversation_message_task.save_message(self.llm_message) | self.conversation_message_task.save_message(self.llm_message) | ||||
| import logging | import logging | ||||
| from typing import List, Dict, Optional, Any | from typing import List, Dict, Optional, Any | ||||
| import openai | |||||
| from flask import current_app | |||||
| from langchain.callbacks.manager import CallbackManagerForChainRun | from langchain.callbacks.manager import CallbackManagerForChainRun | ||||
| from langchain.chains.base import Chain | from langchain.chains.base import Chain | ||||
| from openai import InvalidRequestError | |||||
| from openai.error import APIConnectionError, APIError, ServiceUnavailableError, Timeout, RateLimitError, \ | |||||
| AuthenticationError, OpenAIError | |||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from core.model_providers.error import LLMBadRequestError | from core.model_providers.error import LLMBadRequestError | ||||
| result = self._check_moderation(text) | result = self._check_moderation(text) | ||||
| if not result: | if not result: | ||||
| raise LLMBadRequestError(self.sensitive_word_avoidance_rule.canned_response) | |||||
| raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response) | |||||
| return {self.output_key: text} | return {self.output_key: text} | ||||
| class SensitiveWordAvoidanceError(Exception): | |||||
| def __init__(self, message): | |||||
| super().__init__(message) | |||||
| self.message = message |
| from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy | from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy | ||||
| from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler | from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler | ||||
| from core.callback_handler.llm_callback_handler import LLMCallbackHandler | from core.callback_handler.llm_callback_handler import LLMCallbackHandler | ||||
| from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError | |||||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | ||||
| from core.model_providers.error import LLMBadRequestError | from core.model_providers.error import LLMBadRequestError | ||||
| from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ | from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ | ||||
| app_model_config=app_model_config | app_model_config=app_model_config | ||||
| ) | ) | ||||
| # parse sensitive_word_avoidance_chain | |||||
| chain_callback = MainChainGatherCallbackHandler(conversation_message_task) | |||||
| sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(final_model_instance, [chain_callback]) | |||||
| if sensitive_word_avoidance_chain: | |||||
| query = sensitive_word_avoidance_chain.run(query) | |||||
| # get agent executor | |||||
| agent_executor = orchestrator_rule_parser.to_agent_executor( | |||||
| conversation_message_task=conversation_message_task, | |||||
| memory=memory, | |||||
| rest_tokens=rest_tokens_for_context_and_memory, | |||||
| chain_callback=chain_callback | |||||
| ) | |||||
| # run agent executor | |||||
| agent_execute_result = None | |||||
| if agent_executor: | |||||
| should_use_agent = agent_executor.should_use_agent(query) | |||||
| if should_use_agent: | |||||
| agent_execute_result = agent_executor.run(query) | |||||
| # run the final llm | |||||
| try: | try: | ||||
| # parse sensitive_word_avoidance_chain | |||||
| chain_callback = MainChainGatherCallbackHandler(conversation_message_task) | |||||
| sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain( | |||||
| final_model_instance, [chain_callback]) | |||||
| if sensitive_word_avoidance_chain: | |||||
| try: | |||||
| query = sensitive_word_avoidance_chain.run(query) | |||||
| except SensitiveWordAvoidanceError as ex: | |||||
| cls.run_final_llm( | |||||
| model_instance=final_model_instance, | |||||
| mode=app.mode, | |||||
| app_model_config=app_model_config, | |||||
| query=query, | |||||
| inputs=inputs, | |||||
| agent_execute_result=None, | |||||
| conversation_message_task=conversation_message_task, | |||||
| memory=memory, | |||||
| fake_response=ex.message | |||||
| ) | |||||
| return | |||||
| # get agent executor | |||||
| agent_executor = orchestrator_rule_parser.to_agent_executor( | |||||
| conversation_message_task=conversation_message_task, | |||||
| memory=memory, | |||||
| rest_tokens=rest_tokens_for_context_and_memory, | |||||
| chain_callback=chain_callback, | |||||
| retriever_from=retriever_from | |||||
| ) | |||||
| # run agent executor | |||||
| agent_execute_result = None | |||||
| if agent_executor: | |||||
| should_use_agent = agent_executor.should_use_agent(query) | |||||
| if should_use_agent: | |||||
| agent_execute_result = agent_executor.run(query) | |||||
| # When no extra pre prompt is specified, | |||||
| # the output of the agent can be used directly as the main output content without calling LLM again | |||||
| fake_response = None | |||||
| if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \ | |||||
| and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, | |||||
| PlanningStrategy.REACT_ROUTER]: | |||||
| fake_response = agent_execute_result.output | |||||
| # run the final llm | |||||
| cls.run_final_llm( | cls.run_final_llm( | ||||
| model_instance=final_model_instance, | model_instance=final_model_instance, | ||||
| mode=app.mode, | mode=app.mode, | ||||
| inputs=inputs, | inputs=inputs, | ||||
| agent_execute_result=agent_execute_result, | agent_execute_result=agent_execute_result, | ||||
| conversation_message_task=conversation_message_task, | conversation_message_task=conversation_message_task, | ||||
| memory=memory | |||||
| memory=memory, | |||||
| fake_response=fake_response | |||||
| ) | ) | ||||
| except ConversationTaskStoppedException: | except ConversationTaskStoppedException: | ||||
| return | return | ||||
| inputs: dict, | inputs: dict, | ||||
| agent_execute_result: Optional[AgentExecuteResult], | agent_execute_result: Optional[AgentExecuteResult], | ||||
| conversation_message_task: ConversationMessageTask, | conversation_message_task: ConversationMessageTask, | ||||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]): | |||||
| # When no extra pre prompt is specified, | |||||
| # the output of the agent can be used directly as the main output content without calling LLM again | |||||
| fake_response = None | |||||
| if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \ | |||||
| and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]: | |||||
| fake_response = agent_execute_result.output | |||||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], | |||||
| fake_response: Optional[str]): | |||||
| # get llm prompt | # get llm prompt | ||||
| prompt_messages, stop_words = model_instance.get_prompt( | prompt_messages, stop_words = model_instance.get_prompt( | ||||
| mode=mode, | mode=mode, |
| import logging | import logging | ||||
| import openai | import openai | ||||
| from flask import current_app | |||||
| from core.model_providers.error import LLMBadRequestError | from core.model_providers.error import LLMBadRequestError | ||||
| from core.model_providers.providers.base import BaseModelProvider | from core.model_providers.providers.base import BaseModelProvider | ||||
| from core.model_providers.providers.hosted import hosted_config, hosted_model_providers | |||||
| from models.provider import ProviderType | from models.provider import ProviderType | ||||
| def check_moderation(model_provider: BaseModelProvider, text: str) -> bool: | def check_moderation(model_provider: BaseModelProvider, text: str) -> bool: | ||||
| if current_app.config['HOSTED_MODERATION_ENABLED'] and current_app.config['HOSTED_MODERATION_PROVIDERS']: | |||||
| moderation_providers = current_app.config['HOSTED_MODERATION_PROVIDERS'].split(',') | |||||
| if hosted_config.moderation.enabled is True and hosted_model_providers.openai: | |||||
| if model_provider.provider.provider_type == ProviderType.SYSTEM.value \ | if model_provider.provider.provider_type == ProviderType.SYSTEM.value \ | ||||
| and model_provider.provider_name in moderation_providers: | |||||
| and model_provider.provider_name in hosted_config.moderation.providers: | |||||
| # 2000 text per chunk | # 2000 text per chunk | ||||
| length = 2000 | length = 2000 | ||||
| chunks = [text[i:i + length] for i in range(0, len(text), length)] | |||||
| try: | |||||
| moderation_result = openai.Moderation.create(input=chunks, | |||||
| api_key=current_app.config['HOSTED_OPENAI_API_KEY']) | |||||
| except Exception as ex: | |||||
| logging.exception(ex) | |||||
| raise LLMBadRequestError('Rate limit exceeded, please try again later.') | |||||
| for result in moderation_result.results: | |||||
| if result['flagged'] is True: | |||||
| return False | |||||
| text_chunks = [text[i:i + length] for i in range(0, len(text), length)] | |||||
| max_text_chunks = 32 | |||||
| chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] | |||||
| for text_chunk in chunks: | |||||
| try: | |||||
| moderation_result = openai.Moderation.create(input=text_chunk, | |||||
| api_key=hosted_model_providers.openai.api_key) | |||||
| except Exception as ex: | |||||
| logging.exception(ex) | |||||
| raise LLMBadRequestError('Rate limit exceeded, please try again later.') | |||||
| for result in moderation_result.results: | |||||
| if result['flagged'] is True: | |||||
| return False | |||||
| return True | return True |
| elif provider_name == 'wenxin': | elif provider_name == 'wenxin': | ||||
| from core.model_providers.providers.wenxin_provider import WenxinProvider | from core.model_providers.providers.wenxin_provider import WenxinProvider | ||||
| return WenxinProvider | return WenxinProvider | ||||
| elif provider_name == 'zhipuai': | |||||
| from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider | |||||
| return ZhipuAIProvider | |||||
| elif provider_name == 'chatglm': | elif provider_name == 'chatglm': | ||||
| from core.model_providers.providers.chatglm_provider import ChatGLMProvider | from core.model_providers.providers.chatglm_provider import ChatGLMProvider | ||||
| return ChatGLMProvider | return ChatGLMProvider |
| from core.model_providers.error import LLMBadRequestError | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||||
| from core.third_party.langchain.embeddings.zhipuai_embedding import ZhipuAIEmbeddings | |||||
| class ZhipuAIEmbedding(BaseEmbedding): | |||||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||||
| credentials = model_provider.get_model_credentials( | |||||
| model_name=name, | |||||
| model_type=self.type | |||||
| ) | |||||
| client = ZhipuAIEmbeddings( | |||||
| model=name, | |||||
| **credentials, | |||||
| ) | |||||
| super().__init__(model_provider, client, name) | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| return LLMBadRequestError(f"ZhipuAI embedding: {str(ex)}") |
| max: Optional[T] = None | max: Optional[T] = None | ||||
| default: Optional[T] = None | default: Optional[T] = None | ||||
| alias: Optional[str] = None | alias: Optional[str] = None | ||||
| precision: Optional[int] = None | |||||
| class ModelKwargsRules(BaseModel): | class ModelKwargsRules(BaseModel): |
| from typing import List, Optional, Any | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.schema import LLMResult | |||||
| from core.model_providers.error import LLMBadRequestError | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.model_providers.models.entity.message import PromptMessage | |||||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||||
| from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM | |||||
| class ZhipuAIModel(BaseLLM): | |||||
| model_mode: ModelMode = ModelMode.CHAT | |||||
| def _init_client(self) -> Any: | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||||
| return ZhipuAIChatLLM( | |||||
| streaming=self.streaming, | |||||
| callbacks=self.callbacks, | |||||
| **self.credentials, | |||||
| **provider_model_kwargs | |||||
| ) | |||||
| def _run(self, messages: List[PromptMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs) -> LLMResult: | |||||
| """ | |||||
| run predict by prompt messages and stop words. | |||||
| :param messages: | |||||
| :param stop: | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return self._client.generate([prompts], stop, callbacks) | |||||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||||
| """ | |||||
| get num tokens of prompt messages. | |||||
| :param messages: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return max(self._client.get_num_tokens_from_messages(prompts), 0) | |||||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||||
| for k, v in provider_model_kwargs.items(): | |||||
| if hasattr(self.client, k): | |||||
| setattr(self.client, k, v) | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| return LLMBadRequestError(f"ZhipuAI: {str(ex)}") | |||||
| @property | |||||
| def support_streaming(self): | |||||
| return True |
| # 2000 text per chunk | # 2000 text per chunk | ||||
| length = 2000 | length = 2000 | ||||
| chunks = [text[i:i + length] for i in range(0, len(text), length)] | |||||
| text_chunks = [text[i:i + length] for i in range(0, len(text), length)] | |||||
| moderation_result = self._client.create(input=chunks, | |||||
| api_key=credentials['openai_api_key']) | |||||
| max_text_chunks = 32 | |||||
| chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] | |||||
| for result in moderation_result.results: | |||||
| if result['flagged'] is True: | |||||
| return False | |||||
| for text_chunk in chunks: | |||||
| moderation_result = self._client.create(input=text_chunk, | |||||
| api_key=credentials['openai_api_key']) | |||||
| for result in moderation_result.results: | |||||
| if result['flagged'] is True: | |||||
| return False | |||||
| return True | return True | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0, max=1, default=1), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||||
| temperature=KwargRule[float](min=0, max=1, default=1, precision=2), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), | |||||
| presence_penalty=KwargRule[float](enabled=False), | presence_penalty=KwargRule[float](enabled=False), | ||||
| frequency_penalty=KwargRule[float](enabled=False), | frequency_penalty=KwargRule[float](enabled=False), | ||||
| max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256), | |||||
| max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0), | |||||
| ) | ) | ||||
| @classmethod | @classmethod |
| model_credentials = self.get_model_credentials(model_name, model_type) | model_credentials = self.get_model_credentials(model_name, model_type) | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0, max=2, default=1), | |||||
| top_p=KwargRule[float](min=0, max=1, default=1), | |||||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||||
| temperature=KwargRule[float](min=0, max=2, default=1, precision=2), | |||||
| top_p=KwargRule[float](min=0, max=1, default=1, precision=2), | |||||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||||
| max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get( | max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get( | ||||
| model_credentials['base_model_name'], | model_credentials['base_model_name'], | ||||
| 4097 | 4097 | ||||
| ), default=16), | |||||
| ), default=16, precision=0), | |||||
| ) | ) | ||||
| @classmethod | @classmethod |
| } | } | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0, max=2, default=1), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||||
| temperature=KwargRule[float](min=0, max=2, default=1, precision=2), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), | |||||
| presence_penalty=KwargRule[float](enabled=False), | presence_penalty=KwargRule[float](enabled=False), | ||||
| frequency_penalty=KwargRule[float](enabled=False), | frequency_penalty=KwargRule[float](enabled=False), | ||||
| max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048), | |||||
| max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0), | |||||
| ) | ) | ||||
| @classmethod | @classmethod |
| hosted_model_providers = HostedModelProviders() | hosted_model_providers = HostedModelProviders() | ||||
| class HostedModerationConfig(BaseModel): | |||||
| enabled: bool = False | |||||
| providers: list[str] = [] | |||||
| class HostedConfig(BaseModel): | |||||
| moderation = HostedModerationConfig() | |||||
| hosted_config = HostedConfig() | |||||
| def init_app(app: Flask): | def init_app(app: Flask): | ||||
| if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': | if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': | ||||
| langchain.verbose = True | langchain.verbose = True | ||||
| paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"), | paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"), | ||||
| paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"), | paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"), | ||||
| ) | ) | ||||
| if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"): | |||||
| hosted_config.moderation = HostedModerationConfig( | |||||
| enabled=app.config.get("HOSTED_MODERATION_ENABLED"), | |||||
| providers=app.config.get("HOSTED_MODERATION_PROVIDERS").split(',') | |||||
| ) |
| :return: | :return: | ||||
| """ | """ | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0, max=2, default=1), | |||||
| top_p=KwargRule[float](min=0.01, max=0.99, default=0.7), | |||||
| temperature=KwargRule[float](min=0, max=2, default=1, precision=2), | |||||
| top_p=KwargRule[float](min=0.01, max=0.99, default=0.7, precision=2), | |||||
| presence_penalty=KwargRule[float](enabled=False), | presence_penalty=KwargRule[float](enabled=False), | ||||
| frequency_penalty=KwargRule[float](enabled=False), | frequency_penalty=KwargRule[float](enabled=False), | ||||
| max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200), | |||||
| max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200, precision=0), | |||||
| ) | ) | ||||
| @classmethod | @classmethod |
| :return: | :return: | ||||
| """ | """ | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0, max=2, default=0.7), | |||||
| top_p=KwargRule[float](min=0, max=1, default=1), | |||||
| max_tokens=KwargRule[int](min=10, max=4097, default=16), | |||||
| temperature=KwargRule[float](min=0, max=2, default=0.7, precision=2), | |||||
| top_p=KwargRule[float](min=0, max=1, default=1, precision=2), | |||||
| max_tokens=KwargRule[int](min=10, max=4097, default=16, precision=0), | |||||
| ) | ) | ||||
| @classmethod | @classmethod |
| } | } | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0.01, max=1, default=0.9), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.95), | |||||
| temperature=KwargRule[float](min=0.01, max=1, default=0.9, precision=2), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.95, precision=2), | |||||
| presence_penalty=KwargRule[float](enabled=False), | presence_penalty=KwargRule[float](enabled=False), | ||||
| frequency_penalty=KwargRule[float](enabled=False), | frequency_penalty=KwargRule[float](enabled=False), | ||||
| max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024), | |||||
| max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024, precision=0), | |||||
| ) | ) | ||||
| @classmethod | @classmethod |
| } | } | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0, max=2, default=1), | |||||
| top_p=KwargRule[float](min=0, max=1, default=1), | |||||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||||
| max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16), | |||||
| temperature=KwargRule[float](min=0, max=2, default=1, precision=2), | |||||
| top_p=KwargRule[float](min=0, max=1, default=1, precision=2), | |||||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||||
| max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16, precision=0), | |||||
| ) | ) | ||||
| @classmethod | @classmethod |
| :return: | :return: | ||||
| """ | """ | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||||
| max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128), | |||||
| temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), | |||||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||||
| max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128, precision=0), | |||||
| ) | ) | ||||
| @classmethod | @classmethod |
| min=float(value.get('minimum')) if value.get('minimum') is not None else None, | min=float(value.get('minimum')) if value.get('minimum') is not None else None, | ||||
| max=float(value.get('maximum')) if value.get('maximum') is not None else None, | max=float(value.get('maximum')) if value.get('maximum') is not None else None, | ||||
| default=float(value.get('default')) if value.get('default') is not None else None, | default=float(value.get('default')) if value.get('default') is not None else None, | ||||
| precision = 2 | |||||
| ) | ) | ||||
| if key == 'temperature': | if key == 'temperature': | ||||
| model_kwargs_rules.temperature = kwarg_rule | model_kwargs_rules.temperature = kwarg_rule | ||||
| min=int(value.get('minimum')) if value.get('minimum') is not None else 1, | min=int(value.get('minimum')) if value.get('minimum') is not None else 1, | ||||
| max=int(value.get('maximum')) if value.get('maximum') is not None else 8000, | max=int(value.get('maximum')) if value.get('maximum') is not None else 8000, | ||||
| default=int(value.get('default')) if value.get('default') is not None else 500, | default=int(value.get('default')) if value.get('default') is not None else 500, | ||||
| precision = 0 | |||||
| ) | ) | ||||
| return model_kwargs_rules | return model_kwargs_rules |
| :return: | :return: | ||||
| """ | """ | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0, max=1, default=0.5), | |||||
| temperature=KwargRule[float](min=0, max=1, default=0.5, precision=2), | |||||
| top_p=KwargRule[float](enabled=False), | top_p=KwargRule[float](enabled=False), | ||||
| presence_penalty=KwargRule[float](enabled=False), | presence_penalty=KwargRule[float](enabled=False), | ||||
| frequency_penalty=KwargRule[float](enabled=False), | frequency_penalty=KwargRule[float](enabled=False), | ||||
| max_tokens=KwargRule[int](min=10, max=4096, default=2048), | |||||
| max_tokens=KwargRule[int](min=10, max=4096, default=2048, precision=0), | |||||
| ) | ) | ||||
| @classmethod | @classmethod |
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](enabled=False), | temperature=KwargRule[float](enabled=False), | ||||
| top_p=KwargRule[float](min=0, max=1, default=0.8), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2), | |||||
| presence_penalty=KwargRule[float](enabled=False), | presence_penalty=KwargRule[float](enabled=False), | ||||
| frequency_penalty=KwargRule[float](enabled=False), | frequency_penalty=KwargRule[float](enabled=False), | ||||
| max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024), | |||||
| max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0), | |||||
| ) | ) | ||||
| @classmethod | @classmethod |
| """ | """ | ||||
| if model_name in ['ernie-bot', 'ernie-bot-turbo']: | if model_name in ['ernie-bot', 'ernie-bot-turbo']: | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0.01, max=1, default=0.95), | |||||
| top_p=KwargRule[float](min=0.01, max=1, default=0.8), | |||||
| temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2), | |||||
| top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2), | |||||
| presence_penalty=KwargRule[float](enabled=False), | presence_penalty=KwargRule[float](enabled=False), | ||||
| frequency_penalty=KwargRule[float](enabled=False), | frequency_penalty=KwargRule[float](enabled=False), | ||||
| max_tokens=KwargRule[int](enabled=False), | max_tokens=KwargRule[int](enabled=False), |
| credentials = self.get_model_credentials(model_name, model_type) | credentials = self.get_model_credentials(model_name, model_type) | ||||
| if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm": | if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm": | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||||
| temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), | |||||
| presence_penalty=KwargRule[float](enabled=False), | presence_penalty=KwargRule[float](enabled=False), | ||||
| frequency_penalty=KwargRule[float](enabled=False), | frequency_penalty=KwargRule[float](enabled=False), | ||||
| max_tokens=KwargRule[int](min=10, max=4000, default=256), | |||||
| max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0), | |||||
| ) | ) | ||||
| elif credentials['model_format'] == "ggmlv3": | elif credentials['model_format'] == "ggmlv3": | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||||
| max_tokens=KwargRule[int](min=10, max=4000, default=256), | |||||
| temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), | |||||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||||
| max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0), | |||||
| ) | ) | ||||
| else: | else: | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||||
| temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), | |||||
| presence_penalty=KwargRule[float](enabled=False), | presence_penalty=KwargRule[float](enabled=False), | ||||
| frequency_penalty=KwargRule[float](enabled=False), | frequency_penalty=KwargRule[float](enabled=False), | ||||
| max_tokens=KwargRule[int](min=10, max=4000, default=256), | |||||
| max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0), | |||||
| ) | ) | ||||
| import json | |||||
| from json import JSONDecodeError | |||||
| from typing import Type | |||||
| from langchain.schema import HumanMessage | |||||
| from core.helper import encrypter | |||||
| from core.model_providers.models.base import BaseProviderModel | |||||
| from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding | |||||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType | |||||
| from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel | |||||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||||
| from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM | |||||
| from models.provider import ProviderType, ProviderQuotaType | |||||
| class ZhipuAIProvider(BaseModelProvider): | |||||
| @property | |||||
| def provider_name(self): | |||||
| """ | |||||
| Returns the name of a provider. | |||||
| """ | |||||
| return 'zhipuai' | |||||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||||
| if model_type == ModelType.TEXT_GENERATION: | |||||
| return [ | |||||
| { | |||||
| 'id': 'chatglm_pro', | |||||
| 'name': 'chatglm_pro', | |||||
| }, | |||||
| { | |||||
| 'id': 'chatglm_std', | |||||
| 'name': 'chatglm_std', | |||||
| }, | |||||
| { | |||||
| 'id': 'chatglm_lite', | |||||
| 'name': 'chatglm_lite', | |||||
| }, | |||||
| { | |||||
| 'id': 'chatglm_lite_32k', | |||||
| 'name': 'chatglm_lite_32k', | |||||
| } | |||||
| ] | |||||
| elif model_type == ModelType.EMBEDDINGS: | |||||
| return [ | |||||
| { | |||||
| 'id': 'text_embedding', | |||||
| 'name': 'text_embedding', | |||||
| } | |||||
| ] | |||||
| else: | |||||
| return [] | |||||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||||
| """ | |||||
| Returns the model class. | |||||
| :param model_type: | |||||
| :return: | |||||
| """ | |||||
| if model_type == ModelType.TEXT_GENERATION: | |||||
| model_class = ZhipuAIModel | |||||
| elif model_type == ModelType.EMBEDDINGS: | |||||
| model_class = ZhipuAIEmbedding | |||||
| else: | |||||
| raise NotImplementedError | |||||
| return model_class | |||||
| def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: | |||||
| """ | |||||
| get model parameter rules. | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :return: | |||||
| """ | |||||
| return ModelKwargsRules( | |||||
| temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2), | |||||
| top_p=KwargRule[float](min=0.1, max=0.9, default=0.8, precision=1), | |||||
| presence_penalty=KwargRule[float](enabled=False), | |||||
| frequency_penalty=KwargRule[float](enabled=False), | |||||
| max_tokens=KwargRule[int](enabled=False), | |||||
| ) | |||||
| @classmethod | |||||
| def is_provider_credentials_valid_or_raise(cls, credentials: dict): | |||||
| """ | |||||
| Validates the given credentials. | |||||
| """ | |||||
| if 'api_key' not in credentials: | |||||
| raise CredentialsValidateFailedError('ZhipuAI api_key must be provided.') | |||||
| try: | |||||
| credential_kwargs = { | |||||
| 'api_key': credentials['api_key'] | |||||
| } | |||||
| llm = ZhipuAIChatLLM( | |||||
| temperature=0.01, | |||||
| **credential_kwargs | |||||
| ) | |||||
| llm([HumanMessage(content='ping')]) | |||||
| except Exception as ex: | |||||
| raise CredentialsValidateFailedError(str(ex)) | |||||
| @classmethod | |||||
| def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: | |||||
| credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key']) | |||||
| return credentials | |||||
| def get_provider_credentials(self, obfuscated: bool = False) -> dict: | |||||
| if self.provider.provider_type == ProviderType.CUSTOM.value \ | |||||
| or (self.provider.provider_type == ProviderType.SYSTEM.value | |||||
| and self.provider.quota_type == ProviderQuotaType.FREE.value): | |||||
| try: | |||||
| credentials = json.loads(self.provider.encrypted_config) | |||||
| except JSONDecodeError: | |||||
| credentials = { | |||||
| 'api_key': None, | |||||
| } | |||||
| if credentials['api_key']: | |||||
| credentials['api_key'] = encrypter.decrypt_token( | |||||
| self.provider.tenant_id, | |||||
| credentials['api_key'] | |||||
| ) | |||||
| if obfuscated: | |||||
| credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key']) | |||||
| return credentials | |||||
| else: | |||||
| return {} | |||||
| def should_deduct_quota(self): | |||||
| return True | |||||
| @classmethod | |||||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | |||||
| """ | |||||
| check model credentials valid. | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :param credentials: | |||||
| """ | |||||
| return | |||||
| @classmethod | |||||
| def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, | |||||
| credentials: dict) -> dict: | |||||
| """ | |||||
| encrypt model credentials for save. | |||||
| :param tenant_id: | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :param credentials: | |||||
| :return: | |||||
| """ | |||||
| return {} | |||||
| def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: | |||||
| """ | |||||
| get credentials for llm use. | |||||
| :param model_name: | |||||
| :param model_type: | |||||
| :param obfuscated: | |||||
| :return: | |||||
| """ | |||||
| return self.get_provider_credentials(obfuscated) |
| "tongyi", | "tongyi", | ||||
| "spark", | "spark", | ||||
| "wenxin", | "wenxin", | ||||
| "zhipuai", | |||||
| "chatglm", | "chatglm", | ||||
| "replicate", | "replicate", | ||||
| "huggingface_hub", | "huggingface_hub", |
| { | |||||
| "support_provider_types": [ | |||||
| "system", | |||||
| "custom" | |||||
| ], | |||||
| "system_config": { | |||||
| "supported_quota_types": [ | |||||
| "free" | |||||
| ], | |||||
| "quota_unit": "tokens" | |||||
| }, | |||||
| "model_flexibility": "fixed", | |||||
| "price_config": { | |||||
| "chatglm_pro": { | |||||
| "prompt": "0.01", | |||||
| "completion": "0.01", | |||||
| "unit": "0.001", | |||||
| "currency": "RMB" | |||||
| }, | |||||
| "chatglm_std": { | |||||
| "prompt": "0.005", | |||||
| "completion": "0.005", | |||||
| "unit": "0.001", | |||||
| "currency": "RMB" | |||||
| }, | |||||
| "chatglm_lite": { | |||||
| "prompt": "0.002", | |||||
| "completion": "0.002", | |||||
| "unit": "0.001", | |||||
| "currency": "RMB" | |||||
| }, | |||||
| "chatglm_lite_32k": { | |||||
| "prompt": "0.0004", | |||||
| "completion": "0.0004", | |||||
| "unit": "0.001", | |||||
| "currency": "RMB" | |||||
| }, | |||||
| "text_embedding": { | |||||
| "completion": "0", | |||||
| "unit": "0.001", | |||||
| "currency": "RMB" | |||||
| } | |||||
| } | |||||
| } |
| """Wrapper around ZhipuAI embedding models.""" | |||||
| from typing import Any, Dict, List, Optional | |||||
| from pydantic import BaseModel, Extra, root_validator | |||||
| from langchain.embeddings.base import Embeddings | |||||
| from langchain.utils import get_from_dict_or_env | |||||
| from core.third_party.langchain.llms.zhipuai_llm import ZhipuModelAPI | |||||
| class ZhipuAIEmbeddings(BaseModel, Embeddings): | |||||
| """Wrapper around ZhipuAI embedding models. | |||||
| 1024 dimensions. | |||||
| """ | |||||
| client: Any #: :meta private: | |||||
| model: str | |||||
| """Model name to use.""" | |||||
| base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api" | |||||
| api_key: Optional[str] = None | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| extra = Extra.forbid | |||||
| @root_validator() | |||||
| def validate_environment(cls, values: Dict) -> Dict: | |||||
| """Validate that api key and python package exists in environment.""" | |||||
| values["api_key"] = get_from_dict_or_env( | |||||
| values, "api_key", "ZHIPUAI_API_KEY" | |||||
| ) | |||||
| values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url']) | |||||
| return values | |||||
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |||||
| """Call out to ZhipuAI's embedding endpoint. | |||||
| Args: | |||||
| texts: The list of texts to embed. | |||||
| Returns: | |||||
| List of embeddings, one for each text. | |||||
| """ | |||||
| embeddings = [] | |||||
| for text in texts: | |||||
| response = self.client.invoke(model=self.model, prompt=text) | |||||
| data = response["data"] | |||||
| embeddings.append(data.get('embedding')) | |||||
| return [list(map(float, e)) for e in embeddings] | |||||
| def embed_query(self, text: str) -> List[float]: | |||||
| """Call out to ZhipuAI's embedding endpoint. | |||||
| Args: | |||||
| text: The text to embed. | |||||
| Returns: | |||||
| Embeddings for the text. | |||||
| """ | |||||
| return self.embed_documents([text])[0] |
| """Wrapper around ZhipuAI APIs.""" | |||||
| from __future__ import annotations | |||||
| import json | |||||
| import logging | |||||
| import posixpath | |||||
| from typing import ( | |||||
| Any, | |||||
| Dict, | |||||
| List, | |||||
| Optional, Iterator, Sequence, | |||||
| ) | |||||
| import zhipuai | |||||
| from langchain.chat_models.base import BaseChatModel | |||||
| from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage | |||||
| from langchain.schema.messages import AIMessageChunk | |||||
| from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration | |||||
| from pydantic import Extra, root_validator, BaseModel | |||||
| from langchain.callbacks.manager import ( | |||||
| CallbackManagerForLLMRun, | |||||
| ) | |||||
| from langchain.utils import get_from_dict_or_env | |||||
| from zhipuai.model_api.api import InvokeType | |||||
| from zhipuai.utils import jwt_token | |||||
| from zhipuai.utils.http_client import post, stream | |||||
| from zhipuai.utils.sse_client import SSEClient | |||||
| logger = logging.getLogger(__name__) | |||||
| class ZhipuModelAPI(BaseModel): | |||||
| base_url: str | |||||
| api_key: str | |||||
| api_timeout_seconds = 60 | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| extra = Extra.forbid | |||||
| def invoke(self, **kwargs): | |||||
| url = self._build_api_url(kwargs, InvokeType.SYNC) | |||||
| response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds) | |||||
| if not response['success']: | |||||
| raise ValueError( | |||||
| f"Error Code: {response['code']}, Message: {response['msg']} " | |||||
| ) | |||||
| return response | |||||
| def sse_invoke(self, **kwargs): | |||||
| url = self._build_api_url(kwargs, InvokeType.SSE) | |||||
| data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds) | |||||
| return SSEClient(data) | |||||
| def _build_api_url(self, kwargs, *path): | |||||
| if kwargs: | |||||
| if "model" not in kwargs: | |||||
| raise Exception("model param missed") | |||||
| model = kwargs.pop("model") | |||||
| else: | |||||
| model = "-" | |||||
| return posixpath.join(self.base_url, model, *path) | |||||
| def _generate_token(self): | |||||
| if not self.api_key: | |||||
| raise Exception( | |||||
| "api_key not provided, you could provide it." | |||||
| ) | |||||
| try: | |||||
| return jwt_token.generate_token(self.api_key) | |||||
| except Exception: | |||||
| raise ValueError( | |||||
| f"Your api_key is invalid, please check it." | |||||
| ) | |||||
| class ZhipuAIChatLLM(BaseChatModel): | |||||
| """Wrapper around ZhipuAI large language models. | |||||
| To use, you should pass the api_key as a named parameter to the constructor. | |||||
| Example: | |||||
| .. code-block:: python | |||||
| from core.third_party.langchain.llms.zhipuai import ZhipuAI | |||||
| model = ZhipuAI(model="<model_name>", api_key="my-api-key") | |||||
| """ | |||||
| @property | |||||
| def lc_secrets(self) -> Dict[str, str]: | |||||
| return {"api_key": "API_KEY"} | |||||
| @property | |||||
| def lc_serializable(self) -> bool: | |||||
| return True | |||||
| client: Any = None #: :meta private: | |||||
| model: str = "chatglm_lite" | |||||
| """Model name to use.""" | |||||
| temperature: float = 0.95 | |||||
| """A non-negative float that tunes the degree of randomness in generation.""" | |||||
| top_p: float = 0.7 | |||||
| """Total probability mass of tokens to consider at each step.""" | |||||
| streaming: bool = False | |||||
| """Whether to stream the response or return it all at once.""" | |||||
| api_key: Optional[str] = None | |||||
| base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api" | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| extra = Extra.forbid | |||||
| @root_validator() | |||||
| def validate_environment(cls, values: Dict) -> Dict: | |||||
| """Validate that api key and python package exists in environment.""" | |||||
| values["api_key"] = get_from_dict_or_env( | |||||
| values, "api_key", "ZHIPUAI_API_KEY" | |||||
| ) | |||||
| if 'test' in values['base_url']: | |||||
| values['model'] = 'chatglm_130b_test' | |||||
| values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url']) | |||||
| return values | |||||
| @property | |||||
| def _default_params(self) -> Dict[str, Any]: | |||||
| """Get the default parameters for calling OpenAI API.""" | |||||
| return { | |||||
| "model": self.model, | |||||
| "temperature": self.temperature, | |||||
| "top_p": self.top_p | |||||
| } | |||||
| @property | |||||
| def _identifying_params(self) -> Dict[str, Any]: | |||||
| """Get the identifying parameters.""" | |||||
| return self._default_params | |||||
| @property | |||||
| def _llm_type(self) -> str: | |||||
| """Return type of llm.""" | |||||
| return "zhipuai" | |||||
| def _convert_message_to_dict(self, message: BaseMessage) -> dict: | |||||
| if isinstance(message, ChatMessage): | |||||
| message_dict = {"role": message.role, "content": message.content} | |||||
| elif isinstance(message, HumanMessage): | |||||
| message_dict = {"role": "user", "content": message.content} | |||||
| elif isinstance(message, AIMessage): | |||||
| message_dict = {"role": "assistant", "content": message.content} | |||||
| elif isinstance(message, SystemMessage): | |||||
| message_dict = {"role": "user", "content": message.content} | |||||
| else: | |||||
| raise ValueError(f"Got unknown type {message}") | |||||
| return message_dict | |||||
| def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage: | |||||
| role = _dict["role"] | |||||
| if role == "user": | |||||
| return HumanMessage(content=_dict["content"]) | |||||
| elif role == "assistant": | |||||
| return AIMessage(content=_dict["content"]) | |||||
| elif role == "system": | |||||
| return SystemMessage(content=_dict["content"]) | |||||
| else: | |||||
| return ChatMessage(content=_dict["content"], role=role) | |||||
| def _create_message_dicts( | |||||
| self, messages: List[BaseMessage] | |||||
| ) -> List[Dict[str, Any]]: | |||||
| dict_messages = [] | |||||
| for m in messages: | |||||
| message = self._convert_message_to_dict(m) | |||||
| if dict_messages: | |||||
| previous_message = dict_messages[-1] | |||||
| if previous_message['role'] == message['role']: | |||||
| dict_messages[-1]['content'] += f"\n{message['content']}" | |||||
| else: | |||||
| dict_messages.append(message) | |||||
| else: | |||||
| dict_messages.append(message) | |||||
| return dict_messages | |||||
| def _generate( | |||||
| self, | |||||
| messages: List[BaseMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||||
| **kwargs: Any, | |||||
| ) -> ChatResult: | |||||
| if self.streaming: | |||||
| generation: Optional[ChatGenerationChunk] = None | |||||
| llm_output: Optional[Dict] = None | |||||
| for chunk in self._stream( | |||||
| messages=messages, stop=stop, run_manager=run_manager, **kwargs | |||||
| ): | |||||
| if chunk.generation_info is not None \ | |||||
| and 'token_usage' in chunk.generation_info: | |||||
| llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model} | |||||
| continue | |||||
| if generation is None: | |||||
| generation = chunk | |||||
| else: | |||||
| generation += chunk | |||||
| assert generation is not None | |||||
| return ChatResult(generations=[generation], llm_output=llm_output) | |||||
| else: | |||||
| message_dicts = self._create_message_dicts(messages) | |||||
| request = self._default_params | |||||
| request["prompt"] = message_dicts | |||||
| request.update(kwargs) | |||||
| response = self.client.invoke(**request) | |||||
| return self._create_chat_result(response) | |||||
| def _stream( | |||||
| self, | |||||
| messages: List[BaseMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||||
| **kwargs: Any, | |||||
| ) -> Iterator[ChatGenerationChunk]: | |||||
| message_dicts = self._create_message_dicts(messages) | |||||
| request = self._default_params | |||||
| request["prompt"] = message_dicts | |||||
| request.update(kwargs) | |||||
| for event in self.client.sse_invoke(incremental=True, **request).events(): | |||||
| if event.event == "add": | |||||
| yield ChatGenerationChunk(message=AIMessageChunk(content=event.data)) | |||||
| if run_manager: | |||||
| run_manager.on_llm_new_token(event.data) | |||||
| elif event.event == "error" or event.event == "interrupted": | |||||
| raise ValueError( | |||||
| f"{event.data}" | |||||
| ) | |||||
| elif event.event == "finish": | |||||
| meta = json.loads(event.meta) | |||||
| token_usage = meta['usage'] | |||||
| if token_usage is not None: | |||||
| if 'prompt_tokens' not in token_usage: | |||||
| token_usage['prompt_tokens'] = 0 | |||||
| if 'completion_tokens' not in token_usage: | |||||
| token_usage['completion_tokens'] = token_usage['total_tokens'] | |||||
| yield ChatGenerationChunk( | |||||
| message=AIMessageChunk(content=event.data), | |||||
| generation_info=dict({'token_usage': token_usage}) | |||||
| ) | |||||
| def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult: | |||||
| data = response["data"] | |||||
| generations = [] | |||||
| for res in data["choices"]: | |||||
| message = self._convert_dict_to_message(res) | |||||
| gen = ChatGeneration( | |||||
| message=message | |||||
| ) | |||||
| generations.append(gen) | |||||
| token_usage = data.get("usage") | |||||
| if token_usage is not None: | |||||
| if 'prompt_tokens' not in token_usage: | |||||
| token_usage['prompt_tokens'] = 0 | |||||
| if 'completion_tokens' not in token_usage: | |||||
| token_usage['completion_tokens'] = token_usage['total_tokens'] | |||||
| llm_output = {"token_usage": token_usage, "model_name": self.model} | |||||
| return ChatResult(generations=generations, llm_output=llm_output) | |||||
| # def get_token_ids(self, text: str) -> List[int]: | |||||
| # """Return the ordered ids of the tokens in a text. | |||||
| # | |||||
| # Args: | |||||
| # text: The string input to tokenize. | |||||
| # | |||||
| # Returns: | |||||
| # A list of ids corresponding to the tokens in the text, in order they occur | |||||
| # in the text. | |||||
| # """ | |||||
| # from core.third_party.transformers.Token import ChatGLMTokenizer | |||||
| # | |||||
| # tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm2-6b") | |||||
| # return tokenizer.encode(text) | |||||
| def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: | |||||
| """Get the number of tokens in the messages. | |||||
| Useful for checking if an input will fit in a model's context window. | |||||
| Args: | |||||
| messages: The message inputs to tokenize. | |||||
| Returns: | |||||
| The sum of the number of tokens across the messages. | |||||
| """ | |||||
| return sum([self.get_num_tokens(m.content) for m in messages]) | |||||
| def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: | |||||
| overall_token_usage: dict = {} | |||||
| for output in llm_outputs: | |||||
| if output is None: | |||||
| # Happens in streaming | |||||
| continue | |||||
| token_usage = output["token_usage"] | |||||
| for k, v in token_usage.items(): | |||||
| if k in overall_token_usage: | |||||
| overall_token_usage[k] += v | |||||
| else: | |||||
| overall_token_usage[k] = v | |||||
| return {"token_usage": overall_token_usage, "model_name": self.model} |
| stripe~=5.5.0 | stripe~=5.5.0 | ||||
| pandas==1.5.3 | pandas==1.5.3 | ||||
| xinference==0.4.2 | xinference==0.4.2 | ||||
| safetensors==0.3.2 | |||||
| safetensors==0.3.2 | |||||
| zhipuai==1.0.7 |
| 'result': 'success' | 'result': 'success' | ||||
| } | } | ||||
| def free_quota_qualification_verify(self, tenant_id: str, provider_name: str): | |||||
| def free_quota_qualification_verify(self, tenant_id: str, provider_name: str, token: Optional[str]): | |||||
| api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") | api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") | ||||
| api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") | api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") | ||||
| api_url = api_base_url + '/api/v1/providers/qualification-verify' | api_url = api_base_url + '/api/v1/providers/qualification-verify' | ||||
| 'Content-Type': 'application/json', | 'Content-Type': 'application/json', | ||||
| 'Authorization': f"Bearer {api_key}" | 'Authorization': f"Bearer {api_key}" | ||||
| } | } | ||||
| json_data = {'workspace_id': tenant_id, 'provider_name': provider_name} | |||||
| if token: | |||||
| json_data['token'] = token | |||||
| response = requests.post(api_url, headers=headers, | response = requests.post(api_url, headers=headers, | ||||
| json={'workspace_id': tenant_id, 'provider_name': provider_name}) | |||||
| json=json_data) | |||||
| if not response.ok: | if not response.ok: | ||||
| logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") | logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") | ||||
| raise ValueError(f"Error: {response.status_code} ") | raise ValueError(f"Error: {response.status_code} ") |
| WENXIN_API_KEY= | WENXIN_API_KEY= | ||||
| WENXIN_SECRET_KEY= | WENXIN_SECRET_KEY= | ||||
| # ZhipuAI Credentials | |||||
| ZHIPUAI_API_KEY= | |||||
| # ChatGLM Credentials | # ChatGLM Credentials | ||||
| CHATGLM_API_BASE= | CHATGLM_API_BASE= | ||||
| import json | |||||
| import os | |||||
| from unittest.mock import patch | |||||
| from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding | |||||
| from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider | |||||
| from models.provider import Provider, ProviderType | |||||
| def get_mock_provider(valid_api_key): | |||||
| return Provider( | |||||
| id='provider_id', | |||||
| tenant_id='tenant_id', | |||||
| provider_name='zhipuai', | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=json.dumps({ | |||||
| 'api_key': valid_api_key | |||||
| }), | |||||
| is_valid=True, | |||||
| ) | |||||
| def get_mock_embedding_model(): | |||||
| model_name = 'text_embedding' | |||||
| valid_api_key = os.environ['ZHIPUAI_API_KEY'] | |||||
| provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key)) | |||||
| return ZhipuAIEmbedding( | |||||
| model_provider=provider, | |||||
| name=model_name | |||||
| ) | |||||
| def decrypt_side_effect(tenant_id, encrypted_api_key): | |||||
| return encrypted_api_key | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_embedding(mock_decrypt): | |||||
| embedding_model = get_mock_embedding_model() | |||||
| rst = embedding_model.client.embed_query('test') | |||||
| assert isinstance(rst, list) | |||||
| assert len(rst) == 1024 | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_doc_embedding(mock_decrypt): | |||||
| embedding_model = get_mock_embedding_model() | |||||
| rst = embedding_model.client.embed_documents(['test', 'test2']) | |||||
| assert isinstance(rst, list) | |||||
| assert len(rst[0]) == 1024 |
| import json | |||||
| import os | |||||
| from unittest.mock import patch | |||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||||
| from core.model_providers.models.entity.model_params import ModelKwargs | |||||
| from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel | |||||
| from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider | |||||
| from models.provider import Provider, ProviderType | |||||
| def get_mock_provider(valid_api_key): | |||||
| return Provider( | |||||
| id='provider_id', | |||||
| tenant_id='tenant_id', | |||||
| provider_name='zhipuai', | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=json.dumps({ | |||||
| 'api_key': valid_api_key | |||||
| }), | |||||
| is_valid=True, | |||||
| ) | |||||
| def get_mock_model(model_name: str, streaming: bool = False): | |||||
| model_kwargs = ModelKwargs( | |||||
| temperature=0.01, | |||||
| ) | |||||
| valid_api_key = os.environ['ZHIPUAI_API_KEY'] | |||||
| model_provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key)) | |||||
| return ZhipuAIModel( | |||||
| model_provider=model_provider, | |||||
| name=model_name, | |||||
| model_kwargs=model_kwargs, | |||||
| streaming=streaming | |||||
| ) | |||||
| def decrypt_side_effect(tenant_id, encrypted_api_key): | |||||
| return encrypted_api_key | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_chat_get_num_tokens(mock_decrypt): | |||||
| model = get_mock_model('chatglm_lite') | |||||
| rst = model.get_num_tokens([ | |||||
| PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), | |||||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||||
| ]) | |||||
| assert rst > 0 | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_chat_run(mock_decrypt, mocker): | |||||
| mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) | |||||
| model = get_mock_model('chatglm_lite') | |||||
| messages = [ | |||||
| PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') | |||||
| ] | |||||
| rst = model.run( | |||||
| messages, | |||||
| ) | |||||
| assert len(rst.content) > 0 | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_chat_stream_run(mock_decrypt, mocker): | |||||
| mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) | |||||
| model = get_mock_model('chatglm_lite', streaming=True) | |||||
| messages = [ | |||||
| PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') | |||||
| ] | |||||
| rst = model.run( | |||||
| messages | |||||
| ) | |||||
| assert len(rst.content) > 0 |
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) | MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) | ||||
| credential = VALIDATE_CREDENTIAL.copy() | credential = VALIDATE_CREDENTIAL.copy() | ||||
| credential['api_key'] = 'invalid_key' | |||||
| del credential['api_key'] | |||||
| # raise CredentialsValidateFailedError if api_key is invalid | # raise CredentialsValidateFailedError if api_key is invalid | ||||
| with pytest.raises(CredentialsValidateFailedError): | with pytest.raises(CredentialsValidateFailedError): |
| import pytest | |||||
| from unittest.mock import patch | |||||
| import json | |||||
| from langchain.schema import ChatResult, ChatGeneration, AIMessage | |||||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||||
| from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider | |||||
| from models.provider import ProviderType, Provider | |||||
| PROVIDER_NAME = 'zhipuai' | |||||
| MODEL_PROVIDER_CLASS = ZhipuAIProvider | |||||
| VALIDATE_CREDENTIAL = { | |||||
| 'api_key': 'valid_key', | |||||
| } | |||||
| def encrypt_side_effect(tenant_id, encrypt_key): | |||||
| return f'encrypted_{encrypt_key}' | |||||
| def decrypt_side_effect(tenant_id, encrypted_key): | |||||
| return encrypted_key.replace('encrypted_', '') | |||||
| def test_is_provider_credentials_valid_or_raise_valid(mocker): | |||||
| mocker.patch('core.third_party.langchain.llms.zhipuai_llm.ZhipuAIChatLLM._generate', | |||||
| return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))])) | |||||
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) | |||||
| def test_is_provider_credentials_valid_or_raise_invalid(): | |||||
| # raise CredentialsValidateFailedError if api_key is not in credentials | |||||
| with pytest.raises(CredentialsValidateFailedError): | |||||
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) | |||||
| credential = VALIDATE_CREDENTIAL.copy() | |||||
| credential['api_key'] = 'invalid_key' | |||||
| # raise CredentialsValidateFailedError if api_key is invalid | |||||
| with pytest.raises(CredentialsValidateFailedError): | |||||
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential) | |||||
| @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) | |||||
| def test_encrypt_credentials(mock_encrypt): | |||||
| result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy()) | |||||
| assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}' | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_get_credentials_custom(mock_decrypt): | |||||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||||
| encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] | |||||
| provider = Provider( | |||||
| id='provider_id', | |||||
| tenant_id='tenant_id', | |||||
| provider_name=PROVIDER_NAME, | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=json.dumps(encrypted_credential), | |||||
| is_valid=True, | |||||
| ) | |||||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||||
| result = model_provider.get_provider_credentials() | |||||
| assert result['api_key'] == 'valid_key' | |||||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||||
| def test_get_credentials_obfuscated(mock_decrypt): | |||||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||||
| encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] | |||||
| provider = Provider( | |||||
| id='provider_id', | |||||
| tenant_id='tenant_id', | |||||
| provider_name=PROVIDER_NAME, | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=json.dumps(encrypted_credential), | |||||
| is_valid=True, | |||||
| ) | |||||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||||
| result = model_provider.get_provider_credentials(obfuscated=True) | |||||
| middle_token = result['api_key'][6:-2] | |||||
| assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0) | |||||
| assert all(char == '*' for char in middle_token) |