| @@ -11,6 +11,7 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.model_runtime.entities.model_entities import ModelFeature | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.moderation.base import ModerationException | |||
| from core.tools.entities.tool_entities import ToolRuntimeVariablePool | |||
| @@ -194,6 +195,13 @@ class AssistantApplicationRunner(AppRunner): | |||
| memory=memory, | |||
| ) | |||
| # change function call strategy based on LLM model | |||
| llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) | |||
| model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) | |||
| if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features): | |||
| agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING | |||
| # start agent runner | |||
| if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: | |||
| assistant_cot_runner = AssistantCotApplicationRunner( | |||
| @@ -209,9 +217,9 @@ class AssistantApplicationRunner(AppRunner): | |||
| prompt_messages=prompt_message, | |||
| variables_pool=tool_variables, | |||
| db_variables=tool_conversation_variables, | |||
| model_instance=model_instance | |||
| ) | |||
| invoke_result = assistant_cot_runner.run( | |||
| model_instance=model_instance, | |||
| conversation=conversation, | |||
| message=message, | |||
| query=query, | |||
| @@ -229,10 +237,10 @@ class AssistantApplicationRunner(AppRunner): | |||
| memory=memory, | |||
| prompt_messages=prompt_message, | |||
| variables_pool=tool_variables, | |||
| db_variables=tool_conversation_variables | |||
| db_variables=tool_conversation_variables, | |||
| model_instance=model_instance | |||
| ) | |||
| invoke_result = assistant_fc_runner.run( | |||
| model_instance=model_instance, | |||
| conversation=conversation, | |||
| message=message, | |||
| query=query, | |||
| @@ -1,7 +1,7 @@ | |||
| import logging | |||
| import json | |||
| from typing import Optional, List, Tuple, Union | |||
| from typing import Optional, List, Tuple, Union, cast | |||
| from datetime import datetime | |||
| from mimetypes import guess_extension | |||
| @@ -27,7 +27,10 @@ from core.entities.application_entities import ModelConfigEntity, \ | |||
| AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom | |||
| from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.model_runtime.entities.model_entities import ModelFeature | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_manager import ModelInstance | |||
| from core.file.message_file_parser import FileTransferMethod | |||
| logger = logging.getLogger(__name__) | |||
| @@ -45,6 +48,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| prompt_messages: Optional[List[PromptMessage]] = None, | |||
| variables_pool: Optional[ToolRuntimeVariablePool] = None, | |||
| db_variables: Optional[ToolConversationVariables] = None, | |||
| model_instance: ModelInstance = None | |||
| ) -> None: | |||
| """ | |||
| Agent runner | |||
| @@ -71,6 +75,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| self.history_prompt_messages = prompt_messages | |||
| self.variables_pool = variables_pool | |||
| self.db_variables_pool = db_variables | |||
| self.model_instance = model_instance | |||
| # init callback | |||
| self.agent_callback = DifyAgentCallbackHandler() | |||
| @@ -95,6 +100,14 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| MessageAgentThought.message_id == self.message.id, | |||
| ).count() | |||
| # check if model supports stream tool call | |||
| llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) | |||
| model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) | |||
| if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []): | |||
| self.stream_tool_call = True | |||
| else: | |||
| self.stream_tool_call = False | |||
| def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity: | |||
| """ | |||
| Repacket app orchestration config | |||
| @@ -20,8 +20,7 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner | |||
| from models.model import Conversation, Message | |||
| class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): | |||
| def run(self, model_instance: ModelInstance, | |||
| conversation: Conversation, | |||
| def run(self, conversation: Conversation, | |||
| message: Message, | |||
| query: str, | |||
| ) -> Union[Generator, LLMResult]: | |||
| @@ -82,6 +81,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): | |||
| llm_usage.prompt_price += usage.prompt_price | |||
| llm_usage.completion_price += usage.completion_price | |||
| model_instance = self.model_instance | |||
| while function_call_state and iteration_step <= max_iteration_steps: | |||
| # continue to run until there is not any tool call | |||
| function_call_state = False | |||
| @@ -390,7 +391,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): | |||
| # remove Action: xxx from agent thought | |||
| agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE) | |||
| if action_name and action_input: | |||
| if action_name and action_input is not None: | |||
| return AgentScratchpadUnit( | |||
| agent_response=content, | |||
| thought=agent_thought, | |||
| @@ -5,7 +5,7 @@ from typing import Union, Generator, Dict, Any, Tuple, List | |||
| from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\ | |||
| SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool | |||
| from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage | |||
| from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage, LLMResultChunkDelta | |||
| from core.model_manager import ModelInstance | |||
| from core.application_queue_manager import PublishFrom | |||
| @@ -20,8 +20,7 @@ from models.model import Conversation, Message, MessageAgentThought | |||
| logger = logging.getLogger(__name__) | |||
| class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| def run(self, model_instance: ModelInstance, | |||
| conversation: Conversation, | |||
| def run(self, conversation: Conversation, | |||
| message: Message, | |||
| query: str, | |||
| ) -> Generator[LLMResultChunk, None, None]: | |||
| @@ -81,6 +80,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| llm_usage.prompt_price += usage.prompt_price | |||
| llm_usage.completion_price += usage.completion_price | |||
| model_instance = self.model_instance | |||
| while function_call_state and iteration_step <= max_iteration_steps: | |||
| function_call_state = False | |||
| @@ -101,12 +102,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| # recale llm max tokens | |||
| self.recale_llm_max_tokens(self.model_config, prompt_messages) | |||
| # invoke model | |||
| chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( | |||
| chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=app_orchestration_config.model_config.parameters, | |||
| tools=prompt_messages_tools, | |||
| stop=app_orchestration_config.model_config.stop, | |||
| stream=True, | |||
| stream=self.stream_tool_call, | |||
| user=self.user_id, | |||
| callbacks=[], | |||
| ) | |||
| @@ -122,11 +123,41 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| current_llm_usage = None | |||
| for chunk in chunks: | |||
| if self.stream_tool_call: | |||
| for chunk in chunks: | |||
| # check if there is any tool call | |||
| if self.check_tool_calls(chunk): | |||
| function_call_state = True | |||
| tool_calls.extend(self.extract_tool_calls(chunk)) | |||
| tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) | |||
| try: | |||
| tool_call_inputs = json.dumps({ | |||
| tool_call[1]: tool_call[2] for tool_call in tool_calls | |||
| }, ensure_ascii=False) | |||
| except json.JSONDecodeError as e: | |||
| # ensure ascii to avoid encoding error | |||
| tool_call_inputs = json.dumps({ | |||
| tool_call[1]: tool_call[2] for tool_call in tool_calls | |||
| }) | |||
| if chunk.delta.message and chunk.delta.message.content: | |||
| if isinstance(chunk.delta.message.content, list): | |||
| for content in chunk.delta.message.content: | |||
| response += content.data | |||
| else: | |||
| response += chunk.delta.message.content | |||
| if chunk.delta.usage: | |||
| increase_usage(llm_usage, chunk.delta.usage) | |||
| current_llm_usage = chunk.delta.usage | |||
| yield chunk | |||
| else: | |||
| result: LLMResult = chunks | |||
| # check if there is any tool call | |||
| if self.check_tool_calls(chunk): | |||
| if self.check_blocking_tool_calls(result): | |||
| function_call_state = True | |||
| tool_calls.extend(self.extract_tool_calls(chunk)) | |||
| tool_calls.extend(self.extract_blocking_tool_calls(result)) | |||
| tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) | |||
| try: | |||
| tool_call_inputs = json.dumps({ | |||
| @@ -138,18 +169,44 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| tool_call[1]: tool_call[2] for tool_call in tool_calls | |||
| }) | |||
| if chunk.delta.message and chunk.delta.message.content: | |||
| if isinstance(chunk.delta.message.content, list): | |||
| for content in chunk.delta.message.content: | |||
| if result.usage: | |||
| increase_usage(llm_usage, result.usage) | |||
| current_llm_usage = result.usage | |||
| if result.message and result.message.content: | |||
| if isinstance(result.message.content, list): | |||
| for content in result.message.content: | |||
| response += content.data | |||
| else: | |||
| response += chunk.delta.message.content | |||
| if chunk.delta.usage: | |||
| increase_usage(llm_usage, chunk.delta.usage) | |||
| current_llm_usage = chunk.delta.usage | |||
| response += result.message.content | |||
| if not result.message.content: | |||
| result.message.content = '' | |||
| yield LLMResultChunk( | |||
| model=model_instance.model, | |||
| prompt_messages=result.prompt_messages, | |||
| system_fingerprint=result.system_fingerprint, | |||
| delta=LLMResultChunkDelta( | |||
| index=0, | |||
| message=result.message, | |||
| usage=result.usage, | |||
| ) | |||
| ) | |||
| yield chunk | |||
| if tool_calls: | |||
| prompt_messages.append(AssistantPromptMessage( | |||
| content='', | |||
| name='', | |||
| tool_calls=[AssistantPromptMessage.ToolCall( | |||
| id=tool_call[0], | |||
| type='function', | |||
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( | |||
| name=tool_call[1], | |||
| arguments=json.dumps(tool_call[2], ensure_ascii=False) | |||
| ) | |||
| ) for tool_call in tool_calls] | |||
| )) | |||
| # save thought | |||
| self.save_agent_thought( | |||
| @@ -167,6 +224,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| final_answer += response + '\n' | |||
| # update prompt messages | |||
| if response.strip(): | |||
| prompt_messages.append(AssistantPromptMessage( | |||
| content=response, | |||
| )) | |||
| # call tools | |||
| tool_responses = [] | |||
| for tool_call_id, tool_call_name, tool_call_args in tool_calls: | |||
| @@ -256,12 +319,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| ) | |||
| self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) | |||
| # update prompt messages | |||
| if response.strip(): | |||
| prompt_messages.append(AssistantPromptMessage( | |||
| content=response, | |||
| )) | |||
| # update prompt tool | |||
| for prompt_tool in prompt_messages_tools: | |||
| self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) | |||
| @@ -287,6 +344,14 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| if llm_result_chunk.delta.message.tool_calls: | |||
| return True | |||
| return False | |||
| def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: | |||
| """ | |||
| Check if there is any blocking tool call in llm result | |||
| """ | |||
| if llm_result.message.tool_calls: | |||
| return True | |||
| return False | |||
| def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: | |||
| """ | |||
| @@ -304,6 +369,23 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| )) | |||
| return tool_calls | |||
| def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: | |||
| """ | |||
| Extract blocking tool calls from llm result | |||
| Returns: | |||
| List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] | |||
| """ | |||
| tool_calls = [] | |||
| for prompt_message in llm_result.message.tool_calls: | |||
| tool_calls.append(( | |||
| prompt_message.id, | |||
| prompt_message.function.name, | |||
| json.loads(prompt_message.function.arguments), | |||
| )) | |||
| return tool_calls | |||
| def organize_prompt_messages(self, prompt_template: str, | |||
| query: str = None, | |||
| @@ -78,6 +78,7 @@ class ModelFeature(Enum): | |||
| MULTI_TOOL_CALL = "multi-tool-call" | |||
| AGENT_THOUGHT = "agent-thought" | |||
| VISION = "vision" | |||
| STREAM_TOOL_CALL = "stream-tool-call" | |||
| class DefaultParameterName(Enum): | |||
| @@ -36,6 +36,7 @@ LLM_BASE_MODELS = [ | |||
| features=[ | |||
| ModelFeature.AGENT_THOUGHT, | |||
| ModelFeature.MULTI_TOOL_CALL, | |||
| ModelFeature.STREAM_TOOL_CALL, | |||
| ], | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_properties={ | |||
| @@ -80,6 +81,7 @@ LLM_BASE_MODELS = [ | |||
| features=[ | |||
| ModelFeature.AGENT_THOUGHT, | |||
| ModelFeature.MULTI_TOOL_CALL, | |||
| ModelFeature.STREAM_TOOL_CALL, | |||
| ], | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_properties={ | |||
| @@ -124,6 +126,7 @@ LLM_BASE_MODELS = [ | |||
| features=[ | |||
| ModelFeature.AGENT_THOUGHT, | |||
| ModelFeature.MULTI_TOOL_CALL, | |||
| ModelFeature.STREAM_TOOL_CALL, | |||
| ], | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_properties={ | |||
| @@ -198,6 +201,7 @@ LLM_BASE_MODELS = [ | |||
| features=[ | |||
| ModelFeature.AGENT_THOUGHT, | |||
| ModelFeature.MULTI_TOOL_CALL, | |||
| ModelFeature.STREAM_TOOL_CALL, | |||
| ], | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_properties={ | |||
| @@ -272,6 +276,7 @@ LLM_BASE_MODELS = [ | |||
| features=[ | |||
| ModelFeature.AGENT_THOUGHT, | |||
| ModelFeature.MULTI_TOOL_CALL, | |||
| ModelFeature.STREAM_TOOL_CALL, | |||
| ], | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_properties={ | |||
| @@ -324,6 +324,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): | |||
| tools: Optional[list[PromptMessageTool]] = None) -> Generator: | |||
| index = 0 | |||
| full_assistant_content = '' | |||
| delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None | |||
| real_model = model | |||
| system_fingerprint = None | |||
| completion = '' | |||
| @@ -333,12 +334,32 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): | |||
| delta = chunk.choices[0] | |||
| if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): | |||
| if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \ | |||
| delta.delta.function_call is None: | |||
| continue | |||
| # assistant_message_tool_calls = delta.delta.tool_calls | |||
| assistant_message_function_call = delta.delta.function_call | |||
| # extract tool calls from response | |||
| if delta_assistant_message_function_call_storage is not None: | |||
| # handle process of stream function call | |||
| if assistant_message_function_call: | |||
| # message has not ended ever | |||
| delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments | |||
| continue | |||
| else: | |||
| # message has ended | |||
| assistant_message_function_call = delta_assistant_message_function_call_storage | |||
| delta_assistant_message_function_call_storage = None | |||
| else: | |||
| if assistant_message_function_call: | |||
| # start of stream function call | |||
| delta_assistant_message_function_call_storage = assistant_message_function_call | |||
| if delta_assistant_message_function_call_storage.arguments is None: | |||
| delta_assistant_message_function_call_storage.arguments = '' | |||
| continue | |||
| # extract tool calls from response | |||
| # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) | |||
| function_call = self._extract_response_function_call(assistant_message_function_call) | |||
| @@ -489,7 +510,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): | |||
| else: | |||
| raise ValueError(f"Got unknown type {message}") | |||
| if message.name is not None: | |||
| if message.name: | |||
| message_dict["name"] = message.name | |||
| return message_dict | |||
| @@ -586,7 +607,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): | |||
| num_tokens = 0 | |||
| for tool in tools: | |||
| num_tokens += len(encoding.encode('type')) | |||
| num_tokens += len(encoding.encode(tool.get("type"))) | |||
| num_tokens += len(encoding.encode('function')) | |||
| # calculate num tokens for function object | |||
| @@ -5,7 +5,7 @@ from typing import Generator, List, Optional, cast | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction, | |||
| PromptMessageTool, SystemPromptMessage, UserPromptMessage) | |||
| PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage) | |||
| from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, | |||
| InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| @@ -194,6 +194,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): | |||
| elif isinstance(message, SystemPromptMessage): | |||
| message = cast(SystemPromptMessage, message) | |||
| message_dict = {"role": "system", "content": message.content} | |||
| elif isinstance(message, ToolPromptMessage): | |||
| # check if last message is user message | |||
| message = cast(ToolPromptMessage, message) | |||
| message_dict = {"role": "function", "content": message.content} | |||
| else: | |||
| raise ValueError(f"Unknown message type {type(message)}") | |||
| @@ -4,6 +4,8 @@ label: | |||
| model_type: llm | |||
| features: | |||
| - agent-thought | |||
| - tool-call | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 16384 | |||
| @@ -4,6 +4,8 @@ label: | |||
| model_type: llm | |||
| features: | |||
| - agent-thought | |||
| - tool-call | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 32768 | |||
| @@ -16,7 +16,7 @@ class MinimaxChatCompletion(object): | |||
| """ | |||
| def generate(self, model: str, api_key: str, group_id: str, | |||
| prompt_messages: List[MinimaxMessage], model_parameters: dict, | |||
| tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \ | |||
| tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \ | |||
| -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: | |||
| """ | |||
| generate chat completion | |||
| @@ -162,7 +162,6 @@ class MinimaxChatCompletion(object): | |||
| continue | |||
| for choice in choices: | |||
| print(choice) | |||
| message = choice['delta'] | |||
| yield MinimaxMessage( | |||
| content=message, | |||
| @@ -17,7 +17,7 @@ class MinimaxChatCompletionPro(object): | |||
| """ | |||
| def generate(self, model: str, api_key: str, group_id: str, | |||
| prompt_messages: List[MinimaxMessage], model_parameters: dict, | |||
| tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \ | |||
| tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \ | |||
| -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: | |||
| """ | |||
| generate chat completion | |||
| @@ -82,6 +82,10 @@ class MinimaxChatCompletionPro(object): | |||
| **extra_kwargs | |||
| } | |||
| if tools: | |||
| body['functions'] = tools | |||
| body['function_call'] = { 'type': 'auto' } | |||
| try: | |||
| response = post( | |||
| url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) | |||
| @@ -135,6 +139,7 @@ class MinimaxChatCompletionPro(object): | |||
| """ | |||
| handle stream chat generate response | |||
| """ | |||
| function_call_storage = None | |||
| for line in response.iter_lines(): | |||
| if not line: | |||
| continue | |||
| @@ -148,7 +153,7 @@ class MinimaxChatCompletionPro(object): | |||
| msg = data['base_resp']['status_msg'] | |||
| self._handle_error(code, msg) | |||
| if data['reply']: | |||
| if data['reply'] or 'usage' in data and data['usage']: | |||
| total_tokens = data['usage']['total_tokens'] | |||
| message = MinimaxMessage( | |||
| role=MinimaxMessage.Role.ASSISTANT.value, | |||
| @@ -160,6 +165,12 @@ class MinimaxChatCompletionPro(object): | |||
| 'total_tokens': total_tokens | |||
| } | |||
| message.stop_reason = data['choices'][0]['finish_reason'] | |||
| if function_call_storage: | |||
| function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) | |||
| function_call_message.function_call = function_call_storage | |||
| yield function_call_message | |||
| yield message | |||
| return | |||
| @@ -168,11 +179,28 @@ class MinimaxChatCompletionPro(object): | |||
| continue | |||
| for choice in choices: | |||
| message = choice['messages'][0]['text'] | |||
| if not message: | |||
| continue | |||
| message = choice['messages'][0] | |||
| if 'function_call' in message: | |||
| if not function_call_storage: | |||
| function_call_storage = message['function_call'] | |||
| if 'arguments' not in function_call_storage or not function_call_storage['arguments']: | |||
| function_call_storage['arguments'] = '' | |||
| continue | |||
| else: | |||
| function_call_storage['arguments'] += message['function_call']['arguments'] | |||
| continue | |||
| else: | |||
| if function_call_storage: | |||
| message['function_call'] = function_call_storage | |||
| function_call_storage = None | |||
| yield MinimaxMessage( | |||
| content=message, | |||
| role=MinimaxMessage.Role.ASSISTANT.value | |||
| ) | |||
| minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) | |||
| if 'function_call' in message: | |||
| minimax_message.function_call = message['function_call'] | |||
| if 'text' in message: | |||
| minimax_message.content = message['text'] | |||
| yield minimax_message | |||
| @@ -2,7 +2,7 @@ from typing import Generator, List | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, | |||
| SystemPromptMessage, UserPromptMessage) | |||
| SystemPromptMessage, UserPromptMessage, ToolPromptMessage) | |||
| from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, | |||
| InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| @@ -84,6 +84,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): | |||
| """ | |||
| client: MinimaxChatCompletionPro = self.model_apis[model]() | |||
| if tools: | |||
| tools = [{ | |||
| "name": tool.name, | |||
| "description": tool.description, | |||
| "parameters": tool.parameters | |||
| } for tool in tools] | |||
| response = client.generate( | |||
| model=model, | |||
| api_key=credentials['minimax_api_key'], | |||
| @@ -109,7 +116,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): | |||
| elif isinstance(prompt_message, UserPromptMessage): | |||
| return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content) | |||
| elif isinstance(prompt_message, AssistantPromptMessage): | |||
| if prompt_message.tool_calls: | |||
| message = MinimaxMessage( | |||
| role=MinimaxMessage.Role.ASSISTANT.value, | |||
| content='' | |||
| ) | |||
| message.function_call={ | |||
| 'name': prompt_message.tool_calls[0].function.name, | |||
| 'arguments': prompt_message.tool_calls[0].function.arguments | |||
| } | |||
| return message | |||
| return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content) | |||
| elif isinstance(prompt_message, ToolPromptMessage): | |||
| return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content) | |||
| else: | |||
| raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') | |||
| @@ -151,6 +170,28 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): | |||
| finish_reason=message.stop_reason if message.stop_reason else None, | |||
| ), | |||
| ) | |||
| elif message.function_call: | |||
| if 'name' not in message.function_call or 'arguments' not in message.function_call: | |||
| continue | |||
| yield LLMResultChunk( | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| delta=LLMResultChunkDelta( | |||
| index=0, | |||
| message=AssistantPromptMessage( | |||
| content='', | |||
| tool_calls=[AssistantPromptMessage.ToolCall( | |||
| id='', | |||
| type='function', | |||
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( | |||
| name=message.function_call['name'], | |||
| arguments=message.function_call['arguments'] | |||
| ) | |||
| )] | |||
| ), | |||
| ), | |||
| ) | |||
| else: | |||
| yield LLMResultChunk( | |||
| model=model, | |||
| @@ -7,13 +7,23 @@ class MinimaxMessage: | |||
| USER = 'USER' | |||
| ASSISTANT = 'BOT' | |||
| SYSTEM = 'SYSTEM' | |||
| FUNCTION = 'FUNCTION' | |||
| role: str = Role.USER.value | |||
| content: str | |||
| usage: Dict[str, int] = None | |||
| stop_reason: str = '' | |||
| function_call: Dict[str, Any] = None | |||
| def to_dict(self) -> Dict[str, Any]: | |||
| if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: | |||
| return { | |||
| 'sender_type': 'BOT', | |||
| 'sender_name': '专家', | |||
| 'text': '', | |||
| 'function_call': self.function_call | |||
| } | |||
| return { | |||
| 'sender_type': self.role, | |||
| 'sender_name': '我' if self.role == 'USER' else '专家', | |||
| @@ -6,6 +6,7 @@ model_type: llm | |||
| features: | |||
| - multi-tool-call | |||
| - agent-thought | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 4096 | |||
| @@ -6,6 +6,7 @@ model_type: llm | |||
| features: | |||
| - multi-tool-call | |||
| - agent-thought | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 16385 | |||
| @@ -6,6 +6,7 @@ model_type: llm | |||
| features: | |||
| - multi-tool-call | |||
| - agent-thought | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 16385 | |||
| @@ -6,6 +6,7 @@ model_type: llm | |||
| features: | |||
| - multi-tool-call | |||
| - agent-thought | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 16385 | |||
| @@ -6,6 +6,7 @@ model_type: llm | |||
| features: | |||
| - multi-tool-call | |||
| - agent-thought | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 4096 | |||
| @@ -6,6 +6,7 @@ model_type: llm | |||
| features: | |||
| - multi-tool-call | |||
| - agent-thought | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 128000 | |||
| @@ -6,6 +6,7 @@ model_type: llm | |||
| features: | |||
| - multi-tool-call | |||
| - agent-thought | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 128000 | |||
| @@ -6,6 +6,7 @@ model_type: llm | |||
| features: | |||
| - multi-tool-call | |||
| - agent-thought | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 32768 | |||
| @@ -6,6 +6,7 @@ model_type: llm | |||
| features: | |||
| - multi-tool-call | |||
| - agent-thought | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 128000 | |||
| @@ -6,6 +6,7 @@ model_type: llm | |||
| features: | |||
| - multi-tool-call | |||
| - agent-thought | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 8192 | |||
| @@ -671,7 +671,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): | |||
| else: | |||
| raise ValueError(f"Got unknown type {message}") | |||
| if message.name is not None: | |||
| if message.name: | |||
| message_dict["name"] = message.name | |||
| return message_dict | |||
| @@ -3,14 +3,14 @@ from typing import Generator, Iterator, List, Optional, Union, cast | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, | |||
| SystemPromptMessage, UserPromptMessage) | |||
| SystemPromptMessage, UserPromptMessage, ToolPromptMessage) | |||
| from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, | |||
| ParameterRule, ParameterType) | |||
| ParameterRule, ParameterType, ModelFeature) | |||
| from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, | |||
| InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.model_providers.xinference.llm.xinference_helper import (XinferenceHelper, | |||
| from core.model_runtime.model_providers.xinference.xinference_helper import (XinferenceHelper, | |||
| XinferenceModelExtraParameter) | |||
| from core.model_runtime.utils import helper | |||
| from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, | |||
| @@ -33,6 +33,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` | |||
| """ | |||
| if 'temperature' in model_parameters: | |||
| if model_parameters['temperature'] < 0.01: | |||
| model_parameters['temperature'] = 0.01 | |||
| elif model_parameters['temperature'] > 1.0: | |||
| model_parameters['temperature'] = 0.99 | |||
| return self._generate( | |||
| model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, | |||
| tools=tools, stop=stop, stream=stream, user=user, | |||
| @@ -65,6 +71,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| credentials['completion_type'] = 'completion' | |||
| else: | |||
| raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported') | |||
| if extra_param.support_function_call: | |||
| credentials['support_function_call'] = True | |||
| except RuntimeError as e: | |||
| raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') | |||
| @@ -220,6 +229,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| elif isinstance(message, SystemPromptMessage): | |||
| message = cast(SystemPromptMessage, message) | |||
| message_dict = {"role": "system", "content": message.content} | |||
| elif isinstance(message, ToolPromptMessage): | |||
| message = cast(ToolPromptMessage, message) | |||
| message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} | |||
| else: | |||
| raise ValueError(f"Unknown message type {type(message)}") | |||
| @@ -237,7 +249,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| label=I18nObject( | |||
| zh_Hans='温度', | |||
| en_US='Temperature' | |||
| ) | |||
| ), | |||
| ), | |||
| ParameterRule( | |||
| name='top_p', | |||
| @@ -282,6 +294,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| completion_type = LLMMode.COMPLETION.value | |||
| else: | |||
| raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') | |||
| support_function_call = credentials.get('support_function_call', False) | |||
| entity = AIModelEntity( | |||
| model=model, | |||
| @@ -290,6 +304,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| ), | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_type=ModelType.LLM, | |||
| features=[ | |||
| ModelFeature.TOOL_CALL | |||
| ] if support_function_call else [], | |||
| model_properties={ | |||
| ModelPropertyKey.MODE: completion_type, | |||
| }, | |||
| @@ -310,6 +327,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` | |||
| """ | |||
| if 'server_url' not in credentials: | |||
| raise CredentialsValidateFailedError('server_url is required in credentials') | |||
| if credentials['server_url'].endswith('/'): | |||
| credentials['server_url'] = credentials['server_url'][:-1] | |||
| client = OpenAI( | |||
| base_url=f'{credentials["server_url"]}/v1', | |||
| api_key='abc', | |||
| @@ -2,7 +2,7 @@ import time | |||
| from typing import Optional | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType, ModelPropertyKey | |||
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | |||
| from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, | |||
| InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) | |||
| @@ -10,6 +10,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle | |||
| from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper | |||
| class XinferenceTextEmbeddingModel(TextEmbeddingModel): | |||
| """ | |||
| @@ -35,7 +36,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): | |||
| """ | |||
| server_url = credentials['server_url'] | |||
| model_uid = credentials['model_uid'] | |||
| if server_url.endswith('/'): | |||
| server_url = server_url[:-1] | |||
| client = Client(base_url=server_url) | |||
| try: | |||
| @@ -102,8 +106,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): | |||
| :return: | |||
| """ | |||
| try: | |||
| server_url = credentials['server_url'] | |||
| model_uid = credentials['model_uid'] | |||
| extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) | |||
| if extra_args.max_tokens: | |||
| credentials['max_tokens'] = extra_args.max_tokens | |||
| self._invoke(model=model, credentials=credentials, texts=['ping']) | |||
| except InvokeAuthorizationError: | |||
| except (InvokeAuthorizationError, RuntimeError): | |||
| raise CredentialsValidateFailedError('Invalid api key') | |||
| @property | |||
| @@ -160,6 +171,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): | |||
| """ | |||
| used to define customizable model schema | |||
| """ | |||
| entity = AIModelEntity( | |||
| model=model, | |||
| label=I18nObject( | |||
| @@ -167,7 +179,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): | |||
| ), | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model_properties={}, | |||
| model_properties={ | |||
| ModelPropertyKey.MAX_CHUNKS: 1, | |||
| ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512, | |||
| }, | |||
| parameter_rules=[] | |||
| ) | |||
| @@ -1,6 +1,7 @@ | |||
| from threading import Lock | |||
| from time import time | |||
| from typing import List | |||
| from os import path | |||
| from requests import get | |||
| from requests.adapters import HTTPAdapter | |||
| @@ -12,11 +13,16 @@ class XinferenceModelExtraParameter(object): | |||
| model_format: str | |||
| model_handle_type: str | |||
| model_ability: List[str] | |||
| max_tokens: int = 512 | |||
| support_function_call: bool = False | |||
| def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str]) -> None: | |||
| def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str], | |||
| support_function_call: bool, max_tokens: int) -> None: | |||
| self.model_format = model_format | |||
| self.model_handle_type = model_handle_type | |||
| self.model_ability = model_ability | |||
| self.support_function_call = support_function_call | |||
| self.max_tokens = max_tokens | |||
| cache = {} | |||
| cache_lock = Lock() | |||
| @@ -49,7 +55,7 @@ class XinferenceHelper: | |||
| get xinference model extra parameter like model_format and model_handle_type | |||
| """ | |||
| url = f'{server_url}/v1/models/{model_uid}' | |||
| url = path.join(server_url, 'v1/models', model_uid) | |||
| # this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 | |||
| session = Session() | |||
| @@ -66,10 +72,12 @@ class XinferenceHelper: | |||
| response_json = response.json() | |||
| model_format = response_json['model_format'] | |||
| model_ability = response_json['model_ability'] | |||
| model_format = response_json.get('model_format', 'ggmlv3') | |||
| model_ability = response_json.get('model_ability', []) | |||
| if model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: | |||
| if response_json.get('model_type') == 'embedding': | |||
| model_handle_type = 'embedding' | |||
| elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: | |||
| model_handle_type = 'chatglm' | |||
| elif 'generate' in model_ability: | |||
| model_handle_type = 'generate' | |||
| @@ -78,8 +86,13 @@ class XinferenceHelper: | |||
| else: | |||
| raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') | |||
| support_function_call = 'tools' in model_ability | |||
| max_tokens = response_json.get('max_tokens', 512) | |||
| return XinferenceModelExtraParameter( | |||
| model_format=model_format, | |||
| model_handle_type=model_handle_type, | |||
| model_ability=model_ability | |||
| model_ability=model_ability, | |||
| support_function_call=support_function_call, | |||
| max_tokens=max_tokens | |||
| ) | |||
| @@ -2,6 +2,10 @@ model: glm-3-turbo | |||
| label: | |||
| en_US: glm-3-turbo | |||
| model_type: llm | |||
| features: | |||
| - multi-tool-call | |||
| - agent-thought | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| parameter_rules: | |||
| @@ -2,6 +2,10 @@ model: glm-4 | |||
| label: | |||
| en_US: glm-4 | |||
| model_type: llm | |||
| features: | |||
| - multi-tool-call | |||
| - agent-thought | |||
| - stream-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| parameter_rules: | |||
| @@ -194,6 +194,27 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| 'content': prompt_message.content, | |||
| 'tool_call_id': prompt_message.tool_call_id | |||
| }) | |||
| elif isinstance(prompt_message, AssistantPromptMessage): | |||
| if prompt_message.tool_calls: | |||
| params['messages'].append({ | |||
| 'role': 'assistant', | |||
| 'content': prompt_message.content, | |||
| 'tool_calls': [ | |||
| { | |||
| 'id': tool_call.id, | |||
| 'type': tool_call.type, | |||
| 'function': { | |||
| 'name': tool_call.function.name, | |||
| 'arguments': tool_call.function.arguments | |||
| } | |||
| } for tool_call in prompt_message.tool_calls | |||
| ] | |||
| }) | |||
| else: | |||
| params['messages'].append({ | |||
| 'role': 'assistant', | |||
| 'content': prompt_message.content | |||
| }) | |||
| else: | |||
| params['messages'].append({ | |||
| 'role': prompt_message.role.value, | |||
| @@ -47,7 +47,7 @@ dashscope[tokenizer]~=1.14.0 | |||
| huggingface_hub~=0.16.4 | |||
| transformers~=4.31.0 | |||
| pandas==1.5.3 | |||
| xinference-client~=0.6.4 | |||
| xinference-client~=0.8.1 | |||
| safetensors==0.3.2 | |||
| zhipuai==1.0.7 | |||
| werkzeug~=3.0.1 | |||
| @@ -19,58 +19,86 @@ class MockXinferenceClass(object): | |||
| raise RuntimeError('404 Not Found') | |||
| if 'generate' == model_uid: | |||
| return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url) | |||
| return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |||
| if 'chat' == model_uid: | |||
| return RESTfulChatModelHandle(model_uid, base_url=self.base_url) | |||
| return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |||
| if 'embedding' == model_uid: | |||
| return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url) | |||
| return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |||
| if 'rerank' == model_uid: | |||
| return RESTfulRerankModelHandle(model_uid, base_url=self.base_url) | |||
| return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) | |||
| raise RuntimeError('404 Not Found') | |||
| def get(self: Session, url: str, **kwargs): | |||
| if '/v1/models/' in url: | |||
| response = Response() | |||
| response = Response() | |||
| if 'v1/models/' in url: | |||
| # get model uid | |||
| model_uid = url.split('/')[-1] | |||
| if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ | |||
| model_uid not in ['generate', 'chat', 'embedding', 'rerank']: | |||
| response.status_code = 404 | |||
| raise ConnectionError('404 Not Found') | |||
| return response | |||
| # check if url is valid | |||
| if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): | |||
| response.status_code = 404 | |||
| raise ConnectionError('404 Not Found') | |||
| return response | |||
| if model_uid in ['generate', 'chat']: | |||
| response.status_code = 200 | |||
| response._content = b'''{ | |||
| "model_type": "LLM", | |||
| "address": "127.0.0.1:43877", | |||
| "accelerators": [ | |||
| "0", | |||
| "1" | |||
| ], | |||
| "model_name": "chatglm3-6b", | |||
| "model_lang": [ | |||
| "en" | |||
| ], | |||
| "model_ability": [ | |||
| "generate", | |||
| "chat" | |||
| ], | |||
| "model_description": "latest chatglm3", | |||
| "model_format": "pytorch", | |||
| "model_size_in_billions": 7, | |||
| "quantization": "none", | |||
| "model_hub": "huggingface", | |||
| "revision": null, | |||
| "context_length": 2048, | |||
| "replica": 1 | |||
| }''' | |||
| return response | |||
| elif model_uid == 'embedding': | |||
| response.status_code = 200 | |||
| response._content = b'''{ | |||
| "model_type": "embedding", | |||
| "address": "127.0.0.1:43877", | |||
| "accelerators": [ | |||
| "0", | |||
| "1" | |||
| ], | |||
| "model_name": "bge", | |||
| "model_lang": [ | |||
| "en" | |||
| ], | |||
| "revision": null, | |||
| "max_tokens": 512 | |||
| }''' | |||
| return response | |||
| elif 'v1/cluster/auth' in url: | |||
| response.status_code = 200 | |||
| response._content = b'''{ | |||
| "model_type": "LLM", | |||
| "address": "127.0.0.1:43877", | |||
| "accelerators": [ | |||
| "0", | |||
| "1" | |||
| ], | |||
| "model_name": "chatglm3-6b", | |||
| "model_lang": [ | |||
| "en" | |||
| ], | |||
| "model_ability": [ | |||
| "generate", | |||
| "chat" | |||
| ], | |||
| "model_description": "latest chatglm3", | |||
| "model_format": "pytorch", | |||
| "model_size_in_billions": 7, | |||
| "quantization": "none", | |||
| "model_hub": "huggingface", | |||
| "revision": null, | |||
| "context_length": 2048, | |||
| "replica": 1 | |||
| "auth": true | |||
| }''' | |||
| return response | |||
| def _check_cluster_authenticated(self): | |||
| self._cluster_authed = True | |||
| def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict: | |||
| # check if self._model_uid is a valid uuid | |||
| if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ | |||
| @@ -133,6 +161,7 @@ MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' | |||
| def setup_xinference_mock(request, monkeypatch: MonkeyPatch): | |||
| if MOCK: | |||
| monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model) | |||
| monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated) | |||
| monkeypatch.setattr(Session, 'get', MockXinferenceClass.get) | |||
| monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding) | |||
| monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank) | |||