| @@ -5,6 +5,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, | |||
| from core.model_runtime.entities.message_entities import ( | |||
| AssistantPromptMessage, | |||
| PromptMessage, | |||
| PromptMessageContent, | |||
| PromptMessageContentType, | |||
| PromptMessageRole, | |||
| PromptMessageTool, | |||
| @@ -31,6 +32,7 @@ And you should always end the block with a "```" to indicate the end of the JSON | |||
| ```JSON""" | |||
| class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| def _invoke(self, model: str, credentials: dict, | |||
| @@ -159,7 +161,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| if len(prompt_messages) == 0: | |||
| raise ValueError('At least one message is required') | |||
| if prompt_messages[0].role == PromptMessageRole.SYSTEM: | |||
| if not prompt_messages[0].content: | |||
| prompt_messages = prompt_messages[1:] | |||
| @@ -185,7 +187,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| continue | |||
| if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \ | |||
| copy_prompt_message.role == PromptMessageRole.USER: | |||
| copy_prompt_message.role == PromptMessageRole.USER: | |||
| new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content | |||
| else: | |||
| if copy_prompt_message.role == PromptMessageRole.USER: | |||
| @@ -205,31 +207,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| new_prompt_messages.append(copy_prompt_message) | |||
| if model == 'glm-4v': | |||
| params = { | |||
| 'model': model, | |||
| 'messages': [{ | |||
| 'role': prompt_message.role.value, | |||
| 'content': | |||
| [ | |||
| { | |||
| 'type': 'text', | |||
| 'text': prompt_message.content | |||
| } | |||
| ] if isinstance(prompt_message.content, str) else | |||
| [ | |||
| { | |||
| 'type': 'image', | |||
| 'image_url': { | |||
| 'url': content.data | |||
| } | |||
| } if content.type == PromptMessageContentType.IMAGE else { | |||
| 'type': 'text', | |||
| 'text': content.data | |||
| } for content in prompt_message.content | |||
| ], | |||
| } for prompt_message in new_prompt_messages], | |||
| **model_parameters | |||
| } | |||
| params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) | |||
| else: | |||
| params = { | |||
| 'model': model, | |||
| @@ -277,8 +255,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| for prompt_message in new_prompt_messages: | |||
| # merge system message to user message | |||
| if prompt_message.role == PromptMessageRole.SYSTEM or \ | |||
| prompt_message.role == PromptMessageRole.TOOL or \ | |||
| prompt_message.role == PromptMessageRole.USER: | |||
| prompt_message.role == PromptMessageRole.TOOL or \ | |||
| prompt_message.role == PromptMessageRole.USER: | |||
| if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user': | |||
| params['messages'][-1]['content'] += "\n\n" + prompt_message.content | |||
| else: | |||
| @@ -306,8 +284,44 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| response = client.chat.completions.create(**params, **extra_model_kwargs) | |||
| return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) | |||
| def _handle_generate_response(self, model: str, | |||
| def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage], | |||
| model_parameters: dict): | |||
| messages = [ | |||
| { | |||
| 'role': message.role.value, | |||
| 'content': self._construct_glm_4v_messages(message.content) | |||
| } | |||
| for message in prompt_messages | |||
| ] | |||
| params = { | |||
| 'model': model, | |||
| 'messages': messages, | |||
| **model_parameters | |||
| } | |||
| return params | |||
| def _construct_glm_4v_messages(self, prompt_message: Union[str | list[PromptMessageContent]]) -> list[dict]: | |||
| if isinstance(prompt_message, str): | |||
| return [{'type': 'text', 'text': prompt_message}] | |||
| return [ | |||
| {'type': 'image_url', 'image_url': {'url': self._remove_image_header(item.data)}} | |||
| if item.type == PromptMessageContentType.IMAGE else | |||
| {'type': 'text', 'text': item.data} | |||
| for item in prompt_message | |||
| ] | |||
| def _remove_image_header(self, image: str) -> str: | |||
| if image.startswith('data:image'): | |||
| return image.split(',')[1] | |||
| return image | |||
| def _handle_generate_response(self, model: str, | |||
| credentials: dict, | |||
| tools: Optional[list[PromptMessageTool]], | |||
| response: Completion, | |||
| @@ -338,7 +352,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| ) | |||
| text += choice.message.content or '' | |||
| prompt_usage = response.usage.prompt_tokens | |||
| completion_usage = response.usage.completion_tokens | |||
| @@ -358,7 +372,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| return result | |||
| def _handle_generate_stream_response(self, model: str, | |||
| def _handle_generate_stream_response(self, model: str, | |||
| credentials: dict, | |||
| tools: Optional[list[PromptMessageTool]], | |||
| responses: Generator[ChatCompletionChunk, None, None], | |||
| @@ -380,7 +394,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): | |||
| continue | |||
| assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] | |||
| for tool_call in delta.delta.tool_calls or []: | |||
| if tool_call.type == 'function': | |||
| @@ -454,8 +468,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| return message_text | |||
| def _convert_messages_to_prompt(self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str: | |||
| def _convert_messages_to_prompt(self, messages: list[PromptMessage], | |||
| tools: Optional[list[PromptMessageTool]] = None) -> str: | |||
| """ | |||
| :param messages: List of PromptMessage to combine. | |||
| :return: Combined string with necessary human_prompt and ai_prompt tags. | |||
| @@ -473,4 +487,4 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| text += f"\n{tool.json()}" | |||
| # trim off the trailing ' ' that might come from the "Assistant: " | |||
| return text.rstrip() | |||
| return text.rstrip() | |||