| from core.memory.token_buffer_memory import TokenBufferMemory | from core.memory.token_buffer_memory import TokenBufferMemory | ||||
| from core.model_manager import ModelInstance | from core.model_manager import ModelInstance | ||||
| from core.model_runtime.entities.llm_entities import LLMUsage | from core.model_runtime.entities.llm_entities import LLMUsage | ||||
| from core.model_runtime.entities.model_entities import ModelFeature | |||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.moderation.base import ModerationException | from core.moderation.base import ModerationException | ||||
| from core.tools.entities.tool_entities import ToolRuntimeVariablePool | from core.tools.entities.tool_entities import ToolRuntimeVariablePool | ||||
| memory=memory, | memory=memory, | ||||
| ) | ) | ||||
| # change function call strategy based on LLM model | |||||
| llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) | |||||
| model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) | |||||
| if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features): | |||||
| agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING | |||||
| # start agent runner | # start agent runner | ||||
| if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: | if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: | ||||
| assistant_cot_runner = AssistantCotApplicationRunner( | assistant_cot_runner = AssistantCotApplicationRunner( | ||||
| prompt_messages=prompt_message, | prompt_messages=prompt_message, | ||||
| variables_pool=tool_variables, | variables_pool=tool_variables, | ||||
| db_variables=tool_conversation_variables, | db_variables=tool_conversation_variables, | ||||
| model_instance=model_instance | |||||
| ) | ) | ||||
| invoke_result = assistant_cot_runner.run( | invoke_result = assistant_cot_runner.run( | ||||
| model_instance=model_instance, | |||||
| conversation=conversation, | conversation=conversation, | ||||
| message=message, | message=message, | ||||
| query=query, | query=query, | ||||
| memory=memory, | memory=memory, | ||||
| prompt_messages=prompt_message, | prompt_messages=prompt_message, | ||||
| variables_pool=tool_variables, | variables_pool=tool_variables, | ||||
| db_variables=tool_conversation_variables | |||||
| db_variables=tool_conversation_variables, | |||||
| model_instance=model_instance | |||||
| ) | ) | ||||
| invoke_result = assistant_fc_runner.run( | invoke_result = assistant_fc_runner.run( | ||||
| model_instance=model_instance, | |||||
| conversation=conversation, | conversation=conversation, | ||||
| message=message, | message=message, | ||||
| query=query, | query=query, |
| import logging | import logging | ||||
| import json | import json | ||||
| from typing import Optional, List, Tuple, Union | |||||
| from typing import Optional, List, Tuple, Union, cast | |||||
| from datetime import datetime | from datetime import datetime | ||||
| from mimetypes import guess_extension | from mimetypes import guess_extension | ||||
| AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom | AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom | ||||
| from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool | from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool | ||||
| from core.model_runtime.entities.llm_entities import LLMUsage | from core.model_runtime.entities.llm_entities import LLMUsage | ||||
| from core.model_runtime.entities.model_entities import ModelFeature | |||||
| from core.model_runtime.utils.encoders import jsonable_encoder | from core.model_runtime.utils.encoders import jsonable_encoder | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||||
| from core.model_manager import ModelInstance | |||||
| from core.file.message_file_parser import FileTransferMethod | from core.file.message_file_parser import FileTransferMethod | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| prompt_messages: Optional[List[PromptMessage]] = None, | prompt_messages: Optional[List[PromptMessage]] = None, | ||||
| variables_pool: Optional[ToolRuntimeVariablePool] = None, | variables_pool: Optional[ToolRuntimeVariablePool] = None, | ||||
| db_variables: Optional[ToolConversationVariables] = None, | db_variables: Optional[ToolConversationVariables] = None, | ||||
| model_instance: ModelInstance = None | |||||
| ) -> None: | ) -> None: | ||||
| """ | """ | ||||
| Agent runner | Agent runner | ||||
| self.history_prompt_messages = prompt_messages | self.history_prompt_messages = prompt_messages | ||||
| self.variables_pool = variables_pool | self.variables_pool = variables_pool | ||||
| self.db_variables_pool = db_variables | self.db_variables_pool = db_variables | ||||
| self.model_instance = model_instance | |||||
| # init callback | # init callback | ||||
| self.agent_callback = DifyAgentCallbackHandler() | self.agent_callback = DifyAgentCallbackHandler() | ||||
| MessageAgentThought.message_id == self.message.id, | MessageAgentThought.message_id == self.message.id, | ||||
| ).count() | ).count() | ||||
| # check if model supports stream tool call | |||||
| llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) | |||||
| model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) | |||||
| if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []): | |||||
| self.stream_tool_call = True | |||||
| else: | |||||
| self.stream_tool_call = False | |||||
| def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity: | def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity: | ||||
| """ | """ | ||||
| Repacket app orchestration config | Repacket app orchestration config |
| from models.model import Conversation, Message | from models.model import Conversation, Message | ||||
| class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): | class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): | ||||
| def run(self, model_instance: ModelInstance, | |||||
| conversation: Conversation, | |||||
| def run(self, conversation: Conversation, | |||||
| message: Message, | message: Message, | ||||
| query: str, | query: str, | ||||
| ) -> Union[Generator, LLMResult]: | ) -> Union[Generator, LLMResult]: | ||||
| llm_usage.prompt_price += usage.prompt_price | llm_usage.prompt_price += usage.prompt_price | ||||
| llm_usage.completion_price += usage.completion_price | llm_usage.completion_price += usage.completion_price | ||||
| model_instance = self.model_instance | |||||
| while function_call_state and iteration_step <= max_iteration_steps: | while function_call_state and iteration_step <= max_iteration_steps: | ||||
| # continue to run until there is not any tool call | # continue to run until there is not any tool call | ||||
| function_call_state = False | function_call_state = False | ||||
| # remove Action: xxx from agent thought | # remove Action: xxx from agent thought | ||||
| agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE) | agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE) | ||||
| if action_name and action_input: | |||||
| if action_name and action_input is not None: | |||||
| return AgentScratchpadUnit( | return AgentScratchpadUnit( | ||||
| agent_response=content, | agent_response=content, | ||||
| thought=agent_thought, | thought=agent_thought, |
| from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\ | from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\ | ||||
| SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool | SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool | ||||
| from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage | |||||
| from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage, LLMResultChunkDelta | |||||
| from core.model_manager import ModelInstance | from core.model_manager import ModelInstance | ||||
| from core.application_queue_manager import PublishFrom | from core.application_queue_manager import PublishFrom | ||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | ||||
| def run(self, model_instance: ModelInstance, | |||||
| conversation: Conversation, | |||||
| def run(self, conversation: Conversation, | |||||
| message: Message, | message: Message, | ||||
| query: str, | query: str, | ||||
| ) -> Generator[LLMResultChunk, None, None]: | ) -> Generator[LLMResultChunk, None, None]: | ||||
| llm_usage.prompt_price += usage.prompt_price | llm_usage.prompt_price += usage.prompt_price | ||||
| llm_usage.completion_price += usage.completion_price | llm_usage.completion_price += usage.completion_price | ||||
| model_instance = self.model_instance | |||||
| while function_call_state and iteration_step <= max_iteration_steps: | while function_call_state and iteration_step <= max_iteration_steps: | ||||
| function_call_state = False | function_call_state = False | ||||
| # recale llm max tokens | # recale llm max tokens | ||||
| self.recale_llm_max_tokens(self.model_config, prompt_messages) | self.recale_llm_max_tokens(self.model_config, prompt_messages) | ||||
| # invoke model | # invoke model | ||||
| chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( | |||||
| chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( | |||||
| prompt_messages=prompt_messages, | prompt_messages=prompt_messages, | ||||
| model_parameters=app_orchestration_config.model_config.parameters, | model_parameters=app_orchestration_config.model_config.parameters, | ||||
| tools=prompt_messages_tools, | tools=prompt_messages_tools, | ||||
| stop=app_orchestration_config.model_config.stop, | stop=app_orchestration_config.model_config.stop, | ||||
| stream=True, | |||||
| stream=self.stream_tool_call, | |||||
| user=self.user_id, | user=self.user_id, | ||||
| callbacks=[], | callbacks=[], | ||||
| ) | ) | ||||
| current_llm_usage = None | current_llm_usage = None | ||||
| for chunk in chunks: | |||||
| if self.stream_tool_call: | |||||
| for chunk in chunks: | |||||
| # check if there is any tool call | |||||
| if self.check_tool_calls(chunk): | |||||
| function_call_state = True | |||||
| tool_calls.extend(self.extract_tool_calls(chunk)) | |||||
| tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) | |||||
| try: | |||||
| tool_call_inputs = json.dumps({ | |||||
| tool_call[1]: tool_call[2] for tool_call in tool_calls | |||||
| }, ensure_ascii=False) | |||||
| except json.JSONDecodeError as e: | |||||
| # ensure ascii to avoid encoding error | |||||
| tool_call_inputs = json.dumps({ | |||||
| tool_call[1]: tool_call[2] for tool_call in tool_calls | |||||
| }) | |||||
| if chunk.delta.message and chunk.delta.message.content: | |||||
| if isinstance(chunk.delta.message.content, list): | |||||
| for content in chunk.delta.message.content: | |||||
| response += content.data | |||||
| else: | |||||
| response += chunk.delta.message.content | |||||
| if chunk.delta.usage: | |||||
| increase_usage(llm_usage, chunk.delta.usage) | |||||
| current_llm_usage = chunk.delta.usage | |||||
| yield chunk | |||||
| else: | |||||
| result: LLMResult = chunks | |||||
| # check if there is any tool call | # check if there is any tool call | ||||
| if self.check_tool_calls(chunk): | |||||
| if self.check_blocking_tool_calls(result): | |||||
| function_call_state = True | function_call_state = True | ||||
| tool_calls.extend(self.extract_tool_calls(chunk)) | |||||
| tool_calls.extend(self.extract_blocking_tool_calls(result)) | |||||
| tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) | tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) | ||||
| try: | try: | ||||
| tool_call_inputs = json.dumps({ | tool_call_inputs = json.dumps({ | ||||
| tool_call[1]: tool_call[2] for tool_call in tool_calls | tool_call[1]: tool_call[2] for tool_call in tool_calls | ||||
| }) | }) | ||||
| if chunk.delta.message and chunk.delta.message.content: | |||||
| if isinstance(chunk.delta.message.content, list): | |||||
| for content in chunk.delta.message.content: | |||||
| if result.usage: | |||||
| increase_usage(llm_usage, result.usage) | |||||
| current_llm_usage = result.usage | |||||
| if result.message and result.message.content: | |||||
| if isinstance(result.message.content, list): | |||||
| for content in result.message.content: | |||||
| response += content.data | response += content.data | ||||
| else: | else: | ||||
| response += chunk.delta.message.content | |||||
| if chunk.delta.usage: | |||||
| increase_usage(llm_usage, chunk.delta.usage) | |||||
| current_llm_usage = chunk.delta.usage | |||||
| response += result.message.content | |||||
| if not result.message.content: | |||||
| result.message.content = '' | |||||
| yield LLMResultChunk( | |||||
| model=model_instance.model, | |||||
| prompt_messages=result.prompt_messages, | |||||
| system_fingerprint=result.system_fingerprint, | |||||
| delta=LLMResultChunkDelta( | |||||
| index=0, | |||||
| message=result.message, | |||||
| usage=result.usage, | |||||
| ) | |||||
| ) | |||||
| yield chunk | |||||
| if tool_calls: | |||||
| prompt_messages.append(AssistantPromptMessage( | |||||
| content='', | |||||
| name='', | |||||
| tool_calls=[AssistantPromptMessage.ToolCall( | |||||
| id=tool_call[0], | |||||
| type='function', | |||||
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( | |||||
| name=tool_call[1], | |||||
| arguments=json.dumps(tool_call[2], ensure_ascii=False) | |||||
| ) | |||||
| ) for tool_call in tool_calls] | |||||
| )) | |||||
| # save thought | # save thought | ||||
| self.save_agent_thought( | self.save_agent_thought( | ||||
| final_answer += response + '\n' | final_answer += response + '\n' | ||||
| # update prompt messages | |||||
| if response.strip(): | |||||
| prompt_messages.append(AssistantPromptMessage( | |||||
| content=response, | |||||
| )) | |||||
| # call tools | # call tools | ||||
| tool_responses = [] | tool_responses = [] | ||||
| for tool_call_id, tool_call_name, tool_call_args in tool_calls: | for tool_call_id, tool_call_name, tool_call_args in tool_calls: | ||||
| ) | ) | ||||
| self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) | self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) | ||||
| # update prompt messages | |||||
| if response.strip(): | |||||
| prompt_messages.append(AssistantPromptMessage( | |||||
| content=response, | |||||
| )) | |||||
| # update prompt tool | # update prompt tool | ||||
| for prompt_tool in prompt_messages_tools: | for prompt_tool in prompt_messages_tools: | ||||
| self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) | self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) | ||||
| if llm_result_chunk.delta.message.tool_calls: | if llm_result_chunk.delta.message.tool_calls: | ||||
| return True | return True | ||||
| return False | return False | ||||
| def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: | |||||
| """ | |||||
| Check if there is any blocking tool call in llm result | |||||
| """ | |||||
| if llm_result.message.tool_calls: | |||||
| return True | |||||
| return False | |||||
| def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: | def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: | ||||
| """ | """ | ||||
| )) | )) | ||||
| return tool_calls | return tool_calls | ||||
| def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: | |||||
| """ | |||||
| Extract blocking tool calls from llm result | |||||
| Returns: | |||||
| List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] | |||||
| """ | |||||
| tool_calls = [] | |||||
| for prompt_message in llm_result.message.tool_calls: | |||||
| tool_calls.append(( | |||||
| prompt_message.id, | |||||
| prompt_message.function.name, | |||||
| json.loads(prompt_message.function.arguments), | |||||
| )) | |||||
| return tool_calls | |||||
| def organize_prompt_messages(self, prompt_template: str, | def organize_prompt_messages(self, prompt_template: str, | ||||
| query: str = None, | query: str = None, |
| MULTI_TOOL_CALL = "multi-tool-call" | MULTI_TOOL_CALL = "multi-tool-call" | ||||
| AGENT_THOUGHT = "agent-thought" | AGENT_THOUGHT = "agent-thought" | ||||
| VISION = "vision" | VISION = "vision" | ||||
| STREAM_TOOL_CALL = "stream-tool-call" | |||||
| class DefaultParameterName(Enum): | class DefaultParameterName(Enum): |
| features=[ | features=[ | ||||
| ModelFeature.AGENT_THOUGHT, | ModelFeature.AGENT_THOUGHT, | ||||
| ModelFeature.MULTI_TOOL_CALL, | ModelFeature.MULTI_TOOL_CALL, | ||||
| ModelFeature.STREAM_TOOL_CALL, | |||||
| ], | ], | ||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||||
| model_properties={ | model_properties={ | ||||
| features=[ | features=[ | ||||
| ModelFeature.AGENT_THOUGHT, | ModelFeature.AGENT_THOUGHT, | ||||
| ModelFeature.MULTI_TOOL_CALL, | ModelFeature.MULTI_TOOL_CALL, | ||||
| ModelFeature.STREAM_TOOL_CALL, | |||||
| ], | ], | ||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||||
| model_properties={ | model_properties={ | ||||
| features=[ | features=[ | ||||
| ModelFeature.AGENT_THOUGHT, | ModelFeature.AGENT_THOUGHT, | ||||
| ModelFeature.MULTI_TOOL_CALL, | ModelFeature.MULTI_TOOL_CALL, | ||||
| ModelFeature.STREAM_TOOL_CALL, | |||||
| ], | ], | ||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||||
| model_properties={ | model_properties={ | ||||
| features=[ | features=[ | ||||
| ModelFeature.AGENT_THOUGHT, | ModelFeature.AGENT_THOUGHT, | ||||
| ModelFeature.MULTI_TOOL_CALL, | ModelFeature.MULTI_TOOL_CALL, | ||||
| ModelFeature.STREAM_TOOL_CALL, | |||||
| ], | ], | ||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||||
| model_properties={ | model_properties={ | ||||
| features=[ | features=[ | ||||
| ModelFeature.AGENT_THOUGHT, | ModelFeature.AGENT_THOUGHT, | ||||
| ModelFeature.MULTI_TOOL_CALL, | ModelFeature.MULTI_TOOL_CALL, | ||||
| ModelFeature.STREAM_TOOL_CALL, | |||||
| ], | ], | ||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||||
| model_properties={ | model_properties={ |
| tools: Optional[list[PromptMessageTool]] = None) -> Generator: | tools: Optional[list[PromptMessageTool]] = None) -> Generator: | ||||
| index = 0 | index = 0 | ||||
| full_assistant_content = '' | full_assistant_content = '' | ||||
| delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None | |||||
| real_model = model | real_model = model | ||||
| system_fingerprint = None | system_fingerprint = None | ||||
| completion = '' | completion = '' | ||||
| delta = chunk.choices[0] | delta = chunk.choices[0] | ||||
| if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): | |||||
| if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \ | |||||
| delta.delta.function_call is None: | |||||
| continue | continue | ||||
| # assistant_message_tool_calls = delta.delta.tool_calls | # assistant_message_tool_calls = delta.delta.tool_calls | ||||
| assistant_message_function_call = delta.delta.function_call | assistant_message_function_call = delta.delta.function_call | ||||
| # extract tool calls from response | |||||
| if delta_assistant_message_function_call_storage is not None: | |||||
| # handle process of stream function call | |||||
| if assistant_message_function_call: | |||||
| # message has not ended ever | |||||
| delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments | |||||
| continue | |||||
| else: | |||||
| # message has ended | |||||
| assistant_message_function_call = delta_assistant_message_function_call_storage | |||||
| delta_assistant_message_function_call_storage = None | |||||
| else: | |||||
| if assistant_message_function_call: | |||||
| # start of stream function call | |||||
| delta_assistant_message_function_call_storage = assistant_message_function_call | |||||
| if delta_assistant_message_function_call_storage.arguments is None: | |||||
| delta_assistant_message_function_call_storage.arguments = '' | |||||
| continue | |||||
| # extract tool calls from response | # extract tool calls from response | ||||
| # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) | # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) | ||||
| function_call = self._extract_response_function_call(assistant_message_function_call) | function_call = self._extract_response_function_call(assistant_message_function_call) | ||||
| else: | else: | ||||
| raise ValueError(f"Got unknown type {message}") | raise ValueError(f"Got unknown type {message}") | ||||
| if message.name is not None: | |||||
| if message.name: | |||||
| message_dict["name"] = message.name | message_dict["name"] = message.name | ||||
| return message_dict | return message_dict | ||||
| num_tokens = 0 | num_tokens = 0 | ||||
| for tool in tools: | for tool in tools: | ||||
| num_tokens += len(encoding.encode('type')) | num_tokens += len(encoding.encode('type')) | ||||
| num_tokens += len(encoding.encode(tool.get("type"))) | |||||
| num_tokens += len(encoding.encode('function')) | num_tokens += len(encoding.encode('function')) | ||||
| # calculate num tokens for function object | # calculate num tokens for function object |
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | ||||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction, | from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction, | ||||
| PromptMessageTool, SystemPromptMessage, UserPromptMessage) | |||||
| PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage) | |||||
| from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, | from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, | ||||
| InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) | InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) | ||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| elif isinstance(message, SystemPromptMessage): | elif isinstance(message, SystemPromptMessage): | ||||
| message = cast(SystemPromptMessage, message) | message = cast(SystemPromptMessage, message) | ||||
| message_dict = {"role": "system", "content": message.content} | message_dict = {"role": "system", "content": message.content} | ||||
| elif isinstance(message, ToolPromptMessage): | |||||
| # check if last message is user message | |||||
| message = cast(ToolPromptMessage, message) | |||||
| message_dict = {"role": "function", "content": message.content} | |||||
| else: | else: | ||||
| raise ValueError(f"Unknown message type {type(message)}") | raise ValueError(f"Unknown message type {type(message)}") | ||||
| model_type: llm | model_type: llm | ||||
| features: | features: | ||||
| - agent-thought | - agent-thought | ||||
| - tool-call | |||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 16384 | context_size: 16384 |
| model_type: llm | model_type: llm | ||||
| features: | features: | ||||
| - agent-thought | - agent-thought | ||||
| - tool-call | |||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 32768 | context_size: 32768 |
| """ | """ | ||||
| def generate(self, model: str, api_key: str, group_id: str, | def generate(self, model: str, api_key: str, group_id: str, | ||||
| prompt_messages: List[MinimaxMessage], model_parameters: dict, | prompt_messages: List[MinimaxMessage], model_parameters: dict, | ||||
| tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \ | |||||
| tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \ | |||||
| -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: | -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: | ||||
| """ | """ | ||||
| generate chat completion | generate chat completion | ||||
| continue | continue | ||||
| for choice in choices: | for choice in choices: | ||||
| print(choice) | |||||
| message = choice['delta'] | message = choice['delta'] | ||||
| yield MinimaxMessage( | yield MinimaxMessage( | ||||
| content=message, | content=message, |
| """ | """ | ||||
| def generate(self, model: str, api_key: str, group_id: str, | def generate(self, model: str, api_key: str, group_id: str, | ||||
| prompt_messages: List[MinimaxMessage], model_parameters: dict, | prompt_messages: List[MinimaxMessage], model_parameters: dict, | ||||
| tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \ | |||||
| tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \ | |||||
| -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: | -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: | ||||
| """ | """ | ||||
| generate chat completion | generate chat completion | ||||
| **extra_kwargs | **extra_kwargs | ||||
| } | } | ||||
| if tools: | |||||
| body['functions'] = tools | |||||
| body['function_call'] = { 'type': 'auto' } | |||||
| try: | try: | ||||
| response = post( | response = post( | ||||
| url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) | url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) | ||||
| """ | """ | ||||
| handle stream chat generate response | handle stream chat generate response | ||||
| """ | """ | ||||
| function_call_storage = None | |||||
| for line in response.iter_lines(): | for line in response.iter_lines(): | ||||
| if not line: | if not line: | ||||
| continue | continue | ||||
| msg = data['base_resp']['status_msg'] | msg = data['base_resp']['status_msg'] | ||||
| self._handle_error(code, msg) | self._handle_error(code, msg) | ||||
| if data['reply']: | |||||
| if data['reply'] or 'usage' in data and data['usage']: | |||||
| total_tokens = data['usage']['total_tokens'] | total_tokens = data['usage']['total_tokens'] | ||||
| message = MinimaxMessage( | message = MinimaxMessage( | ||||
| role=MinimaxMessage.Role.ASSISTANT.value, | role=MinimaxMessage.Role.ASSISTANT.value, | ||||
| 'total_tokens': total_tokens | 'total_tokens': total_tokens | ||||
| } | } | ||||
| message.stop_reason = data['choices'][0]['finish_reason'] | message.stop_reason = data['choices'][0]['finish_reason'] | ||||
| if function_call_storage: | |||||
| function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) | |||||
| function_call_message.function_call = function_call_storage | |||||
| yield function_call_message | |||||
| yield message | yield message | ||||
| return | return | ||||
| continue | continue | ||||
| for choice in choices: | for choice in choices: | ||||
| message = choice['messages'][0]['text'] | |||||
| if not message: | |||||
| continue | |||||
| message = choice['messages'][0] | |||||
| if 'function_call' in message: | |||||
| if not function_call_storage: | |||||
| function_call_storage = message['function_call'] | |||||
| if 'arguments' not in function_call_storage or not function_call_storage['arguments']: | |||||
| function_call_storage['arguments'] = '' | |||||
| continue | |||||
| else: | |||||
| function_call_storage['arguments'] += message['function_call']['arguments'] | |||||
| continue | |||||
| else: | |||||
| if function_call_storage: | |||||
| message['function_call'] = function_call_storage | |||||
| function_call_storage = None | |||||
| yield MinimaxMessage( | |||||
| content=message, | |||||
| role=MinimaxMessage.Role.ASSISTANT.value | |||||
| ) | |||||
| minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) | |||||
| if 'function_call' in message: | |||||
| minimax_message.function_call = message['function_call'] | |||||
| if 'text' in message: | |||||
| minimax_message.content = message['text'] | |||||
| yield minimax_message |
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | ||||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, | from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, | ||||
| SystemPromptMessage, UserPromptMessage) | |||||
| SystemPromptMessage, UserPromptMessage, ToolPromptMessage) | |||||
| from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, | from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, | ||||
| InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) | InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) | ||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| """ | """ | ||||
| client: MinimaxChatCompletionPro = self.model_apis[model]() | client: MinimaxChatCompletionPro = self.model_apis[model]() | ||||
| if tools: | |||||
| tools = [{ | |||||
| "name": tool.name, | |||||
| "description": tool.description, | |||||
| "parameters": tool.parameters | |||||
| } for tool in tools] | |||||
| response = client.generate( | response = client.generate( | ||||
| model=model, | model=model, | ||||
| api_key=credentials['minimax_api_key'], | api_key=credentials['minimax_api_key'], | ||||
| elif isinstance(prompt_message, UserPromptMessage): | elif isinstance(prompt_message, UserPromptMessage): | ||||
| return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content) | return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content) | ||||
| elif isinstance(prompt_message, AssistantPromptMessage): | elif isinstance(prompt_message, AssistantPromptMessage): | ||||
| if prompt_message.tool_calls: | |||||
| message = MinimaxMessage( | |||||
| role=MinimaxMessage.Role.ASSISTANT.value, | |||||
| content='' | |||||
| ) | |||||
| message.function_call={ | |||||
| 'name': prompt_message.tool_calls[0].function.name, | |||||
| 'arguments': prompt_message.tool_calls[0].function.arguments | |||||
| } | |||||
| return message | |||||
| return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content) | return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content) | ||||
| elif isinstance(prompt_message, ToolPromptMessage): | |||||
| return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content) | |||||
| else: | else: | ||||
| raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') | raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') | ||||
| finish_reason=message.stop_reason if message.stop_reason else None, | finish_reason=message.stop_reason if message.stop_reason else None, | ||||
| ), | ), | ||||
| ) | ) | ||||
| elif message.function_call: | |||||
| if 'name' not in message.function_call or 'arguments' not in message.function_call: | |||||
| continue | |||||
| yield LLMResultChunk( | |||||
| model=model, | |||||
| prompt_messages=prompt_messages, | |||||
| delta=LLMResultChunkDelta( | |||||
| index=0, | |||||
| message=AssistantPromptMessage( | |||||
| content='', | |||||
| tool_calls=[AssistantPromptMessage.ToolCall( | |||||
| id='', | |||||
| type='function', | |||||
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( | |||||
| name=message.function_call['name'], | |||||
| arguments=message.function_call['arguments'] | |||||
| ) | |||||
| )] | |||||
| ), | |||||
| ), | |||||
| ) | |||||
| else: | else: | ||||
| yield LLMResultChunk( | yield LLMResultChunk( | ||||
| model=model, | model=model, |
| USER = 'USER' | USER = 'USER' | ||||
| ASSISTANT = 'BOT' | ASSISTANT = 'BOT' | ||||
| SYSTEM = 'SYSTEM' | SYSTEM = 'SYSTEM' | ||||
| FUNCTION = 'FUNCTION' | |||||
| role: str = Role.USER.value | role: str = Role.USER.value | ||||
| content: str | content: str | ||||
| usage: Dict[str, int] = None | usage: Dict[str, int] = None | ||||
| stop_reason: str = '' | stop_reason: str = '' | ||||
| function_call: Dict[str, Any] = None | |||||
| def to_dict(self) -> Dict[str, Any]: | def to_dict(self) -> Dict[str, Any]: | ||||
| if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: | |||||
| return { | |||||
| 'sender_type': 'BOT', | |||||
| 'sender_name': '专家', | |||||
| 'text': '', | |||||
| 'function_call': self.function_call | |||||
| } | |||||
| return { | return { | ||||
| 'sender_type': self.role, | 'sender_type': self.role, | ||||
| 'sender_name': '我' if self.role == 'USER' else '专家', | 'sender_name': '我' if self.role == 'USER' else '专家', |
| features: | features: | ||||
| - multi-tool-call | - multi-tool-call | ||||
| - agent-thought | - agent-thought | ||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 4096 | context_size: 4096 |
| features: | features: | ||||
| - multi-tool-call | - multi-tool-call | ||||
| - agent-thought | - agent-thought | ||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 16385 | context_size: 16385 |
| features: | features: | ||||
| - multi-tool-call | - multi-tool-call | ||||
| - agent-thought | - agent-thought | ||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 16385 | context_size: 16385 |
| features: | features: | ||||
| - multi-tool-call | - multi-tool-call | ||||
| - agent-thought | - agent-thought | ||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 16385 | context_size: 16385 |
| features: | features: | ||||
| - multi-tool-call | - multi-tool-call | ||||
| - agent-thought | - agent-thought | ||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 4096 | context_size: 4096 |
| features: | features: | ||||
| - multi-tool-call | - multi-tool-call | ||||
| - agent-thought | - agent-thought | ||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 128000 | context_size: 128000 |
| features: | features: | ||||
| - multi-tool-call | - multi-tool-call | ||||
| - agent-thought | - agent-thought | ||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 128000 | context_size: 128000 |
| features: | features: | ||||
| - multi-tool-call | - multi-tool-call | ||||
| - agent-thought | - agent-thought | ||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 32768 | context_size: 32768 |
| features: | features: | ||||
| - multi-tool-call | - multi-tool-call | ||||
| - agent-thought | - agent-thought | ||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 128000 | context_size: 128000 |
| features: | features: | ||||
| - multi-tool-call | - multi-tool-call | ||||
| - agent-thought | - agent-thought | ||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| context_size: 8192 | context_size: 8192 |
| else: | else: | ||||
| raise ValueError(f"Got unknown type {message}") | raise ValueError(f"Got unknown type {message}") | ||||
| if message.name is not None: | |||||
| if message.name: | |||||
| message_dict["name"] = message.name | message_dict["name"] = message.name | ||||
| return message_dict | return message_dict |
| from core.model_runtime.entities.common_entities import I18nObject | from core.model_runtime.entities.common_entities import I18nObject | ||||
| from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta | from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta | ||||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, | from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, | ||||
| SystemPromptMessage, UserPromptMessage) | |||||
| SystemPromptMessage, UserPromptMessage, ToolPromptMessage) | |||||
| from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, | from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, | ||||
| ParameterRule, ParameterType) | |||||
| ParameterRule, ParameterType, ModelFeature) | |||||
| from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, | from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, | ||||
| InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) | InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) | ||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.model_runtime.model_providers.xinference.llm.xinference_helper import (XinferenceHelper, | |||||
| from core.model_runtime.model_providers.xinference.xinference_helper import (XinferenceHelper, | |||||
| XinferenceModelExtraParameter) | XinferenceModelExtraParameter) | ||||
| from core.model_runtime.utils import helper | from core.model_runtime.utils import helper | ||||
| from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, | from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, | ||||
| see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` | see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` | ||||
| """ | """ | ||||
| if 'temperature' in model_parameters: | |||||
| if model_parameters['temperature'] < 0.01: | |||||
| model_parameters['temperature'] = 0.01 | |||||
| elif model_parameters['temperature'] > 1.0: | |||||
| model_parameters['temperature'] = 0.99 | |||||
| return self._generate( | return self._generate( | ||||
| model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, | model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, | ||||
| tools=tools, stop=stop, stream=stream, user=user, | tools=tools, stop=stop, stream=stream, user=user, | ||||
| credentials['completion_type'] = 'completion' | credentials['completion_type'] = 'completion' | ||||
| else: | else: | ||||
| raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported') | raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported') | ||||
| if extra_param.support_function_call: | |||||
| credentials['support_function_call'] = True | |||||
| except RuntimeError as e: | except RuntimeError as e: | ||||
| raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') | raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') | ||||
| elif isinstance(message, SystemPromptMessage): | elif isinstance(message, SystemPromptMessage): | ||||
| message = cast(SystemPromptMessage, message) | message = cast(SystemPromptMessage, message) | ||||
| message_dict = {"role": "system", "content": message.content} | message_dict = {"role": "system", "content": message.content} | ||||
| elif isinstance(message, ToolPromptMessage): | |||||
| message = cast(ToolPromptMessage, message) | |||||
| message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} | |||||
| else: | else: | ||||
| raise ValueError(f"Unknown message type {type(message)}") | raise ValueError(f"Unknown message type {type(message)}") | ||||
| label=I18nObject( | label=I18nObject( | ||||
| zh_Hans='温度', | zh_Hans='温度', | ||||
| en_US='Temperature' | en_US='Temperature' | ||||
| ) | |||||
| ), | |||||
| ), | ), | ||||
| ParameterRule( | ParameterRule( | ||||
| name='top_p', | name='top_p', | ||||
| completion_type = LLMMode.COMPLETION.value | completion_type = LLMMode.COMPLETION.value | ||||
| else: | else: | ||||
| raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') | raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') | ||||
| support_function_call = credentials.get('support_function_call', False) | |||||
| entity = AIModelEntity( | entity = AIModelEntity( | ||||
| model=model, | model=model, | ||||
| ), | ), | ||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||||
| model_type=ModelType.LLM, | model_type=ModelType.LLM, | ||||
| features=[ | |||||
| ModelFeature.TOOL_CALL | |||||
| ] if support_function_call else [], | |||||
| model_properties={ | model_properties={ | ||||
| ModelPropertyKey.MODE: completion_type, | ModelPropertyKey.MODE: completion_type, | ||||
| }, | }, | ||||
| extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` | extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` | ||||
| """ | """ | ||||
| if 'server_url' not in credentials: | |||||
| raise CredentialsValidateFailedError('server_url is required in credentials') | |||||
| if credentials['server_url'].endswith('/'): | |||||
| credentials['server_url'] = credentials['server_url'][:-1] | |||||
| client = OpenAI( | client = OpenAI( | ||||
| base_url=f'{credentials["server_url"]}/v1', | base_url=f'{credentials["server_url"]}/v1', | ||||
| api_key='abc', | api_key='abc', |
| from typing import Optional | from typing import Optional | ||||
| from core.model_runtime.entities.common_entities import I18nObject | from core.model_runtime.entities.common_entities import I18nObject | ||||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType | |||||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType, ModelPropertyKey | |||||
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | ||||
| from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, | from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, | ||||
| InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) | InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) | ||||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | ||||
| from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle | from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle | ||||
| from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper | |||||
| class XinferenceTextEmbeddingModel(TextEmbeddingModel): | class XinferenceTextEmbeddingModel(TextEmbeddingModel): | ||||
| """ | """ | ||||
| """ | """ | ||||
| server_url = credentials['server_url'] | server_url = credentials['server_url'] | ||||
| model_uid = credentials['model_uid'] | model_uid = credentials['model_uid'] | ||||
| if server_url.endswith('/'): | |||||
| server_url = server_url[:-1] | |||||
| client = Client(base_url=server_url) | client = Client(base_url=server_url) | ||||
| try: | try: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| try: | try: | ||||
| server_url = credentials['server_url'] | |||||
| model_uid = credentials['model_uid'] | |||||
| extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) | |||||
| if extra_args.max_tokens: | |||||
| credentials['max_tokens'] = extra_args.max_tokens | |||||
| self._invoke(model=model, credentials=credentials, texts=['ping']) | self._invoke(model=model, credentials=credentials, texts=['ping']) | ||||
| except InvokeAuthorizationError: | |||||
| except (InvokeAuthorizationError, RuntimeError): | |||||
| raise CredentialsValidateFailedError('Invalid api key') | raise CredentialsValidateFailedError('Invalid api key') | ||||
| @property | @property | ||||
| """ | """ | ||||
| used to define customizable model schema | used to define customizable model schema | ||||
| """ | """ | ||||
| entity = AIModelEntity( | entity = AIModelEntity( | ||||
| model=model, | model=model, | ||||
| label=I18nObject( | label=I18nObject( | ||||
| ), | ), | ||||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||||
| model_type=ModelType.TEXT_EMBEDDING, | model_type=ModelType.TEXT_EMBEDDING, | ||||
| model_properties={}, | |||||
| model_properties={ | |||||
| ModelPropertyKey.MAX_CHUNKS: 1, | |||||
| ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512, | |||||
| }, | |||||
| parameter_rules=[] | parameter_rules=[] | ||||
| ) | ) | ||||
| from threading import Lock | from threading import Lock | ||||
| from time import time | from time import time | ||||
| from typing import List | from typing import List | ||||
| from os import path | |||||
| from requests import get | from requests import get | ||||
| from requests.adapters import HTTPAdapter | from requests.adapters import HTTPAdapter | ||||
| model_format: str | model_format: str | ||||
| model_handle_type: str | model_handle_type: str | ||||
| model_ability: List[str] | model_ability: List[str] | ||||
| max_tokens: int = 512 | |||||
| support_function_call: bool = False | |||||
| def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str]) -> None: | |||||
| def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str], | |||||
| support_function_call: bool, max_tokens: int) -> None: | |||||
| self.model_format = model_format | self.model_format = model_format | ||||
| self.model_handle_type = model_handle_type | self.model_handle_type = model_handle_type | ||||
| self.model_ability = model_ability | self.model_ability = model_ability | ||||
| self.support_function_call = support_function_call | |||||
| self.max_tokens = max_tokens | |||||
| cache = {} | cache = {} | ||||
| cache_lock = Lock() | cache_lock = Lock() | ||||
| get xinference model extra parameter like model_format and model_handle_type | get xinference model extra parameter like model_format and model_handle_type | ||||
| """ | """ | ||||
| url = f'{server_url}/v1/models/{model_uid}' | |||||
| url = path.join(server_url, 'v1/models', model_uid) | |||||
| # this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 | # this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 | ||||
| session = Session() | session = Session() | ||||
| response_json = response.json() | response_json = response.json() | ||||
| model_format = response_json['model_format'] | |||||
| model_ability = response_json['model_ability'] | |||||
| model_format = response_json.get('model_format', 'ggmlv3') | |||||
| model_ability = response_json.get('model_ability', []) | |||||
| if model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: | |||||
| if response_json.get('model_type') == 'embedding': | |||||
| model_handle_type = 'embedding' | |||||
| elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: | |||||
| model_handle_type = 'chatglm' | model_handle_type = 'chatglm' | ||||
| elif 'generate' in model_ability: | elif 'generate' in model_ability: | ||||
| model_handle_type = 'generate' | model_handle_type = 'generate' | ||||
| else: | else: | ||||
| raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') | raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') | ||||
| support_function_call = 'tools' in model_ability | |||||
| max_tokens = response_json.get('max_tokens', 512) | |||||
| return XinferenceModelExtraParameter( | return XinferenceModelExtraParameter( | ||||
| model_format=model_format, | model_format=model_format, | ||||
| model_handle_type=model_handle_type, | model_handle_type=model_handle_type, | ||||
| model_ability=model_ability | |||||
| model_ability=model_ability, | |||||
| support_function_call=support_function_call, | |||||
| max_tokens=max_tokens | |||||
| ) | ) |
| label: | label: | ||||
| en_US: glm-3-turbo | en_US: glm-3-turbo | ||||
| model_type: llm | model_type: llm | ||||
| features: | |||||
| - multi-tool-call | |||||
| - agent-thought | |||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| parameter_rules: | parameter_rules: |
| label: | label: | ||||
| en_US: glm-4 | en_US: glm-4 | ||||
| model_type: llm | model_type: llm | ||||
| features: | |||||
| - multi-tool-call | |||||
| - agent-thought | |||||
| - stream-tool-call | |||||
| model_properties: | model_properties: | ||||
| mode: chat | mode: chat | ||||
| parameter_rules: | parameter_rules: |
| 'content': prompt_message.content, | 'content': prompt_message.content, | ||||
| 'tool_call_id': prompt_message.tool_call_id | 'tool_call_id': prompt_message.tool_call_id | ||||
| }) | }) | ||||
| elif isinstance(prompt_message, AssistantPromptMessage): | |||||
| if prompt_message.tool_calls: | |||||
| params['messages'].append({ | |||||
| 'role': 'assistant', | |||||
| 'content': prompt_message.content, | |||||
| 'tool_calls': [ | |||||
| { | |||||
| 'id': tool_call.id, | |||||
| 'type': tool_call.type, | |||||
| 'function': { | |||||
| 'name': tool_call.function.name, | |||||
| 'arguments': tool_call.function.arguments | |||||
| } | |||||
| } for tool_call in prompt_message.tool_calls | |||||
| ] | |||||
| }) | |||||
| else: | |||||
| params['messages'].append({ | |||||
| 'role': 'assistant', | |||||
| 'content': prompt_message.content | |||||
| }) | |||||
| else: | else: | ||||
| params['messages'].append({ | params['messages'].append({ | ||||
| 'role': prompt_message.role.value, | 'role': prompt_message.role.value, |
| huggingface_hub~=0.16.4 | huggingface_hub~=0.16.4 | ||||
| transformers~=4.31.0 | transformers~=4.31.0 | ||||
| pandas==1.5.3 | pandas==1.5.3 | ||||
| xinference-client~=0.6.4 | |||||
| xinference-client~=0.8.1 | |||||
| safetensors==0.3.2 | safetensors==0.3.2 | ||||
| zhipuai==1.0.7 | zhipuai==1.0.7 | ||||
| werkzeug~=3.0.1 | werkzeug~=3.0.1 |
| raise RuntimeError('404 Not Found') | raise RuntimeError('404 Not Found') | ||||
| if 'generate' == model_uid: | if 'generate' == model_uid: | ||||
| return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url) | |||||
| return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |||||
| if 'chat' == model_uid: | if 'chat' == model_uid: | ||||
| return RESTfulChatModelHandle(model_uid, base_url=self.base_url) | |||||
| return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |||||
| if 'embedding' == model_uid: | if 'embedding' == model_uid: | ||||
| return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url) | |||||
| return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |||||
| if 'rerank' == model_uid: | if 'rerank' == model_uid: | ||||
| return RESTfulRerankModelHandle(model_uid, base_url=self.base_url) | |||||
| return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |||||
| raise RuntimeError('404 Not Found') | raise RuntimeError('404 Not Found') | ||||
| def get(self: Session, url: str, **kwargs): | def get(self: Session, url: str, **kwargs): | ||||
| if '/v1/models/' in url: | |||||
| response = Response() | |||||
| response = Response() | |||||
| if 'v1/models/' in url: | |||||
| # get model uid | # get model uid | ||||
| model_uid = url.split('/')[-1] | model_uid = url.split('/')[-1] | ||||
| if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ | if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ | ||||
| model_uid not in ['generate', 'chat', 'embedding', 'rerank']: | model_uid not in ['generate', 'chat', 'embedding', 'rerank']: | ||||
| response.status_code = 404 | response.status_code = 404 | ||||
| raise ConnectionError('404 Not Found') | |||||
| return response | |||||
| # check if url is valid | # check if url is valid | ||||
| if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): | if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): | ||||
| response.status_code = 404 | response.status_code = 404 | ||||
| raise ConnectionError('404 Not Found') | |||||
| return response | |||||
| if model_uid in ['generate', 'chat']: | |||||
| response.status_code = 200 | |||||
| response._content = b'''{ | |||||
| "model_type": "LLM", | |||||
| "address": "127.0.0.1:43877", | |||||
| "accelerators": [ | |||||
| "0", | |||||
| "1" | |||||
| ], | |||||
| "model_name": "chatglm3-6b", | |||||
| "model_lang": [ | |||||
| "en" | |||||
| ], | |||||
| "model_ability": [ | |||||
| "generate", | |||||
| "chat" | |||||
| ], | |||||
| "model_description": "latest chatglm3", | |||||
| "model_format": "pytorch", | |||||
| "model_size_in_billions": 7, | |||||
| "quantization": "none", | |||||
| "model_hub": "huggingface", | |||||
| "revision": null, | |||||
| "context_length": 2048, | |||||
| "replica": 1 | |||||
| }''' | |||||
| return response | |||||
| elif model_uid == 'embedding': | |||||
| response.status_code = 200 | |||||
| response._content = b'''{ | |||||
| "model_type": "embedding", | |||||
| "address": "127.0.0.1:43877", | |||||
| "accelerators": [ | |||||
| "0", | |||||
| "1" | |||||
| ], | |||||
| "model_name": "bge", | |||||
| "model_lang": [ | |||||
| "en" | |||||
| ], | |||||
| "revision": null, | |||||
| "max_tokens": 512 | |||||
| }''' | |||||
| return response | |||||
| elif 'v1/cluster/auth' in url: | |||||
| response.status_code = 200 | response.status_code = 200 | ||||
| response._content = b'''{ | response._content = b'''{ | ||||
| "model_type": "LLM", | |||||
| "address": "127.0.0.1:43877", | |||||
| "accelerators": [ | |||||
| "0", | |||||
| "1" | |||||
| ], | |||||
| "model_name": "chatglm3-6b", | |||||
| "model_lang": [ | |||||
| "en" | |||||
| ], | |||||
| "model_ability": [ | |||||
| "generate", | |||||
| "chat" | |||||
| ], | |||||
| "model_description": "latest chatglm3", | |||||
| "model_format": "pytorch", | |||||
| "model_size_in_billions": 7, | |||||
| "quantization": "none", | |||||
| "model_hub": "huggingface", | |||||
| "revision": null, | |||||
| "context_length": 2048, | |||||
| "replica": 1 | |||||
| "auth": true | |||||
| }''' | }''' | ||||
| return response | return response | ||||
| def _check_cluster_authenticated(self): | |||||
| self._cluster_authed = True | |||||
| def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict: | def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict: | ||||
| # check if self._model_uid is a valid uuid | # check if self._model_uid is a valid uuid | ||||
| if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ | if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ | ||||
| def setup_xinference_mock(request, monkeypatch: MonkeyPatch): | def setup_xinference_mock(request, monkeypatch: MonkeyPatch): | ||||
| if MOCK: | if MOCK: | ||||
| monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model) | monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model) | ||||
| monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated) | |||||
| monkeypatch.setattr(Session, 'get', MockXinferenceClass.get) | monkeypatch.setattr(Session, 'get', MockXinferenceClass.get) | ||||
| monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding) | monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding) | ||||
| monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank) | monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank) |