| @@ -246,7 +246,8 @@ class ModelProviderModelParameterRuleApi(Resource): | |||
| 'enabled': v.enabled, | |||
| 'min': v.min, | |||
| 'max': v.max, | |||
| 'default': v.default | |||
| 'default': v.default, | |||
| 'precision': v.precision | |||
| } | |||
| for k, v in vars(parameter_rules).items() | |||
| } | |||
| @@ -290,10 +291,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider_name: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('token', type=str, required=False, nullable=True, location='args') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| result = provider_service.free_quota_qualification_verify( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name | |||
| provider_name=provider_name, | |||
| token=args['token'] | |||
| ) | |||
| return result | |||
| @@ -63,7 +63,18 @@ class LLMCallbackHandler(BaseCallbackHandler): | |||
| self.conversation_message_task.append_message_text(response.generations[0][0].text) | |||
| self.llm_message.completion = response.generations[0][0].text | |||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)]) | |||
| if response.llm_output and 'token_usage' in response.llm_output: | |||
| if 'prompt_tokens' in response.llm_output['token_usage']: | |||
| self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] | |||
| if 'completion_tokens' in response.llm_output['token_usage']: | |||
| self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens'] | |||
| else: | |||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens( | |||
| [PromptMessage(content=self.llm_message.completion)]) | |||
| else: | |||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens( | |||
| [PromptMessage(content=self.llm_message.completion)]) | |||
| self.conversation_message_task.save_message(self.llm_message) | |||
| @@ -2,13 +2,8 @@ import enum | |||
| import logging | |||
| from typing import List, Dict, Optional, Any | |||
| import openai | |||
| from flask import current_app | |||
| from langchain.callbacks.manager import CallbackManagerForChainRun | |||
| from langchain.chains.base import Chain | |||
| from openai import InvalidRequestError | |||
| from openai.error import APIConnectionError, APIError, ServiceUnavailableError, Timeout, RateLimitError, \ | |||
| AuthenticationError, OpenAIError | |||
| from pydantic import BaseModel | |||
| from core.model_providers.error import LLMBadRequestError | |||
| @@ -86,6 +81,12 @@ class SensitiveWordAvoidanceChain(Chain): | |||
| result = self._check_moderation(text) | |||
| if not result: | |||
| raise LLMBadRequestError(self.sensitive_word_avoidance_rule.canned_response) | |||
| raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response) | |||
| return {self.output_key: text} | |||
| class SensitiveWordAvoidanceError(Exception): | |||
| def __init__(self, message): | |||
| super().__init__(message) | |||
| self.message = message | |||
| @@ -7,6 +7,7 @@ from requests.exceptions import ChunkedEncodingError | |||
| from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy | |||
| from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler | |||
| from core.callback_handler.llm_callback_handler import LLMCallbackHandler | |||
| from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError | |||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ | |||
| @@ -76,28 +77,53 @@ class Completion: | |||
| app_model_config=app_model_config | |||
| ) | |||
| # parse sensitive_word_avoidance_chain | |||
| chain_callback = MainChainGatherCallbackHandler(conversation_message_task) | |||
| sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(final_model_instance, [chain_callback]) | |||
| if sensitive_word_avoidance_chain: | |||
| query = sensitive_word_avoidance_chain.run(query) | |||
| # get agent executor | |||
| agent_executor = orchestrator_rule_parser.to_agent_executor( | |||
| conversation_message_task=conversation_message_task, | |||
| memory=memory, | |||
| rest_tokens=rest_tokens_for_context_and_memory, | |||
| chain_callback=chain_callback | |||
| ) | |||
| # run agent executor | |||
| agent_execute_result = None | |||
| if agent_executor: | |||
| should_use_agent = agent_executor.should_use_agent(query) | |||
| if should_use_agent: | |||
| agent_execute_result = agent_executor.run(query) | |||
| # run the final llm | |||
| try: | |||
| # parse sensitive_word_avoidance_chain | |||
| chain_callback = MainChainGatherCallbackHandler(conversation_message_task) | |||
| sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain( | |||
| final_model_instance, [chain_callback]) | |||
| if sensitive_word_avoidance_chain: | |||
| try: | |||
| query = sensitive_word_avoidance_chain.run(query) | |||
| except SensitiveWordAvoidanceError as ex: | |||
| cls.run_final_llm( | |||
| model_instance=final_model_instance, | |||
| mode=app.mode, | |||
| app_model_config=app_model_config, | |||
| query=query, | |||
| inputs=inputs, | |||
| agent_execute_result=None, | |||
| conversation_message_task=conversation_message_task, | |||
| memory=memory, | |||
| fake_response=ex.message | |||
| ) | |||
| return | |||
| # get agent executor | |||
| agent_executor = orchestrator_rule_parser.to_agent_executor( | |||
| conversation_message_task=conversation_message_task, | |||
| memory=memory, | |||
| rest_tokens=rest_tokens_for_context_and_memory, | |||
| chain_callback=chain_callback, | |||
| retriever_from=retriever_from | |||
| ) | |||
| # run agent executor | |||
| agent_execute_result = None | |||
| if agent_executor: | |||
| should_use_agent = agent_executor.should_use_agent(query) | |||
| if should_use_agent: | |||
| agent_execute_result = agent_executor.run(query) | |||
| # When no extra pre prompt is specified, | |||
| # the output of the agent can be used directly as the main output content without calling LLM again | |||
| fake_response = None | |||
| if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \ | |||
| and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, | |||
| PlanningStrategy.REACT_ROUTER]: | |||
| fake_response = agent_execute_result.output | |||
| # run the final llm | |||
| cls.run_final_llm( | |||
| model_instance=final_model_instance, | |||
| mode=app.mode, | |||
| @@ -106,7 +132,8 @@ class Completion: | |||
| inputs=inputs, | |||
| agent_execute_result=agent_execute_result, | |||
| conversation_message_task=conversation_message_task, | |||
| memory=memory | |||
| memory=memory, | |||
| fake_response=fake_response | |||
| ) | |||
| except ConversationTaskStoppedException: | |||
| return | |||
| @@ -121,14 +148,8 @@ class Completion: | |||
| inputs: dict, | |||
| agent_execute_result: Optional[AgentExecuteResult], | |||
| conversation_message_task: ConversationMessageTask, | |||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]): | |||
| # When no extra pre prompt is specified, | |||
| # the output of the agent can be used directly as the main output content without calling LLM again | |||
| fake_response = None | |||
| if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \ | |||
| and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]: | |||
| fake_response = agent_execute_result.output | |||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], | |||
| fake_response: Optional[str]): | |||
| # get llm prompt | |||
| prompt_messages, stop_words = model_instance.get_prompt( | |||
| mode=mode, | |||
| @@ -1,32 +1,34 @@ | |||
| import logging | |||
| import openai | |||
| from flask import current_app | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.model_providers.providers.hosted import hosted_config, hosted_model_providers | |||
| from models.provider import ProviderType | |||
| def check_moderation(model_provider: BaseModelProvider, text: str) -> bool: | |||
| if current_app.config['HOSTED_MODERATION_ENABLED'] and current_app.config['HOSTED_MODERATION_PROVIDERS']: | |||
| moderation_providers = current_app.config['HOSTED_MODERATION_PROVIDERS'].split(',') | |||
| if hosted_config.moderation.enabled is True and hosted_model_providers.openai: | |||
| if model_provider.provider.provider_type == ProviderType.SYSTEM.value \ | |||
| and model_provider.provider_name in moderation_providers: | |||
| and model_provider.provider_name in hosted_config.moderation.providers: | |||
| # 2000 text per chunk | |||
| length = 2000 | |||
| chunks = [text[i:i + length] for i in range(0, len(text), length)] | |||
| try: | |||
| moderation_result = openai.Moderation.create(input=chunks, | |||
| api_key=current_app.config['HOSTED_OPENAI_API_KEY']) | |||
| except Exception as ex: | |||
| logging.exception(ex) | |||
| raise LLMBadRequestError('Rate limit exceeded, please try again later.') | |||
| for result in moderation_result.results: | |||
| if result['flagged'] is True: | |||
| return False | |||
| text_chunks = [text[i:i + length] for i in range(0, len(text), length)] | |||
| max_text_chunks = 32 | |||
| chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] | |||
| for text_chunk in chunks: | |||
| try: | |||
| moderation_result = openai.Moderation.create(input=text_chunk, | |||
| api_key=hosted_model_providers.openai.api_key) | |||
| except Exception as ex: | |||
| logging.exception(ex) | |||
| raise LLMBadRequestError('Rate limit exceeded, please try again later.') | |||
| for result in moderation_result.results: | |||
| if result['flagged'] is True: | |||
| return False | |||
| return True | |||
| @@ -45,6 +45,9 @@ class ModelProviderFactory: | |||
| elif provider_name == 'wenxin': | |||
| from core.model_providers.providers.wenxin_provider import WenxinProvider | |||
| return WenxinProvider | |||
| elif provider_name == 'zhipuai': | |||
| from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider | |||
| return ZhipuAIProvider | |||
| elif provider_name == 'chatglm': | |||
| from core.model_providers.providers.chatglm_provider import ChatGLMProvider | |||
| return ChatGLMProvider | |||
| @@ -0,0 +1,22 @@ | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||
| from core.third_party.langchain.embeddings.zhipuai_embedding import ZhipuAIEmbeddings | |||
| class ZhipuAIEmbedding(BaseEmbedding): | |||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||
| credentials = model_provider.get_model_credentials( | |||
| model_name=name, | |||
| model_type=self.type | |||
| ) | |||
| client = ZhipuAIEmbeddings( | |||
| model=name, | |||
| **credentials, | |||
| ) | |||
| super().__init__(model_provider, client, name) | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| return LLMBadRequestError(f"ZhipuAI embedding: {str(ex)}") | |||
| @@ -49,6 +49,7 @@ class KwargRule(Generic[T], BaseModel): | |||
| max: Optional[T] = None | |||
| default: Optional[T] = None | |||
| alias: Optional[str] = None | |||
| precision: Optional[int] = None | |||
| class ModelKwargsRules(BaseModel): | |||
| @@ -0,0 +1,61 @@ | |||
| from typing import List, Optional, Any | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.schema import LLMResult | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM | |||
| class ZhipuAIModel(BaseLLM): | |||
| model_mode: ModelMode = ModelMode.CHAT | |||
| def _init_client(self) -> Any: | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| return ZhipuAIChatLLM( | |||
| streaming=self.streaming, | |||
| callbacks=self.callbacks, | |||
| **self.credentials, | |||
| **provider_model_kwargs | |||
| ) | |||
| def _run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs) -> LLMResult: | |||
| """ | |||
| run predict by prompt messages and stop words. | |||
| :param messages: | |||
| :param stop: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.generate([prompts], stop, callbacks) | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| get num tokens of prompt messages. | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return max(self._client.get_num_tokens_from_messages(prompts), 0) | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| for k, v in provider_model_kwargs.items(): | |||
| if hasattr(self.client, k): | |||
| setattr(self.client, k, v) | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| return LLMBadRequestError(f"ZhipuAI: {str(ex)}") | |||
| @property | |||
| def support_streaming(self): | |||
| return True | |||
| @@ -23,14 +23,18 @@ class OpenAIModeration(BaseModeration): | |||
| # 2000 text per chunk | |||
| length = 2000 | |||
| chunks = [text[i:i + length] for i in range(0, len(text), length)] | |||
| text_chunks = [text[i:i + length] for i in range(0, len(text), length)] | |||
| moderation_result = self._client.create(input=chunks, | |||
| api_key=credentials['openai_api_key']) | |||
| max_text_chunks = 32 | |||
| chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] | |||
| for result in moderation_result.results: | |||
| if result['flagged'] is True: | |||
| return False | |||
| for text_chunk in chunks: | |||
| moderation_result = self._client.create(input=text_chunk, | |||
| api_key=credentials['openai_api_key']) | |||
| for result in moderation_result.results: | |||
| if result['flagged'] is True: | |||
| return False | |||
| return True | |||
| @@ -69,11 +69,11 @@ class AnthropicProvider(BaseModelProvider): | |||
| :return: | |||
| """ | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0, max=1, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||
| temperature=KwargRule[float](min=0, max=1, default=1, precision=2), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256), | |||
| max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0), | |||
| ) | |||
| @classmethod | |||
| @@ -164,14 +164,14 @@ class AzureOpenAIProvider(BaseModelProvider): | |||
| model_credentials = self.get_model_credentials(model_name, model_type) | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0, max=2, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=1), | |||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| temperature=KwargRule[float](min=0, max=2, default=1, precision=2), | |||
| top_p=KwargRule[float](min=0, max=1, default=1, precision=2), | |||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||
| max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get( | |||
| model_credentials['base_model_name'], | |||
| 4097 | |||
| ), default=16), | |||
| ), default=16, precision=0), | |||
| ) | |||
| @classmethod | |||
| @@ -64,11 +64,11 @@ class ChatGLMProvider(BaseModelProvider): | |||
| } | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0, max=2, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||
| temperature=KwargRule[float](min=0, max=2, default=1, precision=2), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048), | |||
| max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0), | |||
| ) | |||
| @classmethod | |||
| @@ -45,6 +45,18 @@ class HostedModelProviders(BaseModel): | |||
| hosted_model_providers = HostedModelProviders() | |||
| class HostedModerationConfig(BaseModel): | |||
| enabled: bool = False | |||
| providers: list[str] = [] | |||
| class HostedConfig(BaseModel): | |||
| moderation = HostedModerationConfig() | |||
| hosted_config = HostedConfig() | |||
| def init_app(app: Flask): | |||
| if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': | |||
| langchain.verbose = True | |||
| @@ -78,3 +90,9 @@ def init_app(app: Flask): | |||
| paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"), | |||
| paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"), | |||
| ) | |||
| if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"): | |||
| hosted_config.moderation = HostedModerationConfig( | |||
| enabled=app.config.get("HOSTED_MODERATION_ENABLED"), | |||
| providers=app.config.get("HOSTED_MODERATION_PROVIDERS").split(',') | |||
| ) | |||
| @@ -47,11 +47,11 @@ class HuggingfaceHubProvider(BaseModelProvider): | |||
| :return: | |||
| """ | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0, max=2, default=1), | |||
| top_p=KwargRule[float](min=0.01, max=0.99, default=0.7), | |||
| temperature=KwargRule[float](min=0, max=2, default=1, precision=2), | |||
| top_p=KwargRule[float](min=0.01, max=0.99, default=0.7, precision=2), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200), | |||
| max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200, precision=0), | |||
| ) | |||
| @classmethod | |||
| @@ -52,9 +52,9 @@ class LocalAIProvider(BaseModelProvider): | |||
| :return: | |||
| """ | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0, max=2, default=0.7), | |||
| top_p=KwargRule[float](min=0, max=1, default=1), | |||
| max_tokens=KwargRule[int](min=10, max=4097, default=16), | |||
| temperature=KwargRule[float](min=0, max=2, default=0.7, precision=2), | |||
| top_p=KwargRule[float](min=0, max=1, default=1, precision=2), | |||
| max_tokens=KwargRule[int](min=10, max=4097, default=16, precision=0), | |||
| ) | |||
| @classmethod | |||
| @@ -74,11 +74,11 @@ class MinimaxProvider(BaseModelProvider): | |||
| } | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0.01, max=1, default=0.9), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.95), | |||
| temperature=KwargRule[float](min=0.01, max=1, default=0.9, precision=2), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.95, precision=2), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024), | |||
| max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024, precision=0), | |||
| ) | |||
| @classmethod | |||
| @@ -133,11 +133,11 @@ class OpenAIProvider(BaseModelProvider): | |||
| } | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0, max=2, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=1), | |||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16), | |||
| temperature=KwargRule[float](min=0, max=2, default=1, precision=2), | |||
| top_p=KwargRule[float](min=0, max=1, default=1, precision=2), | |||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||
| max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16, precision=0), | |||
| ) | |||
| @classmethod | |||
| @@ -45,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider): | |||
| :return: | |||
| """ | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128), | |||
| temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), | |||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||
| max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128, precision=0), | |||
| ) | |||
| @classmethod | |||
| @@ -72,6 +72,7 @@ class ReplicateProvider(BaseModelProvider): | |||
| min=float(value.get('minimum')) if value.get('minimum') is not None else None, | |||
| max=float(value.get('maximum')) if value.get('maximum') is not None else None, | |||
| default=float(value.get('default')) if value.get('default') is not None else None, | |||
| precision = 2 | |||
| ) | |||
| if key == 'temperature': | |||
| model_kwargs_rules.temperature = kwarg_rule | |||
| @@ -84,6 +85,7 @@ class ReplicateProvider(BaseModelProvider): | |||
| min=int(value.get('minimum')) if value.get('minimum') is not None else 1, | |||
| max=int(value.get('maximum')) if value.get('maximum') is not None else 8000, | |||
| default=int(value.get('default')) if value.get('default') is not None else 500, | |||
| precision = 0 | |||
| ) | |||
| return model_kwargs_rules | |||
| @@ -62,11 +62,11 @@ class SparkProvider(BaseModelProvider): | |||
| :return: | |||
| """ | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0, max=1, default=0.5), | |||
| temperature=KwargRule[float](min=0, max=1, default=0.5, precision=2), | |||
| top_p=KwargRule[float](enabled=False), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](min=10, max=4096, default=2048), | |||
| max_tokens=KwargRule[int](min=10, max=4096, default=2048, precision=0), | |||
| ) | |||
| @classmethod | |||
| @@ -64,10 +64,10 @@ class TongyiProvider(BaseModelProvider): | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](enabled=False), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.8), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024), | |||
| max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0), | |||
| ) | |||
| @classmethod | |||
| @@ -63,8 +63,8 @@ class WenxinProvider(BaseModelProvider): | |||
| """ | |||
| if model_name in ['ernie-bot', 'ernie-bot-turbo']: | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0.01, max=1, default=0.95), | |||
| top_p=KwargRule[float](min=0.01, max=1, default=0.8), | |||
| temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2), | |||
| top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](enabled=False), | |||
| @@ -53,27 +53,27 @@ class XinferenceProvider(BaseModelProvider): | |||
| credentials = self.get_model_credentials(model_name, model_type) | |||
| if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm": | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||
| temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](min=10, max=4000, default=256), | |||
| max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0), | |||
| ) | |||
| elif credentials['model_format'] == "ggmlv3": | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| max_tokens=KwargRule[int](min=10, max=4000, default=256), | |||
| temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), | |||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), | |||
| max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0), | |||
| ) | |||
| else: | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||
| temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](min=10, max=4000, default=256), | |||
| max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0), | |||
| ) | |||
| @@ -0,0 +1,176 @@ | |||
| import json | |||
| from json import JSONDecodeError | |||
| from typing import Type | |||
| from langchain.schema import HumanMessage | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType | |||
| from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM | |||
| from models.provider import ProviderType, ProviderQuotaType | |||
| class ZhipuAIProvider(BaseModelProvider): | |||
| @property | |||
| def provider_name(self): | |||
| """ | |||
| Returns the name of a provider. | |||
| """ | |||
| return 'zhipuai' | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| if model_type == ModelType.TEXT_GENERATION: | |||
| return [ | |||
| { | |||
| 'id': 'chatglm_pro', | |||
| 'name': 'chatglm_pro', | |||
| }, | |||
| { | |||
| 'id': 'chatglm_std', | |||
| 'name': 'chatglm_std', | |||
| }, | |||
| { | |||
| 'id': 'chatglm_lite', | |||
| 'name': 'chatglm_lite', | |||
| }, | |||
| { | |||
| 'id': 'chatglm_lite_32k', | |||
| 'name': 'chatglm_lite_32k', | |||
| } | |||
| ] | |||
| elif model_type == ModelType.EMBEDDINGS: | |||
| return [ | |||
| { | |||
| 'id': 'text_embedding', | |||
| 'name': 'text_embedding', | |||
| } | |||
| ] | |||
| else: | |||
| return [] | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| :param model_type: | |||
| :return: | |||
| """ | |||
| if model_type == ModelType.TEXT_GENERATION: | |||
| model_class = ZhipuAIModel | |||
| elif model_type == ModelType.EMBEDDINGS: | |||
| model_class = ZhipuAIEmbedding | |||
| else: | |||
| raise NotImplementedError | |||
| return model_class | |||
| def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: | |||
| """ | |||
| get model parameter rules. | |||
| :param model_name: | |||
| :param model_type: | |||
| :return: | |||
| """ | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2), | |||
| top_p=KwargRule[float](min=0.1, max=0.9, default=0.8, precision=1), | |||
| presence_penalty=KwargRule[float](enabled=False), | |||
| frequency_penalty=KwargRule[float](enabled=False), | |||
| max_tokens=KwargRule[int](enabled=False), | |||
| ) | |||
| @classmethod | |||
| def is_provider_credentials_valid_or_raise(cls, credentials: dict): | |||
| """ | |||
| Validates the given credentials. | |||
| """ | |||
| if 'api_key' not in credentials: | |||
| raise CredentialsValidateFailedError('ZhipuAI api_key must be provided.') | |||
| try: | |||
| credential_kwargs = { | |||
| 'api_key': credentials['api_key'] | |||
| } | |||
| llm = ZhipuAIChatLLM( | |||
| temperature=0.01, | |||
| **credential_kwargs | |||
| ) | |||
| llm([HumanMessage(content='ping')]) | |||
| except Exception as ex: | |||
| raise CredentialsValidateFailedError(str(ex)) | |||
| @classmethod | |||
| def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: | |||
| credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key']) | |||
| return credentials | |||
| def get_provider_credentials(self, obfuscated: bool = False) -> dict: | |||
| if self.provider.provider_type == ProviderType.CUSTOM.value \ | |||
| or (self.provider.provider_type == ProviderType.SYSTEM.value | |||
| and self.provider.quota_type == ProviderQuotaType.FREE.value): | |||
| try: | |||
| credentials = json.loads(self.provider.encrypted_config) | |||
| except JSONDecodeError: | |||
| credentials = { | |||
| 'api_key': None, | |||
| } | |||
| if credentials['api_key']: | |||
| credentials['api_key'] = encrypter.decrypt_token( | |||
| self.provider.tenant_id, | |||
| credentials['api_key'] | |||
| ) | |||
| if obfuscated: | |||
| credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key']) | |||
| return credentials | |||
| else: | |||
| return {} | |||
| def should_deduct_quota(self): | |||
| return True | |||
| @classmethod | |||
| def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): | |||
| """ | |||
| check model credentials valid. | |||
| :param model_name: | |||
| :param model_type: | |||
| :param credentials: | |||
| """ | |||
| return | |||
| @classmethod | |||
| def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, | |||
| credentials: dict) -> dict: | |||
| """ | |||
| encrypt model credentials for save. | |||
| :param tenant_id: | |||
| :param model_name: | |||
| :param model_type: | |||
| :param credentials: | |||
| :return: | |||
| """ | |||
| return {} | |||
| def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: | |||
| """ | |||
| get credentials for llm use. | |||
| :param model_name: | |||
| :param model_type: | |||
| :param obfuscated: | |||
| :return: | |||
| """ | |||
| return self.get_provider_credentials(obfuscated) | |||
| @@ -6,6 +6,7 @@ | |||
| "tongyi", | |||
| "spark", | |||
| "wenxin", | |||
| "zhipuai", | |||
| "chatglm", | |||
| "replicate", | |||
| "huggingface_hub", | |||
| @@ -0,0 +1,44 @@ | |||
| { | |||
| "support_provider_types": [ | |||
| "system", | |||
| "custom" | |||
| ], | |||
| "system_config": { | |||
| "supported_quota_types": [ | |||
| "free" | |||
| ], | |||
| "quota_unit": "tokens" | |||
| }, | |||
| "model_flexibility": "fixed", | |||
| "price_config": { | |||
| "chatglm_pro": { | |||
| "prompt": "0.01", | |||
| "completion": "0.01", | |||
| "unit": "0.001", | |||
| "currency": "RMB" | |||
| }, | |||
| "chatglm_std": { | |||
| "prompt": "0.005", | |||
| "completion": "0.005", | |||
| "unit": "0.001", | |||
| "currency": "RMB" | |||
| }, | |||
| "chatglm_lite": { | |||
| "prompt": "0.002", | |||
| "completion": "0.002", | |||
| "unit": "0.001", | |||
| "currency": "RMB" | |||
| }, | |||
| "chatglm_lite_32k": { | |||
| "prompt": "0.0004", | |||
| "completion": "0.0004", | |||
| "unit": "0.001", | |||
| "currency": "RMB" | |||
| }, | |||
| "text_embedding": { | |||
| "completion": "0", | |||
| "unit": "0.001", | |||
| "currency": "RMB" | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,64 @@ | |||
| """Wrapper around ZhipuAI embedding models.""" | |||
| from typing import Any, Dict, List, Optional | |||
| from pydantic import BaseModel, Extra, root_validator | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.utils import get_from_dict_or_env | |||
| from core.third_party.langchain.llms.zhipuai_llm import ZhipuModelAPI | |||
| class ZhipuAIEmbeddings(BaseModel, Embeddings): | |||
| """Wrapper around ZhipuAI embedding models. | |||
| 1024 dimensions. | |||
| """ | |||
| client: Any #: :meta private: | |||
| model: str | |||
| """Model name to use.""" | |||
| base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api" | |||
| api_key: Optional[str] = None | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| extra = Extra.forbid | |||
| @root_validator() | |||
| def validate_environment(cls, values: Dict) -> Dict: | |||
| """Validate that api key and python package exists in environment.""" | |||
| values["api_key"] = get_from_dict_or_env( | |||
| values, "api_key", "ZHIPUAI_API_KEY" | |||
| ) | |||
| values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url']) | |||
| return values | |||
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |||
| """Call out to ZhipuAI's embedding endpoint. | |||
| Args: | |||
| texts: The list of texts to embed. | |||
| Returns: | |||
| List of embeddings, one for each text. | |||
| """ | |||
| embeddings = [] | |||
| for text in texts: | |||
| response = self.client.invoke(model=self.model, prompt=text) | |||
| data = response["data"] | |||
| embeddings.append(data.get('embedding')) | |||
| return [list(map(float, e)) for e in embeddings] | |||
| def embed_query(self, text: str) -> List[float]: | |||
| """Call out to ZhipuAI's embedding endpoint. | |||
| Args: | |||
| text: The text to embed. | |||
| Returns: | |||
| Embeddings for the text. | |||
| """ | |||
| return self.embed_documents([text])[0] | |||
| @@ -0,0 +1,315 @@ | |||
| """Wrapper around ZhipuAI APIs.""" | |||
| from __future__ import annotations | |||
| import json | |||
| import logging | |||
| import posixpath | |||
| from typing import ( | |||
| Any, | |||
| Dict, | |||
| List, | |||
| Optional, Iterator, Sequence, | |||
| ) | |||
| import zhipuai | |||
| from langchain.chat_models.base import BaseChatModel | |||
| from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage | |||
| from langchain.schema.messages import AIMessageChunk | |||
| from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration | |||
| from pydantic import Extra, root_validator, BaseModel | |||
| from langchain.callbacks.manager import ( | |||
| CallbackManagerForLLMRun, | |||
| ) | |||
| from langchain.utils import get_from_dict_or_env | |||
| from zhipuai.model_api.api import InvokeType | |||
| from zhipuai.utils import jwt_token | |||
| from zhipuai.utils.http_client import post, stream | |||
| from zhipuai.utils.sse_client import SSEClient | |||
| logger = logging.getLogger(__name__) | |||
| class ZhipuModelAPI(BaseModel): | |||
| base_url: str | |||
| api_key: str | |||
| api_timeout_seconds = 60 | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| extra = Extra.forbid | |||
| def invoke(self, **kwargs): | |||
| url = self._build_api_url(kwargs, InvokeType.SYNC) | |||
| response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds) | |||
| if not response['success']: | |||
| raise ValueError( | |||
| f"Error Code: {response['code']}, Message: {response['msg']} " | |||
| ) | |||
| return response | |||
| def sse_invoke(self, **kwargs): | |||
| url = self._build_api_url(kwargs, InvokeType.SSE) | |||
| data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds) | |||
| return SSEClient(data) | |||
| def _build_api_url(self, kwargs, *path): | |||
| if kwargs: | |||
| if "model" not in kwargs: | |||
| raise Exception("model param missed") | |||
| model = kwargs.pop("model") | |||
| else: | |||
| model = "-" | |||
| return posixpath.join(self.base_url, model, *path) | |||
| def _generate_token(self): | |||
| if not self.api_key: | |||
| raise Exception( | |||
| "api_key not provided, you could provide it." | |||
| ) | |||
| try: | |||
| return jwt_token.generate_token(self.api_key) | |||
| except Exception: | |||
| raise ValueError( | |||
| f"Your api_key is invalid, please check it." | |||
| ) | |||
| class ZhipuAIChatLLM(BaseChatModel): | |||
| """Wrapper around ZhipuAI large language models. | |||
| To use, you should pass the api_key as a named parameter to the constructor. | |||
| Example: | |||
| .. code-block:: python | |||
| from core.third_party.langchain.llms.zhipuai import ZhipuAI | |||
| model = ZhipuAI(model="<model_name>", api_key="my-api-key") | |||
| """ | |||
| @property | |||
| def lc_secrets(self) -> Dict[str, str]: | |||
| return {"api_key": "API_KEY"} | |||
| @property | |||
| def lc_serializable(self) -> bool: | |||
| return True | |||
| client: Any = None #: :meta private: | |||
| model: str = "chatglm_lite" | |||
| """Model name to use.""" | |||
| temperature: float = 0.95 | |||
| """A non-negative float that tunes the degree of randomness in generation.""" | |||
| top_p: float = 0.7 | |||
| """Total probability mass of tokens to consider at each step.""" | |||
| streaming: bool = False | |||
| """Whether to stream the response or return it all at once.""" | |||
| api_key: Optional[str] = None | |||
| base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api" | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| extra = Extra.forbid | |||
| @root_validator() | |||
| def validate_environment(cls, values: Dict) -> Dict: | |||
| """Validate that api key and python package exists in environment.""" | |||
| values["api_key"] = get_from_dict_or_env( | |||
| values, "api_key", "ZHIPUAI_API_KEY" | |||
| ) | |||
| if 'test' in values['base_url']: | |||
| values['model'] = 'chatglm_130b_test' | |||
| values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url']) | |||
| return values | |||
| @property | |||
| def _default_params(self) -> Dict[str, Any]: | |||
| """Get the default parameters for calling OpenAI API.""" | |||
| return { | |||
| "model": self.model, | |||
| "temperature": self.temperature, | |||
| "top_p": self.top_p | |||
| } | |||
| @property | |||
| def _identifying_params(self) -> Dict[str, Any]: | |||
| """Get the identifying parameters.""" | |||
| return self._default_params | |||
| @property | |||
| def _llm_type(self) -> str: | |||
| """Return type of llm.""" | |||
| return "zhipuai" | |||
| def _convert_message_to_dict(self, message: BaseMessage) -> dict: | |||
| if isinstance(message, ChatMessage): | |||
| message_dict = {"role": message.role, "content": message.content} | |||
| elif isinstance(message, HumanMessage): | |||
| message_dict = {"role": "user", "content": message.content} | |||
| elif isinstance(message, AIMessage): | |||
| message_dict = {"role": "assistant", "content": message.content} | |||
| elif isinstance(message, SystemMessage): | |||
| message_dict = {"role": "user", "content": message.content} | |||
| else: | |||
| raise ValueError(f"Got unknown type {message}") | |||
| return message_dict | |||
| def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage: | |||
| role = _dict["role"] | |||
| if role == "user": | |||
| return HumanMessage(content=_dict["content"]) | |||
| elif role == "assistant": | |||
| return AIMessage(content=_dict["content"]) | |||
| elif role == "system": | |||
| return SystemMessage(content=_dict["content"]) | |||
| else: | |||
| return ChatMessage(content=_dict["content"], role=role) | |||
| def _create_message_dicts( | |||
| self, messages: List[BaseMessage] | |||
| ) -> List[Dict[str, Any]]: | |||
| dict_messages = [] | |||
| for m in messages: | |||
| message = self._convert_message_to_dict(m) | |||
| if dict_messages: | |||
| previous_message = dict_messages[-1] | |||
| if previous_message['role'] == message['role']: | |||
| dict_messages[-1]['content'] += f"\n{message['content']}" | |||
| else: | |||
| dict_messages.append(message) | |||
| else: | |||
| dict_messages.append(message) | |||
| return dict_messages | |||
| def _generate( | |||
| self, | |||
| messages: List[BaseMessage], | |||
| stop: Optional[List[str]] = None, | |||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||
| **kwargs: Any, | |||
| ) -> ChatResult: | |||
| if self.streaming: | |||
| generation: Optional[ChatGenerationChunk] = None | |||
| llm_output: Optional[Dict] = None | |||
| for chunk in self._stream( | |||
| messages=messages, stop=stop, run_manager=run_manager, **kwargs | |||
| ): | |||
| if chunk.generation_info is not None \ | |||
| and 'token_usage' in chunk.generation_info: | |||
| llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model} | |||
| continue | |||
| if generation is None: | |||
| generation = chunk | |||
| else: | |||
| generation += chunk | |||
| assert generation is not None | |||
| return ChatResult(generations=[generation], llm_output=llm_output) | |||
| else: | |||
| message_dicts = self._create_message_dicts(messages) | |||
| request = self._default_params | |||
| request["prompt"] = message_dicts | |||
| request.update(kwargs) | |||
| response = self.client.invoke(**request) | |||
| return self._create_chat_result(response) | |||
| def _stream( | |||
| self, | |||
| messages: List[BaseMessage], | |||
| stop: Optional[List[str]] = None, | |||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||
| **kwargs: Any, | |||
| ) -> Iterator[ChatGenerationChunk]: | |||
| message_dicts = self._create_message_dicts(messages) | |||
| request = self._default_params | |||
| request["prompt"] = message_dicts | |||
| request.update(kwargs) | |||
| for event in self.client.sse_invoke(incremental=True, **request).events(): | |||
| if event.event == "add": | |||
| yield ChatGenerationChunk(message=AIMessageChunk(content=event.data)) | |||
| if run_manager: | |||
| run_manager.on_llm_new_token(event.data) | |||
| elif event.event == "error" or event.event == "interrupted": | |||
| raise ValueError( | |||
| f"{event.data}" | |||
| ) | |||
| elif event.event == "finish": | |||
| meta = json.loads(event.meta) | |||
| token_usage = meta['usage'] | |||
| if token_usage is not None: | |||
| if 'prompt_tokens' not in token_usage: | |||
| token_usage['prompt_tokens'] = 0 | |||
| if 'completion_tokens' not in token_usage: | |||
| token_usage['completion_tokens'] = token_usage['total_tokens'] | |||
| yield ChatGenerationChunk( | |||
| message=AIMessageChunk(content=event.data), | |||
| generation_info=dict({'token_usage': token_usage}) | |||
| ) | |||
| def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult: | |||
| data = response["data"] | |||
| generations = [] | |||
| for res in data["choices"]: | |||
| message = self._convert_dict_to_message(res) | |||
| gen = ChatGeneration( | |||
| message=message | |||
| ) | |||
| generations.append(gen) | |||
| token_usage = data.get("usage") | |||
| if token_usage is not None: | |||
| if 'prompt_tokens' not in token_usage: | |||
| token_usage['prompt_tokens'] = 0 | |||
| if 'completion_tokens' not in token_usage: | |||
| token_usage['completion_tokens'] = token_usage['total_tokens'] | |||
| llm_output = {"token_usage": token_usage, "model_name": self.model} | |||
| return ChatResult(generations=generations, llm_output=llm_output) | |||
| # def get_token_ids(self, text: str) -> List[int]: | |||
| # """Return the ordered ids of the tokens in a text. | |||
| # | |||
| # Args: | |||
| # text: The string input to tokenize. | |||
| # | |||
| # Returns: | |||
| # A list of ids corresponding to the tokens in the text, in order they occur | |||
| # in the text. | |||
| # """ | |||
| # from core.third_party.transformers.Token import ChatGLMTokenizer | |||
| # | |||
| # tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm2-6b") | |||
| # return tokenizer.encode(text) | |||
| def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: | |||
| """Get the number of tokens in the messages. | |||
| Useful for checking if an input will fit in a model's context window. | |||
| Args: | |||
| messages: The message inputs to tokenize. | |||
| Returns: | |||
| The sum of the number of tokens across the messages. | |||
| """ | |||
| return sum([self.get_num_tokens(m.content) for m in messages]) | |||
| def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: | |||
| overall_token_usage: dict = {} | |||
| for output in llm_outputs: | |||
| if output is None: | |||
| # Happens in streaming | |||
| continue | |||
| token_usage = output["token_usage"] | |||
| for k, v in token_usage.items(): | |||
| if k in overall_token_usage: | |||
| overall_token_usage[k] += v | |||
| else: | |||
| overall_token_usage[k] = v | |||
| return {"token_usage": overall_token_usage, "model_name": self.model} | |||
| @@ -50,4 +50,5 @@ transformers~=4.31.0 | |||
| stripe~=5.5.0 | |||
| pandas==1.5.3 | |||
| xinference==0.4.2 | |||
| safetensors==0.3.2 | |||
| safetensors==0.3.2 | |||
| zhipuai==1.0.7 | |||
| @@ -548,7 +548,7 @@ class ProviderService: | |||
| 'result': 'success' | |||
| } | |||
| def free_quota_qualification_verify(self, tenant_id: str, provider_name: str): | |||
| def free_quota_qualification_verify(self, tenant_id: str, provider_name: str, token: Optional[str]): | |||
| api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") | |||
| api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") | |||
| api_url = api_base_url + '/api/v1/providers/qualification-verify' | |||
| @@ -557,8 +557,11 @@ class ProviderService: | |||
| 'Content-Type': 'application/json', | |||
| 'Authorization': f"Bearer {api_key}" | |||
| } | |||
| json_data = {'workspace_id': tenant_id, 'provider_name': provider_name} | |||
| if token: | |||
| json_data['token'] = token | |||
| response = requests.post(api_url, headers=headers, | |||
| json={'workspace_id': tenant_id, 'provider_name': provider_name}) | |||
| json=json_data) | |||
| if not response.ok: | |||
| logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") | |||
| raise ValueError(f"Error: {response.status_code} ") | |||
| @@ -31,6 +31,9 @@ TONGYI_DASHSCOPE_API_KEY= | |||
| WENXIN_API_KEY= | |||
| WENXIN_SECRET_KEY= | |||
| # ZhipuAI Credentials | |||
| ZHIPUAI_API_KEY= | |||
| # ChatGLM Credentials | |||
| CHATGLM_API_BASE= | |||
| @@ -0,0 +1,50 @@ | |||
| import json | |||
| import os | |||
| from unittest.mock import patch | |||
| from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding | |||
| from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider | |||
| from models.provider import Provider, ProviderType | |||
| def get_mock_provider(valid_api_key): | |||
| return Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name='zhipuai', | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps({ | |||
| 'api_key': valid_api_key | |||
| }), | |||
| is_valid=True, | |||
| ) | |||
| def get_mock_embedding_model(): | |||
| model_name = 'text_embedding' | |||
| valid_api_key = os.environ['ZHIPUAI_API_KEY'] | |||
| provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key)) | |||
| return ZhipuAIEmbedding( | |||
| model_provider=provider, | |||
| name=model_name | |||
| ) | |||
| def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| return encrypted_api_key | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_embedding(mock_decrypt): | |||
| embedding_model = get_mock_embedding_model() | |||
| rst = embedding_model.client.embed_query('test') | |||
| assert isinstance(rst, list) | |||
| assert len(rst) == 1024 | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_doc_embedding(mock_decrypt): | |||
| embedding_model = get_mock_embedding_model() | |||
| rst = embedding_model.client.embed_documents(['test', 'test2']) | |||
| assert isinstance(rst, list) | |||
| assert len(rst[0]) == 1024 | |||
| @@ -0,0 +1,79 @@ | |||
| import json | |||
| import os | |||
| from unittest.mock import patch | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||
| from core.model_providers.models.entity.model_params import ModelKwargs | |||
| from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel | |||
| from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider | |||
| from models.provider import Provider, ProviderType | |||
| def get_mock_provider(valid_api_key): | |||
| return Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name='zhipuai', | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps({ | |||
| 'api_key': valid_api_key | |||
| }), | |||
| is_valid=True, | |||
| ) | |||
| def get_mock_model(model_name: str, streaming: bool = False): | |||
| model_kwargs = ModelKwargs( | |||
| temperature=0.01, | |||
| ) | |||
| valid_api_key = os.environ['ZHIPUAI_API_KEY'] | |||
| model_provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key)) | |||
| return ZhipuAIModel( | |||
| model_provider=model_provider, | |||
| name=model_name, | |||
| model_kwargs=model_kwargs, | |||
| streaming=streaming | |||
| ) | |||
| def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| return encrypted_api_key | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_chat_get_num_tokens(mock_decrypt): | |||
| model = get_mock_model('chatglm_lite') | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst > 0 | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_chat_run(mock_decrypt, mocker): | |||
| mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) | |||
| model = get_mock_model('chatglm_lite') | |||
| messages = [ | |||
| PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') | |||
| ] | |||
| rst = model.run( | |||
| messages, | |||
| ) | |||
| assert len(rst.content) > 0 | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_chat_stream_run(mock_decrypt, mocker): | |||
| mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) | |||
| model = get_mock_model('chatglm_lite', streaming=True) | |||
| messages = [ | |||
| PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') | |||
| ] | |||
| rst = model.run( | |||
| messages | |||
| ) | |||
| assert len(rst.content) > 0 | |||
| @@ -39,7 +39,7 @@ def test_is_provider_credentials_valid_or_raise_invalid(): | |||
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) | |||
| credential = VALIDATE_CREDENTIAL.copy() | |||
| credential['api_key'] = 'invalid_key' | |||
| del credential['api_key'] | |||
| # raise CredentialsValidateFailedError if api_key is invalid | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| @@ -0,0 +1,88 @@ | |||
| import pytest | |||
| from unittest.mock import patch | |||
| import json | |||
| from langchain.schema import ChatResult, ChatGeneration, AIMessage | |||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||
| from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider | |||
| from models.provider import ProviderType, Provider | |||
| PROVIDER_NAME = 'zhipuai' | |||
| MODEL_PROVIDER_CLASS = ZhipuAIProvider | |||
| VALIDATE_CREDENTIAL = { | |||
| 'api_key': 'valid_key', | |||
| } | |||
| def encrypt_side_effect(tenant_id, encrypt_key): | |||
| return f'encrypted_{encrypt_key}' | |||
| def decrypt_side_effect(tenant_id, encrypted_key): | |||
| return encrypted_key.replace('encrypted_', '') | |||
| def test_is_provider_credentials_valid_or_raise_valid(mocker): | |||
| mocker.patch('core.third_party.langchain.llms.zhipuai_llm.ZhipuAIChatLLM._generate', | |||
| return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))])) | |||
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) | |||
| def test_is_provider_credentials_valid_or_raise_invalid(): | |||
| # raise CredentialsValidateFailedError if api_key is not in credentials | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) | |||
| credential = VALIDATE_CREDENTIAL.copy() | |||
| credential['api_key'] = 'invalid_key' | |||
| # raise CredentialsValidateFailedError if api_key is invalid | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential) | |||
| @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) | |||
| def test_encrypt_credentials(mock_encrypt): | |||
| result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy()) | |||
| assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}' | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_get_credentials_custom(mock_decrypt): | |||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||
| encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] | |||
| provider = Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name=PROVIDER_NAME, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps(encrypted_credential), | |||
| is_valid=True, | |||
| ) | |||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||
| result = model_provider.get_provider_credentials() | |||
| assert result['api_key'] == 'valid_key' | |||
| @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | |||
| def test_get_credentials_obfuscated(mock_decrypt): | |||
| encrypted_credential = VALIDATE_CREDENTIAL.copy() | |||
| encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] | |||
| provider = Provider( | |||
| id='provider_id', | |||
| tenant_id='tenant_id', | |||
| provider_name=PROVIDER_NAME, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps(encrypted_credential), | |||
| is_valid=True, | |||
| ) | |||
| model_provider = MODEL_PROVIDER_CLASS(provider=provider) | |||
| result = model_provider.get_provider_credentials(obfuscated=True) | |||
| middle_token = result['api_key'][6:-2] | |||
| assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0) | |||
| assert all(char == '*' for char in middle_token) | |||